2017-02-09 11 views
0

LSTMを使用してパラメータを追跡するTensorflow NN(時系列データ回帰の問題)に取り組んでいます。トレーニングデータのバッチには、の連続したのバッチサイズが含まれています。私は次のサンプルへの入力としてLSTM状態を使用したいと思います。したがって、私がデータ観測のバッチを持っていれば、第1の観測の状態を第2の観測への入力としてフィードしたいと思います。以下では、lstmの状態をsize = batch_sizeのテンソルとして定義します。私はバッチ内の状態を再利用したいと思います:Tensorflow - バッチ内のLSTM状態の再利用

state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False) 
cell = tf.nn.rnn_cell.BasicLSTMCell(100) 
output, curr_state = tf.nn.rnn(cell, data, initial_state=state) 

APIでありtf.nn.state_saving_rnnがあるが、ドキュメントがちょっと曖昧です。 私の質問:curr_state を再利用する方法のトレーニングバッチ。

+0

明確にするために、最初のバッチ要素の結果から状態を次のバッチ要素の開始状態にスレッド化するなどしますか?その場合、バッチディメンションは正確に時間ディメンションではありませんか? –

+0

@Allen Lavoie、そうだよ。バッチ内の各データ観測は、(多次元の)時系列ウィンドウです。バッチには、連続して配置された重複ウィンドウが含まれます。バッチディメンションは、オーバーラップとストライドを持つ時間ディメンションです。 – Leeor

+1

その場合は、バッチ次元は実際には1です。複数のシーケンスがある場合を除いて、バッチ処理を行うことができます。これは比較的遅くなります。単一のより長い時系列のバッチ処理を可能にする近似をサポートする努力が進行中ですが、まだ公開されていないものはありません。 –

答えて

1

あなただけcurr_statestateを更新する必要があり、基本的にあります。

state_update = tf.assign(state, curr_state) 

次に、あなたがstate_update自体や依存関係としてstate_updateを持っている操作にrunを呼び出すか、割り当てがないだろうことを確認してください実際に起こる。たとえば、次のコメントで示唆したように

with tf.control_dependencies([state_update]): 
    model_output = ... 

、のRNNのための典型的なケースでは、最初の次元(0)配列の数であり、二次元(1)最大長さバッチを有することです(これらの2つが交換されたときにRNNを構築するときにtime_major=Trueを渡した場合)。理想的には、良いパフォーマンスを得るために、複数のシーケンスを1つのバッチにスタックし、そのバッチを時間的に分割します。しかし、それはまったく別の話題です。

関連する問題