Home Triton Kernel Optim
Post
Cancel

Triton Kernel Optim

本文记录下本人优化 Triton Kernel 的思路,由于不了解 Cuda 编程以及对 GPU 体系结构知识只是一知半解,所以本文设计的优化思路都比较通用(aka naive)。

kernel写法上请参考 triton language guidetriton tutorial、以及flaggems等项目,网络资料很不错~


IMO,对 Triton Kernel 的优化过程可以简单分为以下两种(因为我目前只会这两步),本文只涉及第一种:

  • 浅层优化:通过替换算子、合并kernel、拆时间片循环(sequence轴拆分)等方式实现初步优化。
  • 深层优化:分析下降所得IR,使用perf工具,对照算子库实现等方式,优化kernel的下降行为。

以优化 flaggems 中的 layernorm kernelbackward 为主线讲解。因为我只用了较为简单通用的方法,使用前向也是一个样优化,就不再说明。

perf test

本文中不介绍相关环境的配置,想上手的同学根据 README 配置就应该不会有啥问题。

正好最近的 commit 支持了对算子的 backward 进行 perf 测试,只不过现在测试函数不全,得自己添加一下,例如测试 layernorm backward kernel 可以在 benchmark/test_reduction_perf.py 中添加:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def test_perf_layernorm_backward():
    def layer_norm_args(dtype, batch, size):
        inp = torch.randn([batch, size], dtype=dtype, device="cuda")
        weight = torch.randn([size,], dtype=dtype, device="cuda",)
        bias = torch.randn([size,], dtype=dtype,device="cuda",)
        return (inp, [size,], weight, bias, )
    bench = Benchmark(
        op_name="layernorm",
        torch_op=torch.layer_norm,
        arg_func=layer_norm_args,
        dtypes=FLOAT_DTYPES,
        batch=REDUCTION_BATCH,
        sizes=SIZES,
        is_backward=True,
    )
    bench.run()

然后运行

1
2
cd benchmark
pytest test_reduction_perf.py::test_perf_layernorm_backward -s

kernel optim

layernormbackward kernelflaggems 中的实现分成了两个,一个计算 in_grad、一个计算 weight_gradbias_grad

因为 in_grad 的每个值都需要完整地遍历 col(即N),而 weight_gradbias_grad的每个值需要完整地遍历 row(即M)。为了更清晰理解计算行为,可以看:这篇bloglayernorm backward 的计算推导。

当前实现功能上基本能cover所有的case,性能上我也不知道如何,因为我还没在GPU测过hhh。但还是可以强行优化一下,而且在我的环境下确实有性能提升叻,并且精度测试没问题。

合并 kernel

当看到 kernel 分为了两个,第一反应是合并一下,但是由于 in_gradweight_gradbias_grad 的计算行为分别依赖不同的遍历,导致难以合并。

这时候翻看下官方 tutorial layernorm backward,虽然也是两个kernel,但是第二个kernel本质上只做了sum,那么我们在第一个kernel中对 partial_dwpartial_db 使用 atomic_add 就可以合并为一个kernelatomic_add 在完成 add 后会有 store 的行为。

kernel

优化后的kernel和 tutorial 中的实现相似:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@triton.jit
def layer_norm_backward_kernel(DX,  # pointer to the input gradient
                               DY,  # pointer to the output gradient
                               DW,  # pointer to the partial sum of weights gradient
                               DB,  # pointer to the partial sum of biases gradient
                               X,  # pointer to the input
                               W,  # pointer to the weights
                               Mean,  # pointer to the mean
                               Rstd,  # pointer to the 1/std
                               stride,  # how much to increase the pointer when moving by 1 row
                               N,  # number of columns in X
                               BLOCK_COL_SIZE: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK_COL_SIZE)
    mask = cols < N
    X += row * stride
    DY += row * stride
    DX += row * stride
    # Load data to SRAM
    x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
    dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
    w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
    mean = tl.load(Mean + row)
    rstd = tl.load(Rstd + row)
    # Compute dx
    xhat = (x - mean) * rstd
    wdy = w * dy
    c1 = tl.sum(xhat * wdy, axis=0)
    c2 = tl.sum(wdy, axis=0)
    dx = (wdy - (xhat * c1 + c2) / N) * rstd
    # Write dx
    tl.store(DX + cols, dx, mask=mask)
    # Accumulate partial sums for dw/db
    partial_dw = (dy * xhat).to(tl.float32)
    partial_db = (dy).to(tl.float32)
    # 使用 atomic_add 合并第二个 sum kernel
    tl.atomic_add(DW + cols, partial_dw)
    tl.atomic_add(DB + cols, partial_db)

launch func

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class LayerNorm(torch.autograd.Function):
    ... # 这里是forward
    @staticmethod
    def backward(ctx, out_grad, mean_grad, rstd_grad):
        logging.debug("GEMS LAYERNORM BACKWARD")
        out_grad = out_grad.contiguous()
        x, weight, mean, rstd = ctx.saved_tensors
        M, N = ctx.M, ctx.N

        # tutorial 中设置的超参数,这里的参数也需要根据硬件来改!!
        MAX_FUSED_SIZE = 65536 // x.element_size()
        BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
        if N > BLOCK_SIZE:
            raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
        num_warps = min(max(BLOCK_SIZE // 256, 1), 8)

        # alloc for out
        in_grad = torch.empty_like(x)
        # 为了保证使用 atomic_add 支持的数据类型,所以直接使用float32
        weight_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
        bias_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
        layer_norm_backward_kernel[(M, )](
            in_grad, out_grad, weight_grad, bias_grad, x, weight, mean, rstd,
            N, BLOCK_COL_SIZE=BLOCK_SIZE, num_warps = num_warps,
        )
        weight_grad = weight_grad.to(x.dtype)
        bias_grad = bias_grad.to(x.dtype)

tuning config

由于当前kernel没有需要tuning的超参数,所以不需要设置

问题分析

根据上文的 launch 函数可以注意到,当前kernel主要存在以下两个问题:

(1)对M是有大小限制,launch grid直接为(M, 1, 1),可能超出 grid 的限制。

(2)对N是有大小限制,导致kernel无法覆盖所有case

针对问题(1),我们选择对M进行拆时间片循环

针对问题(2),我们选择修改flaggems官方实现(也增加拆时间片循环),作为 fallback kernel,然后根据N的大小去选择最终使用的kernel。

下节我们将依次解决这两个问题。

拆时间片循环

将kernel的grid按如下设置,保证不超过grid的最大限制,其中MAX_GRID_NUM是一个人为设置的超参数,根据硬件设置就好。

1
grid = lambda META: (min(triton.cdiv(M, META['BLOCK_ROW_SIZE']), MAX_GRID_NUM),)

使用该 grid 后,每个kernel处理的数据大小就不一定为一个 [1, BLOCK_COL_SIZE]。每个pid处理1或多个大小为 [BLOCK_ROW_SIZE, BLOCK_COL_SIZE] 的数据块:

1
2
3
4
5
6
7
8
pid = tl.program_id(0)
row_start = pid * BLOCK_ROW_SIZE
total_num = min(triton.cdiv(M, META['BLOCK_ROW_SIZE'])
step = total_num * BLOCK_ROW_SIZE
cols = tl.arange(0, BLOCK_COL_SIZE)
for row in range(row_start, M, step):
    # 每次处理 [BLOCK_ROW_SIZE, BLOCK_COL_SIZE]
    row_off = row + tl.arange(0, BLOCK_ROW_SIZE)

kernel

然后以上一步的kernel为基础,增加拆 row(即M)的循环,循环中一次处理 [BLOCK_ROW_SIZE, BLOCK_COL_SIZE] 大小的数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@triton.jit
def layer_norm_backward_kernel(
        DX,  # pointer to the input gradient
        DY,  # pointer to the output gradient
        DW,  # pointer to the partial sum of weights gradient
        DB,  # pointer to the partial sum of biases gradient
        X,  # pointer to the input
        W,  # pointer to the weights
        Mean,  # pointer to the mean
        Rstd,  # pointer to the 1/std
        M,  # number of rows in X
        N,  # number of columns in X
        BLOCK_ROW_SIZE: tl.constexpr, BLOCK_COL_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    row_start = pid * BLOCK_ROW_SIZE
    cols = tl.arange(0, BLOCK_COL_SIZE)
    num_jobs = tl.num_programs(axis=0)
    step = num_jobs * BLOCK_ROW_SIZE
    col_mask = cols < N

    X += cols[None, :]
    DY += cols[None, :]
    W += cols[None, :]
    DX += cols[None, :]
    w = tl.load(W, mask = col_mask, other = 0.0).to(tl.float32)

    partial_dw = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
    partial_db = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
    for row in range(row_start, M, step):
        row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
        row_mask = row_off < M
        # Load data to SRAM
        off = row_off[:, None] * N # row的stride为 BLOCK_ROW_SIZE * N
        mask = row_mask[:, None] and col_mask
        x = tl.load(X + off, mask, other=0.0).to(tl.float32)
        dy = tl.load(DY + off, mask, other=0.0).to(tl.float32)
        mean = tl.load(Mean + row_off, mask = row_mask)[:, None].to(tl.float32)
        rstd = tl.load(Rstd + row_off, mask = row_mask)[:, None].to(tl.float32)
        # Compute dx
        x_hat = (x - mean) * rstd
        wdy = w * dy
        #  [BLOCK_ROW_SIZE, BLOCK_COL_SIZE] -> [BLOCK_ROW_SIZE]
        c1 = tl.sum(x_hat * wdy, axis=1)[:, None]
        c2 = tl.sum(wdy, axis=1)[:, None]
        dx = (wdy - (x_hat * c1 + c2) / N) * rstd
        # Accumulate partial sums for dw/db
        partial_dw += (dy * x_hat).to(tl.float32)
        partial_db += (dy).to(tl.float32)
        # Write dx
        tl.store(DX + off, dx.to(x.dtype), mask=mask)

    #  [BLOCK_ROW_SIZE, BLOCK_COL_SIZE] -> [BLOCK_COL_SIZE]
    dw = tl.sum(partial_dw, axis=0)
    db = tl.sum(partial_db, axis=0)
    tl.atomic_add(DW + cols, dw)
    tl.atomic_add(DB + cols, db)

launch func

backward的launch函数部分也是模仿tutorial写的,人为设置一些超参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class LayerNorm(torch.autograd.Function):
    ... # 这里是forward
    @staticmethod
    def backward(ctx, out_grad, mean_grad, rstd_grad):
        logging.debug("GEMS LAYERNORM BACKWARD")
        out_grad = out_grad.contiguous()
        x, weight, mean, rstd = ctx.saved_tensors
        M, N = ctx.M, ctx.N

        # tutorial 中设置的超参数,这里的参数也需要根据硬件来改!!
        MAX_FUSED_SIZE = 65536 // x.element_size()
        BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
        if N > BLOCK_SIZE:
            raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
        num_warps = min(max(BLOCK_SIZE // 256, 1), 8)

        in_grad = torch.empty_like(x)
        # 为了保证使用 atomic_add 支持的数据类型,所以直接使用float32
        weight_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
        bias_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
        grid = lambda META: (min(triton.cdiv(M, META['BLOCK_ROW_SIZE']), MAX_GRID_NUM),)
        layer_norm_backward_kernel[grid](
            in_grad, out_grad, weight_grad, bias_grad, x, weight, mean, rstd,
            M, N, BLOCK_COL_SIZE=BLOCK_SIZE, num_warps = num_warps,
        )
        weight_grad = weight_grad.to(x.dtype)
        bias_grad = bias_grad.to(x.dtype)

tuning config

row 进行拆分后,我们就需要tuning BLOCK_ROW_SIZE,但由于kernel一次还是处理完整的 col,所以 BLOCK_ROW_SIZE 也不能设置多大。tuning 参数仁者见仁,根据场景做 编译时间和性能的trade-off 就好

1
2
3
4
5
6
7
8
9
10
def cfggen_bw():
    block_m = [1, 4, 16, 32]
    # num_stages 这里就不提供大概设置多少了
    num_stages = [...]
    configs=[
        triton.Config({"BLOCK_ROW_SIZE": m}, num_stages=s)
        for m in block_m
        for s in num_stages
    ],
    return configs

需要注意的是,使用 atomic_add 后,若同时设置了多个 tuning config ,会有精度问题,因为每次选择新的 config 时没有对 atomic_addtarget 重置为0。需要在设计 tuning config 时加一个 reset_to_zero,大致如下。(这个是大佬告诉我的)

1
2
@libentry() # 这是 flaggems 需要加的
@triton.autotune(configs=cfggen_bw(), key=["M", "N"], reset_to_zero=["DW", "DB"])

问题分析

让我们回顾下初次优化kernel后提出的两个问题:

(1)对M是有大小限制,launch grid直接为(M, 1, 1),可能超出 grid 的限制。

(2)对N是有大小限制,导致kernel无法覆盖所有case

前文已经解决了问题(1),现在我们来考虑问题(2)。我们选择修改flaggems官方原本的实现,作为 fallback kernel

  • kernel

实现和官方比较相似,只是增加了时间片拆分,不再展开讲解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def cfggen_input_bw():
    block_m = [1, 4, 16, 32]
    block_n = [32, 256, 1024, 2048]
    # num_stages 和 num_warps 这里就不提供大概设置多少了
    num_stages = [...]
    num_warps = [...]
    configs = [
        triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w, num_stages=s)
        for m in block_m
        for n in block_n
        for s in num_stages
        for w in num_warps
    ]
    return configs

@libentry() # 这是 flaggems 需要加的
@triton.autotune(configs=cfggen_input_bw(), key=["M", "N"])
@triton.jit
def input_backward_kernel(
    dY,
    X,
    W,
    Mean,
    Rstd,
    dX,
    M,
    N,
    BLOCK_ROW_SIZE: tl.constexpr,
    BLOCK_COL_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    row_start = pid * BLOCK_ROW_SIZE
    num_jobs = tl.num_programs(axis=0)
    step = num_jobs * BLOCK_ROW_SIZE

    for row in range(row_start, M, step):
        row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
        mean = tl.load(Mean + row_off, mask = row_off < M, other = 0.0)[:, None].to(tl.float32)
        rstd = tl.load(Rstd + row_off, mask = row_off < M, other = 0.0)[:, None].to(tl.float32)

        row_mask = row_off[:, None] < M
        off = row_off[:, None] * N
        new_dY = dY + off
        new_X = X + off
        new_DX = dX + off

        dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
        dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)

        for off in range(0, N, BLOCK_COL_SIZE):
            cols = off + tl.arange(0, BLOCK_COL_SIZE)
            col_mask = cols[None, :] < N
            mask = row_mask and col_mask
            dy = tl.load(new_dY + cols[None, :], mask, other = 0.0).to(tl.float32)
            x = tl.load(new_X + cols[None, :], mask, other = 0.0).to(tl.float32)
            x_hat = (x - mean) * rstd
            w = tl.load(W + cols, mask=cols < N).to(tl.float32)
            wdy = dy * w
            dx_part2 += wdy
            dx_part3 += wdy * x_hat

        dx_2 = tl.sum(dx_part2, axis=1)[:, None]
        dx_3 = tl.sum(dx_part3, axis=1)[:, None]

        for off in range(0, N, BLOCK_COL_SIZE):
            cols = off + tl.arange(0, BLOCK_COL_SIZE)
            col_mask = cols[None, :] < N
            mask = row_mask and col_mask
            dy = tl.load(new_dY + cols[None, :], mask, other = 0.0).to(tl.float32)
            x = tl.load(new_X + cols[None, :], mask, other = 0.0).to(tl.float32)
            w = tl.load(W + cols, mask=cols < N, other = 0.0).to(tl.float32)
            x_hat = (x - mean) * rstd
            wdy = dy * w
            dx = rstd * (wdy - (dx_2 + x_hat * dx_3) / N)
            tl.store(new_DX + cols, dx.to(x.dtype), mask=mask)


def cfggen_wb_bw():
    block_m = [32, 64, 128, 512, 1024]
    block_n = [1, 4, 16, 32]
    # num_stages 和 num_warps 这里就不提供大概设置多少了
    num_stages = [...]
    num_warps = [...]
    configs = [
        triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_stages=s)
        for m in block_m
        for n in block_n
        for s in num_stages
        for w in num_warps
    ]
    return configs

@libentry()
@triton.autotune(configs=cfggen_wb_bw(), key=["M", "N"])
@triton.jit
def weight_bias_backward_kernel(
    dY,
    X,
    Mean,
    Rstd,
    dW,
    dB,
    M,
    N,
    BLOCK_ROW_SIZE: tl.constexpr,
    BLOCK_COL_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    col_start = pid * BLOCK_COL_SIZE
    num_jobs = tl.num_programs(axis=0)
    step = num_jobs * BLOCK_COL_SIZE

    for col in range(col_start, N, step):
        col_off = col + tl.arange(0, BLOCK_COL_SIZE)[None, :]
        col_mask = col_off < N

        new_dY = dY + col_off
        new_X = X + col_off
        new_dW = dW + col_off
        new_dB = dB + col_off

        accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
        accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)

        for off in range(0, M, BLOCK_ROW_SIZE):
            rows = off + tl.arange(0, BLOCK_ROW_SIZE)
            row_mask = rows[:, None] < M
            mask = row_mask and col_mask
            dy = tl.load(new_dY + rows[:, None] * N, mask, other = 0.0).to(tl.float32)
            x = tl.load(new_X + rows[:, None] * N, mask, other = 0.0).to(tl.float32)
            mean = tl.load(Mean + rows, mask = rows < M, other = 0.0)[:, None].to(tl.float32)
            rstd = tl.load(Rstd + rows, mask = rows < M, other = 0.0)[:, None].to(tl.float32)
            x_hat = (x - mean) * rstd
            accW += dy * x_hat
            accB += dy
        dw = tl.sum(accW, axis=0)
        db = tl.sum(accB, axis=0)
        tl.store(new_dW, dw[None, :], mask=col_mask)
        tl.store(new_dB, db[None, :], mask=col_mask)
  • launch func
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# 人为设置超参数,这些都和硬件参数有关,在这里我都是乱设,一学就废
MAX_COL_LEN_BACKWARD = 16392
MAX_GRID_NUM = 65535

class LayerNorm(torch.autograd.Function):
    ... # 这里是forward
    @staticmethod
    def backward(ctx, out_grad, mean_grad, rstd_grad):
        logging.debug("GEMS LAYERNORM BACKWARD")
        out_grad = out_grad.contiguous()
        x, weight, mean, rstd = ctx.saved_tensors
        M, N = ctx.M, ctx.N
        in_grad = torch.empty_like(x)
        # tutorial 中设置的超参数,这里的参数也需要根据硬件来改!!
        MAX_FUSED_SIZE = 65536 // x.element_size()
        BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
        if (N <= BLOCK_SIZE) and (BLOCK_SIZE <= MAX_COL_LEN_BACKWARD):
            num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
            # 为了保证使用 atomic_add 支持的数据类型,所以直接使用float32
            weight_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
            bias_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
            grid = lambda META: (min(triton.cdiv(M, META['BLOCK_ROW_SIZE']), MAX_GRID_NUM),)
            layer_norm_backward_kernel[grid](
                in_grad, out_grad, weight_grad, bias_grad, x, weight, mean, rstd,
                M, N, BLOCK_COL_SIZE=BLOCK_SIZE, num_warps = num_warps,
            )
        else:
            grid = lambda META: (min(triton.cdiv(M, META['BLOCK_ROW_SIZE']), MAX_GRID_NUM),)
            # 每次 kernel 都要处理完整的 N
            input_backward_kernel[grid](
                out_grad, x, weight, mean, rstd, in_grad, M, N,
            )
            weight_grad = torch.empty_like(weight)
            bias_grad = torch.empty_like(weight)
            grid = lambda META: (min(triton.cdiv(N, META['BLOCK_COL_SIZE']), MAX_GRID_NUM),)
            # 每次 kernel 都要处理完整的 M
            weight_bias_backward_kernel[grid](
                out_grad, x, mean, rstd, weight_grad, bias_grad, M, N,
            )
        return in_grad, None, weight_grad, bias_grad, None, None

替换算子

简单的算子替换

  • tl.max(a, 0.0) 可以换成 tl.where(a > 0, a, 0.0)
  • xytl.load 时用了mask,随后的 tl.where(mask, x - y, 0.0) 可以删除
  • 大规模 reduce(10000->1) -> 多级 reduce(10000->100->1)
  • 把低维reduction转为高维reduction(高维取数的连续性会更好)

人为hint

1
2
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

由于编译器无法感知数据的连续性,所以加载数据时会离散地处理数据。 如果编写kernel时提前已知数据连续,可以使用 tl.max_contiguous & tl.multiple_of 去标识加载数据的连续性,这样编译器就可连续地处理该段数据。

input 和 values 是等维度的

  • max_contiguous(input, values):对于每个维度i,标识input[i]中 每values[i]个相邻元素 是连续的

例如 values = [4], 则 input 可以是 [0, 1, 2, 3, 8, 9, 10, 11]

  • max_constany(input, values):对于每个维度i,标识input[i]中 每values[i]个相邻元素 是常数

例如 values = [4], 则 input 可以是 [0, 0, 0, 0, 1, 1, 1, 1]

  • multiple_of(input, values):对于每个维度i,标识input[i]中 所有元素都是 values[i] 的倍数

例如 values = [2], 则 input 可以是 [0, 2, 4, 6, 8]

1
offs_am = tl.max_contiguous(tl.multiple_of((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M, BLOCK_SIZE_M), BLOCK_SIZE_M)

算法替换

例如:累乘 -> 二分乘法

算法实现上

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
tmp = 1
def mul_acc(x, l, h):
    tmp = 1
    for i in rang(l, h)
        tmp *= i
    return tmp

->

def binary_mul(x, l, h):
    if l >= h:
        return 1
    if h - l == 1:
        return x[l]
    mid = (l + h) // 2
    return binary_mul(x, l, mid) + binary_mul(x, mid, h)

以优化 flaggems 中的 prod 为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@triton.jit
def prod_kernel_mid(
    inp,
    mid,
    M,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    inp_ptrs = inp + offset
    mask = offset < M
    inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
    mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul)
    mid_ptr = mid + pid
    tl.store(mid_ptr, mid_value.to(inp_val.dtype))

首先拆时间片循环:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@triton.jit
def prod_kernel_mid(
    inp,
    mid,
    M,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    num_jobs = tl.num_programs(axis=0)
    block_start = pid * BLOCK_SIZE
    step = num_jobs * BLOCK_SIZE
    _tmp = tl.full([BLOCK_SIZE], value=1.0, dtype=tl.float32)
    for off in range(block_start, M, step):
        offset = off + tl.arange(0, BLOCK_SIZE)
        mask = offset < M
        inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
        _tmp = _tmp * input_val
    mid_value = tl.reduce(_tmp, axis=0, combine_fn=reduce_mul)
    tl.store(mid_ptr + pid, mid_value.to(inp_val.dtype))

# launch func
# grid = lambda META: min((triton.cdiv(M, MEAT['BLOCK_SIZE']), MAX_GRID_NUM),)

然后将 _tmp 的 累乘优化为二分乘法(reduce_mul->归约规约)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
mid_value = tl.reduce(_tmp, axis=0, combine_fn=reduce_mul)
tl.store(mid_ptr + pid, mid_value.to(inp_val.dtype))

->

# 将数组 _tmp 前一半的元素与后一半的元素相乘,并将结果存储在前一半的位置
# triton.Config({"BLOCK_SIZE": m} for m in [...])
# "ITER_NUM" 为 BLOCK_SIZE 能整除2的次数
# 以 BLOCK_SIZE = 16 为例,ITER_NUM=4
# 例: x   _tmp[:BLOCK_SIZE // (2 ** 1)]   _tmp[BLOCK_SIZE // (2 ** 1):BLOCK_SIZE // (2 ** (x - 1))]
#     1   _tmp[:8]                        _tmp[8:16]
#     2   _tmp[:4]                        _tmp[4:8]
#     3   _tmp[:2]                        _tmp[2:4]
#     4   _tmp[:1]                        _tmp[1:2]
for x in tl.static_range(1, int(ITER_NUM), 1):
    _tmp[:BLOCK_SIZE // (2 ** x)] = _tmp[:BLOCK_SIZE // (2 ** x)] * _tmp[BLOCK_SIZE // (2 ** x):(BLOCK_SIZE // (2 ** x)) * 2]
# reduce(_tmp[:2])
res = tl.reduce(_tmp[:BLOCK_SIZE // (2 ** (ITER_NUM - 1))], axis=0, combine_fn=reduce_mul)
tl.store(mid_ptr + pid, res)

# 如果BLOCK_SIZE设置的都是二次幂,并且 {"ITER_NUM": math.log2(m)+1} ,则直接store即可
# tl.store(mid_ptr + pid, _tmp[0])

需要注意的是,上述并行归约优化在tl.reduce下降过程完成更具泛化性。

思考:为什么二分计算方式能带来性能上的提升?

在软件流水阶段,使用二分的计算形式能够将依赖关系更好地打散,实现更好地并行。但缺点就是所使用地资源更多。


结语

最后的光

自此,本文对 Triton Kernel 的优化行为只涉及替换算子、合并kernel、拆时间片循环等初步优化。并且上述优化还存在进一步的空间,例如

  • 在拆时间片循环的时候,我们也可以注意到可以根据选择的 tuning config 提前判断是否需要循环,可以使用一个超参数来控制

我们以 max_kernel 为例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@triton.autotune(...)
@triton.heuristics(
    values={
        "ONE_TILE_PER_CTA": lambda args: args["M"] <= args["BLOCK_SIZE"] * MAX_GRID_NUM,
    },
)
@triton.jit
def max_kernel(
    inp,
    out,
    M,
    BLOCK_SIZE: tl.constexpr,
    ONE_TILE_PER_CTA: tl.constexpr
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    res = -float("inf")
    if ONE_TILE_PER_CTA:
        offset = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offset < M
        inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")).to(tl.float32)
        res = tl.max(inp_val)
    else:
        _tmp = tl.full([BLOCK_SIZE], value=-float("inf"), dtype=tl.float32)
        num_jobs = tl.num_programs(axis=0)
        step = num_jobs * BLOCK_SIZE
        for off in range(block_start, M, step):
            offset = off + tl.arange(0, BLOCK_SIZE)
            mask = offset < M
            inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")).to(tl.float32)
            _tmp = tl.where(_tmp > inp_val, _tmp, inp_val)
        res = tl.max(_tmp)

    tl.atomic_max(out, res.to(tl.float32))
  • 又或者,其实拆 M 的时候本来就不需要防止 M 轴越界,即不需要关于 M 的mask,直接根据 pid 去准确地为每个 kernel 分配需要处理的数据范围

上面的 max kernel 最终将 inp 给 reduce 成一个数,那如果我们需要保留某个维度,而对剩下维度做 reduce 呢?即 256x65536 规模的输入,给 reduce 到 256x1 的输出。

这时候,我们不仅需要拆 N (单次处理65536的数据过多),也需要拆 M (减少总任务数量)。

例如我们一共起 256 个任务,那么每个任务分配可以是

data per taskcomputation in kernel
1x655361 x reduce(65536)->1 + store
4x163844 x reduce(16384)->1 + atomic_max
64x102464 x reduce(1024)->1 + atomic_max

明确了这点后,我们的 launch 函数内大致形式如下

1
2
3
4
# shape = [M, 1]
out = torch.full(shape, -float("inf"), dtype=torch.float32, device=inp.device)
grid = lambda meta: (triton.cdiv(MAX_GRID_NUM, meta["NUM_BLOCK_N"]), meta["NUM_BLOCK_N"])
max_kernel[grid](inp, out, M, N)

接下来修改 max_kernel 如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def cfggen():
    block_nums = [1, 4, 16, 64, ...]
    configs = [
        triton.Config({"NUM_BLOCK_N": block_num}, num_warps=.., num_stages=..) for block_num in block_nums
    ]
    return configs

@triton.autotune(configs=cfggen(), key=["N"])
@triton.heuristics(
    values={
        "BLOCK_SIZE_N": lambda args: (args["N"] + args["NUM_BLOCK_N"] - 1) // args["NUM_BLOCK_N"],
    },
)
@triton.jit
def max_kernel(
    inp,
    out,
    M,
    N,
    NUM_BLOCK_N: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr
):
    # 下文的注释以 inp(256x65536),launch_grid = [4, 64, 1] 为例,每个 kernel 需要处理 (64x1024)
    # 即 NUM_BLOCK_N = 64, BLOCK_SIZE_N = 1024
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    num_jobs_m = tl.num_programs(axis=0) # triton.cdiv(MAX_GRID_NUM, meta["NUM_BLOCK_N"])
    row_per_job = (M + num_jobs_m - 1) // num_jobs_m # 每次处理 64 行

    row_begin = pid_m * row_per_job
    row_end = row_begin + row_per_job
    if pid_m == (num_jobs_m - 1): # 注意末尾边界
        row_end = M

    for row_idx in range(row_begin, row_end): # 相当于这里要循环 64 次
        off_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        mask = off_n < N
        offset = row_idx * N + off_n
        inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")).to(tl.float32)
        res = tl.max(inp_val)
        tl.atomic_max(out + row_idx, res)

以一个二维softmax任务总结优化流程,总数据量 m row, n col,优化过程:

  • 原始kernel:一次 BLOCK_SIZE_ROW 行, 1 列
    • grid = lambda META: (triton.cdiv(n_rows, META[‘BLOCK_SIZE_ROW’]), 1, 1)
  • jobs 间拆分 col
    • grid = lambda META: (triton.cdiv(n_rows, META[‘BLOCK_SIZE_ROW’]), triton.cdiv(n_cols, META[‘BLOCK_SIZE_COL’]), 1)

期待

最后,本文所描述的优化行为都比较 naive,并不包含:

  • 修改lowering源码
  • 使用硬件特性优化

这两者常常和硬件结构相关,各家有各家的说法,再说我也确实不太懂硬件架构(修改源码的地方这里也不好放出来),只能用此文记录下自己naive的优化行为,期望有更多大佬分享优化的经验。

当然也有些简易的优化方法,例如:(在 SIMD 架构上的 load 和 store 操作并没有 SIMT 架构上的 memory-coalecse 优化)

  • 增大单次连续 IO 量:对于 load %ptr, %mask, %other 可以转化为 select %mask, %load, %other 来确保 load(IO 操作) 的高吞吐。
  • 减小单次真实 IO 量:当存在大量地址连续的元素为 constancy(相同值) 时,可以先对 extract_slice %ptr,再 load,后续 broadcast 或 expand_shape 都行。
This post is licensed under CC BY-NC-SA 4.0 by the author.

Triton Linalg

CPP Note Needed for MLIR