Skip to content

vllm.attention.ops.common

CPTritonContext

The CPTritonContext is used to avoid recompilation of the Triton JIT.

Source code in vllm/attention/ops/common.py
class CPTritonContext:
    """ The CPTritonContext is used to avoid recompilation of the Triton JIT.
    """

    def __init__(self):
        self.inner_kernel = None

    def call_kernel(self, kernel, grid, *regular_args, **const_args):
        if self.inner_kernel is None:
            self.inner_kernel = kernel[grid](*regular_args, **const_args)
        else:
            self.inner_kernel[grid](*regular_args)

inner_kernel instance-attribute

inner_kernel = None

__init__

__init__()
Source code in vllm/attention/ops/common.py
def __init__(self):
    self.inner_kernel = None

call_kernel

call_kernel(kernel, grid, *regular_args, **const_args)
Source code in vllm/attention/ops/common.py
def call_kernel(self, kernel, grid, *regular_args, **const_args):
    if self.inner_kernel is None:
        self.inner_kernel = kernel[grid](*regular_args, **const_args)
    else:
        self.inner_kernel[grid](*regular_args)

_correct_attn_cp_out_kernel

_correct_attn_cp_out_kernel(
    outputs_ptr,
    new_output_ptr,
    lses_ptr,
    vlse_ptr,
    outputs_stride_B,
    outputs_stride_H,
    outputs_stride_D,
    lses_stride_N,
    lses_stride_B,
    lses_stride_H,
    lse_idx,
    HEAD_DIM: constexpr,
    N_ROUNDED: constexpr,
)

Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the final attention output.

Parameters:

Name Type Description Default
outputs_ptr PointerType

Pointer to input tensor of shape [ B, H, D ]

required
lses_ptr PointerType

Pointer to input tensor of shape [ N, B, H ]

required
new_output_ptr PointerType

Pointer to output tensor of shape [ B, H, D ]

required
vlse_ptr PointerType

Pointer to output tensor of shape [ B, H ]

required
Source code in vllm/attention/ops/common.py
@triton.jit
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
                                vlse_ptr, outputs_stride_B, outputs_stride_H,
                                outputs_stride_D, lses_stride_N, lses_stride_B,
                                lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
                                N_ROUNDED: tl.constexpr):
    """
    Apply the all-gathered lses to correct each local rank's attention
    output. we still need perform a cross-rank reduction to obtain the
    final attention output.

    Args:
        outputs_ptr (triton.PointerType):
            Pointer to input tensor of shape [ B, H, D ]
        lses_ptr (triton.PointerType):
            Pointer to input tensor of shape [ N, B, H ]
        new_output_ptr (triton.PointerType):
            Pointer to output tensor of shape [ B, H, D ]
        vlse_ptr (triton.PointerType):
            Pointer to output tensor of shape [ B, H ]
    """
    batch_idx = tl.program_id(axis=0).to(tl.int64)
    head_idx = tl.program_id(axis=1).to(tl.int64)
    d_offsets = tl.arange(0, HEAD_DIM)
    num_n_offsets = tl.arange(0, N_ROUNDED)

    # shape = [N]
    lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
        lses_stride_B + head_idx * lses_stride_H

    # calc final lse
    lse = tl.load(lses_ptr + lse_offsets)
    lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
    lse_max = tl.max(lse, axis=0)
    lse -= lse_max
    lse_exp = tl.exp(lse)
    lse_acc = tl.sum(lse_exp, axis=0)
    lse = tl.log(lse_acc)
    lse += lse_max

    lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
    tl.store(vlse_ptr + lse_offsets, lse)

    # shape = [D]
    output_offsets = batch_idx * outputs_stride_B + \
                    head_idx * outputs_stride_H + \
                    d_offsets * outputs_stride_D

    # correct output
    lse_offset = lse_idx * lses_stride_N + batch_idx * \
        lses_stride_B + head_idx * lses_stride_H
    lse_tmp = tl.load(lses_ptr + lse_offset)
    lse_finally = lse_tmp - lse
    lse_finally = tl.where(
        (lse_finally != lse_finally) | (lse_finally == float('inf')),
        -float('inf'), lse_finally)
    factor = tl.exp(lse_finally)
    output = tl.load(outputs_ptr + output_offsets)
    output = output * factor

    tl.store(new_output_ptr + output_offsets, output)

_pack_seq_kernel

_pack_seq_kernel(
    x_ptr,
    out_ptr,
    lengths_ptr,
    N: constexpr,
    D: constexpr,
    Lmax: constexpr,
    PAD_VALUE: constexpr,
    BLOCK_T: constexpr,
    BLOCK_D: constexpr,
)
Source code in vllm/attention/ops/common.py
@triton.jit
def _pack_seq_kernel(
        x_ptr,  # [N, D]
        out_ptr,  # [B, Lmax, D]
        lengths_ptr,  # *i32, [B]
        N: tl.constexpr,
        D: tl.constexpr,
        Lmax: tl.constexpr,
        PAD_VALUE: tl.constexpr,
        BLOCK_T: tl.constexpr,  # timesteps per program
        BLOCK_D: tl.constexpr  # features per program
):
    pid_b = tl.program_id(0)  # batch id
    pid_t = tl.program_id(1)  # block over time dimension
    pid_d = tl.program_id(2)  # block over feature dimension
    off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)  # [BLOCK_T]
    off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)  # [BLOCK_D]

    # Compute start index and sequence length from cumulative lengths
    in_start = 0
    for i in range(pid_b):
        in_start += tl.load(lengths_ptr + i)
    seq_len = tl.load(lengths_ptr + pid_b)

    # valid time positions for this block
    t_mask = off_t < Lmax

    # compute input row indices for valid (b, t)
    in_row = in_start + off_t
    valid_row = (off_t < seq_len) & t_mask

    # Pointers
    # x_ptr: row-major [N, D]
    x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]

    # out_ptr: row-major [B, Lmax, D]
    out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
                                                   None] * D + off_d[None, :]

    # Initialize with PAD (cast will occur as needed based on out_ptr dtype)
    d_mask = off_d[None, :] < D
    pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
    tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)

    # Load & write only where within seq_len
    x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
    tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)

_unpack_seq_triton_kernel

_unpack_seq_triton_kernel(
    packed_ptr,
    out_ptr,
    lengths_ptr,
    B: constexpr,
    Lmax: constexpr,
    D: constexpr,
    BLOCK_T: constexpr,
    BLOCK_D: constexpr,
)
Source code in vllm/attention/ops/common.py
@triton.jit
def _unpack_seq_triton_kernel(
        packed_ptr,  # [B, Lmax, D]
        out_ptr,  # [N, D]
        lengths_ptr,  # *i32, [B]
        B: tl.constexpr,
        Lmax: tl.constexpr,
        D: tl.constexpr,
        BLOCK_T: tl.constexpr,  # timesteps per program
        BLOCK_D: tl.constexpr  # features per program
):
    pid_b = tl.program_id(0)  # batch id
    pid_t = tl.program_id(1)  # block over time dimension
    pid_d = tl.program_id(2)  # block over feature dimension
    off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)  # [BLOCK_T]
    off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)  # [BLOCK_D]

    # bounds: compute start from cumulative lengths
    in_start = 0
    for i in range(pid_b):
        in_start += tl.load(lengths_ptr + i)
    seq_len = tl.load(lengths_ptr + pid_b)

    # valid time positions for this block
    t_mask = off_t < Lmax
    valid_row = (off_t < seq_len) & t_mask

    # compute output row indices for valid (b, t)
    out_row = in_start + off_t

    # Pointers
    # packed_ptr: row-major [B, Lmax, D]
    packed_row_ptr = packed_ptr + (pid_b * Lmax +
                                   off_t)[:, None] * D + off_d[None, :]

    # out_ptr: row-major [N, D]
    out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]

    # Load from packed tensor and store to output
    d_mask = off_d[None, :] < D
    packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
    tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)

correct_attn_out

correct_attn_out(
    out: Tensor,
    lses: Tensor,
    cp_rank: int,
    ctx: CPTritonContext,
) -> tuple[Tensor, Tensor]

Correct the attention output using the all-gathered lses.

Parameters:

Name Type Description Default
out Tensor

Tensor of shape [ B, H, D ]

required
lses Tensor

Tensor of shape [ N, B, H ]

required
cp_rank int

Current rank in the context-parallel group

required
ctx CPTritonContext

Triton context to avoid recompilation

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (out, lse) with corrected attention and final log-sum-exp.

Source code in vllm/attention/ops/common.py
def correct_attn_out(
        out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
        ctx: CPTritonContext) -> tuple[torch.Tensor, torch.Tensor]:
    """Correct the attention output using the all-gathered lses.

    Args:
        out: Tensor of shape [ B, H, D ]
        lses: Tensor of shape [ N, B, H ]
        cp_rank: Current rank in the context-parallel group
        ctx: Triton context to avoid recompilation

    Returns:
        Tuple of (out, lse) with corrected attention and final log-sum-exp.
    """
    if ctx is None:
        ctx = CPTritonContext()

    lse = torch.empty_like(lses[0])

    grid = (out.shape[0], out.shape[1], 1)
    regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
                    cp_rank)
    const_args = {
        "HEAD_DIM": out.shape[-1],
        "N_ROUNDED": lses.shape[0],
    }

    ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
                    **const_args)
    return out, lse

cp_lse_ag_out_rs

cp_lse_ag_out_rs(
    cp_attn_out: Tensor,
    cp_attn_lse: Tensor,
    cp_group: GroupCoordinator,
    ctx: CPTritonContext = None,
)

cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ]

Source code in vllm/attention/ops/common.py
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
                     cp_attn_lse: torch.Tensor,
                     cp_group: GroupCoordinator,
                     ctx: CPTritonContext = None):
    """
    cp_attn_out: [ B, H, D ]
    cp_attn_lse: [ B, H ]
    """
    if cp_group.world_size == 1:
        return cp_attn_out

    if ctx is None:
        ctx = CPTritonContext()

    lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
                       dtype=cp_attn_lse.dtype,
                       device=cp_attn_lse.device)

    cp_attn_lse = cp_attn_lse.contiguous()
    lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
    out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
    out = cp_group.reduce_scatter(out, dim=1)
    return out

pack_seq_triton

pack_seq_triton(
    x: Tensor,
    lengths: Tensor,
    pad_value: float = -float("inf"),
    block_t: int = 64,
    block_d: int = 64,
) -> Tensor

Pack sequences of different lengths into a batched tensor.

Parameters:

Name Type Description Default
x Tensor

[N, ...] - input tensor where N is total number of tokens

required
lengths Tensor

[B] - sequence lengths for each batch

required
pad_value float

value to use for padding

-float('inf')
block_t int

block size for time dimension

64
block_d int

block size for feature dimension

64

Returns:

Name Type Description
packed Tensor

[B, Lmax, ...] - packed tensor

Source code in vllm/attention/ops/common.py
def pack_seq_triton(x: torch.Tensor,
                    lengths: torch.Tensor,
                    pad_value: float = -float('inf'),
                    block_t: int = 64,
                    block_d: int = 64) -> torch.Tensor:
    """
    Pack sequences of different lengths into a batched tensor.

    Args:
        x: [N, ...] - input tensor where N is total number of tokens
        lengths: [B] - sequence lengths for each batch
        pad_value: value to use for padding
        block_t: block size for time dimension
        block_d: block size for feature dimension

    Returns:
        packed: [B, Lmax, ...] - packed tensor
    """

    # Handle multi-dimensional input by reshaping to (N, -1)
    original_shape = x.shape
    if len(original_shape) > 2:
        N = original_shape[0]
        x_reshaped = x.reshape(N, -1)
        D = x_reshaped.shape[1]
    else:
        N, D = x.shape
        x_reshaped = x

    B = lengths.numel()
    Lmax = int(lengths.max().item())

    # Starts are computed inside the kernel from lengths

    out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)

    grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
    _pack_seq_kernel[grid](x_reshaped,
                           out,
                           lengths.int(),
                           N,
                           D,
                           Lmax,
                           PAD_VALUE=float(pad_value),
                           BLOCK_T=block_t,
                           BLOCK_D=block_d,
                           num_warps=4,
                           num_stages=2)

    # Reshape output back to original dimensions (except first dimension)
    if len(original_shape) > 2:
        output_shape = (B, Lmax) + original_shape[1:]
        out = out.reshape(output_shape)

    return out

unpack_seq_triton

unpack_seq_triton(
    packed_tensor: Tensor,
    lengths: Tensor,
    block_t: int = 64,
    block_d: int = 64,
) -> Tensor

Unpack a packed decode query tensor back to the original format. Efficient Triton implementation.

Parameters:

Name Type Description Default
packed_tensor Tensor

[B, Lmax, ...] - packed tensor from pack_seq_triton

required
lengths Tensor

[B] - sequence lengths for each batch

required
block_t int

block size for time dimension

64
block_d int

block size for feature dimension

64

Returns:

Name Type Description
unpacked_tensor Tensor

[N, ...] where N = sum(lengths)

Source code in vllm/attention/ops/common.py
def unpack_seq_triton(packed_tensor: torch.Tensor,
                      lengths: torch.Tensor,
                      block_t: int = 64,
                      block_d: int = 64) -> torch.Tensor:
    """
    Unpack a packed decode query tensor back to the original format.
    Efficient Triton implementation.

    Args:
        packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
        lengths: [B] - sequence lengths for each batch
        block_t: block size for time dimension
        block_d: block size for feature dimension

    Returns:
        unpacked_tensor: [N, ...] where N = sum(lengths)
    """

    # Handle multi-dimensional input by reshaping to (B, Lmax, -1)
    original_shape = packed_tensor.shape
    if len(original_shape) > 3:
        B, Lmax = original_shape[:2]
        packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
        D = packed_reshaped.shape[2]
    else:
        B, Lmax, D = packed_tensor.shape
        packed_reshaped = packed_tensor

    # Calculate total number of elements
    N = int(lengths.sum().item())

    out = torch.empty((N, D),
                      device=packed_tensor.device,
                      dtype=packed_tensor.dtype)

    grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
    _unpack_seq_triton_kernel[grid](packed_reshaped,
                                    out,
                                    lengths.int(),
                                    B,
                                    Lmax,
                                    D,
                                    BLOCK_T=block_t,
                                    BLOCK_D=block_d,
                                    num_warps=4,
                                    num_stages=2)

    # Reshape output back to original dimensions (except first dimension)
    if len(original_shape) > 3:
        output_shape = (N, ) + original_shape[2:]
        out = out.reshape(output_shape)

    return out