Apache TVM 是一个端到端的深度学习编译框架,适用于 CPU、GPU 和各种机器学习加速芯片。更多 TVM 中文文档可访问 → https://tvm.hyper.ai/
作者:Tianqi Chen
下面介绍如何在 TVM 中进行递归计算(神经网络中的典型模式)。
TVM 用线性算子来描述符号循环。以下线性算子计算 X 列上的累积和。
线性在张量的最高维度上进行。s_state
是描述线性转换状态的占位符。s_init
描述如何初始化前 k 个时间步长,其第一个维度为 1,描述了如何初始化第一个时间步长的状态。
s_update
描述了如何更新时间步长 t 处的值,更新的值可通过状态占位符引用上一个时间步长的值。注意在当前或之后的时间步长引用 s_state
是无效的。
线性包含状态占位符、初始值和更新描述。推荐列出线性单元的输入,线性的结果是一个张量—— s_state
在时域更新后的结果。
通过分别调度 update 和 init 部分来调度线性体。注意,调度更新部分的第一个迭代维度是无效的。要在时间迭代上拆分,用户可以在 scan_op.scan_axis 上进行调度。
输出结果:
可以像其他 TVM 内核一样构建线性内核,这里用 numpy 来验证结果的正确性。
以上示例用 s_update 中的一个张量计算阶段描述了线性单元,可以在线性单元中使用多个张量级。
以下代码演示了有两个阶段操作的线性单元中的线性过程:
这些中间张量可以正常调度。为了确保正确性,TVM 创建了一个组约束——禁用线性循环之外的 compute_at 位置的线性体。
输出结果:
对于像 RNN 这样的复杂应用,需要多个递归状态。线性支持多个递归状态,以下示例演示如何构建具有两种状态的递归。
输出结果:
本教程演示了如何使用线性原语。
用 init 和 update 描述线性。
将线性单元当作正常 schedule 进行调度。
对于复杂的工作负载,在线性单元中使用多个状态和步骤。