Skip to content

vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe

logger module-attribute

logger = init_logger(__name__)

FlashInferExperts

Bases: FusedMoEPermuteExpertsUnpermute

Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):

    def __init__(
        self,
        out_dtype: torch.dtype,
        quant_config: FusedMoEQuantConfig,
        ep_rank: int = 0,
        ep_size: int = 1,
        tp_rank: int = 0,
        tp_size: int = 1,
    ):
        super().__init__(quant_config)
        assert quant_config.quant_dtype in (
            "nvfp4", torch.float8_e4m3fn,
            None), ("Only nvfp4, fp8, bfloat16 and"
                    " float16 quantization are currently supported.")
        self.ep_rank = ep_rank
        self.ep_size = ep_size
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.out_dtype = out_dtype

    @property
    def activation_formats(
        self
    ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
        return (mk.FusedMoEActivationFormat.Standard,
                mk.FusedMoEActivationFormat.Standard)

    def supports_expert_map(self) -> bool:
        return False

    def supports_chunking(self) -> bool:
        # This refers to TP chunking; DP chunking is handled separately.
        return True

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceNoOP()

    def workspace_shapes(
        self,
        a: torch.Tensor,
        aq: torch.Tensor,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
        # We use global_num_experts due to how moe_align_block_size handles
        # expert_maps.
        """
        Compute the shapes for the temporary and final outputs of the two gemms
        and activation in the fused expert function.  Since the gemms are
        independent, the workspace for the first gemm can be shared with the
        workspace for the last gemm.

        Returns a tuple of:
        - workspace13 shape tuple: must be large enough to hold the
          result of either expert gemm.
        - workspace2 shape tuple: must be large enough to hold the
          result of the activation function.
        - output shape tuple: must be exact size of the final gemm output.
        - Workspace type: The dtype to use for the workspace tensors.
        - Note: in order for activation chunking to work, the first dimension
          of each tuple must be the number of tokens.
        """
        aq_m, aq_n = aq.shape
        workspace2 = (0, )
        output_shape = (aq_m,
                        aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m,
                                                                       aq_n)
        workspace_dtype = a.dtype
        workspace1 = output_shape
        # The workspace is determined by `aq`, since it comes after any
        # potential communication op and is involved in the expert computation.
        return (workspace1, workspace2, output_shape, workspace_dtype)

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        expert_map: Optional[torch.Tensor],
        a1q_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],
        workspace13: Optional[torch.Tensor],
        workspace2: Optional[torch.Tensor],
        expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
        apply_router_weight_on_input: Optional[bool],
    ):

        assert activation == "silu", ("Only activation silu is supported in "
                                      "FlashInferExperts")

        if self.quant_dtype == torch.float8_e4m3fn:
            quant_scales = [
                self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
            ]

            a1q_scale = None  # not passing input_sf in fp8
            fc1_expert_weights = w1
            fc2_expert_weights = w2
        elif self.quant_dtype == "nvfp4":
            # Ensure w1_scale and w2_scale are not None before calling view
            assert self.w1_scale is not None and self.w2_scale is not None, (
                "w1_scale and w2_scale must not "
                "be None for FlashInferExperts")
            # Flashinfer CUTLASS kernel takes scalar global scales,
            # min because inv_scale.
            quant_scales = [
                self.a1_gscale,
                self.w1_scale.view(torch.int32),
                self.g1_alphas,
                self.a2_gscale,
                self.w2_scale.view(torch.int32),
                self.g2_alphas,
            ]
            # FlashInfer API requires weight to be long for nvfp4
            fc1_expert_weights = w1.view(torch.long)
            fc2_expert_weights = w2.view(torch.long)
        else:
            quant_scales = None
            a1q_scale = None
            fc1_expert_weights = w1
            fc2_expert_weights = w2

        _ = flashinfer_cutlass_fused_moe(
            input=hidden_states,
            token_selected_experts=topk_ids.to(torch.int),
            token_final_scales=topk_weights,
            fc1_expert_weights=fc1_expert_weights,
            fc2_expert_weights=fc2_expert_weights,
            output_dtype=self.out_dtype,
            quant_scales=quant_scales,
            input_sf=a1q_scale,
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
            ep_size=self.ep_size,
            ep_rank=self.ep_rank,
            output=output,
        )

activation_formats property

ep_rank instance-attribute

ep_rank = ep_rank

ep_size instance-attribute

ep_size = ep_size

out_dtype instance-attribute

out_dtype = out_dtype

tp_rank instance-attribute

tp_rank = tp_rank

tp_size instance-attribute

tp_size = tp_size

__init__

__init__(
    out_dtype: dtype,
    quant_config: FusedMoEQuantConfig,
    ep_rank: int = 0,
    ep_size: int = 1,
    tp_rank: int = 0,
    tp_size: int = 1,
)
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def __init__(
    self,
    out_dtype: torch.dtype,
    quant_config: FusedMoEQuantConfig,
    ep_rank: int = 0,
    ep_size: int = 1,
    tp_rank: int = 0,
    tp_size: int = 1,
):
    super().__init__(quant_config)
    assert quant_config.quant_dtype in (
        "nvfp4", torch.float8_e4m3fn,
        None), ("Only nvfp4, fp8, bfloat16 and"
                " float16 quantization are currently supported.")
    self.ep_rank = ep_rank
    self.ep_size = ep_size
    self.tp_rank = tp_rank
    self.tp_size = tp_size
    self.out_dtype = out_dtype

apply

apply(
    output: Tensor,
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[Tensor],
    a1q_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    workspace13: Optional[Tensor],
    workspace2: Optional[Tensor],
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: Optional[bool],
)
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def apply(
    self,
    output: torch.Tensor,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[torch.Tensor],
    a1q_scale: Optional[torch.Tensor],
    a2_scale: Optional[torch.Tensor],
    workspace13: Optional[torch.Tensor],
    workspace2: Optional[torch.Tensor],
    expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
    apply_router_weight_on_input: Optional[bool],
):

    assert activation == "silu", ("Only activation silu is supported in "
                                  "FlashInferExperts")

    if self.quant_dtype == torch.float8_e4m3fn:
        quant_scales = [
            self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
        ]

        a1q_scale = None  # not passing input_sf in fp8
        fc1_expert_weights = w1
        fc2_expert_weights = w2
    elif self.quant_dtype == "nvfp4":
        # Ensure w1_scale and w2_scale are not None before calling view
        assert self.w1_scale is not None and self.w2_scale is not None, (
            "w1_scale and w2_scale must not "
            "be None for FlashInferExperts")
        # Flashinfer CUTLASS kernel takes scalar global scales,
        # min because inv_scale.
        quant_scales = [
            self.a1_gscale,
            self.w1_scale.view(torch.int32),
            self.g1_alphas,
            self.a2_gscale,
            self.w2_scale.view(torch.int32),
            self.g2_alphas,
        ]
        # FlashInfer API requires weight to be long for nvfp4
        fc1_expert_weights = w1.view(torch.long)
        fc2_expert_weights = w2.view(torch.long)
    else:
        quant_scales = None
        a1q_scale = None
        fc1_expert_weights = w1
        fc2_expert_weights = w2

    _ = flashinfer_cutlass_fused_moe(
        input=hidden_states,
        token_selected_experts=topk_ids.to(torch.int),
        token_final_scales=topk_weights,
        fc1_expert_weights=fc1_expert_weights,
        fc2_expert_weights=fc2_expert_weights,
        output_dtype=self.out_dtype,
        quant_scales=quant_scales,
        input_sf=a1q_scale,
        tp_size=self.tp_size,
        tp_rank=self.tp_rank,
        ep_size=self.ep_size,
        ep_rank=self.ep_rank,
        output=output,
    )

finalize_weight_and_reduce_impl

finalize_weight_and_reduce_impl() -> TopKWeightAndReduce
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
    return TopKWeightAndReduceNoOP()

supports_chunking

supports_chunking() -> bool
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def supports_chunking(self) -> bool:
    # This refers to TP chunking; DP chunking is handled separately.
    return True

supports_expert_map

supports_expert_map() -> bool
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def supports_expert_map(self) -> bool:
    return False

workspace_shapes

workspace_shapes(
    a: Tensor,
    aq: Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[
    tuple[int, ...], tuple[int, ...], tuple[int, ...], dtype
]

Compute the shapes for the temporary and final outputs of the two gemms and activation in the fused expert function. Since the gemms are independent, the workspace for the first gemm can be shared with the workspace for the last gemm.

Returns a tuple of: - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - workspace2 shape tuple: must be large enough to hold the result of the activation function. - output shape tuple: must be exact size of the final gemm output. - Workspace type: The dtype to use for the workspace tensors. - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens.

Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def workspace_shapes(
    self,
    a: torch.Tensor,
    aq: torch.Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
    # We use global_num_experts due to how moe_align_block_size handles
    # expert_maps.
    """
    Compute the shapes for the temporary and final outputs of the two gemms
    and activation in the fused expert function.  Since the gemms are
    independent, the workspace for the first gemm can be shared with the
    workspace for the last gemm.

    Returns a tuple of:
    - workspace13 shape tuple: must be large enough to hold the
      result of either expert gemm.
    - workspace2 shape tuple: must be large enough to hold the
      result of the activation function.
    - output shape tuple: must be exact size of the final gemm output.
    - Workspace type: The dtype to use for the workspace tensors.
    - Note: in order for activation chunking to work, the first dimension
      of each tuple must be the number of tokens.
    """
    aq_m, aq_n = aq.shape
    workspace2 = (0, )
    output_shape = (aq_m,
                    aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m,
                                                                   aq_n)
    workspace_dtype = a.dtype
    workspace1 = output_shape
    # The workspace is determined by `aq`, since it comes after any
    # potential communication op and is involved in the expert computation.
    return (workspace1, workspace2, output_shape, workspace_dtype)

flashinfer_cutlass_moe

flashinfer_cutlass_moe(
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    quant_config: FusedMoEQuantConfig,
    inplace: bool = False,
    activation: str = "silu",
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    tp_rank: int = 0,
    tp_size: int = 1,
    ep_rank: int = 0,
    ep_size: int = 1,
    use_dp: bool = False,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def flashinfer_cutlass_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    quant_config: FusedMoEQuantConfig,
    inplace: bool = False,
    activation: str = "silu",
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    tp_rank: int = 0,
    tp_size: int = 1,
    ep_rank: int = 0,
    ep_size: int = 1,
    use_dp: bool = False,
) -> torch.Tensor:
    fused_experts = mk.FusedMoEModularKernel(
        create_flashinfer_prepare_finalize(use_dp=use_dp),
        FlashInferExperts(
            out_dtype=hidden_states.dtype,
            quant_config=quant_config,
            tp_rank=tp_rank,
            tp_size=tp_size,
            ep_rank=ep_rank,
            ep_size=ep_size,
        ))

    return fused_experts(
        hidden_states=hidden_states,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=inplace,
        activation=activation,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )

flashinfer_cutlass_moe_fp4

flashinfer_cutlass_moe_fp4(
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    quant_config: FusedMoEQuantConfig,
    inplace: bool = False,
    activation: str = "silu",
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def flashinfer_cutlass_moe_fp4(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    quant_config: FusedMoEQuantConfig,
    inplace: bool = False,
    activation: str = "silu",
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
    fused_experts = mk.FusedMoEModularKernel(
        create_flashinfer_prepare_finalize(use_dp=False),
        FlashInferExperts(
            out_dtype=hidden_states.dtype,
            quant_config=quant_config,
        ))

    return fused_experts(
        hidden_states=hidden_states,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=inplace,
        activation=activation,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )

is_valid_flashinfer_cutlass_fused_moe

is_valid_flashinfer_cutlass_fused_moe(
    hidden_states: Tensor, w1: Tensor, w2: Tensor
) -> bool

Check if the given problem size is supported by the FlashInfer CUTLASS MoE kernel.

Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor,
                                          w1: torch.Tensor,
                                          w2: torch.Tensor) -> bool:
    """
    Check if the given problem size is supported by the FlashInfer CUTLASS MoE
    kernel.
    """
    if not has_flashinfer_cutlass_fused_moe():
        logger.debug_once("FlashInferExperts disabled: "
                          "flashinfer_cutlass_fused_moe not available.")
        return False
    # Data type checks
    if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8
            or hidden_states.dtype
            not in [torch.float32, torch.float16, torch.bfloat16]):
        logger.debug_once(
            "FlashInferExperts disabled: w1/w2 must be torch.uint8 "
            f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
            f"float32, float16, or bfloat16 (got {hidden_states.dtype}).")
        return False
    return True