0

約1万のつぶやきのサンプルを「関連する」と「関連しない」カテゴリに分類したいと思っています。私はPythonのscikit-learnをこのモデルに使用しています。私は手動で1,000個のつぶやきを「関連性のある」または「関連性のない」ものとしてコード化しました。次に、手作業でコード化されたデータの80%をトレーニングデータとして使用し、残りをテストデータとして使用してSVMモデルを実行しました。私は良い結果(予測精度〜0.90)を得ましたが、過適合を避けるために、1,000個の手動でコーディングされたすべてのツイートにクロスバリデーションを使用することに決めました。cross_val_predictの後に新しい文書を分類する

以下は、私のサンプルのツイートのtf-idfマトリックスを取得した後のコードです。 "target"は、ツイートが "関連"または "非関連"としてマークされているかどうかを示す配列です。

from sklearn.linear_model import SGDClassifier 
from sklearn.model_selection import cross_val_score 
from sklearn.model_selection import cross_val_predict 

clf = SGDClassifier() 
scores = cross_val_score(clf, X_tfidf, target, cv=10) 
predicted = cross_val_predict(clf, X_tfidf, target, cv=10) 

このコードでは、1,000個のツイートが属するクラスを予測できました。私はこれを手動でのコーディングと比較できます。

私のモデルを使って手動でコード化しなかった他の〜9,000個のつぶやきを分類するために、私は次に何をすべきかに固執しています。 cross_val_predictをもう一度使うことを考えていましたが、クラスが私が予測しようとしているものなので、3番目の引数に何を入れるべきか分かりません。

ご協力いただきありがとうございます。

答えて

4

cross_val_predictは、実際にモデルから予測を得る方法であるではなくです。クロスバリデーションは、モデル選択/評価のテクニックであり、モデルには適用されません。 cross_val_predictは非常に特殊な関数です(相互検証の手順中に訓練された多くのモデルの予測を提供します)。実際のモデル構築には、フィットモデルを使用してモデルをトレーニングし、予測を予測する必要があります。前にも述べたように、クロスバリデーションは必要ありません。これはモデル選択(クラシファイア、ハイパーパラメータなどの選択)であり、実際のモデルを訓練するものではありません。

+0

ありがとう、@lejlot!私はまだcross_val_predictに関してちょっと混乱しています。正確に何を返すのですか?また、SVMとナイーブなBayes分類子に対して 'cross_val_predict'と' cross_val_score'を実行したとしましょう。これらの関数の出力を使って、分類器を決めるとしましょうか? – Eunice

+0

cross_val_predictは、内部的には、提供する分割数(自分の場合は10)と同じ数のモデルを、それぞれ異なるデータ分割でトレーニングします。その後、各モデルを使用して見えないデータを予測し、すべての結果を連結して予測のリストを作成します。これは主に*デバッグ*/*解析*ツールであり、トレーニングに使用するものではありません。あなたの使用例では、** cross_val_predictをまったく使用しないでください**。しかし、cross_val_scoreを使って(SVMやNBのような)モデル間で決めることができます。 – lejlot

+0

それは感謝します@lejlot !!どのようにしてテキスト分類モデルのオーバーフィットを防ぐのですか? – Eunice

関連する問題