torch.no_grad()与梯度叠加

torch.no_grad()与梯度叠加

torch.no_grad

with torch.no_grad():
	output = model(img)

在计算网络输出的时候,不存储梯度,能节约很多显存。但是多出来的空间用于增加batch size, 速度并没有提升。

梯度叠加

在显存不足的时候,若是想要实现增大batch size, 可累积梯度,在多个batch之后一起更新网络参数。

# some code
# Initialize dataset with batch size 10
opt.zero_grad()
for i, (input, target) in enumerate(dataset):
    pred = net(input)
    loss = crit(pred, target)
    # one graph is created here
    loss.backward()
    # graph is cleared here
    if (i+1)%10 == 0:
        # every 10 iterations of batches of size 10
        opt.step()
        opt.zero_grad()

参考
[1] https://www.zhihu.com/question/303070254
[2] https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20

评论

此博客中的热门博文

使用ssh反向代理+shadowsocks实现内网穿透

shadowsocks中转

ubuntu 16.04 reboot命令慢的原因