/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;

public class LibMatrixAggUnarySpecialization {
    protected static final Log LOG = LogFactory.getLog((String)LibMatrixAggUnarySpecialization.class.getName());

    public static void aggregateUnary(MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, MatrixIndexes indexesIn) {
        if (op.sparseSafe) {
            LibMatrixAggUnarySpecialization.sparseAggregateUnaryHelp(mb, op, result, blen, indexesIn);
        } else {
            LibMatrixAggUnarySpecialization.denseAggregateUnaryHelp(mb, op, result, blen, indexesIn);
        }
    }

    private static void sparseAggregateUnaryHelp(MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, MatrixIndexes indexesIn) {
        block6: {
            KahanObject buffer;
            MatrixValue.CellIndex tempCellIndex;
            block5: {
                if (op.aggOp.initialValue != 0.0) {
                    result.reset(result.rlen, result.clen, op.aggOp.initialValue);
                }
                tempCellIndex = new MatrixValue.CellIndex(-1, -1);
                buffer = new KahanObject(0.0, 0.0);
                if (!mb.sparse || mb.sparseBlock == null) break block5;
                SparseBlock a = mb.sparseBlock;
                for (int r = 0; r < Math.min(mb.rlen, a.numRows()); ++r) {
                    if (a.isEmpty(r)) continue;
                    int apos = a.pos(r);
                    int alen = a.size(r);
                    int[] aix = a.indexes(r);
                    double[] aval = a.values(r);
                    for (int i = apos; i < apos + alen; ++i) {
                        tempCellIndex.set(r, aix[i]);
                        op.indexFn.execute(tempCellIndex, tempCellIndex);
                        LibMatrixAggUnarySpecialization.incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, aval[i], buffer);
                    }
                }
                break block6;
            }
            if (mb.sparse || mb.denseBlock == null) break block6;
            DenseBlock a = mb.getDenseBlock();
            for (int i = 0; i < mb.rlen; ++i) {
                for (int j = 0; j < mb.clen; ++j) {
                    tempCellIndex.set(i, j);
                    op.indexFn.execute(tempCellIndex, tempCellIndex);
                    LibMatrixAggUnarySpecialization.incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, a.get(i, j), buffer);
                }
            }
        }
    }

    private static void denseAggregateUnaryHelp(MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, MatrixIndexes indexesIn) {
        if (op.aggOp.initialValue != 0.0) {
            result.reset(result.rlen, result.clen, op.aggOp.initialValue);
        }
        MatrixValue.CellIndex tempCellIndex = new MatrixValue.CellIndex(-1, -1);
        KahanObject buffer = new KahanObject(0.0, 0.0);
        for (int i = 0; i < mb.rlen; ++i) {
            for (int j = 0; j < mb.clen; ++j) {
                tempCellIndex.set(i, j);
                op.indexFn.execute(tempCellIndex, tempCellIndex);
                LibMatrixAggUnarySpecialization.incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, mb.quickGetValue(i, j), buffer);
            }
        }
    }

    private static void incrementalAggregateUnaryHelp(AggregateOperator aggOp, MatrixBlock result, int row, int column, double newvalue, KahanObject buffer) {
        if (aggOp.existsCorrection()) {
            if (aggOp.correction == Types.CorrectionLocationType.LASTROW || aggOp.correction == Types.CorrectionLocationType.LASTCOLUMN) {
                int corRow = row;
                int corCol = column;
                if (aggOp.correction == Types.CorrectionLocationType.LASTROW) {
                    ++corRow;
                } else if (aggOp.correction == Types.CorrectionLocationType.LASTCOLUMN) {
                    ++corCol;
                } else {
                    throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction);
                }
                buffer._sum = result.quickGetValue(row, column);
                buffer._correction = result.quickGetValue(corRow, corCol);
                buffer = (KahanObject)aggOp.increOp.fn.execute((Data)buffer, newvalue);
                result.quickSetValue(row, column, buffer._sum);
                result.quickSetValue(corRow, corCol, buffer._correction);
            } else {
                if (aggOp.correction == Types.CorrectionLocationType.NONE) {
                    throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction);
                }
                int corRow = row;
                int corCol = column;
                int countRow = row;
                int countCol = column;
                if (aggOp.correction == Types.CorrectionLocationType.LASTTWOROWS) {
                    ++countRow;
                    corRow += 2;
                } else if (aggOp.correction == Types.CorrectionLocationType.LASTTWOCOLUMNS) {
                    ++countCol;
                    corCol += 2;
                } else {
                    throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction);
                }
                buffer._sum = result.quickGetValue(row, column);
                buffer._correction = result.quickGetValue(corRow, corCol);
                double count = result.quickGetValue(countRow, countCol) + 1.0;
                buffer = (KahanObject)aggOp.increOp.fn.execute(buffer, newvalue, count);
                result.quickSetValue(row, column, buffer._sum);
                result.quickSetValue(corRow, corCol, buffer._correction);
                result.quickSetValue(countRow, countCol, count);
            }
        } else {
            newvalue = aggOp.increOp.fn.execute(result.quickGetValue(row, column), newvalue);
            result.quickSetValue(row, column, newvalue);
        }
    }
}

