/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.functions.FilterDiagMatrixBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInList;
import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInRange;
import org.apache.sysds.runtime.instructions.spark.functions.ReorgMapFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDSortUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

public class ReorgSPInstruction
extends UnarySPInstruction {
    private static final Log LOG = LogFactory.getLog((String)ReorgSPInstruction.class.getName());
    private CPOperand _col = null;
    private CPOperand _desc = null;
    private CPOperand _ixret = null;
    private boolean _bSortIndInMem = false;
    private CPOperand _shift = null;

    private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
        super(SPInstruction.SPType.Reorg, op, in, out, opcode, istr);
    }

    private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand desc, CPOperand ixret, CPOperand out, String opcode, boolean bSortIndInMem, String istr) {
        this(op, in, out, opcode, istr);
        this._col = col;
        this._desc = desc;
        this._ixret = ixret;
        this._bSortIndInMem = bSortIndInMem;
    }

    private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
        super(SPInstruction.SPType.Reorg, op, in, shift, null, out, opcode, istr);
        this._shift = shift;
    }

    public static ReorgSPInstruction parseInstruction(String str) {
        CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String opcode = InstructionUtils.getOpCode(str);
        if (opcode.equalsIgnoreCase(Opcodes.TRANSPOSE.toString())) {
            ReorgSPInstruction.parseUnaryInstruction(str, in, out);
            return new ReorgSPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, str);
        }
        if (opcode.equalsIgnoreCase(Opcodes.REV.toString())) {
            ReorgSPInstruction.parseUnaryInstruction(str, in, out);
            return new ReorgSPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
        }
        if (opcode.equalsIgnoreCase(Opcodes.ROLL.toString())) {
            String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
            InstructionUtils.checkNumFields(str, 3);
            in.split(parts[1]);
            out.split(parts[3]);
            CPOperand shift = new CPOperand(parts[2]);
            return new ReorgSPInstruction(new ReorgOperator(new RollIndex(0)), in, out, shift, opcode, str);
        }
        if (opcode.equalsIgnoreCase(Opcodes.DIAG.toString())) {
            ReorgSPInstruction.parseUnaryInstruction(str, in, out);
            return new ReorgSPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
        }
        if (opcode.equalsIgnoreCase(Opcodes.SORT.toString())) {
            String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
            InstructionUtils.checkNumFields(parts, 5, 6);
            in.split(parts[1]);
            out.split(parts[5]);
            CPOperand col = new CPOperand(parts[2]);
            CPOperand desc = new CPOperand(parts[3]);
            CPOperand ixret = new CPOperand(parts[4]);
            boolean bSortIndInMem = false;
            bSortIndInMem = Boolean.parseBoolean(parts[6]);
            return new ReorgSPInstruction(new ReorgOperator(new SortIndex(1, false, false)), in, col, desc, ixret, out, opcode, bSortIndInMem, str);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a ReorgInstruction: " + str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        String opcode = this.getOpcode();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        Object out = null;
        DataCharacteristics mcIn = sec.getDataCharacteristics(this.input1.getName());
        if (opcode.equalsIgnoreCase(Opcodes.TRANSPOSE.toString())) {
            out = in1.mapToPair((PairFunction)new ReorgMapFunction(opcode));
        } else if (opcode.equalsIgnoreCase(Opcodes.REV.toString())) {
            out = in1.flatMapToPair((PairFlatMapFunction)new RDDRevFunction(mcIn));
            if (mcIn.getRows() % (long)mcIn.getBlocksize() != 0L) {
                out = RDDAggregateUtils.mergeByKey(out, false);
            }
        } else if (opcode.equalsIgnoreCase(Opcodes.ROLL.toString())) {
            int shift = (int)ec.getScalarInput(this._shift).getLongValue();
            out = in1.flatMapToPair((PairFlatMapFunction)new RDDRollFunction(mcIn, shift));
            out = RDDAggregateUtils.mergeByKey(out, false);
        } else if (opcode.equalsIgnoreCase(Opcodes.DIAG.toString())) {
            out = mcIn.getCols() == 1L ? in1.flatMapToPair((PairFlatMapFunction)new RDDDiagV2MFunction(mcIn)) : in1.filter((Function)new FilterDiagMatrixBlocksFunction()).mapToPair((PairFunction)new ReorgMapFunction(opcode));
        } else if (opcode.equalsIgnoreCase(Opcodes.SORT.toString())) {
            long[] lArray;
            if (this._col.getDataType().isMatrix()) {
                lArray = DataConverter.convertToLongVector(ec.getMatrixInput(this._col.getName()));
            } else {
                long[] lArray2 = new long[1];
                lArray = lArray2;
                lArray2[0] = ec.getScalarInput(this._col).getLongValue();
            }
            long[] cols = lArray;
            boolean desc = ec.getScalarInput(this._desc).getBooleanValue();
            boolean ixret = ec.getScalarInput(this._ixret).getBooleanValue();
            boolean singleCol = mcIn.getCols() == 1L;
            out = in1;
            if (cols.length > mcIn.getBlocksize()) {
                LOG.warn((Object)("Unsupported sort with number of order-by columns large than blocksize: " + cols.length));
            }
            if (singleCol || cols.length == 1) {
                if (!singleCol) {
                    out = out.filter((Function)new IsBlockInRange(1L, mcIn.getRows(), cols[0], cols[0], mcIn)).mapValues((Function)new ExtractColumn(UtilFunctions.computeCellInBlock(cols[0], mcIn.getBlocksize())));
                }
                out = ixret ? RDDSortUtils.sortIndexesByVal(out, !desc, mcIn.getRows(), mcIn.getBlocksize()) : (singleCol && !desc ? RDDSortUtils.sortByVal(out, mcIn.getRows(), mcIn.getBlocksize()) : (!this._bSortIndInMem ? RDDSortUtils.sortDataByVal(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getBlocksize()) : RDDSortUtils.sortDataByValMemSort(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getBlocksize(), sec, (ReorgOperator)this._optr)));
            } else {
                if ((long)cols.length < mcIn.getCols()) {
                    out = out.filter((Function)new IsBlockInList(cols, mcIn)).mapToPair((PairFunction)new ExtractColumns(cols, mcIn));
                }
                if (mcIn.getCols() > (long)mcIn.getBlocksize()) {
                    out = RDDAggregateUtils.mergeByKey(out);
                }
                out = ixret ? RDDSortUtils.sortIndexesByVals(out, !desc, mcIn.getRows(), cols.length, mcIn.getBlocksize()) : ((long)cols.length == mcIn.getCols() && !desc ? RDDSortUtils.sortByVals(out, mcIn.getRows(), cols.length, mcIn.getBlocksize()) : RDDSortUtils.sortDataByVals(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), cols.length, mcIn.getBlocksize()));
            }
        } else {
            throw new DMLRuntimeException("Error: Incorrect opcode in ReorgSPInstruction:" + opcode);
        }
        if (opcode.equalsIgnoreCase(Opcodes.SORT.toString()) && this._col.getDataType().isMatrix()) {
            sec.releaseMatrixInput(this._col.getName());
        }
        this.updateReorgDataCharacteristics(sec);
        sec.setRDDHandleForVariable(this.output.getName(), (JavaPairRDD<?, ?>)out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
    }

    private void updateReorgDataCharacteristics(SparkExecutionContext sec) {
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
        if (!mcOut.dimsKnown()) {
            if (!mc1.dimsKnown()) {
                throw new DMLRuntimeException("Unable to compute output matrix characteristics from input.");
            }
            if (this.getOpcode().equalsIgnoreCase(Opcodes.TRANSPOSE.toString())) {
                mcOut.set(mc1.getCols(), mc1.getRows(), mc1.getBlocksize(), mc1.getBlocksize());
            } else if (this.getOpcode().equalsIgnoreCase(Opcodes.DIAG.toString())) {
                mcOut.set(mc1.getRows(), mc1.getCols() > 1L ? 1L : mc1.getRows(), mc1.getBlocksize(), mc1.getBlocksize());
            } else if (this.getOpcode().equalsIgnoreCase(Opcodes.SORT.toString())) {
                boolean ixret = sec.getScalarInput(this._ixret).getBooleanValue();
                mcOut.set(mc1.getRows(), ixret ? 1L : mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
            } else {
                mcOut.set(mc1);
            }
        }
        if (!mcOut.nnzKnown() && mc1.nnzKnown()) {
            boolean sortIx;
            boolean bl = sortIx = this.getOpcode().equalsIgnoreCase(Opcodes.SORT.toString()) && sec.getScalarInput(this._ixret).getBooleanValue();
            if (sortIx) {
                mcOut.setNonZeros(mc1.getRows());
            } else {
                mcOut.setNonZeros(mc1.getNonZeros());
            }
        }
    }

    public CPOperand getIxRet() {
        return this._ixret;
    }

    private static class ExtractColumns
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 2902729186431711506L;
        private final long[] _cols;
        private final int _blen;

        public ExtractColumns(long[] cols, DataCharacteristics mc) {
            this._cols = cols;
            this._blen = mc.getBlocksize();
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock in = (MatrixBlock)arg0._2();
            MatrixBlock out = new MatrixBlock(in.getNumRows(), this._cols.length, true);
            for (int i = 0; i < this._cols.length; ++i) {
                if (!UtilFunctions.isInBlockRange(ix, this._blen, new IndexRange(1L, Long.MAX_VALUE, this._cols[i], this._cols[i]))) continue;
                int index = UtilFunctions.computeCellInBlock(this._cols[i], this._blen);
                out.leftIndexingOperations(in.slice(0, in.getNumRows() - 1, index, index, new MatrixBlock()), 0, in.getNumRows() - 1, i, i, out, MatrixObject.UpdateType.INPLACE);
            }
            return new Tuple2((Object)new MatrixIndexes(ix.getRowIndex(), 1L), (Object)out);
        }
    }

    private static class ExtractColumn
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -1472164797288449559L;
        private int _col;

        public ExtractColumn(int col) {
            this._col = col;
        }

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            return arg0.slice(0, arg0.getNumRows() - 1, this._col, this._col, new MatrixBlock());
        }
    }

    private static class RDDRollFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1183373828539843938L;
        private DataCharacteristics _mcIn = null;
        private int _shift = 0;

        public RDDRollFunction(DataCharacteristics mcIn, int shift) {
            this._mcIn = mcIn;
            this._shift = shift;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
            IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg0);
            ArrayList<IndexedMatrixValue> out = new ArrayList<IndexedMatrixValue>();
            LibMatrixReorg.roll(in, this._mcIn.getRows(), this._mcIn.getBlocksize(), this._shift, out);
            return SparkUtils.fromIndexedMatrixBlock(out).iterator();
        }
    }

    private static class RDDRevFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1183373828539843938L;
        private DataCharacteristics _mcIn = null;

        public RDDRevFunction(DataCharacteristics mcIn) {
            this._mcIn = mcIn;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
            IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg0);
            ArrayList<IndexedMatrixValue> out = new ArrayList<IndexedMatrixValue>();
            LibMatrixReorg.rev(in, this._mcIn.getRows(), this._mcIn.getBlocksize(), out);
            return SparkUtils.fromIndexedMatrixBlock(out).iterator();
        }
    }

    private static class RDDDiagV2MFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 31065772250744103L;
        private ReorgOperator _reorgOp = new ReorgOperator(DiagIndex.getDiagIndexFnObject());
        private DataCharacteristics _mcIn = null;

        public RDDDiagV2MFunction(DataCharacteristics mcIn) {
            this._mcIn = mcIn;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            long rix = ixIn.getRowIndex();
            MatrixIndexes ixOut = new MatrixIndexes(rix, rix);
            MatrixBlock blkOut = blkIn.reorgOperations(this._reorgOp, new MatrixBlock(), -1, -1, -1);
            ret.add(new Tuple2((Object)ixOut, (Object)blkOut));
            int numBlocks = (int)Math.ceil((double)this._mcIn.getRows() / (double)this._mcIn.getBlocksize());
            for (int i = 1; i <= numBlocks; ++i) {
                if ((long)i == ixOut.getColumnIndex()) continue;
                int lrlen = UtilFunctions.computeBlockSize(this._mcIn.getRows(), rix, this._mcIn.getBlocksize());
                int lclen = UtilFunctions.computeBlockSize(this._mcIn.getRows(), i, this._mcIn.getBlocksize());
                MatrixBlock emptyBlk = new MatrixBlock(lrlen, lclen, true);
                ret.add(new Tuple2((Object)new MatrixIndexes(rix, i), (Object)emptyBlk));
            }
            return ret.iterator();
        }
    }
}

