/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class JavaEstimatorTransformerParamExample {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("JavaEstimatorTransformerParamExample").getOrCreate();
        List<Row> dataTraining = Arrays.asList(RowFactory.create((Object[])new Object[]{1.0, Vectors.dense((double)0.0, (double[])new double[]{1.1, 0.1})}), RowFactory.create((Object[])new Object[]{0.0, Vectors.dense((double)2.0, (double[])new double[]{1.0, -1.0})}), RowFactory.create((Object[])new Object[]{0.0, Vectors.dense((double)2.0, (double[])new double[]{1.3, 1.0})}), RowFactory.create((Object[])new Object[]{1.0, Vectors.dense((double)0.0, (double[])new double[]{1.2, -0.5})}));
        StructType schema = new StructType(new StructField[]{new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", (DataType)new VectorUDT(), false, Metadata.empty())});
        Dataset training = spark.createDataFrame(dataTraining, schema);
        LogisticRegression lr = new LogisticRegression();
        System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n");
        lr.setMaxIter(10).setRegParam(0.01);
        LogisticRegressionModel model1 = (LogisticRegressionModel)lr.fit(training);
        System.out.println("Model 1 was fit using parameters: " + String.valueOf(model1.parent().extractParamMap()));
        ParamMap paramMap = new ParamMap().put(new ParamPair[]{lr.maxIter().w(20)}).put((Param)lr.maxIter(), (Object)30).put(new ParamPair[]{lr.regParam().w(0.1), lr.threshold().w(0.55)});
        ParamMap paramMap2 = new ParamMap().put(new ParamPair[]{lr.probabilityCol().w((Object)"myProbability")});
        ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
        LogisticRegressionModel model2 = (LogisticRegressionModel)lr.fit(training, paramMapCombined);
        System.out.println("Model 2 was fit using parameters: " + String.valueOf(model2.parent().extractParamMap()));
        List<Row> dataTest = Arrays.asList(RowFactory.create((Object[])new Object[]{1.0, Vectors.dense((double)-1.0, (double[])new double[]{1.5, 1.3})}), RowFactory.create((Object[])new Object[]{0.0, Vectors.dense((double)3.0, (double[])new double[]{2.0, -0.1})}), RowFactory.create((Object[])new Object[]{1.0, Vectors.dense((double)0.0, (double[])new double[]{2.2, -1.5})}));
        Dataset test = spark.createDataFrame(dataTest, schema);
        Dataset results = model2.transform(test);
        Dataset rows = results.select("features", new String[]{"label", "myProbability", "prediction"});
        for (Row r : rows.collectAsList()) {
            System.out.println("(" + String.valueOf(r.get(0)) + ", " + String.valueOf(r.get(1)) + ") -> prob=" + String.valueOf(r.get(2)) + ", prediction=" + String.valueOf(r.get(3)));
        }
        spark.stop();
    }
}

