Skip to content

vllm.v1.attention.backends.mla.flashmla_sparse

logger module-attribute

logger = init_logger(__name__)

NOTE: FlashMLA Sparse uses an fp8 cache with the following format

In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.

FlashMLASparseBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
class FlashMLASparseBackend(AttentionBackend):

    accept_output_buffer: bool = True

    @staticmethod
    def get_name() -> str:
        return "FLASHMLA_SPARSE_VLLM_V1"

    @staticmethod
    def get_metadata_cls() -> type[AttentionMetadata]:
        return FlashMLASparseMetadata

    @staticmethod
    def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
        return FlashMLASparseMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["FlashMLASparseImpl"]:
        return FlashMLASparseImpl

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        if cache_dtype_str == "fp8_ds_mla":
            # custom storage fromat is 656 bytes
            #  see FlashMLA readme.md for details
            return (num_blocks, block_size, 656)
        else:
            return (num_blocks, block_size, head_size)

    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [576]

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[FlashMLASparseMetadataBuilder]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
    return FlashMLASparseMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[FlashMLASparseImpl]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_impl_cls() -> type["FlashMLASparseImpl"]:
    return FlashMLASparseImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,  # assumed to be 1 for MLA
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    if cache_dtype_str == "fp8_ds_mla":
        # custom storage fromat is 656 bytes
        #  see FlashMLA readme.md for details
        return (num_blocks, block_size, 656)
    else:
        return (num_blocks, block_size, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
    return FlashMLASparseMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_name() -> str:
    return "FLASHMLA_SPARSE_VLLM_V1"

get_supported_dtypes classmethod

get_supported_dtypes() -> list[dtype]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16]

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return [576]

FlashMLASparseDecodeAndContextMetadata dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FlashMLASparseDecodeAndContextMetadata:
    scheduler_metadata: torch.Tensor = None
    num_splits: torch.Tensor = None
    cache_lens: torch.Tensor = None
    prefill_context_lengths: Optional[torch.Tensor] = None
    prefill_new_k_start_locs: Optional[torch.Tensor] = None
    dummy_block_table: torch.Tensor = None

    def filter_prefill_indices(
            self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        assert self.prefill_context_lengths is not None
        prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
        context_indices = torch.where(indices < prefill_context_lengths,
                                      indices, -1)
        new_token_indices = torch.where(indices >= prefill_context_lengths,
                                        indices - prefill_context_lengths, -1)
        return context_indices, new_token_indices

cache_lens class-attribute instance-attribute

cache_lens: Tensor = None

dummy_block_table class-attribute instance-attribute

dummy_block_table: Tensor = None

num_splits class-attribute instance-attribute

num_splits: Tensor = None

prefill_context_lengths class-attribute instance-attribute

prefill_context_lengths: Optional[Tensor] = None

prefill_new_k_start_locs class-attribute instance-attribute

prefill_new_k_start_locs: Optional[Tensor] = None

scheduler_metadata class-attribute instance-attribute

scheduler_metadata: Tensor = None

__init__

__init__(
    scheduler_metadata: Tensor = None,
    num_splits: Tensor = None,
    cache_lens: Tensor = None,
    prefill_context_lengths: Optional[Tensor] = None,
    prefill_new_k_start_locs: Optional[Tensor] = None,
    dummy_block_table: Tensor = None,
) -> None

filter_prefill_indices

filter_prefill_indices(
    indices: Tensor,
) -> tuple[Tensor, Tensor]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def filter_prefill_indices(
        self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    assert self.prefill_context_lengths is not None
    prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
    context_indices = torch.where(indices < prefill_context_lengths,
                                  indices, -1)
    new_token_indices = torch.where(indices >= prefill_context_lengths,
                                    indices - prefill_context_lengths, -1)
    return context_indices, new_token_indices

FlashMLASparseImpl

Bases: MLACommonBaseImpl[FlashMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[list[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            logits_soft_cap: Optional[float],
            attn_type: str,
            kv_sharing_target_layer_name: Optional[str],
            # MLA Specific Arguments
            topk_indice_buffer: Optional[torch.Tensor] = None,
            indexer: Optional["Indexer"] = None,
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         logits_soft_cap, attn_type,
                         kv_sharing_target_layer_name, **mla_args)
        self.softmax_scale = scale
        assert indexer is not None
        self.topk_indices_buffer = indexer.topk_indices_buffer
        self.padding = 128 if current_platform.is_device_capability(
            100) else 64

    def _forward_bf16_kv(
            self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
            topk_indices: torch.Tensor,
            attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
        num_tokens = q.shape[0]
        kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
            -1, 1, kv_c_and_k_pe_cache.shape[-1])

        # NOTE(Chen): kernel requires num_local_head to be a multiple of
        # 64 on hopper and 128 on blackwell
        if self.num_heads % self.padding != 0:
            assert self.padding % self.num_heads == 0
            logger.warning_once(f"padding num_heads to {self.padding} \
                    due to sparse attn kernel requirement")
            q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
            q_padded[:, :self.num_heads, :] = q
            q = q_padded

        topk_indices = topk_indices.view(num_tokens, 1, -1)
        output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices,
                                          self.softmax_scale)[0]
        output = output[:, :self.num_heads, :]
        return output

    def _forward_fp8_kv(self, q: torch.Tensor,
                        kv_c_and_k_pe_cache: torch.Tensor,
                        topk_indices: torch.Tensor,
                        attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:

        assert attn_metadata.fp8_extra_metadata is not None
        extra_metadata = attn_metadata.fp8_extra_metadata

        _attn_out, _ = flash_mla_with_kvcache(
            q=q.unsqueeze(0),  # unsqueeze to add batch_dim
            k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
            block_table=extra_metadata.dummy_block_table,
            head_dim_v=512,
            cache_seqlens=extra_metadata.cache_lens,
            tile_scheduler_metadata=extra_metadata.scheduler_metadata,
            num_splits=extra_metadata.num_splits,
            is_fp8_kvcache=True,
            indices=topk_indices.unsqueeze(0),  # unsqueeze to add batch_dim
            softmax_scale=self.softmax_scale,
        )

        return _attn_out

    def forward(
        self,
        layer: AttentionLayer,
        q: torch.Tensor,
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
        output_block_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
        # MQA 576/512 approach for both prefill and decode

        assert output is not None, "Output tensor must be provided."

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for MLACommonImpl")

        if attn_metadata is None:
            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)

        num_actual_toks = attn_metadata.num_actual_tokens

        # Inputs and outputs may be padded for CUDA graphs

        q = q[:num_actual_toks, ...]
        k_c_normed = k_c_normed[:num_actual_toks, ...]
        k_pe = k_pe[:num_actual_toks, ...]

        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                               dim=-1)
        # Convert from (B, N, P) to (N, B, P)
        q_nope = q_nope.transpose(0, 1)
        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
        ql_nope = torch.bmm(q_nope, self.W_UK_T)
        # Convert from (N, B, L) to (B, N, L)
        ql_nope = ql_nope.transpose(0, 1)

        topk_indices = self.topk_indices_buffer[:num_actual_toks]

        # TODO: handle index / kv_cache correctly
        topk_indices_global = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
        )

        q = torch.cat([ql_nope, q_pe], dim=-1)

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=layer._k_scale,
            )

        if self.kv_cache_dtype != "fp8_ds_mla":
            attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global,
                                             attn_metadata)
        else:
            attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
                                            attn_metadata)

        self._v_up_proj(attn_out, out=output[:num_actual_toks])
        return output

padding instance-attribute

padding = 128 if is_device_capability(100) else 64

softmax_scale instance-attribute

softmax_scale = scale

topk_indices_buffer instance-attribute

topk_indices_buffer = topk_indices_buffer

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    topk_indice_buffer: Optional[Tensor] = None,
    indexer: Optional[Indexer] = None,
    **mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        topk_indice_buffer: Optional[torch.Tensor] = None,
        indexer: Optional["Indexer"] = None,
        **mla_args) -> None:
    super().__init__(num_heads, head_size, scale, num_kv_heads,
                     alibi_slopes, sliding_window, kv_cache_dtype,
                     logits_soft_cap, attn_type,
                     kv_sharing_target_layer_name, **mla_args)
    self.softmax_scale = scale
    assert indexer is not None
    self.topk_indices_buffer = indexer.topk_indices_buffer
    self.padding = 128 if current_platform.is_device_capability(
        100) else 64

_forward_bf16_kv

_forward_bf16_kv(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _forward_bf16_kv(
        self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
    num_tokens = q.shape[0]
    kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
        -1, 1, kv_c_and_k_pe_cache.shape[-1])

    # NOTE(Chen): kernel requires num_local_head to be a multiple of
    # 64 on hopper and 128 on blackwell
    if self.num_heads % self.padding != 0:
        assert self.padding % self.num_heads == 0
        logger.warning_once(f"padding num_heads to {self.padding} \
                due to sparse attn kernel requirement")
        q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
        q_padded[:, :self.num_heads, :] = q
        q = q_padded

    topk_indices = topk_indices.view(num_tokens, 1, -1)
    output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices,
                                      self.softmax_scale)[0]
    output = output[:, :self.num_heads, :]
    return output

_forward_fp8_kv

_forward_fp8_kv(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _forward_fp8_kv(self, q: torch.Tensor,
                    kv_c_and_k_pe_cache: torch.Tensor,
                    topk_indices: torch.Tensor,
                    attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:

    assert attn_metadata.fp8_extra_metadata is not None
    extra_metadata = attn_metadata.fp8_extra_metadata

    _attn_out, _ = flash_mla_with_kvcache(
        q=q.unsqueeze(0),  # unsqueeze to add batch_dim
        k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
        block_table=extra_metadata.dummy_block_table,
        head_dim_v=512,
        cache_seqlens=extra_metadata.cache_lens,
        tile_scheduler_metadata=extra_metadata.scheduler_metadata,
        num_splits=extra_metadata.num_splits,
        is_fp8_kvcache=True,
        indices=topk_indices.unsqueeze(0),  # unsqueeze to add batch_dim
        softmax_scale=self.softmax_scale,
    )

    return _attn_out

forward

forward(
    layer: AttentionLayer,
    q: Tensor,
    k_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlashMLASparseMetadata,
    output: Optional[Tensor] = None,
    output_scale: Optional[Tensor] = None,
    output_block_scale: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def forward(
    self,
    layer: AttentionLayer,
    q: torch.Tensor,
    k_c_normed: torch.Tensor,  # key in unified attn
    k_pe: torch.Tensor,  # value in unified attn
    kv_cache: torch.Tensor,
    attn_metadata: FlashMLASparseMetadata,
    output: Optional[torch.Tensor] = None,
    output_scale: Optional[torch.Tensor] = None,
    output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
    # MQA 576/512 approach for both prefill and decode

    assert output is not None, "Output tensor must be provided."

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported"
            " for MLACommonImpl")

    if attn_metadata is None:
        # The zero fill is required when used with DP + EP
        # to ensure all ranks within a DP group compute the
        # same expert outputs.
        return output.fill_(0)

    num_actual_toks = attn_metadata.num_actual_tokens

    # Inputs and outputs may be padded for CUDA graphs

    q = q[:num_actual_toks, ...]
    k_c_normed = k_c_normed[:num_actual_toks, ...]
    k_pe = k_pe[:num_actual_toks, ...]

    q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
                           dim=-1)
    # Convert from (B, N, P) to (N, B, P)
    q_nope = q_nope.transpose(0, 1)
    # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
    ql_nope = torch.bmm(q_nope, self.W_UK_T)
    # Convert from (N, B, L) to (B, N, L)
    ql_nope = ql_nope.transpose(0, 1)

    topk_indices = self.topk_indices_buffer[:num_actual_toks]

    # TODO: handle index / kv_cache correctly
    topk_indices_global = triton_convert_req_index_to_global_index(
        attn_metadata.req_id_per_token,
        attn_metadata.block_table,
        topk_indices,
        BLOCK_SIZE=attn_metadata.block_size,
        NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
    )

    q = torch.cat([ql_nope, q_pe], dim=-1)

    # write the latent and rope to kv cache
    if kv_cache.numel() > 0:
        ops.concat_and_cache_mla(
            k_c_normed,
            k_pe.squeeze(1),
            kv_cache,
            attn_metadata.slot_mapping.flatten(),
            kv_cache_dtype=self.kv_cache_dtype,
            scale=layer._k_scale,
        )

    if self.kv_cache_dtype != "fp8_ds_mla":
        attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global,
                                         attn_metadata)
    else:
        attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
                                        attn_metadata)

    self._v_up_proj(attn_out, out=output[:num_actual_toks])
    return output

FlashMLASparseMetadata dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FlashMLASparseMetadata:
    num_reqs: int
    max_query_len: int
    max_seq_len: int

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

    block_table: torch.Tensor
    req_id_per_token: torch.Tensor
    block_size: int = 64
    topk_tokens: int = 2048

    @dataclass
    class FP8KernelMetadata:
        scheduler_metadata: Optional[torch.Tensor]
        num_splits: torch.Tensor
        dummy_block_table: torch.Tensor
        cache_lens: torch.Tensor

    fp8_extra_metadata: Optional[FP8KernelMetadata] = None

block_size class-attribute instance-attribute

block_size: int = 64

block_table instance-attribute

block_table: Tensor

fp8_extra_metadata class-attribute instance-attribute

fp8_extra_metadata: Optional[FP8KernelMetadata] = None

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_reqs instance-attribute

num_reqs: int

query_start_loc instance-attribute

query_start_loc: Tensor

req_id_per_token instance-attribute

req_id_per_token: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

topk_tokens class-attribute instance-attribute

topk_tokens: int = 2048

FP8KernelMetadata dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FP8KernelMetadata:
    scheduler_metadata: Optional[torch.Tensor]
    num_splits: torch.Tensor
    dummy_block_table: torch.Tensor
    cache_lens: torch.Tensor

cache_lens instance-attribute

cache_lens: Tensor

dummy_block_table instance-attribute

dummy_block_table: Tensor

num_splits instance-attribute

num_splits: Tensor

scheduler_metadata instance-attribute

scheduler_metadata: Optional[Tensor]

__init__

__init__(
    scheduler_metadata: Optional[Tensor],
    num_splits: Tensor,
    dummy_block_table: Tensor,
    cache_lens: Tensor,
) -> None

__init__

__init__(
    num_reqs: int,
    max_query_len: int,
    max_seq_len: int,
    num_actual_tokens: int,
    query_start_loc: Tensor,
    slot_mapping: Tensor,
    block_table: Tensor,
    req_id_per_token: Tensor,
    block_size: int = 64,
    topk_tokens: int = 2048,
    fp8_extra_metadata: Optional[FP8KernelMetadata] = None,
) -> None

FlashMLASparseMetadataBuilder dataclass

Bases: AttentionMetadataBuilder[FlashMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FlashMLASparseMetadataBuilder(
        AttentionMetadataBuilder[FlashMLASparseMetadata]):
    cudagraph_support: ClassVar[AttentionCGSupport] = \
        AttentionCGSupport.UNIFORM_BATCH

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):

        cache_config = vllm_config.cache_config
        self.kv_cache_spec = kv_cache_spec
        self.model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
        self.device = device

        props = torch.cuda.get_device_properties(device)
        sm_count = props.multi_processor_count

        self.num_heads = self.model_config.get_num_attention_heads(
            parallel_config)
        self.mla_dims = get_mla_dims(self.model_config)
        self.topk_tokens = vllm_config.model_config.hf_config.index_topk
        self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
        self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
                                               device=device,
                                               dtype=torch.int32)
        self.max_model_len_tensor = torch.tensor(
            [self.model_config.max_model_len],
            device=device,
            dtype=torch.int32)
        # this is ignored by `flash_mla_with_kvcache` if indices not None
        self.dummy_block_table = torch.empty((1, 1),
                                             dtype=torch.int32,
                                             device=self.device)

        # Equation taken from FlashMLA/csrc/pybind.cpp
        h_q, h_k = self.num_heads, 1
        s_q = 1  # inversely proportional to s_q, so s_q = 1 is the largest
        max_num_sm_parts = int(
            max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
        if current_platform.is_device_capability(100):
            max_num_sm_parts *= 2
        self.tile_scheduler_metadata_buffer = torch.empty(
            # TileSchedulerMetaDataSize = 8
            # see: FlashMLA/csrc/params.h
            (max_num_sm_parts, 8),
            dtype=torch.int32,
            device=device)
        self.num_splits_buffer = torch.empty(
            # We pack all the tokens into one batch for sparse attention.
            # Otherwise, we can exceed the sm of `get_mla_metadata`.
            (
                2, ),
            dtype=torch.int32,
            device=device)
        self.req_id_per_token_buffer = torch.empty(
            (vllm_config.scheduler_config.max_num_batched_tokens, ),
            dtype=torch.int32,
            device=device)

    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashMLASparseMetadata:

        num_tokens = common_attn_metadata.num_actual_tokens
        starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
                            dtype=np.int32)
        seg_lengths = np.diff(starts)
        req_id_per_token = np.repeat(
            np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
        # Zero-fill for cudagraphs
        self.req_id_per_token_buffer.fill_(0)
        self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
            .copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
        req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

        fp8_extra_metadata = None
        if self.use_fp8_kv_cache:
            tile_scheduler_metadata, num_splits = get_mla_metadata(
                cache_seqlens=self.topk_tokens_tensor,
                num_q_tokens_per_head_k=num_tokens * self.num_heads,
                topk=self.topk_tokens,
                num_heads_q=self.num_heads,
                num_heads_k=1,
                is_fp8_kvcache=True,
            )

            num_sm_parts = tile_scheduler_metadata.size(0)
            # Copy to persistent buffer for full-CG support
            tile_scheduler_metadata_buffer = \
                self.tile_scheduler_metadata_buffer[:num_sm_parts]
            tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
            self.num_splits_buffer.copy_(num_splits)

            fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
                scheduler_metadata=tile_scheduler_metadata_buffer,
                num_splits=self.num_splits_buffer,
                # cache_lens and block_table are basically unused in sparse case
                # but the decode kernel will treat -1 and indices >= cache_lens
                # as invalid so we make sure cache_lens is large enough to not
                # accidentally mark indices invalid, we will use -1 exclusively
                # to mark invalid indices
                cache_lens=self.max_model_len_tensor,
                dummy_block_table=self.dummy_block_table)

        metadata = FlashMLASparseMetadata(
            num_reqs=common_attn_metadata.num_reqs,
            max_query_len=common_attn_metadata.max_query_len,
            max_seq_len=common_attn_metadata.max_seq_len,
            num_actual_tokens=common_attn_metadata.num_actual_tokens,
            query_start_loc=common_attn_metadata.query_start_loc,
            slot_mapping=common_attn_metadata.slot_mapping,
            block_table=common_attn_metadata.block_table_tensor,
            req_id_per_token=req_id_per_token,
            block_size=self.kv_cache_spec.block_size,
            topk_tokens=self.topk_tokens,
            fp8_extra_metadata=fp8_extra_metadata,
        )
        return metadata

cudagraph_support class-attribute

cudagraph_support: AttentionCGSupport = UNIFORM_BATCH

device instance-attribute

device = device

dummy_block_table instance-attribute

dummy_block_table = empty(
    (1, 1), dtype=int32, device=device
)

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

max_model_len_tensor instance-attribute

max_model_len_tensor = tensor(
    [max_model_len], device=device, dtype=int32
)

mla_dims instance-attribute

mla_dims = get_mla_dims(model_config)

model_config instance-attribute

model_config = model_config

num_heads instance-attribute

num_heads = get_num_attention_heads(parallel_config)

num_splits_buffer instance-attribute

num_splits_buffer = empty((2,), dtype=int32, device=device)

req_id_per_token_buffer instance-attribute

req_id_per_token_buffer = empty(
    (max_num_batched_tokens,), dtype=int32, device=device
)

tile_scheduler_metadata_buffer instance-attribute

tile_scheduler_metadata_buffer = empty(
    (max_num_sm_parts, 8), dtype=int32, device=device
)

topk_tokens instance-attribute

topk_tokens = index_topk

topk_tokens_tensor instance-attribute

topk_tokens_tensor = tensor(
    [topk_tokens], device=device, dtype=int32
)

use_fp8_kv_cache instance-attribute

use_fp8_kv_cache = cache_dtype == 'fp8_ds_mla'

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):

    cache_config = vllm_config.cache_config
    self.kv_cache_spec = kv_cache_spec
    self.model_config = vllm_config.model_config
    parallel_config = vllm_config.parallel_config
    self.device = device

    props = torch.cuda.get_device_properties(device)
    sm_count = props.multi_processor_count

    self.num_heads = self.model_config.get_num_attention_heads(
        parallel_config)
    self.mla_dims = get_mla_dims(self.model_config)
    self.topk_tokens = vllm_config.model_config.hf_config.index_topk
    self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
    self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
                                           device=device,
                                           dtype=torch.int32)
    self.max_model_len_tensor = torch.tensor(
        [self.model_config.max_model_len],
        device=device,
        dtype=torch.int32)
    # this is ignored by `flash_mla_with_kvcache` if indices not None
    self.dummy_block_table = torch.empty((1, 1),
                                         dtype=torch.int32,
                                         device=self.device)

    # Equation taken from FlashMLA/csrc/pybind.cpp
    h_q, h_k = self.num_heads, 1
    s_q = 1  # inversely proportional to s_q, so s_q = 1 is the largest
    max_num_sm_parts = int(
        max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
    if current_platform.is_device_capability(100):
        max_num_sm_parts *= 2
    self.tile_scheduler_metadata_buffer = torch.empty(
        # TileSchedulerMetaDataSize = 8
        # see: FlashMLA/csrc/params.h
        (max_num_sm_parts, 8),
        dtype=torch.int32,
        device=device)
    self.num_splits_buffer = torch.empty(
        # We pack all the tokens into one batch for sparse attention.
        # Otherwise, we can exceed the sm of `get_mla_metadata`.
        (
            2, ),
        dtype=torch.int32,
        device=device)
    self.req_id_per_token_buffer = torch.empty(
        (vllm_config.scheduler_config.max_num_batched_tokens, ),
        dtype=torch.int32,
        device=device)

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> FlashMLASparseMetadata
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def build(self,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata,
          fast_build: bool = False) -> FlashMLASparseMetadata:

    num_tokens = common_attn_metadata.num_actual_tokens
    starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
                        dtype=np.int32)
    seg_lengths = np.diff(starts)
    req_id_per_token = np.repeat(
        np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
    # Zero-fill for cudagraphs
    self.req_id_per_token_buffer.fill_(0)
    self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
        .copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
    req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

    fp8_extra_metadata = None
    if self.use_fp8_kv_cache:
        tile_scheduler_metadata, num_splits = get_mla_metadata(
            cache_seqlens=self.topk_tokens_tensor,
            num_q_tokens_per_head_k=num_tokens * self.num_heads,
            topk=self.topk_tokens,
            num_heads_q=self.num_heads,
            num_heads_k=1,
            is_fp8_kvcache=True,
        )

        num_sm_parts = tile_scheduler_metadata.size(0)
        # Copy to persistent buffer for full-CG support
        tile_scheduler_metadata_buffer = \
            self.tile_scheduler_metadata_buffer[:num_sm_parts]
        tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
        self.num_splits_buffer.copy_(num_splits)

        fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
            scheduler_metadata=tile_scheduler_metadata_buffer,
            num_splits=self.num_splits_buffer,
            # cache_lens and block_table are basically unused in sparse case
            # but the decode kernel will treat -1 and indices >= cache_lens
            # as invalid so we make sure cache_lens is large enough to not
            # accidentally mark indices invalid, we will use -1 exclusively
            # to mark invalid indices
            cache_lens=self.max_model_len_tensor,
            dummy_block_table=self.dummy_block_table)

    metadata = FlashMLASparseMetadata(
        num_reqs=common_attn_metadata.num_reqs,
        max_query_len=common_attn_metadata.max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        query_start_loc=common_attn_metadata.query_start_loc,
        slot_mapping=common_attn_metadata.slot_mapping,
        block_table=common_attn_metadata.block_table_tensor,
        req_id_per_token=req_id_per_token,
        block_size=self.kv_cache_spec.block_size,
        topk_tokens=self.topk_tokens,
        fp8_extra_metadata=fp8_extra_metadata,
    )
    return metadata

MLASparsePrefillMetadata dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class MLASparsePrefillMetadata:
    # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
    # the kernel is not from flashmla
    block_table: torch.Tensor
    has_context: bool = False
    context_lens: Optional[torch.Tensor] = None

block_table instance-attribute

block_table: Tensor

context_lens class-attribute instance-attribute

context_lens: Optional[Tensor] = None

has_context class-attribute instance-attribute

has_context: bool = False

__init__

__init__(
    block_table: Tensor,
    has_context: bool = False,
    context_lens: Optional[Tensor] = None,
) -> None

_convert_req_index_to_global_index_kernel

_convert_req_index_to_global_index_kernel(
    req_id_ptr,
    block_table_ptr,
    token_indices_ptr,
    out_ptr,
    max_num_blocks_per_req: constexpr,
    BLOCK_SIZE: constexpr,
    BLOCK_N: constexpr,
    bt_stride0,
    bt_stride1,
    ti_stride0,
    ti_stride1,
    out_stride0,
    out_stride1,
)
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@triton.jit
def _convert_req_index_to_global_index_kernel(
    req_id_ptr,  # int32 [num_tokens]
    block_table_ptr,  # int32 [num_requests, max_num_blocks_per_req]
    token_indices_ptr,  # int32 [num_tokens, NUM_TOPK_TOKENS]
    out_ptr,  # int32 [num_tokens, NUM_TOPK_TOKENS]
    # shapes (compile-time where possible)
    max_num_blocks_per_req: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    BLOCK_N: tl.constexpr,  # tile width along columns
    # strides (in elements)
    bt_stride0,
    bt_stride1,
    ti_stride0,
    ti_stride1,
    out_stride0,
    out_stride1,
):
    # program_id(0) -> token_id (row)
    # program_id(1) -> tile index along columns
    token_id = tl.program_id(0)
    tile_id = tl.program_id(1)

    # Each program covers BLOCK_N consecutive columns
    indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)

    # Load request id for this token (no mask: grid is exact)
    req = tl.load(req_id_ptr + token_id)

    # Load token indices for this tile
    ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
    tok = tl.load(ti_ptr)  # int32

    # Only token == -1 should propagate as -1
    is_invalid_tok = tok < 0

    # Compute block id and in-block offset
    block_id = tok // BLOCK_SIZE
    inblock_off = tok % BLOCK_SIZE

    # Guard block_table access
    valid_block = block_id < max_num_blocks_per_req
    bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
    base = tl.load(bt_ptr, mask=valid_block, other=0)

    # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
    out_val = tl.where(is_invalid_tok | (~valid_block), -1,
                       base * BLOCK_SIZE + inblock_off)

    # Store results
    out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
    tl.store(out_ptr_ij, out_val)

_lse2_to_lse

_lse2_to_lse(lse_base2: Tensor) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
    # Convert base-2 LSE to natural-log LSE
    # Keep FP32 for numerical stability during the merge.
    return (lse_base2.to(torch.float32) * math.log(2.0))

triton_convert_req_index_to_global_index

triton_convert_req_index_to_global_index(
    req_id: Tensor,
    block_table: Tensor,
    token_indices: Tensor,
    BLOCK_SIZE: int = 64,
    NUM_TOPK_TOKENS: int = 2048,
    BLOCK_N: int = 128,
)

out[token_id, indice_id] = block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE

Only when token_indices[token_id, indice_id] == -1 do we output -1. For safety, we also output -1 if the derived block_id would be out-of-bounds.

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def triton_convert_req_index_to_global_index(
        req_id: torch.Tensor,  # int32 [num_tokens]
        block_table: torch.
    Tensor,  # int32 [num_requests, max_num_blocks_per_req]
        token_indices: torch.Tensor,  # int32 [num_tokens, NUM_TOPK_TOKENS]
        BLOCK_SIZE: int = 64,
        NUM_TOPK_TOKENS: int = 2048,
        BLOCK_N: int = 128,  # tile width along columns
):
    """
    out[token_id, indice_id] =
        block_table[req_id[token_id], 
            token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
        + token_indices[token_id, indice_id] % BLOCK_SIZE

    Only when token_indices[token_id, indice_id] == -1 do we output -1.
    For safety, we also output -1 if the derived block_id would be 
        out-of-bounds.
    """
    assert req_id.dtype == torch.int32
    assert block_table.dtype == torch.int32
    assert token_indices.dtype == torch.int32
    assert token_indices.shape[1] == NUM_TOPK_TOKENS
    assert NUM_TOPK_TOKENS % BLOCK_N == 0, \
        f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \
        f"BLOCK_N ({BLOCK_N})"

    num_tokens = req_id.shape[0]
    num_requests, max_num_blocks_per_req = block_table.shape
    tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N

    # Ensure contiguous tensors on the same device
    req_id_c = req_id.contiguous()
    block_table_c = block_table.contiguous()
    token_indices_c = token_indices.contiguous()
    out = torch.empty_like(token_indices_c)

    # Strides in elements
    bt_stride0, bt_stride1 = block_table_c.stride()
    ti_stride0, ti_stride1 = token_indices_c.stride()
    out_stride0, out_stride1 = out.stride()

    # Exact 2D grid: tokens × column tiles
    grid = (num_tokens, tiles_per_row)

    _convert_req_index_to_global_index_kernel[grid](
        req_id_c,
        block_table_c,
        token_indices_c,
        out,
        # shapes / constexprs
        max_num_blocks_per_req,
        BLOCK_SIZE,
        BLOCK_N,
        # strides
        bt_stride0,
        bt_stride1,
        ti_stride0,
        ti_stride1,
        out_stride0,
        out_stride1,
    )
    return out