Skip to content

vllm.v1.attention.backends.mla.indexer

logger module-attribute

logger = init_logger(__name__)

DeepSeekV32IndexerDecodeMetadata dataclass

Source code in vllm/v1/attention/backends/mla/indexer.py
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
    block_table: torch.Tensor
    seq_lens: torch.Tensor
    decode_lens: torch.Tensor
    requires_padding: bool
    schedule_metadata: torch.Tensor

block_table instance-attribute

block_table: Tensor

decode_lens instance-attribute

decode_lens: Tensor

requires_padding instance-attribute

requires_padding: bool

schedule_metadata instance-attribute

schedule_metadata: Tensor

seq_lens instance-attribute

seq_lens: Tensor

__init__

__init__(
    block_table: Tensor,
    seq_lens: Tensor,
    decode_lens: Tensor,
    requires_padding: bool,
    schedule_metadata: Tensor,
) -> None

DeepseekV32IndexerBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mla/indexer.py
class DeepseekV32IndexerBackend(AttentionBackend):

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return DeepseekV32IndexerMetadata

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [32, 64, 128]

    @staticmethod
    def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
        return DeepseekV32IndexerMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        assert num_kv_heads == 1
        return (num_blocks, block_size, head_size)

    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
        return (0, 1, 2)

get_builder_cls staticmethod

get_builder_cls() -> type[
    DeepseekV32IndexerMetadataBuilder
]
Source code in vllm/v1/attention/backends/mla/indexer.py
@staticmethod
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
    return DeepseekV32IndexerMetadataBuilder

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/mla/indexer.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    assert num_kv_heads == 1
    return (num_blocks, block_size, head_size)

get_kv_cache_stride_order staticmethod

get_kv_cache_stride_order() -> tuple[int, ...]
Source code in vllm/v1/attention/backends/mla/indexer.py
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
    return (0, 1, 2)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/mla/indexer.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return DeepseekV32IndexerMetadata

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/mla/indexer.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return [32, 64, 128]

DeepseekV32IndexerMetadata dataclass

Source code in vllm/v1/attention/backends/mla/indexer.py
@dataclass
class DeepseekV32IndexerMetadata:

    # FIXME (zyongye)
    # hacky way to access the data now, need to be in chunked meta
    seq_lens: torch.Tensor

    num_reqs: int
    max_query_len: int
    max_seq_len: int

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor
    # The dimension of the attention heads
    head_dim: int

    # New for MLA (compared to FlashAttention)
    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

    decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
    prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None

decode class-attribute instance-attribute

head_dim instance-attribute

head_dim: int

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

num_reqs instance-attribute

num_reqs: int

prefill class-attribute instance-attribute

query_start_loc instance-attribute

query_start_loc: Tensor

seq_lens instance-attribute

seq_lens: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

__init__

__init__(
    seq_lens: Tensor,
    num_reqs: int,
    max_query_len: int,
    max_seq_len: int,
    num_actual_tokens: int,
    query_start_loc: Tensor,
    slot_mapping: Tensor,
    head_dim: int,
    num_decodes: int,
    num_decode_tokens: int,
    num_prefills: int,
    num_prefill_tokens: int,
    decode: Optional[
        DeepSeekV32IndexerDecodeMetadata
    ] = None,
    prefill: Optional[
        DeepseekV32IndexerPrefillMetadata
    ] = None,
) -> None

DeepseekV32IndexerMetadataBuilder

Bases: AttentionMetadataBuilder

Source code in vllm/v1/attention/backends/mla/indexer.py
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
    cudagraph_support: ClassVar[AttentionCGSupport] = \
        AttentionCGSupport.UNIFORM_BATCH

    reorder_batch_threshold: int = 1

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        scheduler_config = self.vllm_config.scheduler_config
        #NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
        self.max_prefill_buffer_size = get_max_prefill_buffer_size(
            self.vllm_config)
        self.num_speculative_tokens = (
            self.vllm_config.speculative_config.num_speculative_tokens
            if self.vllm_config.speculative_config else 0)
        # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
        self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)

        props = torch.cuda.get_device_properties(self.device)
        sm_count = props.multi_processor_count
        self.num_sms = sm_count

        self.decode_lens_buffer = torch.empty(
            (scheduler_config.max_num_seqs, ),
            dtype=torch.int32,
            device=self.device)

        # See: DeepGMM/csrc/apis/attention.hpp
        self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2),
                                                     dtype=torch.int32,
                                                     device=self.device)

    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> DeepseekV32IndexerMetadata:

        num_reqs = common_attn_metadata.num_reqs
        num_tokens = common_attn_metadata.num_actual_tokens

        device = self.device
        block_table_tensor = common_attn_metadata.block_table_tensor

        query_start_loc = common_attn_metadata.query_start_loc

        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold)

        assert num_decodes + num_prefills == num_reqs
        assert num_decode_tokens + num_prefill_tokens == num_tokens

        prefill_metadata = None
        if num_prefills > 0:
            reqs_start = num_decodes
            prefill_query_start_loc = query_start_loc[
                reqs_start:] - query_start_loc[reqs_start]
            cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
                prefill_query_start_loc,
                common_attn_metadata.seq_lens[reqs_start:])
            total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum()
            assert total_seq_lens < self.max_prefill_buffer_size
            cu_seq_lens = torch.cat([
                torch.zeros(1, dtype=torch.int32, device=device),
                common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0)
            ]).to(torch.int32).cuda()
            prefill_metadata = DeepseekV32IndexerPrefillMetadata(
                block_table=block_table_tensor[reqs_start:, ...],
                query_start_loc=prefill_query_start_loc,
                max_query_len=common_attn_metadata.max_query_len,
                cu_seqlen_ks=cu_seqlen_ks,
                cu_seqlen_ke=cu_seqlen_ke,
                cu_seq_lens=cu_seq_lens,
                total_seq_lens=total_seq_lens,
            )

        decode_metadata = None
        if num_decodes > 0:
            torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
                       out=self.decode_lens_buffer[:num_decodes])
            decode_lens = self.decode_lens_buffer[:num_decodes]
            decode_lens_cpu = torch.diff(
                common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])

            # Use CPU to avoid GPU sync; breaking async scheduling
            requires_padding = (decode_lens_cpu.max()
                                > decode_lens_cpu.min()).item()

            seq_lens = common_attn_metadata.seq_lens[:num_decodes]

            self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
                seq_lens, self.kv_cache_spec.block_size, self.num_sms)
            decode_metadata = DeepSeekV32IndexerDecodeMetadata(
                block_table=common_attn_metadata.
                block_table_tensor[:num_decodes, ...],
                seq_lens=common_attn_metadata.seq_lens[:num_decodes],
                decode_lens=decode_lens,
                requires_padding=requires_padding,
                schedule_metadata=self.scheduler_metadata_buffer,
            )

        attn_metadata = DeepseekV32IndexerMetadata(
            seq_lens=common_attn_metadata.seq_lens,
            num_reqs=common_attn_metadata.num_reqs,
            max_query_len=common_attn_metadata.max_query_len,
            max_seq_len=common_attn_metadata.max_seq_len,
            num_actual_tokens=common_attn_metadata.num_actual_tokens,
            query_start_loc=common_attn_metadata.query_start_loc,
            slot_mapping=common_attn_metadata.slot_mapping,
            head_dim=128,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            prefill=prefill_metadata,
            decode=decode_metadata,
        )

        # if get_tensor_model_parallel_rank() == 0:
        #     logger.info(f"attn_metadata: {attn_metadata}")
        return attn_metadata

cudagraph_support class-attribute

cudagraph_support: AttentionCGSupport = UNIFORM_BATCH

decode_lens_buffer instance-attribute

decode_lens_buffer = empty(
    (max_num_seqs,), dtype=int32, device=device
)

max_prefill_buffer_size instance-attribute

max_prefill_buffer_size = get_max_prefill_buffer_size(
    vllm_config
)

num_sms instance-attribute

num_sms = sm_count

num_speculative_tokens instance-attribute

num_speculative_tokens = (
    num_speculative_tokens if speculative_config else 0
)

reorder_batch_threshold class-attribute instance-attribute

reorder_batch_threshold: int = 1

scheduler_metadata_buffer instance-attribute

scheduler_metadata_buffer = empty(
    (num_sms + 1, 2), dtype=int32, device=device
)

__init__

__init__(*args, **kwargs)
Source code in vllm/v1/attention/backends/mla/indexer.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    scheduler_config = self.vllm_config.scheduler_config
    #NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
    self.max_prefill_buffer_size = get_max_prefill_buffer_size(
        self.vllm_config)
    self.num_speculative_tokens = (
        self.vllm_config.speculative_config.num_speculative_tokens
        if self.vllm_config.speculative_config else 0)
    # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
    self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)

    props = torch.cuda.get_device_properties(self.device)
    sm_count = props.multi_processor_count
    self.num_sms = sm_count

    self.decode_lens_buffer = torch.empty(
        (scheduler_config.max_num_seqs, ),
        dtype=torch.int32,
        device=self.device)

    # See: DeepGMM/csrc/apis/attention.hpp
    self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2),
                                                 dtype=torch.int32,
                                                 device=self.device)

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> DeepseekV32IndexerMetadata
Source code in vllm/v1/attention/backends/mla/indexer.py
def build(self,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata,
          fast_build: bool = False) -> DeepseekV32IndexerMetadata:

    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens

    device = self.device
    block_table_tensor = common_attn_metadata.block_table_tensor

    query_start_loc = common_attn_metadata.query_start_loc

    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
        split_decodes_and_prefills(
            common_attn_metadata,
            decode_threshold=self.reorder_batch_threshold)

    assert num_decodes + num_prefills == num_reqs
    assert num_decode_tokens + num_prefill_tokens == num_tokens

    prefill_metadata = None
    if num_prefills > 0:
        reqs_start = num_decodes
        prefill_query_start_loc = query_start_loc[
            reqs_start:] - query_start_loc[reqs_start]
        cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
            prefill_query_start_loc,
            common_attn_metadata.seq_lens[reqs_start:])
        total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum()
        assert total_seq_lens < self.max_prefill_buffer_size
        cu_seq_lens = torch.cat([
            torch.zeros(1, dtype=torch.int32, device=device),
            common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0)
        ]).to(torch.int32).cuda()
        prefill_metadata = DeepseekV32IndexerPrefillMetadata(
            block_table=block_table_tensor[reqs_start:, ...],
            query_start_loc=prefill_query_start_loc,
            max_query_len=common_attn_metadata.max_query_len,
            cu_seqlen_ks=cu_seqlen_ks,
            cu_seqlen_ke=cu_seqlen_ke,
            cu_seq_lens=cu_seq_lens,
            total_seq_lens=total_seq_lens,
        )

    decode_metadata = None
    if num_decodes > 0:
        torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
                   out=self.decode_lens_buffer[:num_decodes])
        decode_lens = self.decode_lens_buffer[:num_decodes]
        decode_lens_cpu = torch.diff(
            common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])

        # Use CPU to avoid GPU sync; breaking async scheduling
        requires_padding = (decode_lens_cpu.max()
                            > decode_lens_cpu.min()).item()

        seq_lens = common_attn_metadata.seq_lens[:num_decodes]

        self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
            seq_lens, self.kv_cache_spec.block_size, self.num_sms)
        decode_metadata = DeepSeekV32IndexerDecodeMetadata(
            block_table=common_attn_metadata.
            block_table_tensor[:num_decodes, ...],
            seq_lens=common_attn_metadata.seq_lens[:num_decodes],
            decode_lens=decode_lens,
            requires_padding=requires_padding,
            schedule_metadata=self.scheduler_metadata_buffer,
        )

    attn_metadata = DeepseekV32IndexerMetadata(
        seq_lens=common_attn_metadata.seq_lens,
        num_reqs=common_attn_metadata.num_reqs,
        max_query_len=common_attn_metadata.max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        query_start_loc=common_attn_metadata.query_start_loc,
        slot_mapping=common_attn_metadata.slot_mapping,
        head_dim=128,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        prefill=prefill_metadata,
        decode=decode_metadata,
    )

    # if get_tensor_model_parallel_rank() == 0:
    #     logger.info(f"attn_metadata: {attn_metadata}")
    return attn_metadata

DeepseekV32IndexerPrefillMetadata dataclass

Source code in vllm/v1/attention/backends/mla/indexer.py
@dataclass
class DeepseekV32IndexerPrefillMetadata:
    block_table: torch.Tensor
    query_start_loc: torch.Tensor
    max_query_len: int
    cu_seqlen_ks: torch.Tensor
    cu_seqlen_ke: torch.Tensor
    cu_seq_lens: torch.Tensor
    total_seq_lens: int

block_table instance-attribute

block_table: Tensor

cu_seq_lens instance-attribute

cu_seq_lens: Tensor

cu_seqlen_ke instance-attribute

cu_seqlen_ke: Tensor

cu_seqlen_ks instance-attribute

cu_seqlen_ks: Tensor

max_query_len instance-attribute

max_query_len: int

query_start_loc instance-attribute

query_start_loc: Tensor

total_seq_lens instance-attribute

total_seq_lens: int

__init__

__init__(
    block_table: Tensor,
    query_start_loc: Tensor,
    max_query_len: int,
    cu_seqlen_ks: Tensor,
    cu_seqlen_ke: Tensor,
    cu_seq_lens: Tensor,
    total_seq_lens: int,
) -> None

get_max_prefill_buffer_size

get_max_prefill_buffer_size(vllm_config: VllmConfig)
Source code in vllm/v1/attention/backends/mla/indexer.py
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
    max_model_len = vllm_config.model_config.max_model_len
    # max_num_batched_tokens = \
    #     vllm_config.scheduler_config.max_num_batched_tokens
    max_num_seq = vllm_config.scheduler_config.max_num_seqs
    # NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
    return max_model_len * max_num_seq

kv_spans_from_batches

kv_spans_from_batches(
    start_seq_loc: Tensor, seq_len_per_batch: Tensor
) -> tuple[Tensor, Tensor]

Parameters:

Name Type Description Default
start_seq_loc Tensor

1D long tensor [B+1], cumulative counts of selected tokens per batch. Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total.

required
seq_len_per_batch Tensor

1D long tensor [B], full sequence length (KV length) of each batch. Example: [5, 9, 4].

required

Returns:

Name Type Description
start_tensor Tensor

1D long tensor [N], start offset in the concatenated KV cache for each token's batch.

end_location Tensor

1D long tensor [N], exclusive end = start + token's local position. (So the attended KV slice is kv[start:end].)

Assumes each batch contributes its full seq_len_per_batch[i] keys to the KV cache, andthe selected tokens within a batch are the last counts[i] positions of that sequence.

Source code in vllm/v1/attention/backends/mla/indexer.py
def kv_spans_from_batches(
        start_seq_loc: torch.Tensor,
        seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
      start_seq_loc: 1D long tensor [B+1], cumulative counts of 
                     selected tokens per batch.
            Example: [0, 2, 4, 7] -> 
                     batch sizes (selected) [2, 2, 3], N=7 tokens total.
      seq_len_per_batch: 1D long tensor [B], 
                         full sequence length (KV length) of each batch.
                         Example: [5, 9, 4].

    Returns:
      start_tensor: 1D long tensor [N], start offset in the 
                    concatenated KV cache for each token's batch.
      end_location: 1D long tensor [N], 
                    **exclusive** end = start + token's local position.
                    (So the attended KV slice is kv[start:end].)

    Assumes each batch contributes its full `seq_len_per_batch[i]` 
    keys to the KV cache, andthe selected tokens within a batch 
    are the **last** `counts[i]` positions of that sequence.
    """
    q = start_seq_loc.to(dtype=torch.long)
    L = seq_len_per_batch.to(dtype=torch.long, device=q.device)
    assert q.dim() == 1 and L.dim() == 1
    assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"

    # Selected tokens per batch and totals
    counts = q[1:] - q[:-1]  # [B]
    N = int(q[-1].item())  # total selected tokens
    B = L.numel()
    device = L.device

    if N == 0:
        return (torch.empty(0, dtype=torch.long, device=device),
                torch.empty(0, dtype=torch.long, device=device))

    # KV start offsets per batch in the concatenated KV cache
    kv_starts_per_batch = torch.cumsum(L, dim=0) - L  # [B]

    # For each selected token, which batch does it belong to?
    batch_id = torch.repeat_interleave(torch.arange(B, device=device),
                                       counts)  # [N]

    # Map batch KV start to each token
    start_tensor = kv_starts_per_batch[batch_id]  # [N]

    # End-align local positions inside each batch:
    # local_pos = L[b] - counts[b] + (1..counts[b])  for each batch b
    L_expand = torch.repeat_interleave(L, counts)  # [N]
    m_expand = torch.repeat_interleave(counts, counts)  # [N]
    # position within the selected block: 1..counts[b]
    pos_within = (torch.arange(N, device=device, dtype=torch.long) -
                  torch.repeat_interleave(q[:-1], counts) + 1)

    local_pos = L_expand - m_expand + pos_within  # [N], 1-based
    end_location = start_tensor + local_pos  # exclusive end

    return start_tensor.int(), end_location.int()