/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;

public class JavaModelSelectionViaTrainValidationSplitExample {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("JavaModelSelectionViaTrainValidationSplitExample").getOrCreate();
        Dataset data = spark.read().format("libsvm").load("data/mllib/sample_linear_regression_data.txt");
        Dataset[] splits = data.randomSplit(new double[]{0.9, 0.1}, 12345L);
        Dataset training = splits[0];
        Dataset test = splits[1];
        LinearRegression lr = new LinearRegression();
        ParamMap[] paramGrid = new ParamGridBuilder().addGrid(lr.regParam(), new double[]{0.1, 0.01}).addGrid(lr.fitIntercept()).addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0}).build();
        TrainValidationSplit trainValidationSplit = new TrainValidationSplit().setEstimator((Estimator)lr).setEvaluator((Evaluator)new RegressionEvaluator()).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8).setParallelism(2);
        TrainValidationSplitModel model = trainValidationSplit.fit(training);
        model.transform(test).select("features", new String[]{"label", "prediction"}).show();
        spark.stop();
    }
}

