Skip to content

vllm.v1.attention.backends.mla.flashinfer_mla

FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE module-attribute

FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024

g_fi_workspace module-attribute

g_fi_workspace = zeros(
    FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
    dtype=uint8,
    device="cuda",
)

logger module-attribute

logger = init_logger(__name__)

FlashInferMLABackend

Bases: MLACommonBackend

Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
class FlashInferMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "FLASHINFER_MLA"

    @staticmethod
    def get_impl_cls() -> type["FlashInferMLAImpl"]:
        return FlashInferMLAImpl

get_impl_cls staticmethod

get_impl_cls() -> type[FlashInferMLAImpl]
Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
@staticmethod
def get_impl_cls() -> type["FlashInferMLAImpl"]:
    return FlashInferMLAImpl

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
@staticmethod
def get_name() -> str:
    return "FLASHINFER_MLA"

FlashInferMLAImpl

Bases: MLACommonImpl[MLACommonMetadata]

Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[list[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            logits_soft_cap: Optional[float],
            attn_type: str,
            kv_sharing_target_layer_name: Optional[str],
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         logits_soft_cap, attn_type,
                         kv_sharing_target_layer_name, **mla_args)

        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashInferMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashInferMLAImpl")

        self._workspace_buffer = g_fi_workspace
        self.bmm1_scale: Optional[float] = None
        self.bmm2_scale: Optional[float] = None

    def _forward_decode(
        self,
        q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
        layer: AttentionLayer,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

        if isinstance(q, tuple):
            q_nope, q_pe = q
            q = torch.cat([q_nope, q_pe], dim=-1)

        # trtllm API requires extra dimension q_len_per_request for MTP
        q = q.unsqueeze(1)

        if self.bmm1_scale is None:
            self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
                               self.scale)
        if self.bmm2_scale is None:
            self.bmm2_scale = layer._v_scale_float

        o = trtllm_batch_decode_with_kv_cache_mla(
            query=q,
            kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
            workspace_buffer=self._workspace_buffer,
            qk_nope_head_dim=self.qk_nope_head_dim,
            kv_lora_rank=self.kv_lora_rank,
            qk_rope_head_dim=self.qk_rope_head_dim,
            block_tables=attn_metadata.decode.block_table,
            seq_lens=attn_metadata.decode.seq_lens,
            max_seq_len=attn_metadata.max_seq_len,
            bmm1_scale=self.bmm1_scale,
            bmm2_scale=self.bmm2_scale,
        )

        # TODO: Return LSE pending support from Flashinfer API:
        # https://github.com/flashinfer-ai/flashinfer/pull/1566
        return o, None

_workspace_buffer instance-attribute

_workspace_buffer = g_fi_workspace

bmm1_scale instance-attribute

bmm1_scale: Optional[float] = None

bmm2_scale instance-attribute

bmm2_scale: Optional[float] = None

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    **mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        **mla_args) -> None:
    super().__init__(num_heads, head_size, scale, num_kv_heads,
                     alibi_slopes, sliding_window, kv_cache_dtype,
                     logits_soft_cap, attn_type,
                     kv_sharing_target_layer_name, **mla_args)

    unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
    if any(unsupported_features):
        raise NotImplementedError(
            "FlashInferMLAImpl does not support one of the following: "
            "alibi_slopes, sliding_window, logits_soft_cap")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "FlashInferMLAImpl")

    self._workspace_buffer = g_fi_workspace
    self.bmm1_scale: Optional[float] = None
    self.bmm2_scale: Optional[float] = None

_forward_decode

_forward_decode(
    q: Union[Tensor, tuple[Tensor, Tensor]],
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: MLACommonMetadata,
    layer: AttentionLayer,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
def _forward_decode(
    self,
    q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: MLACommonMetadata,
    layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert kv_c_and_k_pe_cache.numel() > 0
    assert attn_metadata.decode is not None

    if isinstance(q, tuple):
        q_nope, q_pe = q
        q = torch.cat([q_nope, q_pe], dim=-1)

    # trtllm API requires extra dimension q_len_per_request for MTP
    q = q.unsqueeze(1)

    if self.bmm1_scale is None:
        self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
                           self.scale)
    if self.bmm2_scale is None:
        self.bmm2_scale = layer._v_scale_float

    o = trtllm_batch_decode_with_kv_cache_mla(
        query=q,
        kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
        workspace_buffer=self._workspace_buffer,
        qk_nope_head_dim=self.qk_nope_head_dim,
        kv_lora_rank=self.kv_lora_rank,
        qk_rope_head_dim=self.qk_rope_head_dim,
        block_tables=attn_metadata.decode.block_table,
        seq_lens=attn_metadata.decode.seq_lens,
        max_seq_len=attn_metadata.max_seq_len,
        bmm1_scale=self.bmm1_scale,
        bmm2_scale=self.bmm2_scale,
    )

    # TODO: Return LSE pending support from Flashinfer API:
    # https://github.com/flashinfer-ai/flashinfer/pull/1566
    return o, None