1

状態全体を明らかにすることができないため、リカレントニューラルネットワークを強化して、ネットワークに過去の出来事の何らかの記憶があるようにする必要があります。わかりやすくするために、LSTMを使用しているとしましょう。PyTorchでLSTMを使って強化学習を行う方法は?

内蔵のPyTorch LSTMでは、形状Time x MiniBatch x Input Dの入力をフィードに入力する必要があり、形状テンソルTime x MiniBatch x Output Dを出力します。

しかし、強化学習では、時刻t+1の入力を知るために、私は環境内でアクションを行っているので、時刻tの出力を知る必要があります。

補強学習設定で内蔵PyTorch LSTMを使用してBPTTを実行することは可能ですか?それがあれば、どうすればいいのですか?

答えて

1

入力シーケンスをループでLSTMに入力することができます。

h, c = Variable(torch.zeros()), Variable(torch.zeros()) 
for i in range(T): 
    input = Variable(...) 
    _, (h, c) = lstm(input, (h,c)) 

アクションを評価するために(h、c)入力することができます。たとえば、アクションを評価するためのタイムステップです。計算グラフを破らない限り、変数はすべての履歴を保持するので、バックプロパゲーションすることができます。

関連する問題