banner
Nagi-ovo

Nagi-ovo

Breezing
github
x

OpenAI Triton中的Softmax

本文是對 @sotadeeplearningtutorials9598 的 Youtube 教程學習的總結,感謝老師深入淺出的指導讓我這個從未接觸過 GPU 編程的小白能夠編寫出第一個有實際效果的 Kernel。

Softmax 是一種常用的激活函數,通常用於多分類任務的神經網絡輸出層。它將輸入的實數向量轉換為概率分佈,滿足所有輸出值在 0 到 1 之間,且總和為 1。Karpathy 的形容是它將 logits squash 為了 0-1 之間的概率分佈。

公式:
Softmax(zi)=ezij=1Kezj\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}

為什麼要在 GPU 上實現 Softmax?#

GPU 擅長處理並行計算任務。深度學習模型通常需要處理大量的數據和計算,使用 GPU 可以顯著提高計算速度。

為什麼選 Triton?#

Triton 是 OpenAI 開發的一個編譯器和編程語言,旨在讓開發者更容易地編寫高性能的 kernel。它提供了類似於 Python 的高級語法,相對於 CUDA 等降低了 GPU 編程的複雜性,PyTorch 對 Triton 的支持也為願意為生態系統做貢獻的開發者提供了更多機會。

下圖很好地體現出了 Triton 在內核開發上能兼顧性能和效率:

Pasted image 20240913134958

圖源自知乎楊軍老師的回答:# 談談對 OpenAI Triton 的一些理解

接下來的操作是我在 WSL2,Ubuntu 20.04,Python3.10 上進行的:

import torch
import triton
import triton.language as tl
  • triton:Triton 主庫。
  • triton.language as tl:Triton 的編程語言模塊,包含了編寫 Triton 內核所需的函數和操作。

GPU 基本知識#

在 GPU 編程中,Kernel 是一個特殊的函數,它定義了需要並行執行的計算任務。為了高效地利用 GPU 的並行處理能力,這個 Kernel 會被分解成多個執行單元,稱為 Block。這種結構允許 GPU 高度並行地處理大量數據,通過將一個大型計算任務分解成眾多小型並行任務,GPU 能夠實現顯著的性能提升。

Pasted image 20240913102053

  • Kernel :是程序員編寫的核心算法,它描述了每個並行執行單元應該執行的操作。。這段代碼被設計為在大量 Thread 上執行相同的操作。
  • Block:GPU 會將這個 Kernel 任務劃分成多個 Block,每個 Block 內部包含許多 Thread(線程)。這些 Thread 同時運行,各自處理一部分數據,但都執行相同的Kernel代碼。

簡而言之,Kernel 定義了 "做什麼",而 BlockThread 決定了 "如何並行地做"。這種方法充分利用了 GPU 的硬件特性,實現了高效的並行計算。

Softmax 實現#

Eager Mode#

先用純 Python 實現 Softmax,來參考和驗證其他實現的正確性:

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    x_max = x.max(dim=1, keepdim=True)[0]
    safe_x = x - x_max
    numerator = torch.exp(safe_x) 
    denominator = numerator.sum(dim=1, keepdim=True)
    softmax_out = numerator / denominator
    return softmax_out
  • 每一行代碼都會立即執行,計算結果會立即產生,類似於 Python 的普通代碼執行方式。
  • 計算圖是動態構建的,每次運行代碼時,會即時創建和執行計算圖。
  • 與之相對的是 Graph Mode,區別在於是否立即執行(而非預構建靜態計算圖)

數值穩定性#

這裡要提一嘴是 safe_x = x - x_max 這一步,這樣減去最大值可以將所有值都變為非正數,防止 $e^x$ 計算溢出,提高數值穩定性,核心在於下面等式成立:

softmax(ximax(x))=eximax(x)jexjmax(x)=exi/emax(x)j(exj/emax(x))=exi/emax(x)(jexj)/emax(x)=exijexj=softmax(xi)\begin{align*} \text{softmax}(x_i - \max(x)) &= \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \\[2ex] &= \frac{e^{x_i} / e^{\max(x)}}{\sum_j (e^{x_j} / e^{\max(x)})} \\[2ex] &= \frac{e^{x_i} / e^{\max(x)}}{(\sum_j e^{x_j}) / e^{\max(x)}} \\[2ex] &= \frac{e^{x_i}}{\sum_j e^{x_j}} \\[2ex] &= \text{softmax}(x_i) \end{align*}

Triton Implement#

內核的開發實際上分為兩部分,內核本身和使並行化(parallelizing)的驅動,讓它同時處理大量實例。

  • 驅動程序(Driver Program):這是在 CPU 上運行的 Python 代碼,用於準備數據、配置內核參數,並調用 Triton 內核。

  • 算子(Operator)本身:這是用 Triton 編寫的 GPU 內核,實際執行 Softmax 計算。

驅動程序#

這裡用自頂向下的學習思路,首先你需要一個 driver program(驅動程序),它會設置大量 meta information,如 block size,共享內存分配等。

def softmax(x: torch.Tensor) -> torch.Tensor:
    """ Triton 實現的 Softmax,只有前向傳播 """
    rows, cols = x.shape
    assert x.dim() == 2, f"Expected 2D input, got {x.dim()}D input"

    # 計算 block_size,為大於等於 cols 的最小 2 的幂
    block_size = triton.next_power_of_2(cols)

    # 根據 block_size 動態調整 num_warps
    num_warps = 4  # 每個 warp 有 32 個線程
    if block_size > 2047:
        num_warps = 8
    if block_size > 4095:
        num_warps = 16

    # 定義網格大小,每個線程塊(Block)處理一行數據
    grid = (rows,) # 這個寫法會創建一個只包含 rows 的 tuple

    # 創建一個與輸入張量形狀相同的空張量,用於存儲輸出
    sm_out = torch.empty_like(x)

    # 調用 Triton 內核(使用方括號傳入 grid,然後將參數傳遞給內核)
    _softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

    return sm_out

後面你會發現,這裡驅動程序中傳入 kernel 的參數 kernel 函數本身聲明的參數多,原因先按下不表

這裡一個比較巧妙的點在於:GPU 通常在處理 2 的幂次大小的數據塊時性能最佳。

使用 next_power_of_2 可以將數據大小向上取整到最近的 2 的幂,這有助於優化內存訪問模式和對齊,後面還可以根據 block_size 的大小動態調整 num_warps—— 較小的問題使用較少的 warps 避免資源浪費,較大的則充分利用 GPU 並行能力。

算子(Triton Kernel)#

內核在 GPU 上執行實際的計算。

裝飾器#

在 Triton 中開發一個內核,需要使用 @triton.jit 裝飾器使其進入 Triton 編譯器,

@triton.jit
def _softmax_fwd_kernel():
    pass

內核參數#

  • output_ptr:輸出張量在內存中的起始地址。
  • stride_output_row:輸出張量在行方向上的步幅(即每行在內存中的間隔)。
  • input_ptr:輸入張量在內存中的起始地址。
  • stride_input_row:輸入張量在行方向上的步幅。
  • num_cols:輸入張量的列數。
  • block_size: tl.constexpr:塊大小,編譯時常量,決定每個線程塊處理的元素數量。

獲取當前線程塊處理的行索引#

獲取當前線程塊在第 0 維度(行維度)的 ID,即處理的行索引。

row_index = tl.program_id(0)

計算當前行的數據指針#

row_start_ptr = input_ptr + row_index * stride_input_row
col_offsets = tl.arange(0, block_size)
input_ptrs = row_start_ptr + col_offsets
  • row_start_ptr:當前行在內存中的起始地址。
  • col_offsets:生成一個從 0 到 block_size - 1 的序列,表示列的偏移量。
  • input_ptrs:當前行中每個元素在內存中的地址。

創建掩碼(Mask)#

掩碼用於在並行計算中避免越界訪問。當處理的元素數量不是線程塊大小的整數倍時,使用掩碼屏蔽無效的線程。

這裡當列數小於 block_size 時,需要掩碼來避免訪問越界的內存地址。

mask = col_offsets < num_cols

從全局內存加載數據到片上存儲器(SRAM)#

row = tl.load(input_ptrs, mask=mask, other=float("-inf"))
  • tl.load:從內存加載數據的 API。
  • mask:指示哪些地址是有效的。
  • `other=float("-inf"):對於無效地址,填充為負無窮,以確保在後續計算最大值時不影響結果。

Softmax 計算#

利用 Triton 庫提供的高效並行計算 API 實現逐元素地將分子除以分母,得到 Softmax 的輸出

row_max = tl.max(row, axis=0)
safe_row = row - row_max
numerator = tl.exp(safe_row) 
denominator = tl.sum(numerator, axis=0)
sm_output = numerator / denominator

將結果寫回全局內存#

output_row_ptr = output_ptr + row_index * stride_output_row
output_ptrs = output_row_ptr + col_offsets
tl.store(output_ptrs, sm_output, mask=mask)
  • output_row_ptr:輸出張量當前行的起始地址。
  • output_ptrs:輸出張量當前行中每個元素的地址。
  • tl.store:將結果寫回內存,使用與加載相同的掩碼,確保只寫回有效的數據。

總體來看,我們的 kernel 長這樣:

@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):
    # 獲取當前程序的 ID(行索引)
    row_index = tl.program_id(0)

    # 計算當前行的起始指針
    row_start_ptr = input_ptr + (row_index * stride_input_row)
    col_offsets = tl.arange(0, block_size)
    input_pointers = row_start_ptr + col_offsets

    # 創建掩碼,防止越界訪問
    row_mask = col_offsets < num_cols

    # 從全局內存加載數據到片上 SRAM
    row = tl.load(input_pointers, mask=row_mask, other=float("-inf"))

    # Softmax 計算
    safe_row = row - tl.max(row, axis=0)
    numerator = tl.exp(safe_row)
    denominator = tl.sum(numerator, axis=0)
    sm_out = numerator / denominator

    # 將結果寫回全局內存
    output_row_ptr = output_ptr + (row_index * stride_output_row)
    output_pointers = output_row_ptr + col_offsets
    tl.store(output_pointers, sm_out, mask=row_mask)

驅動程序與算子的交互#

Grid & Block#

在我們的驅動程序代碼中:

  • grid = (rows,):定義了網格大小,即一維的 rowsBlock,每個 Block 處理輸入張量的一行數據。

參數傳遞#

當我們調用內核時,實際上傳遞了以下參數使內核能夠正確地定位和處理輸入和輸出數據:

_softmax_fwd_kernel[grid](
    sm_out,                # 輸出張量的指針
    sm_out.stride(0),      # 輸出張量在行方向上的步幅
    x,                     # 輸入張量的指針
    x.stride(0),           # 輸入張量在行方向上的步幅
    cols,                  # 輸入張量的列數
    # 內核的配置參數
    block_size=block_size,
    num_warps=num_warps
)

內核的執行#

每個線程塊處理一行數據,通過 row_index = tl.program_id(0),每個線程塊知道自己應該處理哪一行。

GPU 上的多個線程塊同時執行,使得多行數據可以並行處理,大大加快了計算速度。

特殊的 API 回顧#

  • tl.arange(start, end):生成一個從 startend - 1 的序列,用於創建列偏移量。
  • tl.program_id(axis):獲取當前線程塊在指定維度的 ID。
  • tl.constexpr:表示一個在編譯時已知的常量,用於優化。

Benchmark#

完整代碼見:triton_kernels_for_fun_and_profit/demos/demo_softmax.py

Screenshot 2024-09-12 at 17.02.38

3090 Ti 上的表現性能(GB/s)

在原視頻中 Triton 最多能比後者快近三倍,現在 24 年 9 月份測試中 Triton 還是比 Torch Native 快一點,而且十分穩定。

元參數#

還記得我們前面提到驅動程序中傳入 kernel 的參數比 kernel 函數聲明的數量多嗎?

# Driver
_softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

# Kernel
@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):

可以看到驅動有 7 個參數,而後者只有 6 個

原因在於其中一些參數是 Triton 保留的關鍵字,也被稱為 Meta-parameters(元參數)。

Screenshot 2024-09-14 at 21.48.38

triton/python/triton/runtime/interpreter.py,可以看到實際上有 6 個保留的關鍵字

網下翻閱,可以發現這些關鍵字在後面的 GridExecutor 調用時被從參數中過濾掉了:

class GridExecutor:
	"""省略初始化等部分內容"""
	def __call__(self, *args_dev, **kwargs):
        # removes reserved keywords from kwargs
        kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}

Triton 編譯器會吸收這些參數,這就是參數量對不上的原因。

Triton Reserved Keywords#

  • Num Warps:內核本身在 GPU 上使用的線程束數量(默認每個 warp 有 32 個線程);
  • Num Stages:決定編譯器為軟件流水線循環(software-pipelining loops)分配的階段數。主要用途是用於 SM80+(Ampere)架構的 GPU 上執行矩陣乘法等運算。流水線技術允許多個循環迭代同時執行,每個迭代可以部分重疊執行以提高計算性能。(CSAPP 里學過但死去的記憶再次出現);
  • Num CTAS:每個 SM(流多處理器)上並發執行的線程塊(CTA)數量
  • Warps Sepecialization(bool)(已被取代):也稱為空間分區(Spatial Partitioning),是一種允許 Warp 執行獨立計算的技術。當啟用時,多個 Warp 可以並行執行不同的任務,而不必同步執行相同的指令,如使用在生產者 / 消費者模式。在如今 Triton 中已經被下面三個關鍵字取代;
  • enable_fp_fusion:啟用浮點運算融合,將多個浮點操作融合在同一流水線中執行,進一步提升性能,減少多次執行的開銷;
  • grid:控制 Triton 內核的 Grid 結構;
  • maxnreg:用於控制每個線程塊(Block)所能使用的最大寄存器數量

參考資料#

感謝:

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。