torch.export¶
警告
此功能是一个正在积极开发的原型,未来将会有重大变更。
概述¶
torch.export.export() 接受任意的Python可调用对象(一个
torch.nn.Module、函数或方法),并生成一个仅表示函数中Tensor计算的跟踪图,以提前编译(AOT)的方式进行,随后可以使用不同的输出执行或序列化。
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)
# code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
return (add,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {}
torch.export 生成具有以下不变量的干净中间表示(IR)。更多关于 IR 的规范可以在
这里找到。
正确性: 它保证是原始程序的正确表示,并且保持了原始程序相同的调用约定。
已规范化: 图中没有Python语义。来自原始程序的子模块被内联以形成一个完全扁平化的计算图。
图属性: 该图是纯粹的功能性,意味着它不包含具有副作用的操作,如突变或别名。它不会改变任何中间值、参数或缓冲区。
元数据: 图形包含在跟踪过程中捕获的元数据,例如来自用户代码的堆栈跟踪。
在幕后,torch.export 利用了以下最新技术:
TorchDynamo (torch._dynamo) 是一个内部API,它使用CPython的一个特性 称为帧评估API来安全地追踪PyTorch图。这 提供了一个极大改进的图捕获体验,需要重写的次数大大减少 以便完全追踪PyTorch代码。
AOT Autograd 提供了一个功能化的PyTorch图,并确保该图被分解/降低到ATen操作符集。
PyTorch FX (torch.fx) 是图的底层表示形式,允许灵活的基于Python的转换。
现有框架¶
torch.compile() 同样使用了与 torch.export 相同的PT2堆栈,但
略有不同:
即时编译与提前编译:
torch.compile()是一个即时编译器,而 则不打算用于在部署之外生成编译工件。部分图捕获与全图捕获: 当
torch.compile()遇到模型中无法追踪的部分时,它将“图中断”并回退到急切的Python运行时执行程序。相比之下,torch.export旨在获取PyTorch模型的完整图表示,因此当遇到无法追踪的内容时会报错。由于torch.export生成的图与任何Python特性或运行时无关,因此该图可以保存、加载并在不同的环境和语言中运行。可用性权衡: 由于
torch.compile()在遇到无法追踪的内容时能够回退到Python运行时,因此它更加灵活。而torch.export则需要用户提供更多信息或重写代码以使其可追踪。
与torch.fx.symbolic_trace()相比,torch.export使用TorchDynamo进行追踪,它在Python字节码级别运行,因此具有追踪任意Python构造的能力,而不受Python运算符重载支持的限制。此外,torch.export对张量元数据进行细粒度跟踪,因此基于张量形状等条件的追踪不会失败。一般来说,torch.export预计可以在更多用户程序上工作,并生成较低级别的图(在torch.ops.aten运算符级别)。请注意,用户仍然可以将torch.fx.symbolic_trace()作为torch.export之前的预处理步骤。
与 torch.jit.script() 相比,torch.export 不捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为它更容易对 Python 字节码进行全面覆盖)。生成的图更简单,只有直线控制流(除了显式的控制流操作符)。
与 torch.jit.trace() 相比,torch.export 是可靠的:它能够追踪执行整数计算的代码,并记录所有必要的附加条件,以证明特定的追踪对其他输入也是有效的。
导出PyTorch模型¶
一个示例¶
主要入口点是通过 torch.export.export(),它接受一个可调用对象(torch.nn.Module、函数或方法)和示例输入,并将计算图捕获到一个 torch.export.ExportedProgram 中。一个例子:
import torch
from torch.export import export
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# code: a = self.conv(x)
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])
# code: a.add_(constant)
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)
# code: return self.maxpool(self.relu(a))
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
return (max_pool2d,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='constant'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='max_pool2d'),
target=None
)
]
)
Range constraints: {}
检查ExportedProgram,我们可以注意到以下内容:
该
torch.fx.Graph包含原始程序的计算图,以及原始代码记录,便于调试。图表仅包含在此处找到的
torch.ops.aten个操作符 和自定义操作符,并且完全可用,没有任何原地操作符, 例如torch.add_。参数(权重和偏差到卷积)被提升为图的输入,
导致图中没有get_attr个节点,这些节点以前存在于torch.fx.symbolic_trace()的结果中。该
torch.export.ExportGraphSignature模型描述了输入和输出的签名,并指定了哪些输入是参数。图中每个节点生成的张量的形状和数据类型都被标注出来。例如,
convolution节点将生成一个数据类型为torch.float32且形状为 (1, 16, 256, 256) 的张量。
非严格导出¶
在PyTorch 2.3中,我们引入了一种新的跟踪模式,称为非严格模式。 它仍在进行强化,因此如果您遇到任何问题,请将它们提交到Github,并附上“oncall: export”标签。
在非严格模式下,我们使用Python解释器跟踪程序。 你的代码将完全按照急切模式执行;唯一的区别是 所有Tensor对象将被ProxyTensors替换,后者会将所有操作记录到图中。
在严格模式下,这是当前的默认模式,我们首先使用TorchDynamo(一个字节码分析引擎)来追踪程序。TorchDynamo实际上并不执行你的Python代码。相反,它会对其进行符号分析,并根据结果构建一个图。这种分析使torch.export能够提供更强的安全性保证,但并非所有的Python代码都受支持。
一个可能需要使用非严格模式的例子是,如果你遇到了一个不支持的TorchDynamo特性,而这个问题可能不容易解决,并且你知道Python代码并不完全用于计算。例如:
import contextlib
import torch
class ContextManager():
def __init__(self):
self.count = 0
def __enter__(self):
self.count += 1
def __exit__(self, exc_type, exc_value, traceback):
self.count -= 1
class M(torch.nn.Module):
def forward(self, x):
with ContextManager():
return x.sin() + x.cos()
export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully
export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
在这个示例中,第一个调用使用非严格模式(通过strict=False标志)成功追踪,而第二个调用使用严格模式(默认)导致失败,因为TorchDynamo无法支持上下文管理器。一种选择是重写代码(参见torch.export的限制),但由于上下文管理器不影响模型中的张量计算,我们可以采用非严格模式的结果。
导出用于训练和推理¶
在 PyTorch 2.5 中,我们引入了一个名为 export_for_training() 的新 API。
它仍在完善中,所以如果您遇到任何问题,请使用“oncall: export”标签提交到 GitHub。
在此API中,我们生成包含所有ATen运算符(包括函数式和非函数式)的最通用中间表示形式(IR),可用于在急切模式下使用PyTorch Autograd进行训练。此API旨在用于急切训练用例,例如PT2量化,并将成为torch.export.export的默认中间表示形式。有关此更改背后的动机,请参阅 https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206
当此API与run_decompositions()结合使用时,你应该能够获得具有任何所需分解行为的推理IR。
以下是一些示例:
class ConvBatchnorm(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
mod = ConvBatchnorm()
inp = torch.randn(1, 1, 3, 3)
ep_for_training = torch.export.export_for_training(mod, (inp,))
print(ep_for_training)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
return (batch_norm,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='batch_norm'),
target=None
)
]
)
Range constraints: {}
从上述输出中,你可以看到export_for_training()生成的ExportedProgram与export()几乎相同,除了图中的操作符。你可以看到我们以最通用的形式捕获了batch_norm。此操作符在功能上是无效的,并将在推理运行时降级为不同的操作符。
您也可以通过run_decompositions()从该中间表示转换为推理中间表示,并进行任意自定义。
# Lower to core aten inference IR, but keep conv2d
decomp_table = torch.export.default_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
ep_for_inference = ep_for_training.run_decompositions(decomp_table)
print(ep_for_inference)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
return (getitem_3, getitem_4, add, getitem)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_3'),
target='bn.running_mean'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_4'),
target='bn.running_var'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='add'),
target='bn.num_batches_tracked'
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='getitem'),
target=None
)
]
)
Range constraints: {}
在这里你可以看到我们在IR中保留了conv2d个操作,而分解了其余部分。现在IR是一个功能性的IR,包含核心aten操作,除了conv2d。
您可以通过直接注册所选的分解行为来进行更多自定义。
您可以通过直接注册自定义分解行为来进行更多自定义设置。
# Lower to core aten inference IR, but customize conv2d
decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function
ep_for_inference = ep_for_training.run_decompositions(decomp_table)
print(ep_for_inference)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
return (getitem_3, getitem_4, add, getitem)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_3'),
target='bn.running_mean'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_4'),
target='bn.running_var'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='add'),
target='bn.num_batches_tracked'
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='getitem'),
target=None
)
]
)
Range constraints: {}
表达动态性¶
默认情况下,torch.export 会假设所有输入形状都是静态的,并将导出的程序专门化为这些维度。然而,某些维度(例如批处理维度)可以是动态的,并且在每次运行时都可能变化。这些维度必须通过使用
torch.export.Dim() API 来创建,并通过 dynamic_shapes 参数传递给
torch.export.export()。以下是一个示例:
import torch
from torch.export import Dim, export
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):
# code: out1 = self.branch1(x1)
linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)
# code: out2 = self.branch2(x2)
linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)
# code: return (out1 + self.buffer, out2)
add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
return (add, relu_1)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch1_0_weight'),
target='branch1.0.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch1_0_bias'),
target='branch1.0.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch2_0_weight'),
target='branch2.0.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch2_0_bias'),
target='branch2.0.bias',
persistent=None
),
InputSpec(
kind=<InputKind.CONSTANT_TENSOR: 4>,
arg=TensorArgument(name='c_buffer'),
target='buffer',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x1'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x2'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='relu_1'),
target=None
)
]
)
Range constraints: {s0: VR[0, int_oo]}
需要注意的一些其他事项:
通过
torch.export.Dim()API和dynamic_shapes参数,我们指定了每个输入的第一个维度为动态。查看输入x1和x2,它们具有符号形状(s0, 64)和(s0, 128),而不是我们作为示例输入传递的(32, 64)和(32, 128)形状的张量。s0是一个符号,表示这个维度可以是一系列值。exported_program.range_constraints描述了图表中每个符号的范围。 在这种情况下,我们看到s0的范围是 [0, int_oo]。由于技术原因,这里难以解释,它们被假定为不是 0 或 1。 这并不是一个错误,并不一定意味着导出的程序在维度为 0 或 1 时无法工作。 详见 The 0/1 特殊化问题 以深入了解此主题。
我们还可以指定输入形状之间更复杂的关联关系,例如 一对形状可能相差一个单位,一个形状可能是另一个的两倍, 或者一个形状是偶数。一个例子:
class M(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
exported_program = torch.export.export(
M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
# code: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
return (add,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}
需要注意的一些事项:
通过为第一个输入指定
{0: dimx},我们看到第一个输入的形状现在是动态的,变为[s0]。现在通过为第二个输入指定{0: dimy},我们看到第二个输入的形状也是动态的。然而,因为我们表示了dimy = dimx + 1,而不是y的形状包含一个新符号,我们看到它现在用在x,s0中使用的相同符号表示。我们可以看到dimy = dimx + 1的关系通过s0 + 1显示出来。查看范围约束,我们看到
s0的范围是 [3, 6], 这是最初指定的,并且我们可以看到s0 + 1的已解决 范围为 [4, 7]。
序列化¶
要保存ExportedProgram,用户可以使用torch.export.save()和
torch.export.load() API。一种约定是使用ExportedProgram
并以.pt2文件扩展名保存。
一个示例:
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
专业领域¶
理解 torch.export 行为的一个关键概念是 静态 和 动态 值之间的区别。
一个 动态 值是指在每次运行时都可能发生变化的值。这些行为类似于Python函数的普通参数——你可以为一个参数传递不同的值,并期望你的函数能够正确处理。张量 数据 被视为动态。
一个 静态 值是在导出时固定的值,在导出程序的执行之间不能改变。当在跟踪过程中遇到该值时,导出器会将其视为常量并将其硬编码到图中。
当执行一个操作(例如 x + y)并且所有输入都是静态的,那么
该操作的输出将直接硬编码到图中,并且该操作不会显示(即它将被常量折叠)。
当一个值被硬编码到图中时,我们说该图已被 专门化为该值。
以下值是静态的:
输入张量形状¶
默认情况下,torch.export 将跟踪程序以专注于输入张量的形状,除非通过 dynamic_shapes 参数指定维度为动态的 torch.export。这意味着如果存在依赖于形状的控制流,torch.export 将专注于给定样本输入所采取的分支。例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 2]"):
# code: return x + 1
add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
return (add,)
条件(x.shape[0] > 5)不在
ExportedProgram 中出现,因为示例输入具有静态形状 (10, 2)。由于 torch.export 专门针对输入的静态形状进行了优化,因此 else 分支(x - 1)永远不会被触发。为了在跟踪图中保留基于张量形状的动态分支行为,需要使用 torch.export.Dim() 来指定输入张量(x.shape[0])的维度为动态,并且需要重写源代码。
请注意,作为模块状态一部分的张量(例如参数和缓冲区)总是具有静态形状。
Python 基本类型¶
torch.export 还专门针对Python原语进行优化,
例如 int, float, bool, 和 str. 但是它们也有动态
变体,如 SymInt, SymFloat, 和 SymBool。
例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[2, 2]", const, times):
# code: x = x + const
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
return (add_2,)
由于整数是专门化的,torch.ops.aten.add.Tensor 操作都是使用硬编码的常量 1 进行计算的,而不是 const。如果用户在运行时传递一个与导出时使用的值不同的值,例如 2 而不是 1,这将导致错误。
此外,在 for 循环中使用的 times 迭代器也通过 3 次重复的 torch.ops.aten.add.Tensor 调用“内联”到图中,并且输入 times 从未被使用。
Python 容器¶
Python 容器(List, Dict, NamedTuple 等)被认为具有静态结构。
torch.export的限制¶
图中断点¶
由于torch.export是一个一次性过程,用于从PyTorch程序中捕获计算图,因此它最终可能会遇到程序中无法追踪的部分,因为几乎不可能支持追踪所有PyTorch和Python功能。在torch.compile的情况下,不支持的操作将导致“图形中断”,并且不支持的操作将使用默认的Python评估运行。相比之下,torch.export将要求用户提供额外的信息或重写代码的部分以使其可追踪。由于追踪是基于TorchDynamo的,后者在Python字节码级别进行评估,因此与以前的追踪框架相比,所需的重写将显著减少。
当遇到图中断时,ExportDB 是一个很好的资源,可以了解支持和不支持的程序类型,以及如何重写程序以使其可跟踪。
通过使用非严格导出选项,可以解决处理此图形中断的问题
数据/形状依赖的控制流¶
图中断点也可能在数据依赖的控制流(if
x.shape[0] > 2)中遇到,当形状没有被专门化时,因为跟踪编译器不可能处理这种情况,而不会生成组合爆炸数量的路径代码。在这种情况下,用户需要使用特殊的控制流操作符重写他们的代码。目前,我们支持 torch.cond
来表达类似if-else的控制流(更多即将推出!)。
缺少操作符的伪/元/抽象内核¶
在进行追踪时,所有操作符都需要一个假张量内核(也称为元内核或抽象实现)。这用于推理该操作符的输入/输出形状。
请参见torch.library.register_fake()以获取更多详情。
在不幸的情况下,如果你的模型使用了一个尚未实现 FakeTensor 内核的 ATen 操作,请提交一个问题报告。
API 参考¶
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source][source]¶
export()可以接受任何 nn.Module 和示例输入,并生成一个表示该函数张量计算的追踪图, 以预编译方式(AOT)呈现,随后可以使用不同的输入执行或序列化。该追踪图 (1) 生成功能性的 ATen 操作集中的标准化操作 (以及任何用户指定的自定义操作),(2) 消除了所有的 Python 控制流和数据结构(某些例外情况除外),并且 (3) 记录了所需的形状约束集 以证明这种标准化和控制流消除对于未来的输入是有效的。健全性保证
在跟踪过程中,
export()会记录用户程序和底层PyTorch运算符内核所做的与形状相关的假设。 只有当这些假设成立时,输出ExportedProgram才被认为是有效的。跟踪对输入张量的形状(而不是值)做出假设。 这些假设必须在图形捕获时进行验证,以使
export()成功。具体来说:对输入张量的静态形状假设会自动进行验证,无需额外的努力。
关于输入张量动态形状的假设需要通过使用
Dim()API来显式指定构建动态维度,并通过dynamic_shapes参数将它们与示例输入关联。
如果任何假设无法验证,将引发致命错误。当这种情况发生时, 错误消息将包括对规范的建议修复,这些修复是验证假设所需的。 例如
export()可能建议对动态维度定义进行以下修复dim0_x, 比如说出现在与输入x相关的形状中,该输入之前被定义为Dim("dim0_x"):dim = Dim("dim0_x", max=5)
这个示例意味着生成的代码要求输入的第0维
x必须小于或等于5才能有效。您可以检查对动态维度定义的建议修复,然后将它们逐字复制到您的代码中,而无需更改dynamic_shapes参数到您的export()调用。- Parameters
模块 (模块) – 我们将跟踪此模块的前向方法。
动态形状 (可选[联合[字典[字符串, 任意], 元组[任意], 列表[任意]]]) –
一个可选参数,其类型应为: 1) 一个从
f的参数名到它们动态形状规范的字典, 2) 一个元组,按原始顺序为每个输入指定动态形状规范。 如果你在关键字参数上指定动态性,你需要按照原始函数签名中定义的顺序传递它们。张量参数的动态形状可以指定为以下两种方式之一: (1) 一个从动态维度索引到
Dim()种类型的字典,其中不需要在该字典中包含静态维度索引,但如果包含,则应映射到None;或 (2) 一个Dim()种类型或None的元组/列表,其中Dim()种类型对应于动态维度,而静态维度则用None表示。字典或张量的元组/列表形式的参数通过使用包含的规范的映射或序列递归地指定。严格 (布尔值) – 当启用(默认)时,导出函数将通过TorchDynamo跟踪程序,这将确保生成的图的健全性。否则,导出的程序将不会验证图中隐含的假设,可能会导致原始模型和导出模型之间的行为差异。当用户需要绕过跟踪器中的错误,或者只是希望逐步在其模型中启用安全性时,这很有用。请注意,这不会影响生成的IR规范不同,模型将以相同的方式序列化,而不论此处传递的值是什么。 警告:此选项是实验性的,使用它需自行承担风险。
- Returns
一个
ExportedProgram包含被跟踪的可调用对象。- Return type
可接受的输入/输出类型
可接受的输入类型(对于
args和kwargs)和输出包括:基本类型,即
torch.Tensor、int、float、bool和str。数据类,但它们必须通过调用
register_dataclass()首先进行注册。(嵌套) 包含
dict,list,tuple,namedtuple和OrderedDict的数据结构,包含以上所有类型。
- torch.export.save(ep, f, *, extra_files=None, opset_version=None)[source][source]¶
警告
正在积极开发中,保存的文件可能无法在较新版本的PyTorch中使用。
将一个
ExportedProgram保存到类似文件的对象中。然后可以使用Python APItorch.export.load加载它。- Parameters
ep (导出的程序) – 要保存的导出程序。
f (Union[str, os.PathLike, io.BytesIO) – 一个类似文件的对象(必须实现写入和刷新功能)或包含文件名的字符串。
extra_files (可选[Dict[str, Any]]) – 从文件名到内容的映射,这些内容将作为f的一部分进行存储。
Example:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source][source]¶
警告
正在积极开发中,保存的文件可能无法在较新版本的PyTorch中使用。
加载一个
ExportedProgram之前使用torch.export.save保存的。- Parameters
ep (导出的程序) – 要保存的导出程序。
f (Union[str, os.PathLike, io.BytesIO) – 一个类似文件的对象(必须实现写入和刷新功能)或包含文件名的字符串。
extra_files (可选[Dict[str, Any]]) – 此映射中给出的额外文件名将被加载,其内容将存储在提供的映射中。
- Returns
一个
ExportedProgram对象- Return type
Example:
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[source][source]¶
将一个数据类注册为
torch.export.export()的有效输入/输出类型。- Parameters
Example:
import torch from dataclasses import dataclass @dataclass class InputDataClass: feature: torch.Tensor bias: int @dataclass class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res) ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source][source]¶
Dim()构建了一种类似于命名符号整数的类型,具有一定的范围。 它可以用来描述动态张量维度的多种可能值。 请注意,同一张量或不同张量的不同动态维度, 可以用相同的类型来描述。
- torch.export.exported_program.default_decompositions()[source][source]¶
这是默认的分解表,其中包含所有ATEN操作符到核心aten操作集的分解。请与此API一起使用
run_decompositions()- Return type
- class torch.export.dynamic_shapes.ShapesCollection[source][source]¶
动态形状构建器。 用于为输入中出现的张量分配动态形状规格。
- Example::
args = ({“x”: 张量_x, “others”: [张量_y, 张量_z]})
dim = torch.export.Dim(…) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # 这相当于以下内容(现在自动生成): # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}
torch.export(…, args, dynamic_shapes=dynamic_shapes)
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source][source]¶
针对处理导出的动态形状建议修复以及/或自动动态形状,根据约束违规错误信息和原始动态形状优化给定的动态形状规格。
在大多数情况下,行为是直截了当的——即对于专门化或细化 Dim 的范围的建议修复,或建议派生关系的修复,新的动态形状规范将相应更新。
请提供需要翻译的 Pytorch 深度学习框架网站的具体文本内容。
dim = Dim(‘dim’, min=3, max=6) -> this just refines the dim’s range dim = 4 -> this specializes to a constant dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation
然而,与派生维度相关的建议修复可能更为复杂。 例如,如果为根维度提供了建议修复,则新的派生维度值是基于根维度来评估的。
dx = Dim('dx') dy = dx + 2 dynamic_shapes = {"x": (dx,), "y": (dy,)}
建议的修改:
dx = 4 # specialization will lead to dy also specializing = 6 dx = Dim(‘dx’, max=6) # dy now has max = 8
建议的派生维度修复也可以用于表示可除性约束。 这涉及到创建不依赖于特定输入形状的新根维度。 在这种情况下,根维度不会直接出现在新规范中,而是作为某个维度的根。
请提供需要翻译的 Pytorch 深度学习框架网站的具体文本内容。
_dx = Dim(‘_dx’, max=1024) # this won’t appear in the return result, but dx will dx = 4*_dx # dx is now divisible by 4, with a max value of 4096
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source][source]¶
程序包来自
export()。它包含一个torch.fx.Graph,表示张量计算,一个state_dict包含所有提升参数和缓冲区的张量值,以及各种元数据。你可以像调用原始可追踪的
export()一样,以相同的调用约定调用ExportedProgram。要对图进行转换,请使用
.module属性来访问 一个torch.fx.GraphModule。然后你可以使用 FX转换 来重写图。之后,你可以简单地再次使用export()来构建一个正确的ExportedProgram。- run_decompositions(decomp_table=None)[source][source]¶
对导出的程序运行一组分解,并返回一个新的导出程序。默认情况下,我们将运行Core ATen分解以获取 Core ATen 操作集中的操作符。
目前,我们不分解联合图。
- Parameters
decomp_table (可选[字典[OperatorBase, 可调用对象]]) – 一个指定Aten操作分解行为的可选参数 (1) 如果为None,我们将分解为核心aten分解 (2) 如果为空,我们不分解任何操作符
- Return type
一些示例:
如果你不想分解任何内容
ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={})
如果你想要获取除了某些操作之外的核心aten操作集,可以这样做:
ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table)
- class torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[source][source]¶
- class torch.export.ExportGraphSignature(input_specs, output_specs)[source][source]¶
ExportGraphSignature模型的输入/输出签名是导出图的,这是一个具有更强不变性保证的fx.Graph。导出图是功能性的,并不通过
getattr节点访问像参数或缓冲区这样的“状态”。相反,export()确保参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何修改也不包含在图中,而是将更新后的缓冲区值建模为导出图的额外输出。所有输入和输出的顺序为:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出了以下模块:
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的图将为:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的ExportGraphSignature将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[List[str]] = None)[source][source]¶
- class torch.export.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[source][source]¶
- class torch.export.decomp_utils.CustomDecompTable[source][source]¶
这是一个自定义词典,专门用于处理导出中的decomp_table。 我们需要这个的原因是因为在新世界中,你只能删除decomp表中的一个操作以保留它。这对于自定义操作来说是个问题,因为我们不知道自定义操作何时会被加载到调度器中。因此,我们需要记录自定义操作,直到真正需要实现它(即我们在运行分解传递时)。
- Invariants we hold are:
所有 Aten 解构都在初始化时加载。
我们会在用户每次读取表时实现所有操作,以便调度程序更有可能拾取自定义操作。
如果是写入操作,我们并不一定需要将其实体化。
我们在导出前最后一次加载,就在调用 run_decompositions() 之前。
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[source][source]¶
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[source][source]¶
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source][source]¶
ExportGraphSignature模型的输入/输出签名是导出图的,这是一个具有更强不变性保证的fx.Graph。导出图是功能性的,并不通过
getattr节点访问像参数或缓冲区这样的“状态”。相反,export()保证参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何修改也不包含在图中,而是将更新后的缓冲区值建模为导出图的额外输出。所有输入和输出的顺序为:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出了以下模块:
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的图将为:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的ExportGraphSignature将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[source][source]¶
- class torch.export.unflatten.InterpreterModule(graph)[source][source]¶
一个模块,它使用 torch.fx.Interpreter 来执行,而不是通常 GraphModule 使用的代码生成。这提供了更好的堆栈跟踪信息,并使调试执行变得更加容易。
- class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source][source]¶
一个模块,包含一系列与该模块调用序列相对应的 InterpreterModules。每次对该模块的调用都会分派到下一个 InterpreterModule,在最后一个之后循环返回。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source][source]¶
展开一个ExportedProgram,生成一个与原始急切模块具有相同模块层次结构的模块。这在你尝试使用
torch.export与其他期望模块层次结构而不是torch.export通常生成的扁平图的系统一起使用时可能很有用。注意
未展平模块的args/kwargs不一定与急切模式下的模块匹配,因此进行模块交换(例如
self.submod = new_mod)可能无法正常工作。如果你需要替换一个模块,你需要设置preserve_module_call_signature参数为torch.export.export()。- Parameters
模块 (导出的程序) – 需要展开的ExportedProgram。
flat_args_adapter (可选[FlatArgsAdapter]) – 如果输入的TreeSpec与导出模块不匹配,则适应扁平参数。
- Returns
UnflattenedModule的一个实例,它具有与原始急切模式模块导出前相同的模块层次结构。- Return type
UnflattenedModule