2017-07-30 6 views
0

私の目標は、(1)初期値としてファイルから事前にトレーニングされた単語の埋め込み行列を読み込みます。 (2)単語の埋め込みを固定しないで微調整する。 (3)モデルを復元するたびに、事前に訓練されたものの代わりに微調整された単語の埋め込みをロードします。Tensorflow:再トレーニング中に事前にトレーニングされた埋め込みの初期化の問題

私は好きSTH試してみました:

class model(): 
    def __init__(self): 
    # ... 
    def _add_word_embed(self): 
     W = tf.get_variable('W', [self._vsize, self._emb_size], 
       initializer=tf.truncated_normal_initializer(stddev=1e-4)) 
     W.assign(load_and_read_w2v()) 
     # ... 
    def _add_seq2seq(self): 
     # ... 
    def build_graph(self): 
     self._add_word_embed() 
     self._add_seq2seq() 

をしかし、このアプローチは、私が訓練を停止して再起動するたびに埋め込む微調整の単語をカバーします。私もmodel.build_graphを呼び出した後にsess.run(W.assign())を試しました。しかし、それはグラフが完成したので私はもうそれを変更することはできませんエラーを投げた。それを達成する正しい方法を教えてください。前もって感謝します!

EDIT:ITは、新しい要件を持っているよう

この質問が重複していない:USE THE TRAININGの先頭にWORDの埋め込みをPREは、訓練を受け、その後FIND-TUNE ITを。これを効率的に行う方法も尋ねます。この質問で受け入れられた回答は、この要件に合致したものではありません。どんな質問にも重複して表示される前に二度考えてみてください。ここで

+1

Uがで答えを見つけることができます.com/questions/35687678/pre-training-word-embedding-word2vec-glove-in-tensorflowを使用する –

+0

@vijaymどちらの答え?受け入れられた答えは、単語埋め込みを固定したままにしておくので、私が欲しいものではありません。 – user5779223

+0

受け入れられた回答(2)で、 'trainable = False'を削除するだけです。 –

答えて

2

はそれを行う方法についてのおもちゃの例である:

# The graph 

# Inputs 
vocab_size = 2 
embed_dim = 2 
embedding_matrix = np.ones((vocab_size, embed_dim)) 

#The weight matrix to initialize with embeddings 
W = tf.get_variable(initializer=tf.zeros([vocab_size, embed_dim]), name='embed', trainable=True) 

# global step used to take care of the weight initialization 
# for the first time will be loaded from numpy array and not during retraining. 
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') 

# Initialiazation of weights based on global_step 
initW = tf.cond(tf.equal(global_step, 0), lambda:W.assign(embedding_matrix), lambda: W) 
inc = tf.assign_add(W,[[1, 1],[1, 1]]) 

# Update global step 
update = tf.assign_add(global_step, 1) 
op = tf.group(inc, update) 

# init_fn 
def init_embed(sess): 
    sess.run(initW) 

我々はセッション中に上記を実行すると今すぐます。https:// stackoverflowの

sv = tf.train.Supervisor(logdir='tmp',init_fn=init_embed) 
with sv.managed_session() as sess: 
    print('global step:', sess.run(global_step)) 
    print('Initial weight:') 
    print(sess.run(W)) 
    for i in range(2): 
     sess.run([op]) 
    _ W, g_step= sess.run([W, global_step]) 
    print('Final weight:')   
    print(_W) 
    sv.saver.save(sess,sv.save_path, global_step=g_step) 

# Output at first run 
    Initial weight: 
    [[ 1. 1.] 
    [ 1. 1.]] 

    Final weight: 
    [[ 3. 3.] 
    [ 3. 3.]] 

#Output at second run 
    Initial weight: 
    [[ 3. 3.] 
    [ 3. 3.]] 
    Final weight: 
    [[ 5. 5.] 
    [ 5. 5.]] 
関連する問題