2017-01-05 3 views
0

後で簡単に取得できるように自分のオブジェクトを保存するにはtf.add_to_collection()を使用します。ここ は、コードセグメントである:テンソルフローでコレクションを使用して自分自身のオブジェクトを保持する方法

class Model(object): 
    def __init__(self, scope, is_training=True): 

は、オブジェクトをコレクションに追加します。

for i in xrange(num_gpus): 
    with tf.device("/gpu:%d"%i): 
     with tf.name_scope("tower_%d"%i) as scope: 
      m = Model.Model(scope) 
      tf.add_to_collection("train_model", m) 

コレクションからオブジェクトを取得:

models = tf.get_collection("train_model") 

コードが正常に動作しますが、私は警告が出ます:

WARNING:tensorflow:Error encountered when serializing train_model. 
Type is unsupported, or the types of the items don't match field type in CollectionDef. 
'Model' object has no attribute 'name 

この警告を避けるにはどうすればよいですか?

答えて

0

tf.train.Saver.save()を呼び出すと警告が(おそらく)生成され、tf.Graphの内容を表す「MetaGraph」(すべてのグラフのコレクションの内容を含む)を書き出しようとします。

saver.save()と呼び出すときに、警告を回避する最も簡単な方法は、write_meta_graph=Falseを渡すことです。ただし、後でインポートするMetaGraphがないままになります。

あなたがメタグラフを保存して警告を回避したい場合は、tf.train.MetaGraphDefシリアライズ形式のプロトコルバッファとしてあなたModelオブジェクトをシリアル化するために必要なフック(to_protofrom_proto)を実装する必要があります。 MetaGraph tutorialは、これを実行する方法について説明しますが、次のように基本的な考え方は次のとおりです。

  1. Modelオブジェクトの内容を記述プロトコルバッファ(ModelProto)を定義します。

  2. ModelProtoModelをシリアライズmodel_to_proto()機能定義:あなたのための機能を登録し

    def model_from_proto(model_proto): 
        # Construct a `Model` from the fields of `model_proto`. 
        return Model(...) 
    
  3. ModelProtoをデシリアライズし、Modelを返すmodel_from_proto()関数を定義

    def model_to_proto(model): 
        ret = ModelProto() 
        # Set fields of `ret` from `model`. 
        return ret 
    
  4. "train_model"コレクション。これは現在、文書化されていない機能を使用して、register_proto_function()呼ば:

    from tensorflow.python.framework import ops 
    
    ops.register_proto_function("train_model", 
              proto_type=ModelProto, 
              to_proto=model_to_proto, 
              from_proto=model_from_proto) 
    
関連する問題