2017-11-21 3 views
0

私は実装したいQ &注目のメカニズムを持つシステム。私は2つの入力を持っています。 contextおよびqueryの形状は、(batch_size, context_seq_len, embd_size)および(batch_size, query_seq_len, embd_size)です。
私は以下の論文に従っています。 マッチLSTMとアンサーポインタを使用した機械理解。 https://arxiv.org/abs/1608.07905PyTorchの多次元配列への `for`ループ

次に、形状が(batch_size, context_seq_len, query_seq_len, embd_size)の注意行列を取得します。論文では、各行の値を計算します(これは、各コンテキスト・ワード、G_i、α_iを意味します)。

私のコードは以下のとおりで、動作しています。しかし私は自分の道が良いかどうか分からない。たとえば、シーケンスデータ(for i in range(T):)の生成にはfor loopを使用します。そして、各行を取得するために、私はG[:,i,:,:]のようなインプレース演算子を使用し、embd_context[:,i,:].clone()は良い方法ですか?そうでない場合は、コードをどこで変更する必要がありますか?

その他の点に気づいたら、教えてください。私はこのフィールドとピンクの新しいです。私のあいまいな質問を申し訳ありません。

class MatchLSTM(nn.Module): 
    def __init__(self, args): 
     super(MatchLSTM, self).__init__() 
     self.embd_size = args.embd_size 
     d = self.embd_size 
     self.answer_token_len = args.answer_token_len 

     self.embd = WordEmbedding(args) 
     self.ctx_rnn = nn.GRU(d, d, dropout = 0.2) 
     self.query_rnn = nn.GRU(d, d, dropout = 0.2) 

     self.ptr_net = PointerNetwork(d, d, self.answer_token_len) # TBD 

     self.w = nn.Parameter(torch.rand(1, d, 1).type(torch.FloatTensor), requires_grad=True) # (1, 1, d) 
     self.Wq = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d) 
     self.Wp = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d) 
     self.Wr = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d) 

     self.match_lstm_cell = nn.LSTMCell(2*d, d) 

    def forward(self, context, query): 
     # params 
     d = self.embd_size 
     bs = context.size(0) # batch size 
     T = context.size(1) # context length 
     J = query.size(1) # query length 

     # LSTM Preprocessing Layer 
     shape = (bs, T, J, d) 
     embd_context  = self.embd(context)   # (N, T, d) 
     embd_context, _h = self.ctx_rnn(embd_context) # (N, T, d) 
     embd_context_ex = embd_context.unsqueeze(2).expand(shape).contiguous() # (N, T, J, d) 
     embd_query  = self.embd(query)   # (N, J, d) 
     embd_query, _h = self.query_rnn(embd_query) # (N, J, d) 
     embd_query_ex = embd_query.unsqueeze(1).expand(shape).contiguous() # (N, T, J, d) 

     # Match-LSTM layer 
     G = to_var(torch.zeros(bs, T, J, d)) # (N, T, J, d) 

     wh_q = torch.bmm(embd_query, self.Wq.expand(bs, d, d)) # (N, J, d) = (N, J, d)(N, d, d) 

     hidden  = to_var(torch.randn([bs, d])) # (N, d) 
     cell_state = to_var(torch.randn([bs, d])) # (N, d) 
     # TODO bidirectional 
     H_r = [hidden] 
     for i in range(T): 
      wh_p_i = torch.bmm(embd_context[:,i,:].clone().unsqueeze(1), self.Wp.expand(bs, d, d)).squeeze() # (N, 1, d) -> (N, d) 
      wh_r_i = torch.bmm(hidden.unsqueeze(1), self.Wr.expand(bs, d, d)).squeeze() # (N, 1, d) -> (N, d) 
      sec_elm = (wh_p_i + wh_r_i).unsqueeze(1).expand(bs, J, d) # (N, J, d) 

      G[:,i,:,:] = F.tanh((wh_q + sec_elm).view(-1, d)).view(bs, J, d) # (N, J, d) # TODO bias 

      attn_i = torch.bmm(G[:,i,:,:].clone(), self.w.expand(bs, d, 1)).squeeze() # (N, J) 
      attn_query = torch.bmm(attn_i.unsqueeze(1), embd_query).squeeze() # (N, d) 
      z = torch.cat((embd_context[:,i,:], attn_query), 1) # (N, 2d) 

      hidden, cell_state = self.match_lstm_cell(z, (hidden, cell_state)) # (N, d), (N, d) 
      H_r.append(hidden) 
     H_r = torch.stack(H_r, dim=1) # (N, T, d) 

     indices = self.ptr_net(H_r) # (N, M, T) , M means (start, end) 
     return indices 
+2

https://codereview.stackexchange.comに送信してください。実際には、作業コードの品質を確認するための適切な場所ではありません。 –

+0

ああ、私はそのサイトを知らない。私はサイトに移動します。ありがとうございました。 https://codereview.stackexchange.com/questions/180984/for-loop-to-a-multi-dimensional-array-in-pytorch – jef

+0

問題ありません。おそらくここから質問を削除するべきでしょう。 –

答えて

1

あなたのコードは問題ありません。論文(https://openreview.net/pdf?id=B1-q5Pqxl)の方程式(2)では、G_ialpha_iベクトルの計算に関与するMatch-LSTMセルからの隠れ状態があり、入力を計算するために使用されているため、ループを避けることはできません。for i in range(T): Match-LSTMの次のタイムステップのために。ですから、Match-LSTMのすべてのタイムステップごとにループを実行する必要がありますが、私はforループを避けるための代替方法は見当たりません。

+0

毎回ありがとうございます。はい、そう思います。しかし、私はこの分野とPyTorchで新しいです。だから私はあなたからそれを聞いて安心しています。 – jef