2016-03-23 9 views
2

テンソルフローは、双方向RNNの可変バッチサイズをサポートしていないようです。テンソルフローの双方向RNNに可変バッチサイズを使用する方法

_seq_len = tf.fill([batch_size], tf.constant(n_steps, dtype=tf.int64)) 
    outputs, state1,state2 = rnn.bidirectional_rnn(rnn_fw_cell, rnn_bw_cell, input, 
            dtype="float", 
            sequence_length=_seq_len) 

がどのように私はトレーニングやテストのためのさまざまなバッチサイズを使用することができます。この例ではsequence_lengthは、Pythonの整数である、batch_sizeに結びついていますか?

答えて

5

双方向コードは可変バッチサイズで動作します。たとえば、this test codeを見ると、tf.placeholder(..., shape=(None, input_size))が作成されます(Noneはバッチサイズが可変であることを意味します)。

あなたはあなたのコードスニペットは、小さな変更で可変のバッチサイズで動作するように変換することができます:

# Compute the batch size based on the shape of the (presumably fed-in) `input` 
# tensor. (Assumes that `input = tf.placeholder(..., shape=[None, input_size])`.) 
batch_size = tf.shape(input)[0] 

_seq_len = tf.fill(tf.expand_dims(batch_size, 0), 
        tf.constant(n_steps, dtype=tf.int64)) 
outputs, state1, state2 = rnn.bidirectional_rnn(rnn_fw_cell, rnn_bw_cell, input, 
               dtype=tf.float32, 
               sequence_length=_seq_len) 
関連する問題