banner
Nagi-ovo

Nagi-ovo

Breezing
github
x

OpenAI TritonにおけるSoftmax

本文は @sotadeeplearningtutorials9598 の YouTube チュートリアル学習のまとめであり、先生のわかりやすい指導に感謝します。GPU プログラミングに触れたことがない私でも、実際に効果のある Kernel を初めて書くことができました。

Softmax は一般的な活性化関数であり、通常は多クラス分類タスクの神経ネットワークの出力層で使用されます。これは、入力の実数ベクトルを確率分布に変換し、すべての出力値が 0 から 1 の間にあり、合計が 1 になるようにします。Karpathy の表現によれば、これは logits を 0-1 の間の確率分布に圧縮します。

公式:
Softmax(zi)=ezij=1Kezj\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}

なぜ GPU で Softmax を実装するのか?#

GPU は並列計算タスクの処理に優れています。深層学習モデルは通常、大量のデータと計算を処理する必要があり、GPU を使用することで計算速度を大幅に向上させることができます。

なぜ Triton を選ぶのか?#

Triton は OpenAI が開発したコンパイラおよびプログラミング言語であり、開発者が高性能な kernel をより簡単に記述できるように設計されています。Python に似た高級構文を提供し、CUDA などに比べて GPU プログラミングの複雑さを軽減しています。PyTorch の Triton へのサポートも、エコシステムに貢献したい開発者にさらなる機会を提供しています。

下の図は、Triton が内核開発において性能と効率を両立できることをよく示しています:

Pasted image 20240913134958

図は 知乎の杨军先生の回答: # OpenAI Triton に関する理解 から引用

次の操作は、WSL2、Ubuntu 20.04、Python3.10 で行ったものです:

import torch
import triton
import triton.language as tl
  • triton:Triton のメインライブラリ。
  • triton.language as tl:Triton のプログラミング言語モジュールで、Triton 内核を記述するために必要な関数と操作が含まれています。

GPU の基本知識#

GPU プログラミングにおいて、Kernel は特別な関数であり、並列に実行される計算タスクを定義します。GPU の並列処理能力を効率的に利用するために、この Kernel は複数の実行ユニットに分解され、Block と呼ばれます。この構造により、GPU は大量のデータを高度に並列処理でき、大規模な計算タスクを多数の小さな並列タスクに分解することで、顕著な性能向上を実現します。

Pasted image 20240913102053

  • Kernel:プログラマーが記述したコアアルゴリズムで、各並列実行ユニットが実行すべき操作を記述します。このコードは多数の Thread 上で同じ操作を実行するように設計されています。
  • Block:GPU はこの Kernel タスクを複数の Block に分割し、各 Block 内には多くの Thread(スレッド)が含まれます。これらの Thread は同時に実行され、それぞれがデータの一部を処理しますが、同じ Kernel コードを実行します。

簡単に言えば、Kernel は「何をするか」を定義し、BlockThread は「どのように並列に行うか」を決定します。この方法は GPU のハードウェア特性を最大限に活用し、高効率な並列計算を実現します。

Softmax の実装#

Eager Mode#

まず、純粋な Python で Softmax を実装し、他の実装の正確性を参照および検証します:

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    x_max = x.max(dim=1, keepdim=True)[0]
    safe_x = x - x_max
    numerator = torch.exp(safe_x) 
    denominator = numerator.sum(dim=1, keepdim=True)
    softmax_out = numerator / denominator
    return softmax_out
  • 各行のコードは即座に実行され、計算結果が即座に生成されます。これは Python の通常のコード実行方法に似ています。
  • 計算グラフは動的に構築され、コードを実行するたびに即座に計算グラフが作成され、実行されます。
  • これに対して Graph Mode があり、即座に実行されるかどうか(事前に構築された静的計算グラフではない)で違いがあります。

数値安定性#

ここで言及すべきは safe_x = x - x_max というステップで、これにより最大値を引くことで全ての値が非正数になり、$e^x$ の計算オーバーフローを防ぎ、数値の安定性を向上させます。核心は以下の等式が成り立つことです:

softmax(ximax(x))=eximax(x)jexjmax(x)=exi/emax(x)j(exj/emax(x))=exi/emax(x)(jexj)/emax(x)=exijexj=softmax(xi)\begin{align*} \text{softmax}(x_i - \max(x)) &= \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \\[2ex] &= \frac{e^{x_i} / e^{\max(x)}}{\sum_j (e^{x_j} / e^{\max(x)})} \\[2ex] &= \frac{e^{x_i} / e^{\max(x)}}{(\sum_j e^{x_j}) / e^{\max(x)}} \\[2ex] &= \frac{e^{x_i}}{\sum_j e^{x_j}} \\[2ex] &= \text{softmax}(x_i) \end{align*}

Triton の実装#

内核の開発は実際には二つの部分に分かれます。内核自体と、並列化(parallelizing)を行うドライバです。これにより大量のインスタンスを同時に処理します。

  • ドライバプログラム(Driver Program):これは CPU 上で実行される Python コードで、データの準備、内核パラメータの設定、および Triton 内核の呼び出しを行います。

  • 演算子(Operator)自体:これは Triton で記述された GPU 内核で、実際に Softmax 計算を実行します。

ドライバプログラム#

ここではトップダウンの学習アプローチを用います。まず、driver program(ドライバプログラム)が必要であり、これが多くのメタ情報を設定します。例えば、ブロックサイズ、共有メモリの割り当てなどです。

def softmax(x: torch.Tensor) -> torch.Tensor:
    """ Triton 実装の Softmax、前方伝播のみ """
    rows, cols = x.shape
    assert x.dim() == 2, f"Expected 2D input, got {x.dim()}D input"

    # block_size を計算し、cols 以上の最小の 2 の累乗を求める
    block_size = triton.next_power_of_2(cols)

    # block_size に応じて num_warps を動的に調整
    num_warps = 4  # 各 warp は 32 スレッドを持つ
    if block_size > 2047:
        num_warps = 8
    if block_size > 4095:
        num_warps = 16

    # グリッドサイズを定義し、各スレッドブロック(Block)が一行のデータを処理
    grid = (rows,) # この書き方は rows のみを含むタプルを作成します

    # 入力テンソルと同じ形状の空のテンソルを作成し、出力を格納
    sm_out = torch.empty_like(x)

    # Triton 内核を呼び出す(グリッドを角括弧で渡し、パラメータを内核に渡す)
    _softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

    return sm_out

後でわかるように、ここでドライバプログラムに渡される内核のパラメータは、内核関数自体が宣言したパラメータよりも多いです。その理由は後で説明します。

ここでの巧妙な点は、GPU が通常 2 の累乗サイズのデータブロックを処理する際に最も性能が良いことです。

next_power_of_2 を使用することで、データサイズを最近の 2 の累乗に切り上げることができ、メモリアクセスパターンとアライメントを最適化するのに役立ちます。後で block_size のサイズに応じて num_warps を動的に調整できます。小さな問題には少ないワープを使用してリソースの無駄を避け、大きな問題には GPU の並列能力を十分に活用します。

演算子(Triton Kernel)#

内核は GPU 上で実際の計算を実行します。

デコレーター#

Triton で内核を開発するには、@triton.jit デコレーターを使用して Triton コンパイラに入れる必要があります。

@triton.jit
def _softmax_fwd_kernel():
    pass

内核パラメータ#

  • output_ptr:出力テンソルのメモリ内の開始アドレス。
  • stride_output_row:出力テンソルの行方向のストライド(つまり、各行のメモリ内の間隔)。
  • input_ptr:入力テンソルのメモリ内の開始アドレス。
  • stride_input_row:入力テンソルの行方向のストライド。
  • num_cols:入力テンソルの列数。
  • block_size: tl.constexpr:ブロックサイズ、コンパイル時定数で、各スレッドブロックが処理する要素の数を決定します。

現在のスレッドブロックが処理する行インデックスを取得#

現在のスレッドブロックの第 0 次元(行次元)の ID を取得し、処理する行インデックスを得ます。

row_index = tl.program_id(0)

現在の行のデータポインタを計算#

row_start_ptr = input_ptr + row_index * stride_input_row
col_offsets = tl.arange(0, block_size)
input_ptrs = row_start_ptr + col_offsets
  • row_start_ptr:現在の行のメモリ内の開始アドレス。
  • col_offsets:0 から block_size - 1 までのシーケンスを生成し、列のオフセットを表します。
  • input_ptrs:現在の行の各要素のメモリ内のアドレス。

マスクの作成#

マスクは並列計算において越境アクセスを避けるために使用されます。処理する要素の数がスレッドブロックサイズの整数倍でない場合、マスクを使用して無効なスレッドを隠します。

ここで、列数が block_size より小さい場合、越境アクセスを避けるためにマスクが必要です。

mask = col_offsets < num_cols

グローバルメモリからデータをオンチップストレージ(SRAM)にロード#

row = tl.load(input_ptrs, mask=mask, other=float("-inf"))
  • tl.load:メモリからデータをロードするための API。
  • mask:どのアドレスが有効であるかを示します。
  • other=float("-inf"):無効なアドレスには負の無限大で埋め込み、後続の計算で最大値に影響を与えないようにします。

Softmax 計算#

Triton ライブラリが提供する効率的な並列計算 API を利用して、逐次的に分子を分母で割り、Softmax の出力を得ます。

row_max = tl.max(row, axis=0)
safe_row = row - row_max
numerator = tl.exp(safe_row) 
denominator = tl.sum(numerator, axis=0)
sm_output = numerator / denominator

結果をグローバルメモリに書き戻す#

output_row_ptr = output_ptr + row_index * stride_output_row
output_ptrs = output_row_ptr + col_offsets
tl.store(output_ptrs, sm_output, mask=mask)
  • output_row_ptr:出力テンソルの現在の行の開始アドレス。
  • output_ptrs:出力テンソルの現在の行の各要素のアドレス。
  • tl.store:結果をメモリに書き戻し、ロードと同じマスクを使用して、有効なデータのみを書き戻します。

全体的に見て、私たちの内核は次のようになります:

@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):
    # 現在のプログラムの ID(行インデックス)を取得
    row_index = tl.program_id(0)

    # 現在の行の開始ポインタを計算
    row_start_ptr = input_ptr + (row_index * stride_input_row)
    col_offsets = tl.arange(0, block_size)
    input_pointers = row_start_ptr + col_offsets

    # マスクを作成し、越境アクセスを防ぐ
    row_mask = col_offsets < num_cols

    # グローバルメモリからデータをオンチップ SRAM にロード
    row = tl.load(input_pointers, mask=row_mask, other=float("-inf"))

    # Softmax 計算
    safe_row = row - tl.max(row, axis=0)
    numerator = tl.exp(safe_row)
    denominator = tl.sum(numerator, axis=0)
    sm_out = numerator / denominator

    # 結果をグローバルメモリに書き戻す
    output_row_ptr = output_ptr + (row_index * stride_output_row)
    output_pointers = output_row_ptr + col_offsets
    tl.store(output_pointers, sm_out, mask=row_mask)

ドライバプログラムと演算子の相互作用#

グリッド & ブロック#

私たちのドライバプログラムコードでは:

  • grid = (rows,):グリッドサイズを定義し、一次元の rows 個の Block を作成し、各 Block が入力テンソルの一行のデータを処理します。

パラメータの渡し方#

内核を呼び出すとき、実際には以下のパラメータを渡して内核が入力と出力データを正しく位置付けて処理できるようにします:

_softmax_fwd_kernel[grid](
    sm_out,                # 出力テンソルのポインタ
    sm_out.stride(0),      # 出力テンソルの行方向のストライド
    x,                     # 入力テンソルのポインタ
    x.stride(0),           # 入力テンソルの行方向のストライド
    cols,                  # 入力テンソルの列数
    # 内核の設定パラメータ
    block_size=block_size,
    num_warps=num_warps
)

内核の実行#

各スレッドブロックが一行のデータを処理し、row_index = tl.program_id(0) により、各スレッドブロックは自分が処理すべき行を知っています。

GPU 上の複数のスレッドブロックが同時に実行されるため、多行データを並列処理でき、計算速度が大幅に向上します。

特殊な API の振り返り#

  • tl.arange(start, end)start から end - 1 までのシーケンスを生成し、列オフセットを作成します。
  • tl.program_id(axis):指定された次元で現在のスレッドブロックの ID を取得します。
  • tl.constexpr:コンパイル時に既知の定数を表し、最適化に使用されます。

ベンチマーク#

完全なコードは次を参照:triton_kernels_for_fun_and_profit/demos/demo_softmax.py

Screenshot 2024-09-12 at 17.02.38

3090 Ti 上の性能(GB/s)

元の動画では、Triton は後者よりも最大で 3 倍速いことができましたが、現在の 24 年 9 月のテストでは、Triton は Torch Native よりも少し速く、非常に安定しています。

メタパラメータ#

前述の通り、ドライバプログラムに渡される内核のパラメータが内核関数の宣言された数よりも多いことを覚えていますか?

# ドライバ
_softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

# 内核
@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):

ドライバには 7 つのパラメータがあり、後者には 6 つしかありません。

その理由は、一部のパラメータが Triton の予約語であり、Meta-parameters(メタパラメータ)と呼ばれるからです。

Screenshot 2024-09-14 at 21.48.38

triton/python/triton/runtime/interpreter.py を見ると、実際には 6 つの予約語があることがわかります。

調べてみると、これらの予約語は後の GridExecutor 呼び出し時にパラメータからフィルタリングされます:

class GridExecutor:
	"""省略した初期化などの部分"""
	def __call__(self, *args_dev, **kwargs):
        # kwargs から予約語を削除
        kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}

Triton コンパイラはこれらのパラメータを吸収します。これがパラメータ数が合わない理由です。

Triton 予約語#

  • Num Warps:内核自体が GPU 上で使用するスレッド束の数(デフォルトでは各 warp に 32 スレッド);
  • Num Stages:コンパイラがソフトウェアパイプラインループ(software-pipelining loops)に割り当てる段階数を決定します。主に SM80+(Ampere)アーキテクチャの GPU で行列乗算などの演算を実行するために使用されます。パイプライン技術により、複数のループ反復を同時に実行でき、各反復は部分的に重複して実行され、計算性能が向上します(CSAPP で学んだが、死んだ記憶が再び浮かび上がる);
  • Num CTAS:各 SM(ストリーム多処理器)で同時に実行されるスレッドブロック(CTA)の数
  • Warps Specialization(bool)(廃止された):** 空間分割(Spatial Partitioning)** とも呼ばれ、Warp が独立した計算を実行できる技術です。有効にすると、複数の Warp が異なるタスクを並行して実行でき、同じ命令を同期して実行する必要がなくなります。現在の Triton では、以下の 3 つのキーワードに置き換えられています;
  • enable_fp_fusion:浮動小数点演算の融合を有効にし、複数の浮動小数点操作を同じパイプラインで実行し、性能をさらに向上させ、複数回の実行のオーバーヘッドを削減します;
  • grid:Triton 内核のグリッド構造を制御します;
  • maxnreg:各スレッドブロック(Block)が使用できる最大レジスタ数を制御します。

参考資料#

感謝:

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。