Tip
本文提供了torch报错Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed
的深入排查解决方案
Note
排查方法摘要:
- 确认是否在一个iter中多次反向传播
- 确认是否在多个iters之间没调用optimizer.zero_grad()
- 可视化并确认是否及时detach()了模型参数
前言
实操pytorch也有几年了,虽然都是兴趣使然的小打小闹,但是各种问题或多或少都碰到过。
一般来说,绝大多数torch的报错都可以Google到,或者在StackOverflow中找到相关解决方案。
再不济,也可以去GitHub里相关issue中找到解决方案。
当然,作为最终武器,我们还有ChatGPT
,在现在8k上下文的情况下,基本上可以解决大部分问题。
但是很不幸,我遇到的是这个错误:
|
|
我一共遇到了3次这个问题,我之所以说很不幸
,是因为它debug起来极其困难,有的时候还会沦落到不得不可视化计算图的地步。
甚至问GPT,它都只会给我笼统地说检查我是不是backwards了多次什么的,毫无参考性。
这个错误的意思是说在反向传播过程中,计算图在完成反向传播前就被释放了,导致无法再次反向传播。
一般来说,我们在训练GAN时比较容易遇到,因为可能会有量要被要被反向传播训练G和D两个网络。这种情况下,就老老实实使用retain_graph=True
就好了。
其他情况请按照下文排查。
常规检查
废话不多说,再深入研究代码前,先确定是不是由以下错误导致的:
- 在一个iter中多次反向传播,即类似:
1 2 3
loss.backwards() # do something loss.backwards()
- 在多个iters之间没调用optimizer.zero_grad(),即类似:
1 2 3 4
for i in range(100): # do something loss.backwards() optimizer.step()
以上这两种情况请自行排查解决。
继续深入
Note
先说结论:某些该detach().clone()
的地方没有detach().clone()
。
如果不知道是哪里需要detach().clone()
,可以看我下方的实践。
我没犯常规检查
中的两个错误,但还是出现了这种问题🤔。
迫不得已,我继续深入研究。
我很快定位到问题出现在这一段代码:
|
|
但是我不知道具体问题出在哪。
被折磨了好久😩,迫不得已我装了个torchviz
,用它的make_dot
可视化了一下计算图。
|
|
*不要在意我画的圈,这是我之前和学长讨论的时候留下来的
(在线可视化工具地址)
其中,蓝色的方框是requires_grad=True
的模型参数,橙色的方框是requires_grad=False
的。灰色很显然是backwards的节点。
因为之前我测试过,能确定问题出在左边部分的计算图,我很快注意到了这里:
可以看到,在这里出现了两个requires_grad=True
的模型参数,这就是问题所在了。
以下是我对这个问题的理解,可能有错误,欢迎指正。
我认为出现这个问题是因为,一般来说,模型的计算图是类似于树一样的结构(不是很准确,因为非叶子结点之间互相可以连接),根节点是loss,叶子结点是模型参数,而像这样的畸形结构导致torch无法在正确的时机释放计算图,从而导致了这个错误。
于是我优化了一下代码,将self._xyz
(计算图中最顶上那个蓝色方块)detach并复制:
|
|
这样就解决了问题。
可以看看我修改后的计算图: 十分的干净利索。
总结
多次反向传播是torch中较为棘手的问题,需要对torch的反向传播有一定的理解,并且要能完全明确自己的模型是如何backwards的。通常我们可以通过可视化计算图的方式来快速排查。