vllm.attention.ops.flashmla ¶
flash_mla_sparse_prefill ¶
flash_mla_sparse_prefill(
q: Tensor,
kv: Tensor,
indices: Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[Tensor, Tensor, Tensor]
Sparse attention prefill kernel
Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - d_v: The dimension of value vectors. Can only be 512
- (output, max_logits, lse) About the definition of output, max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
Source code in vllm/attention/ops/flashmla.py
flash_mla_with_kvcache ¶
flash_mla_with_kvcache(
q: Tensor,
k_cache: Tensor,
block_table: Tensor,
cache_seqlens: Tensor,
head_dim_v: int,
tile_scheduler_metadata: Tensor,
num_splits: Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[Tensor] = None,
descale_k: Optional[Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]
Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head dimension of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. - softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. - is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md - indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the indices
array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up indices
, please refer to README.md.
Returns: - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
Source code in vllm/attention/ops/flashmla.py
get_mla_metadata ¶
get_mla_metadata(
cache_seqlens: Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None,
) -> Tuple[Tensor, Tensor]
Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. - num_heads_k: The number of k heads. - num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the indices
array passed to flash_mla_with_kvcache_sm90
will be attended to.
- tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
Source code in vllm/attention/ops/flashmla.py
is_flashmla_supported ¶
Return: is_supported_flag, unsupported_reason (optional).