2017-01-06 11 views
1

バッチ・4DテンソルTensorflowインデクシング

  • batch_images:形状の4Dテンソル(B, H, W, C)
  • x:形状の三次元テンソル(B, H, W)
  • y:形状の三次元テンソル(B, H, W)

ゴール

どの形状B, H, W, Cの4Dテンソルを得るためxy座標を用いbatch_imagesにIインデックス缶。つまり、各バッチについて、そして各ペアについて、(x, y)テンソル形状Cを取得したいと考えています。

numpyでは、これは例えばinput_img[np.arange(B)[:,None,None], y, x]を使用して達成されますが、テンソルフローでは機能しないようです。

これまでの形状(B, H, W)の間違ったテンソルを返す

def get_pixel_value(img, x, y): 
    """ 
    Utility function to get pixel value for 
    coordinate vectors x and y from a 4D tensor image. 
    """ 
    H = tf.shape(img)[1] 
    W = tf.shape(img)[2] 
    C = tf.shape(img)[3] 

    # flatten image 
    img_flat = tf.reshape(img, [-1, C]) 

    # flatten idx 
    idx_flat = (x*W) + y 

    return tf.gather(img_flat, idx_flat) 

私の試み。

答えて

1

テンソルを平坦化することで可能ですが、インデックス計算ではバッチ次元を考慮する必要があります。 これを行うには、現在のバッチのインデックスを常に含むxyという同じ形状の追加のダミーバッチインデックステンソルを作成する必要があります。 これは基本的にnumpyの例のnp.arange(B)で、TensorFlowコードにはありません。

また、インデックスの計算を行うtf.gather_ndを使用すると、少し単純化することもできます。ここで

は例です:

import numpy as np 
import tensorflow as tf 

# Example tensors 
M = np.random.uniform(size=(3, 4, 5, 6)) 
x = np.random.randint(0, 5, size=(3, 4, 5)) 
y = np.random.randint(0, 4, size=(3, 4, 5)) 

def get_pixel_value(img, x, y): 
    """ 
    Utility function that composes a new image, with pixels taken 
    from the coordinates given in x and y. 
    The shapes of x and y have to match. 
    The batch order is preserved. 
    """ 

    # We assume that x and y have the same shape. 
    shape = tf.shape(x) 
    batch_size = shape[0] 
    height = shape[1] 
    width = shape[2] 

    # Create a tensor that indexes into the same batch. 
    # This is needed for gather_nd to work. 
    batch_idx = tf.range(0, batch_size) 
    batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) 
    b = tf.tile(batch_idx, (1, height, width)) 

    indices = tf.pack([b, y, x], 3) 
    return tf.gather_nd(img, indices) 

s = tf.Session() 
print(s.run(get_pixel_value(M, x, y)).shape) 
# Should print (3, 4, 5, 6). 
# We've composed a new image of the same size from randomly picked x and y 
# coordinates of each original image. 
関連する問題