2016-11-27 6 views
0

私は決定木を使って分類するために次のコードを持っています。私は、テストデータセットの予測をJava配列に取得し、それらを出力する必要があります。誰かがこのコードを拡張する手助けをすることができますか?私は予測されたラベルと実際のラベルの2D配列を持って、予測されたラベルを印刷する必要があります。Apache Spark決定木の予測

public class DecisionTreeClass { 
    public static void main(String args[]){ 
     SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeClass").setMaster("local[2]"); 
     JavaSparkContext jsc = new JavaSparkContext(sparkConf); 


     // Load and parse the data file. 
     String datapath = "/home/thamali/Desktop/tlib.txt"; 
     JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();//A training example used in supervised learning is called a “labeled point” in MLlib. 
     // Split the data into training and test sets (30% held out for testing) 
     JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); 
     JavaRDD<LabeledPoint> trainingData = splits[0]; 
     JavaRDD<LabeledPoint> testData = splits[1]; 

     // Set parameters. 
     // Empty categoricalFeaturesInfo indicates all features are continuous. 
     Integer numClasses = 12; 
     Map<Integer, Integer> categoricalFeaturesInfo = new HashMap(); 
     String impurity = "gini"; 
     Integer maxDepth = 5; 
     Integer maxBins = 32; 

     // Train a DecisionTree model for classification. 
     final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, 
       categoricalFeaturesInfo, impurity, maxDepth, maxBins); 

     // Evaluate model on test instances and compute test error 
     JavaPairRDD<Double, Double> predictionAndLabel = 
       testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { 
        @Override 
        public Tuple2<Double, Double> call(LabeledPoint p) { 
         return new Tuple2(model.predict(p.features()), p.label()); 
        } 
       }); 

     Double testErr = 
       1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { 
        @Override 
        public Boolean call(Tuple2<Double, Double> pl) { 
         return !pl._1().equals(pl._2()); 
        } 
       }).count()/testData.count(); 

     System.out.println("Test Error: " + testErr); 
     System.out.println("Learned classification tree model:\n" + model.toDebugString()); 


    } 

} 

答えて

1

あなたは基本的に予測変数とラベル変数を正確に持っています。

JavaRDD<double[]> valuesAndPreds = testData.map(point -> new double[]{model.predict(point.features()), point.label()}); 

と2Dダブル配列のリストについては、その参照にcollectを実行します:あなたは本当に2Dダブルアレイのリストを必要に応じて、あなたが使用する方法を変更することができます。

List<double[]> values = valuesAndPreds.collect(); 

私はここでドキュメントを見てみましょう:https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html。データを変更して、MulticlassMetricsなどのクラスを使用して、モデルの追加の統計的パフォーマンス測定値を取得することもできます。これには、mapToPair関数をマップ関数に変更し、ジェネリックをオブジェクトに変更する必要があります。だから、のようなもの:

JavaRDD<Tuple2<Object, Object>> valuesAndPreds = testData().map(point -> new Tuple2<>(model.predict(point.features()), point.label())); 

が次に実行されている:

MulticlassMetrics multiclassMetrics = new MulticlassMetrics(JavaRDD.toRDD(valuesAndPreds)); 

このようなもののすべてが非常によくスパークのMLLibのドキュメントに記載されています。また、結果を印刷する必要があると述べました。これが宿題であれば、リストからそれを行う方法を学ぶのは良い練習になるので、私はあなたにその部分を理解させます。

編集:

も、あなたは、Java 7を使用していることに気づいた、と私は持っていることは、あなたがどうなる2Dダブル配列、にオンにする方法で、あなたの主な質問に答えるために、Java 8からである。

JavaRDD<double[]> valuesAndPreds = testData.map(new org.apache.spark.api.java.function.Function<LabeledPoint, double[]>() { 
       @Override 
       public double[] call(LabeledPoint point) { 
        return new double[]{model.predict(point.features()), point.label()}; 
       } 
      }); 

次に、collectを実行して、2つの2倍のリストを取得します。また、印刷部分にヒントを与えるには、java.util.Arrays toString実装を見てください。