/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.Ctable;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.ReblockBuffer;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.CTableMap;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.LongLongDoubleHashMap;
import scala.Tuple2;

public class CtableSPInstruction
extends ComputationSPInstruction {
    private String _outDim1;
    private String _outDim2;
    private boolean _dim1Literal;
    private boolean _dim2Literal;
    private boolean _isExpand;
    private final boolean _ignoreZeros;
    private final boolean _outputEmptyBlocks;

    private CtableSPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, boolean ignoreZeros, boolean outputEmptyBlocks, String opcode, String istr) {
        super(SPInstruction.SPType.Ctable, null, in1, in2, in3, out, opcode, istr);
        this._outDim1 = outputDim1;
        this._dim1Literal = dim1Literal;
        this._outDim2 = outputDim2;
        this._dim2Literal = dim2Literal;
        this._isExpand = isExpand;
        this._ignoreZeros = ignoreZeros;
        this._outputEmptyBlocks = outputEmptyBlocks;
    }

    public static CtableSPInstruction parseInstruction(String inst) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
        InstructionUtils.checkNumFields(parts, 8);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ctable") && !opcode.equalsIgnoreCase("ctableexpand")) {
            throw new DMLRuntimeException("Unexpected opcode in TertiarySPInstruction: " + inst);
        }
        boolean isExpand = opcode.equalsIgnoreCase("ctableexpand");
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        String[] dim1Fields = parts[4].split("\u00b7");
        String[] dim2Fields = parts[5].split("\u00b7");
        CPOperand out = new CPOperand(parts[6]);
        boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
        boolean outputEmptyBlocks = Boolean.parseBoolean(parts[8]);
        return new CtableSPInstruction(in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, outputEmptyBlocks, opcode, inst);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        long dim2;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        Ctable.OperationTypes ctableOp = Ctable.findCtableOperationByInputDataTypes(this.input1.getDataType(), this.input2.getDataType(), this.input3.getDataType());
        ctableOp = this._isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = !ctableOp.hasSecondInput() ? null : sec.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = null;
        double s2 = -1.0;
        double s3 = -1.0;
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
        long dim1 = this._dim1Literal ? (long)Double.parseDouble(this._outDim1) : sec.getScalarInput(this._outDim1, Types.ValueType.FP64, false).getLongValue();
        long l = dim2 = this._dim2Literal ? (long)Double.parseDouble(this._outDim2) : sec.getScalarInput(this._outDim2, Types.ValueType.FP64, false).getLongValue();
        if (dim1 == -1L && dim2 == -1L) {
            dim1 = (long)RDDAggregateUtils.max(in1);
            dim2 = ctableOp.hasSecondInput() ? (long)RDDAggregateUtils.max(in2) : sec.getScalarInput(this.input3).getLongValue();
        }
        mcOut.set(dim1, dim2, mc1.getBlocksize());
        mcOut.setNonZerosBound(mc1.getLength());
        mcOut.setNoEmptyBlocks(!this._outputEmptyBlocks);
        if (!mcOut.dimsKnown()) {
            throw new DMLRuntimeException("Unknown ctable output dimensions: " + mcOut);
        }
        int numParts = Math.max(4 * (mc1.dimsKnown() ? SparkUtils.getNumPreferredPartitions(mc1) : in1.getNumPartitions()), SparkUtils.getNumPreferredPartitions(mcOut, this._outputEmptyBlocks));
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
        switch (ctableOp) {
            case CTABLE_TRANSFORM: {
                in3 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input3.getName());
                out = in1.join(in2, numParts).join(in3, numParts).mapValues(new MapJoinSignature3()).mapPartitionsToPair(new CTableFunction(ctableOp, s2, s3, this._ignoreZeros, mcOut));
                break;
            }
            case CTABLE_EXPAND_SCALAR_WEIGHT: 
            case CTABLE_TRANSFORM_SCALAR_WEIGHT: {
                s3 = sec.getScalarInput(this.input3).getDoubleValue();
                out = in1.join(in2, numParts).mapValues(new MapJoinSignature2()).mapPartitionsToPair(new CTableFunction(ctableOp, s2, s3, this._ignoreZeros, mcOut));
                break;
            }
            case CTABLE_TRANSFORM_HISTOGRAM: {
                s2 = sec.getScalarInput(this.input2).getDoubleValue();
                s3 = sec.getScalarInput(this.input3).getDoubleValue();
                out = in1.mapValues(new MapJoinSignature1()).mapPartitionsToPair(new CTableFunction(ctableOp, s2, s3, this._ignoreZeros, mcOut));
                break;
            }
            case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
                in3 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input3.getName());
                s2 = sec.getScalarInput(this.input2).getDoubleValue();
                out = in1.join(in3, numParts).mapValues(new MapJoinSignature2()).mapPartitionsToPair(new CTableFunction(ctableOp, s2, s3, this._ignoreZeros, mcOut));
                break;
            }
            default: {
                throw new DMLRuntimeException("Encountered an invalid ctable operation (" + ctableOp + ") while executing instruction: " + this.toString());
            }
        }
        out = !this._outputEmptyBlocks ? out : out.union(SparkUtils.getEmptyBlockRDD(sec.getSparkContext(), mcOut));
        out = RDDAggregateUtils.sumByKeyStable(out, numParts, false);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
        if (ctableOp.hasSecondInput()) {
            sec.addLineageRDD(this.output.getName(), this.input2.getName());
        }
        if (ctableOp.hasThirdInput()) {
            sec.addLineageRDD(this.output.getName(), this.input3.getName());
        }
        SparkUtils.postprocessUltraSparseOutput(sec.getMatrixObject(this.output), mcOut);
    }

    public CPOperand getOutDim1() {
        return new CPOperand(this._outDim1, Types.ValueType.FP64, Types.DataType.SCALAR, this._dim1Literal);
    }

    public CPOperand getOutDim2() {
        return new CPOperand(this._outDim1, Types.ValueType.FP64, Types.DataType.SCALAR, this._dim1Literal);
    }

    public boolean getIsExpand() {
        return this._isExpand;
    }

    public boolean getIgnoreZeros() {
        return this._ignoreZeros;
    }

    public static class MapJoinSignature3
    implements Function<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock[]> {
        private static final long serialVersionUID = -5222678882354280164L;

        public MatrixBlock[] call(Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> v1) throws Exception {
            return (MatrixBlock[])ArrayUtils.toArray((Object[])new MatrixBlock[]{(MatrixBlock)((Tuple2)v1._1())._1(), (MatrixBlock)((Tuple2)v1._1())._2(), (MatrixBlock)v1._2()});
        }
    }

    public static class MapJoinSignature2
    implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock[]> {
        private static final long serialVersionUID = 7690448020081435520L;

        public MatrixBlock[] call(Tuple2<MatrixBlock, MatrixBlock> v1) throws Exception {
            return (MatrixBlock[])ArrayUtils.toArray((Object[])new MatrixBlock[]{(MatrixBlock)v1._1(), (MatrixBlock)v1._2()});
        }
    }

    public static class MapJoinSignature1
    implements Function<MatrixBlock, MatrixBlock[]> {
        private static final long serialVersionUID = -8819908424033945028L;

        public MatrixBlock[] call(MatrixBlock v1) throws Exception {
            return (MatrixBlock[])ArrayUtils.toArray((Object[])new MatrixBlock[]{v1});
        }
    }

    private static class CTableFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 5348127596473232337L;
        private final Ctable.OperationTypes _ctableOp;
        private final double _scalar_input2;
        private final double _scalar_input3;
        private final boolean _ignoreZeros;
        private final long _dim1;
        private final long _dim2;
        private final int _blen;

        public CTableFunction(Ctable.OperationTypes ctableOp, double s2, double s3, boolean ignoreZeros, DataCharacteristics mcOut) {
            this(ctableOp, s2, s3, ignoreZeros, false, mcOut);
        }

        public CTableFunction(Ctable.OperationTypes ctableOp, double s2, double s3, boolean ignoreZeros, boolean emitEmpty, DataCharacteristics mcOut) {
            this._ctableOp = ctableOp;
            this._scalar_input2 = s2;
            this._scalar_input3 = s3;
            this._ignoreZeros = ignoreZeros;
            this._dim1 = mcOut.getRows();
            this._dim2 = mcOut.getCols();
            this._blen = mcOut.getBlocksize();
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg0) throws Exception {
            CTableMap map = new CTableMap();
            MatrixBlock block = null;
            while (arg0.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock[]> tmp = arg0.next();
                MatrixIndexes ix = (MatrixIndexes)tmp._1();
                MatrixBlock[] mb = (MatrixBlock[])tmp._2();
                switch (this._ctableOp) {
                    case CTABLE_TRANSFORM: {
                        OperationsOnMatrixValues.performCtable(ix, mb[0], ix, mb[1], ix, mb[2], map, block, null);
                        break;
                    }
                    case CTABLE_EXPAND_SCALAR_WEIGHT: 
                    case CTABLE_TRANSFORM_SCALAR_WEIGHT: {
                        mb[0].ctableOperations(null, mb[1], this._scalar_input3, this._ignoreZeros, map, block);
                        break;
                    }
                    case CTABLE_TRANSFORM_HISTOGRAM: {
                        OperationsOnMatrixValues.performCtable(ix, mb[0], this._scalar_input2, this._scalar_input3, map, block, null);
                        break;
                    }
                    case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
                        OperationsOnMatrixValues.performCtable(ix, (MatrixValue)mb[0], this._scalar_input2, ix, mb[1], map, block, null);
                        break;
                    }
                }
            }
            ReblockBuffer rbuff = new ReblockBuffer(Math.min(0x400000, map.size()), this._dim1, this._dim2, this._blen);
            ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes, MatrixBlock>>();
            Iterator<LongLongDoubleHashMap.ADoubleEntry> iter = map.getIterator();
            while (iter.hasNext()) {
                LongLongDoubleHashMap.ADoubleEntry e = iter.next();
                if (e.getKey1() > this._dim1 || e.getKey2() > this._dim2) continue;
                if (rbuff.getSize() >= rbuff.getCapacity()) {
                    this.flushBufferToList(rbuff, ret);
                }
                rbuff.appendCell(e.getKey1(), e.getKey2(), e.value);
            }
            if (rbuff.getSize() > 0) {
                this.flushBufferToList(rbuff, ret);
            }
            return ret.iterator();
        }

        protected void flushBufferToList(ReblockBuffer rbuff, ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret) throws DMLRuntimeException {
            rbuff.flushBufferToBinaryBlocks().stream().map(b -> SparkUtils.fromIndexedMatrixBlock(b)).forEach(b -> ret.add((Tuple2<MatrixIndexes, MatrixBlock>)b));
        }
    }
}

