banner
Nagi-ovo

Nagi-ovo

Breezing
github
x

Vector Add in Triton

单线程版本#

逐元素相加:

Screenshot 2024-09-19 at 15.34.56

Triton 实现#

在 Triton 中,向量加法内核通过将向量划分为多个块(blocks),并在每个 Grid 中的线程(threads)并行计算,实现高效的向量加法操作。每个线程负责加载两个向量中对应位置的元素,进行相加并存储结果。

Screenshot 2024-09-19 at 15.35.11

核心步骤#

  1. 线程并行计算:每个 Grid 中的线程独立处理向量中的一部分元素。
  2. 加载元素:每个线程加载向量 A 和向量 B 中对应位置的元素。
  3. 元素相加:将加载的元素进行相加。
  4. 存储结果:将相加后的结果存储到输出向量中。

tl.constexpr 的使用#

tl.constexpr 用于声明编译时常量。这意味着使用这个修饰符的变量的值在编译时就已经确定,而不是在运行时。编译器可以基于这些常量值进行更 aggressive 的优化来提升内核的执行效率

@triton.jit 
def kernel_vector_addition(a_ptr, b_ptr, out_ptr,
						   num_elems: tl.constexpr,
						   block_size: tl.constexpr): 
	# 内核代码

上述代码中,num_elemsblock_size 被声明为编译时常量,使得 Triton 可以在编译阶段优化内核代码。

确定当前块与 Program ID#

每个线程块(block)在 Triton 中都有一个唯一的 Program ID,用于标识当前线程所在的块。通过 tl.program_id,我们可以确定当前线程所在的块,从而计算处理的数据偏移量。

pid = tl.program_id(axis=0)
block_start = pid * block_size

处理最后一个 Block#

由于向量长度可能不被块大小整除,最后一个块可能只有部分线程需要工作。通过掩码操作,可以确保只有有效的线程进行计算,避免无效的内存访问和计算。

掩码的作用#

Triton 提供了掩码(mask)操作,用于屏蔽那些不需要工作的线程(最后一个 Grid 中 NA 的线程)。

mask = thread_offsets < num_elems
a_pointers = tl.load(a_ptr + thread_offsets, mask=mask, other=0.0)
b_pointers = tl.load(b_ptr + thread_offsets, mask=mask, other=0.0)

ceil_div 函数的作用#

ceil_div 函数用于计算块的数量,确保即使向量长度不被块大小整除,也能覆盖所有元素。例如 vec_size=10,block_size=3,ceil_div(10, 3)=4,这样就能确保所有 10 个元素都被处理。

def ceil_div(x: int, y: int) -> int:
    return (x + y - 1) // y

说白了,该函数的作用就是高效实现 “向上取整”。

数值精度验证#

在实现向量加法内核后,验证数值精度是确保内核正确性的关键步骤。通过与 PyTorch 的内置加法操作进行对比,可以确认 Triton 实现的准确性。

def verify_numerics() -> bool:
    torch.manual_seed(2020) # seed both cpu and gpu
    vec_size = 8192
    a = torch.rand(vec_size, device='cuda')
    b = torch.rand_like(a)
    torch_res = a + b
    triton_res = vector_addition(a, b)
    fidelity_correct = torch.allclose(torch_res, triton_res)
    print(f"{fidelity_correct=}")
    return fidelity_correct

Screenshot 2024-09-19 at 22.49.16

验证了解到我们的 Triton 实现与 PyTorch 原生的数值精度一致,可以进行后面的操作了。

下面是完整的 Kernel 实现:

@triton.jit
def kernel_vector_addition(a_ptr, b_ptr, out_ptr,
                           num_elems: tl.constexpr,
                           block_size: tl.constexpr,):

    pid = tl.program_id(axis=0)
    # tl.device_print("pid", pid)
    block_start = pid * block_size # 0 * 2 = 0, 1 * 2 = 2,
    thread_offsets = block_start + tl.arange(0, block_size)
    mask = thread_offsets < num_elems
    a_pointers = tl.load(a_ptr + thread_offsets, mask=mask)
    b_pointers = tl.load(b_ptr + thread_offsets, mask=mask)
    res = a_pointers + b_pointers
    tl.store(out_ptr + thread_offsets, res, mask=mask)


def ceil_div(x: int,y: int) -> int:
    return (x + y - 1) // y

def vector_addition(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    output_buffer = torch.empty_like(a)
    assert a.is_cuda and b.is_cuda
    num_elems = a.numel()
    assert num_elems == b.numel() # todo - handel mismatched sizes

    block_size = 1024
    grid_size = ceil_div(num_elems, block_size)
    grid = (grid_size,)
    num_warps = 8

    k2 = kernel_vector_addition[grid](a, b, output_buffer,
                                      num_elems,
                                      block_size,
                                      num_warps=num_warps
                                      )
    return output_buffer

基准测试与性能调优#

为了评估 Triton 向量加法内核的性能,下面进行基准测试并探讨性能调优的方法。

Benchmark API 介绍#

Triton 提供了丰富的基准测试 API,允许用户测量内核的执行时间和吞吐量。以下代码是使用 triton.testing.perf_report 得到性能报告的一个示例:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],  # 用作图表x轴的参数名
        x_vals=[2**i for i in range(10, 28)],  # `x_name`的可能取值
        x_log=True,  # x轴使用对数刻度
        line_arg='provider',  # 图表中不同线条对应的参数名
        line_vals=['triton', 'torch'],  # `line_arg`的可能取值
        line_names=["Triton", "Torch"],  # 线条的标签名
        styles=[('blue', '-'), ('green', '-')],  # 线条颜色和样式
        ylabel='GB/s',  # y轴标签
        plot_name='vector-add-performance',  # 图表名称,也用作保存文件的文件名
        args={},  # 不在`x_names`和`y_name`中的函数参数的值
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device='cuda', dtype=torch.float32)
    y = torch.rand(size, device='cuda', dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]  # 设定分位数
    
    # 根据 provider 选择不同的计算实现
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: vector_addition(x, y), quantiles=quantiles)
    
    # 计算GB/s
    def gbps(ms):
        return 12 * size / ms * 1e-06
    
    # 返回中位数、最大值和最小值对应的GB/s
    return gbps(ms), gbps(max_ms), gbps(min_ms)

性能报告:

Screenshot 2024-09-19 at 22.51.39

性能对比:

Screenshot 2024-09-19 at 21.47.42

总的来说,vector add 只是一个相对简单的 kernel,相较于复杂的内核更难获得 Triton 实现带来的优势(大部分常用操作 PyTorch 已经通过 CUDA/cuBLAS 等优化到极质了)

调优参数:Num Warps & Block size#

调优内核性能的关键在于合理配置 Warp 数量和块大小。Warp 是 GPU 中的基本执行单元,合理的 Warp 数量和块大小能够充分利用 GPU 的并行计算能力,提升内核的执行效率。

block_size = 1024 # 决定每个线程块处理的元素数量,较大的块大小可以减少块的数量,但可能增加每个块的计算负担。
grid_size = ceil_div(num_elems, block_size)
grid = (grid_size,)
num_warps = 8 # 每个块中包含的 Warp 数量,合理配置 Warp 数量可以优化线程的调度和资源利用。

上节 (Softmax in OpenAI Triton)便给出了通过驱动程序来动态调整参数的方式:

# 计算 block_size,为大于等于 cols 的最小 2 的幂
    block_size = triton.next_power_of_2(cols)

    # 根据 block_size 动态调整 num_warps
    num_warps = 4  # 每个 warp 有 32 个线程
    if block_size > 2047:
        num_warps = 8
    if block_size > 4095:
        num_warps = 16

参考资料#

感谢:

加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。