2017-11-03 3 views
-1

以下は私のCNNです。それの入力は(3,64)行列です。x、y、z軸それぞれを処理するために3つの畳み込みカーネルを使用したいと思います。Pytorchの3 * nマトリックスの3軸に3つのConv1dを使用するにはどうすればよいですか?

class Char_CNN(nn.Module): 
    def __init__(self): 
     super(Char_CNN, self).__init__() 
     self.convdx = nn.Conv1d(1, 12, 20) 
     self.convdy = nn.Conv1d(1, 12, 20) 
     self.convdz = nn.Conv1d(1, 12, 20) 
     self.fc1 = nn.Linear(540, 1024) 
     self.fc2 = nn.Linear(1024, 30) 
     self.fc3 = nn.Linear(30, 13) 

    def forward(self, x): 
     after_convd = [self.convdx(x[:, :, 0]), self.convdy(x[:, :, 1]), self.convdz(x[:, :, 2])] 
     after_pool = [F.max_pool1d(F.relu(value), 3) for value in after_convd] 

     x = torch.cat(after_pool, 1) 
     x = x.view(x.size(0), -1) 
     x = self.fc1(x) 
     x = self.fc2(x) 
     x = self.fc3(x) 
     x = F.softmax(x) 
     return x 

しかしloss = criterion(out, target)の走行時に、ランタイムエラーが発生します。

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.

私は私のコードの間違いを見つけることができないように、私はpytorchに非常に新しいです。 私を助けることができますか?

答えて

0

畳み方は大丈夫です。問題は私のラベルが1から13の間であり、正しい範囲が0から12です。 それを修正した後、私のCNNは正常に動作します。 しかし、Pytorchや深い学習にはより充実しているので、私の畳み込みモードはより明確で簡単になると思います。ようこそ私のエラーを指摘!

+0

あなたの質問は不明です。あなたが遭遇したエラーは、あなたのcriterion関数がターゲットラベルを0から12と期待しているが、ラベルを1から13として提供しているからです。あなたがやっていることの詳細を提供できますか?入力は何ですか?私たちが助けることができるように説明してください。 –

+0

申し訳ありませんが、問題は解決され、コードは正常に実行されています。 –

関連する問題