Featured image of post Solve Multi Times Backwards In pytorch

Solve Multi Times Backwards In pytorch

本文提供了torch报错Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed的深入排查解决方案

Tip

本文提供了torch报错Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed的深入排查解决方案

Note

排查方法摘要:

  • 确认是否在一个iter中多次反向传播
  • 确认是否在多个iters之间没调用optimizer.zero_grad()
  • 可视化并确认是否及时detach()了模型参数

前言

实操pytorch也有几年了,虽然都是兴趣使然的小打小闹,但是各种问题或多或少都碰到过。
一般来说,绝大多数torch的报错都可以Google到,或者在StackOverflow中找到相关解决方案。
再不济,也可以去GitHub里相关issue中找到解决方案。

当然,作为最终武器,我们还有ChatGPT,在现在8k上下文的情况下,基本上可以解决大部分问题。

但是很不幸,我遇到的是这个错误:

1
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward. 

我一共遇到了3次这个问题,我之所以说很不幸,是因为它debug起来极其困难,有的时候还会沦落到不得不可视化计算图的地步。 甚至问GPT,它都只会给我笼统地说检查我是不是backwards了多次什么的,毫无参考性。

这个错误的意思是说在反向传播过程中,计算图在完成反向传播前就被释放了,导致无法再次反向传播。
一般来说,我们在训练GAN时比较容易遇到,因为可能会有量要被要被反向传播训练G和D两个网络。这种情况下,就老老实实使用retain_graph=True就好了。

其他情况请按照下文排查。

常规检查

废话不多说,再深入研究代码前,先确定是不是由以下错误导致的:

  • 在一个iter中多次反向传播,即类似:
    1
    2
    3
    
    loss.backwards()
    # do something
    loss.backwards()
    
  • 在多个iters之间没调用optimizer.zero_grad(),即类似:
    1
    2
    3
    4
    
    for i in range(100):
        # do something
        loss.backwards()
        optimizer.step()
    

以上这两种情况请自行排查解决。

继续深入

Note

先说结论:某些该detach().clone()的地方没有detach().clone()

如果不知道是哪里需要detach().clone(),可以看我下方的实践。

我没犯常规检查中的两个错误,但还是出现了这种问题🤔。
迫不得已,我继续深入研究。

我很快定位到问题出现在这一段代码:

1
2
3
4
5
6
7
    @property
    def get_xyz_purturbed(self):
        unc = sample_from_grid(self.grid, self._xyz.to(self._xyz.device).view(-1,3), self.lb, self.ub)
        noise = torch.normal(0, unc)
        # pdb.set_trace()
        noise = noise.to(self._xyz.device)
        return self._xyz + noise.reshape(*self._xyz.shape[:-1],1)

但是我不知道具体问题出在哪。

被折磨了好久😩,迫不得已我装了个torchviz,用它的make_dot可视化了一下计算图。

1
2
        dot = make_dot(loss)
        dot.save(f'loss.txt')

202310251555198 *不要在意我画的圈,这是我之前和学长讨论的时候留下来的
(在线可视化工具地址

其中,蓝色的方框是requires_grad=True的模型参数,橙色的方框是requires_grad=False的。灰色很显然是backwards的节点。

因为之前我测试过,能确定问题出在左边部分的计算图,我很快注意到了这里: 202310251557820

可以看到,在这里出现了两个requires_grad=True的模型参数,这就是问题所在了。
以下是我对这个问题的理解,可能有错误,欢迎指正。

我认为出现这个问题是因为,一般来说,模型的计算图是类似于树一样的结构(不是很准确,因为非叶子结点之间互相可以连接),根节点是loss,叶子结点是模型参数,而像这样的畸形结构导致torch无法在正确的时机释放计算图,从而导致了这个错误。

于是我优化了一下代码,将self._xyz(计算图中最顶上那个蓝色方块)detach并复制:

1
2
3
4
5
6
7
    @property
    def get_xyz_purturbed(self):
        unc = sample_from_grid(self.grid, self._xyz.cpu().detach().clone().to(self._xyz.device).view(-1,3), self.lb, self.ub)
        noise = torch.normal(0, unc)
        # pdb.set_trace()
        noise = noise.to(self._xyz.device)
        return self._xyz + noise.reshape(*self._xyz.shape[:-1],1)

这样就解决了问题。

可以看看我修改后的计算图: 202310251606170 十分的干净利索。

总结

多次反向传播是torch中较为棘手的问题,需要对torch的反向传播有一定的理解,并且要能完全明确自己的模型是如何backwards的。通常我们可以通过可视化计算图的方式来快速排查。

comments powered by Disqus