Skip to content

vllm.transformers_utils.configs.qwen3_next

Qwen3-Next model configuration

__all__ module-attribute

__all__ = ['Qwen3NextConfig']

logger module-attribute

logger = get_logger(__name__)

Qwen3NextConfig

Bases: PretrainedConfig

This is the configuration class to store the configuration of a [Qwen3NextModel]. It is used to instantiate a Qwen3-Next model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen3-Next-80B-A3B-Instruct Qwen/Qwen3-Next-80B-A3B-Instruct.

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters:

Name Type Description Default
vocab_size `int`, *optional*, defaults to 151936

Vocabulary size of the model. Defines the number of different tokens that can be represented by the inputs_ids.

151936
hidden_size `int`, *optional*, defaults to 2048

Dimension of the hidden representations.

2048
intermediate_size `int`, *optional*, defaults to 5632

Dimension of the MLP representations.

5632
num_hidden_layers `int`, *optional*, defaults to 48

Number of hidden layers in the Transformer encoder.

48
num_attention_heads `int`, *optional*, defaults to 16

Number of attention heads for each attention layer in the Transformer encoder.

16
num_key_value_heads `int`, *optional*, defaults to 2

This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout this paper. If it is not specified, will default to 32.

2
hidden_act `str`, *optional*, defaults to `"silu"`

The non-linear activation function in the decoder.

'silu'
max_position_embeddings `int`, *optional*, defaults to 32768

The maximum sequence length that this model might ever be used with.

32768
initializer_range `float`, *optional*, defaults to 0.02

The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

0.02
rms_norm_eps `float`, *optional*, defaults to 1e-06

The epsilon used by the rms normalization layers.

1e-06
use_cache `bool`, *optional*, defaults to `True`

Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

True
tie_word_embeddings `bool`, *optional*, defaults to `False`

Whether the model's input and output word embeddings should be tied.

False
rope_theta `float`, *optional*, defaults to 10000.0

The base period of the RoPE embeddings.

10000.0
rope_scaling `Dict`, *optional*

Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents: rope_type (str): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. factor (float, optional): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a factor of x will enable the model to handle sequences of length x * original maximum pre-trained length. original_max_position_embeddings (int, optional): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. attention_factor (float, optional): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the factor field to infer the suggested value. beta_fast (float, optional): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. beta_slow (float, optional): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. short_factor (List[float], optional): Only used with 'longrope'. The scaling factor to be applied to short contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 long_factor (List[float], optional): Only used with 'longrope'. The scaling factor to be applied to long contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 low_freq_factor (float, optional): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE high_freq_factor (float, optional): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE

None
partial_rotary_factor `float`, *optional*, defaults to 0.25

Percentage of the query and keys which will have rotary embedding.

0.25
attention_bias `bool`, *optional*, defaults to `False`

Whether to use a bias in the query, key, value and output projection layers during self-attention.

False
attention_dropout `float`, *optional*, defaults to 0.0

The dropout ratio for the attention probabilities.

0.0
head_dim `int`, *optional*, defaults to 256

Projection weights dimension in multi-head attention.

256
linear_conv_kernel_dim `int`, *optional*, defaults to 4

Kernel size of the convolution used in linear attention layers.

4
linear_key_head_dim `int`, *optional*, defaults to 128

Dimension of each key head in linear attention.

128
linear_value_head_dim `int`, *optional*, defaults to 128

Dimension of each value head in linear attention.

128
linear_num_key_heads `int`, *optional*, defaults to 16

Number of key heads used in linear attention layers.

16
linear_num_value_heads `int`, *optional*, defaults to 32

Number of value heads used in linear attention layers.

32
decoder_sparse_step `int`, *optional*, defaults to 1

The frequency of the MoE layer.

1
moe_intermediate_size `int`, *optional*, defaults to 512

Intermediate size of the routed expert.

512
shared_expert_intermediate_size `int`, *optional*, defaults to 512

Intermediate size of the shared expert.

512
num_experts_per_tok `int`, *optional*, defaults to 10

Number of selected experts.

10
num_experts `int`, *optional*, defaults to 512

Number of routed experts.

512
norm_topk_prob `bool`, *optional*, defaults to `True`

Whether to normalize the topk probabilities.

True
output_router_logits `bool`, *optional*, defaults to `False`

Whether or not the router logits should be returned by the model. Enabling this will also allow the model to output the auxiliary loss, including load balancing loss and router z-loss.

False
router_aux_loss_coef `float`, *optional*, defaults to 0.001

The aux loss factor for the total loss.

0.001
mlp_only_layers `list[int]`, *optional*, defaults to `[]`

Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock The list contains layer index, from 0 to num_layers-1 if we have num_layers layers If mlp_only_layers is empty, decoder_sparse_step is used to determine the sparsity.

None
layer_types `list[str]`, *optional*

Types of each layer (attention or linear).

None
>>> from transformers import Qwen3NextModel, Qwen3NextConfig

>>> # Initializing a Qwen3Next style configuration
>>> configuration =  Qwen3NextConfig()

>>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
>>> model = Qwen3NextModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
Source code in vllm/transformers_utils/configs/qwen3_next.py
class Qwen3NextConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
    Qwen3-Next model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of
    Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
        vocab_size (`int`, *optional*, defaults to 151936):
            Vocabulary size of the model. Defines the number of different tokens that can be represented by the
            `inputs_ids`.
        hidden_size (`int`, *optional*, defaults to 2048):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 5632):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 48):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_key_value_heads (`int`, *optional*, defaults to 2):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
        hidden_act (`str`, *optional*, defaults to `"silu"`):
            The non-linear activation function in the decoder.
        max_position_embeddings (`int`, *optional*, defaults to 32768):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
            Whether the model's input and output word embeddings should be tied.
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        rope_scaling (`Dict`, *optional*):
            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
            accordingly.
            Expected contents:
                `rope_type` (`str`):
                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
                    'llama3'], with 'default' being the original RoPE implementation.
                `factor` (`float`, *optional*):
                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
                    original maximum pre-trained length.
                `original_max_position_embeddings` (`int`, *optional*):
                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
                    pretraining.
                `attention_factor` (`float`, *optional*):
                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
                    computation. If unspecified, it defaults to value recommended by the implementation, using the
                    `factor` field to infer the suggested value.
                `beta_fast` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
                    ramp function. If unspecified, it defaults to 32.
                `beta_slow` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
                    ramp function. If unspecified, it defaults to 1.
                `short_factor` (`List[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `long_factor` (`List[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `low_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
                `high_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
        partial_rotary_factor (`float`, *optional*, defaults to 0.25):
            Percentage of the query and keys which will have rotary embedding.
        attention_bias (`bool`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        head_dim (`int`, *optional*, defaults to 256):
            Projection weights dimension in multi-head attention.
        linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
            Kernel size of the convolution used in linear attention layers.
        linear_key_head_dim (`int`, *optional*, defaults to 128):
            Dimension of each key head in linear attention.
        linear_value_head_dim (`int`, *optional*, defaults to 128):
            Dimension of each value head in linear attention.
        linear_num_key_heads (`int`, *optional*, defaults to 16):
            Number of key heads used in linear attention layers.
        linear_num_value_heads (`int`, *optional*, defaults to 32):
            Number of value heads used in linear attention layers.
        decoder_sparse_step (`int`, *optional*, defaults to 1):
            The frequency of the MoE layer.
        moe_intermediate_size (`int`, *optional*, defaults to 512):
            Intermediate size of the routed expert.
        shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
            Intermediate size of the shared expert.
        num_experts_per_tok (`int`, *optional*, defaults to 10):
            Number of selected experts.
        num_experts (`int`, *optional*, defaults to 512):
            Number of routed experts.
        norm_topk_prob (`bool`, *optional*, defaults to `True`):
            Whether to normalize the topk probabilities.
        output_router_logits (`bool`, *optional*, defaults to `False`):
            Whether or not the router logits should be returned by the model. Enabling this will also
            allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
            The aux loss factor for the total loss.
        mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
            Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
            The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
            If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
        layer_types (`list[str]`, *optional*):
            Types of each layer (attention or linear).

    ```python
    >>> from transformers import Qwen3NextModel, Qwen3NextConfig

    >>> # Initializing a Qwen3Next style configuration
    >>> configuration =  Qwen3NextConfig()

    >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
    >>> model = Qwen3NextModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """  # noqa: E501

    model_type = "qwen3_next"
    keys_to_ignore_at_inference = ["past_key_values"]

    base_model_tp_plan = {
        "layers.*.self_attn.q_proj": "colwise",
        "layers.*.self_attn.k_proj": "colwise",
        "layers.*.self_attn.v_proj": "colwise",
        "layers.*.self_attn.o_proj": "rowwise",
        "layers.*.mlp.experts.*.gate_proj": "colwise",
        "layers.*.mlp.experts.*.up_proj": "colwise",
        "layers.*.mlp.experts.*.down_proj": "rowwise",
        "layers.*.mlp.shared_experts.gate_proj": "colwise",
        "layers.*.mlp.shared_experts.up_proj": "colwise",
        "layers.*.mlp.shared_experts.down_proj": "rowwise",
        "layers.*.mlp.gate_proj": "colwise",
        "layers.*.mlp.up_proj": "colwise",
        "layers.*.mlp.down_proj": "rowwise",
    }
    base_model_pp_plan = {
        "embed_tokens": (["input_ids"], ["inputs_embeds"]),
        "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
        "norm": (["hidden_states"], ["hidden_states"]),
    }

    def __init__(
        self,
        vocab_size=151936,
        hidden_size=2048,
        intermediate_size=5632,
        num_hidden_layers=48,
        num_attention_heads=16,
        num_key_value_heads=2,
        hidden_act="silu",
        max_position_embeddings=32768,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        tie_word_embeddings=False,
        rope_theta=10000.0,
        rope_scaling=None,
        partial_rotary_factor=0.25,
        attention_bias=False,
        attention_dropout=0.0,
        head_dim=256,
        linear_conv_kernel_dim=4,
        linear_key_head_dim=128,
        linear_value_head_dim=128,
        linear_num_key_heads=16,
        linear_num_value_heads=32,
        decoder_sparse_step=1,
        moe_intermediate_size=512,
        shared_expert_intermediate_size=512,
        num_experts_per_tok=10,
        num_experts=512,
        norm_topk_prob=True,
        output_router_logits=False,
        router_aux_loss_coef=0.001,
        mlp_only_layers=None,
        layer_types=None,
        **kwargs,
    ):
        if mlp_only_layers is None:
            mlp_only_layers = []
        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.partial_rotary_factor = partial_rotary_factor
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.head_dim = head_dim
        rope_config_validation(self)

        self.layer_types = layer_types
        if self.layer_types is None:
            self.layer_types = [
                "linear_attention" if bool((i + 1) % 4) else "full_attention"
                for i in range(self.num_hidden_layers)
            ]
        layer_type_validation(self.layer_types)

        # linear attention part
        self.linear_conv_kernel_dim = linear_conv_kernel_dim
        self.linear_key_head_dim = linear_key_head_dim
        self.linear_value_head_dim = linear_value_head_dim
        self.linear_num_key_heads = linear_num_key_heads
        self.linear_num_value_heads = linear_num_value_heads

        # MoE arguments
        self.decoder_sparse_step = decoder_sparse_step
        self.moe_intermediate_size = moe_intermediate_size
        self.shared_expert_intermediate_size = shared_expert_intermediate_size
        self.num_experts_per_tok = num_experts_per_tok
        self.num_experts = num_experts
        self.norm_topk_prob = norm_topk_prob
        self.output_router_logits = output_router_logits
        self.router_aux_loss_coef = router_aux_loss_coef
        self.mlp_only_layers = mlp_only_layers

attention_bias instance-attribute

attention_bias = attention_bias

attention_dropout instance-attribute

attention_dropout = attention_dropout

base_model_pp_plan class-attribute instance-attribute

base_model_pp_plan = {
    "embed_tokens": (["input_ids"], ["inputs_embeds"]),
    "layers": (
        ["hidden_states", "attention_mask"],
        ["hidden_states"],
    ),
    "norm": (["hidden_states"], ["hidden_states"]),
}

base_model_tp_plan class-attribute instance-attribute

base_model_tp_plan = {
    "layers.*.self_attn.q_proj": "colwise",
    "layers.*.self_attn.k_proj": "colwise",
    "layers.*.self_attn.v_proj": "colwise",
    "layers.*.self_attn.o_proj": "rowwise",
    "layers.*.mlp.experts.*.gate_proj": "colwise",
    "layers.*.mlp.experts.*.up_proj": "colwise",
    "layers.*.mlp.experts.*.down_proj": "rowwise",
    "layers.*.mlp.shared_experts.gate_proj": "colwise",
    "layers.*.mlp.shared_experts.up_proj": "colwise",
    "layers.*.mlp.shared_experts.down_proj": "rowwise",
    "layers.*.mlp.gate_proj": "colwise",
    "layers.*.mlp.up_proj": "colwise",
    "layers.*.mlp.down_proj": "rowwise",
}

decoder_sparse_step instance-attribute

decoder_sparse_step = decoder_sparse_step

head_dim instance-attribute

head_dim = head_dim

hidden_act instance-attribute

hidden_act = hidden_act

hidden_size instance-attribute

hidden_size = hidden_size

initializer_range instance-attribute

initializer_range = initializer_range

intermediate_size instance-attribute

intermediate_size = intermediate_size

keys_to_ignore_at_inference class-attribute instance-attribute

keys_to_ignore_at_inference = ['past_key_values']

layer_types instance-attribute

layer_types = layer_types

linear_conv_kernel_dim instance-attribute

linear_conv_kernel_dim = linear_conv_kernel_dim

linear_key_head_dim instance-attribute

linear_key_head_dim = linear_key_head_dim

linear_num_key_heads instance-attribute

linear_num_key_heads = linear_num_key_heads

linear_num_value_heads instance-attribute

linear_num_value_heads = linear_num_value_heads

linear_value_head_dim instance-attribute

linear_value_head_dim = linear_value_head_dim

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

mlp_only_layers instance-attribute

mlp_only_layers = mlp_only_layers

model_type class-attribute instance-attribute

model_type = 'qwen3_next'

moe_intermediate_size instance-attribute

moe_intermediate_size = moe_intermediate_size

norm_topk_prob instance-attribute

norm_topk_prob = norm_topk_prob

num_attention_heads instance-attribute

num_attention_heads = num_attention_heads

num_experts instance-attribute

num_experts = num_experts

num_experts_per_tok instance-attribute

num_experts_per_tok = num_experts_per_tok

num_hidden_layers instance-attribute

num_hidden_layers = num_hidden_layers

num_key_value_heads instance-attribute

num_key_value_heads = num_key_value_heads

output_router_logits instance-attribute

output_router_logits = output_router_logits

partial_rotary_factor instance-attribute

partial_rotary_factor = partial_rotary_factor

rms_norm_eps instance-attribute

rms_norm_eps = rms_norm_eps

rope_scaling instance-attribute

rope_scaling = rope_scaling

rope_theta instance-attribute

rope_theta = rope_theta

router_aux_loss_coef instance-attribute

router_aux_loss_coef = router_aux_loss_coef

shared_expert_intermediate_size instance-attribute

shared_expert_intermediate_size = (
    shared_expert_intermediate_size
)

use_cache instance-attribute

use_cache = use_cache

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    vocab_size=151936,
    hidden_size=2048,
    intermediate_size=5632,
    num_hidden_layers=48,
    num_attention_heads=16,
    num_key_value_heads=2,
    hidden_act="silu",
    max_position_embeddings=32768,
    initializer_range=0.02,
    rms_norm_eps=1e-06,
    use_cache=True,
    tie_word_embeddings=False,
    rope_theta=10000.0,
    rope_scaling=None,
    partial_rotary_factor=0.25,
    attention_bias=False,
    attention_dropout=0.0,
    head_dim=256,
    linear_conv_kernel_dim=4,
    linear_key_head_dim=128,
    linear_value_head_dim=128,
    linear_num_key_heads=16,
    linear_num_value_heads=32,
    decoder_sparse_step=1,
    moe_intermediate_size=512,
    shared_expert_intermediate_size=512,
    num_experts_per_tok=10,
    num_experts=512,
    norm_topk_prob=True,
    output_router_logits=False,
    router_aux_loss_coef=0.001,
    mlp_only_layers=None,
    layer_types=None,
    **kwargs,
)
Source code in vllm/transformers_utils/configs/qwen3_next.py
def __init__(
    self,
    vocab_size=151936,
    hidden_size=2048,
    intermediate_size=5632,
    num_hidden_layers=48,
    num_attention_heads=16,
    num_key_value_heads=2,
    hidden_act="silu",
    max_position_embeddings=32768,
    initializer_range=0.02,
    rms_norm_eps=1e-6,
    use_cache=True,
    tie_word_embeddings=False,
    rope_theta=10000.0,
    rope_scaling=None,
    partial_rotary_factor=0.25,
    attention_bias=False,
    attention_dropout=0.0,
    head_dim=256,
    linear_conv_kernel_dim=4,
    linear_key_head_dim=128,
    linear_value_head_dim=128,
    linear_num_key_heads=16,
    linear_num_value_heads=32,
    decoder_sparse_step=1,
    moe_intermediate_size=512,
    shared_expert_intermediate_size=512,
    num_experts_per_tok=10,
    num_experts=512,
    norm_topk_prob=True,
    output_router_logits=False,
    router_aux_loss_coef=0.001,
    mlp_only_layers=None,
    layer_types=None,
    **kwargs,
):
    if mlp_only_layers is None:
        mlp_only_layers = []
    super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
    self.vocab_size = vocab_size
    self.max_position_embeddings = max_position_embeddings
    self.hidden_size = hidden_size
    self.intermediate_size = intermediate_size
    self.num_hidden_layers = num_hidden_layers
    self.num_attention_heads = num_attention_heads
    self.num_key_value_heads = num_key_value_heads
    self.hidden_act = hidden_act
    self.initializer_range = initializer_range
    self.rms_norm_eps = rms_norm_eps
    self.use_cache = use_cache
    self.rope_theta = rope_theta
    self.rope_scaling = rope_scaling
    self.partial_rotary_factor = partial_rotary_factor
    self.attention_bias = attention_bias
    self.attention_dropout = attention_dropout
    self.head_dim = head_dim
    rope_config_validation(self)

    self.layer_types = layer_types
    if self.layer_types is None:
        self.layer_types = [
            "linear_attention" if bool((i + 1) % 4) else "full_attention"
            for i in range(self.num_hidden_layers)
        ]
    layer_type_validation(self.layer_types)

    # linear attention part
    self.linear_conv_kernel_dim = linear_conv_kernel_dim
    self.linear_key_head_dim = linear_key_head_dim
    self.linear_value_head_dim = linear_value_head_dim
    self.linear_num_key_heads = linear_num_key_heads
    self.linear_num_value_heads = linear_num_value_heads

    # MoE arguments
    self.decoder_sparse_step = decoder_sparse_step
    self.moe_intermediate_size = moe_intermediate_size
    self.shared_expert_intermediate_size = shared_expert_intermediate_size
    self.num_experts_per_tok = num_experts_per_tok
    self.num_experts = num_experts
    self.norm_topk_prob = norm_topk_prob
    self.output_router_logits = output_router_logits
    self.router_aux_loss_coef = router_aux_loss_coef
    self.mlp_only_layers = mlp_only_layers