/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.confignode.manager.load.balancer.router.leader;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.confignode.manager.load.balancer.router.leader.AbstractLeaderBalancer;
import org.apache.iotdb.confignode.manager.load.cache.node.NodeStatistics;
import org.apache.iotdb.confignode.manager.load.cache.region.RegionStatistics;

public class MinCostFlowLeaderBalancer
extends AbstractLeaderBalancer {
    private static final int INFINITY = Integer.MAX_VALUE;
    private static final int S_NODE = 0;
    private static final int T_NODE = 1;
    private int maxNode = 2;
    private final Map<TConsensusGroupId, Integer> rNodeMap = new TreeMap<TConsensusGroupId, Integer>();
    private final Map<String, Map<Integer, Integer>> sDNodeMap = new TreeMap<String, Map<Integer, Integer>>();
    private final Map<String, Map<Integer, Integer>> sDNodeReflect = new TreeMap<String, Map<Integer, Integer>>();
    private final Map<Integer, Integer> tDNodeMap = new TreeMap<Integer, Integer>();
    private int maxEdge = 0;
    private final List<MinCostFlowEdge> minCostFlowEdges = new ArrayList<MinCostFlowEdge>();
    private int[] nodeHeadEdge;
    private int[] nodeCurrentEdge;
    private boolean[] isNodeVisited;
    private int[] nodeMinimumCost;
    private int maximumFlow = 0;
    private int minimumCost = 0;

    @Override
    public Map<TConsensusGroupId, Integer> generateOptimalLeaderDistribution(Map<String, List<TConsensusGroupId>> databaseRegionGroupMap, Map<TConsensusGroupId, Set<Integer>> regionLocationMap, Map<TConsensusGroupId, Integer> regionLeaderMap, Map<Integer, NodeStatistics> dataNodeStatisticsMap, Map<TConsensusGroupId, Map<Integer, RegionStatistics>> regionStatisticsMap) {
        this.initialize(databaseRegionGroupMap, regionLocationMap, regionLeaderMap, dataNodeStatisticsMap, regionStatisticsMap);
        this.constructMCFGraph();
        this.dinicAlgorithm();
        Map<TConsensusGroupId, Integer> result = this.collectLeaderDistribution();
        this.clear();
        return result;
    }

    @Override
    protected void clear() {
        super.clear();
        this.rNodeMap.clear();
        this.sDNodeMap.clear();
        this.sDNodeReflect.clear();
        this.tDNodeMap.clear();
        this.minCostFlowEdges.clear();
        this.nodeHeadEdge = null;
        this.nodeCurrentEdge = null;
        this.isNodeVisited = null;
        this.nodeMinimumCost = null;
        this.maxNode = 2;
        this.maxEdge = 0;
    }

    private void constructMCFGraph() {
        String database;
        this.maximumFlow = 0;
        this.minimumCost = 0;
        for (Map.Entry entry : this.databaseRegionGroupMap.entrySet()) {
            database = (String)entry.getKey();
            this.sDNodeMap.put(database, new TreeMap());
            this.sDNodeReflect.put(database, new TreeMap());
            for (TConsensusGroupId regionGroupId2 : (List)entry.getValue()) {
                if (!this.regionGroupIntersection.contains(regionGroupId2)) continue;
                this.rNodeMap.put(regionGroupId2, this.maxNode++);
                ((Set)this.regionLocationMap.get(regionGroupId2)).forEach(dataNodeId -> {
                    if (this.isDataNodeAvailable((int)dataNodeId)) {
                        if (!this.sDNodeMap.get(database).containsKey(dataNodeId)) {
                            this.sDNodeMap.get(database).put((Integer)dataNodeId, this.maxNode);
                            this.sDNodeReflect.get(database).put(this.maxNode, (Integer)dataNodeId);
                            ++this.maxNode;
                        }
                        if (!this.tDNodeMap.containsKey(dataNodeId)) {
                            this.tDNodeMap.put((Integer)dataNodeId, this.maxNode);
                            ++this.maxNode;
                        }
                    }
                });
            }
        }
        this.isNodeVisited = new boolean[this.maxNode];
        this.nodeMinimumCost = new int[this.maxNode];
        this.nodeCurrentEdge = new int[this.maxNode];
        this.nodeHeadEdge = new int[this.maxNode];
        Arrays.fill(this.nodeHeadEdge, -1);
        Iterator<Object> iterator = this.rNodeMap.values().iterator();
        while (iterator.hasNext()) {
            int n = (Integer)iterator.next();
            this.addAdjacentEdges(0, n, 1, 0);
        }
        for (Map.Entry entry : this.databaseRegionGroupMap.entrySet()) {
            database = (String)entry.getKey();
            for (TConsensusGroupId regionGroupId2 : (List)entry.getValue()) {
                if (!this.regionGroupIntersection.contains(regionGroupId2)) continue;
                int rNode = this.rNodeMap.get(regionGroupId2);
                ((Set)this.regionLocationMap.get(regionGroupId2)).forEach(dataNodeId -> {
                    if (this.isDataNodeAvailable((int)dataNodeId) && this.isRegionAvailable(regionGroupId2, (int)dataNodeId)) {
                        int sDNode = this.sDNodeMap.get(database).get(dataNodeId);
                        int cost = Objects.equals(this.regionLeaderMap.getOrDefault(regionGroupId2, -1), dataNodeId) ? 0 : 1;
                        this.addAdjacentEdges(rNode, sDNode, 1, cost);
                    }
                });
            }
        }
        for (Map.Entry entry : this.databaseRegionGroupMap.entrySet()) {
            database = (String)entry.getKey();
            TreeMap leaderCounter = new TreeMap();
            for (TConsensusGroupId regionGroupId3 : (List)entry.getValue()) {
                if (!this.regionGroupIntersection.contains(regionGroupId3)) continue;
                ((Set)this.regionLocationMap.get(regionGroupId3)).forEach(dataNodeId -> {
                    if (this.isDataNodeAvailable((int)dataNodeId)) {
                        int sDNode = this.sDNodeMap.get(database).get(dataNodeId);
                        int tDNode = this.tDNodeMap.get(dataNodeId);
                        int leaderCount = leaderCounter.merge(dataNodeId, 1, Integer::sum);
                        this.addAdjacentEdges(sDNode, tDNode, 1, leaderCount * leaderCount);
                    }
                });
            }
        }
        TreeMap maxLeaderCounter = new TreeMap();
        this.regionLocationMap.forEach((regionGroupId, dataNodeIds) -> dataNodeIds.forEach(dataNodeId -> {
            if (this.isDataNodeAvailable((int)dataNodeId) && this.tDNodeMap.containsKey(dataNodeId)) {
                int tDNode = this.tDNodeMap.get(dataNodeId);
                int leaderCount = maxLeaderCounter.merge(dataNodeId, 1, Integer::sum);
                this.addAdjacentEdges(tDNode, 1, 1, leaderCount * leaderCount);
            }
        }));
    }

    private void addAdjacentEdges(int fromNode, int destNode, int capacity, int cost) {
        this.addEdge(fromNode, destNode, capacity, cost);
        this.addEdge(destNode, fromNode, 0, -cost);
    }

    private void addEdge(int fromNode, int destNode, int capacity, int cost) {
        MinCostFlowEdge edge = new MinCostFlowEdge(destNode, capacity, cost, this.nodeHeadEdge[fromNode]);
        this.minCostFlowEdges.add(edge);
        ++this.maxEdge;
    }

    private boolean bellmanFordCheck() {
        Arrays.fill(this.isNodeVisited, false);
        Arrays.fill(this.nodeMinimumCost, Integer.MAX_VALUE);
        LinkedList<Integer> queue = new LinkedList<Integer>();
        this.nodeMinimumCost[0] = 0;
        this.isNodeVisited[0] = true;
        queue.offer(0);
        while (!queue.isEmpty()) {
            int currentNode = (Integer)queue.poll();
            this.isNodeVisited[currentNode] = false;
            int currentEdge = this.nodeHeadEdge[currentNode];
            while (currentEdge >= 0) {
                MinCostFlowEdge edge = this.minCostFlowEdges.get(currentEdge);
                if (edge.capacity > 0 && this.nodeMinimumCost[currentNode] + edge.cost < this.nodeMinimumCost[edge.destNode]) {
                    this.nodeMinimumCost[((MinCostFlowEdge)edge).destNode] = this.nodeMinimumCost[currentNode] + edge.cost;
                    if (!this.isNodeVisited[edge.destNode]) {
                        this.isNodeVisited[((MinCostFlowEdge)edge).destNode] = true;
                        queue.offer(edge.destNode);
                    }
                }
                currentEdge = this.minCostFlowEdges.get(currentEdge).nextEdge;
            }
        }
        return this.nodeMinimumCost[1] < Integer.MAX_VALUE;
    }

    private int dfsAugmentation(int currentNode, int inputFlow) {
        if (currentNode == 1 || inputFlow == 0) {
            return inputFlow;
        }
        int outputFlow = 0;
        this.isNodeVisited[currentNode] = true;
        int currentEdge = this.nodeCurrentEdge[currentNode];
        while (currentEdge >= 0) {
            MinCostFlowEdge edge = this.minCostFlowEdges.get(currentEdge);
            if (this.nodeMinimumCost[currentNode] + edge.cost == this.nodeMinimumCost[edge.destNode] && edge.capacity > 0 && !this.isNodeVisited[edge.destNode]) {
                int subOutputFlow = this.dfsAugmentation(edge.destNode, Math.min(inputFlow, edge.capacity));
                this.minimumCost += subOutputFlow * edge.cost;
                edge.capacity -= subOutputFlow;
                this.minCostFlowEdges.get(currentEdge ^ 1).capacity += subOutputFlow;
                outputFlow += subOutputFlow;
                if ((inputFlow -= subOutputFlow) == 0) break;
            }
            currentEdge = this.minCostFlowEdges.get(currentEdge).nextEdge;
        }
        this.nodeCurrentEdge[currentNode] = currentEdge;
        if (outputFlow > 0) {
            this.isNodeVisited[currentNode] = false;
        }
        return outputFlow;
    }

    private void dinicAlgorithm() {
        while (this.bellmanFordCheck()) {
            int currentFlow;
            System.arraycopy(this.nodeHeadEdge, 0, this.nodeCurrentEdge, 0, this.maxNode);
            while ((currentFlow = this.dfsAugmentation(0, Integer.MAX_VALUE)) > 0) {
                this.maximumFlow += currentFlow;
            }
        }
    }

    private Map<TConsensusGroupId, Integer> collectLeaderDistribution() {
        ConcurrentHashMap<TConsensusGroupId, Integer> result = new ConcurrentHashMap<TConsensusGroupId, Integer>();
        this.databaseRegionGroupMap.forEach((database, regionGroupIds) -> regionGroupIds.forEach(regionGroupId -> {
            int originalLeader = this.regionLeaderMap.getOrDefault(regionGroupId, -1);
            if (!this.regionGroupIntersection.contains(regionGroupId)) {
                result.put((TConsensusGroupId)regionGroupId, originalLeader);
                return;
            }
            boolean matchLeader = false;
            int currentEdge = this.nodeHeadEdge[this.rNodeMap.get(regionGroupId)];
            while (currentEdge >= 0) {
                MinCostFlowEdge edge = this.minCostFlowEdges.get(currentEdge);
                if (edge.destNode != 0 && edge.capacity == 0) {
                    matchLeader = true;
                    result.put((TConsensusGroupId)regionGroupId, this.sDNodeReflect.get(database).get(edge.destNode));
                }
                currentEdge = this.minCostFlowEdges.get(currentEdge).nextEdge;
            }
            if (!matchLeader) {
                result.put((TConsensusGroupId)regionGroupId, originalLeader);
            }
        }));
        return result;
    }

    public int getMaximumFlow() {
        return this.maximumFlow;
    }

    public int getMinimumCost() {
        return this.minimumCost;
    }

    private static class MinCostFlowEdge {
        private final int destNode;
        private int capacity;
        private final int cost;
        private final int nextEdge;

        private MinCostFlowEdge(int destNode, int capacity, int cost, int nextEdge) {
            this.destNode = destNode;
            this.capacity = capacity;
            this.cost = cost;
            this.nextEdge = nextEdge;
        }
    }
}

