Skip to content

vllm.model_executor.models.gemma3n

EPS module-attribute

EPS = tensor(min)

logger module-attribute

logger = init_logger(__name__)

Gemma3nAltUp

Bases: Module

Alternating updates (Altup) The AltUp module wraps transformer layers. The predict step modifies the input to the transformer layer, and the correct step propagates the output of the transformer layer to the sparsely updated dimensions. See more in the research paper: https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nAltUp(nn.Module):
    """Alternating updates (Altup)
    The AltUp module wraps transformer layers. The `predict` step modifies the
    input to the transformer layer, and the `correct` step propagates the output
    of the transformer layer to the sparsely updated dimensions.
    See more in the research paper:
    https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
    """

    def __init__(
        self,
        hidden_size: int,
        rms_norm_eps: float,
        altup_num_inputs: int,
        altup_coef_clip: float,
        altup_active_idx: int,
        quant_config: QuantizationConfig,
        prefix: str,
    ):
        super().__init__()

        self.altup_num_inputs = altup_num_inputs
        self.altup_active_idx = altup_active_idx
        self.altup_coef_clip = altup_coef_clip

        self.correction_coefs = ReplicatedLinear(
            altup_num_inputs,
            altup_num_inputs,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.correction_coefs",
            return_bias=False,
        )
        self.prediction_coefs = ReplicatedLinear(
            altup_num_inputs,
            altup_num_inputs**2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.prediction_coefs",
            return_bias=False,
        )
        self.modality_router = ReplicatedLinear(
            hidden_size,
            altup_num_inputs,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.modality_router",
            return_bias=False,
        )
        self.router_norm = RMSNorm(
            hidden_size=hidden_size,
            eps=rms_norm_eps,
        )
        self.router_input_scale = torch.tensor(
            hidden_size**-1.0, dtype=self.modality_router.weight.dtype)
        self.correct_output_scale = nn.Parameter(
            torch.zeros(hidden_size, dtype=torch.float32))

    def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
        router_inputs = self.router_norm(x) * self.router_input_scale
        routed = self.modality_router(router_inputs)
        return torch.tanh(routed.float()).type_as(x)

    def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
        return (corrected.type_as(self.correct_output_scale) *
                self.correct_output_scale).type_as(corrected)

    def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # hidden:       [altup_num_inputs, num_tokens, hidden_size]
        # modalities:   [num_tokens, num_altup_inputs]
        # all_coefs:    [num_tokens, num_altup_inputs ** 2]
        modalities = self._compute_router_modalities(
            hidden_states[self.altup_active_idx])
        all_coefs = self.prediction_coefs(modalities)

        # Reshape and transpose the 2D matrix for the matmul.
        # all_coefs_T:  [num_tokens, num_altup_inputs, num_altup_inputs]
        all_coefs_T = all_coefs.reshape(
            -1,
            self.altup_num_inputs,
            self.altup_num_inputs,
        ).permute(0, 2, 1)

        # hidden_states to [num_tokens, hidden_size, altup_num_inputs]
        predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs_T)
        # [altup_num_inputs, num_tokens, hidden_size]
        predictions = predictions.permute(2, 0, 1)
        predictions += hidden_states
        return predictions.contiguous()

    def correct(self, predictions: torch.Tensor,
                activated: torch.Tensor) -> torch.Tensor:
        # predictions:  [altup_num_inputs, num_tokens, hidden_size]
        # activated:    [num_tokens, hidden_size]
        # modalities:   [num_tokens, altup_num_inputs]
        modalities = self._compute_router_modalities(activated)
        # innovation:   [num_tokens, altup_num_inputs]
        innovation = activated - predictions[self.altup_active_idx]
        # innovation:   [altup_num_inputs, num_tokens, hidden_size]
        innovation = innovation.repeat(self.altup_num_inputs, 1, 1)

        # Permute to [altup_num_inputs, num_tokens] as the last dim
        # is a scalar applied to each altup input and expand on
        # num_tokens dim for broadcastability over hidden_size.
        # all_coefs:    [num_tokens, altup_num_inputs]
        all_coefs = self.correction_coefs(modalities) + 1.0
        # all_coefs:    [altup_num_inputs, num_tokens, 1]
        all_coefs = all_coefs.T.unsqueeze(-1)

        # Elementwise (broadcast over hidden_size).
        corrected = torch.mul(innovation, all_coefs)
        corrected += predictions

        return corrected.contiguous()

altup_active_idx instance-attribute

altup_active_idx = altup_active_idx

altup_coef_clip instance-attribute

altup_coef_clip = altup_coef_clip

altup_num_inputs instance-attribute

altup_num_inputs = altup_num_inputs

correct_output_scale instance-attribute

correct_output_scale = Parameter(
    zeros(hidden_size, dtype=float32)
)

correction_coefs instance-attribute

correction_coefs = ReplicatedLinear(
    altup_num_inputs,
    altup_num_inputs,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.correction_coefs",
    return_bias=False,
)

modality_router instance-attribute

modality_router = ReplicatedLinear(
    hidden_size,
    altup_num_inputs,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.modality_router",
    return_bias=False,
)

prediction_coefs instance-attribute

prediction_coefs = ReplicatedLinear(
    altup_num_inputs,
    altup_num_inputs**2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.prediction_coefs",
    return_bias=False,
)

router_input_scale instance-attribute

router_input_scale = tensor(hidden_size**-1.0, dtype=dtype)

router_norm instance-attribute

router_norm = RMSNorm(
    hidden_size=hidden_size, eps=rms_norm_eps
)

__init__

__init__(
    hidden_size: int,
    rms_norm_eps: float,
    altup_num_inputs: int,
    altup_coef_clip: float,
    altup_active_idx: int,
    quant_config: QuantizationConfig,
    prefix: str,
)
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    hidden_size: int,
    rms_norm_eps: float,
    altup_num_inputs: int,
    altup_coef_clip: float,
    altup_active_idx: int,
    quant_config: QuantizationConfig,
    prefix: str,
):
    super().__init__()

    self.altup_num_inputs = altup_num_inputs
    self.altup_active_idx = altup_active_idx
    self.altup_coef_clip = altup_coef_clip

    self.correction_coefs = ReplicatedLinear(
        altup_num_inputs,
        altup_num_inputs,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.correction_coefs",
        return_bias=False,
    )
    self.prediction_coefs = ReplicatedLinear(
        altup_num_inputs,
        altup_num_inputs**2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.prediction_coefs",
        return_bias=False,
    )
    self.modality_router = ReplicatedLinear(
        hidden_size,
        altup_num_inputs,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.modality_router",
        return_bias=False,
    )
    self.router_norm = RMSNorm(
        hidden_size=hidden_size,
        eps=rms_norm_eps,
    )
    self.router_input_scale = torch.tensor(
        hidden_size**-1.0, dtype=self.modality_router.weight.dtype)
    self.correct_output_scale = nn.Parameter(
        torch.zeros(hidden_size, dtype=torch.float32))

_compute_router_modalities

_compute_router_modalities(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
    router_inputs = self.router_norm(x) * self.router_input_scale
    routed = self.modality_router(router_inputs)
    return torch.tanh(routed.float()).type_as(x)

correct

correct(predictions: Tensor, activated: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def correct(self, predictions: torch.Tensor,
            activated: torch.Tensor) -> torch.Tensor:
    # predictions:  [altup_num_inputs, num_tokens, hidden_size]
    # activated:    [num_tokens, hidden_size]
    # modalities:   [num_tokens, altup_num_inputs]
    modalities = self._compute_router_modalities(activated)
    # innovation:   [num_tokens, altup_num_inputs]
    innovation = activated - predictions[self.altup_active_idx]
    # innovation:   [altup_num_inputs, num_tokens, hidden_size]
    innovation = innovation.repeat(self.altup_num_inputs, 1, 1)

    # Permute to [altup_num_inputs, num_tokens] as the last dim
    # is a scalar applied to each altup input and expand on
    # num_tokens dim for broadcastability over hidden_size.
    # all_coefs:    [num_tokens, altup_num_inputs]
    all_coefs = self.correction_coefs(modalities) + 1.0
    # all_coefs:    [altup_num_inputs, num_tokens, 1]
    all_coefs = all_coefs.T.unsqueeze(-1)

    # Elementwise (broadcast over hidden_size).
    corrected = torch.mul(innovation, all_coefs)
    corrected += predictions

    return corrected.contiguous()

predict

predict(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
    # hidden:       [altup_num_inputs, num_tokens, hidden_size]
    # modalities:   [num_tokens, num_altup_inputs]
    # all_coefs:    [num_tokens, num_altup_inputs ** 2]
    modalities = self._compute_router_modalities(
        hidden_states[self.altup_active_idx])
    all_coefs = self.prediction_coefs(modalities)

    # Reshape and transpose the 2D matrix for the matmul.
    # all_coefs_T:  [num_tokens, num_altup_inputs, num_altup_inputs]
    all_coefs_T = all_coefs.reshape(
        -1,
        self.altup_num_inputs,
        self.altup_num_inputs,
    ).permute(0, 2, 1)

    # hidden_states to [num_tokens, hidden_size, altup_num_inputs]
    predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs_T)
    # [altup_num_inputs, num_tokens, hidden_size]
    predictions = predictions.permute(2, 0, 1)
    predictions += hidden_states
    return predictions.contiguous()

scale_corrected_output

scale_corrected_output(corrected: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
    return (corrected.type_as(self.correct_output_scale) *
            self.correct_output_scale).type_as(corrected)

Gemma3nAttention

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nAttention(nn.Module):

    def __init__(self,
                 config: Gemma3nTextConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 head_dim: int,
                 max_position_embeddings: int,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
        super().__init__()
        self.config = config
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=config.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        self.q_norm = RMSNorm(hidden_size=self.head_dim,
                              eps=config.rms_norm_eps)
        self.k_norm = RMSNorm(hidden_size=self.head_dim,
                              eps=config.rms_norm_eps)
        self.v_norm = RMSNorm(hidden_size=self.head_dim,
                              eps=config.rms_norm_eps,
                              has_weight=False)

        layer_idx = extract_layer_index(prefix)
        is_sliding = config.layer_types[layer_idx] == "sliding_attention"
        self.sliding_window = config.sliding_window if is_sliding else None

        # Initialize the rotary embedding.
        if is_sliding:
            # Local attention. Override the values in config.json.
            rope_theta = config.rope_local_base_freq
            rope_scaling = {"rope_type": "default"}
        else:
            # Global attention. Use the values in config.json.
            rope_theta = config.rope_theta
            rope_scaling = config.rope_scaling

        first_kv_shared_layer_idx = (config.num_hidden_layers -
                                     config.num_kv_shared_layers)
        self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx

        kv_sharing_target_layer_name = None
        if self.is_kv_shared:
            # Last full attention layer is 1 before sharing
            # Last sliding attention layer is 2 before sharing
            offset = 2 if self.sliding_window is not None else 1
            kv_shared_layer_index = first_kv_shared_layer_idx - offset
            if kv_shared_layer_index >= 0:
                # Only the greater layer is required to specify sharing.
                kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn"  # noqa: E501

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            is_neox_style=True,
            rope_scaling=rope_scaling,
        )

        self.attn = Attention(
            num_heads=self.num_heads,
            head_size=self.head_dim,
            scale=1.0,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=self.sliding_window,
            kv_sharing_target_layer_name=kv_sharing_target_layer_name,
            prefix=f"{prefix}.attn")

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        q = q.unflatten(-1, (self.num_heads, self.head_dim))
        q = self.q_norm(q)
        q = q.flatten(-2, -1)
        k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
        k = self.k_norm(k)
        k = k.flatten(-2, -1)
        v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
        v = self.v_norm(v)
        v = v.flatten(-2, -1)

        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)

        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    scale=1.0,
    num_kv_heads=num_kv_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    per_layer_sliding_window=sliding_window,
    kv_sharing_target_layer_name=kv_sharing_target_layer_name,
    prefix=f"{prefix}.attn",
)

config instance-attribute

config = config

head_dim instance-attribute

head_dim = head_dim

hidden_size instance-attribute

hidden_size = hidden_size

is_kv_shared instance-attribute

is_kv_shared = layer_idx >= first_kv_shared_layer_idx

k_norm instance-attribute

k_norm = RMSNorm(hidden_size=head_dim, eps=rms_norm_eps)

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=attention_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

q_norm instance-attribute

q_norm = RMSNorm(hidden_size=head_dim, eps=rms_norm_eps)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=attention_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_dim,
    rotary_dim=head_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    is_neox_style=True,
    rope_scaling=rope_scaling,
)

sliding_window instance-attribute

sliding_window = sliding_window if is_sliding else None

total_num_heads instance-attribute

total_num_heads = num_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_kv_heads

v_norm instance-attribute

v_norm = RMSNorm(
    hidden_size=head_dim, eps=rms_norm_eps, has_weight=False
)

__init__

__init__(
    config: Gemma3nTextConfig,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
    max_position_embeddings: int,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/gemma3n.py
def __init__(self,
             config: Gemma3nTextConfig,
             hidden_size: int,
             num_heads: int,
             num_kv_heads: int,
             head_dim: int,
             max_position_embeddings: int,
             cache_config: Optional[CacheConfig] = None,
             quant_config: Optional[QuantizationConfig] = None,
             prefix: str = "") -> None:
    super().__init__()
    self.config = config
    self.hidden_size = hidden_size
    tp_size = get_tensor_model_parallel_world_size()
    self.total_num_heads = num_heads
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.total_num_kv_heads = num_kv_heads
    if self.total_num_kv_heads >= tp_size:
        # Number of KV heads is greater than TP size, so we partition
        # the KV heads across multiple tensor parallel GPUs.
        assert self.total_num_kv_heads % tp_size == 0
    else:
        # Number of KV heads is less than TP size, so we replicate
        # the KV heads across multiple tensor parallel GPUs.
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    self.head_dim = head_dim
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim

    self.qkv_proj = QKVParallelLinear(
        hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=config.attention_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.o_proj = RowParallelLinear(
        self.total_num_heads * self.head_dim,
        hidden_size,
        bias=config.attention_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )
    self.q_norm = RMSNorm(hidden_size=self.head_dim,
                          eps=config.rms_norm_eps)
    self.k_norm = RMSNorm(hidden_size=self.head_dim,
                          eps=config.rms_norm_eps)
    self.v_norm = RMSNorm(hidden_size=self.head_dim,
                          eps=config.rms_norm_eps,
                          has_weight=False)

    layer_idx = extract_layer_index(prefix)
    is_sliding = config.layer_types[layer_idx] == "sliding_attention"
    self.sliding_window = config.sliding_window if is_sliding else None

    # Initialize the rotary embedding.
    if is_sliding:
        # Local attention. Override the values in config.json.
        rope_theta = config.rope_local_base_freq
        rope_scaling = {"rope_type": "default"}
    else:
        # Global attention. Use the values in config.json.
        rope_theta = config.rope_theta
        rope_scaling = config.rope_scaling

    first_kv_shared_layer_idx = (config.num_hidden_layers -
                                 config.num_kv_shared_layers)
    self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx

    kv_sharing_target_layer_name = None
    if self.is_kv_shared:
        # Last full attention layer is 1 before sharing
        # Last sliding attention layer is 2 before sharing
        offset = 2 if self.sliding_window is not None else 1
        kv_shared_layer_index = first_kv_shared_layer_idx - offset
        if kv_shared_layer_index >= 0:
            # Only the greater layer is required to specify sharing.
            kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn"  # noqa: E501

    self.rotary_emb = get_rope(
        self.head_dim,
        rotary_dim=self.head_dim,
        max_position=max_position_embeddings,
        base=rope_theta,
        is_neox_style=True,
        rope_scaling=rope_scaling,
    )

    self.attn = Attention(
        num_heads=self.num_heads,
        head_size=self.head_dim,
        scale=1.0,
        num_kv_heads=self.num_kv_heads,
        cache_config=cache_config,
        quant_config=quant_config,
        per_layer_sliding_window=self.sliding_window,
        kv_sharing_target_layer_name=kv_sharing_target_layer_name,
        prefix=f"{prefix}.attn")

forward

forward(
    positions: Tensor, hidden_states: Tensor, **kwargs
) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    **kwargs,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

    q = q.unflatten(-1, (self.num_heads, self.head_dim))
    q = self.q_norm(q)
    q = q.flatten(-2, -1)
    k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
    k = self.k_norm(k)
    k = k.flatten(-2, -1)
    v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
    v = self.v_norm(v)
    v = v.flatten(-2, -1)

    q, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v)

    output, _ = self.o_proj(attn_output)
    return output

Gemma3nCrossDecoder

Bases: Module

Cross-decoder layers

Source code in vllm/model_executor/models/gemma3n.py
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
                       kv_sharing_fast_prefill)
class Gemma3nCrossDecoder(nn.Module):
    """
    Cross-decoder layers
    """

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layers: list[Gemma3nDecoderLayer],
        layer_idx_start: int,
    ):
        super().__init__()
        self.decoder_layers = decoder_layers
        self.layer_idx_start = layer_idx_start

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        per_layer_inputs: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        # [altnum_inputs, num_tokens, hidden_size]
        hidden_states = hidden_states.permute(2, 0, 1)
        for idx, layer in enumerate(self.decoder_layers):
            layer_idx = idx + self.layer_idx_start
            # [altup_num_inputs, num_tokens, hidden_size]
            hidden_states = layer(
                positions=positions,
                hidden_states=hidden_states,
                per_layer_input=per_layer_inputs[:, layer_idx, :],
                **kwargs,
            )
        # [num_tokens, hidden_size, altnum_inputs]
        hidden_states = hidden_states.permute(1, 2, 0)
        return hidden_states

decoder_layers instance-attribute

decoder_layers = decoder_layers

layer_idx_start instance-attribute

layer_idx_start = layer_idx_start

__init__

__init__(
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    decoder_layers: list[Gemma3nDecoderLayer],
    layer_idx_start: int,
)
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    decoder_layers: list[Gemma3nDecoderLayer],
    layer_idx_start: int,
):
    super().__init__()
    self.decoder_layers = decoder_layers
    self.layer_idx_start = layer_idx_start

forward

forward(
    positions: Tensor,
    hidden_states: Tensor,
    per_layer_inputs: Tensor,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    per_layer_inputs: torch.Tensor,
    **kwargs,
) -> torch.Tensor:
    # [altnum_inputs, num_tokens, hidden_size]
    hidden_states = hidden_states.permute(2, 0, 1)
    for idx, layer in enumerate(self.decoder_layers):
        layer_idx = idx + self.layer_idx_start
        # [altup_num_inputs, num_tokens, hidden_size]
        hidden_states = layer(
            positions=positions,
            hidden_states=hidden_states,
            per_layer_input=per_layer_inputs[:, layer_idx, :],
            **kwargs,
        )
    # [num_tokens, hidden_size, altnum_inputs]
    hidden_states = hidden_states.permute(1, 2, 0)
    return hidden_states

Gemma3nDecoderLayer

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nDecoderLayer(nn.Module):

    def __init__(
        self,
        config: Gemma3nTextConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        assert isinstance(config, Gemma3nTextConfig)
        self.altup_active_idx = config.altup_active_idx
        assert config.altup_correct_scale

        self.altup = Gemma3nAltUp(
            hidden_size=config.hidden_size,
            rms_norm_eps=config.rms_norm_eps,
            altup_num_inputs=config.altup_num_inputs,
            altup_coef_clip=config.altup_coef_clip,
            altup_active_idx=config.altup_active_idx,
            quant_config=quant_config,
            prefix=f"{prefix}.altup",
        )
        self.self_attn = Gemma3nAttention(
            config=config,
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.mlp = Gemma3nMLP(
            hidden_size=config.hidden_size,
            # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501
            intermediate_size=config.intermediate_size[extract_layer_index(
                prefix)],
            hidden_activation=config.hidden_activation,
            quant_config=quant_config,
            activation_sparsity=config.activation_sparsity_pattern[
                extract_layer_index(prefix)],
            prefix=f"{prefix}.mlp",
        )
        self.laurel = Gemma3nLaurelBlock(
            hidden_size=config.hidden_size,
            laurel_rank=config.laurel_rank,
            rms_norm_eps=config.rms_norm_eps,
            quant_config=quant_config,
            prefix=f"{prefix}.laurel",
        )

        # NOTE(rob): should be ColumnParallelLinear and RowParallelLinear
        # But, we need to add per_layer_input_gate(x) to per_layer_input.
        # per_layer_input cannot be sharded, so we replicate for now.
        self.per_layer_input_gate = ReplicatedLinear(
            config.hidden_size,
            config.hidden_size_per_layer_input,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.per_layer_input_gate",
            return_bias=False,
        )
        self.per_layer_projection = ReplicatedLinear(
            config.hidden_size_per_layer_input,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.per_layer_projection",
            return_bias=False,
        )

        # LayerNorms.
        self.input_layernorm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.pre_feedforward_layernorm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.post_feedforward_layernorm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )
        self.post_per_layer_input_norm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )

        self.act_fn = _ACTIVATION_REGISTRY[config.hidden_activation]

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        per_layer_input: torch.Tensor,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:

        # ActUp (predict).
        predictions = self.altup.predict(hidden_states)
        active_prediction = predictions[self.altup_active_idx]
        active_prediction_normed = self.input_layernorm(active_prediction)
        laurel_output = self.laurel(active_prediction_normed)

        # Attention.
        attn = self.self_attn(
            positions=positions,
            hidden_states=active_prediction_normed,
            **kwargs,
        )
        attn = self.post_attention_layernorm(attn)
        attn_gated = attn + active_prediction
        attn_laurel = (attn_gated + laurel_output) / torch.sqrt(
            torch.tensor(2.0))

        # MLP.
        attn_norm = self.pre_feedforward_layernorm(attn_laurel)
        attn_ffw = self.mlp(attn_norm)
        attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
        attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm

        # ActUp (connect).
        corrected_predictions = self.altup.correct(predictions,
                                                   attn_ffw_laurel_gated)
        first_prediction = corrected_predictions[self.altup_active_idx]
        first_prediction = self.altup.scale_corrected_output(first_prediction)

        # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
        first_prediction = self.per_layer_input_gate(first_prediction)
        first_prediction = self.act_fn(first_prediction)
        first_prediction = torch.mul(first_prediction, per_layer_input)

        # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
        first_prediction = self.per_layer_projection(first_prediction)
        first_prediction = self.post_per_layer_input_norm(first_prediction)
        corrected_predictions[1:] += first_prediction

        return corrected_predictions

act_fn instance-attribute

act_fn = _ACTIVATION_REGISTRY[hidden_activation]

altup instance-attribute

altup = Gemma3nAltUp(
    hidden_size=hidden_size,
    rms_norm_eps=rms_norm_eps,
    altup_num_inputs=altup_num_inputs,
    altup_coef_clip=altup_coef_clip,
    altup_active_idx=altup_active_idx,
    quant_config=quant_config,
    prefix=f"{prefix}.altup",
)

altup_active_idx instance-attribute

altup_active_idx = altup_active_idx

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

laurel instance-attribute

laurel = Gemma3nLaurelBlock(
    hidden_size=hidden_size,
    laurel_rank=laurel_rank,
    rms_norm_eps=rms_norm_eps,
    quant_config=quant_config,
    prefix=f"{prefix}.laurel",
)

mlp instance-attribute

mlp = Gemma3nMLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size[
        extract_layer_index(prefix)
    ],
    hidden_activation=hidden_activation,
    quant_config=quant_config,
    activation_sparsity=activation_sparsity_pattern[
        extract_layer_index(prefix)
    ],
    prefix=f"{prefix}.mlp",
)

per_layer_input_gate instance-attribute

per_layer_input_gate = ReplicatedLinear(
    hidden_size,
    hidden_size_per_layer_input,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.per_layer_input_gate",
    return_bias=False,
)

per_layer_projection instance-attribute

per_layer_projection = ReplicatedLinear(
    hidden_size_per_layer_input,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.per_layer_projection",
    return_bias=False,
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

post_feedforward_layernorm instance-attribute

post_feedforward_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

post_per_layer_input_norm instance-attribute

post_per_layer_input_norm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

pre_feedforward_layernorm instance-attribute

pre_feedforward_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

self_attn instance-attribute

self_attn = Gemma3nAttention(
    config=config,
    hidden_size=hidden_size,
    num_heads=num_attention_heads,
    num_kv_heads=num_key_value_heads,
    head_dim=head_dim,
    max_position_embeddings=max_position_embeddings,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

__init__

__init__(
    config: Gemma3nTextConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    config: Gemma3nTextConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    assert isinstance(config, Gemma3nTextConfig)
    self.altup_active_idx = config.altup_active_idx
    assert config.altup_correct_scale

    self.altup = Gemma3nAltUp(
        hidden_size=config.hidden_size,
        rms_norm_eps=config.rms_norm_eps,
        altup_num_inputs=config.altup_num_inputs,
        altup_coef_clip=config.altup_coef_clip,
        altup_active_idx=config.altup_active_idx,
        quant_config=quant_config,
        prefix=f"{prefix}.altup",
    )
    self.self_attn = Gemma3nAttention(
        config=config,
        hidden_size=config.hidden_size,
        num_heads=config.num_attention_heads,
        num_kv_heads=config.num_key_value_heads,
        head_dim=config.head_dim,
        max_position_embeddings=config.max_position_embeddings,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.self_attn",
    )
    self.mlp = Gemma3nMLP(
        hidden_size=config.hidden_size,
        # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501
        intermediate_size=config.intermediate_size[extract_layer_index(
            prefix)],
        hidden_activation=config.hidden_activation,
        quant_config=quant_config,
        activation_sparsity=config.activation_sparsity_pattern[
            extract_layer_index(prefix)],
        prefix=f"{prefix}.mlp",
    )
    self.laurel = Gemma3nLaurelBlock(
        hidden_size=config.hidden_size,
        laurel_rank=config.laurel_rank,
        rms_norm_eps=config.rms_norm_eps,
        quant_config=quant_config,
        prefix=f"{prefix}.laurel",
    )

    # NOTE(rob): should be ColumnParallelLinear and RowParallelLinear
    # But, we need to add per_layer_input_gate(x) to per_layer_input.
    # per_layer_input cannot be sharded, so we replicate for now.
    self.per_layer_input_gate = ReplicatedLinear(
        config.hidden_size,
        config.hidden_size_per_layer_input,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.per_layer_input_gate",
        return_bias=False,
    )
    self.per_layer_projection = ReplicatedLinear(
        config.hidden_size_per_layer_input,
        config.hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.per_layer_projection",
        return_bias=False,
    )

    # LayerNorms.
    self.input_layernorm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.post_attention_layernorm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.pre_feedforward_layernorm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.post_feedforward_layernorm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )
    self.post_per_layer_input_norm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )

    self.act_fn = _ACTIVATION_REGISTRY[config.hidden_activation]

forward

forward(
    positions: Tensor,
    hidden_states: Tensor,
    per_layer_input: Tensor,
    **kwargs,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    per_layer_input: torch.Tensor,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:

    # ActUp (predict).
    predictions = self.altup.predict(hidden_states)
    active_prediction = predictions[self.altup_active_idx]
    active_prediction_normed = self.input_layernorm(active_prediction)
    laurel_output = self.laurel(active_prediction_normed)

    # Attention.
    attn = self.self_attn(
        positions=positions,
        hidden_states=active_prediction_normed,
        **kwargs,
    )
    attn = self.post_attention_layernorm(attn)
    attn_gated = attn + active_prediction
    attn_laurel = (attn_gated + laurel_output) / torch.sqrt(
        torch.tensor(2.0))

    # MLP.
    attn_norm = self.pre_feedforward_layernorm(attn_laurel)
    attn_ffw = self.mlp(attn_norm)
    attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
    attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm

    # ActUp (connect).
    corrected_predictions = self.altup.correct(predictions,
                                               attn_ffw_laurel_gated)
    first_prediction = corrected_predictions[self.altup_active_idx]
    first_prediction = self.altup.scale_corrected_output(first_prediction)

    # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
    first_prediction = self.per_layer_input_gate(first_prediction)
    first_prediction = self.act_fn(first_prediction)
    first_prediction = torch.mul(first_prediction, per_layer_input)

    # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
    first_prediction = self.per_layer_projection(first_prediction)
    first_prediction = self.post_per_layer_input_norm(first_prediction)
    corrected_predictions[1:] += first_prediction

    return corrected_predictions

Gemma3nForCausalLM

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nForCausalLM(nn.Module):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        lora_config = vllm_config.lora_config
        del lora_config  # Unused.
        super().__init__()
        self.config = config
        self.cache_config = vllm_config.cache_config
        self.model = Gemma3nTextModel(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))
        self.logits_processor = LogitsProcessor(
            config.vocab_size, soft_cap=config.final_logit_softcapping)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        *,
        per_layer_inputs: Optional[torch.Tensor] = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:

        hidden_states = self.model(
            input_ids,
            positions,
            per_layer_inputs=per_layer_inputs,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.model.embed_tokens, hidden_states)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self,
                                   skip_substrs=([
                                       "embed_audio.", "embed_vision.",
                                       "audio_tower.", "vision_tower."
                                   ]))
        return loader.load_weights(weights)

cache_config instance-attribute

cache_config = cache_config

config instance-attribute

config = config

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    vocab_size, soft_cap=final_logit_softcapping
)

model instance-attribute

model = Gemma3nTextModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gemma3n.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    config = vllm_config.model_config.hf_config
    lora_config = vllm_config.lora_config
    del lora_config  # Unused.
    super().__init__()
    self.config = config
    self.cache_config = vllm_config.cache_config
    self.model = Gemma3nTextModel(vllm_config=vllm_config,
                                  prefix=maybe_prefix(prefix, "model"))
    self.logits_processor = LogitsProcessor(
        config.vocab_size, soft_cap=config.final_logit_softcapping)

compute_logits

compute_logits(hidden_states: Tensor) -> Optional[Tensor]
Source code in vllm/model_executor/models/gemma3n.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.model.embed_tokens, hidden_states)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    *,
    per_layer_inputs: Optional[Tensor] = None,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    *,
    per_layer_inputs: Optional[torch.Tensor] = None,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:

    hidden_states = self.model(
        input_ids,
        positions,
        per_layer_inputs=per_layer_inputs,
        intermediate_tensors=intermediate_tensors,
        inputs_embeds=inputs_embeds,
        **kwargs,
    )
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/gemma3n.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self,
                               skip_substrs=([
                                   "embed_audio.", "embed_vision.",
                                   "audio_tower.", "vision_tower."
                               ]))
    return loader.load_weights(weights)

Gemma3nLaurelBlock

Bases: Module

Learned Augmented Residual Layer

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nLaurelBlock(nn.Module):
    """Learned Augmented Residual Layer"""

    def __init__(
        self,
        hidden_size: int,
        laurel_rank: int,
        rms_norm_eps: float,
        *,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str,
    ) -> None:
        super().__init__()

        self.linear_left = ColumnParallelLinear(
            hidden_size,
            laurel_rank,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_left",
            return_bias=False,
        )
        self.linear_right = RowParallelLinear(
            laurel_rank,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_right",
            return_bias=False,
        )
        self.post_laurel_norm = RMSNorm(
            hidden_size=hidden_size,
            eps=rms_norm_eps,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        laurel_x = self.linear_left(x)
        laurel_x = self.linear_right(laurel_x)
        normed_laurel_x = self.post_laurel_norm(laurel_x)
        return x + normed_laurel_x

linear_left instance-attribute

linear_left = ColumnParallelLinear(
    hidden_size,
    laurel_rank,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.linear_left",
    return_bias=False,
)

linear_right instance-attribute

linear_right = RowParallelLinear(
    laurel_rank,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.linear_right",
    return_bias=False,
)

post_laurel_norm instance-attribute

post_laurel_norm = RMSNorm(
    hidden_size=hidden_size, eps=rms_norm_eps
)

__init__

__init__(
    hidden_size: int,
    laurel_rank: int,
    rms_norm_eps: float,
    *,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str,
) -> None
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    hidden_size: int,
    laurel_rank: int,
    rms_norm_eps: float,
    *,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str,
) -> None:
    super().__init__()

    self.linear_left = ColumnParallelLinear(
        hidden_size,
        laurel_rank,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.linear_left",
        return_bias=False,
    )
    self.linear_right = RowParallelLinear(
        laurel_rank,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.linear_right",
        return_bias=False,
    )
    self.post_laurel_norm = RMSNorm(
        hidden_size=hidden_size,
        eps=rms_norm_eps,
    )

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    laurel_x = self.linear_left(x)
    laurel_x = self.linear_right(laurel_x)
    normed_laurel_x = self.post_laurel_norm(laurel_x)
    return x + normed_laurel_x

Gemma3nMLP

Bases: Module

Source code in vllm/model_executor/models/gemma3n.py
class Gemma3nMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_activation: str,
        activation_sparsity: float = 0.0,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
        if hidden_activation != "gelu_pytorch_tanh":
            raise ValueError(
                "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
                "function. Please set `hidden_act` and `hidden_activation` to "
                "`gelu_pytorch_tanh`.")

        self.act_fn = GeluAndMulSparse(
            activation_sparsity=activation_sparsity,
            approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul(
                approximate="tanh")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = (
    GeluAndMulSparse(
        activation_sparsity=activation_sparsity,
        approximate="tanh",
    )
    if activation_sparsity > 0.0
    else GeluAndMul(approximate="tanh")
)

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

__init__

__init__(
    hidden_size: int,
    intermediate_size: int,
    hidden_activation: str,
    activation_sparsity: float = 0.0,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    hidden_activation: str,
    activation_sparsity: float = 0.0,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.gate_up_proj = MergedColumnParallelLinear(
        hidden_size,
        [intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
    )
    self.down_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.down_proj",
    )
    if hidden_activation != "gelu_pytorch_tanh":
        raise ValueError(
            "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
            "function. Please set `hidden_act` and `hidden_activation` to "
            "`gelu_pytorch_tanh`.")

    self.act_fn = GeluAndMulSparse(
        activation_sparsity=activation_sparsity,
        approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul(
            approximate="tanh")

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    gate_up, _ = self.gate_up_proj(x)
    x = self.act_fn(gate_up)
    x, _ = self.down_proj(x)
    return x

Gemma3nSelfDecoder

Bases: Module

Includes altup embedding and self decoder layers

Source code in vllm/model_executor/models/gemma3n.py
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
                       kv_sharing_fast_prefill)
class Gemma3nSelfDecoder(nn.Module):
    """
    Includes altup embedding and self decoder layers
    """

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layers: list[Gemma3nDecoderLayer],
        layer_idx_start: int,
    ):
        super().__init__()
        self.decoder_layers = decoder_layers
        self.layer_idx_start = layer_idx_start

        config = vllm_config.model_config.hf_config
        self.config = config
        quant_config = vllm_config.quant_config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
        )
        self.embed_scale = torch.tensor(
            config.hidden_size**0.5,
            dtype=self.embed_tokens.weight.dtype,
        )
        # Additional per-layer embeddings (PLE)
        self.embed_tokens_per_layer = VocabParallelEmbedding(
            config.vocab_size_per_layer_input,
            config.num_hidden_layers * config.hidden_size_per_layer_input,
            quant_config=quant_config,
            prefix=f"{prefix}.per_layer_embed_tokens",
        )
        self.embed_scale_per_layer = torch.tensor(
            config.hidden_size_per_layer_input**0.5,
            dtype=self.embed_tokens.weight.dtype,
        )
        self.per_layer_model_projection = ColumnParallelLinear(
            config.hidden_size,
            config.num_hidden_layers * config.hidden_size_per_layer_input,
            bias=False,
            gather_output=True,
            return_bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.per_layer_model_projection",
        )
        self.per_layer_projection_norm = RMSNorm(
            hidden_size=config.hidden_size_per_layer_input,
            eps=config.rms_norm_eps,
        )
        self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to(
            self.embed_tokens.weight.dtype)
        self.per_layer_projection_scale = torch.tensor(
            config.hidden_size**0.5,
            dtype=self.embed_tokens.weight.dtype,
        )
        self.altup_projections = nn.ModuleList([
            ColumnParallelLinear(
                config.hidden_size,
                config.hidden_size,
                bias=False,
                gather_output=True,
                return_bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.altup_projections.{idx-1}",
            ) for idx in range(1, self.config.altup_num_inputs)
        ])

    def get_per_layer_input_embeddings(
            self, input_ids: torch.Tensor) -> torch.Tensor:
        # Deal with the fact that vocab_size_per_layer_input < vocab_size
        # which causes us to have some out of vocab tokens by setting
        # those token ids to 0. This matches the HF implementation.
        per_layer_inputs_mask = torch.logical_and(
            input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
        per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
                                              torch.zeros_like(input_ids))
        return self.embed_tokens_per_layer(
            per_layer_inputs_tokens) * self.embed_scale_per_layer

    def get_per_layer_inputs(
        self,
        hidden_states_0: torch.Tensor,
        per_layer_inputs: Optional[torch.Tensor],
    ) -> torch.Tensor:
        per_layer_projection = self.per_layer_model_projection(hidden_states_0)
        per_layer_projection = per_layer_projection.reshape(
            *hidden_states_0.shape[:-1],
            self.config.num_hidden_layers,
            self.config.hidden_size_per_layer_input,
        )
        per_layer_projection = self.per_layer_projection_norm(
            per_layer_projection)
        if per_layer_inputs is not None:
            # Profiling run does not compute per_layer_inputs
            per_layer_inputs = per_layer_projection + per_layer_inputs
            per_layer_inputs *= self.per_layer_input_scale
        else:
            per_layer_inputs = per_layer_projection
        return per_layer_inputs

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids) * self.embed_scale

    def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
        # Altup embed.
        hidden_states = [hidden_states_0] * self.config.altup_num_inputs
        target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
                                      keepdim=True)**0.5
        for i in range(1, self.config.altup_num_inputs):
            hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
            new_magnitude = torch.mean(hidden_states[i]**2,
                                       dim=-1,
                                       keepdim=True)**0.5
            hidden_states[i] *= target_magnitude / torch.maximum(
                new_magnitude, EPS)
        hidden_states = torch.stack(hidden_states, dim=-1)
        return hidden_states

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
        per_layer_inputs: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if inputs_embeds is not None:
            hidden_states_0 = inputs_embeds
        else:
            hidden_states_0 = self.get_input_embeddings(input_ids)

        adjusted_per_layer_inputs = self.get_per_layer_inputs(
            hidden_states_0, per_layer_inputs)
        hidden_states = self.altup_embed(hidden_states_0)

        # [altnum_inputs, num_tokens, hidden_size]
        hidden_states = hidden_states.permute(2, 0, 1)

        for idx, layer in enumerate(self.decoder_layers):
            layer_idx = idx + self.layer_idx_start
            # [altup_num_inputs, num_tokens, hidden_size]
            hidden_states = layer(
                positions=positions,
                hidden_states=hidden_states,
                per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
                **kwargs,
            )

        # [num_tokens, hidden_size, altnum_inputs]
        hidden_states = hidden_states.permute(1, 2, 0)

        return hidden_states, adjusted_per_layer_inputs

altup_projections instance-attribute

altup_projections = ModuleList(
    [
        (
            ColumnParallelLinear(
                hidden_size,
                hidden_size,
                bias=False,
                gather_output=True,
                return_bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.altup_projections.{idx - 1}",
            )
        )
        for idx in (range(1, altup_num_inputs))
    ]
)

config instance-attribute

config = config

decoder_layers instance-attribute

decoder_layers = decoder_layers

embed_scale instance-attribute

embed_scale = tensor(hidden_size ** 0.5, dtype=dtype)

embed_scale_per_layer instance-attribute

embed_scale_per_layer = tensor(
    hidden_size_per_layer_input**0.5, dtype=dtype
)

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=f"{prefix}.embed_tokens",
)

embed_tokens_per_layer instance-attribute

embed_tokens_per_layer = VocabParallelEmbedding(
    vocab_size_per_layer_input,
    num_hidden_layers * hidden_size_per_layer_input,
    quant_config=quant_config,
    prefix=f"{prefix}.per_layer_embed_tokens",
)

layer_idx_start instance-attribute

layer_idx_start = layer_idx_start

per_layer_input_scale instance-attribute

per_layer_input_scale = to(dtype)

per_layer_model_projection instance-attribute

per_layer_model_projection = ColumnParallelLinear(
    hidden_size,
    num_hidden_layers * hidden_size_per_layer_input,
    bias=False,
    gather_output=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.per_layer_model_projection",
)

per_layer_projection_norm instance-attribute

per_layer_projection_norm = RMSNorm(
    hidden_size=hidden_size_per_layer_input,
    eps=rms_norm_eps,
)

per_layer_projection_scale instance-attribute

per_layer_projection_scale = tensor(
    hidden_size**0.5, dtype=dtype
)

__init__

__init__(
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    decoder_layers: list[Gemma3nDecoderLayer],
    layer_idx_start: int,
)
Source code in vllm/model_executor/models/gemma3n.py
def __init__(
    self,
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
    decoder_layers: list[Gemma3nDecoderLayer],
    layer_idx_start: int,
):
    super().__init__()
    self.decoder_layers = decoder_layers
    self.layer_idx_start = layer_idx_start

    config = vllm_config.model_config.hf_config
    self.config = config
    quant_config = vllm_config.quant_config

    self.embed_tokens = VocabParallelEmbedding(
        config.vocab_size,
        config.hidden_size,
        quant_config=quant_config,
        prefix=f"{prefix}.embed_tokens",
    )
    self.embed_scale = torch.tensor(
        config.hidden_size**0.5,
        dtype=self.embed_tokens.weight.dtype,
    )
    # Additional per-layer embeddings (PLE)
    self.embed_tokens_per_layer = VocabParallelEmbedding(
        config.vocab_size_per_layer_input,
        config.num_hidden_layers * config.hidden_size_per_layer_input,
        quant_config=quant_config,
        prefix=f"{prefix}.per_layer_embed_tokens",
    )
    self.embed_scale_per_layer = torch.tensor(
        config.hidden_size_per_layer_input**0.5,
        dtype=self.embed_tokens.weight.dtype,
    )
    self.per_layer_model_projection = ColumnParallelLinear(
        config.hidden_size,
        config.num_hidden_layers * config.hidden_size_per_layer_input,
        bias=False,
        gather_output=True,
        return_bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.per_layer_model_projection",
    )
    self.per_layer_projection_norm = RMSNorm(
        hidden_size=config.hidden_size_per_layer_input,
        eps=config.rms_norm_eps,
    )
    self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to(
        self.embed_tokens.weight.dtype)
    self.per_layer_projection_scale = torch.tensor(
        config.hidden_size**0.5,
        dtype=self.embed_tokens.weight.dtype,
    )
    self.altup_projections = nn.ModuleList([
        ColumnParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
            gather_output=True,
            return_bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.altup_projections.{idx-1}",
        ) for idx in range(1, self.config.altup_num_inputs)
    ])

altup_embed

altup_embed(hidden_states_0: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
    # Altup embed.
    hidden_states = [hidden_states_0] * self.config.altup_num_inputs
    target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
                                  keepdim=True)**0.5
    for i in range(1, self.config.altup_num_inputs):
        hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
        new_magnitude = torch.mean(hidden_states[i]**2,
                                   dim=-1,
                                   keepdim=True)**0.5
        hidden_states[i] *= target_magnitude / torch.maximum(
            new_magnitude, EPS)
    hidden_states = torch.stack(hidden_states, dim=-1)
    return hidden_states

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    inputs_embeds: Optional[Tensor] = None,
    per_layer_inputs: Optional[Tensor] = None,
    **kwargs,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    inputs_embeds: Optional[torch.Tensor] = None,
    per_layer_inputs: Optional[torch.Tensor] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    if inputs_embeds is not None:
        hidden_states_0 = inputs_embeds
    else:
        hidden_states_0 = self.get_input_embeddings(input_ids)

    adjusted_per_layer_inputs = self.get_per_layer_inputs(
        hidden_states_0, per_layer_inputs)
    hidden_states = self.altup_embed(hidden_states_0)

    # [altnum_inputs, num_tokens, hidden_size]
    hidden_states = hidden_states.permute(2, 0, 1)

    for idx, layer in enumerate(self.decoder_layers):
        layer_idx = idx + self.layer_idx_start
        # [altup_num_inputs, num_tokens, hidden_size]
        hidden_states = layer(
            positions=positions,
            hidden_states=hidden_states,
            per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
            **kwargs,
        )

    # [num_tokens, hidden_size, altnum_inputs]
    hidden_states = hidden_states.permute(1, 2, 0)

    return hidden_states, adjusted_per_layer_inputs

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids) * self.embed_scale

get_per_layer_input_embeddings

get_per_layer_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_per_layer_input_embeddings(
        self, input_ids: torch.Tensor) -> torch.Tensor:
    # Deal with the fact that vocab_size_per_layer_input < vocab_size
    # which causes us to have some out of vocab tokens by setting
    # those token ids to 0. This matches the HF implementation.
    per_layer_inputs_mask = torch.logical_and(
        input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
    per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
                                          torch.zeros_like(input_ids))
    return self.embed_tokens_per_layer(
        per_layer_inputs_tokens) * self.embed_scale_per_layer

get_per_layer_inputs

get_per_layer_inputs(
    hidden_states_0: Tensor,
    per_layer_inputs: Optional[Tensor],
) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_per_layer_inputs(
    self,
    hidden_states_0: torch.Tensor,
    per_layer_inputs: Optional[torch.Tensor],
) -> torch.Tensor:
    per_layer_projection = self.per_layer_model_projection(hidden_states_0)
    per_layer_projection = per_layer_projection.reshape(
        *hidden_states_0.shape[:-1],
        self.config.num_hidden_layers,
        self.config.hidden_size_per_layer_input,
    )
    per_layer_projection = self.per_layer_projection_norm(
        per_layer_projection)
    if per_layer_inputs is not None:
        # Profiling run does not compute per_layer_inputs
        per_layer_inputs = per_layer_projection + per_layer_inputs
        per_layer_inputs *= self.per_layer_input_scale
    else:
        per_layer_inputs = per_layer_projection
    return per_layer_inputs

Gemma3nTextModel

Bases: Module, SupportsQuant

Source code in vllm/model_executor/models/gemma3n.py
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
                       cache_config.kv_sharing_fast_prefill)
class Gemma3nTextModel(nn.Module, SupportsQuant):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config

        self.altup_unembed_projections = nn.ModuleList([
            ColumnParallelLinear(
                config.hidden_size,
                config.hidden_size,
                bias=False,
                gather_output=True,
                return_bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
            ) for idx in range(1, self.config.altup_num_inputs)
        ])

        # Allocate config.num_kv_shared_layers layers for self-decoder
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Gemma3nDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers")

        first_kv_shared_layer_idx = (config.num_hidden_layers -
                                     config.num_kv_shared_layers)

        # NOTE(sarckk): importing this top level seems to cause issues
        # during running of tests.
        from vllm.compilation.backends import set_model_tag

        # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
        with set_model_tag("self_decoder"):
            self.self_decoder = Gemma3nSelfDecoder(
                vllm_config=vllm_config,
                prefix=f"{prefix}.self_decoder",
                decoder_layers=self.layers[:first_kv_shared_layer_idx],
                layer_idx_start=0,
            )
        # Layer idx 20-30 are cross-decoder layers in YOCO
        with set_model_tag("cross_decoder"):
            self.cross_decoder = Gemma3nCrossDecoder(
                vllm_config=vllm_config,
                prefix=f"{prefix}.cross_decoder",
                decoder_layers=self.layers[first_kv_shared_layer_idx:],
                layer_idx_start=first_kv_shared_layer_idx,
            )

        self.norm = RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
        )

        self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill

        if self.fast_prefill_enabled:
            # Allocate static buffers for CUDAGraph
            # TODO(sarckk): Extract this functionality to interface
            max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
            device = next(self.parameters()).device
            self.positions = torch.zeros(max_num_tokens,
                                         dtype=torch.int64,
                                         device=device)
            self.hidden_states = torch.zeros(
                (max_num_tokens, config.hidden_size,
                 self.config.altup_num_inputs),
                dtype=self.embed_tokens.weight.dtype,
                device=device,
            )
            self.per_layer_inputs = torch.zeros(
                (max_num_tokens, self.config.num_hidden_layers,
                 self.config.hidden_size_per_layer_input),
                dtype=self.embed_tokens.weight.dtype,
                device=device,
            )

    @property
    def embed_tokens(self):
        return self.self_decoder.embed_tokens

    def get_per_layer_input_embeddings(
            self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.self_decoder.get_per_layer_input_embeddings(input_ids)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.self_decoder.get_input_embeddings(input_ids)

    def fast_prefill_forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
        per_layer_inputs: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        logits_indices_padded, num_logits_indices = None, None
        attn_metadata = get_forward_context().attn_metadata

        # attn_metadata is None during dummy runs
        if (self.fast_prefill_enabled and attn_metadata is not None):
            assert isinstance(attn_metadata, dict)
            # Last layer is a KV sharing layer
            layer_attn_metadata = attn_metadata[
                self.layers[-1].self_attn.attn.layer_name]
            if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
                logits_indices_padded = (
                    layer_attn_metadata.logits_indices_padded)
                num_logits_indices = layer_attn_metadata.num_logits_indices

        # Copy inputs for cudagraph
        batch_size = positions.size(0)
        self.positions[:batch_size].copy_(positions)
        self_decoder_hidden_states, per_layer_inputs_adjusted = \
            self.self_decoder(
                input_ids=input_ids,
                positions=self.positions[:batch_size],
                inputs_embeds=inputs_embeds,
                per_layer_inputs=per_layer_inputs,
                **kwargs,
            )

        if logits_indices_padded is None:
            logits_indices_padded = torch.arange(
                positions.size(0),
                dtype=positions.dtype,
                device=positions.device,
            )

        # NOTE(sarckk): There is currently a bug caused by
        # vLLM converting output of last piecewise CUDA graph
        # to weakref, causing memory to be prematurely freed
        # when there are multiple compilation units
        # Keep .clone() until fix in
        # https://github.com/vllm-project/vllm/pull/22282
        hidden_states = self_decoder_hidden_states.clone()

        # Copy inputs for cudagraph
        num_padded_logits_indices = logits_indices_padded.size(0)
        self.positions[:num_padded_logits_indices].copy_(
            positions[logits_indices_padded])
        self.hidden_states[:num_padded_logits_indices].copy_(
            self_decoder_hidden_states[logits_indices_padded])
        self.per_layer_inputs[:num_padded_logits_indices].copy_(
            per_layer_inputs_adjusted[logits_indices_padded])
        cross_decoder_hidden_states = self.cross_decoder(
            positions=self.positions[:num_padded_logits_indices],
            hidden_states=self.hidden_states[:num_padded_logits_indices],
            per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
            **kwargs,
        )

        if num_logits_indices is not None:
            assert num_logits_indices > 0
            # Merge cross-decoder and self-decoder hidden states
            hidden_states[logits_indices_padded[:num_logits_indices]] = (
                cross_decoder_hidden_states[:num_logits_indices])
        else:
            hidden_states = cross_decoder_hidden_states

        return hidden_states

    def normal_forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
        per_layer_inputs: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        hidden_states, per_layer_inputs = self.self_decoder(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            per_layer_inputs=per_layer_inputs,
            **kwargs,
        )
        hidden_states = self.cross_decoder(
            positions=positions,
            hidden_states=hidden_states,
            per_layer_inputs=per_layer_inputs,
            **kwargs,
        )
        return hidden_states

    def altup_unembed(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Altup unembed.
        target_magnitude = torch.mean(hidden_states[..., 0]**2,
                                      dim=-1,
                                      keepdim=True)**0.5
        for i in range(1, self.config.altup_num_inputs):
            hidden_states[..., i] = self.altup_unembed_projections[i - 1](
                hidden_states[..., i])
            new_magnitude = torch.mean(hidden_states[..., i]**2,
                                       dim=-1,
                                       keepdim=True)**0.5
            hidden_states[..., i] *= target_magnitude / torch.maximum(
                new_magnitude, EPS)
        # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
        hidden_states = torch.mean(hidden_states, dim=-1)
        return hidden_states

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        per_layer_inputs: Optional[torch.Tensor] = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if self.fast_prefill_enabled:
            hidden_states = self.fast_prefill_forward(
                input_ids,
                positions,
                inputs_embeds,
                per_layer_inputs,
                **kwargs,
            )
        else:
            hidden_states = self.normal_forward(
                input_ids,
                positions,
                inputs_embeds,
                per_layer_inputs,
                **kwargs,
            )
        hidden_states = self.altup_unembed(hidden_states)
        return self.norm(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            # decoder layer weights, altup_unembed_projections and rmsnorm
            # are initialized in text model, others are in self decoder
            if (not name.startswith('layers')
                    and not name.startswith('altup_unembed_projections')
                    and not name.startswith('norm')):
                name = f"self_decoder.{name}"

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for (param_name, shard_name, shard_id) in stacked_params_mapping:
                if shard_name not in name:
                    continue
                # Avoid spurious match with ".up_proj".
                if "altup_projections" in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)

        return loaded_params

altup_unembed_projections instance-attribute

altup_unembed_projections = ModuleList(
    [
        (
            ColumnParallelLinear(
                hidden_size,
                hidden_size,
                bias=False,
                gather_output=True,
                return_bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.altup_unembed_projections.{idx - 1}",
            )
        )
        for idx in (range(1, altup_num_inputs))
    ]
)

config instance-attribute

config = config

cross_decoder instance-attribute

cross_decoder = Gemma3nCrossDecoder(
    vllm_config=vllm_config,
    prefix=f"{prefix}.cross_decoder",
    decoder_layers=layers[first_kv_shared_layer_idx:],
    layer_idx_start=first_kv_shared_layer_idx,
)

embed_tokens property

embed_tokens

fast_prefill_enabled instance-attribute

fast_prefill_enabled = kv_sharing_fast_prefill

hidden_states instance-attribute

hidden_states = zeros(
    (max_num_tokens, hidden_size, altup_num_inputs),
    dtype=dtype,
    device=device,
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

per_layer_inputs instance-attribute

per_layer_inputs = zeros(
    (
        max_num_tokens,
        num_hidden_layers,
        hidden_size_per_layer_input,
    ),
    dtype=dtype,
    device=device,
)

positions instance-attribute

positions = zeros(
    max_num_tokens, dtype=int64, device=device
)

quant_config instance-attribute

quant_config = quant_config

self_decoder instance-attribute

self_decoder = Gemma3nSelfDecoder(
    vllm_config=vllm_config,
    prefix=f"{prefix}.self_decoder",
    decoder_layers=layers[:first_kv_shared_layer_idx],
    layer_idx_start=0,
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gemma3n.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config
    self.config = config
    self.quant_config = quant_config

    self.altup_unembed_projections = nn.ModuleList([
        ColumnParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
            gather_output=True,
            return_bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
        ) for idx in range(1, self.config.altup_num_inputs)
    ])

    # Allocate config.num_kv_shared_layers layers for self-decoder
    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: Gemma3nDecoderLayer(
            config, cache_config, quant_config, prefix=prefix),
        prefix=f"{prefix}.layers")

    first_kv_shared_layer_idx = (config.num_hidden_layers -
                                 config.num_kv_shared_layers)

    # NOTE(sarckk): importing this top level seems to cause issues
    # during running of tests.
    from vllm.compilation.backends import set_model_tag

    # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
    with set_model_tag("self_decoder"):
        self.self_decoder = Gemma3nSelfDecoder(
            vllm_config=vllm_config,
            prefix=f"{prefix}.self_decoder",
            decoder_layers=self.layers[:first_kv_shared_layer_idx],
            layer_idx_start=0,
        )
    # Layer idx 20-30 are cross-decoder layers in YOCO
    with set_model_tag("cross_decoder"):
        self.cross_decoder = Gemma3nCrossDecoder(
            vllm_config=vllm_config,
            prefix=f"{prefix}.cross_decoder",
            decoder_layers=self.layers[first_kv_shared_layer_idx:],
            layer_idx_start=first_kv_shared_layer_idx,
        )

    self.norm = RMSNorm(
        config.hidden_size,
        eps=config.rms_norm_eps,
    )

    self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill

    if self.fast_prefill_enabled:
        # Allocate static buffers for CUDAGraph
        # TODO(sarckk): Extract this functionality to interface
        max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
        device = next(self.parameters()).device
        self.positions = torch.zeros(max_num_tokens,
                                     dtype=torch.int64,
                                     device=device)
        self.hidden_states = torch.zeros(
            (max_num_tokens, config.hidden_size,
             self.config.altup_num_inputs),
            dtype=self.embed_tokens.weight.dtype,
            device=device,
        )
        self.per_layer_inputs = torch.zeros(
            (max_num_tokens, self.config.num_hidden_layers,
             self.config.hidden_size_per_layer_input),
            dtype=self.embed_tokens.weight.dtype,
            device=device,
        )

altup_unembed

altup_unembed(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def altup_unembed(
    self,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    # Altup unembed.
    target_magnitude = torch.mean(hidden_states[..., 0]**2,
                                  dim=-1,
                                  keepdim=True)**0.5
    for i in range(1, self.config.altup_num_inputs):
        hidden_states[..., i] = self.altup_unembed_projections[i - 1](
            hidden_states[..., i])
        new_magnitude = torch.mean(hidden_states[..., i]**2,
                                   dim=-1,
                                   keepdim=True)**0.5
        hidden_states[..., i] *= target_magnitude / torch.maximum(
            new_magnitude, EPS)
    # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
    hidden_states = torch.mean(hidden_states, dim=-1)
    return hidden_states

fast_prefill_forward

fast_prefill_forward(
    input_ids: Tensor,
    positions: Tensor,
    inputs_embeds: Optional[Tensor] = None,
    per_layer_inputs: Optional[Tensor] = None,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def fast_prefill_forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    inputs_embeds: Optional[torch.Tensor] = None,
    per_layer_inputs: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    logits_indices_padded, num_logits_indices = None, None
    attn_metadata = get_forward_context().attn_metadata

    # attn_metadata is None during dummy runs
    if (self.fast_prefill_enabled and attn_metadata is not None):
        assert isinstance(attn_metadata, dict)
        # Last layer is a KV sharing layer
        layer_attn_metadata = attn_metadata[
            self.layers[-1].self_attn.attn.layer_name]
        if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
            logits_indices_padded = (
                layer_attn_metadata.logits_indices_padded)
            num_logits_indices = layer_attn_metadata.num_logits_indices

    # Copy inputs for cudagraph
    batch_size = positions.size(0)
    self.positions[:batch_size].copy_(positions)
    self_decoder_hidden_states, per_layer_inputs_adjusted = \
        self.self_decoder(
            input_ids=input_ids,
            positions=self.positions[:batch_size],
            inputs_embeds=inputs_embeds,
            per_layer_inputs=per_layer_inputs,
            **kwargs,
        )

    if logits_indices_padded is None:
        logits_indices_padded = torch.arange(
            positions.size(0),
            dtype=positions.dtype,
            device=positions.device,
        )

    # NOTE(sarckk): There is currently a bug caused by
    # vLLM converting output of last piecewise CUDA graph
    # to weakref, causing memory to be prematurely freed
    # when there are multiple compilation units
    # Keep .clone() until fix in
    # https://github.com/vllm-project/vllm/pull/22282
    hidden_states = self_decoder_hidden_states.clone()

    # Copy inputs for cudagraph
    num_padded_logits_indices = logits_indices_padded.size(0)
    self.positions[:num_padded_logits_indices].copy_(
        positions[logits_indices_padded])
    self.hidden_states[:num_padded_logits_indices].copy_(
        self_decoder_hidden_states[logits_indices_padded])
    self.per_layer_inputs[:num_padded_logits_indices].copy_(
        per_layer_inputs_adjusted[logits_indices_padded])
    cross_decoder_hidden_states = self.cross_decoder(
        positions=self.positions[:num_padded_logits_indices],
        hidden_states=self.hidden_states[:num_padded_logits_indices],
        per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
        **kwargs,
    )

    if num_logits_indices is not None:
        assert num_logits_indices > 0
        # Merge cross-decoder and self-decoder hidden states
        hidden_states[logits_indices_padded[:num_logits_indices]] = (
            cross_decoder_hidden_states[:num_logits_indices])
    else:
        hidden_states = cross_decoder_hidden_states

    return hidden_states

forward

forward(
    input_ids: Optional[Tensor],
    positions: Tensor,
    per_layer_inputs: Optional[Tensor] = None,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/gemma3n.py
def forward(
    self,
    input_ids: Optional[torch.Tensor],
    positions: torch.Tensor,
    per_layer_inputs: Optional[torch.Tensor] = None,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
    if self.fast_prefill_enabled:
        hidden_states = self.fast_prefill_forward(
            input_ids,
            positions,
            inputs_embeds,
            per_layer_inputs,
            **kwargs,
        )
    else:
        hidden_states = self.normal_forward(
            input_ids,
            positions,
            inputs_embeds,
            per_layer_inputs,
            **kwargs,
        )
    hidden_states = self.altup_unembed(hidden_states)
    return self.norm(hidden_states)

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.self_decoder.get_input_embeddings(input_ids)

get_per_layer_input_embeddings

get_per_layer_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def get_per_layer_input_embeddings(
        self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.self_decoder.get_per_layer_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/gemma3n.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
        ("gate_up_proj", "gate_proj", 0),
        ("gate_up_proj", "up_proj", 1),
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        # decoder layer weights, altup_unembed_projections and rmsnorm
        # are initialized in text model, others are in self decoder
        if (not name.startswith('layers')
                and not name.startswith('altup_unembed_projections')
                and not name.startswith('norm')):
            name = f"self_decoder.{name}"

        if (self.quant_config is not None and
            (scale_name := self.quant_config.get_cache_scale(name))):
            # Loading kv cache scales for compressed-tensors quantization
            param = params_dict[scale_name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            loaded_weight = loaded_weight[0]
            weight_loader(param, loaded_weight)
            loaded_params.add(scale_name)
            continue
        for (param_name, shard_name, shard_id) in stacked_params_mapping:
            if shard_name not in name:
                continue
            # Avoid spurious match with ".up_proj".
            if "altup_projections" in name:
                continue
            name = name.replace(shard_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            # Remapping the name of FP8 kv-scale.
            name = maybe_remap_kv_scale_name(name, params_dict)
            if name is None:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)

    return loaded_params

normal_forward

normal_forward(
    input_ids: Tensor,
    positions: Tensor,
    inputs_embeds: Optional[Tensor] = None,
    per_layer_inputs: Optional[Tensor] = None,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/gemma3n.py
def normal_forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    inputs_embeds: Optional[torch.Tensor] = None,
    per_layer_inputs: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    hidden_states, per_layer_inputs = self.self_decoder(
        input_ids=input_ids,
        positions=positions,
        inputs_embeds=inputs_embeds,
        per_layer_inputs=per_layer_inputs,
        **kwargs,
    )
    hidden_states = self.cross_decoder(
        positions=positions,
        hidden_states=hidden_states,
        per_layer_inputs=per_layer_inputs,
        **kwargs,
    )
    return hidden_states