2016-01-02 9 views
5

私はTensor FlowでLSTMを扱おうとしています。私はチュートリアルをオンラインで見つけました。一連のシーケンスが取り込まれ、目的関数はLSTMの最後の出力と既知の値で構成されています。しかし、私は目的関数が各出力からの情報を使用するようにしたいと思います。具体的には、私は(つまり、文章中の単語のすべての文字を学ぶ)LSTMは、シーケンスのセットを学ぶ持ってしようとしています。:Tensorflow:コストのためのテンソルの一覧

cell = rnn_cell.BasicLSTMCell(num_units) 
inputs = [tf.placeholder(tf.float32,shape=[batch_size,input_size]) for _ in range(seq_len)] 
result = [tf.placeholder(tf.float32, shape=[batch_size,input_size]) for _ in range(seq_len)] 

W_o = tf.Variable(tf.random_normal([num_units,input_size], stddev=0.01))  
b_o = tf.Variable(tf.random_normal([input_size], stddev=0.01)) 

outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32) 

losses = [] 

for i in xrange(len(outputs)): 
    final_transformed_val = tf.matmul(outputs[i],W_o) + b_o 
    losses.append(tf.nn.softmax(final_transformed_val)) 

cost = tf.reduce_mean(losses) 

これを行うとエラーになります:

どう
TypeError: List of Tensors when single Tensor expected 

この問題を解決する必要がありますか? tf.reduce_mean()はテンソル値のリストを取りますか、それともそれを取る特別なテンソルオブジェクトがありますか?

答えて

3

コードでは、lossesはPythonのリストです。 TensorFlowのreduce_mean()は、Pythonのリストではなく、単一のテンソルしか必要としません。

losses = tf.reshape(tf.concat(1, losses), [-1, size]) 

ここで、sizeは、softmax以上を取っている値の数です。 concat()

the TensorFlow Tutorialのコードでは、入力に3テンソルのオーダーが使用されていますが、入力にはプレースホルダーのリストがあることがわかりました。入力は2次テンソルのリストです。チュートリアルでコードを調べることをお勧めします。なぜなら、ほぼ正確にあなたが尋ねていることを行うからです。

このチュートリアルのメインファイルの1つはhereです。特に、139行目はコストを作り出す場所です。 入力に関しては、90行目と91行目は、入力プレースホルダとターゲットプレースホルダが設定されている場所です。これらの2行の主な目的は、シーケンス全体がプレースホルダのリストではなく単一のプレースホルダに渡されることです。

ptb_word_lm.pyファイルの120行目を参照して、連結の場所を確認してください。例作業

+1

これは私が誤解しています。私は自分の答えを削除しました:)初心者の助けになるかもしれないので、あなたのポストのチュートリアルのコード例を表示したいと思うかもしれません(私も理解したいと思います)。 – Will

+1

私が話していたコードはhttps://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/rnn/ptb/ptb_word_lm.py#139であり、入力に関する部分はその90行目ですファイル。そのファイルの行90のnum_stepsは、OPのコードのseq_lenと同じです(大まかに)。 __init__メソッド全体をインライン展開することを提案するか、ファイルへのリンクを提供するだけですか? –

+0

ああ、わかりました、ええ、これは非常に大きな方法です。私はファイルへのリンクは、いくつかの関連する行と説明の小さなビットで、それが最も明確になると思う。あなたの投稿は素晴らしい出発点であるようですが、OPが正しい方法でやっていないコードの部分の「正しい例」を見ていると、私は思っています。 私が時には使う方法は、メソッド/クラス全体をペーストするだけですが、無関係な行を '#... 'に置き換えて短縮します。 – Will

2

notebookチェック:

import tensorflow as tf 
from tensorflow.models.rnn import rnn, rnn_cell 
print(tf.__version__) 
#> 0.8.0 

batch_size = 2 
output_size = input_size = 2 
seq_len  = 10 
num_units = 2 

cell = rnn_cell.BasicLSTMCell(num_units) 
inputs = [tf.placeholder(tf.float32, shape=[batch_size,input_size ]) for _ in xrange(seq_len)] 
result = [tf.placeholder(tf.float32, shape=[batch_size,output_size]) for _ in xrange(seq_len)] 

W_o = tf.Variable(tf.random_normal([num_units,input_size], stddev=0.01))  
b_o = tf.Variable(tf.random_normal([input_size],   stddev=0.01)) 

outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32) 

losses = [] 

for i in xrange(seq_len): 
    final_transformed_val = tf.matmul(outputs[i],W_o) + b_o 
    losses.append(tf.squared_difference(result[i],final_transformed_val)) 

losses = tf.reshape(tf.concat(1, losses), [-1, seq_len]) 
cost = tf.reduce_mean(losses) 

は、この動作を確認するには、あなたがハック方法でグラフを養うことができます。tensorflow-初心者として

import matplotlib.pyplot as plt 
import numpy as np 

step = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost) 
sess = tf.InteractiveSession() 

sess.run(tf.initialize_all_variables()) 

costs = [] 

# EXAMPLE 
# Learn cumsum over each sequence in x 
# | t  | 0 | 1 | 2 | 3 | 4 | ...| 
# |----------|---|---|---|---|----|----| 
# | x[:,0,0] | 1 | 1 | 1 | 1 | 1 | ...| 
# | x[:,0,1] | 1 | 1 | 1 | 1 | 1 | ...| 
# |   | | | | | | | 
# | y[:,0,0] | 1 | 2 | 3 | 4 | 5 | ...| 
# | y[:,0,1] | 1 | 2 | 3 | 4 | 5 | ...| 

n_iterations = 300 
for _ in xrange(n_iterations): 
    x = np.random.uniform(0,1,[seq_len,batch_size,input_size]) 
    y = np.cumsum(x,axis=0) 

    x_list = {key: value for (key, value) in zip(inputs, x)} 
    y_list = {key: value for (key, value) in zip(result, y)} 

    err,_ = sess.run([cost, step], feed_dict=dict(x_list.items()+y_list.items())) 
    costs.append(err) 

plt.plot(costs) 
plt.show() 

enter image description here

I RNNを扱う統一された方法/ベストプラクティスの方法をまだ見つけていないが、上記のように私はこれが推奨されないと確信している。スニペットのおかげで、あなたのスクリプトが非常に素晴らしいイントロとして好きだった。また、w.r.gで進行中のことがimplementation of scan and RNN-tuple-friendlinessになっているので注意してください。

関連する問題