/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.tokenize.applier;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
import org.apache.sysds.runtime.transform.tokenize.Token;
import org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

public class TokenizerApplierCount
extends TokenizerApplier {
    private static final long serialVersionUID = 6382000606237705019L;
    public boolean sort_alpha = false;
    private List<Map<String, Integer>> counts;

    public TokenizerApplierCount(int numIdCols, int maxTokens, boolean wideFormat, boolean applyPadding, JSONObject params) throws JSONException {
        super(numIdCols, maxTokens, wideFormat, applyPadding);
        if (params != null && params.has("sort_alpha")) {
            this.sort_alpha = params.getBoolean("sort_alpha");
        }
    }

    @Override
    public int getNumRows(DocumentRepresentation[] internalRepresentation) {
        if (this.wideFormat) {
            return internalRepresentation.length;
        }
        if (this.applyPadding) {
            return this.maxTokens * internalRepresentation.length;
        }
        return this.counts.stream().mapToInt(hashMap -> Math.min(hashMap.size(), this.maxTokens)).sum();
    }

    @Override
    public void allocateInternalMeta(int numDocuments) {
        this.counts = new ArrayList<Object>(Collections.nCopies(numDocuments, null));
    }

    @Override
    public void build(DocumentRepresentation[] internalRepresentation, int inputRowStart, int blk) {
        int endIndex = UtilFunctions.getEndIndex(internalRepresentation.length, inputRowStart, blk);
        for (int i = inputRowStart; i < endIndex; ++i) {
            HashMap<String, Integer> tokenCounts = new HashMap<String, Integer>();
            for (Token token : internalRepresentation[i].tokens) {
                String txt = token.toString();
                Integer count = tokenCounts.getOrDefault(txt, null);
                if (count != null) {
                    tokenCounts.put(txt, count + 1);
                    continue;
                }
                tokenCounts.put(txt, 1);
            }
            this.counts.set(i, tokenCounts);
        }
    }

    @Override
    public int applyInternalRepresentation(DocumentRepresentation[] internalRepresentation, FrameBlock out, int inputRowStart, int blk) {
        int endIndex = UtilFunctions.getEndIndex(internalRepresentation.length, inputRowStart, blk);
        int outputRow = this.getOutputRow(inputRowStart, this.counts);
        for (int i = inputRowStart; i < endIndex; ++i) {
            List<Object> keys = internalRepresentation[i].keys;
            Map<String, Integer> tokenCounts = this.counts.get(i);
            Set<String> distinctTokens = tokenCounts.keySet();
            if (this.sort_alpha) {
                distinctTokens = new TreeSet<String>(distinctTokens);
            }
            int numTokens = 0;
            for (String token : distinctTokens) {
                if (numTokens >= this.maxTokens) break;
                int col = this.setKeys(outputRow, keys, out);
                long count = tokenCounts.get(token).intValue();
                out.set(outputRow, col, token);
                out.set(outputRow, col + 1, count);
                ++outputRow;
                ++numTokens;
            }
            if (!this.applyPadding) continue;
            outputRow = this.applyPaddingLong(outputRow, numTokens, keys, out, "", -1);
        }
        return outputRow;
    }

    @Override
    public Types.ValueType[] getOutSchema() {
        if (this.wideFormat) {
            throw new IllegalArgumentException("Wide Format is not supported for Count Representation.");
        }
        Types.ValueType[] schema = UtilFunctions.nCopies(this.numIdCols + 2, Types.ValueType.STRING);
        schema[this.numIdCols + 1] = Types.ValueType.INT64;
        return schema;
    }
}

