2017-01-31 8 views
5

テンソルフローをセマンティックセグメンテーションに使用しています。ピクセル単位の損失を計算する際にテンソルフローに特定のラベルを無視するように指示するにはどうすればよいですか?Tensorflow:セマンティックセグメンテーションで特定のラベルを無視する方法は?

画像の分類には-1というラベルを設定でき、無視されるin this postと読みました。ラベルテンソルが与えられている場合は、ラベルを変更して特定の値を-1に変更するにはどうすればよいですか? MATLABで

それはのようになります。

ignore_label = 255 
myLabelTensor(myLabelTensor == ignore_label) = -1 

しかし、私はTFでこれを行う方法がわかりませんか?

いくつかの背景情報:これは、ラベルがロードされている方法です

label_contents = tf.read_file(input_queue[1]) 
label = tf.image.decode_png(label_contents, channels=1) 

これは、損失が現在の計算方法です。

raw_output = net.layers['fc1_voc12'] 
prediction = tf.reshape(raw_output, [-1, n_classes]) 
label_proc = prepare_label(label_batch, tf.pack(raw_output.get_shape()[1:3]),n_classes) 
gt = tf.reshape(label_proc, [-1, n_classes]) 

# Pixel-wise softmax loss. 
loss = tf.nn.softmax_cross_entropy_with_logits(prediction, gt) 
reduced_loss = tf.reduce_mean(loss) 

I

def prepare_label(input_batch, new_size, n_classes): 
    """Resize masks and perform one-hot encoding. 

    Args: 
     input_batch: input tensor of shape [batch_size H W 1]. 
     new_size: a tensor with new height and width. 

    Returns: 
     Outputs a tensor of shape [batch_size h w 21] 
     with last dimension comprised of 0's and 1's only. 
    """ 
    with tf.name_scope('label_encode'): 
     input_batch = tf.image.resize_nearest_neighbor(input_batch, new_size) # as labels are integer numbers, need to use NN interp. 
     input_batch = tf.squeeze(input_batch, squeeze_dims=[3]) # reducing the channel dimension. 
     input_batch = tf.one_hot(input_batch, depth=n_classes) 
    return input_batch 

でResnetを転送する tensorflow-deeplab-resnet modelを使用しています caffe-tensorflowを使用してCaffeでテンソルフローに実装されたモデル。

+0

の可能性のある重複した[TensorFlow:?画像分割でのボイドラベルされたデータを処理する方法](https://stackoverflow.com/questions/46097968/tensorflow-how-to-handle-void-labeled-data-画像内セグメンテーション) – Shai

答えて

0

tf.nn.softmax_cross_entropy_with_logitslabelsの有効確率分布で呼び出されなければなりません。そうでなければ計算が不正確になり、tf.nn.sparse_softmax_cross_entropy_with_logits(あなたのケースではもっと便利かもしれません)を使用すると、 NaN値を返します。私はいくつかのラベルを無視するためにそれに頼るつもりはない。私はどうなるのか

は正しいクラスは無視一つであり、それらをピクセル単位で無限大と無視クラスのlogitsを交換することであるので、彼らは損失に何も貢献しないでしょう。

ignore_label = ... 
# Make zeros everywhere except for the ignored label 
input_batch_ignored = tf.concat(input_batch.ndims - 1, 
    [tf.zeros_like(input_batch[:, :, :, :ignore_label]), 
    tf.expand_dims(input_batch[:, :, :, ignore_label], -1), 
    tf.zeros_like(input_batch[:, :, :, ignore_label + 1:])]) 
# Make corresponding logits "infinity" (a big enough number) 
predictions_fix = tf.select(input_batch_ignored > 0, 
    1e30 * tf.ones_like(predictions), predictions) 
# Compute loss with fixed logits 
loss = tf.nn.softmax_cross_entropy_with_logits(prediction, gt) 

唯一の問題これは、無視されたクラスのピクセルが常に正しく予測されることを考慮していることです。つまり、多くのピクセルを含む画像の損失は人為的に小さくなります。場合によっては、これが重要な場合もあれば重要でない場合もありますが、実際に正確にしたい場合は、平均をとるのではなく無視されないピクセルの数に応じて各画像の損失を重み付けする必要があります。

# Count relevant pixels on each image 
input_batch_relevant = 1 - input_batch_ignored 
input_batch_weight = tf.reduce_sum(input_batch_relevant, [1, 2, 3]) 
# Compute relative weights 
input_batch_weight = input_batch_weight/tf.reduce_sum(input_batch_weight) 
# Compute reduced loss according to weights 
reduced_loss = tf.reduce_sum(loss * input_batch_weight) 
+0

私は申し訳ありませんが、私は答えを完全に理解していません: 'input_batch_ignored = tf.concat(...)'の出力はどのように見えますか?ラベルチャネルを無視する以外は、すべてのチャネル(C)に ''ゼロ 'を含む 'predict'(N x H x W x C)と同じ形をしているようです。しかし、それはイメージのすべてのピクセルのignore_labelクラスを正しく予測することを意味しますか?私は 'gt_label'として' ignore_label'を持っているピクセルだけを選択する必要があると思います。だから私はそれらのラベルのインデックスを取得するためにMatlabの '(myLabelTensor == ignore_label)のような操作が必要です... – mcExchange

+0

@mcExchangeあなたが言ったように、' input_batch_ignored'は無視されたクラスを除いてすべてゼロです。 'input_batch'は保持されます。これは無限大で乗算され、ロジットに追加され、無視されたクラスのピクセルが常に正しいように予測を効果的に変更します(無限大が悪い結果をもたらす可能性があり、その代わりに十分大きな数値を使用する必要があります)。これは、これらのピクセルが最終コストに0を与えることを意味します。 – jdehesa

+1

@mcExchange無視されたラベルを-1に置き換えたい場合、 'label_wo_ignored = tf.select(label!= ignore_label、label、-1 * tf.ones_like(label))'のようなことをすることができますが、確かにそれはあなたが望む損失を与える(私は少なくともドキュメントによるとは限りません)。 – jdehesa

関連する問題