torch.onnx¶
开放神经网络交换 (ONNX) 是一种用于表示机器学习模型的开放式标准格式。torch.onnx 模块可以将 PyTorch 模型导出为 ONNX 格式。然后,该模型可以被任何支持 ONNX 的 运行时环境 使用。
示例:从 PyTorch 到 ONNX 的 AlexNet¶
这是一个简单的脚本,用于将预训练的AlexNet导出为名为 alexnet.onnx 的ONNX文件。
对 torch.onnx.export 的调用会运行模型一次以追踪其执行过程,然后将追踪后的模型导出到指定文件中:
import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()
# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
生成的 alexnet.onnx 文件包含一个二进制 协议缓冲区
其中包含了您导出的模型(在此情况下为 AlexNet)的网络结构和参数。
参数 verbose=True 会导致导出器打印出模型的人类可读表示:
# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
%learned_0 : Float(64, 3, 11, 11)
%learned_1 : Float(64)
%learned_2 : Float(192, 64, 5, 5)
%learned_3 : Float(192)
# ---- omitted for brevity ----
%learned_14 : Float(1000, 4096)
%learned_15 : Float(1000)) {
# Every statement consists of some output tensors (and their types),
# the operator to be run (with its attributes, e.g., kernels, strides,
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
# ---- omitted for brevity ----
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
# Dynamic means that the shape is not known. This may be because of a
# limitation of our implementation (which we would like to fix in a
# future release) or shapes which are truly dynamic.
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
# ---- omitted for brevity ----
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%output1);
}
你也可以使用 ONNX 库来验证输出结果, 你可以通过 conda 安装它:
conda install -c conda-forge onnx
然后,你可以运行:
import onnx
# Load the ONNX model
model = onnx.load("alexnet.onnx")
# Check that the model is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
您还可以使用支持ONNX的众多运行时之一来运行导出的模型。 例如,在安装ONNX Runtime之后,您可以 加载并运行该模型:
import onnxruntime as ort
ort_session = ort.InferenceSession("alexnet.onnx")
outputs = ort_session.run(
None,
{"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])
这里有一个更详细的关于导出模型并使用ONNX Runtime运行的教程。
追踪与脚本编写¶
在内部,torch.onnx.export() 需要一个 torch.jit.ScriptModule 而不是
一个 torch.nn.Module。如果传入的模型还不是 ScriptModule,
export() 将使用 追踪 来将其转换为一个:
追踪: 如果
torch.onnx.export()被调用时传入的 Module 不是一个ScriptModule,它会首先执行类似于torch.jit.trace()的操作,这会使用给定的args运行模型一次,并记录该运行过程中发生的所有操作。这意味着如果你的模型是动态的,例如根据输入数据改变行为,导出的 模型将 不会 捕获这种动态行为。同样,追踪可能仅适用于特定的输入尺寸。我们建议检查导出的模型并确保操作符看起来合理。追踪会展开循环和 if 语句,导出一个静态图,与追踪运行完全相同。如果你想以动态控制流导出模型,你需要使用 脚本化。脚本化: 通过脚本化编译模型可以保留动态控制流,并适用于不同大小的输入。要使用脚本化:
使用
torch.jit.script()来生成一个ScriptModule。调用
torch.onnx.export()并将ScriptModule作为模型。args仍然需要, 但它们将仅用于内部生成示例输出,以便捕获输出的类型和形状。不会执行任何追踪操作。
请参阅 TorchScript简介 和 TorchScript 以获取更多详细信息,包括如何组合追踪和脚本来满足不同模型的特定需求。
避免常见陷阱¶
避免使用 NumPy 和内置的 Python 类型¶
PyTorch模型可以使用NumPy或Python类型和函数编写,但在追踪过程中,任何NumPy或Python类型的变量(而不是torch.Tensor)都会被转换为常量,如果这些值应根据输入变化,则会产生错误结果。
例如,与其在 numpy.ndarrays 上使用 numpy 函数:
# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)
在 torch.Tensors 上使用 torch 操作符:
# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)
而且,与其使用 torch.Tensor.item()(它将张量转换为Python内置数字):
# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
return x.reshape(y.item(), -1)
利用 torch 对单元素张量的隐式类型转换支持:
# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
return x.reshape(y, -1)
避免使用 Tensor.data¶
使用 Tensor.data 字段可能会产生错误的跟踪记录,从而导致生成不正确的 ONNX 图。
请改用 torch.Tensor.detach()。 (目前正在努力
完全移除 Tensor.data)。
在跟踪模式下使用 tensor.shape 时,避免进行原地操作¶
在追踪模式下,从 tensor.shape 获得的形状值会被追踪为张量, 并且共享相同的内存。这可能会导致最终输出值的不匹配。 作为一种解决方法,在这些情况下避免使用原地操作。 例如,在模型中:
class Model(torch.nn.Module):
def forward(self, states):
batch_size, seq_length = states.shape[:2]
real_seq_length = seq_length
real_seq_length += 2
return real_seq_length + seq_length
real_seq_length 和 seq_length 在追踪模式下共享相同的内存。
通过重写原地操作可以避免这种情况:
real_seq_length = real_seq_length + 2
限制条件¶
类型¶
仅支持torch.Tensors、可以简单转换为torch.Tensors的数值类型(例如float、int)以及这些类型的元组和列表作为模型输入或输出。在追踪模式下接受字典和字符串输入和输出,但:
任何依赖于字典或字符串输入值的计算都将被替换为在一次跟踪执行过程中观察到的常数值。
任何输出为字典的内容将被静默替换为其值的扁平化序列(键将被移除)。例如
{"foo": 1, "bar": 2}变为(1, 2)。任何输出为 str 类型的内容都将被静默移除。
某些涉及元组和列表的操作在脚本模式下不受支持,因为ONNX对嵌套序列的支持有限。 特别是将元组附加到列表的操作不受支持。在追踪模式下,嵌套序列将在追踪过程中自动展平。
算子实现的差异¶
由于操作符的实现存在差异,在不同的运行时上运行导出的模型,可能会产生彼此之间或与 PyTorch 不同的结果。通常这些差异在数值上很小,因此只有在您的应用程序对这些微小差异敏感时,才需要关注这一点。
不支持的张量索引模式¶
无法导出的张量索引模式列表如下。
如果你在导出一个不包含以下任何不支持模式的模型时遇到问题,请确认你是否使用最新版本的 opset_version 进行导出。
读取 / 获取¶
当对张量进行读取索引时,不支持以下模式:
# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.
写入 / 设置¶
当对张量进行写入操作时,以下索引模式不被支持:
# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
# or multiple consecutive tensor indices with rank == 1.
# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.
# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.
# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
# data shape: [3, 4, 5]
# new_data shape: [5]
# expected new_data shape after broadcasting: [2, 2, 2, 5]
添加对运算符的支持¶
当导出包含不支持操作符的模型时,你会看到类似以下的错误信息:
RuntimeError: ONNX export failed: Couldn't export operator foo
当发生这种情况时,你需要要么修改模型以不使用该运算符, 要么为该运算符添加支持。
添加对运算符的支持需要对PyTorch的源代码进行贡献。 请参阅 CONTRIBUTING 以获取一般性指导,以下部分则提供针对支持运算符所需代码更改的具体说明。
在导出过程中,TorchScript 图中的每个节点都会按拓扑顺序被访问。
访问一个节点时,导出器会尝试查找该节点的已注册符号函数。
符号函数是用 Python 实现的。对于名为 foo 的操作,其符号函数可能如下所示:
def foo(
g: torch._C.Graph,
input_0: torch._C.Value,
input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
"""
Modifies g (e.g., using "g.op()"), adding the ONNX operations representing
this PyTorch function.
Args:
g (Graph): graph to write the ONNX representation into.
input_0 (Value): value representing the variables which contain
the first input for this operator.
input_1 (Value): value representing the variables which contain
the second input for this operator.
Returns:
A Value or List of Values specifying the ONNX nodes that compute something
equivalent to the original PyTorch operator with the given inputs.
Returns None if it cannot be converted to ONNX.
"""
...
The torch._C types are Python wrappers around the types defined in C++ in
ir.h.
添加符号函数的过程取决于操作符的类型。
ATen 操作符¶
ATen 是 PyTorch 内置的张量库。
如果该操作是 ATen 操作(在 TorchScript 图中以前缀
aten:: 显示),请确保它尚未被支持。
支持的操作符列表¶
访问自动生成的 支持的ATen操作符列表
以了解每个 opset_version 中支持的操作符详情。
添加对操作符的支持¶
如果操作符不在上述列表中:
在
torch/onnx/symbolic_opset<version>.py中定义符号函数,例如 torch/onnx/symbolic_opset9.py。 确保该函数与ATen函数的名称相同,ATen函数可能在torch/_C/_VariableFunctions.pyi或torch/nn/functional.pyi中声明(这些文件在构建时生成,因此在你检出代码后直到构建PyTorch之前都不会出现)。默认情况下,第一个参数是ONNX图。 其他参数名称必须与
.pyi文件中的名称完全匹配, 因为使用关键字参数进行分发。一个符号函数,其第一个参数(在 Graph 对象之前)具有类型注解 torch.onnx.SymbolicContext,将使用该附加上下文进行调用。请参见下面的示例。
在符号函数中,如果运算符在 ONNX标准运算符集中, 我们只需要创建一个节点来表示图中的ONNX运算符。 如果不是,我们可以创建一个由几个具有等效语义的标准运算符组成的图来表示ATen运算符。
如果输入参数是一个张量,但 ONNX 需要一个标量,我们必须显式地进行转换。
symbolic_helper._scalar()可以将标量张量转换为 Python 标量,而symbolic_helper._if_scalar_type_as()可以将 Python 标量转换为 PyTorch 张量。
这是一个处理 ELU 运算符缺失符号函数的示例。
如果我们运行以下代码:
print(
torch.jit.trace(torch.nn.ELU(), # module
torch.ones(1) # example input
).graph)
我们看到类似这样的内容:
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%4 : float = prim::Constant[value=1.]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
return (%7)
由于我们在图表中看到 aten::elu,我们知道这是一个ATen操作符。
我们检查了ONNX运算符列表,
并确认Elu在ONNX中是标准化的。
我们在 elu 中找到 torch/nn/functional.pyi 的特征:
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
我们向 symbolic_opset9.py 添加以下行:
def elu(g, input, alpha, inplace=False):
return g.op("Elu", input, alpha_f=_scalar(alpha))
现在 PyTorch 可以导出包含 aten::elu 运算符的模型!
查看 symbolic_opset*.py 个文件以获取更多示例。
torch.autograd.Functions¶
如果操作符是 torch.autograd.Function 的子类,有两种方法可以导出它。
静态符号方法¶
你可以向函数类中添加一个名为 symbolic 的静态方法。它应该返回
表示该函数在 ONNX 中行为的 ONNX 运算符。例如:
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def symbolic(g: torch._C.graph, input: torch._C.Value) -> torch._C.Value:
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
PythonOp 符号¶
或者,您可以注册一个自定义的符号函数。
这使符号函数可以通过
torch.onnx.SymbolicContext 对象访问更多信息,该对象作为第一个参数传入(在 Graph 对象之前)。
所有 autograd Function 都会以 prim::PythonOp 节点的形式出现在 TorchScript 图中。
为了区分不同的 Function 子类,符号函数应使用 name 参数,该参数会被设置为类的名称。
自定义符号函数应在返回 Value 对象之前,通过调用 setType(...)
为 Value 对象添加类型和形状信息(由 C++ 中的
torch::jit::Value::setType 实现)。这并非必需,但它可以帮助导出器对下游节点进行形状和类型推断。关于 setType 的非平凡示例,请参见
test_aten_embedding_2 在
test_operators.py 中的实现。
以下示例展示了如何通过 requires_grad 访问 Node 对象:
class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min):
ctx.save_for_backward(input)
return input.clamp(min=min)
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
def symbolic_python_op(ctx: torch.onnx.SymbolicContext, g: torch._C.Graph, *args, **kwargs):
n = ctx.cur_node
print("original node: ", n)
for i, out in enumerate(n.outputs()):
print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad()))
import torch.onnx.symbolic_helper as sym_helper
for i, arg in enumerate(args):
requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False
print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
name = kwargs["name"]
ret = None
if name == "MyClip":
ret = g.op("Clip", args[0], args[1])
elif name == "MyRelu":
ret = g.op("Relu", args[0])
else:
# Logs a warning and returns None
return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
# Copy type and shape from original node.
ret.setType(n.type())
return ret
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
自定义操作符¶
如果模型使用了如使用自定义C++运算符扩展TorchScript中描述的自定义C++运算符实现, 您可以按照此示例进行导出:
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
# Define custom symbolic function
@parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
# Register custom symbolic function
register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
class FooModel(torch.nn.Module):
def __init__(self, attr1, attr2):
super(FooModule, self).__init__()
self.attr1 = attr1
self.attr2 = attr2
def forward(self, input1, input2):
# Calling custom op
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
model = FooModel(attr1, attr2)
torch.onnx.export(
model,
(example_input1, example_input1),
"model.onnx",
# only needed if you want to specify an opset version > 1.
custom_opsets={"custom_domain": 2})
您可以将其导出为一个或多个标准ONNX操作符,或者作为自定义操作符。
上面的例子将其导出为“custom_domain”操作集中的自定义操作符。
在导出自定义操作符时,您可以在导出时使用custom_opsets字典指定自定义域版本。如果没有指定,默认的自定义操作集版本为1。
消耗模型的运行时需要支持自定义操作符。请参阅
Caffe2自定义操作符,
ONNX Runtime自定义操作符,
或您选择的运行时文档。
一次性发现所有无法转换的 ATen 操作¶
当由于无法转换的 ATen 操作导致导出失败时,实际上可能存在多个这样的操作,但错误信息只会提到第一个。若要一次性发现所有无法转换的操作,你可以:
from torch.onnx import utils as onnx_utils
# prepare model, args, opset_version
...
torch_script_graph, unconvertible_ops = onnx_utils.unconvertible_ops(
model, args, opset_version=opset_version)
print(set(unconvertible_ops))
常见问题解答¶
Q: 我已经导出了我的LSTM模型,但它的输入大小似乎被固定了?
The tracer records the shapes of the example inputs. If the model should accept inputs of dynamic shapes, set
dynamic_axeswhen callingtorch.onnx.export().
Q: 如何导出包含循环的模型?
See Tracing vs Scripting.
Q: 如何导出带有原始类型输入(例如 int、float)的模型?
Support for primitive numeric type inputs was added in PyTorch 1.9. However, the exporter does not support models with str inputs.
Q: ONNX 是否支持隐式的标量数据类型转换?
No, but the exporter will try to handle that part. Scalars are exported as constant tensors. The exporter will try to figure out the right datatype for scalars. However when it is unable to do so, you will need to manually specify the datatype. This often happens with scripted models, where the datatypes are not recorded. For example:
class ImplicitCastType(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): # Exporter knows x is float32, will export "2" as float32 as well. y = x + 2 # Currently the exporter doesn't know the datatype of y, so # "3" is exported as int64, which is wrong! return y + 3 # To fix, replace the line above with: # return y + torch.tensor([3], dtype=torch.float32) x = torch.tensor([1.0], dtype=torch.float32) torch.onnx.export(ImplicitCastType(), x, "implicit_cast.onnx", example_outputs=ImplicitCastType()(x))We are trying to improve the datatype propagation in the exporter such that implicit casting is supported in more cases.
Q: Tensor 列表可以导出到 ONNX 吗?
Yes, for
opset_version>= 11, since ONNX introduced the Sequence type in opset 11.
函数¶
-
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, export_modules_as_functions=False)[source]¶ 将模型导出为ONNX格式。如果
model既不是torch.jit.ScriptModule也不是torch.jit.ScriptFunction,则会运行model一次以将其转换为TorchScript图以便导出 (等同于torch.jit.trace())。因此,这与torch.jit.trace()一样,对动态控制流的支持也有限。- Parameters
模型 (torch.nn.Module, torch.jit.ScriptModule 或 torch.jit.ScriptFunction) – 要导出的模型。
参数 (元组 或 torch.Tensor) –
参数可以以以下两种方式之一进行结构化:
仅包含一组参数:
args = (x, y, z)
该元组应包含模型输入,使得
model(*args)是对模型的有效调用。任何非张量参数将被硬编码到导出的模型中;任何张量参数将成为导出模型的输入,按照它们在元组中出现的顺序。一个张量:
args = torch.Tensor([1])
这相当于该 Tensor 的一个一元元组。
以一个包含命名参数字典的参数元组结尾:
args = (x, {'y': input_y, 'z': input_z})
元组中除最后一个元素外的所有元素都将作为非关键字参数传递, 而命名参数将从最后一个元素中设置。如果字典中没有某个命名参数, 则将其赋值为默认值,如果没有提供默认值,则赋值为 None。
注意
如果一个字典是 args 元组的最后一个元素,它将被解释为包含命名参数。如果要将字典作为最后一个非关键字参数传递,请在 args 元组的最后一个元素中提供一个空字典。例如,不要使用:
torch.onnx.export( model, (x, # WRONG: will be interpreted as named arguments {y: z}), "test.onnx.pb")
Write:
torch.onnx.export( model, (x, {y: z}, {}), "test.onnx.pb")
f – 一个类似文件的对象(例如
f.fileno()返回一个文件描述符) 或包含文件名的字符串。二进制协议缓冲区将被写入此文件。export_params (bool, default True) – 如果为 True,将导出所有参数。如果你想导出一个未训练的模型,请将此值设为 False。 在这种情况下,导出的模型将首先将其所有参数作为参数传入,顺序由
model.state_dict().values()指定verbose (bool, 默认 False) – 如果为 True,会将导出模型的描述打印到标准输出。此外,最终的 ONNX 图将包含从导出模型中获取的字段
doc_string`,该字段提到了model的源代码位置。如果为 True,将启用 ONNX 导出器日志记录。训练 (枚举, 默认 TrainingMode.EVAL) –
TrainingMode.EVAL: 以推理模式导出模型。TrainingMode.PRESERVE: 在模型处于推理模式(即 model.training 为 False)时导出模型,在模型处于训练模式(即 model.training 为 True)时进行训练。TrainingMode.TRAINING: 以训练模式导出模型。禁用可能干扰训练的优化选项。
input_names (str 列表, 默认为空列表) – 按顺序分配给图输入节点的名称。
output_names (str 列表, 默认为空列表) – 按顺序分配给图输出节点的名称。
operator_export_type (枚举, 默认 OperatorExportTypes.ONNX) –
OperatorExportTypes.ONNX: 将所有操作导出为常规ONNX操作 (在默认的opset域中)。OperatorExportTypes.ONNX_FALLTHROUGH: 尝试将所有操作转换为默认opset域中的标准ONNX操作。如果无法这样做(例如,因为尚未添加将特定torch操作转换为ONNX的支持),则回退到将操作导出到自定义opset域而不进行转换。适用于自定义操作以及ATen操作。对于导出的模型要可用,运行时必须支持这些非标准操作。OperatorExportTypes.ONNX_ATEN: 所有 ATen 操作(在 TorchScript 命名空间 “aten” 中) 都会以 ATen 操作的形式导出(在 opset 域 “org.pytorch.aten” 中)。 ATen 是 PyTorch 内置的张量库,因此 这指示运行时使用 PyTorch 对这些操作的实现。警告
以这种方式导出的模型可能只能由 Caffe2 运行。
这在操作符实现中的数值差异导致 PyTorch 和 Caffe2 之间行为出现较大差异时可能会有帮助(这种情况在未训练的模型中更为常见)。
OperatorExportTypes.ONNX_ATEN_FALLBACK: 尝试将每个ATen操作 (在TorchScript命名空间“aten”中)导出为常规的ONNX操作。如果无法做到这一点 (例如,因为尚未添加将特定torch操作转换为ONNX的支持), 则回退到导出ATen操作。有关OperatorExportTypes.ONNX_ATEN的上下文,请参阅文档。 例如:graph(%0 : Float): %3 : int = prim::Constant[value=0]() # conversion unsupported %4 : Float = aten::triu(%0, %3) # conversion supported %5 : Float = aten::mul(%4, %0) return (%5)
假设
aten::triu不被 ONNX 支持,这将被导出为:graph(%0 : Float): %1 : Long() = onnx::Constant[value={0}]() # not converted %2 : Float = aten::ATen[operator="triu"](%0, %1) # converted %3 : Float = onnx::Mul(%2, %0) return (%3)
如果PyTorch是使用Caffe2构建的(即使用
BUILD_CAFFE2=1),那么将启用Caffe2特有的行为,包括对由量化模块描述的操作的支持。警告
以这种方式导出的模型可能只能由 Caffe2 运行。
opset_version (int, 默认值 13) – 要针对的 默认 (ai.onnx) opset 版本。必须 >= 7 且 <= 16。
do_constant_folding (bool, default True) – 应用常量折叠优化。 常量折叠将用预先计算的常量节点替换所有输入均为常量的操作。
dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict) –
默认情况下,导出的模型将把所有输入和输出张量的形状设置为与
args中给出的完全一致。若要指定张量的某些轴为动态(即仅在运行时才知道),请将dynamic_axes设置为具有以下模式的字典:KEY (str): 一个输入或输出名称。每个名称也必须在
input_names或output_names中提供。VALUE (字典或列表):如果是字典,键是轴索引,值是轴名称。如果是列表,每个元素是一个轴索引。
例如:
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"])
Produces:
input { name: "x" ... shape { dim { dim_value: 2 # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_value: 2 # axis 0 ...
While:
torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], })
Produces:
input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0 ...
keep_initializers_as_inputs (bool, 默认 None) –
如果为 True,则导出图中的所有初始化器(通常对应参数)也将作为图的输入添加。如果为 False,则初始化器不会作为图的输入添加,仅非参数输入会被添加为输入。 这可能会允许后端/运行时进行更好的优化(例如常量折叠)。
如果
opset_version < 9,初始化器必须是图输入的一部分,此参数将被忽略,行为将等同于将此参数设置为 True。如果为 None,则行为将按以下方式自动选择:
如果
operator_export_type=OperatorExportTypes.ONNX,则行为等同 于将此参数设置为 False。否则,此参数的行为等同于将其设置为 True。
custom_opsets (dict<str, int>, default empty dict) –
一个具有模式的字典:
KEY (str): opset 域名
VALUE (int): opset 版本
如果自定义 opset 被
model引用但未在此字典中提及, 则 opset 版本将被设置为 1。仅应通过此参数指定自定义 opset 的域名和版本。export_modules_as_functions (bool 或 set of python:type of nn.Module, 默认 False) –
启用标志以将所有
nn.Module前向调用作为 ONNX 中的本地函数导出。或者是一个集合,用于指定要作为 ONNX 中本地函数导出的特定模块类型。 此功能需要opset_version>= 15,否则导出将失败。这是因为opset_version< 15 表示 IR 版本 < 8,这意味着不支持本地函数。 模块变量将作为函数属性导出。函数属性分为两类。1. 带注释的属性:通过PEP 526风格进行类型注解的类变量将作为属性导出。 带注释的属性不会在ONNX本地函数的子图中使用,因为它们不是由PyTorch JIT跟踪创建的,但它们可能被消费者用于确定是否用特定融合内核替换该函数。
2. 推断属性:在模块内部操作符中使用的变量。属性名称将带有前缀“inferred::”。这是为了与从Python模块注解中获取的预定义属性区分开来。推断属性用于ONNX本地函数的子图内部。
False``(default): export ``nn.Module前向调用作为细粒度节点。True: 导出所有nn.Module前向调用作为本地函数节点。Set of type of nn.Module: export
nn.Moduleforward calls as local function nodes, only if the type of thenn.Moduleis found in the set。
- Raises
CheckerError – 如果 ONNX 检查器检测到无效的 ONNX 图。即使引发此错误,仍会将模型导出到文件
f。
-
torch.onnx.export_to_pretty_string(*args, **kwargs)[source]¶ 与
export()类似,但返回 ONNX 模型的文本表示形式。仅列出参数的不同之处。其他所有参数与export()相同。
-
torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source]¶ Registers
symbolic_fnto handlesymbolic_name。请参阅模块文档中的“自定义操作符”以了解示例用法。
-
torch.onnx.select_model_mode_for_export(model, mode)[source]¶ 一个上下文管理器,用于临时将
model的训练模式设置为mode,在退出 with 块时将其重置。如果 mode 为 None,则不执行任何操作。