torch.no_grad()与梯度叠加
torch.no_grad
with torch.no_grad():
output = model(img)
在计算网络输出的时候,不存储梯度,能节约很多显存。但是多出来的空间用于增加batch size, 速度并没有提升。
梯度叠加
在显存不足的时候,若是想要实现增大batch size, 可累积梯度,在多个batch之后一起更新网络参数。
# some code
# Initialize dataset with batch size 10
opt.zero_grad()
for i, (input, target) in enumerate(dataset):
pred = net(input)
loss = crit(pred, target)
# one graph is created here
loss.backward()
# graph is cleared here
if (i+1)%10 == 0:
# every 10 iterations of batches of size 10
opt.step()
opt.zero_grad()
参考
[1] https://www.zhihu.com/question/303070254
[2] https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20
评论
发表评论