2017-02-09 4 views
2

私はTheanosバックエンドでKeras NNを使用しています。私は14の出力クラスで分類問題に取り組んでいます。私は予測されたクラスに関連する確率を加えたい。問題は、predict_proba()の確率がpredict()の予測クラスと一致しないように見えることです。ここでは、コードに1サンプルの出力を加えたものです。Keras分類子predict_proba()がpredict()と一致しません

PPRANK = ['pp1', 'pp2', 'pp3', 'pp4', 'pp5', 'pp6', 'pp7', 'pp8', 'pp9', 'pp10', 'pp11', 'pp12', 'pp13', 'pp14', 'pp15'] 

FEATURES = (PPRANK) 

# fix random seed for reproducibility 
seed = 7 
np.random.seed(seed) 

data_df = pd.DataFrame.from_csv("data.csv") 
X = np.array(data_df[FEATURES].values) 
Y = (data_df["bres"].replace(14,13).values) 


# define baseline model 
def baseline_model(): 
    # create model 
    model = Sequential() 
    model.add(Dense(8, input_dim=(len(FEATURES)), init='normal', activation='relu')) 
    model.add(Dense(14, init='normal', activation='softmax')) 
    # Compile model 
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 
    return model 
#build model 
estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=200, batch_size=5, verbose=0) 

#split train and test 
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=seed) 
estimator.fit(X_train, Y_train) 

#get probabilities 
predictions = estimator.predict_proba(X_test) 

#convert expon to floats 
probs = [[] for x in range(21)] 
tick2 = 0 
for i in range(len(predictions)): 
    tick = 0 
    for x in xrange(14): 
     (predictions[i][(tick)]) = '%.4f' % (predictions[i][(tick)]) 
     probs[(tick2)].append((predictions[i][(tick)])) 
     tick += 1 
    tick2 += 1 

# pprint probabilities 
pp = pprint.PrettyPrinter(indent=0) 
pp.pprint(probs) 

#print class predictions 
print estimator.predict(X_test) 
print Y_test 

確率

[0.00000、0.00030、0.02360、0.04329、0.00019、0.00069、0.00120、0.00030、0.00559、0.00410、0.00510、0.91549、0.0、0.0]

クラス

予測

実際のクラス

predict()から11よりむしろpredict_proba()から最も高い確率を持つ12を示します。何か助けてくれてありがとう。

答えて

3

Python配列(およびここではクラス)のインデックスは1からではなく0からカウントアップします。もう一度見てみましょう.0.91は人数をカウントする12番目の値ですが、インデックス= 11なのでpredictとpredict_probaは一貫しています

なぜ13でないのかは、予測が間違っている可能性があります(ただし、同じ種類のエラーがないことを確認してください)

関連する問題