torch.func ¶
译者:片刻小哥哥
torch.func,以前称为“functorch”,是 PyTorch 的类似 JAX 可组合函数转换。
笔记
该库目前处于测试阶段。这意味着这些功能通常可以工作(除非另有说明)并且我们(PyTorch团队)致力于推动这个图书馆的发展。但是,API 可能会根据用户反馈进行更改,并且我们无法全面覆盖 PyTorch 操作。
如果您对希望涵盖的 API 或用例有建议,请打开 GitHub 问题或联系我们。我们很想听听您如何使用图书馆。
什么是可组合函数变换? ¶
- “函数变换”是一个高阶函数,它接受数值函数并返回计算不同数量的新函数。
torch.func
具有自动微分变换(grad(f)
返回一个计算“f”梯度的函数),一个矢量化/批处理变换(vmap(f)
返回一个在批次上计算f
的函数输入),以及其他。*这些函数变换可以任意组合。例如,组合vmap(grad(f))
会计算一个称为“每个样本梯度”的量,而 PyTorch 目前无法有效计算该量。
为什么可组合函数会发生变换? ¶
如今,PyTorch 中有许多用例很难实现:
- 计算每个样本的梯度(或其他每个样本的量)
- 在单台机器上运行模型集合
- 在 MAML 内循环中高效地将任务批量组合在一起
- 高效计算雅克比矩阵和海森矩阵
- 高效计算批量雅克比矩阵和海森矩阵
编写 vmap()
、 grad()
和 vjp()
转换允许我们来表达上述内容,而不需要为每个设计单独的子系统。这种可组合函数转换的想法来自 JAX 框架 。