4
Spark MLのRandomForestClassifierによって生成されたモデル内の個々のツリーにアクセスする方法は? ScalaバージョンのRandomForestClassifierを使用しています。RandomForestClassifier(spark.ml-version)によって作成されたモデル内の個々のツリーにアクセスする方法は?
Spark MLのRandomForestClassifierによって生成されたモデル内の個々のツリーにアクセスする方法は? ScalaバージョンのRandomForestClassifierを使用しています。RandomForestClassifier(spark.ml-version)によって作成されたモデル内の個々のツリーにアクセスする方法は?
は、実際にはtrees
の属性があります。
trees.head.transform(data).show(3)
// +-----+--------------------+-------------+-----------+----------+
// |label| features|rawPrediction|probability|prediction|
// +-----+--------------------+-------------+-----------+----------+
// | 0.0|(692,[127,128,129...| [33.0,0.0]| [1.0,0.0]| 0.0|
// | 1.0|(692,[158,159,160...| [0.0,59.0]| [0.0,1.0]| 1.0|
// | 1.0|(692,[124,125,126...| [0.0,59.0]| [0.0,1.0]| 1.0|
// +-----+--------------------+-------------+-----------+----------+
// only showing top 3 rows
注:
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.{
RandomForestClassificationModel, RandomForestClassifier,
DecisionTreeClassificationModel
}
val meta = NominalAttribute
.defaultAttr
.withName("label")
.withValues("0.0", "1.0")
.toMetadata
val data = sqlContext.read.format("libsvm")
.load("data/mllib/sample_libsvm_data.txt")
.withColumn("label", $"label".as("label", meta))
val rf: RandomForestClassifier = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect {
case t: DecisionTreeClassificationModel => t
}
を、あなたは唯一の問題は、右のように、我々は実際にこれらの使用できるタイプを取得することで見ることができるように
パイプラインで作業する場合は、個々のツリーも抽出できます。
import org.apache.spark.ml.Pipeline
val model = new Pipeline().setStages(Array(rf)).fit(data)
// There is only one stage and know its type
// but lets be thorough
val rfModelOption = model.stages.headOption match {
case Some(m: RandomForestClassificationModel) => Some(m)
case _ => None
}
val trees = rfModelOption.map {
_.trees // ... as before
}.getOrElse(Array())
こんにちはzero323、ありがとうございます。私はフォローアップの質問があります。私は木ノードから予測確率が高い(たとえば0.3以上)ルールを抽出したいと考えています。 'spark.ml'では、オブジェクトの不純物値はツリーの内部ノードでプライベートであり、メソッドtoOldとfromOldもそうです。私は何かを引き出すことができるように、私はそれらがプライベートなのでアクセスできない詳細が必要です。同様に、ノードの分割は、そのカテゴリーおよび機能閾値についての情報を提供しない。 'spark.ml'の高確率ノードからルールを抽出する方法はありますか? –
私は些細な解決策に気づいていません。あなたはそれを別として尋ねるべきです - 多分誰かがすでに解決策を持っているかもしれません。もしあなたがリンクを私にpingしてください。 – zero323
ありがとうございましたzero323。 Spark ML RandomForestClassifierモデル(Scalaバージョン)からルールを抽出する方法を質問しました。私は答えが得られたら更新していきます。 –