2016-10-29 3 views
0

テンソルフローのキューにデータをプッシュするコードを作成しました。キューハンドラのinitとすべてのスレッドが実行するmain関数データをプッシュするスレッドの数を増やすとテンソルフローのキューがいっぱいになります

def __init__(self): 
    self.X = tf.placeholder(tf.int64) 
    self.Y = tf.placeholder(tf.int64) 
    self.queue = tf.RandomShuffleQueue(dtypes=[tf.int64, tf.int64], 
             capacity=100, 
             min_after_dequeue=20) 

    self.enqueue_op = self.queue.enqueue([self.X, self.Y]) 


def thread_main(self, sess, coord): 
    """Cycle through the dataset until the main process says stop.""" 
    train_fs = open(data_train, 'r') 
    while not coord.should_stop(): 
     X_, Y_ = get_batch(train_fs) 
     if not Y: #We're at the end of the file 
      train_fs = open(data_train, 'r') 
      X, Y = get_batch(train_fs) 
     sess.run(self.enqueue_op, feed_dict={self.X:X_, self.Y:Y_}) 

トレーニング中にキューのサイズを監視しています。何らかの理由で、データをプッシュするスレッドの数を増やすと、キューがいっぱいになります。どんな考え?それは私が同時にpythonファイルから読んでいるからですか?

編集:ここでは

は、私は、データと、それはまったく同じであるグラフの横に、使用していたコードです。このダミーデータでは、コードが期待どおりに動作しています。私は2つの観測があります

  • 私は適切にスレッドを閉じていないと思うが、彼らが実行した後、キューに立ち往生し、より多くの私はそれを取得するコードより低速を実行しますようです。
  • マルチスレッドはここで仕事をしているので、私のグラフとデータを読み取る方法は、2つの失敗の唯一の点だと思います。

まず、ダミーのデータセットを生成します

data_train = "./test.txt" 

with open(data_train, 'w') as out_stream: 
    out_stream.write("""[1,2,3,4,5,6]|1\n[1,2,3,4]|2\n[1,2,3,4,5,6]|0\n[1,2,3,4,5,6]|1\n[1,2,5,6]|1\n[1,2,5,6]|0""") 

def get_batch(fs): 
    line = fs.readline() 
    X, Y = line.split('|') 
    X = eval(X) 
    Y = eval(Y) 
    return X, Y 

を次にキューコントローラ:

queue_ctrl = QueueCtrl() 
X_, Y_ = queue_ctrl.get_batch_from_queue() 
output = Y_ * tf.reduce_sum(X_) 
init = tf.initialize_all_variables() 

最後に、我々は繰り返す:

import tensorflow as tf 
import numpy as np 
import threading 

tf.reset_default_graph()#Reset the graph essential to use with jupyter else variable conflicts 

class QueueCtrl(object): 

    def __init__(self): 
     self.X = tf.placeholder(tf.int64) 
     self.Y = tf.placeholder(tf.int64) 
     self.queue = tf.RandomShuffleQueue(dtypes=[tf.int64, tf.int64], 
              capacity=100, 
              min_after_dequeue=20) 

     self.enqueue_op = self.queue.enqueue([self.X, self.Y]) 


    def thread_main(self, sess, coord): 
     """Cycle through the dataset until the main process says stop.""" 
     train_fs = open(data_train, 'r') 
     while not coord.should_stop(): 
      X_, Y_ = get_batch(train_fs) 
      if not Y_: #We're at the end of the file 
       train_fs = open(data_train, 'r') 
       X_, Y_ = get_batch(train_fs) 
      sess.run(self.enqueue_op, feed_dict={self.X:X_, self.Y:Y_}) 

    def get_batch_from_queue(self): 
     """ 
     Return one batch 
     """ 
     return self.queue.dequeue() 

    def start_threads(self, sess, coord, num_threads=2): 
     """Start the threads""" 
     threads = [] 
     for _ in range(num_threads): 
      t = threading.Thread(target=self.thread_main, args=(sess, coord)) 
      t.daemon = True 
      t.start() 
      threads.append(t) 
     return threads 

はその後、我々はダミーグラフを構築データ:

ここ
sess = tf.Session() 

sess.run(init) 
coord = tf.train.Coordinator() 
tf.train.start_queue_runners(sess=sess, coord=coord) 
my_thread = queue_ctrl.start_threads(sess, coord, num_threads=6) 

for i in range(100): 
    out = sess.run(output) 
    print("Iter: %d, output: %d, Element in queue: %d" 
       % (i, out, sess.run(queue_ctrl.queue.size()))) 

coord.request_stop() 
for _ in range(len(my_thread)): #if the queue is full at that time then the threads won't see the coord.should_stop 
    _ = sess.run([output]) 

coord.join(my_thread, stop_grace_period_secs=10) 
sess.close() 

5つのスレッドを持つ25の最初に出力されます。一つのスレッドで

Iter: 0, output: 21, Element in queue: 27 
Iter: 1, output: 21, Element in queue: 37 
Iter: 2, output: 20, Element in queue: 51 
Iter: 3, output: 21, Element in queue: 67 
Iter: 4, output: 20, Element in queue: 81 
Iter: 5, output: 20, Element in queue: 89 
Iter: 6, output: 21, Element in queue: 100 
Iter: 7, output: 20, Element in queue: 100 
Iter: 8, output: 20, Element in queue: 100 
Iter: 9, output: 21, Element in queue: 100 
Iter: 10, output: 20, Element in queue: 100 
Iter: 11, output: 20, Element in queue: 100 
Iter: 12, output: 21, Element in queue: 100 
Iter: 13, output: 21, Element in queue: 100 
Iter: 14, output: 20, Element in queue: 100 
Iter: 15, output: 20, Element in queue: 100 
Iter: 16, output: 21, Element in queue: 100 
Iter: 17, output: 21, Element in queue: 100 
Iter: 18, output: 20, Element in queue: 100 
Iter: 19, output: 21, Element in queue: 100 
Iter: 20, output: 21, Element in queue: 100 
Iter: 21, output: 21, Element in queue: 100 
Iter: 22, output: 20, Element in queue: 100 
Iter: 23, output: 21, Element in queue: 100 
Iter: 24, output: 21, Element in queue: 100 
Iter: 25, output: 21, Element in queue: 100 

Iter: 0, output: 21, Element in queue: 22 
Iter: 1, output: 20, Element in queue: 25 
Iter: 2, output: 20, Element in queue: 27 
Iter: 3, output: 20, Element in queue: 29 
Iter: 4, output: 21, Element in queue: 31 
Iter: 5, output: 20, Element in queue: 32 
Iter: 6, output: 20, Element in queue: 34 
Iter: 7, output: 21, Element in queue: 35 
Iter: 8, output: 21, Element in queue: 36 
Iter: 9, output: 21, Element in queue: 38 
Iter: 10, output: 20, Element in queue: 40 
Iter: 11, output: 20, Element in queue: 42 
Iter: 12, output: 20, Element in queue: 43 
Iter: 13, output: 21, Element in queue: 46 
Iter: 14, output: 20, Element in queue: 47 
Iter: 15, output: 21, Element in queue: 48 
Iter: 16, output: 20, Element in queue: 53 
Iter: 17, output: 20, Element in queue: 56 
Iter: 18, output: 21, Element in queue: 57 
Iter: 19, output: 21, Element in queue: 61 
Iter: 20, output: 21, Element in queue: 63 
Iter: 21, output: 20, Element in queue: 67 
Iter: 22, output: 21, Element in queue: 70 
Iter: 23, output: 21, Element in queue: 73 
Iter: 24, output: 20, Element in queue: 76 
Iter: 25, output: 20, Element in queue: 78 
+1

スレッドがキューへのアクセスが競合しているため、最も可能性が高いです。小さな再現可能な例を投稿できますか? – fabrizioM

+0

私はオフィスを去りました。私は明日もコードに直面し、質問を更新します。また、RandomShuffleQueueの代わりにFIFOキューを使用して、前処理がCPUによって行われていることを確認します。 – user3091275

+0

もう少しコードを追加しましたが、まだ調査中です。 – user3091275

答えて

2

はちょうど私がマルチベースデータを実装し、ここで何かを追加したいですマルチタスク学習のためのパイプラインの供給。それは平均を達成することができます。 GPU使用率> 90%、クアッドコアCPU使用率> 95%。メモリリークを起こしにくく、日々のトレーニングに特に適しています。完璧だと言っているわけではありませんが、少なくとも私の場合は現在のTFキューAPI(1.1)よりも優れています。

もし興味がある人:https://hanxiao.github.io/2017/07/07/Get-10x-Speedup-in-Tensorflow-Multi-Task-Learning-using-Python-Multiprocessing/

関連する問題