/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import com.sun.tools.javac.util.List;
import java.util.Objects;
import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;

public class TernaryFEDInstruction
extends ComputationFEDInstruction {
    private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str, FEDInstruction.FederatedOutput fedOut) {
        super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str, fedOut);
    }

    public static TernaryFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        CPOperand operand1 = new CPOperand(parts[1]);
        CPOperand operand2 = new CPOperand(parts[2]);
        CPOperand operand3 = new CPOperand(parts[3]);
        CPOperand outOperand = new CPOperand(parts[4]);
        int numThreads = parts.length > 5 ? Integer.parseInt(parts[5]) : 1;
        FEDInstruction.FederatedOutput fedOut = parts.length > 7 ? FEDInstruction.FederatedOutput.valueOf(parts[6]) : FEDInstruction.FederatedOutput.NONE;
        TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode, numThreads);
        return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str, fedOut);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixObject mo3;
        MatrixObject mo2;
        MatrixObject mo1 = this.input1.isMatrix() ? ec.getMatrixObject(this.input1.getName()) : null;
        long matrixInputsCount = List.of(mo1, mo2 = this.input2.isMatrix() ? ec.getMatrixObject(this.input2.getName()) : null, mo3 = this.input3 != null && this.input3.isMatrix() ? ec.getMatrixObject(this.input3.getName()) : null).stream().filter(Objects::nonNull).count();
        if (matrixInputsCount == 3L) {
            this.processMatrixInput(ec, mo1, mo2, mo3);
        } else if (matrixInputsCount == 1L) {
            CPOperand in;
            CPOperand cPOperand = mo1 == null ? (mo2 == null ? this.input3 : this.input2) : (in = this.input1);
            mo1 = mo1 == null ? (mo2 == null ? mo3 : mo2) : mo1;
            this.processMatrixScalarInput(ec, mo1, in);
        } else if (mo1 != null && mo2 != null) {
            if (this.input3 != null && !this.input3.isLiteral()) {
                this.instString = InstructionUtils.replaceOperand(this.instString, 4, InstructionUtils.createLiteralOperand(ec.getScalarInput(this.input3).getStringValue(), Types.ValueType.FP64));
            }
            this.process2MatrixScalarInput(ec, mo1, mo2, this.input1, this.input2);
        } else if (mo2 != null && mo3 != null) {
            if (!this.input1.isLiteral()) {
                this.instString = InstructionUtils.replaceOperand(this.instString, 2, InstructionUtils.createLiteralOperand(ec.getScalarInput(this.input1).getStringValue(), Types.ValueType.FP64));
            }
            this.process2MatrixScalarInput(ec, mo2, mo3, this.input2, this.input3);
        } else if (mo1 != null && mo3 != null) {
            if (!this.input2.isLiteral()) {
                this.instString = InstructionUtils.replaceOperand(this.instString, 3, InstructionUtils.createLiteralOperand(ec.getScalarInput(this.input2).getStringValue(), Types.ValueType.FP64));
            }
            this.process2MatrixScalarInput(ec, mo1, mo3, this.input1, this.input3);
        }
    }

    private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
        FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{in}, new long[]{mo1.getFedMapping().getID()});
        this.sendFederatedRequests(ec, mo1, fr1.getID(), fr1);
    }

    private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, CPOperand in1, CPOperand in2) {
        long[] varNewIn;
        FederatedRequest[] fr1 = null;
        boolean cleanupIn = true;
        CPOperand[] varOldIn = new CPOperand[]{in1, in2};
        if (mo1.isFederated()) {
            if (mo2.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
                varNewIn = new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()};
            } else {
                fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
                varNewIn = new long[]{mo1.getFedMapping().getID(), fr1[0].getID()};
            }
        } else {
            cleanupIn = false;
            mo1 = ec.getMatrixObject(in2);
            fr1 = mo1.getFedMapping().broadcastSliced(ec.getMatrixObject(in1), false);
            varNewIn = new long[]{fr1[0].getID(), mo1.getFedMapping().getID()};
        }
        FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, varOldIn, varNewIn);
        if (fr1 == null) {
            this.sendFederatedRequests(ec, mo1, fr2.getID(), fr2);
        } else if (cleanupIn) {
            FederatedRequest fr3 = mo1.getFedMapping().cleanup(this.getTID(), fr1[0].getID());
            this.sendFederatedRequests(ec, mo1, fr2.getID(), fr1, new FederatedRequest[]{fr2, fr3});
        } else {
            this.sendFederatedRequests(ec, mo1, fr2.getID(), fr1, new FederatedRequest[]{fr2});
        }
    }

    private void sendFederatedRequests(ExecutionContext ec, MatrixObject fedMapObj, long fedOutputID, FederatedRequest ... federatedRequests) {
        this.sendFederatedRequests(ec, fedMapObj, fedOutputID, (FederatedRequest[])null, (FederatedRequest[])null, federatedRequests);
    }

    private void sendFederatedRequests(ExecutionContext ec, MatrixObject fedMapObj, long fedOutputID, FederatedRequest[] federatedSlices, FederatedRequest ... federatedRequests) {
        this.sendFederatedRequests(ec, fedMapObj, fedOutputID, federatedSlices, (FederatedRequest[])null, federatedRequests);
    }

    private void sendFederatedRequests(ExecutionContext ec, MatrixObject fedMapObj, long fedOutputID, FederatedRequest[] federatedSlices1, FederatedRequest[] federatedSlices2, FederatedRequest ... federatedRequests) {
        if (!this._fedOut.isForcedLocal()) {
            fedMapObj.getFedMapping().execute(this.getTID(), true, federatedSlices1, federatedSlices2, federatedRequests);
            this.setOutputFedMapping(ec, fedMapObj, fedOutputID);
        } else {
            this.processAndRetrieve(ec, fedMapObj, fedOutputID, federatedSlices1, federatedSlices2, federatedRequests);
        }
    }

    private void processAndRetrieve(ExecutionContext ec, MatrixObject fedMapObj, long fedOutputID, FederatedRequest[] federatedSlices1, FederatedRequest[] federatedSlices2, FederatedRequest ... federatedRequests) {
        FederatedRequest getRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fedOutputID);
        Future<FederatedResponse>[] executionResponse = fedMapObj.getFedMapping().execute(this.getTID(), true, federatedSlices1, federatedSlices2, TernaryFEDInstruction.collectRequests(federatedRequests, getRequest));
        ec.setMatrixOutput(this.output.getName(), FederationUtils.bind(executionResponse, fedMapObj.isFederated(FederationMap.FType.COL)));
    }

    private static FederatedRequest[] collectRequests(FederatedRequest[] fedRequests, FederatedRequest fedRequest1) {
        FederatedRequest[] allRequests = new FederatedRequest[fedRequests.length + 1];
        for (int i = 0; i < fedRequests.length; ++i) {
            allRequests[i] = fedRequests[i];
        }
        allRequests[allRequests.length - 1] = fedRequest1;
        return allRequests;
    }

    private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
        RetAlignedValues retAlignedValues = this.getAlignedInputs(ec, mo1, mo2, mo3);
        if (retAlignedValues._allAligned) {
            FederatedRequest fr3 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
            this.sendFederatedRequests(ec, mo1, fr3.getID(), fr3);
        } else if (retAlignedValues._twoAligned) {
            FederatedRequest fr3 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, retAlignedValues._vars);
            FederatedRequest fr4 = mo1.getFedMapping().cleanup(this.getTID(), retAlignedValues._fr[0].getID());
            this.sendFederatedRequests(ec, mo1, fr3.getID(), retAlignedValues._fr, new FederatedRequest[]{fr3, fr4});
        } else {
            if (!mo1.isFederated()) {
                if (mo2.isFederated()) {
                    mo1 = mo2;
                    mo2 = ec.getMatrixObject(this.input1);
                } else {
                    mo1 = mo3;
                    mo3 = ec.getMatrixObject(this.input1);
                }
            }
            FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
            FederatedRequest[] fr2 = mo1.getFedMapping().broadcastSliced(mo3, false);
            long[] vars = new long[]{mo1.getFedMapping().getID(), fr1[0].getID(), fr2[0].getID()};
            if (!ec.getMatrixObject(this.input1).isFederated()) {
                long[] lArray;
                if (ec.getMatrixObject(this.input2).isFederated()) {
                    long[] lArray2 = new long[3];
                    lArray2[0] = fr1[0].getID();
                    lArray2[1] = mo1.getFedMapping().getID();
                    lArray = lArray2;
                    lArray2[2] = fr2[0].getID();
                } else {
                    long[] lArray3 = new long[3];
                    lArray3[0] = fr1[0].getID();
                    lArray3[1] = fr2[0].getID();
                    lArray = lArray3;
                    lArray3[2] = mo1.getFedMapping().getID();
                }
                vars = lArray;
            }
            FederatedRequest fr3 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, vars);
            FederatedRequest fr4 = mo1.getFedMapping().cleanup(this.getTID(), fr1[0].getID(), fr2[0].getID());
            this.sendFederatedRequests(ec, mo1, fr3.getID(), fr1, fr2, new FederatedRequest[]{fr3, fr4});
        }
    }

    private RetAlignedValues getAlignedInputs(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
        long[] vars = new long[]{};
        FederatedRequest[] fr = new FederatedRequest[]{};
        boolean twoAligned = false;
        boolean allAligned = false;
        if (mo1.isFederated() && mo2.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
            twoAligned = true;
            fr = mo1.getFedMapping().broadcastSliced(mo3, false);
            vars = new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), fr[0].getID()};
        }
        if (mo1.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {
            allAligned = twoAligned;
            twoAligned = true;
            fr = mo1.getFedMapping().broadcastSliced(mo2, false);
            vars = new long[]{mo1.getFedMapping().getID(), fr[0].getID(), mo3.getFedMapping().getID()};
        }
        if (mo2.isFederated() && mo3.isFederated() && mo2.getFedMapping().isAligned(mo3.getFedMapping(), false) && !allAligned) {
            twoAligned = true;
            mo1 = mo2;
            mo2 = mo3;
            mo3 = ec.getMatrixObject(this.input1);
            fr = mo1.getFedMapping().broadcastSliced(mo3, false);
            vars = new long[]{fr[0].getID(), mo1.getFedMapping().getID(), mo2.getFedMapping().getID()};
        }
        return new RetAlignedValues(twoAligned, allAligned, vars, fr);
    }

    private void setOutputFedMapping(ExecutionContext ec, MatrixObject fedMapObj, long fedOutputID) {
        MatrixObject out = ec.getMatrixObject(this.output);
        out.setFedMapping(fedMapObj.getFedMapping().copyWithNewID(fedOutputID));
    }

    private static final class RetAlignedValues {
        public boolean _twoAligned;
        public boolean _allAligned;
        public long[] _vars;
        public FederatedRequest[] _fr;

        public RetAlignedValues(boolean twoAligned, boolean allAligned, long[] vars, FederatedRequest[] fr) {
            this._twoAligned = twoAligned;
            this._allAligned = allAligned;
            this._vars = vars;
            this._fr = fr;
        }
    }
}

