在 C++ 中注册调度运算符

原文:https://pytorch.org/tutorials/advanced/dispatcher.html

调度器是 PyTorch 的内部组件,负责确定调用torch::add之类的函数时应实际运行哪些代码。 这是不平凡的,因为 PyTorch 操作需要处理很多交叉关注点,这些关注点“层叠”在另一个之上。 以下是其处理的一些示例:

  • 根据输入张量的设备,在运算符的 CPU 和 CUDA 实现之间切换。
  • 在运算符的自动微分和后端实现之间切换,这取决于是否需要自动微分处理。
  • 必要时应用自动广播来实现自动混合精度。
  • 当运算符在vmap调用下运行时,应用批量规则。
  • 如果要跟踪导出的模型,则跟踪操作的执行。

如果在自定义运算符代码中发现自己手动编写了if语句来处理这些情况,则调度器 API 可以帮助组织代码。 (相反,如果您的自定义运算符非常简单并且仅用于 CPU 推断,则可能不需要使用调度器,只需使用基本 API。)

在本教程中,我们将描述如何构造自定义运算符注册以使用调度器来组织各种组件。 我们假设您熟悉如何注册运算符以及如何编写自定义自动微分函数

定义模式和后端实现

调度器背后的一般原理是将一个运算符的实现分为多个内核,每个内核都为特定的调度键实现功能; 例如,CPU,CUDA 或 Autograd。 调度器在您调用运算符时确定最高优先级的调度键是什么(这通过查看张量参数和某些线程本地状态来完成),并将控制权传递给内核以使用该调度键。 最终结果是,当您调用运算符时,我们首先执行 Autograd 内核,然后根据传入的张量的设备类型将其重新分配到 CPU 或 CUDA 内核。

让我们看一下实现这一目标所涉及的各个部分。 首先,我们必须为所讨论的运算符定义架构。 与简单的pybind11样式的运算符注册不同,我们目前实际上并未提供运算符的实现; 我们只提供一个模式字符串,指定所有其他内核将遵守的运算符的类型签名:

TORCH_LIBRARY(myops, m) {
  m.def("myadd(Tensor self, Tensor other) -> Tensor");
}

接下来,我们需要实际提供此运算符的一些实现。 具体来说,这是一个非常简单的 CPU 实现:

Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) {
  TORCH_CHECK(self_.sizes() == other_.sizes());
  TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU);
  TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU);
  Tensor self = self_.contiguous();
  Tensor other = other_.contiguous();
  Tensor result = torch::empty(self.sizes(), self.options());
  const float* self_ptr = self.data_ptr<float>();
  const float* other_ptr = other.data_ptr<float>();
  float* result_ptr = result.data_ptr<float>();
  for (int64_t i = 0; i < result.numel(); i++) {
    result_ptr[i] = self_ptr[i] + other_ptr[i];
  }
  return result;
}

我们想将此函数注册为myops::myadd的实现。 但是,简单的注册方法(def("myadd", myadd_cpu))将注册内核以在所有情况下都可以运行,即使张量不是 CPU 张量! (在内部,我们将它们称为“全部捕获”内核,因为它们捕获所有情况。)为确保仅针对 CPU 张量运行myadd_cpu,我们可以使用TORCH_LIBRARY_IMPL宏:

TORCH_LIBRARY_IMPL(myops, CPU, m) {
  m.impl("myadd", myadd_cpu);
}

通过TORCH_LIBRARY_IMPL,我们可以在特定的调度键(在本例中为 CPU)上为运算符注册实现。 每次对impl的调用都会将 CPU 内核与相应的运算符(我们先前在TORCH_LIBRARY块中定义)相关联。 如果我们还有 CUDA 实现myadd_cuda,我们可以将其注册在单独的TORCH_LIBRARY_IMPL块中:

TORCH_LIBRARY_IMPL(myops, CUDA, m) {
  m.impl("myadd", myadd_cuda);
}

这些注册可以跨文件甚至跨库边界拆分; 因此,例如,您可以将这两个TORCH_LIBRARY_IMPL块编译为单独的myops_cpumyops_cuda动态库。 一般来说,您的注册结构如下所示:

  1. 单个TORCH_LIBRARY在集中位置列出名称空间中的每个自定义运算符。
  2. 每个调度键的TORCH_LIBRARY_IMPL,用于注册该键的实现(例如,CPU 或 CUDA)。 如果愿意,可以按每个运算符将TORCH_LIBRARY_IMPL块进一步细分为一个块。 如果每个运算符的实现都有一个单独的文件,但是又不想在标头中显示运算符,这将很方便。 您只需将注册内容放入定义您的运算符的 cpp 文件中。

注意

您知道吗,您还可以为 PyTorch 中的现有核心运算符编写TORCH_LIBRARY_IMPL块? 这就是实现 XLA 对 PyTorch 的支持的方式:torch_xla库包含一个TORCH_LIBRARY_IMPL,该库为 XLA 调度键上的所有基本运算符提供实现。

添加 Autograd 支持

至此,我们有了一个同时具有 CPU 和 CUDA 实现的运算符。 我们如何为它添加 Autograd 支持? 您可能会猜到,我们将注册一个 Autograd 内核(类似于自定义 Autograd 函数教程中描述的内容)! 但是,有一个变数:与 CPU 和 CUDA 内核不同,Autograd 内核需要重新分发:它需要回调调度器才能到达最终的 CPU 和 CUDA 实现。

因此,在编写 Autograd 内核之前,让我们编写一个调度函数,该函数调用调度器以为您的运算符找到合适的内核。 该函数构成了供您的运算符使用的公共 C++ API,实际上,PyTorch C++ API 中的所有张量函数都在后台完全以相同的方式调用了调度器。 调度函数如下所示:

Tensor myadd(const Tensor& self, const Tensor& other) {
  static auto op = torch::Dispatcher::singleton()
    .findSchemaOrThrow("myops::myadd", "")
    .typed<decltype(myadd)>();
  return op.call(self, other);
}

让我们分解一下:

  • 在第一行中,我们从调度器中查找与要调度到的运算符相对应的类型化运算符句柄。 findSchemaOrThrow具有两个参数:运算符的(名称空间限定)名称和运算符的重载名称(通常只是空字符串)。 typed将动态类型的句柄转换为静态类型的句柄(进行运行时测试以确保您提供了正确的 C++ 类型),以便我们可以对其进行常规的 C++ 调用。 我们将其传递给decltype(myadd),因为调度函数的类型与注册到调度器的基础内核的类型相同。

    为了提高性能,此计算是在静态变量中完成的,因此我们只需要进行一次(慢速)查找。 如果键入了要调用的运算符的名称,则第一次调用此函数时,此查找将出错。

  • 在第二行中,我们只需将所有参数传递到调度函数中,就可以简单地call运算符句柄。 这实际上将调用调度器,最终控制权将转移到适合此调用的任何内核。

有了分发函数,我们现在可以编写 Autograd 内核:

class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
 public:
  static Tensor forward(
      AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
    at::AutoNonVariableTypeMode g;
    return myadd(self, other);
  }

  static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
    auto grad_output = grad_outputs[0];
    return {grad_output, grad_output};
  }
};

Tensor myadd_autograd(const Tensor& self, const Tensor& other) {
  return MyAddFunction::apply(self, other)[0];
}

使用torch::autograd::Function正常编写 Autograd 函数,除了代替直接在forward()中编写实现,我们:

  1. 使用at::AutoNonVariableTypeMode RAII 保护器关闭 Autograd 处理,然后
  2. 调用调度函数myadd以回调调度器。

如果没有(1),您的调用将无限循环(并且栈溢出),因为myadd将使您返回此函数(因为最高优先级分配键仍将是自动微分的。)对于(1),自动微分从一组正在考虑的调度键中排除,我们将转到下一个处理器,即 CPU 和 CUDA。

现在,我们可以按照注册 CPU/CUDA 函数的相同方式注册此函数:

TORCH_LIBRARY_IMPL(myops, Autograd, m) {
  m.impl("myadd", myadd_autograd);
}

超越 Autograd

从某种意义上说,调度员并没有做太多事情:它所做的只是实现一种美化的if语句,其方法如下:

class MyAddFunction : ... {
public:
  static Tensor forward(
    AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {

    if (self.device().type() == DeviceType::CPU) {
      return add_cpu(self, other);
    } else if (self.device().type() == DeviceType::CUDA) {
      return add_cuda(self, other);
    } else {
      TORCH_CHECK(0, "Unsupported device ", self.device().type());
    }
  }
  ...
}

那么为什么要使用调度器呢? 有几个原因:

  1. 它是分散的。 您可以组装运算符的所有部分(CPU,CUDA,Autograd),而不必编写引用所有元素的集中式if语句。 重要的是,第三方可以注册其他方面的额外实现,而不必修补运算符的原始定义。
  2. 它比 CPU,CUDA 和 Autograd 支持更多的调度键。 您可以在c10/core/DispatchKey.h中查看 PyTorch 中当前实现的调度键的完整列表。 这些调度键为运算符实现了多种可选功能,如果您决定希望自定义运算符支持该功能,则只需为相应的键注册内核即可。
  3. 调度器实现对盒装后备函数的支持,后者是可以一次实现并应用于系统中所有运算符的函数。 盒装后备可用于提供调度键的默认行为。 如果您使用调度器来实现您的运算符,那么您还可以选择所有这些操作的备用。

这是一些特定的调度键,您可能需要为其定义一个运算符。

Autocast

Autocast 调度键实现对自动混合精度(AMP)的支持。 自动广播包装器内核通常会在运行操作之前将传入的float16float32 CUDA 张量转换为某些首选精度。 例如,浮点 CUDA 张量上的积和卷积通常运行得更快,并且在float16中使用较少的内存,而不会影响收敛。 自动广播包装器仅在启用自动广播的上下文中有效。

这是假设的自定义Matmul的自动广播包装器及其注册信息:

// Autocast-specific helper functions
#include <ATen/autocast_mode.h>

Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return mymatmul(at::autocast::cached_cast(at::kHalf, self),
                  at::autocast::cached_cast(at::kHalf, other));
}

TORCH_LIBRARY_IMPL(myops, Autocast, m) {
  m.impl("mymatmul", mymatmul_autocast);
}

如果tensor为 CUDA 和float32,则cached_cast(kHalf, tensor)tensor强制转换为float16,否则,tensor保持不变(参见资格策略对于本地自动播报的操作)。 这样可以确保网络是否在float16float32 CUDA 张量的任何混合形式上调用mymatmulmymatmulfloat16中运行。 同时,使用非 CUDA,整数类型或float64输入的对mymatmul的调用不受影响。 建议使用cached_cast在您自己的自动广播包装程序中遵循本机资格策略,但不是必需的。 例如,如果要对所有输入类型强制执行float16,则可以使用return mymatmul(self.half(), other.half());而不是使用cached_cast

请注意,就像我们的 Autograd 内核一样,我们在重新分配之前从分配中排除Autocast键。

默认情况下,如果未提供自动广播包装器,我们将直接进入常规的运算符实现(不进行自动广播)。 (在此示例中,我们没有使用myadd,因为逐点加法不需要自动广播,因此应该会失败。)

什么时候应该注册自动广播包装器? 不幸的是,对于运算符的首选精度并没有严格的规定。 通过查看运算符列表,您可以了解某些本机运算符的首选精度。 一般指导:

  • 进行减少操作的操作可能应该在float32中执行,
  • 在幕后进行卷积或宝石运算的任何操作都应在float16中执行,并且
  • 具有多个浮点张量输入的其他运算符应将它们标准化为通用精度(除非实现支持具有不同精度的输入)。

如果您的自定义操作属于第三类,则promote_type模板有助于找出输入张量中存在的最宽浮点类型,这是执行类型的最安全选择:

#include <ATen/autocast_mode.h>

Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  // The required at::kHalf argument is an optimistic initial guess.
  auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
  return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
                              at::autocast::cached_cast(exec_type, t1));
}

如果您的自定义操作已启用 Autograd,则只需编写和注册自动广播包装器,其名称与注册自动梯度包装器的名称相同。 例如,如果您想为 Autograd 部分中显示的myadd函数使用自动广播包装,那么您所需要做的就是

Tensor myadd_autocast(const Tensor& self, const Tensor& other) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return myadd(at::autocast::cached_cast(<desired dtype>, self),
               at::autocast::cached_cast(<desired dtype>, other));
}

TORCH_LIBRARY_IMPL(myops, Autocast, m) {
  m.impl("myadd", myadd_autocast);
}

没有单独的体操可使后向方法自动广播兼容。 但是,在自定义 Autograd 函数中定义的向后方法将以与正向方法的自动广播集相同的dtype运行,因此您应该选择既适合于正向方法又适合于向后方法的<desired dtype>

批量

批量张量允许您按示例方式编写代码,然后在vmap调用下运行时自动对其进行批量。 当前正在开发用于编写批量规则的 API,但是一旦稳定该 API,就可以通过在 Batched 调度键处注册内核来为运算符添加对vmap的支持。

追踪器

追踪器调度键实现了对在运行torch.jit.trace时将运算符调用记录到跟踪中的支持。 我们打算提供一个盒装后备,它将实现对任意操作的跟踪,请参阅 ISSUE#41478 以跟踪进度。


Copyright © ibooker.org.cn 2019 all right reserved,由 ApacheCN 团队提供支持该文件修订时间: 2021-04-12 06:14:20

results matching ""

    No results matching ""

    results matching ""

      No results matching ""