2016-12-13 8 views
3

TensorFlowでK-Nearest Neighborを実装するのに苦労しています。私は間違いを見落としているか、何かひどいやり方をしていると思います。TensorFlowでのKNN実装に関する問題

次のコードは常に0

from __future__ import print_function 

import numpy as np 
import tensorflow as tf 

# Import MNIST data 
from tensorflow.examples.tutorials.mnist import input_data 

K = 4 
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) 

# In this example, we limit mnist data 
Xtr, Ytr = mnist.train.next_batch(55000) # whole training set 
Xte, Yte = mnist.test.next_batch(10000) # whole test set 

# tf Graph Input 
xtr = tf.placeholder("float", [None, 784]) 
ytr = tf.placeholder("float", [None, 10]) 
xte = tf.placeholder("float", [784]) 

# Euclidean Distance 
distance = tf.neg(tf.sqrt(tf.reduce_sum(tf.square(tf.sub(xtr, xte)), reduction_indices=1))) 
# Prediction: Get min distance neighbors 
values, indices = tf.nn.top_k(distance, k=K, sorted=False) 
nearest_neighbors = [] 
for i in range(K): 
    nearest_neighbors.append(np.argmax(ytr[indices[i]])) 

sorted_neighbors, counts = np.unique(nearest_neighbors, return_counts=True) 

pred = tf.Variable(nearest_neighbors[np.argmax(counts)]) 

# not works either 
# neighbors_tensor = tf.pack(nearest_neighbors) 
# y, idx, count = tf.unique_with_counts(neighbors_tensor) 
# pred = tf.slice(y, begin=[tf.arg_max(count, 0)], size=tf.constant([1], dtype=tf.int64))[0] 

accuracy = 0. 

# Initializing the variables 
init = tf.initialize_all_variables() 

# Launch the graph 
with tf.Session() as sess: 
    sess.run(init) 

    # loop over test data 
    for i in range(len(Xte)): 
     # Get nearest neighbor 
     nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]}) 
     # Get nearest neighbor class label and compare it to its true label 
     print("Test", i, "Prediction:", nn_index, 
       "True Class:", np.argmax(Yte[i])) 
     # Calculate accuracy 
     if nn_index == np.argmax(Yte[i]): 
      accuracy += 1./len(Xte) 
    print("Done!") 
    print("Accuracy:", accuracy) 

任意の助けを大幅に理解されるようにMnistラベルを予測します。

+0

あなたがグラフに 'numpy'機能を使用しているべきではありません。ここで

は完全な作業コードです。 – martianwars

答えて

7

一般的に、TensorFlowモデルを定義するときにnumpy関数に行くのは良い考えではありません。それが、あなたのコードがうまくいかない理由です。あなたのコードにちょうど2つの変更を加えました。私はnp.argmaxtf.argmaxに置き換えました。私は#This doesn't work eitherからコメントを削除しました。

from __future__ import print_function 

import numpy as np 
import tensorflow as tf 

# Import MNIST data 
from tensorflow.examples.tutorials.mnist import input_data 

K = 4 
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) 

# In this example, we limit mnist data 
Xtr, Ytr = mnist.train.next_batch(55000) # whole training set 
Xte, Yte = mnist.test.next_batch(10000) # whole test set 

# tf Graph Input 
xtr = tf.placeholder("float", [None, 784]) 
ytr = tf.placeholder("float", [None, 10]) 
xte = tf.placeholder("float", [784]) 

# Euclidean Distance 
distance = tf.negative(tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(xtr, xte)), reduction_indices=1))) 
# Prediction: Get min distance neighbors 
values, indices = tf.nn.top_k(distance, k=K, sorted=False) 

nearest_neighbors = [] 
for i in range(K): 
    nearest_neighbors.append(tf.argmax(ytr[indices[i]], 0)) 

neighbors_tensor = tf.stack(nearest_neighbors) 
y, idx, count = tf.unique_with_counts(neighbors_tensor) 
pred = tf.slice(y, begin=[tf.argmax(count, 0)], size=tf.constant([1], dtype=tf.int64))[0] 

accuracy = 0. 

# Initializing the variables 
init = tf.initialize_all_variables() 

# Launch the graph 
with tf.Session() as sess: 
    sess.run(init) 

    # loop over test data 
    for i in range(len(Xte)): 
     # Get nearest neighbor 
     nn_index = sess.run(pred, feed_dict={xtr: Xtr, ytr: Ytr, xte: Xte[i, :]}) 
     # Get nearest neighbor class label and compare it to its true label 
     print("Test", i, "Prediction:", nn_index, 
      "True Class:", np.argmax(Yte[i])) 
     #Calculate accuracy 
     if nn_index == np.argmax(Yte[i]): 
      accuracy += 1./len(Xte) 
    print("Done!") 
    print("Accuracy:", accuracy) 
+0

あなたのコードを修正しました。私はそれを打ち砕く時間がありませんでしたが、できるだけ早くお知らせしたいと思います:トレースバック(最新のコール最後):ライン28、 nearest_neighbors。 TypeError:タイプがのTensor( "strided_slice:0"、shape =()、dtype = int32)型エラー:テンソルflow.python .framework.ops.Tensor '> Python 3.5とtensorflow r0.12で実行する – wrecker

+0

テンソルフロー0.11でこのエラーが発生するのを覚えていません。私はもう少し詳しく見ていきますが、 ':'を取り除くことができます。 – martianwars

+0

コードを更新しました。これがうまくいくかどうかを教えてください。このコードと前のコードの両方がr0.11で私のために働いていました。基本的には、https://github.com/tensorflow/tensorflow/issues/206に起因する不一致があります。 – martianwars

関連する問題