2016-08-14 10 views
7

変数とその状態をチェックポイントから読み取るにはどうすればよいですか?Tensorflow。チェックポイントのリスト変数

私は自動エンコーダで作業しています。チェックポイントには、エンコーダ、デコーダ、オプティマイザなどのネットワークの完全な状態が含まれています。エンコーディングをばかにしたいので、評価モードでネットワークのデコーダ部分のみが必要です。

より抽象的な方法で同じ質問:別のモデルで再利用するために既存のチェックポイントから特定の変数のみを読み取るにはどうすればよいですか?

私は変数に対応して名前を付けるべきですか? かのようなものを取得する方法があります:

w_init = read_from_state(state_location, var_name) 

def read_from_state(state_location, var_name): 
    # the magic goes here 
    pass 

答えて

14

あなたが保存されたすべての変数を参照することができますcheckpoint_utils.pylist_variables方法があります。

ただし、使用する場合は、Saverで復元する方が簡単です。チェックポイントを保存したときに変数の名前がわかっている場合は、新しいセーバーを作成し、その名前を新しいVariableオブジェクト(おそらく名前が異なる)に初期化するように指示できます。これはCIFARの例で復元を選択するために使用されますsubset of variables。 (指定されている場合、または単に1)自分のコンテンツと一緒にすべてのチェックポイントテンソルを印刷することを、ハウツー

0

別の方法でChoosing which Variables to Save and Restoreを参照してください:

from tensorflow.python.tools import inspect_checkpoint as inch 
inch.print_tensors_in_checkpoint_file('path/to/ckpt', '', True) 
""" 
Args: 
    file_name: Name of the checkpoint file. 
    tensor_name: Name of the tensor in the checkpoint file to print. 
    all_tensors: Boolean indicating whether to print all tensors. 
""" 

それは常にテンソルの内容を印刷します。

そして、我々はそれでありながら、ここでは前の回答で提案、checkpoint_utilsを使用する方法である:

from tensorflow.contrib.framework.python.framework import checkpoint_utils 
    var_list = checkpoint_utils.list_variables('path/to/ckpt') 
    for v in var_list: print(v) 
関連する問題