(测试版)在FX中构建一个卷积/BATCH NORM 融合器
译者:方小生
项目地址:https://pytorch.apachecn.org/2.0/tutorials/intermediate/fx_conv_bn_fuser
原始地址:https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html
作者: Horace He
在本教程中,我们将使用 FX(一个用于 PyTorch 的可组合函数转换的工具包)来执行以下操作:
-
查找数据依赖项中的conv/batch norm。
-
对于 1) 中找到的模式,将批量范数统计数据折叠到卷积权重中。
请注意,此优化仅适用于推理模式下的模型(即 mode.eval())。我们将构建此处存在的融合器:https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py。
首先,让我们进行一些导入(稍后我们将在代码中使用所有这些)。
from typing import Type, Dict, Any, Tuple, Iterable
import copy
import torch.fx as fx
import torch
import torch.nn as nn
在本教程中,我们将创建一个由卷积和batch norms组成的模型。请注意,该模型有一些棘手的组件 - 一些conv/batch norm patterns隐藏在 Sequentials 中,并且其中一个 BatchNorms 包装在另一个模块中。
class WrappedBatchNorm(nn.Module): # 定义封装BatchNorm2d的自定义模块
def __init__(self):
super().__init__()
self.mod = nn.BatchNorm2d(1) # 创建BatchNorm2d层
def forward(self, x):
return self.mod(x) # 前向传播直接调用BatchNorm2d层
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1) # 卷积层1
self.bn1 = nn.BatchNorm2d(1) # 批标准化层1
self.conv2 = nn.Conv2d(1, 1, 1) # 卷积层2
self.nested = nn.Sequential( # 序列模块
nn.BatchNorm2d(1), # 批标准化层2
nn.Conv2d(1, 1, 1),
)
self.wrapped = WrappedBatchNorm() # 封装批标准化层的自定义模块
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x) # 应用批标准化层1
x = self.conv2(x)
x = self.nested(x) # 应用序列模块,包含批标准化层2
x = self.wrapped(x) # 应用封装BatchNorm的自定义模块
return x
model = M()
model.eval()
将卷积与Batch Norm融合1
PyTorch中尝试自动融合convolution and batch norm的主要挑战之一是PyTorch无法提供访问计算图的简单方法。FX通过象征性地跟踪调用的实际操作来解决这个问题,这样我们就可以通过前向调用、嵌套在Sequential模块中或封装在用户定义的模块中来跟踪计算。
这为我们的模型提供了一个图形表示。请注意,隐藏在序列模块中的模块以及封装的模块都已内联到图中。这是默认的抽象级别,但可以由pass-writer进行配置。更多信息可在外汇概述中找到https://pytorch.org/docs/master/fx.html#module-torch.fx
卷积与Batch Norm融合2
与其他一些融合不同,融合convolution and batch norm不需要任何新的算子。相反,由于推理过程中的批处理范数由逐点加法和乘法组成,因此这些运算可以“baked”到前面的卷积的权重中。这使我们能够完全从我们的模型中删除batch norm!阅读https://nenadmarkus.com/p/fusing-batchnorm-and-conv/了解更多详细信息。此处代码https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py。
def fuse_conv_bn_eval(conv, bn):
"""
给定一个卷积模块A和批标准化模块B,返回一个卷积模块C,
使得C(x) == B(A(x)),仅在推理模式下成立。
"""
assert(not (conv.training or bn.training)), "融合仅用于推理!"
fused_conv = copy.deepcopy(conv)
# 融合卷积和BN的权重及偏置
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
fused_conv.weight, fused_conv.bias,
bn.running_mean, bn.running_var, bn.eps,
bn.weight, bn.bias)
return fused_conv
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
# 处理None情况
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
# 计算融合后的权重和偏置
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
conv_w = conv_w * (bn_w / bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
FX Fusion 实现过程
现在我们有了计算图以及融合convolution 和 batch norm的方法,剩下的就是迭代FX图并应用所需的融合。
def _parent_name(target: str) -> Tuple[str, str]:
"""
将qualname拆分成父路径和最后一个属性名
例如:'foo.bar.baz' -> ('foo.bar', 'baz')
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
# 替换图中某个节点对应的模块
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, new_module) # 替换模块
def fuse(model: torch.nn.Module) -> torch.nn.Module:
model = copy.deepcopy(model) # 深拷贝模型
# 获得图表示
fx_model: fx.GraphModule = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules()) # 获得所有模块
# 遍历图中的每个节点
for node in fx_model.graph.nodes:
if node.op != 'call_module':
continue # 跳过非模块调用节点
# 检查节点是否为卷积->批标准化
if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
if len(node.args[0].users) > 1:
continue # 卷积层输出被多个节点使用时跳过
# 获得卷积层和BN层
conv = modules[node.args[0].target]
bn = modules[node.target]
# 融合参数
fused_conv = fuse_conv_bn_eval(conv, bn)
# 替换图中的卷积层
replace_node_module(node.args[0], modules, fused_conv)
# 替换BN层的使用为卷积层
node.replace_all_uses_with(node.args[0])
# 删除BN节点
fx_model.graph.erase_node(node)
fx_model.graph.lint() # 检查图的正确性
fx_model.recompile() # 重新编译
return fx_model
注意:
出于演示目的,我们在这里进行了一些简化,例如仅匹配2D卷积。查看https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py以获得更详细的过程。
测试我们的Fusion Pass
我们现在可以在我们最初的玩具模型上运行这个融合过程,并验证我们的结果是否相同。此外,我们可以打印出融合模型的代码,并验证是否没有更多的batch norms。
fused_model = fuse(model)
print(fused_model.code)
inp = torch.randn(5, 1, 1, 1)
torch.testing.assert_allclose(fused_model(inp), model(inp))
在ResNet18上对我们的融合进行基准测试
我们可以在像ResNet18这样的更大模型上测试我们的融合过程,看看这个过程在多大程度上提高了推理性能。
import torchvision.models as models
import time
rn18 = models.resnet18()
rn18.eval()
inp = torch.randn(10, 3, 224, 224)
output = rn18(inp)
def benchmark(model, iters=20):
for _ in range(10):
model(inp)
begin = time.time()
for _ in range(iters):
model(inp)
return str(time.time()-begin)
fused_rn18 = fuse(rn18)
print("Unfused time: ", benchmark(rn18))
print("Fused time: ", benchmark(fused_rn18))
正如我们之前看到的,我们的FX转换的输出是(“torchscriptable”)PyTorch代码,我们可以轻松地对输出进行jit.script,以尝试进一步提高我们的性能。通过这种方式,我们的FX模型转换与TorchScript完美结合。
jit_rn18 = torch.jit.script(fused_rn18)
print("jit time: ", benchmark(jit_rn18))
############
# Conclusion
# ----------
# As we can see, using FX we can easily write static graph transformations on
# PyTorch code.
#
# Since FX is still in beta, we would be happy to hear any
# feedback you have about using it. Please feel free to use the
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
# you might have.