# 正向模式自动微分（测试版） ¶

## 基本用法 ¶

import torch

primal = torch.randn(10, 10)
tangent = torch.randn(10, 10)

def fn(x, y):
return x ** 2 + y ** 2

# All forward AD computation must be performed in the context of
# a dual_level context. All dual tensors created in such a context
# will have their tangents destroyed upon exit. This is to ensure that
# if the output or intermediate results of this computation are reused
# in a future forward AD computation, their tangents (which are associated
# with this computation) won't be confused with tangents from the later
# computation.
# To create a dual tensor we associate a tensor, which we call the
# primal with another tensor of the same size, which we call the tangent.
# If the layout of the tangent is different from that of the primal,
# The values of the tangent are copied into a new tensor with the same
# metadata as the primal. Otherwise, the tangent itself is used as-is.
#
# It is also important to note that the dual tensor created by
# make_dual is a view of the primal.

# To demonstrate the case where the copy of the tangent happens,
# we pass in a tangent with a layout different from that of the primal

# Tensors that do not have an associated tangent are automatically
# considered to have a zero-filled tangent of the same shape.
plain_tensor = torch.randn(10, 10)
dual_output = fn(dual_input, plain_tensor)

# Unpacking the dual returns a namedtuple with primal and tangent
# as attributes



## 与模块一起使用 ¶

import torch.nn as nn

model = nn.Linear(5, 5)
input = torch.randn(16, 5)

params = {name: p for name, p in model.named_parameters()}
tangents = {name: torch.rand_like(p) for name, p in params.items()}

for name, p in params.items():
delattr(model, name)

out = model(input)


## 使用功能模块 API（测试版） ¶

nn.Module 与转发 AD 结合使用的另一种方法是利用 功能模块 API（也称为无状态模块 API）。

from torch.func import functional_call

# We need a fresh module because the functional call requires the
# the model to have parameters registered.
model = nn.Linear(5, 5)

dual_params = {}
for name, p in params.items():
# Using the same tangents from the above section
out = functional_call(model, dual_params, input)

# Check our results
assert torch.allclose(jvp, jvp2)


class Fn(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
result = torch.exp(foo)
# Tensors stored in ctx can be used in the subsequent forward grad
# computation.
ctx.result = result
return result

@staticmethod
def jvp(ctx, gI):
gO = gI * ctx.result
# If the tensor stored in ctx will not also be used in the backward pass,
# one can manually free it using del
del ctx.result
return gO

fn = Fn.apply

primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True)
tangent = torch.randn(10, 10)

dual_output = fn(dual_input)

# It is important to use autograd.gradcheck to verify that your
# gradcheck only checks the backward-mode (reverse-mode) AD gradients. Specify
# check_forward_ad=True to also check forward grads. If you did not
# implement the backward formula for your function, you can also tell gradcheck
# to skip the tests that require backward-mode AD by specifying
# check_backward_ad=False, check_undefined_grad=False, and
# check_batched_grad=False.

True


## 功能 API（测试版） ¶

import functorch as ft

primal0 = torch.randn(10, 10)
tangent0 = torch.randn(10, 10)
primal1 = torch.randn(10, 10)
tangent1 = torch.randn(10, 10)

def fn(x, y):
return x ** 2 + y ** 2

# Here is a basic example to compute the JVP of the above function.
# The jvp(func, primals, tangents) returns func(*primals) as well as the
# computed Jacobian-vector product (JVP). Each primal must be associated with a tangent of the same shape.
primal_out, tangent_out = ft.jvp(fn, (primal0, primal1), (tangent0, tangent1))

# functorch.jvp requires every primal to be associated with a tangent.
# If we only want to associate certain inputs to fn with tangents,
# then we'll need to create a new function that captures inputs without tangents:
primal = torch.randn(10, 10)
tangent = torch.randn(10, 10)
y = torch.randn(10, 10)

import functools
new_fn = functools.partial(fn, y=y)
primal_out, tangent_out = ft.jvp(new_fn, (primal,), (tangent,))

/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/deprecated.py:77: UserWarning:

We've integrated functorch into PyTorch. As the final step of the integration, functorch.jvp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.jvp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html


## 将功能 API 与模块结合使用 ¶

model = nn.Linear(5, 5)
input = torch.randn(16, 5)
tangents = tuple([torch.rand_like(p) for p in model.parameters()])

# Given a torch.nn.Module, ft.make_functional_with_buffers extracts the state
# (params and buffers) and returns a functional version of the model that
# can be invoked like a function.
# That is, the returned func can be invoked like
# func(params, buffers, input).
# ft.make_functional_with_buffers is analogous to the nn.Modules stateless API
# that you saw previously and we're working on consolidating the two.
func, params, buffers = ft.make_functional_with_buffers(model)

# Because jvp requires every input to be associated with a tangent, we need to
# create a new function that, when given the parameters, produces the output
def func_params_only(params):
return func(params, buffers, input)

model_output, jvp_out = ft.jvp(func_params_only, (params,), (tangents,))

/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/deprecated.py:104: UserWarning:

We've integrated functorch into PyTorch. As the final step of the integration, functorch.make_functional_with_buffers is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.functional_call instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html

/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/deprecated.py:77: UserWarning:

We've integrated functorch into PyTorch. As the final step of the integration, functorch.jvp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.jvp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html