本文记录下本人优化 Triton Kernel
的思路,由于不了解 Cuda
编程以及对 GPU
体系结构知识只是一知半解,所以本文设计的优化思路都比较通用(aka naive)。
kernel写法上请参考 triton language guide、triton tutorial、以及flaggems等项目,网络资料很不错~
IMO,对 Triton Kernel
的优化过程可以简单分为以下两种(因为我目前只会这两步),本文只涉及第一种:
- 浅层优化:通过替换算子、合并kernel、拆时间片循环(sequence轴拆分)等方式实现初步优化。
- 深层优化:分析下降所得IR,使用perf工具,对照算子库实现等方式,优化kernel的下降行为。
以优化 flaggems 中的 layernorm kernel 的 backward
为主线讲解。因为我只用了较为简单通用的方法,使用前向也是一个样优化,就不再说明。
perf test
本文中不介绍相关环境的配置,想上手的同学根据 README 配置就应该不会有啥问题。
正好最近的 commit 支持了对算子的 backward
进行 perf
测试,只不过现在测试函数不全,得自己添加一下,例如测试 layernorm backward kernel
可以在 benchmark/test_reduction_perf.py
中添加:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def test_perf_layernorm_backward():
def layer_norm_args(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
weight = torch.randn([size,], dtype=dtype, device="cuda",)
bias = torch.randn([size,], dtype=dtype,device="cuda",)
return (inp, [size,], weight, bias, )
bench = Benchmark(
op_name="layernorm",
torch_op=torch.layer_norm,
arg_func=layer_norm_args,
dtypes=FLOAT_DTYPES,
batch=REDUCTION_BATCH,
sizes=SIZES,
is_backward=True,
)
bench.run()
然后运行
1
2
cd benchmark
pytest test_reduction_perf.py::test_perf_layernorm_backward -s
kernel optim
layernorm
的 backward kernel
在 flaggems
中的实现分成了两个,一个计算 in_grad
、一个计算 weight_grad
和 bias_grad
。
因为 in_grad
的每个值都需要完整地遍历 col
(即N),而 weight_grad
和 bias_grad
的每个值需要完整地遍历 row
(即M)。为了更清晰理解计算行为,可以看:这篇blog中layernorm backward
的计算推导。
当前实现功能上基本能cover所有的case,性能上我也不知道如何,因为我还没在GPU测过hhh。但还是可以强行优化一下,而且在我的环境下确实有性能提升叻,并且精度测试没问题。
合并 kernel
当看到 kernel
分为了两个,第一反应是合并一下,但是由于 in_grad
、 weight_grad
和 bias_grad
的计算行为分别依赖不同的遍历,导致难以合并。
这时候翻看下官方 tutorial
layernorm backward,虽然也是两个kernel,但是第二个kernel本质上只做了sum,那么我们在第一个kernel中对 partial_dw
和 partial_db
使用 atomic_add
就可以合并为一个kernel,atomic_add
在完成 add
后会有 store
的行为。
kernel
优化后的kernel和 tutorial
中的实现相似:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@triton.jit
def layer_norm_backward_kernel(DX, # pointer to the input gradient
DY, # pointer to the output gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
X, # pointer to the input
W, # pointer to the weights
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_COL_SIZE: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_COL_SIZE)
mask = cols < N
X += row * stride
DY += row * stride
DX += row * stride
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd
wdy = w * dy
c1 = tl.sum(xhat * wdy, axis=0)
c2 = tl.sum(wdy, axis=0)
dx = (wdy - (xhat * c1 + c2) / N) * rstd
# Write dx
tl.store(DX + cols, dx, mask=mask)
# Accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(tl.float32)
partial_db = (dy).to(tl.float32)
# 使用 atomic_add 合并第二个 sum kernel
tl.atomic_add(DW + cols, partial_dw)
tl.atomic_add(DB + cols, partial_db)
launch func
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class LayerNorm(torch.autograd.Function):
... # 这里是forward
@staticmethod
def backward(ctx, out_grad, mean_grad, rstd_grad):
logging.debug("GEMS LAYERNORM BACKWARD")
out_grad = out_grad.contiguous()
x, weight, mean, rstd = ctx.saved_tensors
M, N = ctx.M, ctx.N
# tutorial 中设置的超参数,这里的参数也需要根据硬件来改!!
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# alloc for out
in_grad = torch.empty_like(x)
# 为了保证使用 atomic_add 支持的数据类型,所以直接使用float32
weight_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
bias_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
layer_norm_backward_kernel[(M, )](
in_grad, out_grad, weight_grad, bias_grad, x, weight, mean, rstd,
N, BLOCK_COL_SIZE=BLOCK_SIZE, num_warps = num_warps,
)
weight_grad = weight_grad.to(x.dtype)
bias_grad = bias_grad.to(x.dtype)
tuning config
由于当前kernel没有需要tuning的超参数,所以不需要设置
问题分析
根据上文的 launch
函数可以注意到,当前kernel主要存在以下两个问题:
(1)对M是有大小限制,launch grid直接为(M, 1, 1),可能超出 grid
的限制。
(2)对N是有大小限制,导致kernel无法覆盖所有case
针对问题(1),我们选择对M进行拆时间片循环。
针对问题(2),我们选择修改flaggems
官方实现(也增加拆时间片循环),作为 fallback kernel
,然后根据N的大小去选择最终使用的kernel。
下节我们将依次解决这两个问题。
拆时间片循环
将kernel的grid
按如下设置,保证不超过grid
的最大限制,其中MAX_GRID_NUM
是一个人为设置的超参数,根据硬件设置就好。
1
grid = lambda META: (min(triton.cdiv(M, META['BLOCK_ROW_SIZE']), MAX_GRID_NUM),)
使用该 grid
后,每个kernel处理的数据大小就不一定为一个 [1, BLOCK_COL_SIZE]
。每个pid
处理1或多个大小为 [BLOCK_ROW_SIZE, BLOCK_COL_SIZE]
的数据块:
1
2
3
4
5
6
7
8
pid = tl.program_id(0)
row_start = pid * BLOCK_ROW_SIZE
total_num = min(triton.cdiv(M, META['BLOCK_ROW_SIZE'])
step = total_num * BLOCK_ROW_SIZE
cols = tl.arange(0, BLOCK_COL_SIZE)
for row in range(row_start, M, step):
# 每次处理 [BLOCK_ROW_SIZE, BLOCK_COL_SIZE]
row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
kernel
然后以上一步的kernel为基础,增加拆 row
(即M)的循环,循环中一次处理 [BLOCK_ROW_SIZE, BLOCK_COL_SIZE]
大小的数据:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@triton.jit
def layer_norm_backward_kernel(
DX, # pointer to the input gradient
DY, # pointer to the output gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
X, # pointer to the input
W, # pointer to the weights
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
M, # number of rows in X
N, # number of columns in X
BLOCK_ROW_SIZE: tl.constexpr, BLOCK_COL_SIZE: tl.constexpr):
pid = tl.program_id(0)
row_start = pid * BLOCK_ROW_SIZE
cols = tl.arange(0, BLOCK_COL_SIZE)
num_jobs = tl.num_programs(axis=0)
step = num_jobs * BLOCK_ROW_SIZE
col_mask = cols < N
X += cols[None, :]
DY += cols[None, :]
W += cols[None, :]
DX += cols[None, :]
w = tl.load(W, mask = col_mask, other = 0.0).to(tl.float32)
partial_dw = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
partial_db = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
for row in range(row_start, M, step):
row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
row_mask = row_off < M
# Load data to SRAM
off = row_off[:, None] * N # row的stride为 BLOCK_ROW_SIZE * N
mask = row_mask[:, None] and col_mask
x = tl.load(X + off, mask, other=0.0).to(tl.float32)
dy = tl.load(DY + off, mask, other=0.0).to(tl.float32)
mean = tl.load(Mean + row_off, mask = row_mask)[:, None].to(tl.float32)
rstd = tl.load(Rstd + row_off, mask = row_mask)[:, None].to(tl.float32)
# Compute dx
x_hat = (x - mean) * rstd
wdy = w * dy
# [BLOCK_ROW_SIZE, BLOCK_COL_SIZE] -> [BLOCK_ROW_SIZE]
c1 = tl.sum(x_hat * wdy, axis=1)[:, None]
c2 = tl.sum(wdy, axis=1)[:, None]
dx = (wdy - (x_hat * c1 + c2) / N) * rstd
# Accumulate partial sums for dw/db
partial_dw += (dy * x_hat).to(tl.float32)
partial_db += (dy).to(tl.float32)
# Write dx
tl.store(DX + off, dx.to(x.dtype), mask=mask)
# [BLOCK_ROW_SIZE, BLOCK_COL_SIZE] -> [BLOCK_COL_SIZE]
dw = tl.sum(partial_dw, axis=0)
db = tl.sum(partial_db, axis=0)
tl.atomic_add(DW + cols, dw)
tl.atomic_add(DB + cols, db)
launch func
backward的launch函数部分也是模仿tutorial写的,人为设置一些超参数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class LayerNorm(torch.autograd.Function):
... # 这里是forward
@staticmethod
def backward(ctx, out_grad, mean_grad, rstd_grad):
logging.debug("GEMS LAYERNORM BACKWARD")
out_grad = out_grad.contiguous()
x, weight, mean, rstd = ctx.saved_tensors
M, N = ctx.M, ctx.N
# tutorial 中设置的超参数,这里的参数也需要根据硬件来改!!
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
in_grad = torch.empty_like(x)
# 为了保证使用 atomic_add 支持的数据类型,所以直接使用float32
weight_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
bias_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
grid = lambda META: (min(triton.cdiv(M, META['BLOCK_ROW_SIZE']), MAX_GRID_NUM),)
layer_norm_backward_kernel[grid](
in_grad, out_grad, weight_grad, bias_grad, x, weight, mean, rstd,
M, N, BLOCK_COL_SIZE=BLOCK_SIZE, num_warps = num_warps,
)
weight_grad = weight_grad.to(x.dtype)
bias_grad = bias_grad.to(x.dtype)
tuning config
对 row
进行拆分后,我们就需要tuning BLOCK_ROW_SIZE
,但由于kernel一次还是处理完整的 col
,所以 BLOCK_ROW_SIZE
也不能设置多大。tuning 参数仁者见仁,根据场景做 编译时间和性能的trade-off 就好
1
2
3
4
5
6
7
8
9
10
def cfggen_bw():
block_m = [1, 4, 16, 32]
# num_stages 这里就不提供大概设置多少了
num_stages = [...]
configs=[
triton.Config({"BLOCK_ROW_SIZE": m}, num_stages=s)
for m in block_m
for s in num_stages
],
return configs
需要注意的是,使用 atomic_add
后,若同时设置了多个 tuning config
,会有精度问题,因为每次选择新的 config
时没有对 atomic_add
的 target
重置为0。需要在设计 tuning config
时加一个 reset_to_zero
,大致如下。(这个是大佬告诉我的)
1
2
@libentry() # 这是 flaggems 需要加的
@triton.autotune(configs=cfggen_bw(), key=["M", "N"], reset_to_zero=["DW", "DB"])
问题分析
让我们回顾下初次优化kernel后提出的两个问题:
(1)对M是有大小限制,launch grid直接为(M, 1, 1),可能超出 grid
的限制。
(2)对N是有大小限制,导致kernel无法覆盖所有case
前文已经解决了问题(1),现在我们来考虑问题(2)。我们选择修改flaggems
官方原本的实现,作为 fallback kernel
。
- kernel
实现和官方比较相似,只是增加了时间片拆分,不再展开讲解。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def cfggen_input_bw():
block_m = [1, 4, 16, 32]
block_n = [32, 256, 1024, 2048]
# num_stages 和 num_warps 这里就不提供大概设置多少了
num_stages = [...]
num_warps = [...]
configs = [
triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w, num_stages=s)
for m in block_m
for n in block_n
for s in num_stages
for w in num_warps
]
return configs
@libentry() # 这是 flaggems 需要加的
@triton.autotune(configs=cfggen_input_bw(), key=["M", "N"])
@triton.jit
def input_backward_kernel(
dY,
X,
W,
Mean,
Rstd,
dX,
M,
N,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
row_start = pid * BLOCK_ROW_SIZE
num_jobs = tl.num_programs(axis=0)
step = num_jobs * BLOCK_ROW_SIZE
for row in range(row_start, M, step):
row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
mean = tl.load(Mean + row_off, mask = row_off < M, other = 0.0)[:, None].to(tl.float32)
rstd = tl.load(Rstd + row_off, mask = row_off < M, other = 0.0)[:, None].to(tl.float32)
row_mask = row_off[:, None] < M
off = row_off[:, None] * N
new_dY = dY + off
new_X = X + off
new_DX = dX + off
dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_COL_SIZE):
cols = off + tl.arange(0, BLOCK_COL_SIZE)
col_mask = cols[None, :] < N
mask = row_mask and col_mask
dy = tl.load(new_dY + cols[None, :], mask, other = 0.0).to(tl.float32)
x = tl.load(new_X + cols[None, :], mask, other = 0.0).to(tl.float32)
x_hat = (x - mean) * rstd
w = tl.load(W + cols, mask=cols < N).to(tl.float32)
wdy = dy * w
dx_part2 += wdy
dx_part3 += wdy * x_hat
dx_2 = tl.sum(dx_part2, axis=1)[:, None]
dx_3 = tl.sum(dx_part3, axis=1)[:, None]
for off in range(0, N, BLOCK_COL_SIZE):
cols = off + tl.arange(0, BLOCK_COL_SIZE)
col_mask = cols[None, :] < N
mask = row_mask and col_mask
dy = tl.load(new_dY + cols[None, :], mask, other = 0.0).to(tl.float32)
x = tl.load(new_X + cols[None, :], mask, other = 0.0).to(tl.float32)
w = tl.load(W + cols, mask=cols < N, other = 0.0).to(tl.float32)
x_hat = (x - mean) * rstd
wdy = dy * w
dx = rstd * (wdy - (dx_2 + x_hat * dx_3) / N)
tl.store(new_DX + cols, dx.to(x.dtype), mask=mask)
def cfggen_wb_bw():
block_m = [32, 64, 128, 512, 1024]
block_n = [1, 4, 16, 32]
# num_stages 和 num_warps 这里就不提供大概设置多少了
num_stages = [...]
num_warps = [...]
configs = [
triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_stages=s)
for m in block_m
for n in block_n
for s in num_stages
for w in num_warps
]
return configs
@libentry()
@triton.autotune(configs=cfggen_wb_bw(), key=["M", "N"])
@triton.jit
def weight_bias_backward_kernel(
dY,
X,
Mean,
Rstd,
dW,
dB,
M,
N,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
col_start = pid * BLOCK_COL_SIZE
num_jobs = tl.num_programs(axis=0)
step = num_jobs * BLOCK_COL_SIZE
for col in range(col_start, N, step):
col_off = col + tl.arange(0, BLOCK_COL_SIZE)[None, :]
col_mask = col_off < N
new_dY = dY + col_off
new_X = X + col_off
new_dW = dW + col_off
new_dB = dB + col_off
accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
for off in range(0, M, BLOCK_ROW_SIZE):
rows = off + tl.arange(0, BLOCK_ROW_SIZE)
row_mask = rows[:, None] < M
mask = row_mask and col_mask
dy = tl.load(new_dY + rows[:, None] * N, mask, other = 0.0).to(tl.float32)
x = tl.load(new_X + rows[:, None] * N, mask, other = 0.0).to(tl.float32)
mean = tl.load(Mean + rows, mask = rows < M, other = 0.0)[:, None].to(tl.float32)
rstd = tl.load(Rstd + rows, mask = rows < M, other = 0.0)[:, None].to(tl.float32)
x_hat = (x - mean) * rstd
accW += dy * x_hat
accB += dy
dw = tl.sum(accW, axis=0)
db = tl.sum(accB, axis=0)
tl.store(new_dW, dw[None, :], mask=col_mask)
tl.store(new_dB, db[None, :], mask=col_mask)
- launch func
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# 人为设置超参数,这些都和硬件参数有关,在这里我都是乱设,一学就废
MAX_COL_LEN_BACKWARD = 16392
MAX_GRID_NUM = 65535
class LayerNorm(torch.autograd.Function):
... # 这里是forward
@staticmethod
def backward(ctx, out_grad, mean_grad, rstd_grad):
logging.debug("GEMS LAYERNORM BACKWARD")
out_grad = out_grad.contiguous()
x, weight, mean, rstd = ctx.saved_tensors
M, N = ctx.M, ctx.N
in_grad = torch.empty_like(x)
# tutorial 中设置的超参数,这里的参数也需要根据硬件来改!!
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if (N <= BLOCK_SIZE) and (BLOCK_SIZE <= MAX_COL_LEN_BACKWARD):
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# 为了保证使用 atomic_add 支持的数据类型,所以直接使用float32
weight_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
bias_grad = torch.zeros((weight.shape[0],), dtype=torch.float, device=weight.device)
grid = lambda META: (min(triton.cdiv(M, META['BLOCK_ROW_SIZE']), MAX_GRID_NUM),)
layer_norm_backward_kernel[grid](
in_grad, out_grad, weight_grad, bias_grad, x, weight, mean, rstd,
M, N, BLOCK_COL_SIZE=BLOCK_SIZE, num_warps = num_warps,
)
else:
grid = lambda META: (min(triton.cdiv(M, META['BLOCK_ROW_SIZE']), MAX_GRID_NUM),)
# 每次 kernel 都要处理完整的 N
input_backward_kernel[grid](
out_grad, x, weight, mean, rstd, in_grad, M, N,
)
weight_grad = torch.empty_like(weight)
bias_grad = torch.empty_like(weight)
grid = lambda META: (min(triton.cdiv(N, META['BLOCK_COL_SIZE']), MAX_GRID_NUM),)
# 每次 kernel 都要处理完整的 M
weight_bias_backward_kernel[grid](
out_grad, x, mean, rstd, weight_grad, bias_grad, M, N,
)
return in_grad, None, weight_grad, bias_grad, None, None
替换算子
简单的算子替换
tl.max(a, 0.0)
可以换成tl.where(a > 0, a, 0.0)
x
和y
在tl.load
时用了mask,随后的tl.where(mask, x - y, 0.0)
可以删除- 大规模
reduce(10000->1)
-> 多级reduce(10000->100->1)
- 把低维reduction转为高维reduction(高维取数的连续性会更好)
- …
人为hint
1
2
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
由于编译器无法感知数据的连续性,所以加载数据时会离散地处理数据。 如果编写kernel时提前已知数据连续,可以使用 tl.max_contiguous & tl.multiple_of
去标识加载数据的连续性,这样编译器就可连续地处理该段数据。
input 和 values 是等维度的
- max_contiguous(input, values):对于每个维度i,标识input[i]中 每values[i]个相邻元素 是连续的
例如 values = [4], 则 input 可以是 [0, 1, 2, 3, 8, 9, 10, 11]
- max_constancy(input, values):对于每个维度i,标识input[i]中 每values[i]个相邻元素 是常数
例如 values = [4], 则 input 可以是 [0, 0, 0, 0, 1, 1, 1, 1]
- multiple_of(input, values):对于每个维度i,标识input[i]中 所有元素都是 values[i] 的倍数
例如 values = [2], 则 input 可以是 [0, 2, 4, 6, 8]
1
offs_am = tl.max_contiguous(tl.multiple_of((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M, BLOCK_SIZE_M), BLOCK_SIZE_M)
算法替换
例如:累乘 -> 二分乘法
算法实现上
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
tmp = 1
def mul_acc(x, l, h):
tmp = 1
for i in rang(l, h)
tmp *= x[i]
return tmp
->
def binary_mul(x, l, h):
if l >= h:
return 1
if h - l == 1:
return x[l]
mid = (l + h) // 2
return binary_mul(x, l, mid) * binary_mul(x, mid, h)
以优化 flaggems 中的 prod 为例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@triton.jit
def prod_kernel_mid(
inp,
mid,
M,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul)
mid_ptr = mid + pid
tl.store(mid_ptr, mid_value.to(inp_val.dtype))
首先拆时间片循环:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@triton.jit
def prod_kernel_mid(
inp,
mid,
M,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
num_jobs = tl.num_programs(axis=0)
block_start = pid * BLOCK_SIZE
step = num_jobs * BLOCK_SIZE
_tmp = tl.full([BLOCK_SIZE], value=1.0, dtype=tl.float32)
for off in range(block_start, M, step):
offset = off + tl.arange(0, BLOCK_SIZE)
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
_tmp = _tmp * input_val
mid_value = tl.reduce(_tmp, axis=0, combine_fn=reduce_mul)
tl.store(mid_ptr + pid, mid_value.to(inp_val.dtype))
# launch func
# grid = lambda META: min((triton.cdiv(M, MEAT['BLOCK_SIZE']), MAX_GRID_NUM),)
然后将 _tmp
的 累乘优化为二分乘法(reduce_mul
->归约)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
mid_value = tl.reduce(_tmp, axis=0, combine_fn=reduce_mul)
tl.store(mid_ptr + pid, mid_value.to(inp_val.dtype))
->
# 将数组 _tmp 前一半的元素与后一半的元素相乘,并将结果存储在前一半的位置
# triton.Config({"BLOCK_SIZE": m} for m in [...])
# "ITER_NUM" 为 BLOCK_SIZE 能整除2的次数
# 以 BLOCK_SIZE = 16 为例,ITER_NUM=4
# 例: x _tmp[:BLOCK_SIZE // (2 ** x)] _tmp[BLOCK_SIZE // (2 ** x):BLOCK_SIZE // (2 ** (x - 1))]
# 1 _tmp[:8] _tmp[8:16]
# 2 _tmp[:4] _tmp[4:8]
# 3 _tmp[:2] _tmp[2:4]
# 4 _tmp[:1] _tmp[1:2]
for x in tl.static_range(1, int(ITER_NUM), 1):
_tmp[:BLOCK_SIZE // (2 ** x)] = _tmp[:BLOCK_SIZE // (2 ** x)] * _tmp[BLOCK_SIZE // (2 ** x):(BLOCK_SIZE // (2 ** x)) * 2]
# reduce(_tmp[:2])
res = tl.reduce(_tmp[:BLOCK_SIZE // (2 ** (ITER_NUM - 1))], axis=0, combine_fn=reduce_mul)
tl.store(mid_ptr + pid, res)
# 如果BLOCK_SIZE设置的都是二次幂,并且 {"ITER_NUM": math.log2(m)+1} ,则直接store即可
# tl.store(mid_ptr + pid, _tmp[0])
需要注意的是,上述并行归约优化在tl.reduce
下降过程完成更具泛化性。
思考:为什么二分计算方式能带来性能上的提升?
在软件流水阶段,使用二分的计算形式能够将依赖关系更好地打散,实现更好地并行。但缺点就是所使用地资源更多。
结语
最后的光
自此,本文对 Triton Kernel
的优化行为只涉及替换算子、合并kernel、拆时间片循环等初步优化。并且上述优化还存在进一步的空间,例如
- 在拆时间片循环的时候,我们也可以注意到可以根据选择的
tuning config
提前判断是否需要循环,可以使用一个超参数来控制
我们以 max_kernel
为例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@triton.autotune(...)
@triton.heuristics(
values={
"ONE_TILE_PER_CTA": lambda args: args["M"] <= args["BLOCK_SIZE"] * MAX_GRID_NUM,
},
)
@triton.jit
def max_kernel(
inp,
out,
M,
BLOCK_SIZE: tl.constexpr,
ONE_TILE_PER_CTA: tl.constexpr
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
res = -float("inf")
if ONE_TILE_PER_CTA:
offset = block_start + tl.arange(0, BLOCK_SIZE)
mask = offset < M
inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")).to(tl.float32)
res = tl.max(inp_val)
else:
_tmp = tl.full([BLOCK_SIZE], value=-float("inf"), dtype=tl.float32)
num_jobs = tl.num_programs(axis=0)
step = num_jobs * BLOCK_SIZE
for off in range(block_start, M, step):
offset = off + tl.arange(0, BLOCK_SIZE)
mask = offset < M
inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")).to(tl.float32)
_tmp = tl.where(_tmp > inp_val, _tmp, inp_val)
res = tl.max(_tmp)
tl.atomic_max(out, res.to(tl.float32))
- 又或者,其实拆 M 的时候本来就不需要防止 M 轴越界,即不需要关于 M 的mask,直接根据
pid
去准确地为每个 kernel 分配需要处理的数据范围
上面的 max kernel 最终将 inp 给 reduce 成一个数,那如果我们需要保留某个维度,而对剩下维度做 reduce 呢?即 256x65536 规模的输入,给 reduce 到 256x1 的输出。
这时候,我们不仅需要拆 N (单次处理65536的数据过多),也需要拆 M (减少总任务数量)。
例如我们一共起 256 个任务,那么每个任务分配可以是
data per task | computation in kernel |
---|---|
1x65536 | 1 x reduce(65536)->1 + store |
4x16384 | 4 x reduce(16384)->1 + atomic_max |
64x1024 | 64 x reduce(1024)->1 + atomic_max |
… | … |
明确了这点后,我们的 launch 函数内大致形式如下
1
2
3
4
# shape = [M, 1]
out = torch.full(shape, -float("inf"), dtype=torch.float32, device=inp.device)
grid = lambda meta: (triton.cdiv(MAX_GRID_NUM, meta["NUM_BLOCK_N"]), meta["NUM_BLOCK_N"])
max_kernel[grid](inp, out, M, N)
接下来修改 max_kernel
如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def cfggen():
block_nums = [1, 4, 16, 64, ...]
configs = [
triton.Config({"NUM_BLOCK_N": block_num}, num_warps=.., num_stages=..) for block_num in block_nums
]
return configs
@triton.autotune(configs=cfggen(), key=["N"])
@triton.heuristics(
values={
"BLOCK_SIZE_N": lambda args: (args["N"] + args["NUM_BLOCK_N"] - 1) // args["NUM_BLOCK_N"],
},
)
@triton.jit
def max_kernel(
inp,
out,
M,
N,
NUM_BLOCK_N: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr
):
# 下文的注释以 inp(256x65536),launch_grid = [4, 64, 1] 为例,每个 kernel 需要处理 (64x1024)
# 即 NUM_BLOCK_N = 64, BLOCK_SIZE_N = 1024
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_jobs_m = tl.num_programs(axis=0) # triton.cdiv(MAX_GRID_NUM, meta["NUM_BLOCK_N"])
row_per_job = (M + num_jobs_m - 1) // num_jobs_m # 每次处理 64 行
row_begin = pid_m * row_per_job
row_end = row_begin + row_per_job
if pid_m == (num_jobs_m - 1): # 注意末尾边界
row_end = M
for row_idx in range(row_begin, row_end): # 相当于这里要循环 64 次
off_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = off_n < N
offset = row_idx * N + off_n
inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")).to(tl.float32)
res = tl.max(inp_val)
tl.atomic_max(out + row_idx, res)
以一个二维softmax任务总结优化流程,总数据量 m row, n col,优化过程:
- 原始kernel:一次 BLOCK_SIZE_ROW 行, 1 列
- grid = lambda META: (triton.cdiv(n_rows, META[‘BLOCK_SIZE_ROW’]), 1, 1)
- jobs 间拆分 col
- grid = lambda META: (triton.cdiv(n_rows, META[‘BLOCK_SIZE_ROW’]), triton.cdiv(n_cols, META[‘BLOCK_SIZE_COL’]), 1)
期待
最后,本文所描述的优化行为都比较 naive,并不包含:
- 修改lowering源码
- 使用硬件特性优化
这两者常常和硬件结构相关,各家有各家的说法,再说我也确实不太懂硬件架构(修改源码的地方这里也不好放出来),只能用此文记录下自己naive的优化行为,期望有更多大佬分享优化的经验。
当然也有些简易的优化方法,例如:(在 SIMD 架构上的 load 和 store 操作并没有 SIMT 架构上的 memory-coalecse
优化)
- 增大单次连续 IO 量:对于
load %ptr, %mask, %other
可以转化为select %mask, %load, %other
来确保load
(IO 操作) 的高吞吐。 - 减小单次真实 IO 量:当存在大量地址连续的元素为 constancy(相同值) 时,可以先对
extract_slice %ptr
,再load
,后续 broadcast 或 expand_shape 都行。