Skip to content

vllm.model_executor.models.longcat_flash

Inference-only Flash model compatible with HuggingFace weights.

logger module-attribute

logger = init_logger(__name__)

FlashConfig

Bases: PretrainedConfig

Flash model configuration.

Source code in vllm/model_executor/models/longcat_flash.py
class FlashConfig(PretrainedConfig):
    """Flash model configuration."""
    model_type = "longcat_flash"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        vocab_size=131072,
        hidden_size=4096,
        intermediate_size=8192,
        num_layers=28,
        num_hidden_layers=None,
        num_attention_heads=96,
        num_key_value_heads=128,
        ep_size=1,
        kv_lora_rank=512,
        q_lora_rank=1536,
        qk_rope_head_dim=64,
        v_head_dim=128,
        qk_nope_head_dim=128,
        num_experts_per_tok=None,
        norm_topk_prob=False,
        max_position_embeddings=8192,
        initializer_range=0.02,
        rms_norm_eps=1e-05,
        use_cache=True,
        pad_token_id=None,
        bos_token_id=100000,
        eos_token_id=100001,
        pretraining_tp=1,
        tie_word_embeddings=False,
        rope_theta=1000000.0,
        rope_scaling=None,
        attention_bias=False,
        attention_dropout=0.0,
        mla_scale_q_lora=False,
        mla_scale_kv_lora=False,
        torch_dtype="bfloat16",
        params_dtype="bfloat16",
        router_dtype="float32",
        router_bias=False,
        topk_method=None,
        routed_scaling_factor=None,
        zero_expert_num=0,
        zero_expert_type=None,
        nextn_use_scmoe=False,
        **kwargs,
    ):
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            torch_dtype=torch_dtype,
            params_dtype=params_dtype,
            router_dtype=router_dtype,
            topk_method=topk_method,
            router_bias=router_bias,
            nextn_use_scmoe=nextn_use_scmoe,
            **kwargs,
        )
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_hidden_layers = (num_hidden_layers if num_hidden_layers
                                  is not None else num_layers)
        self.num_attention_heads = num_attention_heads
        self.ep_size = ep_size
        self.kv_lora_rank = kv_lora_rank
        self.q_lora_rank = q_lora_rank
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.qk_nope_head_dim = qk_nope_head_dim
        self.num_experts_per_tok = num_experts_per_tok
        self.norm_topk_prob = norm_topk_prob
        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.pretraining_tp = pretraining_tp
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.mla_scale_q_lora = mla_scale_q_lora
        self.mla_scale_kv_lora = mla_scale_kv_lora
        self.zero_expert_num = zero_expert_num
        self.zero_expert_type = zero_expert_type
        self.routed_scaling_factor = routed_scaling_factor
        self.hidden_act = "silu"
        self.intermediate_size = self.ffn_hidden_size if hasattr(
            self, "ffn_hidden_size") else self.intermediate_size
        if hasattr(self, "moe_intermediate_size"):
            self.moe_intermediate_size = self.moe_intermediate_size
        elif hasattr(self, "expert_ffn_hidden_size"):
            self.moe_intermediate_size = self.expert_ffn_hidden_size
        else:
            self.moe_intermediate_size = self.intermediate_size

attention_bias instance-attribute

attention_bias = attention_bias

attention_dropout instance-attribute

attention_dropout = attention_dropout

ep_size instance-attribute

ep_size = ep_size

hidden_act instance-attribute

hidden_act = 'silu'

hidden_size instance-attribute

hidden_size = hidden_size

initializer_range instance-attribute

initializer_range = initializer_range

intermediate_size instance-attribute

intermediate_size = (
    ffn_hidden_size
    if hasattr(self, "ffn_hidden_size")
    else intermediate_size
)

keys_to_ignore_at_inference class-attribute instance-attribute

keys_to_ignore_at_inference = ['past_key_values']

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

mla_scale_kv_lora instance-attribute

mla_scale_kv_lora = mla_scale_kv_lora

mla_scale_q_lora instance-attribute

mla_scale_q_lora = mla_scale_q_lora

model_type class-attribute instance-attribute

model_type = 'longcat_flash'

moe_intermediate_size instance-attribute

moe_intermediate_size = moe_intermediate_size

norm_topk_prob instance-attribute

norm_topk_prob = norm_topk_prob

num_attention_heads instance-attribute

num_attention_heads = num_attention_heads

num_experts_per_tok instance-attribute

num_experts_per_tok = num_experts_per_tok

num_hidden_layers instance-attribute

num_hidden_layers = (
    num_hidden_layers
    if num_hidden_layers is not None
    else num_layers
)

num_key_value_heads instance-attribute

num_key_value_heads = num_key_value_heads

pretraining_tp instance-attribute

pretraining_tp = pretraining_tp

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

rms_norm_eps instance-attribute

rms_norm_eps = rms_norm_eps

rope_scaling instance-attribute

rope_scaling = rope_scaling

rope_theta instance-attribute

rope_theta = rope_theta

routed_scaling_factor instance-attribute

routed_scaling_factor = routed_scaling_factor

use_cache instance-attribute

use_cache = use_cache

v_head_dim instance-attribute

v_head_dim = v_head_dim

vocab_size instance-attribute

vocab_size = vocab_size

zero_expert_num instance-attribute

zero_expert_num = zero_expert_num

zero_expert_type instance-attribute

zero_expert_type = zero_expert_type

__init__

__init__(
    vocab_size=131072,
    hidden_size=4096,
    intermediate_size=8192,
    num_layers=28,
    num_hidden_layers=None,
    num_attention_heads=96,
    num_key_value_heads=128,
    ep_size=1,
    kv_lora_rank=512,
    q_lora_rank=1536,
    qk_rope_head_dim=64,
    v_head_dim=128,
    qk_nope_head_dim=128,
    num_experts_per_tok=None,
    norm_topk_prob=False,
    max_position_embeddings=8192,
    initializer_range=0.02,
    rms_norm_eps=1e-05,
    use_cache=True,
    pad_token_id=None,
    bos_token_id=100000,
    eos_token_id=100001,
    pretraining_tp=1,
    tie_word_embeddings=False,
    rope_theta=1000000.0,
    rope_scaling=None,
    attention_bias=False,
    attention_dropout=0.0,
    mla_scale_q_lora=False,
    mla_scale_kv_lora=False,
    torch_dtype="bfloat16",
    params_dtype="bfloat16",
    router_dtype="float32",
    router_bias=False,
    topk_method=None,
    routed_scaling_factor=None,
    zero_expert_num=0,
    zero_expert_type=None,
    nextn_use_scmoe=False,
    **kwargs,
)
Source code in vllm/model_executor/models/longcat_flash.py
def __init__(
    self,
    vocab_size=131072,
    hidden_size=4096,
    intermediate_size=8192,
    num_layers=28,
    num_hidden_layers=None,
    num_attention_heads=96,
    num_key_value_heads=128,
    ep_size=1,
    kv_lora_rank=512,
    q_lora_rank=1536,
    qk_rope_head_dim=64,
    v_head_dim=128,
    qk_nope_head_dim=128,
    num_experts_per_tok=None,
    norm_topk_prob=False,
    max_position_embeddings=8192,
    initializer_range=0.02,
    rms_norm_eps=1e-05,
    use_cache=True,
    pad_token_id=None,
    bos_token_id=100000,
    eos_token_id=100001,
    pretraining_tp=1,
    tie_word_embeddings=False,
    rope_theta=1000000.0,
    rope_scaling=None,
    attention_bias=False,
    attention_dropout=0.0,
    mla_scale_q_lora=False,
    mla_scale_kv_lora=False,
    torch_dtype="bfloat16",
    params_dtype="bfloat16",
    router_dtype="float32",
    router_bias=False,
    topk_method=None,
    routed_scaling_factor=None,
    zero_expert_num=0,
    zero_expert_type=None,
    nextn_use_scmoe=False,
    **kwargs,
):
    super().__init__(
        pad_token_id=pad_token_id,
        bos_token_id=bos_token_id,
        eos_token_id=eos_token_id,
        tie_word_embeddings=tie_word_embeddings,
        torch_dtype=torch_dtype,
        params_dtype=params_dtype,
        router_dtype=router_dtype,
        topk_method=topk_method,
        router_bias=router_bias,
        nextn_use_scmoe=nextn_use_scmoe,
        **kwargs,
    )
    self.vocab_size = vocab_size
    self.max_position_embeddings = max_position_embeddings
    self.hidden_size = hidden_size
    self.num_hidden_layers = (num_hidden_layers if num_hidden_layers
                              is not None else num_layers)
    self.num_attention_heads = num_attention_heads
    self.ep_size = ep_size
    self.kv_lora_rank = kv_lora_rank
    self.q_lora_rank = q_lora_rank
    self.qk_rope_head_dim = qk_rope_head_dim
    self.v_head_dim = v_head_dim
    self.qk_nope_head_dim = qk_nope_head_dim
    self.num_experts_per_tok = num_experts_per_tok
    self.norm_topk_prob = norm_topk_prob
    # for backward compatibility
    if num_key_value_heads is None:
        num_key_value_heads = num_attention_heads

    self.num_key_value_heads = num_key_value_heads
    self.initializer_range = initializer_range
    self.rms_norm_eps = rms_norm_eps
    self.pretraining_tp = pretraining_tp
    self.use_cache = use_cache
    self.rope_theta = rope_theta
    self.rope_scaling = rope_scaling
    self.attention_bias = attention_bias
    self.attention_dropout = attention_dropout
    self.mla_scale_q_lora = mla_scale_q_lora
    self.mla_scale_kv_lora = mla_scale_kv_lora
    self.zero_expert_num = zero_expert_num
    self.zero_expert_type = zero_expert_type
    self.routed_scaling_factor = routed_scaling_factor
    self.hidden_act = "silu"
    self.intermediate_size = self.ffn_hidden_size if hasattr(
        self, "ffn_hidden_size") else self.intermediate_size
    if hasattr(self, "moe_intermediate_size"):
        self.moe_intermediate_size = self.moe_intermediate_size
    elif hasattr(self, "expert_ffn_hidden_size"):
        self.moe_intermediate_size = self.expert_ffn_hidden_size
    else:
        self.moe_intermediate_size = self.intermediate_size

FlashDecoderLayer

Bases: Module

Flash decoder layer with dual attention and MLP structure.

Source code in vllm/model_executor/models/longcat_flash.py
class FlashDecoderLayer(nn.Module):
    """Flash decoder layer with dual attention and MLP structure."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        config: FlashConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        enable_eplb: bool = False,
    ) -> None:
        super().__init__()
        self.layer_idx = int(prefix.split(sep='.')[-1])
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)

        # Dual attention structure
        self.self_attn = nn.ModuleList([
            DeepseekV2MLAAttention(
                vllm_config=vllm_config,
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
                qk_nope_head_dim=config.qk_nope_head_dim,
                qk_rope_head_dim=config.qk_rope_head_dim,
                v_head_dim=config.v_head_dim,
                q_lora_rank=(config.q_lora_rank if hasattr(
                    config, "q_lora_rank") else None),
                kv_lora_rank=config.kv_lora_rank,
                rope_theta=rope_theta,
                rope_scaling=rope_scaling,
                max_position_embeddings=max_position_embeddings,
                cache_config=cache_config,
                quant_config=None if "self_attn" in getattr(
                    config, "disable_quant_module", []) else quant_config,
                prefix=f"{prefix}.self_attn.{i}",
            ) for i in range(2)
        ])
        self.input_layernorm = nn.ModuleList([
            RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            for i in range(2)
        ])
        self.post_attention_layernorm = nn.ModuleList([
            RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            for i in range(2)
        ])

        # Dual MLP structure
        self.mlps = nn.ModuleList([
            FlashMLP(
                hidden_size=self.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=None if "mlps" in getattr(
                    config, "disable_quant_module", []) else quant_config,
                prefix=f"{prefix}.mlps.{i}",
            ) for i in range(2)
        ])

        self.mlp = LongcatMoe(
            config=config,
            num_experts=config.n_routed_experts if hasattr(
                config, "n_routed_experts") else
            config.num_experts[self.layer_idx],
            top_k=config.moe_topk
            if hasattr(config, "moe_topk") else config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            quant_config=quant_config,
            prefix=(f"{prefix}.mlp"),
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> tuple[torch.Tensor, torch.Tensor]:

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm[0](hidden_states)
        else:
            hidden_states, residual = self.input_layernorm[0](hidden_states,
                                                              residual)

        hidden_states = self.self_attn[0](
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states, residual = self.post_attention_layernorm[0](
            hidden_states, residual)

        # moe
        hidden_states_copy = hidden_states.clone()
        moe_hidden_states = self.mlp(hidden_states_copy)

        # first mlp
        hidden_states = self.mlps[0](hidden_states)

        hidden_states, residual = self.input_layernorm[1](hidden_states,
                                                          residual)

        # second_attn
        hidden_states = self.self_attn[1](
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states, residual = self.post_attention_layernorm[1](
            hidden_states, residual)

        # second_mlp
        hidden_states = self.mlps[1](hidden_states)

        hidden_states = hidden_states + moe_hidden_states

        return hidden_states, residual

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = ModuleList(
    [
        (RMSNorm(hidden_size, eps=rms_norm_eps))
        for i in (range(2))
    ]
)

layer_idx instance-attribute

layer_idx = int(split(sep='.')[-1])

mlp instance-attribute

mlp = LongcatMoe(
    config=config,
    num_experts=n_routed_experts
    if hasattr(config, "n_routed_experts")
    else num_experts[layer_idx],
    top_k=moe_topk
    if hasattr(config, "moe_topk")
    else num_experts_per_tok,
    hidden_size=hidden_size,
    intermediate_size=moe_intermediate_size,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

mlps instance-attribute

mlps = ModuleList(
    [
        (
            FlashMLP(
                hidden_size=hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=hidden_act,
                quant_config=None
                if "mlps"
                in getattr(
                    config, "disable_quant_module", []
                )
                else quant_config,
                prefix=f"{prefix}.mlps.{i}",
            )
        )
        for i in (range(2))
    ]
)

post_attention_layernorm instance-attribute

post_attention_layernorm = ModuleList(
    [
        (RMSNorm(hidden_size, eps=rms_norm_eps))
        for i in (range(2))
    ]
)

self_attn instance-attribute

self_attn = ModuleList(
    [
        (
            DeepseekV2MLAAttention(
                vllm_config=vllm_config,
                config=config,
                hidden_size=hidden_size,
                num_heads=num_attention_heads,
                qk_nope_head_dim=qk_nope_head_dim,
                qk_rope_head_dim=qk_rope_head_dim,
                v_head_dim=v_head_dim,
                q_lora_rank=q_lora_rank
                if hasattr(config, "q_lora_rank")
                else None,
                kv_lora_rank=kv_lora_rank,
                rope_theta=rope_theta,
                rope_scaling=rope_scaling,
                max_position_embeddings=max_position_embeddings,
                cache_config=cache_config,
                quant_config=None
                if "self_attn"
                in getattr(
                    config, "disable_quant_module", []
                )
                else quant_config,
                prefix=f"{prefix}.self_attn.{i}",
            )
        )
        for i in (range(2))
    ]
)

__init__

__init__(
    vllm_config: VllmConfig,
    config: FlashConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    enable_eplb: bool = False,
) -> None
Source code in vllm/model_executor/models/longcat_flash.py
def __init__(
    self,
    vllm_config: VllmConfig,
    config: FlashConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    enable_eplb: bool = False,
) -> None:
    super().__init__()
    self.layer_idx = int(prefix.split(sep='.')[-1])
    self.hidden_size = config.hidden_size
    rope_theta = getattr(config, "rope_theta", 10000)
    rope_scaling = getattr(config, "rope_scaling", None)
    max_position_embeddings = getattr(config, "max_position_embeddings",
                                      8192)
    if rope_scaling is not None and getattr(
            config, "original_max_position_embeddings", None):
        rope_scaling["original_max_position_embeddings"] = (
            config.original_max_position_embeddings)

    # Dual attention structure
    self.self_attn = nn.ModuleList([
        DeepseekV2MLAAttention(
            vllm_config=vllm_config,
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(config.q_lora_rank if hasattr(
                config, "q_lora_rank") else None),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=None if "self_attn" in getattr(
                config, "disable_quant_module", []) else quant_config,
            prefix=f"{prefix}.self_attn.{i}",
        ) for i in range(2)
    ])
    self.input_layernorm = nn.ModuleList([
        RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        for i in range(2)
    ])
    self.post_attention_layernorm = nn.ModuleList([
        RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        for i in range(2)
    ])

    # Dual MLP structure
    self.mlps = nn.ModuleList([
        FlashMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=None if "mlps" in getattr(
                config, "disable_quant_module", []) else quant_config,
            prefix=f"{prefix}.mlps.{i}",
        ) for i in range(2)
    ])

    self.mlp = LongcatMoe(
        config=config,
        num_experts=config.n_routed_experts if hasattr(
            config, "n_routed_experts") else
        config.num_experts[self.layer_idx],
        top_k=config.moe_topk
        if hasattr(config, "moe_topk") else config.num_experts_per_tok,
        hidden_size=config.hidden_size,
        intermediate_size=config.moe_intermediate_size,
        quant_config=quant_config,
        prefix=(f"{prefix}.mlp"),
    )

forward

forward(
    positions: Tensor,
    hidden_states: Tensor,
    residual: Optional[Tensor],
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/longcat_flash.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:

    if residual is None:
        residual = hidden_states
        hidden_states = self.input_layernorm[0](hidden_states)
    else:
        hidden_states, residual = self.input_layernorm[0](hidden_states,
                                                          residual)

    hidden_states = self.self_attn[0](
        positions=positions,
        hidden_states=hidden_states,
    )

    hidden_states, residual = self.post_attention_layernorm[0](
        hidden_states, residual)

    # moe
    hidden_states_copy = hidden_states.clone()
    moe_hidden_states = self.mlp(hidden_states_copy)

    # first mlp
    hidden_states = self.mlps[0](hidden_states)

    hidden_states, residual = self.input_layernorm[1](hidden_states,
                                                      residual)

    # second_attn
    hidden_states = self.self_attn[1](
        positions=positions,
        hidden_states=hidden_states,
    )
    hidden_states, residual = self.post_attention_layernorm[1](
        hidden_states, residual)

    # second_mlp
    hidden_states = self.mlps[1](hidden_states)

    hidden_states = hidden_states + moe_hidden_states

    return hidden_states, residual

FlashMLP

Bases: Module

Flash MLP layer.

Source code in vllm/model_executor/models/longcat_flash.py
class FlashMLP(nn.Module):
    """Flash MLP layer."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.numel() == 0:
            return x

        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = SiluAndMul()

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    reduce_results=reduce_results,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

__init__

__init__(
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: Optional[QuantizationConfig] = None,
    reduce_results: bool = True,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/longcat_flash.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    hidden_act: str,
    quant_config: Optional[QuantizationConfig] = None,
    reduce_results: bool = True,
    prefix: str = "",
) -> None:
    super().__init__()
    self.gate_up_proj = MergedColumnParallelLinear(
        hidden_size,
        [intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
    )
    self.down_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        reduce_results=reduce_results,
        prefix=f"{prefix}.down_proj",
    )
    if hidden_act != "silu":
        raise ValueError(f"Unsupported activation: {hidden_act}. "
                         "Only silu is supported for now.")
    self.act_fn = SiluAndMul()

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/longcat_flash.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    if x.numel() == 0:
        return x

    gate_up, _ = self.gate_up_proj(x)
    x = self.act_fn(gate_up)
    x, _ = self.down_proj(x)
    return x

FlashModel

Bases: Module

Flash model.

Source code in vllm/model_executor/models/longcat_flash.py
@support_torch_compile
class FlashModel(nn.Module):
    """Flash model."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.config = config

        self.padding_idx = getattr(config, "pad_token_id", None)
        self.vocab_size = config.vocab_size

        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                prefix=maybe_prefix(prefix, "embed_tokens"),
            )
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: FlashDecoderLayer(
                vllm_config,
                config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers")
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

config instance-attribute

config = config

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size,
    hidden_size,
    prefix=maybe_prefix(prefix, "embed_tokens"),
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], hidden_size
    )
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

padding_idx instance-attribute

padding_idx = getattr(config, 'pad_token_id', None)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/longcat_flash.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config
    self.config = config

    self.padding_idx = getattr(config, "pad_token_id", None)
    self.vocab_size = config.vocab_size

    if get_pp_group().is_first_rank:
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
    else:
        self.embed_tokens = PPMissingLayer()
    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: FlashDecoderLayer(
            vllm_config,
            config,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
        ),
        prefix=f"{prefix}.layers")
    if get_pp_group().is_last_rank:
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    else:
        self.norm = PPMissingLayer()
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size))

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/longcat_flash.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    for i in range(self.start_layer, self.end_layer):
        layer = self.layers[i]
        hidden_states, residual = layer(
            positions,
            hidden_states,
            residual,
        )

    if not get_pp_group().is_last_rank:
        return IntermediateTensors({
            "hidden_states": hidden_states,
            "residual": residual
        })

    hidden_states, _ = self.norm(hidden_states, residual)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/longcat_flash.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids)

LongcatFlashForCausalLM

Bases: Module, SupportsLoRA, SupportsPP

Flash model for causal language modeling.

Source code in vllm/model_executor/models/longcat_flash.py
class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    """Flash model for causal language modeling."""

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        config.intermediate_size = config.ffn_hidden_size if hasattr(
            config, "ffn_hidden_size") else config.intermediate_size
        self.lora_config = lora_config
        self.quant_config = quant_config

        self.model = FlashModel(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))

        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.hidden_size,
                                          quant_config=quant_config,
                                          prefix=maybe_prefix(
                                              prefix, "lm_head"))
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts if hasattr(
                self.config, "n_routed_experts") else
            self.config.num_experts[0],
        )

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:

        stacked_params_mapping = [
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]

        expert_params_mapping = self.get_expert_mapping()
        loaded_params: set[str] = set()

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                if "mlp" in name and "mlps" not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if (name.endswith(".bias")
                        or name.endswith("_bias")) and name not in params_dict:
                    continue
                # Skip mtp
                if ".mtp." in name:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                is_expert_weight = False
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    is_expert_weight = True
                    name_mapped = name.replace(weight_name, param_name)
                    # Skip mtp
                    if ".mtp." in name_mapped:
                        continue
                    if (name_mapped.endswith(".bias")
                            or name_mapped.endswith("_bias")
                        ) and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name_mapped]
                    weight_loader = param.weight_loader
                    weight_loader = typing.cast(Callable[..., bool],
                                                param.weight_loader)
                    success = weight_loader(param,
                                            loaded_weight,
                                            name_mapped,
                                            shard_id=shard_id,
                                            expert_id=expert_id,
                                            return_success=True)
                    if success:
                        name = name_mapped
                        break
                else:
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    # Skip loading kv_scale from ckpts towards new design.
                    if name.endswith(".kv_scale") and name not in params_dict:
                        continue
                    # Skip mtp
                    if ".mtp." in name:
                        continue
                    if name is None:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        for layer_id in range(self.config.num_hidden_layers):
            for i in range(2):
                if isinstance(self.model.layers[layer_id], PPMissingLayer):
                    continue
                self_attn = self.model.layers[layer_id].self_attn[i]
                if hasattr(self.quant_config, "weight_block_size"
                           ) and self_attn.kv_b_proj.weight.dtype in (
                               torch.float8_e4m3fn,
                               torch.float8_e4m3fnuz,
                           ):
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        dtype = torch.get_default_dtype()
                        w = block_dequant(self_attn.kv_b_proj.weight,
                                          self_attn.kv_b_proj.weight_scale_inv,
                                          weight_block_size).to(dtype)
                else:
                    w = self_attn.kv_b_proj.weight

                w_kc, w_vc = w.unflatten(
                    0,
                    (-1,
                     self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split(
                         [self_attn.qk_nope_head_dim, self_attn.v_head_dim],
                         dim=1)
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(
                    1, 2)
                self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
                if self.config.mla_scale_q_lora:
                    self_attn.q_a_layernorm.weight.data *= (
                        self.config.hidden_size / self.config.q_lora_rank)**0.5
                if self.config.mla_scale_kv_lora:
                    self_attn.kv_a_layernorm.weight.data *= (
                        self.config.hidden_size /
                        self.config.kv_lora_rank)**0.5
        return loaded_params

config instance-attribute

config = config

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "lm_head"),
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

lora_config instance-attribute

lora_config = lora_config

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = FlashModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/longcat_flash.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config

    self.config = config
    config.intermediate_size = config.ffn_hidden_size if hasattr(
        config, "ffn_hidden_size") else config.intermediate_size
    self.lora_config = lora_config
    self.quant_config = quant_config

    self.model = FlashModel(vllm_config=vllm_config,
                            prefix=maybe_prefix(prefix, "model"))

    if get_pp_group().is_last_rank:
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config,
                                      prefix=maybe_prefix(
                                          prefix, "lm_head"))
    else:
        self.lm_head = PPMissingLayer()

    self.logits_processor = LogitsProcessor(config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors)

compute_logits

compute_logits(hidden_states: Tensor) -> Optional[Tensor]
Source code in vllm/model_executor/models/longcat_flash.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/longcat_flash.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds)
    return hidden_states

get_expert_mapping

get_expert_mapping() -> list[tuple[str, str, int, str]]
Source code in vllm/model_executor/models/longcat_flash.py
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
    # Params for weights, fp8 weight scales, fp8 activation scales
    # (param_name, weight_name, expert_id, shard_id)
    return FusedMoE.make_expert_params_mapping(
        ckpt_gate_proj_name="gate_proj",
        ckpt_down_proj_name="down_proj",
        ckpt_up_proj_name="up_proj",
        num_experts=self.config.n_routed_experts if hasattr(
            self.config, "n_routed_experts") else
        self.config.num_experts[0],
    )

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/longcat_flash.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/longcat_flash.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:

    stacked_params_mapping = [
        ("fused_qkv_a_proj", "q_a_proj", 0),
        ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
        (".gate_up_proj", ".gate_proj", 0),
        (".gate_up_proj", ".up_proj", 1),
    ]

    expert_params_mapping = self.get_expert_mapping()
    loaded_params: set[str] = set()

    params_dict = dict(self.named_parameters())
    for name, loaded_weight in weights:
        if "rotary_emb.inv_freq" in name:
            continue
        for param_name, weight_name, shard_id in stacked_params_mapping:
            if weight_name not in name:
                continue
            if "mlp" in name and "mlps" not in name:
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if (name.endswith(".bias")
                    or name.endswith("_bias")) and name not in params_dict:
                continue
            # Skip mtp
            if ".mtp." in name:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            is_expert_weight = False
            for mapping in expert_params_mapping:
                param_name, weight_name, expert_id, shard_id = mapping
                if weight_name not in name:
                    continue
                is_expert_weight = True
                name_mapped = name.replace(weight_name, param_name)
                # Skip mtp
                if ".mtp." in name_mapped:
                    continue
                if (name_mapped.endswith(".bias")
                        or name_mapped.endswith("_bias")
                    ) and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name_mapped]
                weight_loader = param.weight_loader
                weight_loader = typing.cast(Callable[..., bool],
                                            param.weight_loader)
                success = weight_loader(param,
                                        loaded_weight,
                                        name_mapped,
                                        shard_id=shard_id,
                                        expert_id=expert_id,
                                        return_success=True)
                if success:
                    name = name_mapped
                    break
            else:
                if is_expert_weight:
                    # We've checked that this is an expert weight
                    # However it's not mapped locally to this rank
                    # So we simply skip it
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip loading kv_scale from ckpts towards new design.
                if name.endswith(".kv_scale") and name not in params_dict:
                    continue
                # Skip mtp
                if ".mtp." in name:
                    continue
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
        loaded_params.add(name)
    for layer_id in range(self.config.num_hidden_layers):
        for i in range(2):
            if isinstance(self.model.layers[layer_id], PPMissingLayer):
                continue
            self_attn = self.model.layers[layer_id].self_attn[i]
            if hasattr(self.quant_config, "weight_block_size"
                       ) and self_attn.kv_b_proj.weight.dtype in (
                           torch.float8_e4m3fn,
                           torch.float8_e4m3fnuz,
                       ):
                weight_block_size = self.quant_config.weight_block_size
                if weight_block_size is not None:
                    assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                    dtype = torch.get_default_dtype()
                    w = block_dequant(self_attn.kv_b_proj.weight,
                                      self_attn.kv_b_proj.weight_scale_inv,
                                      weight_block_size).to(dtype)
            else:
                w = self_attn.kv_b_proj.weight

            w_kc, w_vc = w.unflatten(
                0,
                (-1,
                 self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split(
                     [self_attn.qk_nope_head_dim, self_attn.v_head_dim],
                     dim=1)
            self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(
                1, 2)
            self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
            if self.config.mla_scale_q_lora:
                self_attn.q_a_layernorm.weight.data *= (
                    self.config.hidden_size / self.config.q_lora_rank)**0.5
            if self.config.mla_scale_kv_lora:
                self_attn.kv_a_layernorm.weight.data *= (
                    self.config.hidden_size /
                    self.config.kv_lora_rank)**0.5
    return loaded_params

LongcatMoe

Bases: Module

Source code in vllm/model_executor/models/longcat_flash.py
class LongcatMoe(nn.Module):

    def __init__(
        self,
        config: FlashConfig,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        enable_eplb: bool = False,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.zero_expert_num = config.zero_expert_num
        self.zero_expert_type = config.zero_expert_type
        self.routed_scaling_factor = config.routed_scaling_factor
        self.enable_eplb = enable_eplb
        # Gate always runs at half / full precision for now.
        self.rounter_params_dtype = params_dtype
        if config.router_dtype == "float32":
            self.rounter_params_dtype = torch.float32

        self.router = LongcatRouter(
            config=config,
            zero_expert_num=self.zero_expert_num,
            rounter_params_dtype=self.rounter_params_dtype,
            prefix=f"{prefix}.gate")

        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            reduce_results=True,
            params_dtype=params_dtype,
            e_score_correction_bias=self.router.e_score_correction_bias,
            renormalize=False,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            zero_expert_num=self.zero_expert_num,
            zero_expert_type=self.zero_expert_type,
            enable_eplb=self.enable_eplb,
            routed_scaling_factor=config.routed_scaling_factor,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

        router_logits = self.router(hidden_states.to(
            self.rounter_params_dtype))
        final_hidden_states = self.experts(hidden_states=hidden_states,
                                           router_logits=router_logits)

        return final_hidden_states.view(num_tokens, hidden_dim)

enable_eplb instance-attribute

enable_eplb = enable_eplb

experts instance-attribute

experts = FusedMoE(
    num_experts=num_experts,
    top_k=top_k,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    reduce_results=True,
    params_dtype=params_dtype,
    e_score_correction_bias=e_score_correction_bias,
    renormalize=False,
    quant_config=quant_config,
    prefix=f"{prefix}.experts",
    zero_expert_num=zero_expert_num,
    zero_expert_type=zero_expert_type,
    enable_eplb=enable_eplb,
    routed_scaling_factor=routed_scaling_factor,
)

hidden_size instance-attribute

hidden_size = hidden_size

rounter_params_dtype instance-attribute

rounter_params_dtype = params_dtype

routed_scaling_factor instance-attribute

routed_scaling_factor = routed_scaling_factor

router instance-attribute

router = LongcatRouter(
    config=config,
    zero_expert_num=zero_expert_num,
    rounter_params_dtype=rounter_params_dtype,
    prefix=f"{prefix}.gate",
)

zero_expert_num instance-attribute

zero_expert_num = zero_expert_num

zero_expert_type instance-attribute

zero_expert_type = zero_expert_type

__init__

__init__(
    config: FlashConfig,
    num_experts: int,
    top_k: int,
    hidden_size: int,
    intermediate_size: int,
    params_dtype: Optional[dtype] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    enable_eplb: bool = False,
)
Source code in vllm/model_executor/models/longcat_flash.py
def __init__(
    self,
    config: FlashConfig,
    num_experts: int,
    top_k: int,
    hidden_size: int,
    intermediate_size: int,
    params_dtype: Optional[torch.dtype] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    enable_eplb: bool = False,
):
    super().__init__()
    self.hidden_size = hidden_size
    self.zero_expert_num = config.zero_expert_num
    self.zero_expert_type = config.zero_expert_type
    self.routed_scaling_factor = config.routed_scaling_factor
    self.enable_eplb = enable_eplb
    # Gate always runs at half / full precision for now.
    self.rounter_params_dtype = params_dtype
    if config.router_dtype == "float32":
        self.rounter_params_dtype = torch.float32

    self.router = LongcatRouter(
        config=config,
        zero_expert_num=self.zero_expert_num,
        rounter_params_dtype=self.rounter_params_dtype,
        prefix=f"{prefix}.gate")

    self.experts = FusedMoE(
        num_experts=num_experts,
        top_k=top_k,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        reduce_results=True,
        params_dtype=params_dtype,
        e_score_correction_bias=self.router.e_score_correction_bias,
        renormalize=False,
        quant_config=quant_config,
        prefix=f"{prefix}.experts",
        zero_expert_num=self.zero_expert_num,
        zero_expert_type=self.zero_expert_type,
        enable_eplb=self.enable_eplb,
        routed_scaling_factor=config.routed_scaling_factor,
    )

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/longcat_flash.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

    num_tokens, hidden_dim = hidden_states.shape
    hidden_states = hidden_states.view(-1, hidden_dim)

    router_logits = self.router(hidden_states.to(
        self.rounter_params_dtype))
    final_hidden_states = self.experts(hidden_states=hidden_states,
                                       router_logits=router_logits)

    return final_hidden_states.view(num_tokens, hidden_dim)

LongcatRouter

Bases: Module

Source code in vllm/model_executor/models/longcat_flash.py
class LongcatRouter(nn.Module):

    def __init__(self,
                 config,
                 zero_expert_num=0,
                 rounter_params_dtype=torch.bfloat16,
                 prefix: str = ""):
        super().__init__()
        self.n_routed_experts = config.n_routed_experts if hasattr(
            config, "n_routed_experts") else config.num_experts[0]
        self.n_routed_experts = self.n_routed_experts + zero_expert_num
        self.classifier = ReplicatedLinear(
            config.hidden_size,
            self.n_routed_experts,
            bias=config.router_bias,
            params_dtype=rounter_params_dtype,
            quant_config=None,
            prefix=f"{prefix}.classifier",
        )
        self.e_score_correction_bias = nn.Parameter(
            torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype))

    def forward(self, hidden_states):
        logits, _ = self.classifier(hidden_states)
        return logits

classifier instance-attribute

classifier = ReplicatedLinear(
    hidden_size,
    n_routed_experts,
    bias=router_bias,
    params_dtype=rounter_params_dtype,
    quant_config=None,
    prefix=f"{prefix}.classifier",
)

e_score_correction_bias instance-attribute

e_score_correction_bias = Parameter(
    zeros(n_routed_experts, dtype=rounter_params_dtype)
)

n_routed_experts instance-attribute

n_routed_experts = n_routed_experts + zero_expert_num

__init__

__init__(
    config,
    zero_expert_num=0,
    rounter_params_dtype=bfloat16,
    prefix: str = "",
)
Source code in vllm/model_executor/models/longcat_flash.py
def __init__(self,
             config,
             zero_expert_num=0,
             rounter_params_dtype=torch.bfloat16,
             prefix: str = ""):
    super().__init__()
    self.n_routed_experts = config.n_routed_experts if hasattr(
        config, "n_routed_experts") else config.num_experts[0]
    self.n_routed_experts = self.n_routed_experts + zero_expert_num
    self.classifier = ReplicatedLinear(
        config.hidden_size,
        self.n_routed_experts,
        bias=config.router_bias,
        params_dtype=rounter_params_dtype,
        quant_config=None,
        prefix=f"{prefix}.classifier",
    )
    self.e_score_correction_bias = nn.Parameter(
        torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype))

forward

forward(hidden_states)
Source code in vllm/model_executor/models/longcat_flash.py
def forward(self, hidden_states):
    logits, _ = self.classifier(hidden_states)
    return logits