Skip to content

vllm.attention.ops.flashmla

_flashmla_C_AVAILABLE module-attribute

_flashmla_C_AVAILABLE = True

_flashmla_extension_C_AVAILABLE module-attribute

_flashmla_extension_C_AVAILABLE = True

logger module-attribute

logger = init_logger(__name__)

flash_mla_sparse_prefill

flash_mla_sparse_prefill(
    q: Tensor,
    kv: Tensor,
    indices: Tensor,
    sm_scale: float,
    d_v: int = 512,
) -> Tuple[Tensor, Tensor, Tensor]

Sparse attention prefill kernel

Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - d_v: The dimension of value vectors. Can only be 512

  • (output, max_logits, lse) About the definition of output, max_logits and lse, please refer to README.md
  • output: [s_q, h_q, d_v], bfloat16
  • max_logits: [s_q, h_q], float
  • lse: [s_q, h_q], float, 2-based log-sum-exp
Source code in vllm/attention/ops/flashmla.py
def flash_mla_sparse_prefill(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sparse attention prefill kernel

    Args:
    - q: [s_q, h_q, d_qk], bfloat16
    - kv: [s_kv, h_kv, d_qk], bfloat16
    - indices: [s_q, h_kv, topk], int32. 
        Invalid indices should be set to -1 or numbers >= s_kv
    - sm_scale: float
    - d_v: The dimension of value vectors. Can only be 512

    Returns:
    - (output, max_logits, lse)
        About the definition of output, 
        max_logits and lse, please refer to README.md
    - output: [s_q, h_q, d_v], bfloat16
    - max_logits:  [s_q, h_q], float
    - lse: [s_q, h_q], float, 2-based log-sum-exp
    """
    results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
                                                       sm_scale, d_v)
    return results

flash_mla_with_kvcache

flash_mla_with_kvcache(
    q: Tensor,
    k_cache: Tensor,
    block_table: Tensor,
    cache_seqlens: Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: Tensor,
    num_splits: Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    descale_q: Optional[Tensor] = None,
    descale_k: Optional[Tensor] = None,
    is_fp8_kvcache: bool = False,
    indices: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]

Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head dimension of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. - softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. - is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md - indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the indices array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up indices, please refer to README.md.

Returns: - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.

Source code in vllm/attention/ops/flashmla.py
def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
    - q: (batch_size, seq_len_q, num_heads_q, head_dim).
    - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
    - block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
    - cache_seqlens: (batch_size), torch.int32.
    - head_dim_v: Head dimension of v.
    - tile_scheduler_metadata: 
        (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, 
        returned by get_mla_metadata.
    - num_splits: 
        (batch_size + 1), torch.int32, returned by get_mla_metadata.
    - softmax_scale: float. 
        The scale of QK^T before applying softmax. 
        Default to 1 / sqrt(head_dim).
    - causal: bool. Whether to apply causal attention mask.
    - descale_q: (batch_size), 
        torch.float32. Descaling factors for Q, used for fp8 quantization.
    - descale_k: (batch_size), 
        torch.float32. Descaling factors for K, used for fp8 quantization.
    - is_fp8_kvcache: bool. 
        Whether the k_cache and v_cache are in fp8 format. 
        For the format of FP8 KV cache, please refer to README.md
    - indices: (batch_size, seq_len_q, topk), torch.int32. 
        If not None, sparse attention will be enabled, 
        and only tokens in the `indices` array will be attended to. 
        Invalid indices should be set to -1 or numbers >= total_seq_len_kv. 
        For details about how to set up `indices`, please refer to README.md.

    Returns:
    - out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
    - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1]**(-0.5)
    if indices is not None:
        # NOTE (zyongye): sparse attention is also causal
        # since it only attend to the tokens before
        # but here `causal` should not be specified
        assert not causal, \
            "causal must be `false` if sparse attention is enabled."
    assert (descale_q is None) == (
        descale_k is None
    ), "descale_q and descale_k should be both None or both not None"

    if indices is None and q.element_size() == 1:
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
            q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
            causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
    else:
        out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
            q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
            causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
            indices)
    return out, softmax_lse

get_mla_metadata

get_mla_metadata(
    cache_seqlens: Tensor,
    num_q_tokens_per_head_k: int,
    num_heads_k: int,
    num_heads_q: Optional[int] = None,
    is_fp8_kvcache: bool = False,
    topk: Optional[int] = None,
) -> Tuple[Tensor, Tensor]

Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. - num_heads_k: The number of k heads. - num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the indices array passed to flash_mla_with_kvcache_sm90 will be attended to.

  • tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
  • num_splits: (batch_size + 1), dtype torch.int32.
Source code in vllm/attention/ops/flashmla.py
def get_mla_metadata(
        cache_seqlens: torch.Tensor,
        num_q_tokens_per_head_k: int,
        num_heads_k: int,
        num_heads_q: Optional[int] = None,
        is_fp8_kvcache: bool = False,
        topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
    - cache_seqlens: (batch_size), dtype torch.int32.
    - num_q_tokens_per_head_k: 
            Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
    - num_heads_k: The number of k heads.
    - num_heads_q: 
            The number of q heads. 
            This argument is optional when sparse attention is not enabled
    - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
    - topk: If not None, sparse attention will be enabled, 
            and only tokens in the `indices` array 
            passed to `flash_mla_with_kvcache_sm90` will be attended to.

    Returns:
    - tile_scheduler_metadata: 
            (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    - num_splits: (batch_size + 1), dtype torch.int32.
    """
    return torch.ops._flashmla_C.get_mla_decoding_metadata(
        cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
        is_fp8_kvcache, topk)

is_flashmla_supported

is_flashmla_supported() -> Tuple[bool, Optional[str]]

Return: is_supported_flag, unsupported_reason (optional).

Source code in vllm/attention/ops/flashmla.py
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    if not current_platform.is_cuda():
        return False, "FlashMLA is only supported on CUDA devices."
    if current_platform.get_device_capability()[0] != 9:
        return False, "FlashMLA is only supported on Hopper devices."
    if not _flashmla_C_AVAILABLE:
        return False, "vllm._flashmla_C is not available, likely was not "\
            "compiled due to insufficient nvcc version or a supported arch "\
            "(only sm90a currently) was not in the list of target arches to "\
            "compile for."
    return True, None