注意
点击 这里 下载完整示例代码
自定义 Python 运算符¶
创建日期: 2024年6月18日 | 最后更新日期: 2025年1月2日 | 最后验证日期: 2024年11月5日
如何将用Python编写的自定义操作与PyTorch集成
如何测试自定义运算符使用
torch.library.opcheck
PyTorch 2.4 或更高版本
PyTorch 提供了大量用于操作张量的运算符(例如
torch.add, torch.sum, 等等)。然而,您可能希望使用新的自定义运算符与 PyTorch 结合使用,也许是由第三方库编写的。本教程展示了如何包装 Python 函数,使其行为如同 PyTorch 内置运算符。在 PyTorch 中创建自定义运算符的原因包括:
将任意Python函数视为不透明的可调用对象,以对待
torch.compile(即,防止torch.compile追踪进入该函数)。将任意Python函数添加训练支持
请注意,如果您的操作可以表示为现有PyTorch运算符的组合,那么通常无需使用自定义运算符API – 所有内容(例如torch.compile,训练支持)都应该正常工作。
示例:将PIL的crop封装为自定义操作¶
让我们说我们正在使用PIL的crop操作。
import torch
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
import PIL
import IPython
import matplotlib.pyplot as plt
def crop(pic, box):
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return pil_to_tensor(cropped_img).to(pic.device) / 255.
def display(img):
plt.imshow(img.numpy().transpose((1, 2, 0)))
img = torch.ones(3, 64, 64)
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
display(img)

cropped_img = crop(img, (10, 10, 50, 50))
display(cropped_img)

crop 在默认情况下无法被
torch.compile 有效处理:
torch.compile 在无法处理的函数上引发了一个
“图中断”
而图中断对性能不利。
以下代码通过抛出错误来演示这一点
(如果发生图中断,torch.compile 和 fullgraph=True 会抛出错误)。
为了将黑盒crop与torch.compile结合使用,我们需要做两件事:
将函数封装为PyTorch自定义操作符。
添加一个“
FakeTensor核心”(即“元核心”)到操作符。 给定一些FakeTensors输入(虚拟张量,没有存储空间),此函数应返回您选择的具有正确张量元数据(形状/步长/dtype/设备)的虚拟张量。
from typing import Sequence
# Use torch.library.custom_op to define a new custom operator.
# If your operator mutates any input Tensors, their names must be specified
# in the ``mutates_args`` argument.
@torch.library.custom_op("mylib::crop", mutates_args=())
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)
# Use register_fake to add a ``FakeTensor`` kernel for the operator
@crop.register_fake
def _(pic, box):
channels = pic.shape[0]
x0, y0, x1, y1 = box
return pic.new_empty(channels, y1 - y0, x1 - x0)
在此之后,crop 现在不会出现图中断:

/usr/local/lib/python3.10/dist-packages/onnxscript/converter.py:820: FutureWarning:
'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
/usr/local/lib/python3.10/dist-packages/onnxscript/converter.py:820: FutureWarning:
'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
display(cropped_img)

添加训练支持 for crop ¶
使用 torch.library.register_autograd 为操作添加训练支持。
优先使用此方法而非直接使用 torch.autograd.Function;某些 autograd.Function 与 PyTorch 操作注册 API 的组合可能导致(并已导致)在与 torch.compile 组合时出现无声的错误。
如果不需要训练支持,就没有必要使用
torch.library.register_autograd.
如果您最终使用了一个没有自动求导注册的
custom_op 进行训练,我们将抛出一个错误消息。
The gradient formula for crop 是基本上是 PIL.paste (我们将会把推导留给读者)。让我们首先将 paste 包装成一个自定义操作符:
@torch.library.custom_op("mylib::paste", mutates_args=())
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
assert im1.device == im2.device
assert im1.dtype == im2.dtype
im1_pil = to_pil_image(im1.cpu())
im2_pil = to_pil_image(im2.cpu())
PIL.Image.Image.paste(im1_pil, im2_pil, coord)
return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)
@paste.register_fake
def _(im1, im2, coord):
assert im1.device == im2.device
assert im1.dtype == im2.dtype
return torch.empty_like(im1)
现在让我们使用 register_autograd 来指定 crop 的梯度公式:
def backward(ctx, grad_output):
grad_input = grad_output.new_zeros(ctx.pic_shape)
grad_input = paste(grad_input, grad_output, ctx.coords)
return grad_input, None
def setup_context(ctx, inputs, output):
pic, box = inputs
ctx.coords = box[:2]
ctx.pic_shape = pic.shape
crop.register_autograd(backward, setup_context=setup_context)
请注意,反向传播必须由PyTorch能够理解的操作组成,这就是为什么我们将paste封装成了一个自定义操作,而不是直接使用PIL的paste。

这是正确的梯度,在裁剪区域为1(白色),在未使用的区域为0(黑色)。
测试 Python 自定义运算符¶
使用 torch.library.opcheck 测试自定义操作是否已正确注册。
这不测试梯度是否正确;请编写单独的测试(手动测试或 torch.autograd.gradcheck)来验证。
要使用 opcheck,请传递一组示例输入以进行测试。如果您的操作支持训练,那么示例应包括需要 grad 的 Tensor。如果您的操作支持多个设备,那么示例应包括每个设备上的 Tensor。
examples = [
[torch.randn(3, 64, 64), [0, 0, 10, 10]],
[torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],
[torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],
[torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],
]
for example in examples:
torch.library.opcheck(crop, example)
可变的Python自定义运算符¶
您也可以将一个修改其输入的Python函数包装成一个自定义操作符。
具有修改输入功能的函数很常见,因为许多低级内核就是这样编写的;例如,一个计算sin的内核可能会接收输入张量和输出张量,并将input.sin()写入输出张量。
我们将使用 numpy.sin 来演示一个可变的 Python 自定义操作符的例子。
import numpy as np
@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.device == output.device
assert input.device.type == "cpu"
input_np = input.numpy()
output_np = output.numpy()
np.sin(input_np, out=output_np)
因为操作符不返回任何内容,所以不需要注册一个 FakeTensor 内核(元内核)使其能够与 torch.compile 兼容。
@torch.compile(fullgraph=True)
def f(x):
out = torch.empty(3)
numpy_sin(x, out)
return out
x = torch.randn(3)
y = f(x)
assert torch.allclose(y, x.sin())
这里有一个 opcheck 运行结果,告诉我们确实正确注册了操作符。
如果我们忘记将输出添加到 mutates_args,例如,opcheck 会报错。
example_inputs = [
[torch.randn(3), torch.empty(3)],
[torch.randn(0, 3), torch.empty(0, 3)],
[torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],
]
for example in example_inputs:
torch.library.opcheck(numpy_sin, example)
结论¶
在本教程中,我们学习了如何使用torch.library.custom_op来
创建一个与PyTorch子系统如torch.compile和autograd兼容的自定义操作符。
这个教程提供了对自定义操作的基本介绍。 更多信息,请参见:
脚本总运行时间: ( 0 分钟 4.595 秒)