Skip to content

vllm.model_executor.layers.mla

MLAModules dataclass

Modules used in MLA.

Source code in vllm/model_executor/layers/mla.py
@dataclass
class MLAModules:
    """Modules used in MLA.
    """
    kv_a_layernorm: torch.nn.Module
    kv_b_proj: torch.nn.Module
    rotary_emb: torch.nn.Module
    o_proj: torch.nn.Module
    fused_qkv_a_proj: Optional[torch.nn.Module]
    kv_a_proj_with_mqa: Optional[torch.nn.Module]
    q_a_layernorm: Optional[torch.nn.Module]
    q_b_proj: Optional[torch.nn.Module]
    q_proj: Optional[torch.nn.Module]
    indexer: Optional[torch.nn.Module]
    is_sparse: bool
    topk_indices_buffer: Optional[torch.Tensor]

fused_qkv_a_proj instance-attribute

fused_qkv_a_proj: Optional[Module]

indexer instance-attribute

indexer: Optional[Module]

is_sparse instance-attribute

is_sparse: bool

kv_a_layernorm instance-attribute

kv_a_layernorm: Module

kv_a_proj_with_mqa instance-attribute

kv_a_proj_with_mqa: Optional[Module]

kv_b_proj instance-attribute

kv_b_proj: Module

o_proj instance-attribute

o_proj: Module

q_a_layernorm instance-attribute

q_a_layernorm: Optional[Module]

q_b_proj instance-attribute

q_b_proj: Optional[Module]

q_proj instance-attribute

q_proj: Optional[Module]

rotary_emb instance-attribute

rotary_emb: Module

topk_indices_buffer instance-attribute

topk_indices_buffer: Optional[Tensor]

__init__

__init__(
    kv_a_layernorm: Module,
    kv_b_proj: Module,
    rotary_emb: Module,
    o_proj: Module,
    fused_qkv_a_proj: Optional[Module],
    kv_a_proj_with_mqa: Optional[Module],
    q_a_layernorm: Optional[Module],
    q_b_proj: Optional[Module],
    q_proj: Optional[Module],
    indexer: Optional[Module],
    is_sparse: bool,
    topk_indices_buffer: Optional[Tensor],
) -> None

MultiHeadLatentAttention

Bases: CustomOp

MLA layer registered as CustomOp. Note that currently MLA ignores the enable/disable mechanism of CustomOp because there is only one in-tree implementation in forward_native. TODO: implement this with a new PluggableLayer mechanism.

This class takes positions and hidden_states as input. The input tensors can either contain prefill tokens or decode tokens. The class does the following:

  1. MLA Preprocess.
  2. Perform multi-head attention to prefill tokens and multi-query attention to decode tokens separately.
  3. Return the output tensor.
Source code in vllm/model_executor/layers/mla.py
@CustomOp.register("multi_head_latent_attention")
class MultiHeadLatentAttention(CustomOp):
    """MLA layer registered as CustomOp.
    Note that currently MLA ignores the enable/disable mechanism of CustomOp
    because there is only one in-tree implementation in forward_native.
    TODO: implement this with a new PluggableLayer mechanism.

    This class takes positions and hidden_states as input. 
    The input tensors can either contain prefill tokens or decode tokens.
    The class does the following:

    1. MLA Preprocess.
    2. Perform multi-head attention to prefill tokens and
       multi-query attention to decode tokens separately.
    3. Return the output tensor.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        mla_modules: MLAModules,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
        self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
        self.q_a_layernorm = mla_modules.q_a_layernorm
        self.q_b_proj = mla_modules.q_b_proj
        self.q_proj = mla_modules.q_proj
        self.kv_a_layernorm = mla_modules.kv_a_layernorm
        self.kv_b_proj = mla_modules.kv_b_proj
        self.rotary_emb = mla_modules.rotary_emb
        self.o_proj = mla_modules.o_proj
        self.indexer = mla_modules.indexer
        self.is_sparse = mla_modules.is_sparse

        if self.indexer is not None:
            assert hasattr(self.indexer, "topk_tokens")
            self.topk_tokens = self.indexer.topk_tokens
            self.topk_indices_buffer = mla_modules.topk_indices_buffer

        # In the MLA backend, kv_cache includes both k_c and
        # pe (i.e. decoupled position embeddings). In particular,
        # the concat_and_cache_mla op requires
        #     k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
        # i.e.
        #     kv_lora_rank + qk_rope_head_dim == head_size
        self.mla_attn = Attention(
            num_heads=self.num_heads,
            head_size=self.kv_lora_rank + self.qk_rope_head_dim,
            scale=scale,
            num_kv_heads=1,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_mla=True,
            use_sparse=mla_modules.is_sparse,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=self.kv_b_proj,
            indexer=self.indexer,
        )

        self.prefix = prefix

    def forward_native(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        q_c = None
        kv_lora = None

        if self.q_lora_rank is not None:
            assert self.fused_qkv_a_proj is not None, \
                "fused_qkv_a_proj is required when q_lora_rank is not None"
            assert self.q_a_layernorm is not None, \
                "q_a_layernorm is required when q_lora_rank is not None"
            assert self.q_b_proj is not None, \
                "q_b_proj is required when q_lora_rank is not None"
            qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
            q_c, kv_lora = qkv_lora.split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                dim=-1,
            )
            q_c = self.q_a_layernorm(q_c)
            q = self.q_b_proj(q_c)[0]
        else:
            assert self.kv_a_proj_with_mqa is not None, \
                "kv_a_proj_with_mqa is required when q_lora_rank is None"
            assert self.q_proj is not None, \
                "q_proj is required when q_lora_rank is None"
            kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
            q = self.q_proj(hidden_states)[0]

        kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
                                   dim=-1)
        kv_c_normed = self.kv_a_layernorm(kv_c)

        q = q.view(-1, self.num_heads, self.qk_head_dim)
        # Add head dim of 1 to k_pe
        k_pe = k_pe.unsqueeze(1)

        q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
            positions, q[..., self.qk_nope_head_dim:], k_pe)

        if self.indexer and self.is_sparse:
            _topk_indices = self.indexer(hidden_states, q_c, positions,
                                         self.rotary_emb)

        attn_out = self.mla_attn(
            q,
            kv_c_normed,
            k_pe,
            output_shape=(hidden_states.shape[0],
                          self.num_heads * self.v_head_dim))
        return self.o_proj(attn_out)[0]

    def forward_cuda(self, *args, **kwargs):
        return self.forward_native(*args, **kwargs)

fused_qkv_a_proj instance-attribute

fused_qkv_a_proj = fused_qkv_a_proj

hidden_size instance-attribute

hidden_size = hidden_size

indexer instance-attribute

indexer = indexer

is_sparse instance-attribute

is_sparse = is_sparse

kv_a_layernorm instance-attribute

kv_a_layernorm = kv_a_layernorm

kv_a_proj_with_mqa instance-attribute

kv_a_proj_with_mqa = kv_a_proj_with_mqa

kv_b_proj instance-attribute

kv_b_proj = kv_b_proj

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

mla_attn instance-attribute

mla_attn = Attention(
    num_heads=num_heads,
    head_size=kv_lora_rank + qk_rope_head_dim,
    scale=scale,
    num_kv_heads=1,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
    use_mla=True,
    use_sparse=is_sparse,
    q_lora_rank=q_lora_rank,
    kv_lora_rank=kv_lora_rank,
    qk_nope_head_dim=qk_nope_head_dim,
    qk_rope_head_dim=qk_rope_head_dim,
    qk_head_dim=qk_head_dim,
    v_head_dim=v_head_dim,
    kv_b_proj=kv_b_proj,
    indexer=indexer,
)

num_heads instance-attribute

num_heads = num_heads

o_proj instance-attribute

o_proj = o_proj

prefix instance-attribute

prefix = prefix

q_a_layernorm instance-attribute

q_a_layernorm = q_a_layernorm

q_b_proj instance-attribute

q_b_proj = q_b_proj

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

q_proj instance-attribute

q_proj = q_proj

qk_head_dim instance-attribute

qk_head_dim = qk_nope_head_dim + qk_rope_head_dim

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

rotary_emb instance-attribute

rotary_emb = rotary_emb

topk_indices_buffer instance-attribute

topk_indices_buffer = topk_indices_buffer

topk_tokens instance-attribute

topk_tokens = topk_tokens

v_head_dim instance-attribute

v_head_dim = v_head_dim

__init__

__init__(
    hidden_size: int,
    num_heads: int,
    scale: float,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    mla_modules: MLAModules,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/layers/mla.py
def __init__(
    self,
    hidden_size: int,
    num_heads: int,
    scale: float,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    mla_modules: MLAModules,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = hidden_size
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
    self.v_head_dim = v_head_dim
    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank
    self.num_heads = num_heads
    self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
    self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
    self.q_a_layernorm = mla_modules.q_a_layernorm
    self.q_b_proj = mla_modules.q_b_proj
    self.q_proj = mla_modules.q_proj
    self.kv_a_layernorm = mla_modules.kv_a_layernorm
    self.kv_b_proj = mla_modules.kv_b_proj
    self.rotary_emb = mla_modules.rotary_emb
    self.o_proj = mla_modules.o_proj
    self.indexer = mla_modules.indexer
    self.is_sparse = mla_modules.is_sparse

    if self.indexer is not None:
        assert hasattr(self.indexer, "topk_tokens")
        self.topk_tokens = self.indexer.topk_tokens
        self.topk_indices_buffer = mla_modules.topk_indices_buffer

    # In the MLA backend, kv_cache includes both k_c and
    # pe (i.e. decoupled position embeddings). In particular,
    # the concat_and_cache_mla op requires
    #     k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
    # i.e.
    #     kv_lora_rank + qk_rope_head_dim == head_size
    self.mla_attn = Attention(
        num_heads=self.num_heads,
        head_size=self.kv_lora_rank + self.qk_rope_head_dim,
        scale=scale,
        num_kv_heads=1,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
        use_mla=True,
        use_sparse=mla_modules.is_sparse,
        # MLA Args
        q_lora_rank=self.q_lora_rank,
        kv_lora_rank=self.kv_lora_rank,
        qk_nope_head_dim=self.qk_nope_head_dim,
        qk_rope_head_dim=self.qk_rope_head_dim,
        qk_head_dim=self.qk_head_dim,
        v_head_dim=self.v_head_dim,
        kv_b_proj=self.kv_b_proj,
        indexer=self.indexer,
    )

    self.prefix = prefix

forward_cuda

forward_cuda(*args, **kwargs)
Source code in vllm/model_executor/layers/mla.py
def forward_cuda(self, *args, **kwargs):
    return self.forward_native(*args, **kwargs)

forward_native

forward_native(
    positions: Tensor, hidden_states: Tensor
) -> Tensor
Source code in vllm/model_executor/layers/mla.py
def forward_native(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    q_c = None
    kv_lora = None

    if self.q_lora_rank is not None:
        assert self.fused_qkv_a_proj is not None, \
            "fused_qkv_a_proj is required when q_lora_rank is not None"
        assert self.q_a_layernorm is not None, \
            "q_a_layernorm is required when q_lora_rank is not None"
        assert self.q_b_proj is not None, \
            "q_b_proj is required when q_lora_rank is not None"
        qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
        q_c, kv_lora = qkv_lora.split(
            [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
            dim=-1,
        )
        q_c = self.q_a_layernorm(q_c)
        q = self.q_b_proj(q_c)[0]
    else:
        assert self.kv_a_proj_with_mqa is not None, \
            "kv_a_proj_with_mqa is required when q_lora_rank is None"
        assert self.q_proj is not None, \
            "q_proj is required when q_lora_rank is None"
        kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
        q = self.q_proj(hidden_states)[0]

    kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
                               dim=-1)
    kv_c_normed = self.kv_a_layernorm(kv_c)

    q = q.view(-1, self.num_heads, self.qk_head_dim)
    # Add head dim of 1 to k_pe
    k_pe = k_pe.unsqueeze(1)

    q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
        positions, q[..., self.qk_nope_head_dim:], k_pe)

    if self.indexer and self.is_sparse:
        _topk_indices = self.indexer(hidden_states, q_c, positions,
                                     self.rotary_emb)

    attn_out = self.mla_attn(
        q,
        kv_c_normed,
        k_pe,
        output_shape=(hidden_states.shape[0],
                      self.num_heads * self.v_head_dim))
    return self.o_proj(attn_out)[0]