2017-10-01 11 views
2

scikit-learnのNMF(別名NNMF)を使用して自然言語データのトピック抽出を行っています。私はクラスター(別名コンポーネント)の数を最適化しようとしています。これを行うには、再構成誤差を計算する必要があります。しかし、scikit-learnを使用すると、私はトレーニングセットでこのメトリックを計算する方法しか見ることができません。しかし、私はテストセットにこれらのメトリックを取得することに興味があります。助言がありますか?scikit-learnのNMF(別名NNMF)のテストセットの再構築エラー。

答えて

2

外部データに対するSklearnのメカニズムをエミュレートするのは簡単です。

このエラーメトリックは、_beta_divergence(X, W, H, self.beta_loss, square_root=True)を使用してhereと計算されます。

W, Hを取得する方法については、API-docsで概要を説明しています。

sklearn >= 0.19(ここではこれが導入されている)を想定して、使用法を単純にコピーできます。ここで

いっぱいデモです:

from sklearn.datasets import fetch_20newsgroups_vectorized 
from sklearn.decomposition import NMF 
from sklearn.decomposition.nmf import _beta_divergence # needs sklearn 0.19!!! 

""" Test-data """ 
bunch_train = fetch_20newsgroups_vectorized('train') 
bunch_test = fetch_20newsgroups_vectorized('test') 
X_train = bunch_train.data 
X_test = bunch_test.data 
X_train = X_train[:2500, :] # smaller for demo 
X_test = X_test[:2500, :] # ... 

""" NMF fitting """ 
nmf = NMF(n_components=10, random_state=0, alpha=.1, l1_ratio=.5).fit(X_train) 
print('original reconstruction error automatically calculated -> TRAIN: ', nmf.reconstruction_err_) 

""" Manual reconstruction_err_ calculation 
    -> use transform to get W 
    -> ask fitted NMF to get H 
    -> use available _beta_divergence-function to calculate desired metric 
""" 
W_train = nmf.transform(X_train) 
rec_error = _beta_divergence(X_train, W_train, nmf.components_, 'frobenius', square_root=True) 
print('Manually calculated rec-error train: ', rec_error) 

W_test = nmf.transform(X_test) 
rec_error = _beta_divergence(X_test, W_test, nmf.components_, 'frobenius', square_root=True) 
print('Manually calculated rec-error test: ', rec_error) 

出力:

('original reconstruction error automatically calculated -> TRAIN: ', 37.326794668961604) 
('Manually calculated rec-error train: ', 37.326816210011778) 
('Manually calculated rec-error test: ', 37.019526486067413) 

備考:が、おそらくFP-数学によって誘発されるいくつかの小さな誤差があるが、私は怠け者ですこれが正確にどこから来るのかを確認する。問題が小さいほど問題は小さくなり、n_featuresの点では上記の問題は大きくなります。

使用されるこの計算と関数は、開発者によって決定されたいくつかの形式であり、おそらく根本的な理論があります。 しかし一般的に言っておくと、MFはすべて再構成に関するものなので、X_orignmf.inverse_transform(nmf.transform(X_orig))を比較するアイデアに基づいて、好きなすべてのメトリックを作成できます。

+0

迅速なプロトタイプを作成していた当時の私のために働いた回避策/近似をありがとう。私は、この機能をscikit-learn APIに追加することを提案し、コミッタとして参加したいと考えています。あなたが私に加わることに興味があるなら、私に知らせてください、そして、私はリンクを送るでしょう – user179041

関連する問題