ディスクからテンソルフローモデルを読み込み、値を予測しようとしています。マップをpysparkを使用してmapPartitionに変換する
コード
def get_value(row):
print("**********************************************")
graph = tf.Graph()
rowkey = row[0]
checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300"
print("Loading model................................")
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=allow_soft_placement,
log_device_placement=log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
input_x = graph.get_operation_by_name("X_train").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
batch_predictions = sess.run(predictions, {input_x: [row[1]], dropout_keep_prob: 1.0})
print(batch_predictions)
return (rowkey, batch_predictions)
Iタプル(のrowKey、input_vector)から成るRDDを有します。ロードされたモデルを使用して、入力のスコア/クラスを予測したいと思います。
(GET_VALUEを呼び出すコード)
result = data_rdd.map(lambda iter: get_value(iter))
result.foreach(print)
問題は、私はマップを呼び出すたびにある、モデルは、各タプルのためにロードされ、毎回あり、それは多くの時間を要します。
私はmapPartitionsを使用してモデルをロードすることを考えた後、のget_value機能を呼び出すためにマップを使用しています。 私はテンソルフローモデルを分割ごとに一度だけロードし、実行時間を短縮するmapPartitionにコードを変換する方法のヒントはありません。
ありがとうございます。
エラーが 'トレースバック(最新の呼び出しの最後)をポップアップ表示します.dumps(sess) TypeError:SwigPyObjectオブジェクトをpickleできません ' – wadhwasahil