1つのパラメータサーバーに格納され、各ワーカーによって(非同期に)インクリメントされるグローバルカウンタを実装しようとする(分散しようとする)2つのバージョンのTensorflowコードがあります。分散Tensorflowはどのようにこの種のtf.Variable作成を処理しますか?
どちらのバージョンも同じものを表示するようですが、その理由はわかりません。バージョンの違いは、コメントで示される2行、# NEW
です。
各ワーカーがバージョン1を実行しているとき、パラメータサーバはワーカーごとに自動的にlocal_counter
tf.Variable
を保存しますか?
バージョン2では、それぞれlocal_counter
tf.Variable
を明示的にパラメーターサーバーに配置しようとしています。
次のバージョン1またはバージョン2では、実際に違いはありますか?
PS:これはすべてのインスタンスで共有されているtf.Variable
を管理する最善の方法ではないと私は確信しています。ありがとう!
バージョン1
# Standard distributed Tensorflow boilerplate
# ...
elif FLAGS.job_name == 'worker':
TASK = FLAGS.task_index
with tf.device('/job:ps/task:0/cpu:0'):
with tf.variable_scope('global'):
global_counter = tf.Variable(0, name='global_counter',
trainable=False)
local_counter = tf.Variable(0, name='local_counter_{}'.format(TASK),
trainable=False)
init_op = tf.global_variables_initializer()
with tf.device('/job:worker/task:{}'.format(TASK)):
with tf.variable_scope('local'):
local_inc_op = local_counter.assign_add(1)
global_inc_op = global_counter.assign_add(1)
with tf.Session(server.target):
sess.run(init_op)
global_count = 0
while global_count < 1000:
sess.run([local_inc_op, global_inc_op])
local_count, global_count = sess.run([local_counter, global_counter])
print('Local {}, Global {}, worker-{}'.format(
local_count, global_count, TASK))
バージョン2
# Standard distributed Tensorflow boilerplate
# ...
elif FLAGS.job_name == 'worker':
NUM_WORKERS = len(worker_hosts)
TASK = FLAGS.task_index
with tf.device('/job:ps/task:0/cpu:0'):
with tf.variable_scope('global'):
global_counter = tf.Variable(0, name='global_counter',
trainable=False)
local_counters = [tf.Variable(0, name='local_counter_{}'.format(i),
trainable=False)
for i in range(NUM_WORKERS)] # NEW
init_op = tf.global_variables_initializer()
with tf.device('/job:worker/task:{}'.format(TASK)):
with tf.variable_scope('local'):
local_counter = local_counters[TASK] # NEW
local_inc_op = local_counter.assign_add(1)
global_inc_op = global_counter.assign_add(1)
with tf.Session(server.target):
sess.run(init_op)
global_count = 0
while global_count < 1000:
sess.run([local_inc_op, global_inc_op])
local_count, global_count = sess.run([local_counter, global_counter])
print('Local {}, Global {}, worker-{}'.format(
local_count, global_count, TASK))
Gotchaので、違いは労働者のグラフです。さもなければ、機能は基本的に同一です。ありがとう、アレン!私はResourceMgrを調べます。 – awalllllll