/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.execution.operator.process.ai;

import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
import org.apache.iotdb.commons.client.ainode.AINodeClient;
import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
import org.apache.iotdb.db.exception.runtime.ModelInferenceProcessException;
import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
import org.apache.iotdb.db.queryengine.execution.operator.Operator;
import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext;
import org.apache.iotdb.db.queryengine.execution.operator.process.ProcessOperator;
import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter;
import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter;
import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType;
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.read.common.block.TsBlock;
import org.apache.tsfile.read.common.block.TsBlockBuilder;
import org.apache.tsfile.read.common.block.column.TimeColumnBuilder;
import org.apache.tsfile.read.common.block.column.TsBlockSerde;
import org.apache.tsfile.utils.RamUsageEstimator;

public class InferenceOperator
implements ProcessOperator {
    private static final long INSTANCE_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceOperator.class);
    private final OperatorContext operatorContext;
    private final Operator child;
    private final ModelInferenceDescriptor modelInferenceDescriptor;
    private final TsBlockBuilder inputTsBlockBuilder;
    private final ExecutorService modelInferenceExecutor;
    private ListenableFuture<TInferenceResp> inferenceExecutionFuture;
    private boolean finished = false;
    private final long maxRetainedSize;
    private final long maxReturnSize;
    private final int[] columnIndexes;
    private long totalRow;
    private int resultIndex = 0;
    private List<ByteBuffer> results;
    private final TsBlockSerde serde = new TsBlockSerde();
    private InferenceWindowType windowType = null;
    private final boolean generateTimeColumn;
    private long maxTimestamp;
    private long minTimestamp;
    private long interval;
    private long currentRowIndex;

    public InferenceOperator(OperatorContext operatorContext, Operator child, ModelInferenceDescriptor modelInferenceDescriptor, ExecutorService modelInferenceExecutor, List<String> targetColumnNames, List<String> inputColumnNames, boolean generateTimeColumn, long maxRetainedSize, long maxReturnSize) {
        this.operatorContext = operatorContext;
        this.child = child;
        this.modelInferenceDescriptor = modelInferenceDescriptor;
        this.inputTsBlockBuilder = new TsBlockBuilder(Arrays.asList(modelInferenceDescriptor.getModelInformation().getInputDataType()));
        this.modelInferenceExecutor = modelInferenceExecutor;
        this.columnIndexes = new int[inputColumnNames.size()];
        for (int i = 0; i < inputColumnNames.size(); ++i) {
            this.columnIndexes[i] = targetColumnNames.indexOf(inputColumnNames.get(i));
        }
        this.maxRetainedSize = maxRetainedSize;
        this.maxReturnSize = maxReturnSize;
        this.totalRow = 0L;
        if (modelInferenceDescriptor.getInferenceWindowParameter() != null) {
            this.windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType();
        }
        if (generateTimeColumn) {
            this.interval = 0L;
            this.minTimestamp = Long.MAX_VALUE;
            this.maxTimestamp = Long.MIN_VALUE;
            this.currentRowIndex = 0L;
        }
        this.generateTimeColumn = generateTimeColumn;
    }

    @Override
    public OperatorContext getOperatorContext() {
        return this.operatorContext;
    }

    @Override
    public ListenableFuture<?> isBlocked() {
        ListenableFuture<?> childBlocked = this.child.isBlocked();
        boolean executionDone = this.forecastExecutionDone();
        if (executionDone && childBlocked.isDone()) {
            return NOT_BLOCKED;
        }
        if (childBlocked.isDone()) {
            return this.inferenceExecutionFuture;
        }
        if (executionDone) {
            return childBlocked;
        }
        return Futures.successfulAsList(Arrays.asList(this.inferenceExecutionFuture, childBlocked));
    }

    private boolean forecastExecutionDone() {
        if (this.inferenceExecutionFuture == null) {
            return true;
        }
        return this.inferenceExecutionFuture.isDone();
    }

    @Override
    public boolean hasNext() throws Exception {
        return !this.finished || this.results != null && this.results.size() != this.resultIndex;
    }

    private void fillTimeColumn(TsBlock tsBlock) {
        Column timeColumn = tsBlock.getTimeColumn();
        long[] time = timeColumn.getLongs();
        for (int i = 0; i < time.length; ++i) {
            time[i] = this.maxTimestamp + this.interval * this.currentRowIndex;
            ++this.currentRowIndex;
        }
    }

    @Override
    public TsBlock next() throws Exception {
        if (this.inferenceExecutionFuture == null) {
            if (this.child.hasNextWithTimer()) {
                TsBlock inputTsBlock = this.child.nextWithTimer();
                if (inputTsBlock != null) {
                    this.appendTsBlockToBuilder(inputTsBlock);
                }
            } else {
                this.submitInferenceTask();
            }
            return null;
        }
        if (this.results != null && this.resultIndex != this.results.size()) {
            TsBlock tsBlock = this.serde.deserialize(this.results.get(this.resultIndex));
            if (this.generateTimeColumn) {
                this.fillTimeColumn(tsBlock);
            }
            ++this.resultIndex;
            return tsBlock;
        }
        try {
            if (!this.inferenceExecutionFuture.isDone()) {
                throw new IllegalStateException("The operator cannot continue until the forecast execution is done.");
            }
            TInferenceResp inferenceResp = (TInferenceResp)this.inferenceExecutionFuture.get();
            if (inferenceResp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
                String message = String.format("Error occurred while executing inference:[%s]", inferenceResp.getStatus().getMessage());
                throw new ModelInferenceProcessException(message);
            }
            this.finished = true;
            TsBlock resultTsBlock = this.serde.deserialize((ByteBuffer)inferenceResp.inferenceResult.get(0));
            if (this.generateTimeColumn) {
                this.fillTimeColumn(resultTsBlock);
            }
            this.results = inferenceResp.inferenceResult;
            ++this.resultIndex;
            return resultTsBlock;
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new ModelInferenceProcessException(e.getMessage());
        }
        catch (ExecutionException e) {
            throw new ModelInferenceProcessException(e.getMessage());
        }
    }

    private void appendTsBlockToBuilder(TsBlock inputTsBlock) {
        TimeColumnBuilder timeColumnBuilder = this.inputTsBlockBuilder.getTimeColumnBuilder();
        ColumnBuilder[] columnBuilders = this.inputTsBlockBuilder.getValueColumnBuilders();
        this.totalRow += (long)inputTsBlock.getPositionCount();
        for (int i = 0; i < inputTsBlock.getPositionCount(); ++i) {
            long timestamp = inputTsBlock.getTimeByIndex(i);
            if (this.generateTimeColumn) {
                this.minTimestamp = Math.min(this.minTimestamp, timestamp);
                this.maxTimestamp = Math.max(this.maxTimestamp, timestamp);
            }
            timeColumnBuilder.writeLong(timestamp);
            for (int columnIndex = 0; columnIndex < inputTsBlock.getValueColumnCount(); ++columnIndex) {
                columnBuilders[this.columnIndexes[columnIndex]].write(inputTsBlock.getColumn(columnIndex), i);
            }
            this.inputTsBlockBuilder.declarePosition();
        }
    }

    private TWindowParams getWindowParams() {
        TWindowParams windowParams;
        if (this.windowType == null) {
            return null;
        }
        if (this.windowType == InferenceWindowType.COUNT) {
            CountInferenceWindowParameter countInferenceWindowParameter = (CountInferenceWindowParameter)this.modelInferenceDescriptor.getInferenceWindowParameter();
            windowParams = new TWindowParams();
            windowParams.setWindowInterval((int)countInferenceWindowParameter.getInterval());
            windowParams.setWindowStep((int)countInferenceWindowParameter.getStep());
        } else {
            windowParams = null;
        }
        return windowParams;
    }

    private TsBlock preProcess(TsBlock inputTsBlock) {
        boolean notBuiltIn = false;
        if (this.windowType == null || this.windowType == InferenceWindowType.HEAD) {
            if (notBuiltIn && this.totalRow != (long)this.modelInferenceDescriptor.getModelInformation().getInputShape()[0]) {
                throw new ModelInferenceProcessException(String.format("The number of rows %s in the input data does not match the model input %s. Try to use LIMIT in SQL or WINDOW in CALL INFERENCE", this.totalRow, this.modelInferenceDescriptor.getModelInformation().getInputShape()[0]));
            }
            return inputTsBlock;
        }
        if (this.windowType == InferenceWindowType.COUNT) {
            if (notBuiltIn && this.totalRow < (long)this.modelInferenceDescriptor.getModelInformation().getInputShape()[0]) {
                throw new ModelInferenceProcessException(String.format("The number of rows %s in the input data is less than the model input %s. ", this.totalRow, this.modelInferenceDescriptor.getModelInformation().getInputShape()[0]));
            }
        } else if (this.windowType == InferenceWindowType.TAIL) {
            if (notBuiltIn && this.totalRow < (long)this.modelInferenceDescriptor.getModelInformation().getInputShape()[0]) {
                throw new ModelInferenceProcessException(String.format("The number of rows %s in the input data is less than the model input %s. ", this.totalRow, this.modelInferenceDescriptor.getModelInformation().getInputShape()[0]));
            }
            long windowSize = (int)((BottomInferenceWindowParameter)this.modelInferenceDescriptor.getInferenceWindowParameter()).getWindowSize();
            return inputTsBlock.subTsBlock((int)(this.totalRow - windowSize));
        }
        return inputTsBlock;
    }

    private void submitInferenceTask() {
        if (this.generateTimeColumn) {
            this.interval = (this.maxTimestamp - this.minTimestamp) / this.totalRow;
        }
        TsBlock inputTsBlock = this.inputTsBlockBuilder.build();
        TsBlock finalInputTsBlock = this.preProcess(inputTsBlock);
        TWindowParams windowParams = this.getWindowParams();
        this.inferenceExecutionFuture = Futures.submit(() -> {
            TInferenceResp tInferenceResp;
            block8: {
                AINodeClient client = (AINodeClient)AINodeClientManager.getInstance().borrowClient((Object)this.modelInferenceDescriptor.getTargetAINode());
                try {
                    tInferenceResp = client.inference(this.modelInferenceDescriptor.getModelName(), finalInputTsBlock, this.modelInferenceDescriptor.getInferenceAttributes(), windowParams);
                    if (client == null) break block8;
                }
                catch (Throwable throwable) {
                    try {
                        if (client != null) {
                            try {
                                client.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (Exception e) {
                        throw new ModelInferenceProcessException(e.getMessage());
                    }
                }
                client.close();
            }
            return tInferenceResp;
        }, (Executor)this.modelInferenceExecutor);
    }

    @Override
    public boolean isFinished() throws Exception {
        return this.finished && !this.hasNext();
    }

    @Override
    public void close() throws Exception {
        if (this.inferenceExecutionFuture != null) {
            this.inferenceExecutionFuture.cancel(true);
        }
        this.child.close();
    }

    @Override
    public long calculateMaxPeekMemory() {
        return this.maxReturnSize + this.maxRetainedSize + this.child.calculateMaxPeekMemory();
    }

    @Override
    public long calculateMaxReturnSize() {
        return this.maxReturnSize;
    }

    @Override
    public long calculateRetainedSizeAfterCallingNext() {
        return this.maxRetainedSize + this.child.calculateRetainedSizeAfterCallingNext();
    }

    public long ramBytesUsed() {
        return INSTANCE_SIZE + MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(this.child) + MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(this.operatorContext) + this.inputTsBlockBuilder.getRetainedSizeInBytes() + (long)this.columnIndexes.length * 4L;
    }
}

