目录

概述

从用户体验的角度来看,TorchDynamo 非常易于使用。用户通过 torchdynamo.optimize 作为注解进行调用:

@torchdynamo.optimize(my_compiler)
def fn_foo(bar):

一个完整的示例看起来像这样:

from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

这使得TorchDynamo能够捕获解释型Python帧,获取 任何和所有相关信息,并在可能的地方加快速度。 加速来自几个方面,且很大程度上取决于 提供的后端(如上面示例中的 my_compiler),但本节中重要的一个加速是 缓存。缓存本身并不是直接的加速,而是一个关键的使能机制,可以防止 重新编译。我们用dynamo挖了一个洞,而缓存使我们能够爬出来。 它使我们能够在保持性能中立的同时,启用后端——这才是我们 加速的真正来源。

即使提供了通过式无操作后端:

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    return gm.forward

我们可以看到,即使在普通的 Python 上,而不仅仅是 PyTorch 中,TorchDynamo 也能加速 Python 的执行。

缓存和守卫概述

TorchDynamo 通过缓存由 TorchDynamo 转换过的用户字节码来运行。当 TorchDynamo 接收到一个帧用于评估时,它会检查该帧中引用的对象是否以某种方式发生了变化,如果没有变化,TorchDynamo 将读取之前转换过的用户字节码来对其进行评估。在本节中,我们将重点介绍如何确定帧中引用的对象是否发生了变化。这是 TorchDynamo 中一个关键的功能,因为它驱动了整个失效生命周期。这个功能被称为 guards

从非常高的层次来看,流程可以总结如下:

  1. TorchDynamo 接收到一个 Python 帧。

  2. 它将帧(1)转换,通过指令 翻译。

  3. 对于(2)中捕获的对象,TorchDynamo会创建跟踪对象,这些对象是: * 在输出图上被跟踪,输出图是 torch.fx.Tracer 的一个内部特化版本 * 保护措施

  4. TorchDynamo 处理在 (3) 中创建的 guard 对象,将其转换为一个生成的 Python 函数,check_fn,并与一段代码相关联。

  5. The check_fn 是在我们再次遇到这段代码时进行评估的 - 如果 check_fn 通过并评估为 True,TorchDynamo 会识别缓存中的代码与此处遇到的代码是相同的,并且可以安全使用。如果它失败并评估为 False,TorchDynamo 会识别缓存中的代码为无效,并可以将其丢弃,转而采用新的条目,通过重新编译或图中断。

Python 框架评估和 PEP 523

TorchDynamo 的功能基于 PEP 523

TorchDynamo 通过使用 _PyInterpreterState_SetEvalFrameFunc 在 Python 上安装一个帧评估函数。 TorchDynamo 有一个钩子,可以在评估期间将控制权交还给我们。

我们安装的函数是 convert_frameconvert_frame_assertnopython=True 情况下,但暂时忽略这个细微差别,让我们来看看 convert_frame_assert, 作为 convert_frame 的代理。

我们可以在 convert_frame.py 的第 20 行 找到它, 其签名如下:

def  convert_frame_assert(compiler_fn: Callable, one_graph=True):

此函数将 Python 调用 TorchDynamo 的入口点包装在一个帧中:

def  _convert_frame_assert(frame: types.FrameType, cache_size: int):

这是该函数的作用:

  1. 检查是否已经见过这个code(参见: f_code 这里),如果已经见过则提前退出。

  2. 检查代码是否属于不支持的情况。

  3. 检查 cache_size(上面的第二个参数)是否超过配置中定义的限制 cache_size_limit。如果超过了,该函数 会丢弃该帧并记录警告信息。这有助于避免频繁 重新编译该帧,因为通常这意味着该帧以一种意外的方式被频繁调用,缓存它会产生不必要的开销, 因为它很可能在下次遇到时就会被清除。

  4. 将帧传递给一个函数,该函数通过字节码转换创建一个 InstructionTranslator,通过 transform_code_object。在此过程中发生了一些关键的事情:

    1. 新代码通过 transform_code_object 生成。

    2. 一个名为 output 的 FX 追踪器是通过 InstructionTranslator 生成的。

      这可能会有点令人困惑, 因为 InstructionTranslator 并不是一个 fx 追踪器,但它是存储 在一个名为 tracer 的变量中,并且其输出*一个 `fx`追踪器。

    3. 该函数生成保护机制,并将它们存储在output上方。

    4. 该函数生成 output_instructions 并将它们存储在 output 上面。

    5. 该函数将新生成的转换代码映射到它从帧中读取的初始代码。这种映射值得记住,我们将在后面讨论 guard 失败时多次引用它。

  5. 使用4.1节中的转换代码和4.3节中的保护语句, 该函数生成一个 GuardedCode

现在我们已经了解了帧评估,让我们回顾 InstructionTranslator,并看看它是如何将我们传递给它的帧转换为TorchDynamo内部类型的。

InstructionTranslator

InstructionTranslator 功能非常强大!我们不会详细说明它所做的一切,但对于本文来说,最重要的是它生成了一个 symbolic_locals 映射,该映射维护了从帧的 f_locals 到 TorchDynamo 内部 Variable 对象的映射(稍后将更详细地介绍这些内容。symbolic_locals 通过遍历帧的局部变量来填充:

self.symbolic_locals = collections.OrderedDict(
    (k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
    for k in vars
    if k in f_locals
)

此处的重要组成部分是调用一个进入 VariableBuilder 的调用 VariableBuilder 的调用实现 代理到一个名为 _wrap 的函数,该函数又同时 构造 VariableTracker 的实例并在它们上调用 make_guards。更多 内容将在后面介绍。

这种映射至关重要,因为每个Variable都有相关的保护机制,这些保护机制随后会被传递给self.output,即OutputGraph的实例,这是一个fx追踪器,在上方章节的4.2节中提到过。如果你还记得,这个OutputGraph,存储在一个名为output的变量中,是我们存储保护机制的地方,之后这些保护机制会被传递并成为GuardedCode

如何让 InstructionTranslator 实现这一点?其核心是一个被驱动的循环,它执行一个函数 step

step 就是这样 - 一个单独的处理步骤,仅执行一条 指令,并对它进行 某种操作

注意

这些是被TorchDynamo的 transform_code_object 处理的真实指令,这非常酷。

注意

本节有意跳过 dis.get_instructions 的详细内容。

对于上面的例子,这里是一些 Instruction 可能的样子的代码片段:

Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)

这是该函数的核心功能。看看这个 opname, 然后再看看来自内部的这个小代码片段 step

if not hasattr(self, inst.opname):
    unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)

正如我们所见,该函数检查当前类是否具有与运算符名称匹配的属性设置 (例如,LOAD_CONST)。如果存在,则调用该属性,并传递整个指令对象。如果不存在,则将帧视为未实现。

对于 LOAD_CONST 示例,我们可以看到我们确实支持它, 并且定义相对简单直接:

def  LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))

我们可以看到,这个函数创建了类 ConstantVariable 的一个新实例,其值在我们的示例中为 -1,然后将其压入栈中。

有数十种这样的方法 - 请参见 symbolic_convert.py 以查看所有方法。通常,我们会尽可能多地实现与 Python 字节码指令匹配的方法。

step之后的逻辑以及调用VariableBuilder的逻辑中,我们现在有很多VariableTracker,当然,我们之前也讨论过创建保护机制的问题。让我们深入了解什么是Variables,并更进一步地理解保护机制。

变量

A ConstantVariableVariableTracker 的一个实例。 VariableTracker 表示一个被追踪的 Python 局部变量或栈值。

当涉及到在TorchDynamo中表示一个对象时, VariableTracker 正如其名所述,它会跟踪给定的变量。 这是一个极其灵活的类,但有几个需要注意的要点:

  • 它通过以下方式管理底层对象周围的 guard 关系:

    • make_guard

    • replace_guards

    • add_guard(s)

    • propagate - propagate(*vars: List[List["VariableTracker"]]) - 也许最重要的是,因为它结合了所有传入的 VariableTracker 实例所提供的保护机制。它会访问这些保护机制,并将它们合并到自身上。

  • 它作为底层对象的代理,为 TorchDynamo 的其余部分实现方法,以获取有关被跟踪对象的信息:

    • call_method

    • call_function

    • python_type

    • as_proxy

    • is/as_python_proxy

  • 它存储了类型为 source 的变量 Source,来自 torchdynamo/source.py。这种源类型是一个相对自包含的类,有助于我们组织和记录原始源的来源,并提供诸如获取名称等便捷方法,对我们来说尤为重要的是生成保护机制。

And this class (VariableTracker) is built around subclassing, somewhere between a full Abstract Base Class and fully fleshed out class - it leaves many methods raising NotImplementedError - with reliance on subclasses. See torchdynamo/variables/ for all subclasses to fulfill contracts and custom behaviors。

了解我们现在所知道的内容,我们可以看到一个例子,说明来自disBUILD_TUPLE的指令如何执行:

BUILD_TUPLE(count) Creates a tuple consuming count items from the stack, and pushes the resulting tuple onto the stack.

在我们的情况下,我们的签名会因为创建 Instruction 对象的方式而略有不同,但其核心思想是一样的。 我们不再传入 count,而是传入一个带有额外簿记信息的对象,当然,我们还需要处理将普通的 Python 对象转换为 TorchDynamo 概念的过程:

def BUILD_TUPLE(self, inst):
    items = self.popn(inst.argval)
    options = VariableTracker.propagate(items)
    self.push(TupleVariable(items, **options))

以下是这段代码的作用:

  1. 该函数读取 argval,在这种情况下,相当于 pydoc 中对应指令的 counts

  2. 该函数 popn 了这些项目,在这种情况下,签名是 def  popn(self, n: int) -> List[TensorVariable]: 这暗示了一个 底层的契约 - 我们返回的是 TensorVariables。如果我们 更仔细地查看 sybmolic_convert.pyInstructionTranslatorBase/InstructionTranslator,我们会发现 唯一被压入和弹出我们栈的是 VariableTracker

  1. 该函数调用 VariableTracker.propagate。这 会从堆栈中弹出的每个项目中获取守卫条件,然后递归遍历并合并所有守卫条件到 options: py  return {      "guards": guards,  }

  2. 该函数随后会根据VariableTrackerTupleVariableitemsoptions创建一个新的实例。这使我们能够从items安装所有适当的防护措施,这些防护措施组成了新的TupleVariable

注意

第一个守卫是从哪里来的?传播是一种很好的技术,但我们需要在传播之前先创建一些东西。VariableBuilder调用make_guards,因为它从f_locals创建VariableTracker实例。这反过来又调用了source,让它创建守卫。

在完成所有这些步骤之后,字节码翻译已经完成,我们离生成 GuardedCode 又近了一步。现在我们理解了局部变量是如何变成 VariableTrackers 的,指令是如何被处理的,以及在哪里调用 guard 来进行创建。在我们深入查看代码和 guard 如何组合成一个 GuardedCode 对象之前,我们需要稍微了解一下上面那些 make_guardsource.make_guard 的调用。这样我们就能理解,当我们在 VariableTracker 实例上创建 guard 时发生了什么。

制作守卫

Guard 是普通的 Python 对象,属于类 Guard。让我们更详细地了解它们。

查看 dataclass 的定义(因此也是构造函数的签名),我们可以看到它包含一个名称、一个来源和一个创建函数。

@dataclasses.dataclass
class Guard:
    name: str
    source: GuardSource
    create_fn: Callable

名称应该是变量的名称。

此处的源是一个枚举,表示守卫属于哪种类型的源。

注意

不要与 Source 和存储在 VariableTracker 中的其他类型混淆。

create_fn 提供了从简单的数据类过渡到实际生成有效的Python代码的主要功能,以便在调用之间确定事物是否发生了变化,以及我们是否可以安全地从代码缓存中读取。

获取护盾实例最常见的代码路径是通过 make_guardsVariableTracker 上。 make_guards->``source.make_guard``->``return Guard(self.name(), self.guard_source(), fn)``

或者,以一个具体的例子来说:

...
elif istype(value, range):
    guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
    return RangeVariable(value=value, guards=guards)

由于 source 在构建时被设置为这个 VariableTracker,这里只需要提供 fnGuardBuilder.EQUALS_MATCHcreate_fn 字段。

这个 create_fn 必须是 GuardBuilder 上的方法。这样做的原因在我们的下一步中变得显而易见。一旦我们为一个帧创建了所有守卫,我们就继续进行 CheckFunctionManagercompile_check_fn

convert_frame 函数能够生成 GuardedCode 之前, 它需要运行 CheckFunctionManager,并且带上所有的保护措施,以 生成一个 check_fn,然后将其与代码一起传递给 GuardedCode。这是我们在缓存条目中存储的同一个 check_fn, 也是我们用来判断是否检索与之存储在一起的代码的同一个。 作为参考,以下是该代码:

static CacheEntry *create_cache_entry(CacheEntry *next,
                                      PyObject *guarded_code) {
  CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
  DEBUG_NULL_CHECK(e);
  e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
  NULL_CHECK(e->check_fn);
  e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
  NULL_CHECK(e->code);
  e->next = next;
  return e;
}

我们现在知道如何使用一个check_fn函数,以及谁创建了它,它由什么组成,但我们还不知道它是如何工作的。一个Guard对象的列表是如何变成我们以后可以运行的函数的?

首先,我们迭代这些保护:

for guard in sorted(guards or [], key=Guard.sort_key):
    if not config.guard_nn_modules and guard.is_nn_module():
        continue
    guard.create(local_builder, global_builder)

调用 guard.create 次我们在 create_fn 上设置的运行,Guard 类(不要与我们正在处理的 check_fn 混淆,名称相似,因此可能会有点混淆)。在 上面的例子中,我们的 create_fnGuardBuilder.EQUALS_MATCH。 所以我们现在调用它,传入 self,即保护器本身, 在。

签名是: def EQUALS_MATCH(self, guard: Guard):

在该函数内部,我们可以使用name来获取原始对象,查询其数据和类型信息, 从而得到最重要的部分:追加代码。

在最简单的情况下,EQUALS_MATCH 只需添加一行代码: self.code.append(f"{ref} == {val!r}"). 其中 ref 是变量的名称,而 val 是值。它可能会生成如下代码:

y == 2

这是一个基本示例。但是如果我们添加一些其他类型的GuardBuilder 函数,然后将它们与 and 结合在每个语句之间(如我们所做的),我们可能会得到这样的结果:

___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)

以下是这段代码执行的内容:

  1. 一个检查 .valid

  2. 类型ID检查

  3. 值检查

  4. 张量检查

这将成为我们check_fn的核心代码,它将在下一次遇到此代码时进行评估。然后它将检查:

  1. 这段代码仍然有效吗?

  2. 如果 (1),y 是否仍然具有 94367738391392 类型?

  3. 如果 (2),y 还是 2 吗?

  4. 如果 (3),让我们检查一下张量 x 是否以某种特定方式发生了变化。

如果所有这些仍然为真,那么我们可以使用与这个check_fn一起缓存的代码。

注意

对于更深入的了解这如何以及在哪里发生 你可以阅读 static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) { 中的 _eval_frame.c

如果没有,那么我们可以继续重新编译代码,并将其存储在缓存中,与这段代码一起,以及一个全新的check_fn,再次在下一个帧上进行检查。

有许多其他这样的函数在 GuardBuilder 上,它们会被合并成有时非常大的字符串,然后作为 Python 代码进行评估并存储到 check_fn 中。上面的例子说明了一个简单的情况。要更好地理解此功能,请阅读 GuardBuilder 上的其他函数,或者更好的是,在 compile_check_fn 中转储 code 变量以查看生成的内容,特别是在较大的真实模型上。

摘要

在这一部分,我们回顾了:

  • .valid 和弱引用周围的失效作用(以及可能很快 NN 模块的失效)。

  • C++端的守护函数(___check_type_id___check_tensors等)如何工作

  • 当守卫失败时会发生什么。

  • 如果我们生成了无效的保护代码会发生什么。

我们介绍了用户提供的代码在TorchDynamo上下文中如何被追踪和跟踪,组织成VariableTrackers Sources 和随后的 Guards,并且这些 Guards 如何反过来指导缓存条目的选择和失效处理Python代码。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源