2016-11-23 9 views
3

を予測する私は、次のコードを書いた:Sklearnは、複数の出力

from sklearn import tree 

# Dataset & labels 
# Using metric units 
# features = [height, weight, style] 
styles = ['modern', 'classic'] 
features = [[1.65, 65, 1], 
      [1.55, 50, 1], 
      [1.76, 64, 0], 
      [1.68, 77, 0] ] 
labels = ['Yellow dress', 'Red dress', 'Blue dress', 'Green dress'] 

# Decision Tree 
clf = tree.DecisionTreeClassifier() 
clf = clf.fit(features, labels) 

# Returns the dress 
height = input('Height: ') 
weight = input('Weight: ') 
style = input('Modern [0] or Classic [1]: ') 
print(clf.predict([[height,weight,style]])) 

このコードは、その彼女に、より良いフィットドレスを返し、その後、使用者の身長と体重を受けます。複数のオプションを返す方法はありますか?たとえば、2つ以上のドレスを返します。

UPDATE

from sklearn import tree 
import numpy as np 

# Dataset & labels 
# features = [height, weight, style] 
# styles = ['modern', 'classic'] 
features = [[1.65, 65, 1], 
      [1.55, 50, 1], 
      [1.76, 64, 1], 
      [1.72, 68, 0], 
      [1.73, 68, 0], 
      [1.68, 77, 0]] 
labels = ['Yellow dress', 
      'Red dress', 
      'Blue dress', 
      'Green dress', 
      'Purple dress', 
      'Orange dress'] 

# Decision Tree 
clf = tree.DecisionTreeClassifier() 
clf = clf.fit(features, labels) 

# Returns the dress 
height = input('Height: ') 
weight = input('Weight: ') 
style = input('Modern [0] or Classic [1]: ') 

print(clf.predict_proba([[height,weight,style]])) 

ユーザーは1.72メートルと68キロであれば、私は緑と紫のドレスの両方を表示したいです。この例では、緑のドレスの100%を返します。

+0

それが1以上のものを返します:あなたはそれでこのような何かを行うことができ

?最も可能な順序でそれらを返すことを意味しますか? – erip

答えて

5

はいできます。実際にあなたができることは、各クラスの確率を得ることができるということです。いくつかのクラシファイアには.predict_proba()という機能が実装されています。

here、sklearnのドキュメントを参照してください。

各クラスのサンプルのメンバーシップの確率を返します。

たとえば、2つ、3つの最も高い確率に関連付けられたラベルを返すことができます。

+0

この方法は、条件に完全に適合するオプションを選択するだけです。例えば、身長が1.72m、体重が68kgの場合は、1.73mと68kgのドレスと1.72と68kgのドレスを表示したいと思います。 – bodruk

2

predict()は高い確率でのみクラスを返します。 predict_proba()を代わりに使用すると、各クラスの確率で配列が返されるため、たとえば、特定のしきい値を超えるものを選択できます。

Hereは、このメソッドのドキュメントです。

probs = clf.predict_proba([[height, weight, style]]) 
threshold = 0.25 # change this accordingly 
for index, prob in enumerate(probs[0]): 
    if prob > threshold: 
     print(styles[index]) 
+1

あなたのしきい値はうまく選択されていません。 'threshold = 0.5'の場合、1つのラベルしか返しません。確率の合計は1に等しく、他のラベルは0.5より大きいprobaを持つことはできません... – MMF

+1

はい、そうです。私はちょっと考えずに乱数を選んだだけです。 –