Home Triton Example
Post
Cancel

Triton Example

Triton Kernel 的写法可以参考:官方 tutorial

add

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(
    a_ptr,       # 输入向量 A 的指针
    b_ptr,       # 输入向量 B 的指针
    output_ptr,  # 输出向量的指针
    N,           # 向量长度
    BLOCK_SIZE: tl.constexpr):

    # 线程处理的索引
    idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = idx < N

    # 加载 A 和 B
    a = tl.load(a_ptr + idx, mask=mask, other=0.0)
    b = tl.load(b_ptr + idx, mask=mask, other=0.0)

    # 执行加法
    result = a + b

    # 写入输出
    tl.store(output_ptr + idx, result, mask=mask)

# 使用示例
def vector_add(a, b):
    import torch

    N = a.shape[0]
    assert a.shape == b.shape, "A 和 B 必须具有相同的形状。"

    output = torch.empty_like(a)

    BLOCK_SIZE = 128
    grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, )

    add_kernel[
        grid
    ](
        a, b, output, N, BLOCK_SIZE
    )

    return output

if __name__ == "__main__":
    import torch

    # 随机生成两个向量 A 和 B
    N = 128
    a = torch.randn(N, dtype=torch.float32, device='cuda')
    b = torch.randn(N, dtype=torch.float32, device='cuda')

    # 调用 vector_add 函数
    output = vector_add(a, b)

laynernorm

最简单的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch

import triton
import triton.language as tl

@triton.jit
def layernorm_kernel(
    input_ptr,     # 输入张量指针
    output_ptr,    # 输出张量指针
    weight_ptr,     # 缩放参数指针
    bias_ptr,      # 偏移参数指针
    M,             # 行数
    N: tl.constexpr,             # 列数
    epsilon,       # 防止除以 0 的
    BLOCK_SIZE: tl.constexpr):

    # 计算当前线程处理的行索引
    row_idx = tl.program_id(0)

    # 计算该行的起始地址
    row_start = row_idx * N

    # 加载该行的数据到共享内存
    row = tl.load(input_ptr + row_start + tl.arange(0, N))

    # Step 1: 计算均值
    mean = tl.sum(row, axis=0) / N

    # Step 2: 计算方差
    var = tl.sum(tl.pow((row - mean), 2), axis=0) / N

    # Step 3: 标准化
    norm_row = (row - mean) / tl.sqrt(var + epsilon)

    # Step 4: 应用权重
    weight = tl.load(weight_ptr + tl.arange(0, N))
    bias = tl.load(bias_ptr + tl.arange(0, N))
    result = norm_row * weight + bias

    # Step 5: 写入输出
    tl.store(output_ptr + row_start + tl.arange(0, N), result)

# 使用示例
def layernorm(input, weight, bias, epsilon=1e-5):
    import torch

    M, N = input.shape
    output = torch.empty_like(input)

    # 确保 weight 和 bias 的形状一致
    assert weight.shape[0] == bias.shape[0] == N

    # 分配块大小
    BLOCK_SIZE = 128  # 假设列数是 BLOCK_SIZE 的倍数

    grid = (M,)  # 每一行一个块

    # 调用 Triton 内核
    layernorm_kernel[
        grid
    ](
        input,
        output,
        weight,
        bias,
        M,
        N,
        epsilon,
        BLOCK_SIZE
    )

    return output


if __name__ == "__main__":

    # 随机生成输入张量
    M, N = 32, 128  # 假设有 32 行,每行 128 列
    input = torch.randn(M, N, dtype=torch.float32, device='cuda')
    weight = torch.randn(N, dtype=torch.float32, device='cuda')
    bias = torch.randn(N, dtype=torch.float32, device='cuda')

    # 调用 layernorm 函数
    output = layernorm(input, weight, bias)

matmul

简单的 matmul

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch

import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    a_ptr,       # 矩阵 A 的指针
    b_ptr,       # 矩阵 B 的指针
    c_ptr,       # 矩阵 C 的指针
    M,           # A 的行数
    N,           # B 的列数
    K,           # A 的列数 / B 的行数
    BLOCK_SIZE_M: tl.constexpr,  # 每个块的行数
    BLOCK_SIZE_N: tl.constexpr,  # 每个块的列数
    BLOCK_SIZE_K: tl.constexpr   # 中间维度的块大小
):
    # 计算当前线程块的起始索引
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # 确定块的起始位置
    block_m = pid_m * BLOCK_SIZE_M
    block_n = pid_n * BLOCK_SIZE_N

    # 初始化累积值
    c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # 循环处理中间维度的块
    for k in range(0, K, BLOCK_SIZE_K):
        # 加载 A 和 B 的块
        a = tl.load(
            a_ptr + (block_m + tl.arange(0, BLOCK_SIZE_M))[:, None] * K + (k + tl.arange(0, BLOCK_SIZE_K)),
            mask=(block_m + tl.arange(0, BLOCK_SIZE_M))[:, None] < M,
            other=0.0
        )
        b = tl.load(
            b_ptr + (k + tl.arange(0, BLOCK_SIZE_K))[:, None] * N + (block_n + tl.arange(0, BLOCK_SIZE_N)),
            mask=(k + tl.arange(0, BLOCK_SIZE_K))[:, None] < K,
            other=0.0
        )

        # 计算局部矩阵乘法并累积
        c += tl.dot(a, b)

    # 将结果写回 C
    tl.store(
        c_ptr + (block_m + tl.arange(0, BLOCK_SIZE_M))[:, None] * N + (block_n + tl.arange(0, BLOCK_SIZE_N)),
        c,
        mask=((block_m + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ((block_n + tl.arange(0, BLOCK_SIZE_N)) < N)
    )

# 使用示例
def matmul(a, b):
    import torch

    M, K = a.shape
    K_b, N = b.shape
    assert K == K_b, "A 的列数必须等于 B 的行数。"

    c = torch.empty((M, N), dtype=torch.float32, device='cuda')

    BLOCK_SIZE_M = 32
    BLOCK_SIZE_N = 32
    BLOCK_SIZE_K = 32

    grid = ((M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M, (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N, 1)

    matmul_kernel[
        grid
    ](
        a, b, c, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K
    )

    return c

if __name__ == "__main__":
    import torch

    # 随机生成两个矩阵 A 和 B
    M, K, N = 128, 128, 128
    a = torch.randn((M, K), dtype=torch.float32, device='cuda')
    b = torch.randn((K, N), dtype=torch.float32, device='cuda')

    # 调用 matmul 函数
    c = matmul(a, b)

tutorial_mm 源码说明

来自 03-mm

tutorials/03-matrix-multiplication.py 中矩阵乘优化为例。

下面的group-order的行为能获得更好的data-reuse

layout

分析:A和B中的内容都是行优先存储,以计算九个数为例,那么原始的一次load需要9+9$\times$9=90次read和9次write。而group order中,一次load需要9$\times$3+3$\times$9=54次read和9次write

  • num_pid_m 和 num_pid_n 就是为来获得矩阵长宽各可以分为多少个block(上图的黄色小块)
1
2
3
4
pid = tl.program_id(axis=0)
# number of program ids along the M / N axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  • num_pid_in_group 表示一个高是 GROUP_SIZE_M , 宽是 num_pid_n的group中包含多少个黄色小块
1
2
# number of program in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
  • group_id表示当前循环iter是在哪个group内
1
2
# id of the group which related to this program
group_id = pid // num_pid_in_group
  • first_pid_m 表示当前所在的的group内的第一个黄色block是全局的第几个黄色block(从m的维度上看)
1
2
# row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M = (pid // (GROUP_SIZE_M * num_pid_n)) * GROUP_SIZE_M
  • 重复计算下group_size_m,防止越界
1
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
  • 得到当前循环需要处理哪个块 [pid_m, pid_n]

pid_m ≤ first_pid_m + group_size_m

pid_n 是从左到右一列列来的,000111222

1
2
3
4
5
# row-id of the p in the launch grid
pid_m = first_pid_m + pid % group_size_m # 行id
# col-id of the p in the launch grid
pid_n = (pid % num_pid_in_group) // group_size_m # 列id
# num_pid_in_group = GROUP_SIZE_M * num_pid_n

a_ptr 是A矩阵第一个元素的地址

offs_am 和 offs_bn 是 A 矩阵 9 个 block 中第一个 block 中, 每个元素在整个 A 矩阵中的坐标,即 m 维度的 index 和 k 维度的 index

1
2
3
4
5
6
7
    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
1
2
3
4
5
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c+ptr + stride_cm * offset_cm[:, None] + stride_cn * offset_cn[None, :]
c_mask = (offset_cm[:, None] < M) & (offset_cn[None, :] < N)
tl.store(c_ptrs, mask=c_mask)

计算循环,mask保证load和store不越界

1
2
3
4
5
6
7
8
9
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator += tl.dot(a, b)
        # 计算下K个BLOCK
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

mn 主序

输入 A, B,输出C,计算:$C = A \times B$

  • A shape (M, K)
  • B shape (K, N)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_M: tl.constexpr,
        ACTIVATION: tl.constexpr
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator = tl.dot(a, b, accumulator)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float32)

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

def matmul(a, b, activation=""):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float32)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation  #
    )
    return c

kk 主序

输入 A, B,输出C,计算:$C = A^{T} \times B^{T}$

  • A shape (K, M)
  • B shape (N, K)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        stride_ak, stride_am,
        stride_bn, stride_bk,
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_M: tl.constexpr,
        ACTIVATION: tl.constexpr
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    # 这里开始有区别
    # a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) ->
    a_ptrs = a_ptr + (offs_k[:, None] * stride_ak + offs_am[None, :] * stride_am)
    # b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) ->
    b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_k[None, :] * stride_bk)

    # -----------------------------------------------------------
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # mask也是不同的
        a = tl.load(a_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator = tl.dot(a.T, b.T, accumulator)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float16)
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

def matmul(a, b, activation=""):
    # Check constraints.
    assert a.is_contiguous(), "Matrix A must be contiguous"
    K, M = a.shape
    N, K = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation  #
    )
    return c

batch matmul

以下给的是MN主序的BMM,KK主序的很容易照着改

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@triton.autotune(
    configs=get_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def bmm_kernel(
    # Pointers to matrices
    A_ptr, B_ptr, C_ptr,
    # Matrix dimensions
    B, M, N, K,
    stride_ab, stride_am, stride_ak,
    stride_bb, stride_bk, stride_bn,
    stride_cb, stride_cm, stride_cn,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    ACTIVATION: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    offs_b = tl.program_id(axis=1)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M // 防止越界
    offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    A_ptr = A_ptr + (offs_b * stride_ab + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
    B_ptr = B_ptr + (offs_b * stride_bb + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a_mask = (offs_b < B) & (offs_k[None, :] < K - k * BLOCK_SIZE_K)
        b_mask = (offs_b < B) & (offs_k[:, None] < K - k * BLOCK_SIZE_K)
        a = tl.load(A_ptr, mask=a_mask, other=0.0)
        b = tl.load(B_ptr, mask=b_mask, other=0.0)
        acc += tl.dot(a, b, out_dtype=tl.float32)
        A_ptr += BLOCK_SIZE_K * stride_ak
        B_ptr += BLOCK_SIZE_K * stride_bk

    # Write back.
    if ACTIVATION == "leaky_relu":
        acc = leaky_relu(acc)
    c = acc.to(tl.float16)
    C_ptr = C_ptr + (offs_b * stride_cb + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    c_mask = (offs_b < B) & (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(C_ptr, c, mask=c_mask)

@triton.jit
def leaky_relu(x):
    return tl.where(x >= 0, x, 0.01 * x)

def bmm(a, b, activation=""):
    # Check constraints.
    assert a.shape[0] == b.shape[0], "Incompatible dimensions"
    assert a.shape[2] == b.shape[1], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    B, M, K = a.shape
    B, K, M = b.shape
    c = torch.empty((B, M, N), device=a.device, dtype=torch.float16)
    # 2D launch kernel
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), B,)
    bmm_kernel[grid](
        a, b, c,  #
        B, M, N, K,  #
        a.stride(0), a.stride(1), a.stride(2), #
        b.stride(0), b.stride(1), b.stride(2), #
        c.stride(0), c.stride(1), c.stride(2), #
        ACTIVATION=activation  #
    )
    return c

attention

scaled_dot_product_attention_kernel

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch

import triton
import triton.language as tl

@triton.jit
def scaled_dot_product_attention_kernel(
    q_ptr,       # Query 矩阵指针
    k_ptr,       # Key 矩阵指针
    v_ptr,       # Value 矩阵指针
    o_ptr,       # 输出矩阵指针
    M,           # Query 的行数
    N,           # Key 的列数 / Value 的列数
    K,           # Query 的列数 / Key 的行数
    SCALE,       # 缩放因子
    BLOCK_SIZE_M: tl.constexpr,  # 每个块的行数
    BLOCK_SIZE_N: tl.constexpr,  # 每个块的列数
    BLOCK_SIZE_K: tl.constexpr   # 中间维度的块大小
):
    # 当前块的起始索引
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # 块的起始位置
    block_m = pid_m * BLOCK_SIZE_M
    block_n = pid_n * BLOCK_SIZE_N

    # 初始化累积值
    output = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # 循环处理中间维度的块
    for k in range(0, K, BLOCK_SIZE_K):
        # 加载 Query 和 Key 的块
        q = tl.load(
            q_ptr + (block_m + tl.arange(0, BLOCK_SIZE_M))[:, None] * K + (k + tl.arange(0, BLOCK_SIZE_K)),
            mask=(block_m + tl.arange(0, BLOCK_SIZE_M))[:, None] < M,
            other=0.0
        )
        k_block = tl.load(
            k_ptr + (k + tl.arange(0, BLOCK_SIZE_K))[:, None] * N + (block_n + tl.arange(0, BLOCK_SIZE_N)),
            mask=(k + tl.arange(0, BLOCK_SIZE_K))[:, None] < K,
            other=0.0
        )

        # 计算缩放点积并应用 softmax(近似处理,仅按列归一化)
        logits = tl.dot(q, k_block) * SCALE
        max_logits = tl.max(logits, axis=1)
        logits_exp = tl.exp(logits - max_logits[:, None])
        softmax = logits_exp / tl.sum(logits_exp, axis=1)[:, None]

        # 加载 Value 的块并计算加权求和
        v = tl.load(
            v_ptr + (k + tl.arange(0, BLOCK_SIZE_K))[:, None] * N + (block_n + tl.arange(0, BLOCK_SIZE_N)),
            mask=(k + tl.arange(0, BLOCK_SIZE_K))[:, None] < K,
            other=0.0
        )
        output += tl.dot(softmax, v)

    # 写回结果
    tl.store(
        o_ptr + (block_m + tl.arange(0, BLOCK_SIZE_M))[:, None] * N + (block_n + tl.arange(0, BLOCK_SIZE_N)),
        output,
        mask=((block_m + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ((block_n + tl.arange(0, BLOCK_SIZE_N)) < N)
    )

# 使用示例
def scaled_dot_product_attention(q, k, v, scale):
    import torch

    M, K = q.shape
    K_k, N = k.shape
    K_v, N_v = v.shape
    assert K == K_k == K_v, "Q, K, V 的列数必须一致。"
    assert N == N_v, "K 和 V 的列数必须一致。"

    o = torch.empty((M, N), dtype=torch.float32, device='cuda')

    BLOCK_SIZE_M = 32
    BLOCK_SIZE_N = 32
    BLOCK_SIZE_K = 32

    grid = ((M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M, (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N)

    scaled_dot_product_attention_kernel[
        grid
    ](
        q, k, v, o, M, N, K, scale, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K
    )

    return o

if __name__ == "__main__":
    import torch

    # 随机生成 Query, Key, Value 矩阵
    M, K, N = 128, 64, 128
    q = torch.randn((M, K), dtype=torch.float32, device='cuda')
    k = torch.randn((K, N), dtype=torch.float32, device='cuda')
    v = torch.randn((K, N), dtype=torch.float32, device='cuda')

    # 缩放因子
    scale = 1.0 / (K ** 0.5)

    # 调用 scaled_dot_product_attention 函数
    o = scaled_dot_product_attention(q, k, v, scale)

    # # 打印结果
    # print("Q:\n", q)
    # print("K:\n", k)
    # print("V:\n", v)
    # print("Output:\n", o)

DL Network

resnet

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch

import triton
import triton.language as tl

@triton.jit
def resnet_block_kernel(
    x_ptr,        # 输入张量指针
    w1_ptr,       # 第一个卷积核权重指针
    w2_ptr,       # 第二个卷积核权重指针
    out_ptr,      # 输出张量指针
    M, N, K: tl.constexpr,      # 输入和输出的尺寸参数
    BLOCK_SIZE_M: tl.constexpr,  # 行分块大小
    BLOCK_SIZE_N: tl.constexpr   # 列分块大小
):
    # 当前线程块的起始索引
    row_offsets = tl.arange(0, BLOCK_SIZE_M)
    col_offsets = tl.arange(0, BLOCK_SIZE_N)

    row_idx = tl.program_id(0) * BLOCK_SIZE_M + row_offsets
    col_idx = tl.program_id(1) * BLOCK_SIZE_N + col_offsets

    mask_row = row_idx < M
    mask_col = col_idx < N

    # 每个线程块加载一部分输入数据
    x = tl.load(x_ptr + row_idx[:, None] * K + tl.arange(0, K), mask=mask_row[:, None], other=0.0)

    # 第一次卷积
    w1 = tl.load(w1_ptr + col_idx[None, :] * K + tl.arange(0, K)[:, None], mask=mask_col[None, :], other=0.0)
    h1 = tl.dot(x, w1)

    # ReLU 激活
    h1_relu = tl.where(h1 > 0, h1, 0.0)

    # 第二次卷积
    w2 = tl.load(w2_ptr + col_idx[None, :] * K + tl.arange(0, K)[:, None], mask=mask_col[None, :], other=0.0)
    h2 = tl.dot(h1_relu, w2)

    # 残差连接
    x_residual = tl.load(x_ptr + row_idx[:, None] * K + col_idx[None, :], mask=(mask_row[:, None] & mask_col[None, :]), other=0.0)
    out = h2 + x_residual

    # 写回输出
    tl.store(out_ptr + row_idx[:, None] * N + col_idx[None, :], out, mask=(mask_row[:, None] & mask_col[None, :]))

# 使用示例
def resnet_block(x, w1, w2):
    import torch

    M, K = x.shape
    K_w1, N = w1.shape
    K_w2, N_w2 = w2.shape

    assert K == K_w1 == K_w2, "输入与权重维度不匹配"
    assert N == N_w2, "两次卷积输出维度必须一致"

    output = torch.empty((M, N), dtype=torch.float32, device='cuda')

    BLOCK_SIZE_M = 32
    BLOCK_SIZE_N = 64
    grid = ((M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M, (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N)

    resnet_block_kernel[
        grid
    ](
        x, w1, w2, output, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N
    )

    return output

if __name__ == "__main__":
    import torch

    # 初始化输入和权重
    M, K, N = 128, 64, 128  # 输入大小: MxK, 权重大小: KxN
    x = torch.randn((M, K), dtype=torch.float32, device='cuda')
    w1 = torch.randn((K, N), dtype=torch.float32, device='cuda')
    w2 = torch.randn((K, N), dtype=torch.float32, device='cuda')

    # 调用 ResNet 块函数
    output = resnet_block(x, w1, w2)

    # # 打印结果
    # print("Input:\n", x)
    # print("Weight1:\n", w1)
    # print("Weight2:\n", w2)
    # print("Output:\n", output)
This post is licensed under CC BY-NC-SA 4.0 by the author.

LLVM TableGen

LLM Inference Optimization