2016-10-03 27 views
0

W(重み)とb(偏り)はtf.Variableを使用しましたが、X(入力バッチ)とY(このバッチの期待値) )。 すべて正常に動作します。しかし、今日、私はこのトピックを見つけました: Tensorflow github issues と引用:列車データにtf.placeholderの代わりにtensorflow tf.Variableを使用

Feed_dictはTensorFlowランタイムにPythonランタイムから内容のシングルスレッドのmemcpyを行います。 GPUでデータが必要な場合は、さらにCPU-> GPU転送が必要になります。私は

ネイティブTensorFlow(可変/キュー)にfeed_dictから切り替えたときにパフォーマンスが10倍向上させるまで見慣れてそして今、私は、入力データなしfeed_dictためtf.Variableまたはキューを使用する方法を見つけようとしています速度向上のために、特にバッチのために。私は1つずつデータを変更する必要があります。すべてのバッチが完了したとき - エポックの終わり。そして初めから、第2の時代より...

しかし、どうして私はそれを使うことができないのですか?ここで

+0

は、このチュートリアルhttps://www.tensorflow.org/versions/r0.11からcifar10_input.pyを見ます/tutorials/deep_cnn/index.html – mdaoust

答えて

1

は、あなたがトレーニングバッチを養うためにキューを使用する方法の自己完結型の例である:

import tensorflow as tf 

IMG_SIZE = [30, 30, 3] 
BATCH_SIZE_TRAIN = 50 

def get_training_batch(batch_size): 
    ''' training data pipeline -- normally you would read data from files here using 
    a TF reader of some kind. ''' 
    image = tf.random_uniform(shape=IMG_SIZE) 
    label = tf.random_uniform(shape=[]) 

    min_after_dequeue = 100 
    capacity = min_after_dequeue + 3 * batch_size 
    images, labels = tf.train.shuffle_batch(
     [image, label], batch_size=batch_size, capacity=capacity, 
     min_after_dequeue=min_after_dequeue) 
    return images, labels 

# define the graph 
images_train, labels_train = get_training_batch(BATCH_SIZE_TRAIN) 
'''inference, training and other ops generally are defined here too''' 

# start a session 
with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 

    ''' do something interesting here -- training, validation, etc''' 
    for _ in range(5): 
     # typical training step where batch data are drawn from the training queue 
     py_images, py_labels = sess.run([images_train, labels_train]) 
     print('\nData from queue:') 
     print('\tImages shape, first element: ', py_images.shape, py_images[0][0, 0, 0]) 
     print('\tLabels shape, first element: ', py_labels.shape, py_labels[0]) 

    # close threads 
    coord.request_stop() 
    coord.join(threads) 
関連する問題