@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}),
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
triton.Config({'BLOCK_SIZE': 2048}),
],
key=['dim'],
)
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices
states_ptr,
out_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
dim: tl.constexpr,
nchunks,
seqlen,
chunk_size: tl.constexpr,
# Strides
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_dim: tl.constexpr,
stride_out_chunk: tl.int64,
stride_out_head: tl.int64,
stride_out_dim: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_initstates_batch: tl.int64,
stride_initstates_head: tl.int64,
stride_initstates_dim: tl.constexpr,
stride_seq_idx_chunk: tl.constexpr,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_h = tl.program_id(axis=1)
pid_m = tl.program_id(axis=0)
states_ptr += pid_h * stride_states_head
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
1) * stride_dA_cs_csize
out_ptr += pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
if HAS_INITSTATES:
initstates_ptrs = initstates_ptr \
+ pid_h * stride_initstates_head \
+ offs_m * stride_initstates_dim
states = tl.load(initstates_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
else:
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
prev_seq_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
# we have started a new sequence
if prev_seq_idx != seq_idx:
if HAS_INITSTATES:
initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \
+ pid_h * stride_initstates_head \
+ offs_m * stride_initstates_dim
states = tl.load(initstates_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
else:
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
prev_seq_idx = seq_idx
states = tl.exp(dA_cs) * states + new_states
tl.store(out_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk