2017-01-25 4 views
0

私はTensorflowを使って、テンソルのインデックスに対応するもの以外のすべてのサブ要素を抽出する方法を探しています。傾いたテンソルのI番目のサブ要素を削除またはマスキングしますか?

(例えば、インデックス1を見てのみサブ要素0および2が存在する場合)

numpyのを使用してthis approachに非常に類似します。

ここではタイル張りのテンソルとブールマスクを作成するために、いくつかのサンプルコードです:

import tensorflow as tf 
import numpy as np 

_coordinates = np.array([ 
    [1.0, 7.0, 0.0], 
    [2.0, 7.0, 0.0], 
    [3.0, 7.0, 0.0], 
]) 

verts_coord = _coordinates 
n = verts_coord.shape[0] 

mat_loc = tf.Variable(verts_coord) 

tile = tf.tile(mat_loc, [n, 1]) 
tile = tf.reshape(tile, [n, n, n]) 

mask = tf.constant(~np.eye(n, dtype=bool)) 

result = tf.somefunc(tile, mask) #somehow extract only the elements where mask == true 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    print(sess.run(tile)) 
    print(sess.run(mask)) 

出力例テンソル:

>>> print(tile) 
[[[ 1. 7. 0.] 
    [ 2. 7. 0.] 
    [ 3. 7. 0.]] 

[[ 1. 7. 0.] 
    [ 2. 7. 0.] 
    [ 3. 7. 0.]] 

[[ 1. 7. 0.] 
    [ 2. 7. 0.] 
    [ 3. 7. 0.]]] 

>>> print(mask) 
[[False True True] 
[ True False True] 
[ True True False]] 

所望の出力:

>>> print(result) 
[[[ 2. 7. 0.] 
    [ 3. 7. 0.]] 

[[ 1. 7. 0.] 
    [ 3. 7. 0.]] 

[[ 1. 7. 0.] 
    [ 2. 7. 0.]]] 

私も好奇心大規模なテンソルを作成し、それをマスキングするのではなく、これを行うためのより効率的な方法があれば?

ありがとうございます!

答えて

0

Tensorflowは、私はすでに:)

result = tf.boolean_mask(tile, mask) 
result = tf.reshape(result, [n, n-1, -1]) 

>>> print(result) 
[[[ 2. 7. 0.] 
    [ 3. 7. 0.]] 

[[ 1. 7. 0.] 
    [ 3. 7. 0.]] 

[[ 1. 7. 0.] 
    [ 2. 7. 0.]]] 
建て探していたまさにでしたが判明
関連する問題