/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.transform.nlpmodel.llm.remote;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.format.json.RowToJsonConverters;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.Model;

public abstract class AbstractModel
implements Model {
    protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final RowToJsonConverters.RowToJsonConverter rowToJsonConverter;
    private final SeaTunnelRowType rowType;
    private final String prompt;
    private final SqlType outputType;
    private final List<String> projectionColumns;

    public AbstractModel(SeaTunnelRowType rowType, SqlType outputType, List<String> projectionColumns, String prompt) {
        this.rowType = rowType;
        this.prompt = prompt;
        this.outputType = outputType;
        this.projectionColumns = projectionColumns;
        this.rowToJsonConverter = this.getRowToJsonConverter();
    }

    public RowToJsonConverters.RowToJsonConverter getRowToJsonConverter() {
        RowToJsonConverters converters = new RowToJsonConverters();
        if (this.projectionColumns != null && !this.projectionColumns.isEmpty()) {
            ArrayList<SeaTunnelDataType> fieldTypes = new ArrayList<SeaTunnelDataType>();
            for (String fieldName : this.projectionColumns) {
                int fieldIndex = this.rowType.indexOf(fieldName);
                if (fieldIndex != -1) {
                    fieldTypes.add(this.rowType.getFieldType(fieldIndex));
                    continue;
                }
                throw new IllegalArgumentException("Field name " + fieldName + " does not exist in the row type.");
            }
            SeaTunnelRowType projectionRowType = new SeaTunnelRowType(this.projectionColumns.toArray(new String[0]), fieldTypes.toArray(new SeaTunnelDataType[0]));
            return converters.createConverter((SeaTunnelDataType<?>)projectionRowType, null);
        }
        return converters.createConverter((SeaTunnelDataType<?>)this.rowType, null);
    }

    private String getPromptWithLimit() {
        return this.prompt + "\n The following rules need to be followed: \n 1. The received data is an array, and the result is returned in the form of an array.\n 2. Only the result needs to be returned, and no other information can be returned.\n 3. The element type of the array is " + this.outputType.toString() + ".\n Eg: [\"value1\", \"value2\"]";
    }

    @Override
    public List<String> inference(List<SeaTunnelRow> rows) throws IOException {
        ArrayNode rowsNode = OBJECT_MAPPER.createArrayNode();
        for (SeaTunnelRow row : rows) {
            ObjectNode rowNode = OBJECT_MAPPER.createObjectNode();
            this.rowToJsonConverter.convert(OBJECT_MAPPER, (JsonNode)rowNode, this.createProjectionSeaTunnelRow(row));
            rowsNode.add((JsonNode)rowNode);
        }
        return this.chatWithModel(this.getPromptWithLimit(), OBJECT_MAPPER.writeValueAsString((Object)rowsNode));
    }

    @VisibleForTesting
    public SeaTunnelRow createProjectionSeaTunnelRow(SeaTunnelRow row) {
        if (row == null || this.projectionColumns == null || this.projectionColumns.isEmpty()) {
            return row;
        }
        SeaTunnelRow projectionRow = new SeaTunnelRow(this.projectionColumns.size());
        for (int i = 0; i < this.projectionColumns.size(); ++i) {
            String fieldName = this.projectionColumns.get(i);
            int fieldIndex = this.rowType.indexOf(fieldName);
            if (fieldIndex == -1) {
                throw new IllegalArgumentException("Field name " + fieldName + " does not exist in the row type.");
            }
            projectionRow.setField(i, row.getField(fieldIndex));
        }
        return projectionRow;
    }

    protected abstract List<String> chatWithModel(String var1, String var2) throws IOException;

    protected String convertData(String data) {
        return this.outputType == SqlType.BOOLEAN ? data.toLowerCase() : data;
    }
}

