Skip to content

vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe

flashinfer_fused_moe_blockscale_fp8

flashinfer_fused_moe_blockscale_fp8(
    routing_logits: Tensor,
    routing_bias: Tensor,
    x: Tensor,
    w13_weight: Tensor,
    w13_weight_scale_inv: Tensor,
    w2_weight: Tensor,
    w2_weight_scale_inv: Tensor,
    global_num_experts: int,
    top_k: int,
    num_expert_group: int,
    topk_group: int,
    intermediate_size: int,
    expert_offset: int,
    local_num_experts: int,
    block_shape: List[int],
    routed_scaling: float = 1.0,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
def flashinfer_fused_moe_blockscale_fp8(
        routing_logits: torch.Tensor,
        routing_bias: torch.Tensor,
        x: torch.Tensor,
        w13_weight: torch.Tensor,
        w13_weight_scale_inv: torch.Tensor,
        w2_weight: torch.Tensor,
        w2_weight_scale_inv: torch.Tensor,
        global_num_experts: int,
        top_k: int,
        num_expert_group: int,
        topk_group: int,
        intermediate_size: int,
        expert_offset: int,
        local_num_experts: int,
        block_shape: List[int],  #noqa: UP006
        routed_scaling: float = 1.0) -> torch.Tensor:
    from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
    assert top_k <= global_num_experts
    assert top_k <= 8
    assert topk_group <= 4
    assert global_num_experts > num_expert_group
    assert global_num_experts % num_expert_group == 0
    assert global_num_experts % 4 == 0
    assert top_k < (topk_group * global_num_experts / num_expert_group)
    assert block_shape == [128, 128]
    # Routing kernel expects #experts <= #threads 256
    assert global_num_experts <= 256

    a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
    # NOTE: scales of hidden states have to be transposed!
    a_sf_t = a_sf.t().contiguous()
    return flashinfer_trtllm_fp8_block_scale_moe(
        routing_logits=routing_logits,
        routing_bias=routing_bias,
        hidden_states=a_q,
        hidden_states_scale=a_sf_t,
        gemm1_weights=w13_weight,
        gemm1_weights_scale=w13_weight_scale_inv,
        gemm2_weights=w2_weight,
        gemm2_weights_scale=w2_weight_scale_inv,
        num_experts=global_num_experts,
        top_k=top_k,
        n_group=num_expert_group,
        topk_group=topk_group,
        intermediate_size=intermediate_size,
        local_expert_offset=expert_offset,
        local_num_experts=local_num_experts,
        routed_scaling_factor=routed_scaling,
        tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
                                                  global_num_experts),
        routing_method_type=2,  # DeepSeek-styled routing method
        use_shuffled_weight=False,
    )

flashinfer_fused_moe_blockscale_fp8_fake

flashinfer_fused_moe_blockscale_fp8_fake(
    routing_logits: Tensor,
    routing_bias: Tensor,
    x: Tensor,
    w13_weight: Tensor,
    w13_weight_scale_inv: Tensor,
    w2_weight: Tensor,
    w2_weight_scale_inv: Tensor,
    global_num_experts: int,
    top_k: int,
    num_expert_group: int,
    topk_group: int,
    intermediate_size: int,
    expert_offset: int,
    local_num_experts: int,
    block_shape: list[int],
    routed_scaling: float = 1.0,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
def flashinfer_fused_moe_blockscale_fp8_fake(
        routing_logits: torch.Tensor,
        routing_bias: torch.Tensor,
        x: torch.Tensor,
        w13_weight: torch.Tensor,
        w13_weight_scale_inv: torch.Tensor,
        w2_weight: torch.Tensor,
        w2_weight_scale_inv: torch.Tensor,
        global_num_experts: int,
        top_k: int,
        num_expert_group: int,
        topk_group: int,
        intermediate_size: int,
        expert_offset: int,
        local_num_experts: int,
        block_shape: list[int],
        routed_scaling: float = 1.0) -> torch.Tensor:
    return torch.empty_like(x)

flashinfer_fused_moe_per_tensor_scale_fp8

flashinfer_fused_moe_per_tensor_scale_fp8(
    routing_logits: Tensor,
    routing_bias: Optional[Tensor],
    hidden_states: Tensor,
    input_scale: Tensor,
    gemm1_weights: Tensor,
    gemm2_weights: Tensor,
    output1_scales_scalar: Tensor,
    output1_scales_gate_scalar: Tensor,
    output2_scales_scalar: Tensor,
    num_experts: int,
    top_k: int,
    num_expert_group: Optional[int],
    topk_group: Optional[int],
    intermediate_size: int,
    local_expert_offset: int,
    local_num_experts: int,
    use_routing_scales_on_input: bool,
    routing_method_type: int,
    routed_scaling_factor: float = 1.0,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
def flashinfer_fused_moe_per_tensor_scale_fp8(
        routing_logits: torch.Tensor,
        routing_bias: Optional[torch.Tensor],
        hidden_states: torch.Tensor,
        input_scale: torch.Tensor,
        gemm1_weights: torch.Tensor,
        gemm2_weights: torch.Tensor,
        output1_scales_scalar: torch.Tensor,
        output1_scales_gate_scalar: torch.Tensor,
        output2_scales_scalar: torch.Tensor,
        num_experts: int,
        top_k: int,
        num_expert_group: Optional[int],
        topk_group: Optional[int],
        intermediate_size: int,
        local_expert_offset: int,
        local_num_experts: int,
        use_routing_scales_on_input: bool,
        routing_method_type: int,
        routed_scaling_factor: float = 1.0) -> torch.Tensor:
    num_expert_group = num_expert_group if num_expert_group is not None else 0
    topk_group = topk_group if topk_group is not None else 0

    quant_hidden_states, _ = moe_kernel_quantize_input(
        hidden_states,
        input_scale,
        quant_dtype=torch.float8_e4m3fn,
        per_act_token_quant=False)

    from vllm.utils.flashinfer import (
        flashinfer_trtllm_fp8_per_tensor_scale_moe)
    return flashinfer_trtllm_fp8_per_tensor_scale_moe(
        routing_logits=routing_logits,
        routing_bias=routing_bias,
        hidden_states=quant_hidden_states,
        gemm1_weights=gemm1_weights,
        output1_scales_scalar=output1_scales_scalar,
        output1_scales_gate_scalar=output1_scales_gate_scalar,
        gemm2_weights=gemm2_weights,
        output2_scales_scalar=output2_scales_scalar,
        num_experts=num_experts,
        top_k=top_k,
        n_group=num_expert_group,
        topk_group=topk_group,
        intermediate_size=intermediate_size,
        local_expert_offset=local_expert_offset,
        local_num_experts=local_num_experts,
        routed_scaling_factor=routed_scaling_factor,
        use_routing_scales_on_input=use_routing_scales_on_input,
        tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
                                                  top_k, num_experts),
        routing_method_type=routing_method_type)

flashinfer_fused_moe_per_tensor_scale_fp8_fake

flashinfer_fused_moe_per_tensor_scale_fp8_fake(
    routing_logits: Tensor,
    routing_bias: Optional[Tensor],
    hidden_states: Tensor,
    input_scale: Tensor,
    gemm1_weights: Tensor,
    gemm2_weights: Tensor,
    output1_scales_scalar: Tensor,
    output1_scales_gate_scalar: Tensor,
    output2_scales_scalar: Tensor,
    num_experts: int,
    top_k: int,
    num_expert_group: Optional[int],
    topk_group: Optional[int],
    intermediate_size: int,
    local_expert_offset: int,
    local_num_experts: int,
    use_routing_scales_on_input: bool,
    routing_method_type: int,
    routed_scaling_factor: float = 1.0,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
        routing_logits: torch.Tensor,
        routing_bias: Optional[torch.Tensor],
        hidden_states: torch.Tensor,
        input_scale: torch.Tensor,
        gemm1_weights: torch.Tensor,
        gemm2_weights: torch.Tensor,
        output1_scales_scalar: torch.Tensor,
        output1_scales_gate_scalar: torch.Tensor,
        output2_scales_scalar: torch.Tensor,
        num_experts: int,
        top_k: int,
        num_expert_group: Optional[int],
        topk_group: Optional[int],
        intermediate_size: int,
        local_expert_offset: int,
        local_num_experts: int,
        use_routing_scales_on_input: bool,
        routing_method_type: int,
        routed_scaling_factor: float = 1.0) -> torch.Tensor:
    return torch.empty_like(hidden_states)