2017-10-24 11 views
0

私は非常にpytorchに新しく、CNNに画像ではなく行列を入力する方法を理解したいと思います。 私は次のように試しましたが、いくつかのエラーが発生します。pytorchでCNNに行列を入力する方法

class FrameDataSet(tud.Dataset): 
    def __init__(self, data): 
     targets = data['class'].values.tolist() 
     features = data.drop('class', axis=1).astype(np.int64).values 

     self.datalist = features.reshape((-1, feature_num, frame_size)) 
     self.labellist = targets 

    def __getitem__(self, index): 
     return torch.Tensor(self.datalist[index].astype(float)), self.labellist[index] 

    def __len__(self): 
     return self.datalist.shape[0] 

そして、私のCNNは、次のとおりです: 私は次のように私のデータセットを定義

self.conv = nn.Sequential(
     nn.Conv2d(1, 12, 3), 
     nn.ReLU(True), 
     nn.MaxPool2d(3, 3)) 
self.fc1 = nn.Linear(80, 100) 
self.fc2 = nn.Linear(100, 30) 
self.fc3 = nn.Linear(30, 5) 

しかし、データがCNNに入力されたときに、エラーがもたらし:

ファイル「/家庭を/sparks/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py "、48行目、conv2d raise ValueError("予想される4Dテンソルを入力として、代わりに{} Dテンソルを得ました。 " (input.dim())) Exp 4Dテンソルを入力として、代わりに3Dテンソルを得ました。

答えて

2

あなたの入力におそらく1次元がありません。それは次のようになります。

(BATCH_SIZE、チャンネル、幅、高さ)

あなただけのバッチで一つの要素を持っている場合は、テンソルは、あなたのケース

などでなければなりません(1,1,28,28)

あなたの最初のconv2dレイヤーは1チャンネルの入力を期待していたからです。

+0

Cool!できます。ありがとう〜 –

関連する問題