2013-07-29 16 views
9

Pipeline()の中に包まれたscikit-learnクラシファイアでpartial_fit()をどのように呼び出しますか?Scikitパイプラインでpartial_fitを使用

私のようなSGDClassifierを使用してインクリメンタルトレーニング可能なテキスト分類器を構築しようとしている:

from sklearn.linear_model import SGDClassifier 
from sklearn.pipeline import Pipeline 
from sklearn.feature_extraction.text import HashingVectorizer 
from sklearn.feature_extraction.text import TfidfTransformer 
from sklearn.multiclass import OneVsRestClassifier 

classifier = Pipeline([ 
    ('vectorizer', HashingVectorizer(ngram_range=(1,4), non_negative=True)), 
    ('tfidf', TfidfTransformer()), 
    ('clf', OneVsRestClassifier(SGDClassifier())), 
]) 

が、私はAttributeErrorclassifier.partial_fit(x,y)を呼び出そうとします。

fit()をサポートしていますので、なぜ私は表示されませんpartial_fit()は利用できません。パイプラインをイントロスペクトし、データトランスフォーマを呼び出してから、自分のクラシファイアに直接partial_fit()を呼び出すことは可能でしょうか?

+0

で素敵果たし変換ステップを可能にしますが、最終的ソリューを思い付くでしたこれについては? – GreenGodot

答えて

5

パイプラインはpartial_fitを使用しないため、公開しません。コア外の計算に専用のパイプライン方式が必要なのかもしれませんが、これも以前のモデルの機能に依存しています。

特にこのケースでは、パイプラインの各ステージに合ったデータをいくつかパスし、次にステートレスな最初のステージを除いて、次のデータセットに合わせてデータセットを変換することをお勧めします。したがって、データのパラメータに適合しません。

あなたのニーズに合わせて独自のラッパーコードをロールするほうが簡単です。

+1

自分自身を転がす方法をお勧めできますか?パイプラインのtransform()メソッドを使ってみましたが、分類子を抽出し、変換されたデータをpartial_fit()に渡しましたが、tdfベクトルが定義されていないというエラーが出ます。 – Cerin

+3

[Pipelineクラスのソースコード](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/pipeline.py#L26)と[この例](http:// scikit-learn.org/dev/auto_examples/applications/plot_out_of_core_classification.html)。次に、[テキストの特徴抽出とハッシュトリック]のドキュメントを読んでください(http://scikit-learn.org/dev/modules/feature_extraction.html#vectorizing-a-large-text-corpus-with-the-hashing-trick )を使用して、ステートフルなフィーチャ抽出に関する問題を完全に理解できるようにします。実装は、あなたが解決しようとしている問題によって異なります。 – ogrisel

+0

特に、ステートフルトランスを 'TfidfTransformer'として使用する場合は、データをいくつか渡す必要があります。 – ogrisel

6

私がやっていることは - ここでmapperとclfはPipeline objの2ステップです。

def partial_pipe_fit(pipeline_obj, df): 
    X = pipeline_obj.named_steps['mapper'].fit_transform(df) 
    Y = df['class'] 
    pipeline_obj.named_steps['clf'].partial_fit(X,Y) 

あなたはおそらく、あなたのクラシファイアを更新/調整保つよう、パフォーマンスを追跡したい - それはそうより具体的に二点

ある -

を以下のように、元のパイプライン(複数可)を構築しました
to_vect = Pipeline([('vect', CountVectorizer(min_df=2, max_df=.9, ngram_range=(1, 1), max_features = 100)), 
          ('tfidf', TfidfTransformer())]) 
full_mapper = DataFrameMapper([ 
      ('norm_text', to_vect), 
      ('norm_fname', to_vect), ]) 

full_pipe = Pipeline([('mapper', full_mapper), ('clf', SGDClassifier(n_iter=15, warm_start=True, 
                   n_jobs=-1, random_state=self.random_state))]) 

グーグルDataFrameMapperそれについての詳細を学ぶために - しかし、ここではそれだけでパンダ

関連する問題