torch.func¶
torch.func,之前称为“functorch”,是用于PyTorch的 类似于JAX的可组合函数转换。
注意
这个库目前处于测试版。 这意味着功能通常可以正常工作(除非另有文档说明), 并且我们(PyTorch 团队)致力于推进这个库的发展。然而,根据用户反馈,API 可能会发生变化, 并且我们还没有完全覆盖 PyTorch 的所有操作。
如果您有关于 API 或您希望涵盖的用例的建议,请打开 GitHub 问题或联系我们。我们很想知道您是如何使用该库的。
什么是可组合函数变换?¶
一个“函数变换”是一个高阶函数,它接受一个数值函数并返回一个新的函数,用于计算不同的量。
torch.func具有自动微分变换(grad(f)返回一个计算f梯度的函数),向量化/批量变换(vmap(f)返回一个在输入批次上计算f的函数),以及其他功能。这些函数变换可以任意组合。例如,组合
vmap(grad(f))计算一个称为每样本梯度的量, 而今天的标准PyTorch无法高效地计算这个量。
为什么使用可组合函数变换?¶
目前有一些用例在 PyTorch 中处理起来比较棘手:
计算每个样本的梯度(或其他每个样本的量)
在单台机器上运行模型集合
在 MAML 内循环中高效地批量处理任务
高效计算雅可比矩阵和黑塞矩阵
高效计算批处理雅可比矩阵和黑塞矩阵
组合vmap()、grad()和vjp()变换,使我们能够在不为每个子系统单独设计的情况下表达上述内容。
这种可组合函数变换的思想来自JAX框架。