2016-12-22 7 views
0

私は桁認識のために畳み込みニューラルネットワークを行っています。私は画像データセットを訓練したいが、どのように"バッチ"トレーニングデータを知りません。Tensorflow - データセットをバッチする方法

私はtrain_imageとtrain_labelを保存する二つの配列を取得:私はtrain_image_batch

を印刷するときのバッチサイズ= 50

sess.run(tf.initialize_all_variables()) 
    train_image_batch, train_label_batch = tf.train.shuffle_batch([train_image, 
     train_label, batch_size = 50, capacity = 50000, min_after_dequeue = 10000) 

と私はトレーニングデータバッチにしたい、今すぐ

print train_image.shape 
# (73257, 1024) 
# where I have 73257 images with size 32x32=1024 

print train_label.shape 
# (73257, 10) 
# Digit '1' has label 1, '9' has label 9 and '0' has label 10 

print train_image_batch 
# Tensor("shuffle_batch:0", shape=(50, 73257, 1024), dtype=unit8) 

形状が(50, 1024)

ここで何か問題がありますか?

+0

参照してください。 'total_size * image_size'次元は期待していません。トレーニングサンプルは、 'RecordReader()'のようなものを通してエンキューされることを期待しています。 [このブログ](https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/)をチェックしてください。 –

答えて

1

shuffle_batchは、デフォルトで1つのサンプルが必要です。複数のサンプルを受け入れるように強制するには、enqueue_many=Trueを渡します。あなたは() ``間違ったtf.train.shuffle_batchを使用しているdoc

train_image_batch, train_label_batch = tf.train.shuffle_batch(
    [train_image, train_label], batch_size = 50, enqueue_many=True, capacity = 50000, min_after_dequeue = 10000) 

print(train_image_batch.shape) 

Output: 
(50, 1024) 
関連する問題