Skip to content

torch.func

译者:片刻小哥哥

项目地址:https://pytorch.apachecn.org/2.0/docs/func

原始地址:https://pytorch.org/docs/stable/func.html

torch.func,以前称为“functorch”,是 PyTorch 的类似 JAX 可组合函数转换。

笔记

该库目前处于测试阶段。这意味着这些功能通常可以工作(除非另有说明)并且我们(PyTorch团队)致力于推动这个图书馆的发展。但是,API 可能会根据用户反馈进行更改,并且我们无法全面覆盖 PyTorch 操作。

如果您对希望涵盖的 API 或用例有建议,请打开 GitHub 问题或联系我们。我们很想听听您如何使用图书馆。

什么是可组合函数变换?

  • “函数变换”是一个高阶函数,它接受数值函数并返回计算不同数量的新函数。
  • torch.func 具有自动微分变换( grad(f) 返回一个计算“f”梯度的函数),一个矢量化/批处理变换( vmap(f) 返回一个在批次上计算 f 的函数输入),以及其他。*这些函数变换可以任意组合。例如,组合vmap(grad(f))会计算一个称为“每个样本梯度”的量,而 PyTorch 目前无法有效计算该量。

为什么可组合函数会发生变换?

如今,PyTorch 中有许多用例很难实现:

  • 计算每个样本的梯度(或其他每个样本的量)
  • 在单台机器上运行模型集合
  • 在 MAML 内循环中高效地将任务批量组合在一起
  • 高效计算雅克比矩阵和海森矩阵
  • 高效计算批量雅克比矩阵和海森矩阵

编写 vmap()grad()vjp() 转换允许我们来表达上述内容,而不需要为每个设计单独的子系统。这种可组合函数转换的想法来自 JAX 框架

阅读更多内容


我们一直在努力

apachecn/AiLearning

【布客】中文翻译组