2016-05-12 6 views
1

ここでTensorflowのRNNサンプルを変更しようとしています。 ptb_word_lm.pyでTensorflow RNNサンプルのIDを取得する

https://www.tensorflow.org/versions/r0.8/tutorials/recurrent/index.html

私は、彼らは単語インデックス(m.input_data:X)のint型の配列を入力していると思います。

def run_epoch(session, m, data, eval_op, verbose=False): 
    """Runs the model on the given data.""" 
    epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps 
    start_time = time.time() 
    costs = 0.0 
    iters = 0 
    state = m.initial_state.eval() 
    for step, (x, y) in enumerate(reader.ptb_iterator(data, m.batch_size, 
                m.num_steps)): 
    cost, state, _ = session.run([m.cost, m.final_state, eval_op], 
           {m.input_data: x, 
            m.targets: y, 
            m.initial_state: state}) 

idsの代わりに実際の言葉を見たいのですが、どのように表示できますか?

答えて

1

まず、単語からidへのインデックスである語彙を保持する必要があります。

メインの先頭には、以下のようにreader.ptb_raw_data()から4番目の戻り値を保持します。

raw_data = reader.ptb_raw_data(FLAGS.data_path) 
train_data, valid_data, test_data, vocabulary = raw_data 

次に、run_epoch()に語彙を渡します。あなたはxの最初のステップで言葉にIDを変換したいrun_epoch()、内部

test_perplexity = run_epoch(session, mtest, test_data, tf.no_op(), vocabulary) 

def run_epoch(session, m, data, eval_op, vocabulary, verbose=False): 

... 
for step, (x, y) in enumerate(... 

message ="x: " 
for i in range(0, m.num_steps): 
    key = vocabulary.keys()[vocabulary.values().index(x[0][i])] 
    message += key + " " 

print(message) 

はそれがお役に立てば幸いです。

+0

ありがとうございました!私はこれを試してみる。 – Hub

関連する問題