banner
Nagi-ovo

Nagi-ovo

Breezing
github
x

Softmax in OpenAI Triton

本文是对 @sotadeeplearningtutorials9598 的 Youtube 教程学习的总结,感谢老师深入浅出的指导让我这个从未接触过 GPU 编程的小白能够编写出第一个有实际效果的 Kernel。

Softmax 是一种常用的激活函数,通常用于多分类任务的神经网络输出层。它将输入的实数向量转换为概率分布,满足所有输出值在 0 到 1 之间,且总和为 1。Karpathy 的形容是它将 logits squash 为了 0-1 之间的概率分布。

公式:
Softmax(zi)=ezij=1Kezj\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}

为什么要在 GPU 上实现 Softmax?#

GPU 擅长处理并行计算任务。深度学习模型通常需要处理大量的数据和计算,使用 GPU 可以显著提高计算速度。

为什么选 Triton?#

Triton 是 OpenAI 开发的一个编译器和编程语言,旨在让开发者更容易地编写高性能的 kernel。它提供了类似于 Python 的高级语法,相对于 CUDA 等降低了 GPU 编程的复杂性,PyTorch 对 Triton 的支持也为愿意为生态系统做贡献的开发者提供了更多机会。

下图很好地体现出了 Triton 在内核开发上能兼顾性能和效率:

Pasted image 20240913134958

图源自知乎杨军老师的回答:# 谈谈对 OpenAI Triton 的一些理解

接下来的操作是我在 WSL2,Ubuntu 20.04,Python3.10 上进行的:

import torch
import triton
import triton.language as tl
  • triton:Triton 主库。
  • triton.language as tl:Triton 的编程语言模块,包含了编写 Triton 内核所需的函数和操作。

GPU 基本知识#

在 GPU 编程中,Kernel 是一个特殊的函数,它定义了需要并行执行的计算任务。为了高效地利用 GPU 的并行处理能力,这个 Kernel 会被分解成多个执行单元,称为 Block。这种结构允许 GPU 高度并行地处理大量数据,通过将一个大型计算任务分解成众多小型并行任务,GPU 能够实现显著的性能提升。

Pasted image 20240913102053

  • Kernel :是程序员编写的核心算法,它描述了每个并行执行单元应该执行的操作。。这段代码被设计为在大量 Thread 上执行相同的操作。
  • Block:GPU 会将这个 Kernel 任务划分成多个 Block,每个 Block 内部包含许多 Thread(线程)。这些 Thread 同时运行,各自处理一部分数据,但都执行相同的Kernel代码。

简而言之,Kernel 定义了 "做什么",而 BlockThread 决定了 "如何并行地做"。这种方法充分利用了 GPU 的硬件特性,实现了高效的并行计算。

Softmax 实现#

Eager Mode#

先用纯 Python 实现 Softmax,来参考和验证其他实现的正确性:

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    x_max = x.max(dim=1, keepdim=True)[0]
    safe_x = x - x_max
    numerator = torch.exp(safe_x) 
    denominator = numerator.sum(dim=1, keepdim=True)
    softmax_out = numerator / denominator
    return softmax_out
  • 每一行代码都会立即执行,计算结果会立即产生,类似于 Python 的普通代码执行方式。
  • 计算图是动态构建的,每次运行代码时,都会即时创建和执行计算图。
  • 与之相对的是 Graph Mode,区别在于是否立即执行(而非预构建静态计算图)

数值稳定性#

这里要提一嘴是 safe_x = x - x_max 这一步,这样减去最大值可以将所有值都变为非正数,防止 $e^x$ 计算溢出,提高数值稳定性,核心在于下面等式成立:

softmax(ximax(x))=eximax(x)jexjmax(x)=exi/emax(x)j(exj/emax(x))=exi/emax(x)(jexj)/emax(x)=exijexj=softmax(xi)\begin{align*} \text{softmax}(x_i - \max(x)) &= \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \\[2ex] &= \frac{e^{x_i} / e^{\max(x)}}{\sum_j (e^{x_j} / e^{\max(x)})} \\[2ex] &= \frac{e^{x_i} / e^{\max(x)}}{(\sum_j e^{x_j}) / e^{\max(x)}} \\[2ex] &= \frac{e^{x_i}}{\sum_j e^{x_j}} \\[2ex] &= \text{softmax}(x_i) \end{align*}

Triton Implement#

内核的开发实际上分为两部分,内核本身和使并行化(parallelizing) 的驱动,让它同时处理大量实例。

  • 驱动程序(Driver Program):这是在 CPU 上运行的 Python 代码,用于准备数据、配置内核参数,并调用 Triton 内核。

  • 算子(Operator)本身:这是用 Triton 编写的 GPU 内核,实际执行 Softmax 计算。

驱动程序#

这里用自顶向下的学习思路,首先你需要一个 driver program(驱动程序),它会设置大量 meta information,如 block size,共享内存分配等。

def softmax(x: torch.Tensor) -> torch.Tensor:
    """ Triton 实现的 Softmax,只有前向传播 """
    rows, cols = x.shape
    assert x.dim() == 2, f"Expected 2D input, got {x.dim()}D input"

    # 计算 block_size,为大于等于 cols 的最小 2 的幂
    block_size = triton.next_power_of_2(cols)

    # 根据 block_size 动态调整 num_warps
    num_warps = 4  # 每个 warp 有 32 个线程
    if block_size > 2047:
        num_warps = 8
    if block_size > 4095:
        num_warps = 16

    # 定义网格大小,每个线程块(Block)处理一行数据
    grid = (rows,) # 这个写法会创建一个只包含 rows 的 tuple

    # 创建一个与输入张量形状相同的空张量,用于存储输出
    sm_out = torch.empty_like(x)

    # 调用 Triton 内核(使用方括号传入 grid,然后将参数传递给内核)
    _softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

    return sm_out

后面你会发现,这里驱动程序中传入 kernel 的参数 kernel 函数本身声明的参数多,原因先按下不表

这里一个比较巧妙的点在于:GPU 通常在处理 2 的幂次大小的数据块时性能最佳。

使用 next_power_of_2 可以将数据大小向上取整到最近的 2 的幂,这有助于优化内存访问模式和对齐,后面还可以根据 block_size 的大小动态调整 num_warps—— 较小的问题使用较少的 warps 避免资源浪费,较大的则充分利用 GPU 并行能力。

算子(Triton Kernel)#

内核在 GPU 上执行实际的计算。

装饰器#

在 Triton 中开发一个内核,需要使用 @triton.jit 装饰器使其进入 Triton 编译器,

@triton.jit
def _softmax_fwd_kernel():
    pass

内核参数#

  • output_ptr:输出张量在内存中的起始地址。
  • stride_output_row:输出张量在行方向上的步幅(即每行在内存中的间隔)。
  • input_ptr:输入张量在内存中的起始地址。
  • stride_input_row:输入张量在行方向上的步幅。
  • num_cols:输入张量的列数。
  • block_size: tl.constexpr:块大小,编译时常量,决定每个线程块处理的元素数量。

获取当前线程块处理的行索引#

获取当前线程块在第 0 维度(行维度)的 ID,即处理的行索引。

row_index = tl.program_id(0)

计算当前行的数据指针#

row_start_ptr = input_ptr + row_index * stride_input_row
col_offsets = tl.arange(0, block_size)
input_ptrs = row_start_ptr + col_offsets
  • row_start_ptr:当前行在内存中的起始地址。
  • col_offsets:生成一个从 0 到 block_size - 1 的序列,表示列的偏移量。
  • input_ptrs:当前行中每个元素在内存中的地址。

创建掩码(Mask)#

掩码用于在并行计算中避免越界访问。当处理的元素数量不是线程块大小的整数倍时,使用掩码屏蔽无效的线程。

这里当列数小于 block_size 时,需要掩码来避免访问越界的内存地址。

mask = col_offsets < num_cols

从全局内存加载数据到片上存储器(SRAM)#

row = tl.load(input_ptrs, mask=mask, other=float("-inf"))
  • tl.load:从内存加载数据的 API。
  • mask:指示哪些地址是有效的。
  • other=float("-inf"):对于无效地址,填充为负无穷,以确保在后续计算最大值时不影响结果。

Softmax 计算#

利用 Triton 库提供的高效并行计算 API 实现逐元素地将分子除以分母,得到 Softmax 的输出

row_max = tl.max(row, axis=0)
safe_row = row - row_max
numerator = tl.exp(safe_row) 
denominator = tl.sum(numerator, axis=0)
sm_output = numerator / denominator

将结果写回全局内存#

output_row_ptr = output_ptr + row_index * stride_output_row
output_ptrs = output_row_ptr + col_offsets
tl.store(output_ptrs, sm_output, mask=mask)

  • output_row_ptr:输出张量当前行的起始地址。
  • output_ptrs:输出张量当前行中每个元素的地址。
  • tl.store:将结果写回内存,使用与加载相同的掩码,确保只写回有效的数据。

总体来看,我们的 kernel 长这样:

@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):
    # 获取当前程序的 ID(行索引)
    row_index = tl.program_id(0)

    # 计算当前行的起始指针
    row_start_ptr = input_ptr + (row_index * stride_input_row)
    col_offsets = tl.arange(0, block_size)
    input_pointers = row_start_ptr + col_offsets

    # 创建掩码,防止越界访问
    row_mask = col_offsets < num_cols

    # 从全局内存加载数据到片上 SRAM
    row = tl.load(input_pointers, mask=row_mask, other=float("-inf"))

    # Softmax 计算
    safe_row = row - tl.max(row, axis=0)
    numerator = tl.exp(safe_row)
    denominator = tl.sum(numerator, axis=0)
    sm_out = numerator / denominator

    # 将结果写回全局内存
    output_row_ptr = output_ptr + (row_index * stride_output_row)
    output_pointers = output_row_ptr + col_offsets
    tl.store(output_pointers, sm_out, mask=row_mask)

驱动程序与算子的交互#

Grid & Block#

在我们的驱动程序代码中:

  • grid = (rows,):定义了网格大小,即一维的 rowsBlock,每个 Block 处理输入张量的一行数据。

参数传递#

当我们调用内核时,实际上传递了以下参数使内核能够正确地定位和处理输入和输出数据:

_softmax_fwd_kernel[grid](
    sm_out,                # 输出张量的指针
    sm_out.stride(0),      # 输出张量在行方向上的步幅
    x,                     # 输入张量的指针
    x.stride(0),           # 输入张量在行方向上的步幅
    cols,                  # 输入张量的列数
    # 内核的配置参数
    block_size=block_size,
    num_warps=num_warps
)

内核的执行#

每个线程块处理一行数据,通过 row_index = tl.program_id(0),每个线程块知道自己应该处理哪一行。

GPU 上的多个线程块同时执行,使得多行数据可以并行处理,大大加快了计算速度。

特殊的 API 回顾#

  • tl.arange(start, end):生成一个从 startend - 1 的序列,用于创建列偏移量。
  • tl.program_id(axis):获取当前线程块在指定维度的 ID。
  • tl.constexpr:表示一个在编译时已知的常量,用于优化。

Benchmark#

完整代码见:triton_kernels_for_fun_and_profit/demos/demo_softmax.py

Screenshot 2024-09-12 at 17.02.38

3090 Ti 上的表现性能(GB/s)

在原视频中 Triton 最多能比后者快近三倍,现在 24 年 9 月份测试中 Triton 还是比 Torch Native 快一点,而且十分稳定。

元参数#

还记得我们前面提到驱动程序中传入 kernel 的参数比 kernel 函数声明的数量多吗?

# Driver
_softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

# Kernel
@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):

可以看到驱动有 7 个参数,而后者只有 6 个

原因在于其中一些参数是 Triton 保留的关键字,也被称为 Meta-parameters(元参数)

Screenshot 2024-09-14 at 21.48.38

triton/python/triton/runtime/interpreter.py,可以看到实际上有 6 个保留的关键字

网下翻阅,可以发现这些关键字在后面的 GridExecutor 调用时被从参数中过滤掉了:

class GridExecutor:
	"""省略初始化等部分内容"""
	def __call__(self, *args_dev, **kwargs):
        # removes reserved keywords from kwargs
        kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}

Triton 编译器会吸收这些参数,这就是参数量对不上的原因。

Triton Reserved Keywords#

  • Num Warps:内核本身在 GPU 上使用的线程束数量(默认每个 warp 有 32 个线程);
  • Num Stages:决定编译器为软件流水线循环(software-pipelining loops)分配的阶段数。主要用途是用于 SM80+(Ampere)架构的 GPU 上执行矩阵乘法等运算。流水线技术允许多个循环迭代同时执行,每个迭代可以部分重叠执行以提高计算性能。(CSAPP 里学过但死去的记忆再次出现);
  • Num CTAS:每个 SM(流多处理器)上并发执行的线程块(CTA)数量
  • Warps Sepecialization(bool)(已被取代):也称为空间分区(Spatial Partitioning),是一种允许 Warp 执行独立计算 的技术。当启用时,多个 Warp 可以并行执行不同的任务,而不必同步执行相同的指令,如使用在生产者 / 消费者模式。在如今 Triton 中已经被下面三个关键字取代;
  • enable_fp_fusion:启用浮点运算融合,将多个浮点操作融合在同一流水线中执行,进一步提升性能,减少多次执行的开销;
  • grid:控制 Triton 内核的 Grid 结构;
  • maxnreg:用于控制每个线程块(Block)所能使用的最大寄存器数量

参考资料#

感谢:

加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。