2016-09-06 17 views
0

私は強化学習に取り組んでおり、学習中にsess.run()を使用して学習したデータの量を減らして学習をスピードアップしたいと考えています。TensorFlow:グラフ内のLSTM状態の保存/更新

私はLSTMに探していたし、楽しみにして適切なQ値を見つけるためにリセットする必要があると、私は)(tf.caseでこのようなソリューションを細工:

CurrentStateOption = tf.Variable(0, trainable=False, name='SavedState') 
with tf.name_scope("LSTMLayer") as scope: 
     initializer = tf.random_uniform_initializer(-.1, .1) 
     lstm_cell_L1 = tf.nn.rnn_cell.LSTMCell(self.input_sizes, forget_bias=1.0, initializer=initializer, state_is_tuple=True) 
     self.cell_L1 = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_L1] *self.NumberLSTMLayers, state_is_tuple=True) 
     self.state = self.cell_L1.zero_state(1,tf.float64) 

     self.SavedState = self.cell_L1.zero_state(1,tf.float64) #tf.Variable(state, trainable=False, name='SavedState') 

     #SaveCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(1)), self.SaveState, self.SameState) 
     #RestoreCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(-1)), self.RestoreState, self.SameState) 
     #ZeroCond = tf.cond(tf.less(CurrentStateOption,tf.constant(-1)), self.ZeroState, self.SameState) 

     self.state = tf.case({tf.equal(CurrentStateOption,tf.constant(1)): self.SaveState, tf.equal(CurrentStateOption,tf.constant(-1)): self.RestoreState, 
      tf.less(CurrentStateOption,tf.constant(-1)): self.ZeroState}, default=self.SameState, exclusive=True) 

     RunConditions = tf.group([SaveCond, RestoreCond, ZeroCond]) 

     self.Xinputs = [tf.concat(1,[Xinputs])] 

     outputs, stateFINAL_L1 = rnn.rnn(self.cell_L1,self.Xinputs, initial_state=self.state, dtype=tf.float32) 
def RestoreState(self): 
    #self.state = self.state.assign(self.SavedState) 
    self.state = self.SavedState 
    return self.state 
def ZeroState(self): 
    self.state = self.cell_L1.zero_state(1,tf.float64) 
    return self.state 
def SaveState(self): 
    #self.SavedState = self.SavedState.assign(self.state) 
    self.SavedState = self.state 
    return self.SavedState 
def SameState(self): 
    return self.state 

これは私がwが何をすべきかLSTMグラフを指示するINTを供給できるようになりました概念にうまく動作するようですith状態。パス "1"を実行する前に状態を保存します。 "-1"を渡すと最後に保存された状態に戻ります。 "< -1"を渡すと状態はゼロになります。 "0"の場合は、最後に実行したときのLSTM(推論)を使用します。私は単純なtf.cond()アプローチを含むいくつかの異なるアプローチを試みました。

テンソルを必要とするtf.case()オペレーションが、LSTMの状態がタプルである(タプル以外のタプルが減価償却される)から生じると考えられる問題です。これは、グラフ変数に値をtf.assign()しようとすると明らかになりました。

私の最終目標はグラフ内に「状態」を残し、INTを渡して状態をどう処理するかを指示することです。将来、私はさまざまなルックバックのために複数の "店舗"の場所を持っていたいと思います。

タフルとテンソルの構造体のtf.case()型を扱う方法はありますか?

答えて

0

タプルは単なるpythonタプルなので、状態タプルで要素ごとに1つのtf.case()を使用すると効果があります。

関連する問題