2016-03-30 7 views

答えて

6

は、実際にはtreesの属性があります。

trees.head.transform(data).show(3) 
// +-----+--------------------+-------------+-----------+----------+ 
// |label|   features|rawPrediction|probability|prediction| 
// +-----+--------------------+-------------+-----------+----------+ 
// | 0.0|(692,[127,128,129...| [33.0,0.0]| [1.0,0.0]|  0.0| 
// | 1.0|(692,[158,159,160...| [0.0,59.0]| [0.0,1.0]|  1.0| 
// | 1.0|(692,[124,125,126...| [0.0,59.0]| [0.0,1.0]|  1.0| 
// +-----+--------------------+-------------+-----------+----------+ 
// only showing top 3 rows 

import org.apache.spark.ml.attribute.NominalAttribute 
import org.apache.spark.ml.classification.{ 
    RandomForestClassificationModel, RandomForestClassifier, 
    DecisionTreeClassificationModel 
} 

val meta = NominalAttribute 
    .defaultAttr 
    .withName("label") 
    .withValues("0.0", "1.0") 
    .toMetadata 

val data = sqlContext.read.format("libsvm") 
    .load("data/mllib/sample_libsvm_data.txt") 
    .withColumn("label", $"label".as("label", meta)) 

val rf: RandomForestClassifier = new RandomForestClassifier() 
    .setLabelCol("label") 
    .setFeaturesCol("features") 

val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect { 
    case t: DecisionTreeClassificationModel => t 
} 

を、あなたは唯一の問題は、右のように、我々は実際にこれらの使用できるタイプを取得することで見ることができるように

パイプラインで作業する場合は、個々のツリーも抽出できます。

import org.apache.spark.ml.Pipeline 

val model = new Pipeline().setStages(Array(rf)).fit(data) 

// There is only one stage and know its type 
// but lets be thorough 
val rfModelOption = model.stages.headOption match { 
    case Some(m: RandomForestClassificationModel) => Some(m) 
    case _ => None 
} 

val trees = rfModelOption.map { 
    _.trees // ... as before 
}.getOrElse(Array()) 
+0

こんにちはzero323、ありがとうございます。私はフォローアップの質問があります。私は木ノードから予測確率が高い(たとえば0.3以上)ルールを抽出したいと考えています。 'spark.ml'では、オブジェクトの不純物値はツリーの内部ノードでプライベートであり、メソッドtoOldとfromOldもそうです。私は何かを引き出すことができるように、私はそれらがプライベートなのでアクセスできない詳細が必要です。同様に、ノードの分割は、そのカテゴリーおよび機能閾値についての情報を提供しない。 'spark.ml'の高確率ノードからルールを抽出する方法はありますか? –

+0

私は些細な解決策に気づいていません。あなたはそれを別として尋ねるべきです - 多分誰かがすでに解決策を持っているかもしれません。もしあなたがリンクを私にpingしてください。 – zero323

+0

ありがとうございましたzero323。 Spark ML RandomForestClassifierモデル(Scalaバージョン)からルールを抽出する方法を質問しました。私は答えが得られたら更新していきます。 –

関連する問題