私はバッチ内の各サンプルについてフォワードを行い、サンプルのモデル出力上のいくつかの条件に基づいて一部のサンプルについてのみ損失を累積するユースケースを持っています。ここにコードを例示します。私はいくつかのサンプルを転送するだけで、計算グラフはいつ解放されますか?
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
total_loss = 0
loss_count_local = 0
for i in range(len(target)):
im = Variable(data[i].unsqueeze(0).cuda())
y = Variable(torch.FloatTensor([target[i]]).cuda())
out = model(im)
# if out satisfy some condtion, we will calculate loss
# for this sample, else proceed to next sample
if some_condition(out):
loss = criterion(out, y)
else:
continue
total_loss += loss
loss_count_local += 1
if loss_count_local == 32 or i == (len(target)-1):
total_loss /= loss_count_local
total_loss.backward()
total_loss = 0
loss_count_local = 0
optimizer.step()
私の質問は、すべてのサンプルについては転送しますが、いくつかのサンプルについては後で行います。損失に寄与しないサンプルのグラフはいつ解放されますか?これらのグラフは、forループが終了した後、または次のサンプルを転送した直後に解放されますか?私はここで少し混乱しています。
total_loss
に寄与するサンプルについては、total_loss.backward()
の直後にグラフが解放されます。そうですか?