2017-12-08 1 views
1

私は5D入力テンソルで3D畳み込みを実行するネットワークを持っています。サイズ(1,12,60,36,60)が(BatchSize、NumClasses、x-dim、y-dim、z-dim)に対応する場合、私のネットワークの出力。私は、ボクセルのクロスエントロピー損失を計算する必要があります。しかし、私は間違いを続けています。例PytorchでのセマンティックセグメンテーションのためのCrossEntropyLoss

torch.nn.CrossEntropyLoss()を用いたクロスエントロピー損失を計算しようと、私は次のエラーメッセージを取得し続ける:ここ

RuntimeError: multi-target not supported at .../src/THCUNN/generic/ClassNLLCriterion.cu:16 

は私のコードの抜粋です:私が作成したときに

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
criterion = torch.nn.CrossEntropyLoss() 
images = Variable(torch.randn(1, 12, 60, 36, 60)).cuda() 
labels = Variable(torch.zeros(1, 12, 60, 36, 60).random_(2)).long().cuda() 
loss = criterion(images.view(1,-1), labels.view(1,-1)) 

同じことが起こりますラベルのためのワンホットテンソル:

nclasses = 12 
labels = (np.random.randint(0,12,(1,60,36,60))) # Random labels with values between [0..11] 
labels = (np.arange(nclasses) == labels[..., None] - 1).astype(int) # Converts labels to one_hot_tensor 
a = np.transpose(labels,(0,4,3,2,1)) # Reorder dimensions to match shape of "images" ([1, 12, 60, 36, 60]) 
b = Variable(torch.from_numpy(a)).cuda() 
loss = criterion(images.view(1,-1), b.view(1,-1)) 

私は何ですか?間違っている? 誰かが5D出力テンソルでクロスエントロピーを計算する例を提供できますか?

答えて

0

docsこの動作を説明するには、(一番下の行は、それが実際には、これにより、出力のすべての次元のターゲットを必要としない、スパースクロスエントロピー損失を計算しているように見えますが、必要な1のインデックスのみ)...彼らは、具体的な状態:

Input: (N,C), where C = number of classes 
Target: (N), where each value is 0 <= targets[i] <= C-1 
Output: scalar. If reduce is False, then (N) instead. 

私はあなたのユースケースについてはよく分からないんだけど、あなたの代わりにKL DivergenceまたはBinary Cross Entropy Lossを使用する場合があります。どちらも入力と同じサイズのターゲットで定義されます。

0

2Dセマンティックセグメンテーションの実装(fcn)を確認し、3Dセマンティックセグメンテーションに適応させようとしました。これが正しいことを保証するものではありません。私は再度確認する必要があります。

import torch 
import torch.nn.functional as F 
def cross_entropy3d(input, target, weight=None, size_average=True): 
    # input: (n, c, h, w, z), target: (n, h, w, z) 
    n, c, h, w , z = input.size() 
    # log_p: (n, c, h, w, z) 
    log_p = F.log_softmax(input, dim=1) 
    # log_p: (n*h*w*z, c) 
    log_p = log_p.permute(0, 4, 3, 2, 1).contiguous().view(-1, c) # make class dimension last dimension 
    log_p = log_p[target.view(n, h, w, z, 1).repeat(1, 1, 1, 1, c) >= 0] # this looks wrong -> Should rather be a one-hot vector 
    log_p = log_p.view(-1, c) 
    # target: (n*h*w*z,) 
    mask = target >= 0 
    target = target[mask] 
    loss = F.nll_loss(log_p, target.view(-1), weight=weight, size_average=False) 
    if size_average: 
     loss /= mask.data.sum() 
    return loss 
images = Variable(torch.randn(5, 3, 16, 16, 16)) 
labels = Variable(torch.LongTensor(5, 16, 16, 16).random_(3)) 
cross_entropy3d(images, labels, weight=None, size_average=True) 
関連する問題