Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。
更多 Triton 中文文档可访问 →https://triton.hyper.ai/
在本教程中,您将编写一个非常简短的高性能 FP16 矩阵乘法内核,其性能可以与 cuBLAS 或 rocBLAS 相媲美。
您将具体学习以下内容:
块级矩阵乘法。
多维指针算术。
为提高 L2 缓存命中率而进行的程序重排序。
自动性能调优。
矩阵乘法是现代大多数高性能计算系统的关键构建块。
矩阵乘法难以优化,因此其实现通常由硬件供应商自行完成,作为所谓「内核库」(例如 cuBLAS )的一部分。
这些库通常是专有的,不能轻易定制以满足现代深度学习工作负载的需求(例如融合激活函数)。
在本教程中,您将学习如何借助一种更易于定制和扩展的方法,用 Triton 实现高效的矩阵乘法。
整体来说,我们将编写的内核将实现以下的分块算法,用于计算一个 (M, K) 乘以一个 (K, N) 的矩阵:
其中,每次双重嵌套的循环迭代都由专用的 Triton 程序实例执行。
实际上,上述算法在 Triton 中实现起来相当简单。
主要困难在于计算内循环中必须读取 A
和 B
块的内存位置。为此,我们需要多维指针算术。
因此,对于行主序的二维张量 X
,X[i, j]
的内存位置由 &X[i, j] = X + i*stride_xi + j*stride_xj
给出。
因此,A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]
和B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N]
的指针块可以用伪代码定义为:
这意味着在 Triton 中可以将 A 和 B 块的指针初始化(即 k=0
)为以下代码。还要注意,当 M
不是 BLOCK_SIZE_M
的倍数或 N
不是 BLOCK_SIZE_N
的倍数时,我们需要额外的取模运算来应对,这种情况下我们可以用一些无用的值填充数据,这些值不会对结果有影响。对于 K
维度,我们将在后面使用掩码加载语义来处理。
然后在内循环中更新如下:
正如上面提到的,每个程序实例计算 C
的一个 [BLOCK_SIZE_M, BLOCK_SIZE_N]
块。
重点要记住这些块的计算顺序,因为它会影响我们程序的 L2 缓存命中率,而且,简单的行主序排序是行不通的。
一种可能的解决方案是以促进数据重用的顺序启动块。
在转向下一列之前,可以通过将 GROUP_M
行的块进行「超级分组」来实现此目的:
例如,在以下的矩阵乘法示例中,每个矩阵都是 9*9 个块。可以看到,如果按行主序计算输出,我们需要加载 90 个块到 SRAM 中来计算前 9 个输出块,但如果按组顺序计算,我们只需要加载 54 个块。
实际上,这种做法可以在某些硬件架构上显著提升我们的矩阵乘法核心性能,例如在 A100 上,性能提升可以超过 10%,从 220 到 245 TFLOPS 不等。
现在我们可以创建一个方便的 wrapper 函数,只接受两个输入张量,并且:(1) 检查任何 shape 约束;(2) 分配输出;(3) 启动上述的内核。
对自定义矩阵乘法操作进行测试,与 原生 torch 实现(例如 cuBLAS)进行对比。
Out:
triton_output_with_fp16_inputs=tensor([[-10.9531, -4.7109, 15.6953, ..., -28.4062, 4.3320, -26.4219],
torch_output_with_fp16_inputs=tensor([[-10.9531, -4.7109, 15.6953, ..., -28.4062, 4.3320, -26.4219],
[ 26.8438, 10.0469, -5.4297, ..., -11.2969, -8.5312, 30.7500],
[-13.2578, 15.8516, 18.0781, ..., -21.7656, -8.6406, 10.2031],
...,
[ 40.2812, 18.6094, -25.6094, ..., -2.7598, -3.2441, 41.0000],
[ -6.1211, -16.8281, 4.4844, ..., -21.0312, 24.7031, 15.0234],
[-17.0938, -19.0000, -0.3831, ..., 21.5469, -30.2344, -13.2188]], device='cuda:0', dtype=torch.float16)✅ Triton and Torch matchtriton_output_with_fp8_inputs=tensor([[-21.4375, 13.1719, 6.0352, ..., 28.7031, 8.6719, -40.7500],
[ 10.0000, 37.0000, -5.5664, ..., 20.9844, 46.8125, 30.8281],
[ 19.5625, -3.0078, -20.0469, ..., -2.1309, -8.0625, 12.5625],
...,
[-18.1562, -34.1562, -27.4219, ..., -27.3906, -24.0938, -12.3516],
[ -3.3945, -8.6250, -23.6562, ..., -4.1094, -3.5332, -16.0781],
[-23.9688, -3.2637, -33.6875, ..., 17.3125, -36.6250, 25.8594]], device='cuda:0', dtype=torch.float16)torch_output_with_fp8_inputs=tensor([[-21.4375, 13.1719, 6.0352, ..., 28.7031, 8.6719, -40.7500],
[ 10.0000, 37.0000, -5.5664, ..., 20.9844, 46.8125, 30.8281],
[ 19.5625, -3.0078, -20.0469, ..., -2.1309, -8.0625, 12.5625],
...,
[-18.1562, -34.1562, -27.4219, ..., -27.3906, -24.0938, -12.3516],
[ -3.3945, -8.6250, -23.6562, ..., -4.1094, -3.5332, -16.0781],
[-23.9688, -3.2637, -33.6875, ..., 17.3125, -36.6250, 25.8594]], device='cuda:0', dtype=torch.float16)✅ Triton and Torch match
比较内核与 cuBLAS 或 rocBLAS 的性能差异。此处以方阵为例进行讲解,也可以可以按需调整脚本,对其他 matrix shape 进行基准测试。
Out:
matmul-performance-fp16:
matmul-performance-fp8:
小琪学姐吖 2024-11-02
成哥的科技生活 2024-11-02
恒点信息 2024-11-02