Skip to content

vllm.model_executor.layers.quantization.quark.quark_moe

__all__ module-attribute

__all__ = [
    "QuarkMoEMethod",
    "QuarkW8A8Fp8MoEMethod",
    "QuarkW4A4MXFp4MoEMethod",
]

logger module-attribute

logger = init_logger(__name__)

QuarkMoEMethod

Bases: FusedMoEMethodBase

Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
class QuarkMoEMethod(FusedMoEMethodBase):

    def __init__(self, moe: FusedMoEConfig):
        super().__init__(moe)

    @staticmethod
    def get_moe_method(
            quant_config: "QuarkConfig",  # type: ignore # noqa E501 # noqa F821
            module: torch.nn.Module,
            layer_name: str) -> "QuarkMoEMethod":
        layer_quant_config = quant_config._find_matched_config(
            layer_name, module)

        if (layer_quant_config.get("output_tensors")
                or layer_quant_config.get("bias")):
            raise NotImplementedError("Currently, Quark models with "
                                      "output_tensors and bias "
                                      "quantized are not supported")
        weight_config = layer_quant_config.get("weight")
        input_config = layer_quant_config.get("input_tensors")

        if quant_config._is_fp8_w8a8(weight_config, input_config):
            return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
                                         module.moe_config)
        elif quant_config._is_mx_fp4(weight_config, input_config):
            return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
                                           module.moe_config)
        else:
            raise RuntimeError("Unsupported FusedMoe scheme")

__init__

__init__(moe: FusedMoEConfig)
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def __init__(self, moe: FusedMoEConfig):
    super().__init__(moe)

get_moe_method staticmethod

get_moe_method(
    quant_config: QuarkConfig,
    module: Module,
    layer_name: str,
) -> QuarkMoEMethod
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
@staticmethod
def get_moe_method(
        quant_config: "QuarkConfig",  # type: ignore # noqa E501 # noqa F821
        module: torch.nn.Module,
        layer_name: str) -> "QuarkMoEMethod":
    layer_quant_config = quant_config._find_matched_config(
        layer_name, module)

    if (layer_quant_config.get("output_tensors")
            or layer_quant_config.get("bias")):
        raise NotImplementedError("Currently, Quark models with "
                                  "output_tensors and bias "
                                  "quantized are not supported")
    weight_config = layer_quant_config.get("weight")
    input_config = layer_quant_config.get("input_tensors")

    if quant_config._is_fp8_w8a8(weight_config, input_config):
        return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
                                     module.moe_config)
    elif quant_config._is_mx_fp4(weight_config, input_config):
        return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
                                       module.moe_config)
    else:
        raise RuntimeError("Unsupported FusedMoe scheme")

QuarkW4A4MXFp4MoEMethod

Bases: QuarkMoEMethod

Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):

    def __init__(
        self,
        weight_config: dict[str, Any],
        input_config: dict[str, Any],
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
        self.weight_quant = weight_config
        self.input_quant = input_config

        weight_qscheme = self.weight_quant.get("qscheme")
        input_qscheme = self.input_quant.get("qscheme")
        if not (weight_qscheme == "per_group"
                and input_qscheme == "per_group"):
            raise ValueError(
                "For MX(FP4) Fused MoE layers, only per-group scales "
                "for weights and activations are supported. Found "
                f"{weight_qscheme}, {input_qscheme}")  # noqa E501

        self.static_input_scales = not self.input_quant.get("is_dynamic")

        if self.static_input_scales:
            raise NotImplementedError(
                "QuarkW4A4MXFp4MoEMethod with static input scales is currently "
                "not implemented. Please open an issue.")

        if not current_platform.supports_mx():
            self.emulate = True
            logger.warning_once(
                "The current platform does not support native MXFP4 "
                "computation. Simulated weight dequantization and activation "
                "QDQ (quantize and dequantize) will be used, with the linear "
                "layers computed in high precision.")
        else:
            self.emulate = True
            logger.warning_once(
                "The current platform supports native MXFP4 "
                "computation, but kernels are not yet integrated in vLLM. "
                "Simulated weight dequantization and activation "
                "QDQ (quantize and dequantize) will be used, with the linear "
                "layers computed in high precision.")

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

        params_dtype = torch.uint8

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size // 2,
            dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)

        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition // 2,
            dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)

        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size // OCP_MX_BLOCK_SIZE,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(
                num_experts,
                hidden_size,
                intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

    def get_fused_moe_quant_config(
            self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
        return mxfp4_w4a4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=None,
            a2_scale=None,
            block_shape=None,
        )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert self.fused_experts is None

        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.")

        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids, _ = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype)

        out = fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            global_num_experts=global_num_experts,
            apply_router_weight_on_input=apply_router_weight_on_input,
            expert_map=expert_map,
            quant_config=self.moe_quant_config,
        )
        return out

emulate instance-attribute

emulate = True

input_quant instance-attribute

input_quant = input_config

static_input_scales instance-attribute

static_input_scales = not get('is_dynamic')

weight_quant instance-attribute

weight_quant = weight_config

__init__

__init__(
    weight_config: dict[str, Any],
    input_config: dict[str, Any],
    moe: FusedMoEConfig,
)
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def __init__(
    self,
    weight_config: dict[str, Any],
    input_config: dict[str, Any],
    moe: FusedMoEConfig,
):
    super().__init__(moe)
    self.weight_quant = weight_config
    self.input_quant = input_config

    weight_qscheme = self.weight_quant.get("qscheme")
    input_qscheme = self.input_quant.get("qscheme")
    if not (weight_qscheme == "per_group"
            and input_qscheme == "per_group"):
        raise ValueError(
            "For MX(FP4) Fused MoE layers, only per-group scales "
            "for weights and activations are supported. Found "
            f"{weight_qscheme}, {input_qscheme}")  # noqa E501

    self.static_input_scales = not self.input_quant.get("is_dynamic")

    if self.static_input_scales:
        raise NotImplementedError(
            "QuarkW4A4MXFp4MoEMethod with static input scales is currently "
            "not implemented. Please open an issue.")

    if not current_platform.supports_mx():
        self.emulate = True
        logger.warning_once(
            "The current platform does not support native MXFP4 "
            "computation. Simulated weight dequantization and activation "
            "QDQ (quantize and dequantize) will be used, with the linear "
            "layers computed in high precision.")
    else:
        self.emulate = True
        logger.warning_once(
            "The current platform supports native MXFP4 "
            "computation, but kernels are not yet integrated in vLLM. "
            "Simulated weight dequantization and activation "
            "QDQ (quantize and dequantize) will be used, with the linear "
            "layers computed in high precision.")

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert self.fused_experts is None

    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.")

    from vllm.model_executor.layers.fused_moe import fused_experts

    topk_weights, topk_ids, _ = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype)

    out = fused_experts(
        x,
        layer.w13_weight,
        layer.w2_weight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        activation=activation,
        global_num_experts=global_num_experts,
        apply_router_weight_on_input=apply_router_weight_on_input,
        expert_map=expert_map,
        quant_config=self.moe_quant_config,
    )
    return out

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    # Add the quantization method used (per tensor/grouped/channel)
    # to ensure the weight scales are loaded in properly
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

    params_dtype = torch.uint8

    # WEIGHTS
    w13_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        2 * intermediate_size_per_partition,
        hidden_size // 2,
        dtype=params_dtype),
                                    requires_grad=False)
    layer.register_parameter("w13_weight", w13_weight)

    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        hidden_size,
        intermediate_size_per_partition // 2,
        dtype=params_dtype),
                                   requires_grad=False)
    layer.register_parameter("w2_weight", w2_weight)

    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    w13_weight_scale = torch.nn.Parameter(
        torch.ones(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size // OCP_MX_BLOCK_SIZE,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    w2_weight_scale = torch.nn.Parameter(
        torch.ones(
            num_experts,
            hidden_size,
            intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    set_weight_attrs(w2_weight_scale, extra_weight_attrs)
    set_weight_attrs(w13_weight_scale, extra_weight_attrs)

    layer.register_parameter("w13_weight_scale", w13_weight_scale)
    layer.register_parameter("w2_weight_scale", w2_weight_scale)

get_fused_moe_quant_config

get_fused_moe_quant_config(
    layer: Module,
) -> Optional[FusedMoEQuantConfig]
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def get_fused_moe_quant_config(
        self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
    return mxfp4_w4a4_moe_quant_config(
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        a1_scale=None,
        a2_scale=None,
        block_shape=None,
    )

QuarkW8A8Fp8MoEMethod

Bases: QuarkMoEMethod

Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):

    def __init__(
        self,
        weight_config: dict[str, Any],
        input_config: dict[str, Any],
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
        self.weight_quant = weight_config
        self.input_quant = input_config

        self.weight_qscheme = self.weight_quant.get("qscheme")
        self.input_qscheme = self.input_quant.get("qscheme")
        per_tensor = (self.weight_qscheme == "per_tensor"
                      and self.input_qscheme == "per_tensor")
        per_channel = (self.weight_qscheme == "per_channel"
                       and self.input_qscheme == "per_channel")
        self.act_quant_group_shape = GroupShape.PER_TOKEN \
            if per_channel else GroupShape.PER_TENSOR
        if not (per_tensor or per_channel):
            raise ValueError(
                "For FP8 Fused MoE layers, only per-tensor and per-channel "
                "scales for weights and activations are supported. Found "
                f"{self.weight_qscheme}, {self.input_qscheme}")  # noqa E501

        self.static_input_scales = not self.input_quant.get("is_dynamic")
        if self.static_input_scales and per_channel:
            raise ValueError(
                "For FP8 Fused MoE layer, we require either per tensor or "
                "channelwise, dynamic per token quantization.")

        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
        params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        if self.weight_qscheme == "per_tensor":
            # Allocate 2 scales for w1 and w3 respectively.
            # They are combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
            # Add PER-TENSOR quantization for FusedMoE.weight_loader.
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
        elif self.weight_qscheme == "per_channel":
            # quark's scale is 1 dim.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts,
                2 * intermediate_size_per_partition,
                dtype=torch.float32),
                                                  requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, hidden_size, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
            # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.static_input_scales:
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
            set_weight_attrs(w13_input_scale, extra_weight_attrs)

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
            set_weight_attrs(w2_input_scale, extra_weight_attrs)
        else:
            layer.w13_input_scale = None
            layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Fp8 moe kernels require a single activation scale.
        # We take the max of all the scales in case they differ.
        if self.static_input_scales:
            if (layer.w13_input_scale is None or layer.w2_input_scale is None):
                raise ValueError(
                    "QuantConfig has static quantization, but found "
                    "activation scales are None.")
            if (not all_close_1d(layer.w13_input_scale)
                    or not all_close_1d(layer.w2_input_scale)):
                logger.warning_once(
                    "Found input_scales that are not equal for "
                    "fp8 MoE layer. Using the maximum across experts "
                    "for each layer. ")
            layer.w13_input_scale = torch.nn.Parameter(
                layer.w13_input_scale.max(), requires_grad=False)
            layer.w2_input_scale = torch.nn.Parameter(
                layer.w2_input_scale.max(), requires_grad=False)

        if current_platform.is_fp8_fnuz():
            # Normalize the weights and scales
            w13_weight, w13_weight_scale, w13_input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    layer.w13_weight, layer.w13_weight_scale,
                    layer.w13_input_scale)
            w2_weight, w2_weight_scale, w2_input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    layer.w2_weight, layer.w2_weight_scale,
                    layer.w2_input_scale)
            # Reset the parameter
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
                                                        requires_grad=False)
            if w13_input_scale is not None:
                layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
                                                           requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                       requires_grad=False)
            if w2_input_scale is not None:
                layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
                                                          requires_grad=False)

        # For per-tensor case, Fp8 moe kernel needs single weight scale
        # for w13 per expert. Use max then dequant and requant each expert.
        if self.weight_qscheme == "per_tensor":
            assert layer.w13_weight_scale is not None
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
            for expert_id in range(layer.local_num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id])
                    layer.w13_weight[expert_id][
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
        # quark's scale is 1 dim.
        elif self.weight_qscheme == "per_channel":
            if self.act_quant_group_shape == GroupShape.PER_TOKEN:
                w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
        # Property to determine if AITER is used
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa E501
                rocm_aiter_fused_experts, shuffle_weights)

            # reshaping weights is required for aiter moe kernel.
            shuffled_w13, shuffled_w2 = shuffle_weights(
                layer.w13_weight.data, layer.w2_weight.data)

            layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                 requires_grad=False)

            self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
        elif self.use_marlin:

            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
            self.fused_experts_func = None
        else:
            from vllm.model_executor.layers.fused_moe import fused_experts
            self.fused_experts_func = fused_experts

    def get_fused_moe_quant_config(
            self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            per_act_token_quant=self.weight_qscheme == "per_channel",
        )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert self.fused_experts is None

        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")

        topk_weights, topk_ids, _ = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype)

        if self.rocm_aiter_moe_enabled:
            return self.rocm_aiter_fused_experts_func(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
                quant_config=self.moe_quant_config,
                expert_map=expert_map)
        if self.use_marlin:
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)

        assert self.fused_experts_func is not None

        return self.fused_experts_func(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            quant_config=self.moe_quant_config)

act_quant_group_shape instance-attribute

act_quant_group_shape = (
    PER_TOKEN if per_channel else PER_TENSOR
)

input_qscheme instance-attribute

input_qscheme = get('qscheme')

input_quant instance-attribute

input_quant = input_config

rocm_aiter_moe_enabled instance-attribute

rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

static_input_scales instance-attribute

static_input_scales = not get('is_dynamic')

use_marlin instance-attribute

use_marlin = (
    not has_device_capability(89)
    or VLLM_TEST_FORCE_FP8_MARLIN
)

weight_qscheme instance-attribute

weight_qscheme = get('qscheme')

weight_quant instance-attribute

weight_quant = weight_config

__init__

__init__(
    weight_config: dict[str, Any],
    input_config: dict[str, Any],
    moe: FusedMoEConfig,
)
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def __init__(
    self,
    weight_config: dict[str, Any],
    input_config: dict[str, Any],
    moe: FusedMoEConfig,
):
    super().__init__(moe)
    self.weight_quant = weight_config
    self.input_quant = input_config

    self.weight_qscheme = self.weight_quant.get("qscheme")
    self.input_qscheme = self.input_quant.get("qscheme")
    per_tensor = (self.weight_qscheme == "per_tensor"
                  and self.input_qscheme == "per_tensor")
    per_channel = (self.weight_qscheme == "per_channel"
                   and self.input_qscheme == "per_channel")
    self.act_quant_group_shape = GroupShape.PER_TOKEN \
        if per_channel else GroupShape.PER_TENSOR
    if not (per_tensor or per_channel):
        raise ValueError(
            "For FP8 Fused MoE layers, only per-tensor and per-channel "
            "scales for weights and activations are supported. Found "
            f"{self.weight_qscheme}, {self.input_qscheme}")  # noqa E501

    self.static_input_scales = not self.input_quant.get("is_dynamic")
    if self.static_input_scales and per_channel:
        raise ValueError(
            "For FP8 Fused MoE layer, we require either per tensor or "
            "channelwise, dynamic per token quantization.")

    # For GPUs that lack FP8 hardware support, we can leverage the Marlin
    # kernel for fast weight-only FP8 quantization
    self.use_marlin = (not current_platform.has_device_capability(89)
                       or envs.VLLM_TEST_FORCE_FP8_MARLIN)
    # Disable marlin for rocm
    if current_platform.is_rocm():
        self.use_marlin = False

    self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert self.fused_experts is None

    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")

    topk_weights, topk_ids, _ = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype)

    if self.rocm_aiter_moe_enabled:
        return self.rocm_aiter_fused_experts_func(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
            quant_config=self.moe_quant_config,
            expert_map=expert_map)
    if self.use_marlin:
        assert activation == "silu", (
            f"{activation} not supported for Marlin MoE.")
        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            None,
            None,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=scalar_types.float8_e4m3fn.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map)

    assert self.fused_experts_func is not None

    return self.fused_experts_func(
        hidden_states=x,
        w1=layer.w13_weight,
        w2=layer.w2_weight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        activation=activation,
        apply_router_weight_on_input=apply_router_weight_on_input,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=self.moe_quant_config)

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    layer.intermediate_size_per_partition = intermediate_size_per_partition
    layer.hidden_size = hidden_size
    layer.num_experts = num_experts
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None
    params_dtype = torch.float8_e4m3fn

    # WEIGHTS
    w13_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        2 * intermediate_size_per_partition,
        hidden_size,
        dtype=params_dtype),
                                    requires_grad=False)
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        hidden_size,
        intermediate_size_per_partition,
        dtype=params_dtype),
                                   requires_grad=False)
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    if self.weight_qscheme == "per_tensor":
        # Allocate 2 scales for w1 and w3 respectively.
        # They are combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts, 2, dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        w2_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        # Add PER-TENSOR quantization for FusedMoE.weight_loader.
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
    elif self.weight_qscheme == "per_channel":
        # quark's scale is 1 dim.
        w13_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts,
            2 * intermediate_size_per_partition,
            dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        w2_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts, hidden_size, dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    # INPUT_SCALES
    if self.static_input_scales:
        w13_input_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w13_input_scale", w13_input_scale)
        set_weight_attrs(w13_input_scale, extra_weight_attrs)

        w2_input_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                            requires_grad=False)
        layer.register_parameter("w2_input_scale", w2_input_scale)
        set_weight_attrs(w2_input_scale, extra_weight_attrs)
    else:
        layer.w13_input_scale = None
        layer.w2_input_scale = None

get_fused_moe_quant_config

get_fused_moe_quant_config(
    layer: Module,
) -> Optional[FusedMoEQuantConfig]
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def get_fused_moe_quant_config(
        self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
    return fp8_w8a8_moe_quant_config(
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        a1_scale=layer.w13_input_scale,
        a2_scale=layer.w2_input_scale,
        per_act_token_quant=self.weight_qscheme == "per_channel",
    )

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/quark/quark_moe.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # Fp8 moe kernels require a single activation scale.
    # We take the max of all the scales in case they differ.
    if self.static_input_scales:
        if (layer.w13_input_scale is None or layer.w2_input_scale is None):
            raise ValueError(
                "QuantConfig has static quantization, but found "
                "activation scales are None.")
        if (not all_close_1d(layer.w13_input_scale)
                or not all_close_1d(layer.w2_input_scale)):
            logger.warning_once(
                "Found input_scales that are not equal for "
                "fp8 MoE layer. Using the maximum across experts "
                "for each layer. ")
        layer.w13_input_scale = torch.nn.Parameter(
            layer.w13_input_scale.max(), requires_grad=False)
        layer.w2_input_scale = torch.nn.Parameter(
            layer.w2_input_scale.max(), requires_grad=False)

    if current_platform.is_fp8_fnuz():
        # Normalize the weights and scales
        w13_weight, w13_weight_scale, w13_input_scale = \
            normalize_e4m3fn_to_e4m3fnuz(
                layer.w13_weight, layer.w13_weight_scale,
                layer.w13_input_scale)
        w2_weight, w2_weight_scale, w2_input_scale = \
            normalize_e4m3fn_to_e4m3fnuz(
                layer.w2_weight, layer.w2_weight_scale,
                layer.w2_input_scale)
        # Reset the parameter
        layer.w13_weight = torch.nn.Parameter(w13_weight,
                                              requires_grad=False)
        layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
                                                    requires_grad=False)
        if w13_input_scale is not None:
            layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
                                                       requires_grad=False)
        layer.w2_weight = torch.nn.Parameter(w2_weight,
                                             requires_grad=False)
        layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                   requires_grad=False)
        if w2_input_scale is not None:
            layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
                                                      requires_grad=False)

    # For per-tensor case, Fp8 moe kernel needs single weight scale
    # for w13 per expert. Use max then dequant and requant each expert.
    if self.weight_qscheme == "per_tensor":
        assert layer.w13_weight_scale is not None
        shard_size = layer.intermediate_size_per_partition
        max_w13_scales = layer.w13_weight_scale.max(dim=1).values
        for expert_id in range(layer.local_num_experts):
            start = 0
            for shard_id in range(2):
                dq_weight = per_tensor_dequantize(
                    layer.w13_weight[expert_id][start:start +
                                                shard_size, :],
                    layer.w13_weight_scale[expert_id][shard_id])
                layer.w13_weight[expert_id][
                    start:start + shard_size, :], _ = ops.scaled_fp8_quant(
                        dq_weight, max_w13_scales[expert_id])
                start += shard_size

        layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                    requires_grad=False)
    # quark's scale is 1 dim.
    elif self.weight_qscheme == "per_channel":
        if self.act_quant_group_shape == GroupShape.PER_TOKEN:
            w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1)
            layer.w13_weight_scale = torch.nn.Parameter(
                w13_weight_scale, requires_grad=False)
            w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1)
            layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                       requires_grad=False)
    # Property to determine if AITER is used
    if self.rocm_aiter_moe_enabled:
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa E501
            rocm_aiter_fused_experts, shuffle_weights)

        # reshaping weights is required for aiter moe kernel.
        shuffled_w13, shuffled_w2 = shuffle_weights(
            layer.w13_weight.data, layer.w2_weight.data)

        layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                              requires_grad=False)
        layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                             requires_grad=False)

        self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
    elif self.use_marlin:

        prepare_moe_fp8_layer_for_marlin(layer, False)
        # Activations not quantized for marlin.
        del layer.w13_input_scale
        del layer.w2_input_scale
        self.fused_experts_func = None
    else:
        from vllm.model_executor.layers.fused_moe import fused_experts
        self.fused_experts_func = fused_experts