2016-08-26 14 views
5

私がしたいのは、複数のトレーニング済みTensorflowネットを同時に実行することです。各ネット内のいくつかの変数の名前は同じである可能性があるため、ネットを作成するときには名前スコープを使用するのが一般的な解決策です。しかし、問題は、これらのモデルを訓練し、訓練された変数をいくつかのチェックポイントファイルの中に保存することです。ネットを作成するときに名前スコープを使用した後、チェックポイントファイルから変数を読み込むことができません。複数のトレーニング済みTensorflowネットを同時に実行する

たとえば、私はAlexNetを訓練しましたが、2つの変数セットを比較したいと思います.1つはエポック10(ファイルepoch_10.ckptに保存されます)と別のセットはエポック50ファイルepoch_50.ckpt)。これらの2つはまったく同じネットなので、内部の変数の名前は同じです。私はこのネットの訓練を受けたとき、私は名前のスコープを使用していなかったので、しかし、私は.ckptファイルから訓練された変数をロードすることはできません

with tf.name_scope("net1"): 
    net1 = CreateAlexNet() 
with tf.name_scope("net2"): 
    net2 = CreateAlexNet() 

を使用して、2つのネットを作成することができます。ネットを訓練するときに名前スコープを "net1"に設定することはできますが、これによってnet2の変数をロードできなくなります。

私が試してみました:

with tf.name_scope("net1"): 
    mySaver.restore(sess, 'epoch_10.ckpt') 
with tf.name_scope("net2"): 
    mySaver.restore(sess, 'epoch_50.ckpt') 

これは動作しません。

この問題を解決する最善の方法は何ですか?

答えて

10

最も簡単な解決策は、各モデルのために別々のグラフを使用し、異なるセッションを作成することです。これはいくつかの理由で動作しない場合

# Build a graph containing `net1`. 
with tf.Graph().as_default() as net1_graph: 
    net1 = CreateAlexNet() 
    saver1 = tf.train.Saver(...) 
sess1 = tf.Session(graph=net1_graph) 
saver1.restore(sess1, 'epoch_10.ckpt') 

# Build a separate graph containing `net2`. 
with tf.Graph().as_default() as net2_graph: 
    net2 = CreateAlexNet() 
    saver2 = tf.train.Saver(...) 
sess2 = tf.Session(graph=net1_graph) 
saver2.restore(sess2, 'epoch_50.ckpt') 

を、そしてあなたは、単一のtf.Sessionを使用する必要があります(たとえば、

    すでに実行しているように名前スコープに異なるネットワークを作成し、0を指定すると、2つのネットワークの結果を別のTensorFlow計算で結合したいからです。
  1. 別のtf.train.Saverインスタンスを2つのネットワークに作成し、変数名を再マップするための追加の引数を付けます。セーバーをconstructingとき

は、あなたがそれぞれのモデルで作成したtf.Variableオブジェクトへのチェックポイント(すなわち、名前スコープ接頭辞なし)での変数の名前をマッピングし、var_list引数として辞書を渡すことができます。

あなたは、プログラムvar_listを構築することができ、あなたは、次のような何かを行うことができるはず:

with tf.name_scope("net1"): 
    net1 = CreateAlexNet() 
with tf.name_scope("net2"): 
    net2 = CreateAlexNet() 

# Strip off the "net1/" prefix to get the names of the variables in the checkpoint. 
net1_varlist = {v.name.lstrip("net1/"): v 
       for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")} 
net1_saver = tf.train.Saver(var_list=net1_varlist) 

# Strip off the "net2/" prefix to get the names of the variables in the checkpoint. 
net2_varlist = {v.name.lstrip("net2/"): v 
       for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")} 
net2_saver = tf.train.Saver(var_list=net2_varlist) 

# ... 
net1_saver.restore(sess, "epoch_10.ckpt") 
net2_saver.restore(sess, "epoch_50.ckpt") 
+0

驚くような答え! – denru

+0

lstripを使用して接頭辞を削除すると、結果が正しくない可能性があります。代わりにスライスしてください。コードの他の部分は完全に機能します。別の質問は、変数の名前が ":0"、 ":1"のような接尾辞を持つことがわかったということです。変数をチェックポイントファイルに格納する前に、この後置を取り除く必要がありますか? – denru

+0

誰でもこの回答を試みましたか?私は、何もしない 'restore'関数の問題にぶつかっています:http://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session – TheCriticalImperitive

0

私は私に長い時間を気に同じ問題を抱えています。ここで良い解決策を見つけました:Loading two models from Saver in the same Tensorflow sessionTensorFlow checkpoint save and read

tf.train.Saver()のデフォルトの動作は、各変数を対応するopの名前に関連付けることです。つまり、tf.train.Saver()を作成するたびに、前の呼び出しのすべての変数が含まれます。したがって、別のグラフを作成し、それらと異なるセッションを実行する必要があります。

関連する問題