2017-02-18 5 views
3

mによってnの2つのランク2テンソルarr1arr2があります。テンソルarr2はブール値です。その行のそれぞれに正確に1つのエントリはTrueです。 Iはarr3i番目のエントリは、arr2i行目はTrueに等しい場合に対応arr1i行目のエントリに等しい長さmの新しいランク1テンソルarr3を抽出します。 numpyテンソルのブーリアンテンソルを使用してテンソルのサブセットを選択してください

次のように、私はこれを行うことができます。

arr1 = np.array([[1,2], 
       [3,4]]) 
arr2 = np.array([[0,1], 
       [1,0]], dtype="bool") 
arr3 = arr1[arr2] 

私はtensorflowに似た何かを行うことができますか?私はeval()テンソルを使用してnumpy関数を使用することができることを知っていますが、それは非効率的です。

このquestionは、tf.gathertf.selectの使用を示唆していますが、私の質問のように出力の次元を崩壊させることはありません。

答えて

2

おそらくtf.boolean_maskを使用できますか?

from __future__ import print_function 

import tensorflow as tf 

with tf.Session() as sess: 
    arr1 = tf.constant([[1,2], 
         [3,4]]) 
    arr2 = tf.constant([[False, True], 
         [True, False]]) 
    print(sess.run(tf.boolean_mask(arr1, arr2))) 

を与える必要があり:[2、3]

関連する問題