2016-06-24 7 views
1

Tensorflowにあるコードを書いて、1つの文字列と文字列の間の編集距離を計算しました。私はエラーを理解することはできません。編集距離を計算する(feed_dictエラー)

import tensorflow as tf 
sess = tf.Session() 

# Create input data 
test_string = ['foo'] 
ref_strings = ['food', 'bar'] 

def create_sparse_vec(word_list): 
    num_words = len(word_list) 
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)] 
    chars = list(''.join(word_list)) 
    return(tf.SparseTensor(indices, chars, [num_words,1,1])) 


test_string_sparse = create_sparse_vec(test_string*len(ref_strings)) 
ref_string_sparse = create_sparse_vec(ref_strings) 

sess.run(tf.edit_distance(test_string_sparse, ref_string_sparse, normalize=True)) 

このコードは動作したときに実行し、それが出力を生成します

array([[ 0.25], 
     [ 1. ]], dtype=float32) 

しかし、私はまばらなプレースホルダ経由でスパーステンソルを供給することにより、これを行うにしようとしたとき、私はエラーを取得します。ここで

test_input = tf.sparse_placeholder(dtype=tf.string) 
ref_input = tf.sparse_placeholder(dtype=tf.string) 

edit_distances = tf.edit_distance(test_input, ref_input, normalize=True) 

feed_dict = {test_input: test_string_sparse, 
      ref_input: ref_string_sparse} 

sess.run(edit_distances, feed_dict=feed_dict) 

エラートレースバックです:

Traceback (most recent call last): 

    File "<ipython-input-29-4e06de0b7af3>", line 1, in <module> 
    sess.run(edit_distances, feed_dict=feed_dict) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 372, in run 
run_metadata_ptr) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 597, in _run 
    for subfeed, subfeed_val in _feed_fn(feed, feed_val): 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 558, in _feed_fn 
    return feed_fn(feed, feed_val) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 268, in <lambda> 
    [feed.indices, feed.values, feed.shape], feed_val)), 

TypeError: zip argument #2 must support iteration 

ここで何が起こっているすべてのアイデア?

+0

エラーはおそらく完璧に動作し、あなたがありがとう、作成 –

答えて

2

TL; DR:create_sparse_vec()の戻り型についてtf.SparseTensorValue代わりにtf.SparseTensorを使用します。

問題はここtf.SparseTensorあり、そしてsess.run()への呼び出しでフィードとして理解されていないcreate_sparse_vec()の戻り値の型、から来ています。

(稠密)tf.Tensorを入力すると、期待値タイプはNumPy配列(または配列に変換できる特定のオブジェクト)です。 tf.SparseTensorを入力すると、期待値タイプはtf.SparseTensorValueで、tf.SparseTensorに似ていますが、indices,valuesshapeのプロパティはNumPy配列(または例のリストのように配列に変換できる特定のオブジェクト)です。

次のコードは動作するはずです:?!

def create_sparse_vec(word_list): 
    num_words = len(word_list) 
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)] 
    chars = list(''.join(word_list)) 
    return tf.SparseTensorValue(indices, chars, [num_words,1,1]) 
+0

のためのコードを提供することができますtest_string_parse'または 'ref_string_parse'値'から来ています。 – nfmcclure