私は5D入力テンソルで3D畳み込みを実行するネットワークを持っています。サイズ(1,12,60,36,60)が(BatchSize、NumClasses、x-dim、y-dim、z-dim)に対応する場合、私のネットワークの出力。私は、ボクセルのクロスエントロピー損失を計算する必要があります。しかし、私は間違いを続けています。例PytorchでのセマンティックセグメンテーションのためのCrossEntropyLoss
torch.nn.CrossEntropyLoss()
を用いたクロスエントロピー損失を計算しようと、私は次のエラーメッセージを取得し続ける:ここ
RuntimeError: multi-target not supported at .../src/THCUNN/generic/ClassNLLCriterion.cu:16
は私のコードの抜粋です:私が作成したときに
import torch
import torch.nn as nn
from torch.autograd import Variable
criterion = torch.nn.CrossEntropyLoss()
images = Variable(torch.randn(1, 12, 60, 36, 60)).cuda()
labels = Variable(torch.zeros(1, 12, 60, 36, 60).random_(2)).long().cuda()
loss = criterion(images.view(1,-1), labels.view(1,-1))
同じことが起こりますラベルのためのワンホットテンソル:
nclasses = 12
labels = (np.random.randint(0,12,(1,60,36,60))) # Random labels with values between [0..11]
labels = (np.arange(nclasses) == labels[..., None] - 1).astype(int) # Converts labels to one_hot_tensor
a = np.transpose(labels,(0,4,3,2,1)) # Reorder dimensions to match shape of "images" ([1, 12, 60, 36, 60])
b = Variable(torch.from_numpy(a)).cuda()
loss = criterion(images.view(1,-1), b.view(1,-1))
私は何ですか?間違っている? 誰かが5D出力テンソルでクロスエントロピーを計算する例を提供できますか?