@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
@triton.autotune(configs=[
triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST
for num_warps in [2, 4, 8]
],
key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'])
@triton.jit(do_not_specialize=['T'])
def chunk_local_cumsum_vector_kernel(
s,
o,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr,
REVERSE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, BT)
if REVERSE:
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
else:
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
else:
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))