/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.ipa;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.ipa.FunctionCallGraph;
import org.apache.sysml.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysml.hops.ipa.IPAPass;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;

public class IPAPassEliminateDeadCode
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return true;
    }

    @Override
    public void rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        IPAPassEliminateDeadCode.findAndRemoveDeadCode(prog.getStatementBlocks(), new HashSet<String>(), fgraph);
        for (FunctionStatementBlock fsb : prog.getFunctionStatementBlocks()) {
            HashSet<String> usedVars = new HashSet<String>();
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            fstmt.getOutputParams().stream().forEach(d -> usedVars.add(d.getName()));
            IPAPassEliminateDeadCode.findAndRemoveDeadCode(fstmt.getBody(), usedVars, fgraph);
        }
    }

    private static void findAndRemoveDeadCode(List<StatementBlock> sbs, Set<String> usedVars, FunctionCallGraph fgraph) {
        for (int i = sbs.size() - 1; i >= 0; --i) {
            if (HopRewriteUtils.isLastLevelStatementBlock(sbs.get(i))) {
                ArrayList<Hop> roots = sbs.get(i).getHops();
                for (int j = 0; j < roots.size(); ++j) {
                    Hop root = (Hop)roots.get(j);
                    boolean isTWrite = HopRewriteUtils.isData(root, Hop.DataOpTypes.TRANSIENTWRITE);
                    boolean isFCall = IPAPassEliminateDeadCode.isFunctionCallWithUnusedOutputs(root, usedVars, fgraph);
                    if ((!isTWrite || usedVars.contains(root.getName())) && !isFCall) continue;
                    if (isFCall) {
                        String fkey = ((FunctionOp)root).getFunctionKey();
                        fgraph.removeFunctionCall(fkey, (FunctionOp)root, sbs.get(i));
                    }
                    roots.remove(j);
                    --j;
                    IPAPassEliminateDeadCode.rRemoveOpFromDAG(root);
                }
            }
            usedVars.addAll(IPAPassEliminateDeadCode.rCollectReadVariableNames(sbs.get(i), new HashSet<String>()));
        }
    }

    private static boolean isFunctionCallWithUnusedOutputs(Hop hop, Set<String> varNames, FunctionCallGraph fgraph) {
        return hop instanceof FunctionOp && fgraph.isSideEffectFreeFunction(((FunctionOp)hop).getFunctionKey()) && Arrays.stream(((FunctionOp)hop).getOutputVariableNames()).allMatch(var -> !varNames.contains(var));
    }

    private static void rRemoveOpFromDAG(Hop current) {
        for (Hop input : current.getInput()) {
            input.getParent().remove(current);
            if (!input.getParent().isEmpty()) continue;
            IPAPassEliminateDeadCode.rRemoveOpFromDAG(input);
        }
        current.getInput().clear();
    }

    private static Set<String> rCollectReadVariableNames(StatementBlock sb, Set<String> varNames) {
        block7: {
            block9: {
                block8: {
                    block6: {
                        if (!(sb instanceof WhileStatementBlock)) break block6;
                        WhileStatementBlock wsb = (WhileStatementBlock)sb;
                        WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
                        IPAPassEliminateDeadCode.collectReadVariableNames(wsb.getPredicateHops(), varNames);
                        for (StatementBlock csb : wstmt.getBody()) {
                            IPAPassEliminateDeadCode.rCollectReadVariableNames(csb, varNames);
                        }
                        break block7;
                    }
                    if (!(sb instanceof ForStatementBlock)) break block8;
                    ForStatementBlock fsb = (ForStatementBlock)sb;
                    ForStatement fstmt = (ForStatement)sb.getStatement(0);
                    IPAPassEliminateDeadCode.collectReadVariableNames(fsb.getFromHops(), varNames);
                    IPAPassEliminateDeadCode.collectReadVariableNames(fsb.getToHops(), varNames);
                    IPAPassEliminateDeadCode.collectReadVariableNames(fsb.getIncrementHops(), varNames);
                    for (StatementBlock csb : fstmt.getBody()) {
                        IPAPassEliminateDeadCode.rCollectReadVariableNames(csb, varNames);
                    }
                    break block7;
                }
                if (!(sb instanceof IfStatementBlock)) break block9;
                IfStatementBlock isb = (IfStatementBlock)sb;
                IfStatement istmt = (IfStatement)sb.getStatement(0);
                IPAPassEliminateDeadCode.collectReadVariableNames(isb.getPredicateHops(), varNames);
                for (StatementBlock csb : istmt.getIfBody()) {
                    IPAPassEliminateDeadCode.rCollectReadVariableNames(csb, varNames);
                }
                if (istmt.getElseBody() == null) break block7;
                for (StatementBlock csb : istmt.getElseBody()) {
                    IPAPassEliminateDeadCode.rCollectReadVariableNames(csb, varNames);
                }
                break block7;
            }
            if (sb.getHops() != null) {
                Hop.resetVisitStatus(sb.getHops());
                for (Hop hop : sb.getHops()) {
                    IPAPassEliminateDeadCode.rCollectReadVariableNames(hop, varNames);
                }
            }
        }
        return varNames;
    }

    private static Set<String> collectReadVariableNames(Hop hop, Set<String> varNames) {
        if (hop == null) {
            return varNames;
        }
        hop.resetVisitStatus();
        return IPAPassEliminateDeadCode.rCollectReadVariableNames(hop, varNames);
    }

    private static Set<String> rCollectReadVariableNames(Hop hop, Set<String> varNames) {
        if (hop.isVisited()) {
            return varNames;
        }
        for (Hop c : hop.getInput()) {
            IPAPassEliminateDeadCode.rCollectReadVariableNames(c, varNames);
        }
        if (HopRewriteUtils.isData(hop, Hop.DataOpTypes.TRANSIENTREAD)) {
            varNames.add(hop.getName());
        }
        hop.setVisited();
        return varNames;
    }
}

