2017-02-08 3 views
0

ディスクからテンソルフローモデルを読み込み、値を予測しようとしています。マップをpysparkを使用してmapPartitionに変換する

コード

def get_value(row): 
    print("**********************************************") 
    graph = tf.Graph() 
    rowkey = row[0] 
    checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300" 
    print("Loading model................................") 
    with graph.as_default(): 
     session_conf = tf.ConfigProto(
      allow_soft_placement=allow_soft_placement, 
      log_device_placement=log_device_placement) 
     sess = tf.Session(config=session_conf) 
     with sess.as_default(): 
      # Load the saved meta graph and restore variables 
      saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 
      saver.restore(sess, checkpoint_file) 
      input_x = graph.get_operation_by_name("X_train").outputs[0] 
      dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 
      predictions = graph.get_operation_by_name("output/predictions").outputs[0] 
      batch_predictions = sess.run(predictions, {input_x: [row[1]], dropout_keep_prob: 1.0}) 
      print(batch_predictions) 
      return (rowkey, batch_predictions) 

Iタプル(のrowKey、input_vector)から成るRDDを有します。ロードされたモデルを使用して、入力のスコア/クラスを予測したいと思います。

(GET_VALUEを呼び出すコード)

result = data_rdd.map(lambda iter: get_value(iter)) 
result.foreach(print) 

問題は、私はマップを呼び出すたびにある、モデルは、各タプルのためにロードされ、毎回あり、それは多くの時間を要します。

私はmapPartitionsを使用してモデルをロードすることを考えた後、のget_value機能を呼び出すためにマップを使用しています。 私はテンソルフローモデルを分割ごとに一度だけロードし、実行時間を短縮するmapPartitionにコードを変換する方法のヒントはありません。

ありがとうございます。

答えて

0

以下のコードは、mapPartitionsを使用するため、大幅に改善されたと思います。 session_pickle =はcPickleで、 ファイル "/home/sahil/Desktop/Relation_Extraction/temp.py"、行465:

コード

def predict(rows): 
    graph = tf.Graph() 
    checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300" 
    print("Loading model................................") 
    with graph.as_default(): 
     session_conf = tf.ConfigProto(
      allow_soft_placement=allow_soft_placement, 
      log_device_placement=log_device_placement) 
     sess = tf.Session(config=session_conf) 
     with sess.as_default(): 
      # Load the saved meta graph and restore variables 
      saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 
      saver.restore(sess, checkpoint_file) 
     print("**********************************************") 
     # Get the placeholders from the graph by name 
     input_x = graph.get_operation_by_name("X_train").outputs[0] 
     dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 
     # Tensors we want to evaluate 
     predictions = graph.get_operation_by_name("output/predictions").outputs[0] 

     # Generate batches for one epoch 
     for row in rows: 
      X_test = [row[1]] 
      batch_predictions = sess.run(predictions, {input_x: X_test, dropout_keep_prob: 
      yield (row[0], batch_predictions) 


result = data_rdd.mapPartitions(lambda iter: predict(iter)) 
result.foreach(print) 
1

ご質問が正しく行われたかどうかは不明ですが、こちらでコードを最適化することができます。

graph = tf.Graph() 

checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300" 

with graph.as_default(): 
     session_conf = tf.ConfigProto(
      allow_soft_placement=allow_soft_placement, 
      log_device_placement=log_device_placement) 
     sess = tf.Session(config=session_conf) 

s = sess.as_default() 
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 
saver.restore(sess, checkpoint_file) 


input_x = graph.get_operation_by_name("X_train").outputs[0] 
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 
predictions = graph.get_operation_by_name("output/predictions").outputs[0] 

session_pickle = cPickle.dumps(sess) 

def get_value(key, vector, session_pickle): 
    sess = cPickle.loads(session_pickle) 
    rowkey = key 
    batch_predictions = sess.run(predictions, {input_x: [vector], dropout_keep_prob: 1.0}) 
    print(batch_predictions) 
    return (rowkey, batch_predictions 



result = data_rdd.map(lambda (key, row): get_value(key=key, vector = row , session_pickle = session_pickle)) 
result.foreach(print) 

テンソルフローのセッションをシリアル化できます。私はここであなたのコードをテストしていませんが。これを実行してコメントを残してください。

+0

エラーが 'トレースバック(最新の呼び出しの最後)をポップアップ表示します.dumps(sess) TypeError:SwigPyObjectオブジェクトをpickleできません ' – wadhwasahil

関連する問題