2017-12-14 3 views
0

私はpix2pixHDの事前訓練モデルを使用して自分の画像を生成しようとしています。 Github repo found here自分のデータセットでpix2pixHDエラー

データセット内の画像は、アルファチャンネルのないグレースケールである必要があります。レポ内の画像は16 bitPerSampleのサイズを持ち、サイズ8と16 bitPerSampleの両方の画像を持っています。

sips -g allを使用して自分の画像とレポの画像をチェックすると、

pixelWidth: 2048 
pixelHeight: 1024 
typeIdentifier: public.png 
format: png 
formatOptions: default 
dpiWidth: 72.000 
dpiHeight: 72.000 
samplesPerPixel: 1 
bitsPerSample: 16 
hasAlpha: no 
space: Gray 

奇妙なことは、それが8 bitPerSampleを持っているイメージで動作することである:これは私が得る結果です。私はtest.py 16とは、bitsPerSampleイメージを実行すると、それは動作しません output

グレースケール入力 grayscale 変換されたラベルマップ Input 最終的な出力: これは私が得る結果です。 これは、私を与えるエラーです:

model [Pix2PixHDModel] was created 
Traceback (most recent call last): 
    File "test.py", line 26, in <module> 
    for i, data in enumerate(dataset): 
    File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 210, in __next__ 
    return self._process_next_batch(batch) 
    File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 230, in _process_next_batch 
    raise batch.exc_type(batch.exc_msg) 
TypeError: Traceback (most recent call last): 
    File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in _worker_loop 
    samples = collate_fn([dataset[i] for i in batch_indices]) 
    File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in <listcomp> 
    samples = collate_fn([dataset[i] for i in batch_indices]) 
    File "/home/paperspace/Documents/pix2pixHD/data/aligned_dataset.py", line 41, in __getitem__ 
    label_tensor = transform_label(label) * 255.0 
    File "/usr/local/lib/python3.5/dist-packages/torch/tensor.py", line 309, in __mul__ 
    return self.mul(other) 
TypeError: mul received an invalid combination of arguments - got (float), but expected one of: 
* (int value) 
     didn't match because some of the arguments have invalid types: (float) 
* (torch.IntTensor other) 
     didn't match because some of the arguments have invalid types: (float) 

私はかなりTensorflowに新しいですし、私が前にpytorchを使用したことがありません。

このエラーの意味は何ですか、どうすれば解決できますか?

答えて

0

はい、私はあなたを助けることができると思います。 私はリポジトリをチェックし、エラーから問題を追跡していないが、以下のように見える:

あなたはtransform_label(label)の出力(おそらくテンソル)とスカラー255.0 betweenn乗算演算を実行しています。あなたのスカラーとテンソルが同じである限り、これは問題ありませんdatatype。しかし、エラートレースからは、transform_label()の出力がデータ型がInt/Longであるかのように見えますが、255.0は浮動小数点です。

255.0の代わりに255またはint(255.0)をお試しください。

これで問題が解決しない場合は、transform_label()の出力のデータタイプを教えてください。

関連する問題