2016-05-31 13 views
3

Tensorflowの使用中に、チェックポイントファイルを使用してCIFAR10トレーニングを再開しようとしています。いくつかの他の記事を参照して、私はtf.train.Saver()を試してみました。誰かが私にどのように進行するかについて光を当てることができますか? Tensorflow CIFAR10テンソルフローcifar10チェックポイントファイルからトレーニングを再開

def train(): 
    # methods to build graph from the cifar10_train.py 
    global_step = tf.Variable(0, trainable=False) 
    images, labels = cifar10.distorted_inputs() 
    logits = cifar10.inference(images) 
    loss = cifar10.loss(logits, labels) 
    train_op = cifar10.train(loss, global_step) 
    saver = tf.train.Saver(tf.all_variables()) 
    summary_op = tf.merge_all_summaries() 

    init = tf.initialize_all_variables() 
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) 
    sess.run(init) 


    print("FLAGS.checkpoint_dir is %s" % FLAGS.checkpoint_dir) 

    if FLAGS.checkpoint_dir is None: 
    # Start the queue runners. 
    tf.train.start_queue_runners(sess=sess) 
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) 
    else: 
    # restoring from the checkpoint file 
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
    tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) 

    # cur_step prints out well with the checkpointed variable value 
    cur_step = sess.run(global_step); 
    print("current step is %s" % cur_step) 

    for step in xrange(cur_step, FLAGS.max_steps): 
    start_time = time.time() 
    # **It stucks at this call ** 
    _, loss_value = sess.run([train_op, loss]) 
    # below same as original 

答えて

2

問題から

コードスニペットは、この行ことのようだ:

tf.train.start_queue_runners(sess=sess) 

...のみFLAGS.checkpoint_dir is None場合に実行されます。チェックポイントから復元する場合は、キューランナーを開始する必要があります。 tf.train.Saver(原因コードのリリースバージョンの競合状態に)作成後、私はあなたがキューがランナースタートをお勧めしたい

注意は、その優れた構造は次のようになります。

if FLAGS.checkpoint_dir is not None: 
    # restoring from the checkpoint file 
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
    tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) 

# Start the queue runners. 
tf.train.start_queue_runners(sess=sess) 

# ... 

for step in xrange(cur_step, FLAGS.max_steps): 
    start_time = time.time() 
    _, loss_value = sess.run([train_op, loss]) 
    # ... 
+0

感謝あなたは答えのために!それは問題を解決しました。私はqueue_runnerが(歪みによって)入力画像を作成する責任があると思っていました。チェックポイントファイルから復元するのに必要なステップではありません。 – emerson

関連する問題