开发者

Pytorch中retain_graph的坑及解决

开发者 https://www.devze.com 2023-02-22 09:21 出处:网络 作者: Longlongaaago
目录Pytorch中retain_graph的坑Pytorch中有多次backward时需要retain_graph参数解决办法总结Pytorch中retain_graph的坑
目录
  • Pytorch中retain_graph的坑
  • Pytorch中有多次backward时需要retain_graph参数
    • 解决办法
  • 总结

    Pytorch中retain_graph的坑

    在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用编程客栈就是

    在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

      ###########php#################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(target)
        if torch.cuda.is_available():
          real_img = real_img.cuda()
        z = Variable(data)
        if torch.cuda.is_avxhmKovDailable():
          z = z.cuda()
        fake_img = netG(z)
    
        nphpetD.zero_grad()
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True) #####
        optimizerD.step()
    
        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()
        optimizerG.step()
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()
    
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        running_results['g_loss'] += g_loss.data[0] * BATch_size
        d_loss = 1 - real_out + fake_out
        running_results['d_loss'] += d_loss.data[0] * batch_size
        running_results['d_score'] += real_out.data[0] * batch_size
        running_results['g_score'] += fake_out.data[0] * batch_size

    也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True)  让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True)就可以单独的计算梯度,屡试不爽。

    但是另外一个问题在于,如果你都这么用的话,显存会爆炸,因为他保留了梯度,所以都没有及时释放掉,浪费资源。

    而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!

    Pytorch中有多次backward时需要retain_graph参数

    Pytorch中的机制是每次调用loss.backward()时都会free掉计算图中所有缓存的buffers,当模型中编程客栈可能有多次backward()时,因为前一次调用backward()时已经释放掉了buf开发者_Go培训fer,所以下一次调用时会因为buffers不存在而报错

    解决办法

    loss.backward(retain_graph=True)

    错误使用

    • optimizer.zero_grad() 清空过往梯度;
    • loss1.backward(retain_graph=True) 反向传播,计算当前梯度;
    • loss2.backward(retain_graph=True) 反向传播,计算当前梯度;
    • optimizer.step() 根据梯度更新网络参数

    因为每次调用bckward时都没有将buffers释放掉,所以会导致内存溢出,迭代越来越慢(因为梯度都保存了,没有free)

    正确使用

    • optimizer.zero_grad() 清空过往梯度;
    • loss1.backward(retain_graph=True) 反向传播,计算当前梯度;
    • loss2.backward() 反向传播,计算当前梯度;
    • optimizer.step() 根据梯度更新网络参数

    最后一个 backward() 不要加 retain_graph 参数,这样每次更新完成后会释放占用的内存,也就不会出现越来越慢的情况了

    总结

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

    0

    精彩评论

    暂无评论...
    验证码 换一张
    取 消