vllm.attention.ops.common ¶
CPTritonContext ¶
The CPTritonContext is used to avoid recompilation of the Triton JIT.
Source code in vllm/attention/ops/common.py
_correct_attn_cp_out_kernel ¶
_correct_attn_cp_out_kernel(
outputs_ptr,
new_output_ptr,
lses_ptr,
vlse_ptr,
outputs_stride_B,
outputs_stride_H,
outputs_stride_D,
lses_stride_N,
lses_stride_B,
lses_stride_H,
lse_idx,
HEAD_DIM: constexpr,
N_ROUNDED: constexpr,
)
Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the final attention output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
outputs_ptr | PointerType | Pointer to input tensor of shape [ B, H, D ] | required |
lses_ptr | PointerType | Pointer to input tensor of shape [ N, B, H ] | required |
new_output_ptr | PointerType | Pointer to output tensor of shape [ B, H, D ] | required |
vlse_ptr | PointerType | Pointer to output tensor of shape [ B, H ] | required |
Source code in vllm/attention/ops/common.py
_pack_seq_kernel ¶
_pack_seq_kernel(
x_ptr,
out_ptr,
lengths_ptr,
N: constexpr,
D: constexpr,
Lmax: constexpr,
PAD_VALUE: constexpr,
BLOCK_T: constexpr,
BLOCK_D: constexpr,
)
Source code in vllm/attention/ops/common.py
_unpack_seq_triton_kernel ¶
_unpack_seq_triton_kernel(
packed_ptr,
out_ptr,
lengths_ptr,
B: constexpr,
Lmax: constexpr,
D: constexpr,
BLOCK_T: constexpr,
BLOCK_D: constexpr,
)
Source code in vllm/attention/ops/common.py
correct_attn_out ¶
correct_attn_out(
out: Tensor,
lses: Tensor,
cp_rank: int,
ctx: CPTritonContext,
) -> tuple[Tensor, Tensor]
Correct the attention output using the all-gathered lses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
out | Tensor | Tensor of shape [ B, H, D ] | required |
lses | Tensor | Tensor of shape [ N, B, H ] | required |
cp_rank | int | Current rank in the context-parallel group | required |
ctx | CPTritonContext | Triton context to avoid recompilation | required |
Returns:
Type | Description |
---|---|
tuple[Tensor, Tensor] | Tuple of (out, lse) with corrected attention and final log-sum-exp. |
Source code in vllm/attention/ops/common.py
cp_lse_ag_out_rs ¶
cp_lse_ag_out_rs(
cp_attn_out: Tensor,
cp_attn_lse: Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
)
cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ]
Source code in vllm/attention/ops/common.py
pack_seq_triton ¶
pack_seq_triton(
x: Tensor,
lengths: Tensor,
pad_value: float = -float("inf"),
block_t: int = 64,
block_d: int = 64,
) -> Tensor
Pack sequences of different lengths into a batched tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x | Tensor | [N, ...] - input tensor where N is total number of tokens | required |
lengths | Tensor | [B] - sequence lengths for each batch | required |
pad_value | float | value to use for padding | -float('inf') |
block_t | int | block size for time dimension | 64 |
block_d | int | block size for feature dimension | 64 |
Returns:
Name | Type | Description |
---|---|---|
packed | Tensor | [B, Lmax, ...] - packed tensor |
Source code in vllm/attention/ops/common.py
unpack_seq_triton ¶
unpack_seq_triton(
packed_tensor: Tensor,
lengths: Tensor,
block_t: int = 64,
block_d: int = 64,
) -> Tensor
Unpack a packed decode query tensor back to the original format. Efficient Triton implementation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
packed_tensor | Tensor | [B, Lmax, ...] - packed tensor from pack_seq_triton | required |
lengths | Tensor | [B] - sequence lengths for each batch | required |
block_t | int | block size for time dimension | 64 |
block_d | int | block size for feature dimension | 64 |
Returns:
Name | Type | Description |
---|---|---|
unpacked_tensor | Tensor | [N, ...] where N = sum(lengths) |