2017-02-09 9 views
0

hereからすべてのTensorFlow操作

登録OPSのためOpDefsのリストを取得するには、いくつかの方法があります。C APIの

  • TF_GetAllOpListが登録されているすべての取得は、 OpDefプロトコルメッセージ。これを使用して、 クライアント言語でジェネレータを書き込むことができます。これは、OpDefメッセージを解釈するために、クライアント言語が プロトコルバッファサポートを有することを必要とする。
  • C++関数OpRegistry :: Global() - > GetRegisteredOps()は、登録されたすべてのOpDef( [tensorflow/core/framework/op.h]で定義されている)の同じリストを返します。これは、 ジェネレータをC++で記述するために使用できます(特に、 にはプロトコルバッファがサポートされていない言語で便利です)。
  • このリストのASCIIシリアル化バージョンは、自動的に プロセスによって[tensorflow/core/ops/ops.pbtxt]にチェックインされます。

しかし、悲しいかな、私はPythonでこれをやりたい

import tensorflow as tf 
from google.protobuf import json_format 
json_string = json_format.MessageToJson(tf.GetAllOpsList()) 

、のような私はJSONを介したとして、それをダンプすることができるようにTensorflow内のすべての操作のためにいるProtobufメッセージを取得するための方法をしたいです

答えて

2

ops.txtです。文字列出力を生成するopsのすべてのメッセージをOpDefにリストする例を示します。

import tensorflow as tf 

from tensorflow.core.framework import op_def_pb2 
from google.protobuf import text_format 

def get_op_types(op): 
    for attr in op.attr: 
     if attr.type != 'type': 
      continue 
     return list(attr.allowed_values.list.type) 
    return [] 

# directory where you did "git clone" 
tensorflow_git_base = "/Users/yaroslav/tensorflow.git" 
ops_file = tensorflow_git_base+"/tensorflow/tensorflow/core/ops/ops.pbtxt" 
ops = op_def_pb2.OpList() 
text_format.Merge(open(ops_file).read(), ops) 

for op in ops.op: 
    # get templated string types 
    if tf.string in get_op_types(op): 
     print(op.name, op.summary) 
    #for arg in op.input_arg: 
    for arg in op.output_arg: 
     if arg.type == tf.string: 
      print(op.name, op.summary) 
      break 

あなたは現在のPythonラッパーはそれを行う方法エンジニアを逆転できる追加される新しいOPSに敏感になりたい場合は** ** を追加しました。たとえば、gen_array_ops.pyファイルを考えてみましょう。それはそうそれらのメッセージのprotobufsがgen_array_opsの生成中に基本的なCコードから生成されている次のスニペット

def _InitOpDefLibrary(): 
    op_list = _op_def_pb2.OpList() 
    _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list) 
    _op_def_registry.register_op_list(op_list) 
    op_def_lib = _op_def_library.OpDefLibrary() 
    op_def_lib.add_op_list(op_list) 
    return op_def_lib 


_InitOpDefLibrary.op_list_ascii = """op { 
    name: "BatchMatrixBandPart" 
    input_arg { 
    name: "input" 
    type_attr: "T" 
    } 
    input_arg { 
    name: "num_lower" 
    type: DT_INT64 
    } 
    input_arg { 
    name: "num_upper" 
    type: DT_INT64 
    } 
    output_arg { 
    name: "band" 
    type_attr: "T" 
    } 
    attr { 
    name: "T" 
    type: "type" 
    } 
    deprecation { 
    version: 14 
    explanation: "Use MatrixBandPart" 
    } 
} 

を持っています。それらがどのように生成されたかを追跡するために、https://stackoverflow.com/a/41149557/419116

+0

を参照してください。私はカスタム操作を追加してProtobufを入手できるので、 –

+0

追加情報を追加しました。 gen _ * _ ops.py生成中に実行されたのと同じ世代のスクリプトを呼び出すことによって "ops.txt"を再生成するか、基本的なCの機能を見つけていくつかのswigラッパーを追加することができます –

+0

ありがとう!ほぼそこに。私はasciiについて気にしないが、私はJSONが必要だ。そこで、私は 'google.protobuf import json_format'から' json_format.MessageToJson(gen_array_ops._InitOpDefLibrary()._ ops ['Const']。op_def) 'が私にしたいものをほとんど見つけたことを発見しました。これはただのmetaInfoDefですが、graphDefブロックも必要です –

関連する問題