Skip to content

AI编程转录稿:20251030自动微分2

· 82 min

自动微分#

课程回顾:自动微分与计算图#

我们来上本周的 AI 中的编程这门课。回顾一下,我们上周进入了面向 AI 框架的部分,介绍了 AI 框架的一些主要任务。并且我们从一个很重要的任务,也就是自动微分入手,聊了一下怎么做相关的微分。我们也讲了一些经典的过去的做法,比如说数值微分的计算,基于符号的微分计算。

上节课最重要的事情是,我们讲到了自动微分的一个基于计算图的实现。并且我们提出不仅仅是要基于反向遍历我们的计算图,实现自动微分,我们还提出要在反向遍历计算图的同时,对它进行扩展,把反向计算微分的图,在执行代码的时候也给它扩展出来,构建出我们的计算图,去进行相关的微分操作。这是我们上节课一个比较核心的内容,也是大家习题要去实现的一部分工作,是比较重要的内容。

回到计算图这边,我们接下来的课程,一方面是要给自动微分进行一个收尾,进行一些扩展性的讨论。除了讨论以外,我们就要逐步步入到更多跟计算图相关的内容。

至少在上一节课当中,我们对计算图已经有了一个比较明确的具体的认知。我们认为计算图本身是定义了计算规则的一种计算的数据表达形式。同时通过图的扩展,它也可以把我们的反向自动微分的计算,也作为计算图的表达形式包含在这里面。所以这是为什么计算图在我们整个 AI 框架当中是非常重要的事情。

前向自动微分 vs 反向自动微分#

上节课稍微有一点没有提的,就是我们整个做的是一个反向的自动微分,做的是一个基于 adjoint 这个伴随量不断的叠乘、按照链式法则推算的一个想法。但事实上,这个算法也非常类似地可以去做前向的自动微分的一种计算方式。当然前向的话,你算的就不再是伴随量,而是一些切向量。大家可以想象一下,如果你要做一个前向的切向量的自动微分计算的计算图扩展的话,那可能这套方法也是完全适用的。有可能反向遍历会改成前向的遍历,查找 input 的时候有可能要反向去,有可能要前向找它的 output 等等,但总之这套方法也是可以实现的。

上节课到后面讲得比较仓促了,但我依然觉得这里面的一些概念对我们来说是非常重要的,大家可能在很多场合也会看到,为了避免以后看到的时候觉得困惑,课程也有一个目的是把这些概念解释清楚。

我们想说的概念就是关于大家常见的反向的微分计算,同时我们还经常讨论的与它相对立的,也有前向的自动微分计算的一种方式。我们也提到了反向的微分主要是伴随量,前向的微分主要是切向量。

有的时候你会看到一些地方说,反向的自动微分计算经常算的是向量和雅各比矩阵(Vector-Jacobian Product, VJP)的这种方式。前向 mode 是反过来做雅各比和向量乘(Jacobian-Vector Product, JVP)的计算方式。如果只听了这两句,也经常会让人感到困惑。

首先,向量和雅各比矩阵乘,这里的向量其实指的就是我们的伴随量,即因变量对我们中间各种各样 Operator 计算出来的这个量的微分。与之相对应的,前向的时候涉及到的雅各比矩阵和向量的乘,这个向量其实指的是我们的一个切向量,即中间变量对我们的输入自变量的一个微分。

另外一个是雅各比的定义。雅各比矩阵肯定是因变量对自变量微分形成的矩阵。但在我们局限在这样的一个环节下,你关心的雅各比矩阵,它应该不是整个网络(比如说 Y=f(X)Y = f(X))的 YX\frac{\partial Y}{\partial X} 产生的雅各比矩阵。这里的雅各比矩阵,其实单指某一步当中,你有一些 input 是 vjv_j,有一些 output 是 viv_i,那么这些所有 output viv_i 对于所有的 input vjv_j 它都会产生一个微分,整个微分它是一个雅各比矩阵。

自动微分的一个前置要求,就是要求这些微分都是有定义的、好算的。具体来说,它们可以是矩阵乘、sine、cosine,总之是一些基础的微分计算操作,都是我们底层算子能够提供的一些微分计算。而自动微分所做的事情,就是把算子提供的这些微分组织成雅各比矩阵以后,给我不停地去做反向的伴随向量和这个雅各比矩阵乘的计算,最终直到我推导出我的因变量对于我所有的自变量θ\theta的一个微分。

雅各比矩阵的计算与 PyTorch API#

对于整体的雅各比矩阵来说,我们的反向 mode,考虑的是一个一个 row 的情形。如果是前向的话,你考虑的就是一个一个 column 的情形。

这些概念在什么时候比较有用呢?在一些计算当中,我们经常还是会需要算一下雅各比矩阵的。你已经有了自动微分的工具,比如说 PyTorch 提供了 torch.autograd.grad 这个函数,那它是不是能够服务于我们的雅各比的计算?

我们知道,如果你正常的调用 PyTorch 的一个自动微分计算,它一般会采用的是 Adjoint 模式,也就是你算的是某一行。这样的话,如果你的任务是想要去构建一个雅各比矩阵,你真正做的事情肯定是需要调用它给你提供某一行。这时候,因为我们考虑的是雅各比矩阵,所以我们的 Y 有可能不是一个 scale,它是有很多个维度的。我就需要针对它每一个维度 YjY_j,去考虑这个维度的 YjY_j 相对于我所有的自变量 X 它的自动微分是什么形式的。我把每一行都算出来了以后,我可能是需要按照行去做一个 concatenate,一个堆叠,最终就能完成这个雅各比矩阵的计算。

整体我们认为自动微分跟 AI learning 都跟我们的微分计算有莫大的关系。最重要的是自动微分,涉及到了 backward 以及 grad 的 API 等等。大家现在在很多 AI 相关的写代码当中,对这些 API 都已经不陌生了。现在我们已经学习了自动微分的形式,大家也可以再回顾一下 PyTorch 的文档,对于这些 API 有什么更具体的定义,就能有些更深刻的理解。

除了我们了解到的默认用反向微分算伴随量的 AutoDiff 以外,我们发现 PyTorch 2.0 也已经开始支持一些前向的自动微分了。当然从 PyTorch 2.0 开始,这是一个实验性质的功能,即使到目前为止,它也依然还是一个实验性质的功能。甚至它还有更高级的一些提供,比如说就直接给你提供了 Jacobian、Hessian 矩阵的运算等等。了解它的微分机制以后,你也可以想象一下这些 API 都是如何实现的。

课程作业:操作符重载与自动微分实现#

回到我们的代码习题上,大家的任务是实现一下操作符重载,来构建一个基于反向的自动微分算法。一方面大家习题就需要做这个事情,另外一方面,大家在大家的大作业当中,可能也可以根据这个习题,对我们之前的 lab 有一些回顾。比如说之前我们的 Tensor 是怎么设计的,我们之前每个 Operator 它的一些 gradient 是怎么反馈的,我怎么把它用一个 AutoDiff 的反向微分操作,在 Python 里面能够把它连接起来。

Tensor 这个类,它其实是我们面向用户的接口,经常是一些神经网络的参数等等。总之,很多时候我们认为它是需要被优化的,它就会有一个 requires_gradient=True 的参数。大家其实就是要支持给它算自动微分的功能。

我们知道了,要做微分的计算,如果要做 gradient 的计算,这个东西是需要在操作符重载当中完成的。所以你就会发现在这个 Tensor 类当中,你有可能会要做一些操作符的重载。我们这里也是举了一个最简单的拿 PyTorch 举的例子。比如说你两个可以优化的是 A 和 B,Q 可能是一个你最终要用来被优化的因变量。Q 我们也使用了最简单的规则,比如说就是一些幂、乘、加、减等等。如果它在做这些操作的时候,它就会自动地去调这些操作符的重载。我们就知道在操作符重载的时候,应该要记录这个 Tensor,它的 Operator 对应的 input 是谁,把它记录下来,才能够服务于我们后续的一个自动微分的计算。

我们认为在你刚声明这两个 Tensor 的时候,A 和 B 显然都还不具备 gradient 函数,应该还是 None 的状态。主要就是因为你还没有去做反向图的构建。我们认为反向图应该是从 backward 这个函数入口开始构建的。如果开始调 backward,这时候就意味着我们可以去构建整个反向图,去进行相关的反向图构建操作。这里就应该涉及到我们课上提到的,根据 gradient 如何去用 adjoint 的方法去算我们的 gradient 的相关操作。

Tensor 类与 Value 类#

大家在构建图的时候,会发现图里面肯定是需要各种各样的表达的。这里面的无论是 adjoint 还是计算图中间的一些 viv_i,它都是需要一些表达的。这时候你可以用的就是一个 Value 的类。我们注意到 Tensor 也是 Value 类的一个继承关系。那 Tensor 和 Value 肯定是有所区别的。

我们认为 Value 更多的是我们计算图当中的一些节点,它很重要的一个作用就是来存储它是一个怎样的 creator(运算),同时它又包括哪些 input。有了这些 input,这就是我们在反向遍历图的时候一个主要的依据,你是怎么查找每一个节点它的 input 的。所以 Value 是我们辅助计算图运算,作为计算图节点的这样一种数据结构。而 Tensor 更多的是用户的接口,让用户可以去调 gradient,自动微分。写一写这个习题,都会对 Tensor 和 Value 它们服务于怎样的目的有了更清晰的了解。

gradient 最开始的时候它可能是 None,但是你过完 backward 以后,这个 gradient 就应该已经有了我们赋予的一些节点,那么这些节点在需要的时候,就应该可以被计算。

操作符重载的优缺点#

我们大概讲的就是基于操作符重载的一个反向图扩展的自动微分计算。总结来说,我们觉得它的优点是实现起来比较简单。你只需要去做一个操作符的重载,就意味着你可以比较好的,在 Python 语言上去把这样的思路实现。这样用户在编程的时候,他就直接用 Python 原生的加减乘除就可以用,而不需要去调一些很特别的 API,也就意味着用户使用的一个应用性是比较高的。

相应的就是我们也会分析一些它的缺点。比如说,你需要在构造的时候,在每个 variable 可能会需要有它的 list 来标注 input 和 output。当然如果你不用 variable list,你也可以全局有一个 tape,但无论如何,需要有一个数据结构来记录每个运算的输入和输出,这些是需要一些额外的构建。

上节课还提到了一个缺点,就是说扩展计算图的方式来计算自动微分的话,虽然它的好处是你可以拿它算高阶微分,但它的一个不足也在于,当你不停的基于这个计算图去扩展更高阶的微分的时候,图的扩展它是指数增长的计算图复杂度。这也是我们当前神经网络可能在处理高阶微分的时候,效率会显著下降的一个很关键的挑战。

除此之外,我们接下来的课程马上就要解决的一个很重要的问题,是对于我们的分支语句。大家可以想象,我们之前讲的这套操作符重载的思路,它对于加减乘除肯定是非常友好的。但是如果你是 ifelse 的分支语句的话,它就并不能够通过操作符重载来很好的支持。你有的时候按照原生语句,你可能只走了某一个分支,另外一个分支有可能不会被我们涉及到等等。它就给我们的计算图的构建带来了很大的一些挑战,也是我们后面课程讲计算图这个部分,最重要围绕的一个核心。

自动微分的扩展讨论:高阶微分#

现在我们对自动微分做一些简单的扩展性讨论。第一个事情是,我们刚才已经提到了,有的时候你是需要计算高阶微分的。我们想看一下哪些场合需要计算高阶微分,以及带来了怎样的挑战。

可能我们经常做的是梯度下降的优化。对于梯度下降优化来说,我们只需要考虑一阶微分就够了。但是其实对于数值方法来说,我们过去常用的优化手段并不是梯度下降。在很传统的工作当中,牛顿优化也是非常常见的一类方法。

我们可以先简单的提一下牛顿优化。它可能不仅仅是沿着梯度这个方向去下降,它还有一个固定的步长。一阶梯度下降,你的步长是一个用户可以调的超参,是一个 learning rate。但是对于牛顿法来说,它就会假设有一个更合理的步长,就是你可以一步走到最优的解。他认为这个步长就应该等于当前函数 Hessian 矩阵的逆。这里就涉及到了 Hessian 的逆,其实就是梯度的梯度,涉及到了二阶微分。所以如果你要用牛顿优化的话,你是需要二阶微分的。在很多时候,如果说跟物理相关的一些问题,大家就往往会在优化中发现牛顿优化的收敛会比梯度下降更好。

除此之外还有一些,比如说元学习(Meta-Learning)。如果大家感兴趣可以搜一下 MAML (Model-Agnostic Meta-Learning) 这样的名字。具体来说,他可能用一个网络学的是另外一个网络的 weight。这个网络 weight 的优化本身是一个一阶微分,你希望把这个梯度进一步传到你这个元网络上,你就会涉及到梯度的梯度这样的一个概念。

还有一大类就是我们的物理。比如说很多物理本身是一个偏微分方程,那如果你要基于这个偏微分方程去构建某一个损失函数,比如说你认为这个压强场的梯度应该等于某一个值,那你就建立一个压强场的梯度减去某一个值的 L2 Loss。假如你这么做的话,因为本身这里就涉及到一些微分计算,那你有可能也会涉及到这个微分的微分。

自动微分的扩展讨论:其他探索#

所谓的其他探索就是,自动微分到底是不是最终的方案?这里面想提到的一篇工作,也是图形学的一些人做的,它比较有意思的是从图形学角度运用这个数值微分,来去进行微分的计算。这个数值微分就是后面的 Finite Difference,前面的 CS 指的是 Complex-Step 的 Finite Difference。

我们之前提过说数值积分 Finite Difference(有限差分)也是能够算微分的。但我们提过如果你用的是有限差分,它算的微分的误差可能是呈现出这种折线的一种形式。我们之前也讲过有限差分的主要原因,可能在这个小量(ϵ\epsilon)比较大的时候,原因是因为你丢掉的小量有点太大了。而在小量比较小的时候,我们认为有限差分产生误差的一个主要原因,是因为浮点数的精度问题。两个大数作差的时候,有可能小精度已经不是浮点数能够表达的了,就产生了一些相减的浮点数误差。

这篇文章主要攻克的是这样的一个点。也就是说,如果你用一个浮点数去表达,你就会存在浮点数的精度不足。这篇文章想的想法是,我这个数本身就是一个复数,我利用复数域来做这个小量的相关信息的存储。这样的话,原数无论是多大,我都能保证复数域额外有它的一个小量的自由度,能够保证它不会被原来的精度太大、保存不下的这个问题所干扰。它的一个效果也是这个橙色的这条线的效果。也就是说,他已经避免了只要我们这个小量足够小,他就一定能够保证这个误差被抵消,一定能够收敛到非常接近我们真实的微分计算情况。

扩展讲这样的内容,其实是想说,未来到底我们该用怎样的方式去计算微分,还是有非常多的一种可能的。

这篇文章它也是跟现有的 PyTorch、TensorFlow、JAX 等等的 framework 去进行了对比。在这些对比当中,你经常会发现,如果你只需要去计算 gradient 本身的话,通常可能 PyTorch 或者是 JAX 都会给你一个非常不错的计算效率。但是如果说你是在计算一个 Hessian,一个二阶的,那往往我们这样的一个 performance(指 CSFD)就会非常好。

这里其实主要就是想说,传统的这种反向图构建的自动微分方法,它就会在二阶微分的工作当中面临指数构建图膨胀的问题。而对于我们 Finite Difference 来说,它就不会遇到这样的一个问题。所以大家实际的使用当中,其实应该要根据你的微分计算的实际需求,到底你是自变量比较多还是因变量比较多,到底你是要算一阶的微分还是二阶的微分,根据你的需求去选择最为高效的实现手段。

自动微分的扩展讨论:自动积分#

另外跟自动微分也比较有趣的事情是,我们现在知道神经网络能够做自动微分了,那在底层的数学当中,除了一系列微分问题,可能还有一些积分的问题也是我们很感兴趣的。我们很可能也很希望能有一些工具能够做一些自动积分类的工作。对于一些比如说渲染等这一类的工作,可能积分也是比较需要解决的一些难题。

这篇 paper (AutoInt) 它的思路也是比较新颖的。简单来说,它也是运用了一个神经网络。通常我们会拿神经网络的自动微分去做梯度下降。这篇文章提的思路是,我就拿网络自动微分所构建出来的计算图来做我的神经网络。这样的话,它的输出就直接是我网络的输出。

如果说这个本来是用来做微分计算的一种输出,你要求它去拟合了某一种你知道的一些数据,做了一些网络的训练。你可以想象,这时候与它相对应的就是有一个与它的积分版本。因为你这个是自动微分的图,所以它本来的这个网络就是一个它的积分版本。这样的话,只要我的微分计算满足了你想要的一些微分量,那这些微分量积分的结果就可以通过原始的这个图去进行获得。原始的图就可以给我们提供一些很便捷的积分操作。

计算图#

计算图的核心地位#

我们接下来就进行我们的计算图的部分。

我们认为计算图是我们 AI 框架当中最核心的一个部分,它是我们最核心的一个表达。我们上节课已经讲过了,AI framework 它要完成很多功能,所以它中间有很多模块。这节课我们就可以把所有跟计算图相关的模块拿出来。

最核心的部分就是它是我们一个统一的计算表征。它不仅可以去拿来做前向的运算,同时也可以扩展跟神经网络相关的反向的优化的计算,也可以在我们的计算图上完成。这就使得计算图成为了我们 AI 框架最重要的、最统一的表征之一。

围绕这个计算图,我们要学的事情、AI 框架要完成的事情,就包括如何基于计算图去做自动的微分求导。这个求导我们希望整个过程能够很好的去支撑我们前端的编程语言,比如说 Python,这就是我们上次课的内容。

我们在后面的课当中要讲的内容就是,有了这个计算图,为什么大家都选择去构建计算图?我们其实也提到了一点点,就是我一定要做图的扩展,主要是因为有了这样的一个数据结构,我就可以在它的基础上去做很多的优化,就可以使得我们的跟 AI 相关的计算能够更高效地去执行,甚至可以面对不同硬件去做一个高效的调度。所以图的优化和调度都是我们讲完计算图之后的课程要去重点介绍的内容。这些内容其实它跟编译有很多的相关性,也一般都是被我们称作 AI 中的编译的这样一部分内容。

计算图的构成:节点、边与控制流#

我们来看一下计算图有哪些构成,以及我们关心计算图关心哪些事情。首先是计算图的一些定义。我们上节课其实也讲到了,计算图它就是一个有向无环图 (DAG)。有向主要是计算本身是有向的,有输入有输出就决定了它有向。无环主要是因为我们需要有一个完整的从输入到输出的便利方式,如果有环的话,本身在计算上是一些很复杂的逻辑,也不利于我们的图的构建以及自动微分的计算。

如果你关注这个计算图,你会发现它的所有的顶点其实都应该是我们的算子(Operators)。也就是这些节点都是比如说加减乘除、卷积计算等等。它的边才是我们的这些数据,就是输入输出,计算的结果。这可能跟我们以前的这种数据流的感觉有所不同。所以我们整体这个计算图更像是一个计算流或控制流的图。

我们认为顶点描述的是所有的相关计算,并且每一个计算以后,它都会决定这个计算对应的反向的一些计算。所有的边它其实是意味着我们整个图当中的一个数据的流通。

这里面也要提一些特别的。计算图里面最重要的事情,也是不像加减乘除那么容易做的事情,就是我们的控制流。控制流它从逻辑上是和我们的计算类似的,它应该是一个 Operator,也就是说它应该是一些节点,但是显然它不像其他的一些节点可以操作符重载去实现。所以我们认为控制流是我们计算图当中一类特殊的 Operator,是我们后续要重点讲的。

另外就是除了我们作为数据流动的这样的一些边以外,我们还有一些特别的边,是我们的这个依赖关系(Dependency)的边,这都是我们后面要去讲的内容。

特殊的边:依赖关系#

我们现在先看一下这个依赖的边到底是怎样的一种特殊的边。

首先我们最常见的,我们认为 Operator 是我们的节点。如果你的 Operator A 直接使用了 Operator B 的一些输出,这时候我们都能理解它们之间一定是存在数据流的,一定有一个 Tensor 从 B 走到 A,所以这时候你会有一个直接的边。这一类的就不属于我们特殊的边,是属于一个普通的边。

这种特殊的 dependency 什么时候是存在的呢?就是说假如 A 并不是 B 的 output,但是它是依赖于 B 的,那时候你就会存在一些间接的有依赖关系的边。如果通通都没有,那么他俩才是完全独立的节点。

举例来说,你可能有这样的一些 Operator。Operator A 的输入是 Tensor 1,输出 Tensor 2。Operator B 输入是 Tensor 2,输出 Tensor 3。还有一个 Operator C,输入是 2 和 3,然后输出是 4。

此处口误,根据上下文逻辑,B 的输入应该不是 Tensor 2,而是其他独立的 Tensor,才能和 A 并行。我们假设 Operator B 输入是 Tensor 1’,输出是 Tensor 3

有了这样的一个计算图,你可以想象在后续去做优化的时候,你很可能就会选择 Operator A 给一个 GPU 去算,Operator B 给另外一个 GPU 去算,因为你认为他们两个是完全独立的。他俩都算好以后,我就可以去执行 Operator C。

但是有的时候你可能中间会存在一些依赖的关系。举一个很极端的例子,可能说比如说这两个(A 和 B)它共用了同样的一些 GPU 的显存。如果说存在这种依赖关系的话,那对于我们当前的这个计算图,你应该构建的就并不是上面这个结构了,而是底下的这个结构。我们认为 Operator A 和 B 存在依赖关系,那你就多了一个 dependency 的这样的一条边。对于这一类的依赖的边,它主要的作用是用在我们后面图优化等等。有了这种 dependency 就能保证我们在优化的时候,也不破坏我们本来的计算逻辑。

特殊的节点:控制流#

比较复杂的地方还是我们的控制流。本身控制流它应该是一个操作符,我们现在也都理解它一定是一个很复杂的操作符。它的复杂性就主要体现在它很难完全使用操作符重载来实现。

控制流实现其实有一些不同的思路。

第一种,我们可以用一些专门的原生的实现去实现这些控制流。所谓原生的实现,比如 TensorFlow 1.0,他们就会说你不能够调用 Python 自己的 if else,你是需要去调用我自己专门的 TensorFlow.if,TensorFlow.else。如果你调用的是 TensorFlow 自己的 API,那 TensorFlow 就可以专门地对这个控制流去做一个计算图的完整构建。这一类的话,它的主要好处是在我们构建计算图时,它是非常流畅的一种构建,思路是很完整的。但它的不好的地方很显然就是你引入了一些额外的 API,对于 Python 本身的 if/else,你可能会破坏你的整体逻辑。这也是为什么 TensorFlow 1.0 很快就被 AI 框架所淘汰的原因之一。

第二种,我们肯定也比较了解现在 PyTorch 的一种实现方案。PyTorch 的实现方案是沿用 Python 前端的语言去完成这样的工作。那如果说你要沿用 Python 本身的一些 if else 这一类的关键语句的话,我们肯定就需要想办法去兼容这种逻辑,使得我们能够构建一个完整的算图。如果说你不能构建一个完整的计算图,它往往会带来很多计算的问题。事实上,PyTorch 1.0 确实会面临非常多的问题,这也是我们后面要去讲的。

第三种,也就是最新的版本是怎么想的呢?我首先肯定想要类似于第二条,让大家能够用 Python 这种前端框架,自然地去写 if else 这样的编程代码。但是我又希望能够做更多的事情,完成我的一个完整的图的构建。那这时候他就会涉及到很多跟编译相关的。我能不能把 Python 的语言去编译成我的计算图,然后就能够获得更好的计算图的支撑。

这三类都是我们控制流的一些实现,相信大家也对他们这个复杂程度有了一定的理解。

计算图的构建方式:历史演进#

我们这次课主要介绍的内容就是怎么来构建我们的计算图。对于这个问题来说,其实大家习题就已经要去自己去构建计算图了。大家就会发现,对于一般普通的 Operator、普通的边来说,计算图的构建都没有特别多值得讲的,就是一个具体的工程实现的问题。但是到了控制流以后,它使得我们的计算图变得复杂了起来。研究怎么构建计算图,等价于研究怎么去支撑控制流。

相应地,从历史发展上,他主要经历了这样的发展历程:

  1. 最开始是以 TensorFlow 1.0 为代表的这种声明式编程 (Declarative Programming) ,对应的一些静态图 (Static Graphs) 的构建。
  2. 接着就是 PyTorch 为代表的这种命令式编程 (Imperative Programming) ,对应的是一种动态图 (Dynamic Graphs) 的构建。
  3. 最后是我们现在最主流的实现路径,静态图和动态图的混合 (Hybrid)

命令式编程 (动态图)#

我们就先从我们特别熟悉的这种命令式的动态图的构建方式来了解该怎么构图。这个其实是我们上节课就已经在介绍的内容,对应的版本就是 PyTorch v0.x 和 v1.x。

具体来说,大家就可以想象你可能会有 Tensor 类,有 Value 类,Value 就是我们计算图当中的节点。当用户在声明一些语言的时候,我们整个是一个即时(Just-in-Time)的编程,随着语言自动的一起去走。走到对应的每一行,对应的它就会有它的一个即时的构建。当你调用了某一个 Operator,比如说矩阵乘,或者是一些 clamp,或者是 power,当你调用了任何一个这种 Operator 的时候,我们自动地就去记录我的计算图当前的 input 是什么,把计算图的输入输出连接关系都给它构建好。

整个这样的一个随着用户的调用语句、自动地利用操作符重载等去同时构建我们计算图,并且在用户调用 backward 的时候,采用我们上节课提到的这种扩展式的计算图构建,这一类的就是我们所谓的命令式的动态图构建。

对于这一类的动态图构建,它最大的特色是什么呢?就是随着你代码的写,这个图才逐渐的被构建出来。当你调 backward 的时候,就在这个时刻我把反向的图构建了出来。当你算完 gradient 之后,马上我就会把这个计算图清空掉。

大家可以想象这个计算过程,神经网络通常是说有一个主循环,你要不断地去做 loss 的微分计算,去做梯度下降。这个主循环就意味着每一次我都要重新按照我的语句去做矩阵乘、做幂的运算等等,我要重新按照这个语句去动态地、每次都要重新构图。构完图以后我就算一次微分,算完了以后这个图就被我消掉了。所谓动态图的构建,就是图最终没有被存下来,整个都是在随着我们的语句一次一次的在被进行构建。

它的优缺点也是非常明显的。优势当然是非常的灵活,尤其是如果你 debug,你可以随时中断。它的劣势肯定是性能并非最优。

动态图中的控制流#

我们来看一下这里的控制流它是怎么实现的。这里控制流它也是一个自然的逻辑。比如说你是有在前向的有运算,然后有可能有一些控制流 if else 的出现。那对于 PyTorch 来说,它因为是一个命令式的,它最自然的选择就是支持原生的 Python 自己的控制流 if else 的操作。

所谓命令式,就是说我到这里的时候,我才去做相应的图的构建。比如说我就会根据我当前的 flag 的状况,来去判断我走的是上面的这个分支,还是走的是下面的这个分支。事实上就意味着对于这样的一段代码,我每次构建的都只是这个代码中的一部分的计算图。我每次其实会根据当前的情况,去构建我自己的一个当下的、代表我当下计算的计算图。

你就会发现你的计算图其实是不完整的。但是总的来说,因为你的逻辑是不停地每一次都重新构图,所以只要这一次的逻辑是对的,那就够用了。等到下一次的时候,你可能要走其他的逻辑了,没关系,我重新构图,它很自然的就把其他的部分构建出来。

所以说我们看到,对于这种命令式的动态图构建的话,它的控制流是能够支持原生语言,然后根据动态的情况去进行相应的一个构建。这种模式也经常被我们称作 Eager Mode。它最主要的特色就是我们最终构建出来的图,一定不是一张完整的图。

声明式编程 (静态图)#

现在我们来看一下已经被淘汰了的这种静态图的构建思想。之所以要讲它,一方面是因为它是一个历史上发生过的比较重要的事情,另外一方面是因为我们后面会看到,我们逐渐开始去做一些静态图和动态图相融合的事情。

对于静态图的构建,我们把它叫做一个声明式的,或者说你花了一部分时间来去构图,构完图之后再去做计算。这一类的主流的代表就是 TensorFlow v0.x 和 v1.x的系列。

在这个代码形式里面,你会看到 TensorFlow 它会设计它自己的一些原语,用来做图的构建。这就是声明式命名的来源,就是你每次都要声明这个图有哪些节点,有哪些连接关系。

你会看到所有上面的这些地方,全都是在做声明式的静态图构建。比如说 TensorFlow 有一个 placeholder,这就是说我要有一个输入,它应该是一个 float 型的,它的大小是什么样的。至于它具体是多少,我还没有告诉你,我只是告诉你有这样的一个节点。同样的 Y 也是一个节点,还有一些中间能优化的 weights,W1、W2 等等。中间我会进行各种各样的计算。

所有的这个过程都没有实际的计算在发生,而是 TensorFlow 利用他自己的这些原语去搭建了这样的一个图。真正构图,把所有 input/output 连成一张图的操作,在过去的 TensorFlow 其实是在 tf.Session 的 API 下去支持的。也就是说当你去调 tf.Session 这个 API 的时候,它才真正的把整个图按照一个静态图的思路把它构建出来。

我们后面会频繁的使用 Graph IR 的描述,它就是一个图的临时表达(Intermediate Representation)。

对于 TensorFlow 来说,它的设计逻辑就是说,用户先使用这些原语来提供这些操作,接着它会调 Session。在 Session 的时候我会构建这个图,并且我就构建这一次。等我构建完了以后,如果你想去跑这个图的话,你需要做的事情就是调 Session 去 run 某一些节点。比如说有一些节点是做初始化的,有一些节点是算 loss 的,有一些节点是做梯度下降优化的。当你有了这些图以后,这些对象就都是一些可以让 Session 去 run 的。你看到他的逻辑就是说我构图就够一次,但是后面我就会 run 非常多次。

静态图中的控制流#

静态图我们也可以总结一下它的优缺点。它的优点主要就是说我只构建了一次计算图,并且我允许优化,我整体的效率就会是非常高的。同时这个图构建完了以后,它作为一种计算的表达形式,它也可以面向硬件去做硬件专属的优化,去做硬件专属的计算分割,把相关的子图部署在不同的硬件上。

它的不足主要体现在这样的编程手段不够灵活。对于用户来说,他首先要学习一些专门的原语,一些本身不是 Python 原生的 API。然后另外,在这些原语的设计当中,我们也要特殊的去处理控制流。

对于 TensorFlow 来说,如果你有用它的原语去构建相应的功能,你会发现对于控制流的话,你要调 TensorFlow 自己设计的一些原语。比如说对于这种 if else,它其实就涉及 tf.condition,然后根据某一个 condition 去决定走分支 1 还是分支 2。当你要 Session 的时候,TensorFlow 会构建图,他会根据那个逻辑构建,你就会发现它同时会把 ifelse 的两个分支都构建在我完整的一个算图当中。这就是 TensorFlow 静态图的一个最基础的形式。

动态图 vs 静态图#

我们就可以把它去做一些相应的对比。

第一个问题,这两个究竟谁能够给你计时的计算和反馈?答案也是比较显然的,就是动态图,你可以一边动态执行,一边每个值都在计算,所以它是支持这种临时调度的结果。相应地,你就可以很自然的在这里面去做一些 debug、检查错误的操作。这两个对于静态图来说都是非常有挑战的事情。

我们刚才也额外关注了一下控制流。我们就会发现 TensorFlow 需要特殊的语言结构去支持控制流。而对于动态图来说,它就使用前端语言的语法,同时它每次重新构图也能保证我们控制流的逻辑是正确的。

更多的我们其实也关心一些面向实际使用效率上的事情。总的来说,在效率上,其实静态图都是会优于动态图的。比如说内存的占用,因为你可以做优化,所以你可能占用的会低一些。而且你不用每一次都构建,所以你的效率往往也会是更加好一些的。最后它还可以面向部署的时候做特殊硬件的优化,所以它的部署也会是更直接。所有在 performance 性能上都是静态图比动态图更加占优。

这里面我们提到了一个稍微复杂一点的事情,就是关于控制流和数据流的相关性。有的时候你会有一些控制流,比如说分支语句。这时候分支语句你可能依赖一些数据流。比如说某一个数据大于零你走这个分支,某一个数据小于零你走另外一个分支。这就是一类数据流和控制流发生了混合的情况。

如果说你发生了混合,对于我们静态图来说,因为我直接就把整个图全都构建出来了,所以你走哪个分支它都是存在的,并且我可以根据这个分支去准确的优化我的图。

但是如果你用的是 PyTorch 的话,你可以想象一下,在你读到这个数据之前,你并不知道你在走哪一个分支。等你读到这个数据之后,你决定走哪一个分支以后,你才知道你有哪些具体的运算。那这里就有可能出现一些问题,就比如说你的数据流可能在 CPU 上,然后你的底下的控制流里面的那些数据在 GPU 上。这时候你就会出现一些不同硬件的切换,这也是带来我们效率上损失的一个比较重要的点。

走向融合:动静态图混合#

我们已经感受到了 PyTorch 灵活性的好处,我们也感受到了过去的 TensorFlow 高效性的优势。现在我们想做的事情,其实是我们是想要这两个优点都能够并存的。这就逐渐产生了去把静态图和动态图的技术相融合的手段。而且这确实是我们现在主流的 AI framework 非常一致的选择,大家最终都已经殊途同归走到了动静态图互相混合的流程。

如果想要做动静态图混合,其实它的核心思想也比较容易理解。我们肯定是希望想要利用动态图的灵活性,又想要去利用静态图的高效性。最终大家选择的一个方案就是说,我在整体上我希望我依然是一个动态图,也就是说我整体上是需要非常灵活,我有可能各种模块是需要各种各样的连接。但是对于每一个局部来说,有可能我会希望,比如说这个子模块,它相对来说就完成这样的一个事情,它就比较固定,那我可能就希望把它做一个静态图的部署和优化。相应地,我就可以有子图的优化、子图面向硬件的部署等等。

混合图的实现(一):Tracing (追踪)#

这里面一个比较重要的事情就是你的力度怎么去区分。究竟是谁来定义怎样是一个子图?很自然的一个选择就是我们的 function,这个函数式的编程。大家在 Python 里面 define 一个 function,这样的一个 function,我大概就认为它就可以作为一个子图,被固化下来。你把这个静态图固化下来的过程,我们就可以把它叫做图捕捉 (Graph Capture)。

我们现在的静态图创建,已经和我们刚才介绍的 TensorFlow 1.0 的静态图创建有了一些本质的差异。最大的差异就是说,我们并不想要用一套另外的原语 (API),我们已经发现,如果还需要用户再去记一遍,是很影响用户使用的一件事情。所以我们就决定即使是一个静态图的构建,我依然还是要尽量去支持前端的一个语言,只不过这个时候我前端的语言支持也是有限制的,就比如说你必须得是一个 function 的形式。

我们接下来要讲的就是这样的一种,主要的图是一个动态图,然后中间会有一些 function 被我们捕捉构建成一些静态图的混合式编程逻辑。我们也提供两种这种静态图捕捉的思路。第一类的思路是一些追踪式(Tracing)的思路,它跟我们前面讲的命令式的动态图构建更相关。另外一个其实是更主流的,是基于我们编译(Compilation)的方式的思路去进行构建。

我们最先想看的就是一些基于追踪去进行图捕捉的方式。它最接近我们动态图的构建,也是最容易理解的方式。我们前面已经说到了,你跟着命令式,你就能动态的构建出来这个图。如果你想把这个图捕捉到,你最简单的方式就是你不要删掉这个图就可以了,你把这个图留下。

这个例子,Trace-based 的这类方法是在 PyTorch 一点几就已经有的一些 API。我们要讲的是 PyTorch 的 symbolic_trace (torch.fx.symbolic_trace)。这个 F 就是大家会在 Python 里面 define 的一个函数。然后你可能认为它是一个比较重要的图,你想把它固化下来,你就可以调用 torch.fx.symbolic_trace 的方式。这样的话你就会得到一个 FX model,这个东西它就是一个 PyTorch 已经固化了的计算图。

似乎看起来就是把动态图固定一下就完事儿了。之所以我们后面还有很多内容要讲,我们现在就是想关心一下这种 symbolic(符号)追踪,它到底有什么问题。

我们后面会频繁的举这几个例子:

  1. 控制流:如果你出现了一些不同的分支,固化出来的图到底够不够用?
  2. 动态形状:如果你形状总是在不停的变,你这个代码能不能灵活的应用?
  3. 外部代码:假如你的程序涉及到一些外部的数据、外部的代码,比如说 NumPy 的代码,它会对我们计算图的构建带来什么影响?

我们首先看符号追踪 symbolic_trace。如果说你执行一下这些相关的代码的话,你会发现它基本上就都报错了。

第一个,关心一下条件分支。我们刚才讲动态图构建的时候就已经提到了,对于动态图构建,你每一次只能构建其中的一个分支,你不可能构建一个完整的图。那如果这时候你去调用 symbolic_trace,然后他就会给你一个 error,他就会告诉你,当前在你没有提供任何输入的情况下,我不知道你要走哪个分支,所以我完全没有办法给你动态地去构建这张图。

接着是形状相关的。它同样也是报错了,理由也是非常一致的。你这个代码它具体的执行逻辑跟你的输入的形状有关,而你这时候你是没有输入的,你不知道输入,所以同样的我也没有办法按照输入的逻辑去动态的构建我的图。

第三个是外部代码的支持。这时候比较有趣的一件事情就是说,他确实执行了这一步并没有报错。这个外部代码它本身的逻辑,是需要外部随机的给我生成一个数值,然后我把它这个随机数和某一个数相加去做一个返回。正常的逻辑每次跑它,都应该有一个不同的结果,因为它是随机的。

但是如果说我们去给它静态构建一个图,并且把这个图捕捉变成一个静态图存下来,对于 PyTorch,你在执行这个静态图的时候,你让他执行两次,看这两次结果一不一样,结果就很显然,这两次的结果它确实完全相同。这就说明我们固定的这个静态图的逻辑,和我本身用户写的 API 的逻辑就发生了冲突。

这其实是我们最不希望的一个状况。我们通常对于一个编程框架来说,我们会希望尽可能的支持用户所有需求。如果说支持不了,我希望在最早的时刻去进行报错。我最不希望的情况就是说我不能支持,但是我执行了,并且我完全不报错,然后我执行的逻辑完全不符合用户的需求,这是最糟糕的一个状况。

混合图的实现(二):JIT Trace#

大家就可以想象一下你该怎么突破我们刚才提到的这些困难。我们刚才提到了很多困难,都是因为我们的输入是不知道的。

PyTorch 最开始的一个想法是,我就不做符号的追踪了,我做 Just-in-Time (JIT) 的追踪。因为我认为在 JIT 的时候,你就应该已经知道 input 是什么。只要你知道 input 是什么,我做一个即时的追踪,我在即时的那个时刻,我动态的构建图,并把那个图保存下来,这个逻辑它应该就能够支撑的更好一些。

按照这个逻辑,我们就到了 torch.jit.trace。我们依然去把我们之前提的控制流、形状改变、外部代码的这些 function,作为我们想要固定的子图,采用 torch.jit.trace 的思路去固化它们。

首先是控制流。这时候你有 input 了,你就完全可以根据你的 input 来去选择我当时的那个代码逻辑里面的某一条分支。但这时候其实它是存在一定的风险的。我们就看到 PyTorch 也很体贴,相应地把这个风险的 warning 在执行的时候就会打给你。他就是说,根据你的代码逻辑里面你是有分支的,我可以根据你的 input、根据你的分支去构建一个图。但是我告诉你,这件事情它对于我当下的输入是匹配的,但如果你将来要用这个静态图去计算其他的输入,他有可能就不再匹配了。

接着我们想看的是形状发生变化的。这时候如果说我即时编译下已经提供了某一种输入,这个其实是非常智能的一个实现。大家可以看到,本身你有一个东西它是基于跟你的输入相关,结果他自己就支撑了一下,他用自己的原语支撑了一下这个形状大小获取的事情。有了这个大小的获取,我就能够去实现你想做的、基于当前输入的一个 shape 来做形变的操作。在 shape 的这个方面上,他就做的非常的好了。

除此之外我们关心的是外部代码。我们计时的这样的一个编译,他至少也是做到了一件事情,就是他做不了的事情会给你来一个 error。也就是我在 trace 的时候,我就发现你掉了一个外部的代码。对于外部的代码的实现来说,我会遇到一些没有办法去进行追踪的代码,所以我没有办法把这个事情作为一个静态图一同的存下来。

我们就认为这个基于 JIT 即时编译的一个追踪的静态图固定路数,它最大的优势就是所有的事情都做的是比较正确的。当然它最大的不足,就是它还是没有完美的去支持外部的代码和控制流。

混合图的实现(三):Source Code Transformation (源码转换)#

我们接下来要讲的就是源码转换的这样一种把静态图捕捉下来的思路。这里我们就已经用到了一些基于编译的思想。也就是说,我首先我愿意支持的是用户使用他自己的源代码原语,比如说 Python 的 API。但是最终我要构建图,那我就需要做一些编译的技术,把这个用户的 API 转换成我的一个语法树相关的构建。基于语法树去得到各种各样的连接关系,去构建我的图。

在 PyTorch 1.0 的时候,它就有了 jit.script 这一类算法。面对这个我们依然还是去测试我们之前提到的三类 function。这里 torch.jit.script 一种逻辑就是你可以使用 @torch.jit.script 这个 decoration(装饰器)去修饰它。如果你做了这个修饰,其实就是在告诉 PyTorch 的编译组件,这些 function 将来是我希望能够固化成静态图的内容,你就需要去对这些 function 进行一些子图的、基于源码编译的构建。

如果你做一个图的代码的打印,你就会发现之前这个分支语句,在这里你就会发现它其实整个分支全部都在这里了。这就是一个源码转换的核心好处。因为整体我看到了整个的源码,我其实做的就是一个 Python 语言向我的计算图语言的一个翻译工作。我翻译的时候我就看到了这些语言,我就把它都在我的图当中去进行一个准确的翻译。我就可以在这个时刻完整地把我的整个的控制流构建下来。

我们也可以看一下跟这个形状相关的。作为一个源码的翻译,它的好处当然说控制流我就可以完成支持。但是如果说你有些其他的东西是依赖输入的,这个时候我翻译的工作依然是没有办法做。就是因为你没有输入的话,我不知道这个东西到底应该变成什么样子。这时候你就会发现,它在执行的时候,它会出现一个报错。对于这种形状相关的代码,它还是依然没有办法依靠一个简单的翻译工作去进行完成。

最后还有外部代码的操作。你也会发现它虽然我尽可能的去支持各种各样的 Python 语言,但是如果说你调用的是一些我并不支持的 Python 语言,比如说你调的 NumPy 这个包,那对于这些外部代码来说,我还是没有办法去执行翻译操作。

混合图的实现(四):PyTorch 2.0 (Bytecode Transformation)#

这个事情最终是在 PyTorch 2.0 的版本被解决的。PyTorch 在 2.0 的时候提出了它的 torch.compile 功能。PyTorch 也是在提出了 torch.compile 这个功能之后,非常骄傲地宣布,我们认为我们已经完整的解决了所有的动静态图混合构建的问题。

所有跟 torch.compile 相关的,其实就是 PyTorch 在 Dynamo 这样的一些研究工作。

我们刚才介绍的源码转化,它存在各种各样的问题,最主要的一个问题有可能就是我的输入依然是不知道的。那到了 PyTorch 2.0,对于 torch.compile,它该怎么来解决这样的一个问题?

他这里的一个思路是非常有趣的。就是他说我与其去源码翻译 Python 的语言到我的图,我选择更进一步。因为 Python 本身你也有你的编译器,你 Python 编译的时候,也会产生各种各样的一些编译的中间语言。所以对于 torch.compile,他就说我现在还是要等一等,我要等你的输入都完全确定了以后,再去做我的源码翻译。

而等你的输入都确定以后,Python 本身自己会做什么事情?Python 本身是一个即时编译的语言,他在确定了它的输入以后,它自己会产生 Python 自己的字节码 (Bytecode)

torch.compile 它的一个逻辑就是说,我不去翻译 Python 的 API,我去翻译你的字节码,把你的字节码翻译成我的图。因为这个时候你的字节码就是根据你各种各样的输入去构建的,这个时候我再翻译构建的图,就一定是符合我现在输入的逻辑的图。

我们这边也可以来看一下。我们调 torch.compile 的时候,最前面是你要做的 function,比如说就是我们刚才的分支语句的测试函数。接着你还可以放一个 backend。这个 backend 其实它做的是计算图的优化,我们现在还没有讲。我们索性在这里就直接放一个不做任何优化的后端。我们后端做的事情就是把这个计算图打印一下。

我们调 torch.compile,把这个函数输给他。执行完的时候,其实什么都没有发生,因为他要截断的是 Python 的字节码,所以在你没有给任何输入的时候,他什么都不做。

那接下来我们就可以执行它了。你可以调这个 function,给它各种各样的不同的输入。当然这种不同的输入就意味着在这个 function 当中,我们可能会走不同的分支语句。比如说我们的条件是小于 0,那可能前两个(输入 0 和 1)都不满足,我们走的都是 else 这条语句。然后后面(输入-1)走的是前面的这条语句。

当你执行第一个逻辑(输入 0)的时候,你就会发现 torch.compile 就开始建图固化子图了,并且把这个子图打印出来。在第一步的时候打印出来了部分子图。如果你仔细看一下的话,这一部分其实只有这个判断,两个分支它都还没有走。但是执行完这个判断以后,马上他就会把第二个子图打出来,那他可能是走了 else 这个分支。

这整个事情就是说,在执行这个函数的时候,首先 Python 它自己就生成了走 else 这条分支的字节码。然后 torch.compile 就捕捉到这些字节码,并且进行语言的转换,构建成了自己的两个图(一个判断图,一个分支图)。这两个图还是两个子图,最大的优势就是说上面这个子图(判断图)可能会频繁的被复用,等到第二次我就不用再一次构建了。

你就会发现在你第一步(输入 0)的时候,他先完成构图,完成构图以后它执行你的代码逻辑,打印的是-1。接着你再去执行第二条(输入 1),他就会发现它还是走的 else 这个分支,所以无事发生,然后就把 0 打印出来。

直到你有一个输入(输入-1),突然发现它触发了另外一条分支。这个触发的其实是 Python 本身触发的。这样一来,我就捕捉到了新的字节码,这个新的字节码就构建了一个新的子图,然后这个子图让你去走上面这条分支,然后你就会返回 0。

我们就是以这个简单的例子来感受一下 torch.compile。它其实就是更深入地去捕捉字节码,去做 CPython 的代码语言的转换,转换成他自己的一个计算图构建的思路,最终比较完美的解决了这个控制流的问题。

我们还有一些形状、外部代码的一些例子,我们可以下节课再聊。但具体来说,它比较完美地解决了很多问题,所以我们认为这是现在最好的动静态图混合的一个方案。

附录#

关键点和注意事项#

提示词和模型#

请严格按照以下要求处理并输出我的请求:
根据以上转录稿以及PPT的内容,帮我整理一份文字清晰、术语准确、与原文逐字逐句对应的转录稿。原文本中有非常多识别错误,需要你结合上下文进行猜测。首先确定文本主题,然后想想都有哪些关键词,最后想一想原文中错误的字和哪个关键词相似,然后替换。整理过程中,你还需要去除重复的话和语气词。仅用三级标题 (###) 来创建清晰的章节结构,与主题内容相对应。如果转录稿中讲师特别提醒要注意某些内容,将这些内容汇总在输出的最后一个三级标题 ###关键点和注意事项 中。

将转录文本和经过 MinerU 处理的课件放入提示词中。

模型:Gemini 2.5 Pro.