0
torch.nn
にはクラスBatchNorm1d
,BatchNorm2d
,BatchNorm3d
がありますが、完全に接続されたBatchNormクラスはありませんか? PyTorchで標準のバッチノルムを実行する標準的な方法は何ですか?PyTorchで完全に接続されたバッチノルムを行うには?
torch.nn
にはクラスBatchNorm1d
,BatchNorm2d
,BatchNorm3d
がありますが、完全に接続されたBatchNormクラスはありませんか? PyTorchで標準のバッチノルムを実行する標準的な方法は何ですか?PyTorchで完全に接続されたバッチノルムを行うには?
私はそれを考え出した。 BatchNorm1d
もRank-2テンソルを扱うことができるので、通常の完全接続の場合はBatchNorm1d
を使用することができます。
したがって、たとえば:
import torch.nn as nn
class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
super(Policy2, self).__init__()
self.action_space = action_space
num_outputs = action_space
self.linear1 = nn.Linear(num_inputs, hidden_size1)
self.linear2 = nn.Linear(hidden_size1, hidden_size2)
self.linear3 = nn.Linear(hidden_size2, num_outputs)
self.bn1 = nn.BatchNorm1d(hidden_size1)
self.bn2 = nn.BatchNorm1d(hidden_size2)
def forward(self, inputs):
x = inputs
x = self.bn1(F.relu(self.linear1(x)))
x = self.bn2(F.relu(self.linear2(x)))
out = self.linear3(x)
return out
例(コード)を教えてください。 – Royi
何があなたがこれらの層が完全に接続されていないと思うのですか? – Maximilian