/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.cplan;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.runtime.util.UtilFunctions;

public class CNodeNary
extends CNode {
    private final NaryType _type;

    public CNodeNary(CNode[] inputs, NaryType type) {
        for (CNode in : inputs) {
            this._inputs.add(in);
        }
        this._type = type;
        this.setOutputDims();
    }

    public NaryType getType() {
        return this._type;
    }

    @Override
    public String codegen(boolean sparse, SpoofCompiler.GeneratorAPI api) {
        if (this.isGenerated()) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        for (CNode in : this._inputs) {
            sb.append(in.codegen(sparse, api));
        }
        boolean lsparse = sparse && this._inputs.get(0) instanceof CNodeData && ((CNode)this._inputs.get(0)).getVarname().startsWith("a") && !((CNode)this._inputs.get(0)).isLiteral();
        String var = this.createVarname();
        String tmp = this._type.getTemplate(lsparse, this._cols, this._inputs, api);
        tmp = tmp.replace("%TMP%", var);
        String varj1 = ((CNode)this._inputs.get(0)).getVarname();
        String varj2 = ((CNode)this._inputs.get(1)).getVarname();
        tmp = this._type == NaryType.VECT_CONV2DMM ? this.replaceBinaryPlaceholders(tmp, new String[]{varj1, varj2}, false, api) : this.replaceUnaryPlaceholders(tmp, varj1, false, api);
        sb.append(tmp);
        this._generated = true;
        return sb.toString();
    }

    public String toString() {
        switch (this._type) {
            case VECT_CBIND: {
                return "n(cbind)";
            }
            case VECT_MAX_POOL: {
                return "n(maxpool)";
            }
            case VECT_AVG_POOL: {
                return "n(avgpool)";
            }
            case VECT_IM2COL: {
                return "n(im2col)";
            }
            case VECT_CONV2DMM: {
                return "n(conv2dmm)";
            }
        }
        return "m(" + this._type.name().toLowerCase() + ")";
    }

    @Override
    public void setOutputDims() {
        switch (this._type) {
            case VECT_CBIND: {
                this._rows = ((CNode)this._inputs.get((int)0))._rows;
                this._cols = 0L;
                for (CNode in : this._inputs) {
                    this._cols += in._cols;
                }
                this._dataType = Types.DataType.MATRIX;
                break;
            }
            case VECT_MAX_POOL: 
            case VECT_AVG_POOL: {
                int C = Integer.parseInt(((CNode)this._inputs.get(6)).getVarname());
                int H = Integer.parseInt(((CNode)this._inputs.get(7)).getVarname());
                int W = Integer.parseInt(((CNode)this._inputs.get(8)).getVarname());
                int R = Integer.parseInt(((CNode)this._inputs.get(11)).getVarname());
                int S = Integer.parseInt(((CNode)this._inputs.get(12)).getVarname());
                long P = DnnUtils.getP(H, R, 1L, 0L);
                long Q = DnnUtils.getQ(W, S, 1L, 0L);
                this._rows = ((CNode)this._inputs.get((int)0))._rows;
                this._cols = (long)C * P * Q;
                this._dataType = Types.DataType.MATRIX;
                break;
            }
            case VECT_IM2COL: {
                this._rows = 1L;
                this._cols = -1L;
                this._dataType = Types.DataType.MATRIX;
                break;
            }
            case VECT_CONV2DMM: {
                int H = Integer.parseInt(((CNode)this._inputs.get(8)).getVarname());
                int W = Integer.parseInt(((CNode)this._inputs.get(9)).getVarname());
                int K = Integer.parseInt(((CNode)this._inputs.get(10)).getVarname());
                int R = Integer.parseInt(((CNode)this._inputs.get(12)).getVarname());
                int S = Integer.parseInt(((CNode)this._inputs.get(13)).getVarname());
                long P = DnnUtils.getP(H, R, 1L, 0L);
                long Q = DnnUtils.getQ(W, S, 1L, 0L);
                this._rows = ((CNode)this._inputs.get((int)0))._rows;
                this._cols = (long)K * P * Q;
                this._dataType = Types.DataType.MATRIX;
            }
        }
    }

    @Override
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode());
        }
        return this._hash;
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof CNodeNary)) {
            return false;
        }
        CNodeNary that = (CNodeNary)o;
        return super.equals(that) && this._type == that._type;
    }

    @Override
    public boolean isSupported(SpoofCompiler.GeneratorAPI api) {
        boolean is_supported = api == SpoofCompiler.GeneratorAPI.JAVA;
        int i = 0;
        while (is_supported && i < this._inputs.size()) {
            CNode in = (CNode)this._inputs.get(i++);
            is_supported = in.isSupported(api);
        }
        return is_supported;
    }

    private static String getDnnParameterString(List<CNode> inputs, boolean unary) {
        int off = unary ? 0 : 1;
        int C = Integer.parseInt(inputs.get(off + 6).getVarname());
        int H = Integer.parseInt(inputs.get(off + 7).getVarname());
        int W = Integer.parseInt(inputs.get(off + 8).getVarname());
        int K = Integer.parseInt(inputs.get(off + 9).getVarname());
        int R = Integer.parseInt(inputs.get(off + 11).getVarname());
        int S = Integer.parseInt(inputs.get(off + 12).getVarname());
        int P = (int)DnnUtils.getP(H, R, 1L, 0L);
        int Q = (int)DnnUtils.getQ(W, S, 1L, 0L);
        return "rix, " + StringUtils.join((int[])new int[]{C, P, Q, K, R, S, H, W}, (char)',');
    }

    private String replaceBinaryPlaceholders(String tmp, String[] vars, boolean vectIn, SpoofCompiler.GeneratorAPI api) {
        for (int j = 0; j < 2; ++j) {
            String varj = vars[j];
            tmp = tmp.replace("%IN" + (j + 1) + "v%", varj + "vals");
            tmp = tmp.replace("%IN" + (j + 1) + "i%", varj + "ix");
            tmp = tmp.replace("%IN" + (j + 1) + "%", (CharSequence)(varj.startsWith("b") ? (api == SpoofCompiler.GeneratorAPI.JAVA ? varj + ".values(rix)" : varj + ".vals(0)") : varj));
            tmp = tmp.replace("%POS" + (j + 1) + "%", (CharSequence)(this._inputs.get(j) instanceof CNodeData && ((CNode)this._inputs.get(j)).getDataType().isMatrix() ? (!varj.startsWith("b") ? varj + "i" : (TemplateUtils.isMatrix((CNode)this._inputs.get(j)) && this._type != NaryType.VECT_CONV2DMM ? varj + ".pos(rix)" : "0")) : "0"));
        }
        if (((CNode)this._inputs.get(0)).getDataType().isMatrix()) {
            tmp = tmp.replace("%LEN%", ((CNode)this._inputs.get(0)).getVectorLength(api));
        }
        return tmp;
    }

    public static enum NaryType {
        VECT_CBIND,
        VECT_MAX_POOL,
        VECT_AVG_POOL,
        VECT_IM2COL,
        VECT_CONV2DMM;


        public static boolean contains(String value) {
            for (NaryType bt : NaryType.values()) {
                if (!bt.name().equals(value)) continue;
                return true;
            }
            return false;
        }

        public String getTemplate(boolean sparseGen, long len, ArrayList<CNode> inputs, SpoofCompiler.GeneratorAPI api) {
            switch (this) {
                case VECT_CBIND: {
                    StringBuilder sb = new StringBuilder();
                    sb.append("    double[] %TMP% = LibSpoofPrimitives.allocVector(" + len + ", true); //nary cbind\n");
                    int off = 0;
                    for (int i = 0; i < inputs.size(); ++i) {
                        CNode input = inputs.get(i);
                        boolean sparseInput = sparseGen && input instanceof CNodeData && input.getVarname().startsWith("a");
                        String varj = input.getVarname();
                        if (input.getDataType() == Types.DataType.MATRIX) {
                            String pos;
                            Object object = input instanceof CNodeData ? (!varj.startsWith("b") ? varj + "i" : varj + ".pos(rix)") : (pos = "0");
                            sb.append(sparseInput ? "    LibSpoofPrimitives.vectWrite(" + varj + "vals, %TMP%, " + varj + "ix, " + pos + ", " + off + ", " + input._cols + ");\n" : "    LibSpoofPrimitives.vectWrite(" + (String)(varj.startsWith("b") ? varj + ".values(rix)" : varj) + ", %TMP%, " + pos + ", " + off + ", " + input._cols + ");\n");
                            off = (int)((long)off + input._cols);
                            continue;
                        }
                        sb.append("    %TMP%[" + off + "] = " + varj + ";\n");
                        ++off;
                    }
                    return sb.toString();
                }
                case VECT_MAX_POOL: 
                case VECT_AVG_POOL: {
                    String vectName = this == VECT_MAX_POOL ? "Maxpool" : "Avgpool";
                    String paramStr = CNodeNary.getDnnParameterString(inputs, true);
                    return sparseGen ? "    double[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN1i%, %POS1%, alen, len, " + paramStr + ");\n" : "    double[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %POS1%, %LEN%, " + paramStr + ");\n";
                }
                case VECT_IM2COL: {
                    String paramStr = CNodeNary.getDnnParameterString(inputs, true);
                    return sparseGen ? "    double[] %TMP% = LibSpoofPrimitives.vectIm2colWrite(%IN1v%, %IN1i%, %POS1%, alen, len, " + paramStr + ");\n" : "    double[] %TMP% = LibSpoofPrimitives.vectIm2colWrite(%IN1%, %POS1%, %LEN%, " + paramStr + ");\n";
                }
                case VECT_CONV2DMM: {
                    return "    double[] %TMP% = LibSpoofPrimitives.vectConv2dmmWrite(%IN2%, %IN1%, %POS2%, %POS1%, %LEN%, " + CNodeNary.getDnnParameterString(inputs, false) + ");\n";
                }
            }
            throw new RuntimeException("Invalid nary type: " + this.toString());
        }

        public boolean isVectorPrimitive() {
            return this == VECT_CBIND || this == VECT_MAX_POOL || this == VECT_AVG_POOL || this == VECT_IM2COL || this == VECT_CONV2DMM;
        }
    }
}

