2016-03-26 7 views
0

2つのサブグラフ間で変数を共有したいと思います。より正確には、私はfowolling操作を行うにはしたいと思います:4つのテンソルabcdw可変重み与え、W*aW*bW*cW*dを計算するが、異なるサブグラフに。私が持っているコードは、以下の通りである:テンソルフローと変数共有でのvariable_scopeとname_scopeの理解

def forward(inputs): 
    w = tf.get_variable("weights", ...) 
    return tf.matmult(w, inputs) 

with tf.name_scope("group_1"): 
    a = tf.placeholder(...) 
    b = tf.placeholder(...) 
    c = tf.placeholder(...) 

    aa = forward(a) 
    bb = forward(b) 
    cc = forward(c) 

with tf.name_scope("group_2): 
    d = tf.placeholder(...) 

    tf.get_variable_scope().reuse_variable() 
    dd = forward(d) 

この例では、実行するようだが、私は、私は何の変数がないというエラーを得たtf.get_variable_scope.reuse_variable()を追加すると、変数Wgroup_1内、特に再利用されているかどうか私はわからないんだけど共有する。 テンソルボードでグラフを視覚化すると、サブグラフの中にいくつかのweigths_*があります。

+1

スケルトンコード( 'matmult'などのタイプミスを含む)ではなく、実際のコードを提供するのに本当に役立ちます。さらに、コードが "実行されているようだ"と言っているが、再利用したいときにいつも明示的に 'reuse_variables()'を指定しなければならないので、タイプミスを修正した後でも、 'bb = forward(b)変数。作業コードについては私の答えを見てください。 – MiniQuark

答えて

1

次のコードは、あなたが欲しいものを行います。

import tensorflow as tf 

def forward(inputs): 
    init = tf.random_normal_initializer() 
    w = tf.get_variable("weights", shape=(3,2), initializer=init) 
    return tf.matmul(w, inputs) 

with tf.name_scope("group_1"): 
    a = tf.placeholder(tf.float32, shape=(2, 3), name="a") 
    b = tf.placeholder(tf.float32, shape=(2, 3), name="b") 
    c = tf.placeholder(tf.float32, shape=(2, 3), name="c") 
    with tf.variable_scope("foo", reuse=False): 
     aa = forward(a) 
    with tf.variable_scope("foo", reuse=True): 
     bb = forward(b) 
     cc = forward(c) 

with tf.name_scope("group_2"): 
    d = tf.placeholder(tf.float32, shape=(2, 3), name="d") 
    with tf.variable_scope("foo", reuse=True): 
     dd = forward(d) 

init = tf.initialize_all_variables() 

with tf.Session() as sess: 
    sess.run(init) 
    print(bb.eval(feed_dict={b: np.array([[1,2,3],[4,5,6]])})) 
    for var in tf.all_variables(): 
     print(var.name) 
     print(var.eval()) 

理解するためのいくつかの重要な事柄:

  • name_scope()get_variable()で作成された変数を除くすべてのOPS に影響を与えます。
  • スコープ内に変数を配置するには、variable_scope()を使用する必要があります。たとえば、プレースホルダab、およびcは実際に"group_1/a""group_1/b""group_1/c""group_1/d"を命名しているが、weights変数が"foo/weights"という名前です。 したがって、という名前のスコープのget_variable("weights")と、可変スコープ"foo"は実際には"foo/weights"を検索します。

all_variables()機能は、どのような変数が存在し、どのように名前が付けられているかわからない場合に便利です。

関連する問題