2016-03-18 11 views
6

寸法が[batch_size, num_rows, num_coordinates]のテンソルlogitsです(つまり、ロット内の各ロジットはマトリックスです)。私の場合、バッチサイズは2、4行4座標です。TensorFlowで3Dテンソルから行を選択する方法は?

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0], 
         [12.0, 10.0, 10.0, 20.0], 
         [13.0, 10.0, 10.0, 20.0]], 
        [[14.0, 11.0, 21.0, 31.0], 
         [15.0, 11.0, 11.0, 21.0], 
         [16.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

第1バッチの第1行と第2行と第2バッチの第2行と第4行を選択したいとします。

indices = tf.constant([[0, 1], [1, 3]]) 

だから、所望の出力は、私は、この使用してTensorFlowを行うにはどうすればよい

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0]], 
        [[15.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

は可能でしょうか?私はtf.gather(logits, indices)を使ってみましたが、期待したものを返せませんでした。ありがとう!

答えて

7

これはTensorFlowでは可能ですが、tf.gather()は現在、1次元インデックスでのみ機能し、テンソルの0次元からスライスのみを選択するため、少し不便です。しかし、それは彼らがtf.gather()に渡すことができるように引数を変換することによって、効率的に問題を解決することは可能である:これはtf.reshape()なくtf.transpose()を使用しているため、

logits = ... # [2 x 4 x 4] tensor 
indices = tf.constant([[0, 1], [1, 3]]) 

# Use tf.shape() to make this work with dynamic shapes. 
batch_size = tf.shape(logits)[0] 
rows_per_batch = tf.shape(logits)[1] 
indices_per_batch = tf.shape(indices)[1] 

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately. 
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1) 

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = tf.reshape(indices + offset, [-1]) 
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]])) 

selected_rows = tf.gather(flattened_logits, flattened_indices) 

result = tf.reshape(selected_rows, 
        tf.concat(0, [tf.pack([batch_size, indices_per_batch]), 
            tf.shape(logits)[2:]])) 

注意、それは変更する必要はありません。 logitsテンソルの(潜在的に大きい)データなので、かなり効率的です。

+0

あなたの答えは素晴らしいですが、今日は、あなたが書いている時点ではまだ入手できなかった 'tf.gather_nd'と置き換えることができると思います(私の答えを参照してください) – kaufmanu

4

mrryの答えは素晴らしいですが、私は(この関数はmrryの執筆時点ではまだ利用できませんでしたおそらく)問題は、コードの多くの少ない線で解決することができる機能tf.gather_ndと思う:

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0], 
         [12.0, 10.0, 10.0, 20.0], 
         [13.0, 10.0, 10.0, 20.0]], 
        [[14.0, 11.0, 21.0, 31.0], 
         [15.0, 11.0, 11.0, 21.0], 
         [16.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]]) 

result = tf.gather_nd(logits, indices) 
with tf.Session() as sess: 
    print(sess.run(result)) 

これは印刷されます

[[[ 10. 10. 20. 20.] 
    [ 11. 10. 10. 30.]] 

[[ 15. 11. 11. 21.] 
    [ 17. 11. 11. 21.]]] 

tf.gather_ndがv0.10から利用可能になるはずです。詳細については、this github issueを参照してください。

+0

どのようにして2d (質問されるように)? – Tulsi

+0

@トルシー私はあなたの質問を理解していません。質問に3D指標の言及はありませんか、それともありますか? – kaufmanu

関連する問題