2017-07-31 3 views
0

私はスパイアmllibを使用してナイーブベイのクラシファイアモデルをトレーニングしています。ここでは、文字列フィーチャのインデックスを作成するパイプラインを作成し、次元削減のためにPCAを正規化して適用します。私はパイプラインを実行すると、私はPCAのコンポーネントvector.Onグーグルで私は正ベクトルを得るためにNMF(非負行列の因子分解)を適用する必要があり、私はALSがメソッドでNMFを実装することがわかったことがわかった.setnonnegative(true)しかし、私はPCAの後に私のパイプラインにALSを統合する方法を知らない。どんな助けもありがたい。ありがとう。ここ非負行列分解を実装するためのスパークパイプラインにALSを統合する方法は?

はコード

import org.apache.spark.SparkConf; 
import org.apache.spark.SparkContext; 
import org.apache.spark.api.java.JavaSparkContext; 
import org.apache.spark.ml.Pipeline; 
import org.apache.spark.ml.PipelineModel; 
import org.apache.spark.ml.PipelineStage; 
import org.apache.spark.ml.classification.NaiveBayes; 
import org.apache.spark.ml.feature.IndexToString; 
import org.apache.spark.ml.feature.Normalizer; 
import org.apache.spark.ml.feature.PCA; 
import org.apache.spark.ml.feature.StringIndexer; 
import org.apache.spark.ml.feature.StringIndexerModel; 
import org.apache.spark.ml.feature.VectorAssembler; 
import org.apache.spark.ml.recommendation.ALS; 
import org.apache.spark.sql.DataFrame; 
import org.apache.spark.sql.SQLContext; 

public class NBTrainPCA { 
    public static void main(String args[]){ 
     try{ 
      SparkConf conf = new SparkConf().setAppName("NBTrain"); 
      SparkContext scc = new SparkContext(conf); 
      scc.setLogLevel("ERROR"); 
      JavaSparkContext sc = new JavaSparkContext(scc); 
      SQLContext sqlc = new SQLContext(scc); 
      DataFrame traindata = sqlc.read().format("parquet").load(args[0]).filter("user_email!='NA' and user_email!='00' and user_email!='0ed709b5bec77b6bff96ea5b5e334a8e5' and user_email is not null and ip is not null and region_code is not null and city is not null and browser_name is not null and os_name is not null"); 
      traindata.registerTempTable("master"); 
      //DataFrame data = sqlc.sql("select user_email,user_device,ip,country_code,region_code,city,zip_code,time_zone,browser_name,browser_manf,os_name,os_manf from master where user_email!='NA' and user_email is not null and user_device is not null and ip is not null and country_code is not null and region_code is not null and city is not null and browser_name is not null and browser_manf is not null and zip_code is not null and time_zone is not null and os_name is not null and os_manf is not null"); 
      StringIndexerModel emailIndexer = new StringIndexer() 
       .setInputCol("user_email") 
       .setOutputCol("email_index") 
       .setHandleInvalid("skip") 
       .fit(traindata); 
      StringIndexer udevIndexer = new StringIndexer() 
       .setInputCol("user_device") 
       .setOutputCol("udev_index") 
       .setHandleInvalid("skip"); 
      StringIndexer ipIndexer = new StringIndexer() 
       .setInputCol("ip") 
       .setOutputCol("ip_index") 
       .setHandleInvalid("skip"); 
      StringIndexer ccodeIndexer = new StringIndexer() 
       .setInputCol("country_code") 
       .setOutputCol("ccode_index") 
       .setHandleInvalid("skip"); 
      StringIndexer rcodeIndexer = new StringIndexer() 
       .setInputCol("region_code") 
       .setOutputCol("rcode_index") 
       .setHandleInvalid("skip"); 
      StringIndexer cyIndexer = new StringIndexer() 
       .setInputCol("city") 
       .setOutputCol("cy_index") 
       .setHandleInvalid("skip"); 
      StringIndexer zpIndexer = new StringIndexer() 
       .setInputCol("zip_code") 
       .setOutputCol("zp_index") 
       .setHandleInvalid("skip"); 
      StringIndexer tzIndexer = new StringIndexer() 
       .setInputCol("time_zone") 
       .setOutputCol("tz_index") 
       .setHandleInvalid("skip"); 
      StringIndexer bnIndexer = new StringIndexer() 
       .setInputCol("browser_name") 
       .setOutputCol("bn_index") 
       .setHandleInvalid("skip"); 
      StringIndexer bmIndexer = new StringIndexer() 
       .setInputCol("browser_manf") 
       .setOutputCol("bm_index") 
       .setHandleInvalid("skip"); 
      StringIndexer bvIndexer = new StringIndexer() 
       .setInputCol("browser_version") 
       .setOutputCol("bv_index") 
       .setHandleInvalid("skip"); 
      StringIndexer onIndexer = new StringIndexer() 
       .setInputCol("os_name") 
       .setOutputCol("on_index") 
       .setHandleInvalid("skip"); 
      StringIndexer omIndexer = new StringIndexer() 
       .setInputCol("os_manf") 
       .setOutputCol("om_index") 
       .setHandleInvalid("skip"); 
      VectorAssembler assembler = new VectorAssembler() 
       .setInputCols(new String[]{ "udev_index","ip_index","ccode_index","rcode_index","cy_index","zp_index","tz_index","bn_index","bm_index","bv_index","on_index","om_index"}) 
       .setOutputCol("ffeatures"); 
      Normalizer normalizer = new Normalizer() 
       .setInputCol("ffeatures") 
       .setOutputCol("sfeatures") 
       .setP(1.0); 
      PCA pca = new PCA() 
       .setInputCol("sfeatures") 
       .setOutputCol("pcafeatures") 
       .setK(5); 
      NaiveBayes nbcl = new NaiveBayes() 
      .setFeaturesCol("pcafeatures") 
      .setLabelCol("email_index") 
      .setSmoothing(1.0); 
      IndexToString is = new IndexToString() 
      .setInputCol("prediction") 
      .setOutputCol("op") 
      .setLabels(emailIndexer.labels()); 
      Pipeline pipeline = new Pipeline() 
       .setStages(new PipelineStage[] {emailIndexer,udevIndexer,ipIndexer,ccodeIndexer,rcodeIndexer,cyIndexer,zpIndexer,tzIndexer,bnIndexer,bmIndexer,bvIndexer,onIndexer,omIndexer,assembler,normalizer,pca,nbcl,is}); 
      PipelineModel model = pipeline.fit(traindata); 
      //DataFrame chidata = model.transform(data); 
      //chidata.write().format("com.databricks.spark.csv").save(args[1]); 
      model.write().overwrite().save(args[1]); 
      sc.close(); 
      } 
      catch(Exception e){ 

      } 
    } 
} 

答えて

0

は、あなたがそれをやっているのより良い感覚を得ることができますので、私はPCAについて少し読むことをお勧めします。ここではいくつかのリンク:あなただけの他の後に一つのことをプラグインするようにALSの統合に

https://stats.stackexchange.com/questions/26352/interpreting-positive-and-negative-signs-of-the-elements-of-pca-eigenvectors

https://stats.stackexchange.com/questions/2691/making-sense-of-principal-component-analysis-eigenvectors-eigenvalues

あなたのパイプラインには思えます。それぞれが何をして、何をしているのかを理解することをお勧めします。ALSとPCAは全く異なるものです。 ALSは、誤差最小化のためにAlSを使用して行列分解を行っており、データに変換を適用するための主成分を見つけていないか、次元削減を行っています。

BTW:PCAコンポーネントのベクトルに負の値が表示されることはありません。これは上記のリンクで確認できます。データに線形変換を適用しています。新しいベクトルは今や変換の結果です。 私はそれが助けてくれることを願っています。

+0

PCAコンポーネントベクトルで負の値を取得する際に問題が発生しました。純粋なベイは機能セットで負の値を占めていません。それは正確な問題です。 –

+0

このリンクを参照してくださいhttps://stackoverflow.com/questions/36491852/using-pca-before-bayes-classificition/36491982 –

+0

コメントを読む:「スパークで実装されたNMFは、元の行列を因数分解する際に直交性を考慮しませんあなたのアプリケーションではうまく機能しないかもしれません。 ALSマトリックス因子分解は、PCAの近くに何もしていない。 –

関連する問題