/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.classification.logisticregression;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelParams;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.servable.api.DataFrame;
import org.apache.flink.ml.servable.api.ModelServable;
import org.apache.flink.ml.servable.api.Row;
import org.apache.flink.ml.servable.types.BasicType;
import org.apache.flink.ml.servable.types.DataTypes;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ServableReadWriteUtils;
import org.apache.flink.util.Preconditions;

public class LogisticRegressionModelServable
implements ModelServable<LogisticRegressionModelServable>,
LogisticRegressionModelParams<LogisticRegressionModelServable> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private LogisticRegressionModelData modelData;

    public LogisticRegressionModelServable() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    LogisticRegressionModelServable(LogisticRegressionModelData modelData) {
        this();
        this.modelData = modelData;
    }

    @Override
    public DataFrame transform(DataFrame input) {
        ArrayList<Double> predictionResults = new ArrayList<Double>();
        ArrayList<DenseVector> rawPredictionResults = new ArrayList<DenseVector>();
        int featuresColIndex = input.getIndex(this.getFeaturesCol());
        for (Row row : input.collect()) {
            Vector features = (Vector)row.get(featuresColIndex);
            Tuple2<Double, DenseVector> dataPoint = this.transform(features);
            predictionResults.add((Double)dataPoint.f0);
            rawPredictionResults.add((DenseVector)dataPoint.f1);
        }
        input.addColumn(this.getPredictionCol(), DataTypes.DOUBLE, predictionResults);
        input.addColumn(this.getRawPredictionCol(), DataTypes.VECTOR(BasicType.DOUBLE), rawPredictionResults);
        return input;
    }

    @Override
    public LogisticRegressionModelServable setModelData(InputStream ... modelDataInputs) throws IOException {
        Preconditions.checkArgument((modelDataInputs.length == 1 ? 1 : 0) != 0);
        this.modelData = LogisticRegressionModelData.decode(modelDataInputs[0]);
        return this;
    }

    public static LogisticRegressionModelServable load(String path) throws IOException {
        LogisticRegressionModelServable servable = ServableReadWriteUtils.loadServableParam(path, LogisticRegressionModelServable.class);
        try (InputStream modelData = ServableReadWriteUtils.loadModelData(path);){
            servable.setModelData(modelData);
            LogisticRegressionModelServable logisticRegressionModelServable = servable;
            return logisticRegressionModelServable;
        }
    }

    protected Tuple2<Double, DenseVector> transform(Vector feature) {
        double dotValue = BLAS.dot(feature, (Vector)this.modelData.coefficient);
        double prob = 1.0 - 1.0 / (1.0 + Math.exp(dotValue));
        return Tuple2.of((Object)(dotValue >= 0.0 ? 1.0 : 0.0), (Object)Vectors.dense(1.0 - prob, prob));
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}

