Skip to content

vllm.v1.attention.backends.mla.flashattn_mla

logger module-attribute

logger = init_logger(__name__)

FlashAttnMLABackend

Bases: MLACommonBackend

Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
class FlashAttnMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN_MLA"

    @staticmethod
    def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
        return FlashAttnMLAMetadata

    @staticmethod
    def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
        return FlashAttnMLAMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["FlashAttnMLAImpl"]:
        return FlashAttnMLAImpl

get_builder_cls staticmethod

get_builder_cls() -> type[FlashAttnMLAMetadataBuilder]
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
@staticmethod
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
    return FlashAttnMLAMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[FlashAttnMLAImpl]
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
@staticmethod
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
    return FlashAttnMLAImpl

get_metadata_cls staticmethod

get_metadata_cls() -> type[FlashAttnMLAMetadata]
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
@staticmethod
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
    return FlashAttnMLAMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
@staticmethod
def get_name() -> str:
    return "FLASH_ATTN_MLA"

FlashAttnMLADecodeMetadata dataclass

Bases: MLACommonDecodeMetadata

Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
    query_start_loc: torch.Tensor
    max_query_len: int
    max_seq_len: int
    scheduler_metadata: Optional[torch.Tensor] = None
    max_num_splits: int = 0

max_num_splits class-attribute instance-attribute

max_num_splits: int = 0

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

query_start_loc instance-attribute

query_start_loc: Tensor

scheduler_metadata class-attribute instance-attribute

scheduler_metadata: Optional[Tensor] = None

__init__

__init__(
    block_table: Tensor,
    seq_lens: Tensor,
    query_start_loc: Tensor,
    max_query_len: int,
    max_seq_len: int,
    scheduler_metadata: Optional[Tensor] = None,
    max_num_splits: int = 0,
) -> None

FlashAttnMLAImpl

Bases: MLACommonImpl[FlashAttnMLAMetadata]

Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
    can_return_lse_for_decode: bool = True

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[list[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            logits_soft_cap: Optional[float],
            attn_type: str,
            kv_sharing_target_layer_name: Optional[str],
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         logits_soft_cap, attn_type,
                         kv_sharing_target_layer_name, **mla_args)

        assert flash_attn_supports_mla(), \
            "FlashAttnMLA is not supported on this device"

        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashAttnMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashAttnMLAImpl")

        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "FlashAttnMLA V1 with FP8 KV cache not yet supported")

    def _forward_decode(
        self,
        q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashAttnMLAMetadata,
        layer: AttentionLayer,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

        if type(q) is tuple:
            q_nope, q_pe = q
        else:
            q_nope, q_pe = torch.split(
                q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

        if self.kv_cache_dtype.startswith("fp8"):
            raise NotImplementedError(
                "FP8 FlashAttention MLA not yet supported")

        kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
        k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]

        # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
        # kernel uses this to calculate grid dimensions. Ensure it's at least 1
        # to prevent invalid grid configuration during graph capture.
        max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)

        attn_out = flash_attn_varlen_func(
            q=q_pe,
            k=k_pe_cache.unsqueeze(-2),  # Add head dim of 1
            v=kv_c_cache.unsqueeze(-2),  # Add head dim of 1
            q_v=q_nope,
            max_seqlen_q=max_seqlen_q,
            cu_seqlens_q=attn_metadata.decode.query_start_loc,
            max_seqlen_k=attn_metadata.decode.max_seq_len,
            seqused_k=attn_metadata.decode.seq_lens,
            block_table=attn_metadata.decode.block_table,
            softmax_scale=self.scale,
            causal=True,
            return_softmax_lse=self.need_to_return_lse_for_decode,
            fa_version=3,  # only version 3 is supported
            scheduler_metadata=attn_metadata.decode.scheduler_metadata,
            num_splits=attn_metadata.decode.max_num_splits,
        )

        if self.need_to_return_lse_for_decode:
            o, lse = attn_out
            # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
            return o, lse.transpose(0, 1)  # [ H, B ] -> [ B, H ]
        else:
            o = attn_out
            return o, None

can_return_lse_for_decode class-attribute instance-attribute

can_return_lse_for_decode: bool = True

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    **mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        **mla_args) -> None:
    super().__init__(num_heads, head_size, scale, num_kv_heads,
                     alibi_slopes, sliding_window, kv_cache_dtype,
                     logits_soft_cap, attn_type,
                     kv_sharing_target_layer_name, **mla_args)

    assert flash_attn_supports_mla(), \
        "FlashAttnMLA is not supported on this device"

    unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
    if any(unsupported_features):
        raise NotImplementedError(
            "FlashAttnMLAImpl does not support one of the following: "
            "alibi_slopes, sliding_window, logits_soft_cap")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "FlashAttnMLAImpl")

    if is_quantized_kv_cache(self.kv_cache_dtype):
        raise NotImplementedError(
            "FlashAttnMLA V1 with FP8 KV cache not yet supported")

_forward_decode

_forward_decode(
    q: Union[Tensor, tuple[Tensor, Tensor]],
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: FlashAttnMLAMetadata,
    layer: AttentionLayer,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
def _forward_decode(
    self,
    q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: FlashAttnMLAMetadata,
    layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert kv_c_and_k_pe_cache.numel() > 0
    assert attn_metadata.decode is not None

    if type(q) is tuple:
        q_nope, q_pe = q
    else:
        q_nope, q_pe = torch.split(
            q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

    if self.kv_cache_dtype.startswith("fp8"):
        raise NotImplementedError(
            "FP8 FlashAttention MLA not yet supported")

    kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
    k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]

    # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
    # kernel uses this to calculate grid dimensions. Ensure it's at least 1
    # to prevent invalid grid configuration during graph capture.
    max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)

    attn_out = flash_attn_varlen_func(
        q=q_pe,
        k=k_pe_cache.unsqueeze(-2),  # Add head dim of 1
        v=kv_c_cache.unsqueeze(-2),  # Add head dim of 1
        q_v=q_nope,
        max_seqlen_q=max_seqlen_q,
        cu_seqlens_q=attn_metadata.decode.query_start_loc,
        max_seqlen_k=attn_metadata.decode.max_seq_len,
        seqused_k=attn_metadata.decode.seq_lens,
        block_table=attn_metadata.decode.block_table,
        softmax_scale=self.scale,
        causal=True,
        return_softmax_lse=self.need_to_return_lse_for_decode,
        fa_version=3,  # only version 3 is supported
        scheduler_metadata=attn_metadata.decode.scheduler_metadata,
        num_splits=attn_metadata.decode.max_num_splits,
    )

    if self.need_to_return_lse_for_decode:
        o, lse = attn_out
        # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
        return o, lse.transpose(0, 1)  # [ H, B ] -> [ B, H ]
    else:
        o = attn_out
        return o, None

FlashAttnMLAMetadata dataclass

Bases: MLACommonMetadata[FlashAttnMLADecodeMetadata]

Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
@dataclass
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
    pass

__init__

__init__(
    num_reqs: int,
    max_query_len: int,
    max_seq_len: int,
    num_actual_tokens: int,
    query_start_loc: Tensor,
    slot_mapping: Tensor,
    num_decodes: int,
    num_decode_tokens: int,
    num_prefills: int,
    head_dim: Optional[int] = None,
    decode: Optional[D] = None,
    prefill: Optional[
        Union[
            MLACommonPrefillMetadata,
            FlashInferPrefillMetadata,
            CudnnPrefillMetadata,
        ]
    ] = None,
) -> None

FlashAttnMLAMetadataBuilder

Bases: MLACommonMetadataBuilder[FlashAttnMLAMetadata]

Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
class FlashAttnMLAMetadataBuilder(
        MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
    cudagraph_support: ClassVar[AttentionCGSupport] = \
        AttentionCGSupport.UNIFORM_BATCH

    reorder_batch_threshold: int = 512

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device,
                         FlashAttnMLAMetadata)
        self.max_num_splits = 0  # No upper bound on the number of splits.
        self.fa_aot_schedule = (get_flash_attn_version() == 3)

        self.use_full_cuda_graph = \
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()

        if self.use_full_cuda_graph and self.fa_aot_schedule:
            self.max_cudagraph_size = self.compilation_config.max_capture_size

            if self.max_cudagraph_size > 992:
                # This condition derives from FA3's internal heuristic.
                # TODO(woosuk): Support larger cudagraph sizes.
                raise ValueError(
                    "Capture size larger than 992 is not supported for "
                    "full cuda graph.")

            self.scheduler_metadata = torch.zeros(
                vllm_config.scheduler_config.max_num_seqs + 1,
                dtype=torch.int32,
                device=self.device,
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
            self.max_num_splits = (
                envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)

        # TODO(lucas): Until we add support for the DCP custom masking we need
        #   to restrict decodes to q_len == 1 when DCP is enabled.
        self.reorder_batch_threshold = 1 \
            if get_dcp_group().world_size > 1 else self.reorder_batch_threshold

    def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
                         max_seq_len, causal):
        if self.fa_aot_schedule:
            return get_scheduler_metadata(
                batch_size=num_reqs,
                max_seqlen_q=max_query_len,
                max_seqlen_k=max_seq_len,
                num_heads_q=self.num_heads,
                num_heads_kv=1,
                headdim=self.mla_dims.qk_rope_head_dim,
                cache_seqlens=seqlens,
                qkv_dtype=self.kv_cache_spec.dtype,
                headdim_v=self.mla_dims.kv_lora_rank,
                page_size=self.page_size,
                cu_seqlens_q=cu_query_lens,
                causal=causal,
                num_splits=self.max_num_splits,
            )
        return None

    def _build_decode(self, block_table_tensor: torch.Tensor,
                      seq_lens_cpu: torch.Tensor,
                      seq_lens_device: torch.Tensor,
                      query_start_loc_cpu: torch.Tensor,
                      query_start_loc_device: torch.Tensor,
                      num_decode_tokens: int) -> FlashAttnMLADecodeMetadata:
        query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
        max_query_len = query_lens_cpu.max().item()
        max_seq_len = seq_lens_cpu.max().item()

        scheduler_metadata = self._schedule_decode(
            num_reqs=seq_lens_cpu.numel(),
            cu_query_lens=query_start_loc_device,
            max_query_len=max_query_len,
            seqlens=seq_lens_device,
            max_seq_len=max_seq_len,
            causal=True,
        )

        # For FA3 + full cudagraph
        max_num_splits = 0
        if self.use_full_cuda_graph and scheduler_metadata is not None:
            n = scheduler_metadata.shape[0]
            # Ensure the persistent buffer is large enough
            assert n <= self.scheduler_metadata.shape[0], \
                f"Scheduler metadata size {n} exceeds buffer size " + \
                f"{self.scheduler_metadata.shape[0]}"
            self.scheduler_metadata[:n] = scheduler_metadata
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

            if num_decode_tokens <= self.max_cudagraph_size:
                # NOTE(woosuk): Setting num_splits > 1 may increase the memory
                # usage, because the intermediate buffers of size [num_splits,
                # num_heads, num_tokens, head_size] are allocated. Therefore,
                # we only set num_splits when using cuda graphs.
                max_num_splits = self.max_num_splits

        return FlashAttnMLADecodeMetadata(
            block_table=block_table_tensor,
            seq_lens=seq_lens_device,
            query_start_loc=query_start_loc_device,
            max_query_len=max_query_len,
            max_seq_len=max_seq_len,
            scheduler_metadata=scheduler_metadata,
            max_num_splits=max_num_splits,
        )

cudagraph_support class-attribute

cudagraph_support: AttentionCGSupport = UNIFORM_BATCH

fa_aot_schedule instance-attribute

fa_aot_schedule = get_flash_attn_version() == 3

max_cudagraph_size instance-attribute

max_cudagraph_size = max_capture_size

max_num_splits instance-attribute

max_num_splits = 0

reorder_batch_threshold class-attribute instance-attribute

reorder_batch_threshold: int = (
    1 if world_size > 1 else reorder_batch_threshold
)

scheduler_metadata instance-attribute

scheduler_metadata = zeros(
    max_num_seqs + 1, dtype=int32, device=device
)

use_full_cuda_graph instance-attribute

use_full_cuda_graph = has_full_cudagraphs()

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):
    super().__init__(kv_cache_spec, layer_names, vllm_config, device,
                     FlashAttnMLAMetadata)
    self.max_num_splits = 0  # No upper bound on the number of splits.
    self.fa_aot_schedule = (get_flash_attn_version() == 3)

    self.use_full_cuda_graph = \
        self.compilation_config.cudagraph_mode.has_full_cudagraphs()

    if self.use_full_cuda_graph and self.fa_aot_schedule:
        self.max_cudagraph_size = self.compilation_config.max_capture_size

        if self.max_cudagraph_size > 992:
            # This condition derives from FA3's internal heuristic.
            # TODO(woosuk): Support larger cudagraph sizes.
            raise ValueError(
                "Capture size larger than 992 is not supported for "
                "full cuda graph.")

        self.scheduler_metadata = torch.zeros(
            vllm_config.scheduler_config.max_num_seqs + 1,
            dtype=torch.int32,
            device=self.device,
        )
        # When using cuda graph, we need to set the upper bound of the
        # number of splits so that large enough intermediate buffers are
        # pre-allocated during capture.
        self.max_num_splits = (
            envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)

    # TODO(lucas): Until we add support for the DCP custom masking we need
    #   to restrict decodes to q_len == 1 when DCP is enabled.
    self.reorder_batch_threshold = 1 \
        if get_dcp_group().world_size > 1 else self.reorder_batch_threshold

_build_decode

_build_decode(
    block_table_tensor: Tensor,
    seq_lens_cpu: Tensor,
    seq_lens_device: Tensor,
    query_start_loc_cpu: Tensor,
    query_start_loc_device: Tensor,
    num_decode_tokens: int,
) -> FlashAttnMLADecodeMetadata
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
def _build_decode(self, block_table_tensor: torch.Tensor,
                  seq_lens_cpu: torch.Tensor,
                  seq_lens_device: torch.Tensor,
                  query_start_loc_cpu: torch.Tensor,
                  query_start_loc_device: torch.Tensor,
                  num_decode_tokens: int) -> FlashAttnMLADecodeMetadata:
    query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
    max_query_len = query_lens_cpu.max().item()
    max_seq_len = seq_lens_cpu.max().item()

    scheduler_metadata = self._schedule_decode(
        num_reqs=seq_lens_cpu.numel(),
        cu_query_lens=query_start_loc_device,
        max_query_len=max_query_len,
        seqlens=seq_lens_device,
        max_seq_len=max_seq_len,
        causal=True,
    )

    # For FA3 + full cudagraph
    max_num_splits = 0
    if self.use_full_cuda_graph and scheduler_metadata is not None:
        n = scheduler_metadata.shape[0]
        # Ensure the persistent buffer is large enough
        assert n <= self.scheduler_metadata.shape[0], \
            f"Scheduler metadata size {n} exceeds buffer size " + \
            f"{self.scheduler_metadata.shape[0]}"
        self.scheduler_metadata[:n] = scheduler_metadata
        # NOTE(woosuk): We should zero out the rest of the scheduler
        # metadata to guarantee the correctness. Otherwise, some thread
        # blocks may use the invalid scheduler metadata and overwrite the
        # output buffer.
        self.scheduler_metadata[n:] = 0
        scheduler_metadata = self.scheduler_metadata[:n]

        if num_decode_tokens <= self.max_cudagraph_size:
            # NOTE(woosuk): Setting num_splits > 1 may increase the memory
            # usage, because the intermediate buffers of size [num_splits,
            # num_heads, num_tokens, head_size] are allocated. Therefore,
            # we only set num_splits when using cuda graphs.
            max_num_splits = self.max_num_splits

    return FlashAttnMLADecodeMetadata(
        block_table=block_table_tensor,
        seq_lens=seq_lens_device,
        query_start_loc=query_start_loc_device,
        max_query_len=max_query_len,
        max_seq_len=max_seq_len,
        scheduler_metadata=scheduler_metadata,
        max_num_splits=max_num_splits,
    )

_schedule_decode

_schedule_decode(
    num_reqs,
    cu_query_lens,
    max_query_len,
    seqlens,
    max_seq_len,
    causal,
)
Source code in vllm/v1/attention/backends/mla/flashattn_mla.py
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
                     max_seq_len, causal):
    if self.fa_aot_schedule:
        return get_scheduler_metadata(
            batch_size=num_reqs,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_seq_len,
            num_heads_q=self.num_heads,
            num_heads_kv=1,
            headdim=self.mla_dims.qk_rope_head_dim,
            cache_seqlens=seqlens,
            qkv_dtype=self.kv_cache_spec.dtype,
            headdim_v=self.mla_dims.kv_lora_rank,
            page_size=self.page_size,
            cu_seqlens_q=cu_query_lens,
            causal=causal,
            num_splits=self.max_num_splits,
        )
    return None