Skip to content

vllm.model_executor.layers.fused_moe.modular_kernel

PrepareResultType module-attribute

ReceiverType module-attribute

ReceiverType = Callable[[], PrepareResultType]

ExpertTokensMetadata dataclass

Metadata regarding expert-token routing.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@dataclass
class ExpertTokensMetadata:
    """
  Metadata regarding expert-token routing.
  """
    expert_num_tokens: torch.Tensor
    expert_num_tokens_cpu: Optional[torch.Tensor]

    @staticmethod
    def make_from_list(expert_num_tokens_list: list[int],
                       device: str) -> "ExpertTokensMetadata":
        expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
                                             device="cpu",
                                             dtype=torch.int32)
        return ExpertTokensMetadata(
            expert_num_tokens=expert_num_tokens_cpu.to(device,
                                                       non_blocking=True),
            expert_num_tokens_cpu=expert_num_tokens_cpu)

expert_num_tokens instance-attribute

expert_num_tokens: Tensor

expert_num_tokens_cpu instance-attribute

expert_num_tokens_cpu: Optional[Tensor]

__init__

__init__(
    expert_num_tokens: Tensor,
    expert_num_tokens_cpu: Optional[Tensor],
) -> None

make_from_list staticmethod

make_from_list(
    expert_num_tokens_list: list[int], device: str
) -> ExpertTokensMetadata
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@staticmethod
def make_from_list(expert_num_tokens_list: list[int],
                   device: str) -> "ExpertTokensMetadata":
    expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
                                         device="cpu",
                                         dtype=torch.int32)
    return ExpertTokensMetadata(
        expert_num_tokens=expert_num_tokens_cpu.to(device,
                                                   non_blocking=True),
        expert_num_tokens_cpu=expert_num_tokens_cpu)

FusedMoEActivationFormat

Bases: Enum

The standard activation format (num_tokens, hidden dim).

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
class FusedMoEActivationFormat(Enum):
    """
    The standard activation format (num_tokens, hidden dim).
    """
    Standard = "standard",
    """
    The batched experts format (num experts, max tokens per expert, hidden dim)
    """
    BatchedExperts = "batched_experts",

BatchedExperts class-attribute instance-attribute

BatchedExperts = ('batched_experts',)

Standard class-attribute instance-attribute

Standard = ('standard',)

The batched experts format (num experts, max tokens per expert, hidden dim)

FusedMoEModularKernel

Bases: Module

This class combines a FusedMoEPrepareAndFinalize instance and a FusedMoEPermuteExpertsUnpermute to provide an interface that is compatible with the fused_experts function in fused_moe.py.

It takes care of managing any required scratch space.

Note: Instances of this class should only be used for a single model layer due to any layer specific state that may be used by the component objects.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
@final
class FusedMoEModularKernel(torch.nn.Module):
    """
    This class combines a FusedMoEPrepareAndFinalize instance and
    a FusedMoEPermuteExpertsUnpermute to provide an interface that
    is compatible with the `fused_experts` function in fused_moe.py.

    It takes care of managing any required scratch space.

    Note: Instances of this class should only be used for a single model
    layer due to any layer specific state that may be used by the component
    objects.
    """

    class SharedBuffers:

        def __init__(self) -> None:
            self.fused_out = SharedResizableBuffer()
            self.workspace13 = SharedResizableBuffer()
            self.workspace2 = SharedResizableBuffer()

    # Persistent buffers that are shared across `FusedMoEModularKernel`
    # instances (layers), to save memory and allocattions.
    #
    # We have two sets of buffers to support dual batch overlap (DBO) where each
    # microbatch (ubatch) should use its own set of buffers to avoid
    # cross-ubatch contimination.
    # NOTE that memory is lazily allocated for these buffers, meaning that if
    # DBO isn't being used, the second SharedBuffers will be empty.
    shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]

    def __init__(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        fused_experts: FusedMoEPermuteExpertsUnpermute,
        shared_experts: Optional[torch.nn.Module] = None,
    ):
        super().__init__()
        self.prepare_finalize = prepare_finalize
        self.fused_experts = fused_experts
        self.shared_experts = shared_experts
        assert prepare_finalize.activation_format == \
            fused_experts.activation_formats[0], (
                f"{prepare_finalize.__class__.__name__}."
                f"{prepare_finalize.activation_format} == "
                f"{fused_experts.__class__.__name__}."
                f"{fused_experts.activation_formats[0]}")

    def _do_fused_experts(
        self,
        fused_out: Optional[torch.Tensor],
        a1: torch.Tensor,
        a1q: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        local_num_experts: int,
        expert_map: Optional[torch.Tensor],
        a1q_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],
        expert_tokens_meta: Optional[ExpertTokensMetadata],
        apply_router_weight_on_input: bool,
    ) -> torch.Tensor:

        _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)

        (workspace13_shape, workspace2_shape, fused_out_shape,
         workspace_dtype) = self.fused_experts.workspace_shapes(
             a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
             expert_tokens_meta)

        # select per-ubatch buffers to avoid cross-ubatch reuse under DBO
        ubatch_idx = dbo_current_ubatch_id()
        buffers = self.shared_buffers[ubatch_idx]

        # We can reuse the memory between cache1 and cache3 because by the
        # time we need cache3, we're done with cache1.
        workspace13 = buffers.workspace13.get(workspace13_shape,
                                              device=a1.device,
                                              dtype=workspace_dtype)
        workspace2 = buffers.workspace2.get(workspace2_shape,
                                            device=a1.device,
                                            dtype=workspace_dtype)

        assert fused_out is None or fused_out.shape == fused_out_shape, (
            f"fused_out {fused_out.shape} but expected {fused_out_shape}")
        if fused_out is None:
            # reuse workspace13 for the output
            fused_out = _resize_cache(workspace13, fused_out_shape)

        self.fused_experts.apply(
            fused_out,
            a1q,
            w1,
            w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            a1q_scale=a1q_scale,
            a2_scale=a2_scale,
            workspace13=workspace13,
            workspace2=workspace2,
            expert_tokens_meta=expert_tokens_meta,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

        return fused_out

    def _maybe_chunk_fused_experts(
        self,
        a1: torch.Tensor,
        a1q: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        local_num_experts: int,
        expert_map: Optional[torch.Tensor],
        a1q_scale: Optional[torch.Tensor],
        expert_tokens_meta: Optional[ExpertTokensMetadata],
        apply_router_weight_on_input: bool,
    ) -> torch.Tensor:

        _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)

        CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
        num_chunks = cdiv(M, CHUNK_SIZE)

        # TODO(bnell): get rid of one level here, update slice functions
        # to nops on num_chunks==1

        if not self.fused_experts.supports_chunking() or num_chunks == 1:
            return self._do_fused_experts(
                fused_out=None,
                a1=a1,
                a1q=a1q,
                w1=w1,
                w2=w2,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                global_num_experts=global_num_experts,
                local_num_experts=local_num_experts,
                expert_map=expert_map,
                a1q_scale=a1q_scale,
                a2_scale=self.fused_experts.a2_scale,
                expert_tokens_meta=expert_tokens_meta,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

        # Chunking required case
        assert num_chunks > 1

        # Construct the entire output that can then be processed in chunks.
        (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
            a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
            expert_tokens_meta)
        ubatch_idx = dbo_current_ubatch_id()
        buffers = self.shared_buffers[ubatch_idx]
        fused_out = buffers.fused_out.get(fused_out_shape,
                                          device=a1q.device,
                                          dtype=a1.dtype)

        def slice_input_tensors(
            chunk_idx: int
        ) -> tuple[torch.Tensor, Optional[torch.Tensor],
                   Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
            s = chunk_idx * CHUNK_SIZE
            e = min(s + CHUNK_SIZE, M)
            return (
                a1q[s:e],
                _chunk_scales(a1q_scale, s, e),
                _chunk_scales(self.fused_experts.a2_scale, s, e),
                topk_ids[s:e],
                topk_weights[s:e],
            )

        def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
            assert fused_out.size(0) % M == 0, (
                f"fused_out shape {fused_out.shape} vs M {M}")
            factor = fused_out.size(0) // M
            out_chunk_size = CHUNK_SIZE * factor
            s = chunk_idx * out_chunk_size
            e = min(s + out_chunk_size, fused_out.size(0))
            return fused_out[s:e]

        def slice_expert_tokens_metadata(
                full_expert_tokens_meta: ExpertTokensMetadata,
                chunk_topk_ids: torch.Tensor, local_num_experts: int,
                expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
            # The existing expert_num_tokens is for the entire a1q
            # input. Chunking forces recomputation of the number
            # of tokens assigned to each expert.
            c_expert_num_tokens = count_expert_num_tokens(
                chunk_topk_ids, local_num_experts, expert_map)

            c_expert_num_tokens_cpu = None
            need_expert_num_tokens_cpu = (
                full_expert_tokens_meta.expert_num_tokens_cpu is not None)
            if need_expert_num_tokens_cpu:
                # This is blocking as some implementations need the count
                # on the CPU to determine appropriate input/out fused-moe
                # buffers
                c_expert_num_tokens_cpu = c_expert_num_tokens.to(
                    "cpu", non_blocking=False)

            return ExpertTokensMetadata(
                expert_num_tokens=c_expert_num_tokens,
                expert_num_tokens_cpu=c_expert_num_tokens_cpu)

        for chunk_idx in range(num_chunks):
            c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
                slice_input_tensors(chunk_idx))

            c_expert_tokens_meta = None
            if expert_tokens_meta is not None:
                c_expert_tokens_meta = slice_expert_tokens_metadata(
                    expert_tokens_meta, c_topk_ids, local_num_experts,
                    expert_map)

            self._do_fused_experts(
                fused_out=slice_output_tensor(chunk_idx),
                a1=a1,
                a1q=c_a1q,
                w1=w1,
                w2=w2,
                topk_weights=c_topk_weights,
                topk_ids=c_topk_ids,
                activation=activation,
                global_num_experts=global_num_experts,
                local_num_experts=local_num_experts,
                expert_map=expert_map,
                a1q_scale=c_a1q_scale,
                a2_scale=c_a2_scale,
                expert_tokens_meta=c_expert_tokens_meta,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

        return fused_out

    def forward(
        self,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        inplace: bool = False,
        activation: str = "silu",
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        """
        This function computes a Mixture of Experts (MoE) layer using two sets
        of weights, w1 and w2, and top-k gating mechanism.

        Parameters:
        - hidden_states: (torch.Tensor): The input tensor to the MoE layer.
        - w1 (torch.Tensor): The first set of expert weights.
        - w2 (torch.Tensor): The second set of expert weights.
        - topk_weights (torch.Tensor): The topk weights applied at the end of
          the layer.
        - topk_ids (torch.Tensor): A map of row to expert id.
        - inplace (bool): If True, perform the operation in-place.
          Defaults to False.
        - activation (str): The activation function to apply after the first
          MoE layer.
        - global_num_experts (int): The total number of experts in the global
          expert space.
        - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
          from the global expert space to the local expert space of the expert
          parallel shard.
        - apply_router_weight_on_input (bool): When true, the topk weights are
          applied directly on the inputs. This is only applicable when topk is
          1.

        Returns:
        - torch.Tensor: The output tensor after applying the MoE layer.
        """

        a1 = hidden_states
        if inplace and self.shared_experts is None:
            output = a1
        else:
            output = torch.zeros_like(a1)

        local_num_experts = w1.size(0)
        if global_num_experts == -1:
            global_num_experts = local_num_experts

        if not self.prepare_finalize.supports_async():
            # We shouldn't be running an a2a kernel that doesn't
            # support async prepare/finalize
            # TODO(lucas): enable in follow-up
            assert not dbo_enabled()

            (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
             _expert_topk_weights) = self.prepare_finalize.prepare(
                 a1,
                 topk_weights,
                 topk_ids,
                 global_num_experts,
                 expert_map,
                 apply_router_weight_on_input,
                 self.fused_experts.quant_config,
             )
        else:
            # Overlap shared expert compute with all2all dispatch.
            dbo_maybe_run_recv_hook()
            prepare_ret = self.prepare_finalize.prepare_async(
                a1,
                topk_weights,
                topk_ids,
                global_num_experts,
                expert_map,
                apply_router_weight_on_input,
                self.fused_experts.quant_config,
            )

            # TODO(lucas): refactor this in the alternative schedules followup
            # currently unpack if we have hook + receiver pair or just
            # receiver (see finalize_async docstring)
            hook, receiver = prepare_ret \
                if isinstance(prepare_ret, tuple) else (None, prepare_ret)

            if hook is not None:
                if dbo_enabled():
                    # If DBO is being used, register the hook with the ubatch
                    # context and call it in dbo_maybe_run_recv_hook instead of
                    #  passing it to the receiver.
                    dbo_register_recv_hook(hook)
                    dbo_yield()
                else:
                    hook()

            (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
             _expert_topk_weights) = receiver()

        # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
        topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
        topk_weights = (topk_weights if _expert_topk_weights is None else
                        _expert_topk_weights)

        fused_out = None

        if a1q.numel() == 0:
            # This happens when none of the tokens from the all2all reach this
            # EP rank. Also, note that this is only relevant for CUDAGraph
            # incompatible all2all kernels like the DeepEP high-throughput
            # kernels. CUDAGraph compatible all2all kernels like the pplx
            # kernels and the DeepEP low-latency kernels are always batched
            # and can never run into the tensor.numel() == 0 case.
            fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
        else:
            fused_out = self._maybe_chunk_fused_experts(
                a1=a1,
                a1q=a1q,
                w1=w1,
                w2=w2,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                global_num_experts=global_num_experts,
                local_num_experts=local_num_experts,
                expert_map=expert_map,
                a1q_scale=a1q_scale,
                expert_tokens_meta=expert_tokens_meta,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

        shared_output: Optional[torch.Tensor] = None

        if not self.prepare_finalize.supports_async():
            assert not dbo_enabled()

            self.prepare_finalize.finalize(
                output,
                fused_out,
                topk_weights,
                topk_ids,
                apply_router_weight_on_input,
                self.fused_experts.finalize_weight_and_reduce_impl(),
            )
            if self.shared_experts is not None:
                shared_output = self.shared_experts(a1)
        else:
            finalize_ret = self.prepare_finalize.finalize_async(
                output,
                fused_out,
                topk_weights,
                topk_ids,
                apply_router_weight_on_input,
                self.fused_experts.finalize_weight_and_reduce_impl(),
            )

            if self.shared_experts is not None:
                shared_output = self.shared_experts(a1)

            # TODO(lucas): refactor this in the alternative schedules followup
            # currently unpack if we have hook + receiver pair or just
            # receiver (see finalize_async docstring)
            hook, receiver = finalize_ret \
                if isinstance(finalize_ret, tuple) else (None, finalize_ret)

            if hook is not None:
                if dbo_enabled():
                    # If DBO is being used, register the hook with the ubatch
                    # context and call it in dbo_maybe_run_recv_hook instead of
                    #  passing it to the receiver.
                    dbo_register_recv_hook(hook)
                    dbo_yield()
                else:
                    hook()

            receiver()

        if self.shared_experts is None:
            return output
        else:
            assert shared_output is not None
            return shared_output, output

fused_experts instance-attribute

fused_experts = fused_experts

prepare_finalize instance-attribute

prepare_finalize = prepare_finalize

shared_buffers class-attribute instance-attribute

shared_buffers: list[SharedBuffers] = [
    SharedBuffers(),
    SharedBuffers(),
]

shared_experts instance-attribute

shared_experts = shared_experts

SharedBuffers

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
class SharedBuffers:

    def __init__(self) -> None:
        self.fused_out = SharedResizableBuffer()
        self.workspace13 = SharedResizableBuffer()
        self.workspace2 = SharedResizableBuffer()

fused_out instance-attribute

fused_out = SharedResizableBuffer()

workspace13 instance-attribute

workspace13 = SharedResizableBuffer()

workspace2 instance-attribute

workspace2 = SharedResizableBuffer()

__init__

__init__() -> None
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def __init__(self) -> None:
    self.fused_out = SharedResizableBuffer()
    self.workspace13 = SharedResizableBuffer()
    self.workspace2 = SharedResizableBuffer()

__init__

__init__(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    fused_experts: FusedMoEPermuteExpertsUnpermute,
    shared_experts: Optional[Module] = None,
)
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def __init__(
    self,
    prepare_finalize: FusedMoEPrepareAndFinalize,
    fused_experts: FusedMoEPermuteExpertsUnpermute,
    shared_experts: Optional[torch.nn.Module] = None,
):
    super().__init__()
    self.prepare_finalize = prepare_finalize
    self.fused_experts = fused_experts
    self.shared_experts = shared_experts
    assert prepare_finalize.activation_format == \
        fused_experts.activation_formats[0], (
            f"{prepare_finalize.__class__.__name__}."
            f"{prepare_finalize.activation_format} == "
            f"{fused_experts.__class__.__name__}."
            f"{fused_experts.activation_formats[0]}")

_do_fused_experts

_do_fused_experts(
    fused_out: Optional[Tensor],
    a1: Tensor,
    a1q: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    local_num_experts: int,
    expert_map: Optional[Tensor],
    a1q_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def _do_fused_experts(
    self,
    fused_out: Optional[torch.Tensor],
    a1: torch.Tensor,
    a1q: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    local_num_experts: int,
    expert_map: Optional[torch.Tensor],
    a1q_scale: Optional[torch.Tensor],
    a2_scale: Optional[torch.Tensor],
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
) -> torch.Tensor:

    _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)

    (workspace13_shape, workspace2_shape, fused_out_shape,
     workspace_dtype) = self.fused_experts.workspace_shapes(
         a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
         expert_tokens_meta)

    # select per-ubatch buffers to avoid cross-ubatch reuse under DBO
    ubatch_idx = dbo_current_ubatch_id()
    buffers = self.shared_buffers[ubatch_idx]

    # We can reuse the memory between cache1 and cache3 because by the
    # time we need cache3, we're done with cache1.
    workspace13 = buffers.workspace13.get(workspace13_shape,
                                          device=a1.device,
                                          dtype=workspace_dtype)
    workspace2 = buffers.workspace2.get(workspace2_shape,
                                        device=a1.device,
                                        dtype=workspace_dtype)

    assert fused_out is None or fused_out.shape == fused_out_shape, (
        f"fused_out {fused_out.shape} but expected {fused_out_shape}")
    if fused_out is None:
        # reuse workspace13 for the output
        fused_out = _resize_cache(workspace13, fused_out_shape)

    self.fused_experts.apply(
        fused_out,
        a1q,
        w1,
        w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        activation=activation,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        a1q_scale=a1q_scale,
        a2_scale=a2_scale,
        workspace13=workspace13,
        workspace2=workspace2,
        expert_tokens_meta=expert_tokens_meta,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )

    return fused_out

_maybe_chunk_fused_experts

_maybe_chunk_fused_experts(
    a1: Tensor,
    a1q: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    local_num_experts: int,
    expert_map: Optional[Tensor],
    a1q_scale: Optional[Tensor],
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def _maybe_chunk_fused_experts(
    self,
    a1: torch.Tensor,
    a1q: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    local_num_experts: int,
    expert_map: Optional[torch.Tensor],
    a1q_scale: Optional[torch.Tensor],
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
) -> torch.Tensor:

    _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)

    CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
    num_chunks = cdiv(M, CHUNK_SIZE)

    # TODO(bnell): get rid of one level here, update slice functions
    # to nops on num_chunks==1

    if not self.fused_experts.supports_chunking() or num_chunks == 1:
        return self._do_fused_experts(
            fused_out=None,
            a1=a1,
            a1q=a1q,
            w1=w1,
            w2=w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=activation,
            global_num_experts=global_num_experts,
            local_num_experts=local_num_experts,
            expert_map=expert_map,
            a1q_scale=a1q_scale,
            a2_scale=self.fused_experts.a2_scale,
            expert_tokens_meta=expert_tokens_meta,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

    # Chunking required case
    assert num_chunks > 1

    # Construct the entire output that can then be processed in chunks.
    (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
        a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
        expert_tokens_meta)
    ubatch_idx = dbo_current_ubatch_id()
    buffers = self.shared_buffers[ubatch_idx]
    fused_out = buffers.fused_out.get(fused_out_shape,
                                      device=a1q.device,
                                      dtype=a1.dtype)

    def slice_input_tensors(
        chunk_idx: int
    ) -> tuple[torch.Tensor, Optional[torch.Tensor],
               Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
        s = chunk_idx * CHUNK_SIZE
        e = min(s + CHUNK_SIZE, M)
        return (
            a1q[s:e],
            _chunk_scales(a1q_scale, s, e),
            _chunk_scales(self.fused_experts.a2_scale, s, e),
            topk_ids[s:e],
            topk_weights[s:e],
        )

    def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
        assert fused_out.size(0) % M == 0, (
            f"fused_out shape {fused_out.shape} vs M {M}")
        factor = fused_out.size(0) // M
        out_chunk_size = CHUNK_SIZE * factor
        s = chunk_idx * out_chunk_size
        e = min(s + out_chunk_size, fused_out.size(0))
        return fused_out[s:e]

    def slice_expert_tokens_metadata(
            full_expert_tokens_meta: ExpertTokensMetadata,
            chunk_topk_ids: torch.Tensor, local_num_experts: int,
            expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
        # The existing expert_num_tokens is for the entire a1q
        # input. Chunking forces recomputation of the number
        # of tokens assigned to each expert.
        c_expert_num_tokens = count_expert_num_tokens(
            chunk_topk_ids, local_num_experts, expert_map)

        c_expert_num_tokens_cpu = None
        need_expert_num_tokens_cpu = (
            full_expert_tokens_meta.expert_num_tokens_cpu is not None)
        if need_expert_num_tokens_cpu:
            # This is blocking as some implementations need the count
            # on the CPU to determine appropriate input/out fused-moe
            # buffers
            c_expert_num_tokens_cpu = c_expert_num_tokens.to(
                "cpu", non_blocking=False)

        return ExpertTokensMetadata(
            expert_num_tokens=c_expert_num_tokens,
            expert_num_tokens_cpu=c_expert_num_tokens_cpu)

    for chunk_idx in range(num_chunks):
        c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
            slice_input_tensors(chunk_idx))

        c_expert_tokens_meta = None
        if expert_tokens_meta is not None:
            c_expert_tokens_meta = slice_expert_tokens_metadata(
                expert_tokens_meta, c_topk_ids, local_num_experts,
                expert_map)

        self._do_fused_experts(
            fused_out=slice_output_tensor(chunk_idx),
            a1=a1,
            a1q=c_a1q,
            w1=w1,
            w2=w2,
            topk_weights=c_topk_weights,
            topk_ids=c_topk_ids,
            activation=activation,
            global_num_experts=global_num_experts,
            local_num_experts=local_num_experts,
            expert_map=expert_map,
            a1q_scale=c_a1q_scale,
            a2_scale=c_a2_scale,
            expert_tokens_meta=c_expert_tokens_meta,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

    return fused_out

forward

forward(
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    inplace: bool = False,
    activation: str = "silu",
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
) -> Union[Tensor, tuple[Tensor, Tensor]]

This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.

Parameters: - hidden_states: (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. - topk_ids (torch.Tensor): A map of row to expert id. - inplace (bool): If True, perform the operation in-place. Defaults to False. - activation (str): The activation function to apply after the first MoE layer. - global_num_experts (int): The total number of experts in the global expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1.

Returns: - torch.Tensor: The output tensor after applying the MoE layer.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def forward(
    self,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    inplace: bool = False,
    activation: str = "silu",
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    """
    This function computes a Mixture of Experts (MoE) layer using two sets
    of weights, w1 and w2, and top-k gating mechanism.

    Parameters:
    - hidden_states: (torch.Tensor): The input tensor to the MoE layer.
    - w1 (torch.Tensor): The first set of expert weights.
    - w2 (torch.Tensor): The second set of expert weights.
    - topk_weights (torch.Tensor): The topk weights applied at the end of
      the layer.
    - topk_ids (torch.Tensor): A map of row to expert id.
    - inplace (bool): If True, perform the operation in-place.
      Defaults to False.
    - activation (str): The activation function to apply after the first
      MoE layer.
    - global_num_experts (int): The total number of experts in the global
      expert space.
    - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
      from the global expert space to the local expert space of the expert
      parallel shard.
    - apply_router_weight_on_input (bool): When true, the topk weights are
      applied directly on the inputs. This is only applicable when topk is
      1.

    Returns:
    - torch.Tensor: The output tensor after applying the MoE layer.
    """

    a1 = hidden_states
    if inplace and self.shared_experts is None:
        output = a1
    else:
        output = torch.zeros_like(a1)

    local_num_experts = w1.size(0)
    if global_num_experts == -1:
        global_num_experts = local_num_experts

    if not self.prepare_finalize.supports_async():
        # We shouldn't be running an a2a kernel that doesn't
        # support async prepare/finalize
        # TODO(lucas): enable in follow-up
        assert not dbo_enabled()

        (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
         _expert_topk_weights) = self.prepare_finalize.prepare(
             a1,
             topk_weights,
             topk_ids,
             global_num_experts,
             expert_map,
             apply_router_weight_on_input,
             self.fused_experts.quant_config,
         )
    else:
        # Overlap shared expert compute with all2all dispatch.
        dbo_maybe_run_recv_hook()
        prepare_ret = self.prepare_finalize.prepare_async(
            a1,
            topk_weights,
            topk_ids,
            global_num_experts,
            expert_map,
            apply_router_weight_on_input,
            self.fused_experts.quant_config,
        )

        # TODO(lucas): refactor this in the alternative schedules followup
        # currently unpack if we have hook + receiver pair or just
        # receiver (see finalize_async docstring)
        hook, receiver = prepare_ret \
            if isinstance(prepare_ret, tuple) else (None, prepare_ret)

        if hook is not None:
            if dbo_enabled():
                # If DBO is being used, register the hook with the ubatch
                # context and call it in dbo_maybe_run_recv_hook instead of
                #  passing it to the receiver.
                dbo_register_recv_hook(hook)
                dbo_yield()
            else:
                hook()

        (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
         _expert_topk_weights) = receiver()

    # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
    topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
    topk_weights = (topk_weights if _expert_topk_weights is None else
                    _expert_topk_weights)

    fused_out = None

    if a1q.numel() == 0:
        # This happens when none of the tokens from the all2all reach this
        # EP rank. Also, note that this is only relevant for CUDAGraph
        # incompatible all2all kernels like the DeepEP high-throughput
        # kernels. CUDAGraph compatible all2all kernels like the pplx
        # kernels and the DeepEP low-latency kernels are always batched
        # and can never run into the tensor.numel() == 0 case.
        fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
    else:
        fused_out = self._maybe_chunk_fused_experts(
            a1=a1,
            a1q=a1q,
            w1=w1,
            w2=w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=activation,
            global_num_experts=global_num_experts,
            local_num_experts=local_num_experts,
            expert_map=expert_map,
            a1q_scale=a1q_scale,
            expert_tokens_meta=expert_tokens_meta,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

    shared_output: Optional[torch.Tensor] = None

    if not self.prepare_finalize.supports_async():
        assert not dbo_enabled()

        self.prepare_finalize.finalize(
            output,
            fused_out,
            topk_weights,
            topk_ids,
            apply_router_weight_on_input,
            self.fused_experts.finalize_weight_and_reduce_impl(),
        )
        if self.shared_experts is not None:
            shared_output = self.shared_experts(a1)
    else:
        finalize_ret = self.prepare_finalize.finalize_async(
            output,
            fused_out,
            topk_weights,
            topk_ids,
            apply_router_weight_on_input,
            self.fused_experts.finalize_weight_and_reduce_impl(),
        )

        if self.shared_experts is not None:
            shared_output = self.shared_experts(a1)

        # TODO(lucas): refactor this in the alternative schedules followup
        # currently unpack if we have hook + receiver pair or just
        # receiver (see finalize_async docstring)
        hook, receiver = finalize_ret \
            if isinstance(finalize_ret, tuple) else (None, finalize_ret)

        if hook is not None:
            if dbo_enabled():
                # If DBO is being used, register the hook with the ubatch
                # context and call it in dbo_maybe_run_recv_hook instead of
                #  passing it to the receiver.
                dbo_register_recv_hook(hook)
                dbo_yield()
            else:
                hook()

        receiver()

    if self.shared_experts is None:
        return output
    else:
        assert shared_output is not None
        return shared_output, output

FusedMoEPermuteExpertsUnpermute

Bases: ABC

An abstract base class for the [Permute-Experts-Unpermute] step described above.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
class FusedMoEPermuteExpertsUnpermute(ABC):
    """
    An abstract base class for the [Permute-Experts-Unpermute] step described
    above.
    """

    def __init__(
        self,
        quant_config: FusedMoEQuantConfig,
    ):
        """
        quant_config: Quantization parameters for this experts instance.
        """
        self.quant_config = quant_config

    @property
    @abstractmethod
    def activation_formats(
            self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
        """
        A property which is a tuple of the input and output activation formats
        for the 'apply' method.
        """
        raise NotImplementedError

    #
    # Various helpers for accessing quantization parameters from the
    # quant_config.
    #

    @property
    def quant_dtype(self) -> Optional[torch.dtype]:
        return self.quant_config.quant_dtype

    @property
    def block_shape(self) -> Optional[list[int]]:
        return self.quant_config.block_shape

    @property
    def per_act_token_quant(self) -> bool:
        return self.quant_config.per_act_token_quant

    @property
    def per_out_ch_quant(self) -> bool:
        return self.quant_config.per_out_ch_quant

    @property
    def a1_scale(self) -> Optional[torch.Tensor]:
        return self.quant_config.a1_scale

    @property
    def a2_scale(self) -> Optional[torch.Tensor]:
        return self.quant_config.a2_scale

    @property
    def a1_gscale(self) -> Optional[torch.Tensor]:
        return self.quant_config.a1_gscale

    @property
    def a2_gscale(self) -> Optional[torch.Tensor]:
        return self.quant_config.a2_gscale

    @property
    def w1_scale(self) -> Optional[torch.Tensor]:
        return self.quant_config.w1_scale

    @property
    def w2_scale(self) -> Optional[torch.Tensor]:
        return self.quant_config.w2_scale

    @property
    def w1_zp(self) -> Optional[torch.Tensor]:
        return self.quant_config.w1_zp

    @property
    def w2_zp(self) -> Optional[torch.Tensor]:
        return self.quant_config.w2_zp

    @property
    def w1_bias(self) -> Optional[torch.Tensor]:
        return self.quant_config.w1_bias

    @property
    def w2_bias(self) -> Optional[torch.Tensor]:
        return self.quant_config.w2_bias

    @property
    def g1_alphas(self) -> Optional[torch.Tensor]:
        return self.quant_config.g1_alphas

    @property
    def g2_alphas(self) -> Optional[torch.Tensor]:
        return self.quant_config.g2_alphas

    # TODO (bnell): make this return a CHUNK_SIZE or None instead?
    @abstractmethod
    def supports_chunking(self) -> bool:
        """
        A flag indicating whether or not this class supports activation
        chunking.
        """
        raise NotImplementedError

    @abstractmethod
    def supports_expert_map(self) -> bool:
        """
        A flag indicating whether or not this class supports expert maps
        """
        raise NotImplementedError

    @abstractmethod
    def workspace_shapes(
        self,
        a: torch.Tensor,
        aq: torch.Tensor,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: Optional[ExpertTokensMetadata],
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
        """
        Compute the shapes for the temporary and final outputs of the two gemms
        and activation in the fused expert function.  Since the gemms are
        independent, the workspace for the first gemm can be shared with the
        workspace for the last gemm.

        Returns a tuple of:
        - workspace13 shape tuple: must be large enough to hold the
          result of either expert gemm.
        - workspace2 shape tuple: must be large enough to hold the
          result of the activation function.
        - output shape tuple: must be exact size of the final gemm output.
        - Workspace type: The dtype to use for the workspace tensors.
        - Note: in order for activation chunking to work, the first dimension
          of each tuple must be the number of tokens.
        """
        raise NotImplementedError

    def activation(self, activation: str, output: torch.Tensor,
                   input: torch.Tensor) -> None:
        assert output.size(-1) * 2 == input.size(-1)
        if activation == "silu":
            torch.ops._C.silu_and_mul(output, input)
        elif activation == "gelu":
            torch.ops._C.gelu_and_mul(output, input)
        else:
            raise ValueError(f"Unsupported FusedMoe activation: {activation}")

    def enable_chunking(self):
        return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
          self.supports_chunking()

    def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
        raise NotImplementedError

    @abstractmethod
    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        expert_map: Optional[torch.Tensor],
        a1q_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: Optional[ExpertTokensMetadata],
        apply_router_weight_on_input: bool,
    ):
        """
        This function computes the intermediate result of a Mixture of Experts
        (MoE) layer using two sets of weights, w1 and w2.

        Parameters:
        - output: (torch.Tensor): The unweighted, unreduced output tensor.
        - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
          layer.
        - w1 (torch.Tensor): The first set of expert weights.
        - w2 (torch.Tensor): The second set of expert weights.
        - topk_weights: A map of row to expert weights. Some implementations
          choose to do weight application.
        - topk_ids (torch.Tensor): A map of row to expert id.
        - activation (str): The activation function to apply after the first
          MoE layer.
        - global_num_experts (int): The total number of experts in the global
          expert space.
        - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
          from the global expert space to the local expert space of the expert
          parallel shard.
        - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
          used for a1.  Result of quantization from prepare/finalize and not
          from the FusedMoEQuantConfig.
        - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
          must be large enough to hold output of either MoE gemm.
        - workspace2 (torch.Tensor): A scratch tensor used for the activation
          function.
        - expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
          ExpertTokensMetadata object containing gpu/cpu tensors
          as big as the number of local experts with the information about the
          number of tokens assigned to each local expert.
        - apply_router_weight_on_input: True if router weights are already
          applied on the input. This is relevant if the implementation
          chooses to do weight application.
        """
        raise NotImplementedError

a1_gscale property

a1_gscale: Optional[Tensor]

a1_scale property

a1_scale: Optional[Tensor]

a2_gscale property

a2_gscale: Optional[Tensor]

a2_scale property

a2_scale: Optional[Tensor]

activation_formats abstractmethod property

A property which is a tuple of the input and output activation formats for the 'apply' method.

block_shape property

block_shape: Optional[list[int]]

g1_alphas property

g1_alphas: Optional[Tensor]

g2_alphas property

g2_alphas: Optional[Tensor]

per_act_token_quant property

per_act_token_quant: bool

per_out_ch_quant property

per_out_ch_quant: bool

quant_config instance-attribute

quant_config = quant_config

quant_dtype property

quant_dtype: Optional[dtype]

w1_bias property

w1_bias: Optional[Tensor]

w1_scale property

w1_scale: Optional[Tensor]

w1_zp property

w1_zp: Optional[Tensor]

w2_bias property

w2_bias: Optional[Tensor]

w2_scale property

w2_scale: Optional[Tensor]

w2_zp property

w2_zp: Optional[Tensor]

__init__

__init__(quant_config: FusedMoEQuantConfig)

quant_config: Quantization parameters for this experts instance.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def __init__(
    self,
    quant_config: FusedMoEQuantConfig,
):
    """
    quant_config: Quantization parameters for this experts instance.
    """
    self.quant_config = quant_config

activation

activation(
    activation: str, output: Tensor, input: Tensor
) -> None
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def activation(self, activation: str, output: torch.Tensor,
               input: torch.Tensor) -> None:
    assert output.size(-1) * 2 == input.size(-1)
    if activation == "silu":
        torch.ops._C.silu_and_mul(output, input)
    elif activation == "gelu":
        torch.ops._C.gelu_and_mul(output, input)
    else:
        raise ValueError(f"Unsupported FusedMoe activation: {activation}")

apply abstractmethod

apply(
    output: Tensor,
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[Tensor],
    a1q_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    workspace13: Tensor,
    workspace2: Tensor,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
)

This function computes the intermediate result of a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2.

Parameters: - output: (torch.Tensor): The unweighted, unreduced output tensor. - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights: A map of row to expert weights. Some implementations choose to do weight application. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first MoE layer. - global_num_experts (int): The total number of experts in the global expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be used for a1. Result of quantization from prepare/finalize and not from the FusedMoEQuantConfig. - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation function. - expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional ExpertTokensMetadata object containing gpu/cpu tensors as big as the number of local experts with the information about the number of tokens assigned to each local expert. - apply_router_weight_on_input: True if router weights are already applied on the input. This is relevant if the implementation chooses to do weight application.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def apply(
    self,
    output: torch.Tensor,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[torch.Tensor],
    a1q_scale: Optional[torch.Tensor],
    a2_scale: Optional[torch.Tensor],
    workspace13: torch.Tensor,
    workspace2: torch.Tensor,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
):
    """
    This function computes the intermediate result of a Mixture of Experts
    (MoE) layer using two sets of weights, w1 and w2.

    Parameters:
    - output: (torch.Tensor): The unweighted, unreduced output tensor.
    - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
      layer.
    - w1 (torch.Tensor): The first set of expert weights.
    - w2 (torch.Tensor): The second set of expert weights.
    - topk_weights: A map of row to expert weights. Some implementations
      choose to do weight application.
    - topk_ids (torch.Tensor): A map of row to expert id.
    - activation (str): The activation function to apply after the first
      MoE layer.
    - global_num_experts (int): The total number of experts in the global
      expert space.
    - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
      from the global expert space to the local expert space of the expert
      parallel shard.
    - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
      used for a1.  Result of quantization from prepare/finalize and not
      from the FusedMoEQuantConfig.
    - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
      must be large enough to hold output of either MoE gemm.
    - workspace2 (torch.Tensor): A scratch tensor used for the activation
      function.
    - expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
      ExpertTokensMetadata object containing gpu/cpu tensors
      as big as the number of local experts with the information about the
      number of tokens assigned to each local expert.
    - apply_router_weight_on_input: True if router weights are already
      applied on the input. This is relevant if the implementation
      chooses to do weight application.
    """
    raise NotImplementedError

enable_chunking

enable_chunking()
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def enable_chunking(self):
    return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
      self.supports_chunking()

finalize_weight_and_reduce_impl

finalize_weight_and_reduce_impl() -> TopKWeightAndReduce
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
    raise NotImplementedError

supports_chunking abstractmethod

supports_chunking() -> bool

A flag indicating whether or not this class supports activation chunking.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def supports_chunking(self) -> bool:
    """
    A flag indicating whether or not this class supports activation
    chunking.
    """
    raise NotImplementedError

supports_expert_map abstractmethod

supports_expert_map() -> bool

A flag indicating whether or not this class supports expert maps

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def supports_expert_map(self) -> bool:
    """
    A flag indicating whether or not this class supports expert maps
    """
    raise NotImplementedError

workspace_shapes abstractmethod

workspace_shapes(
    a: Tensor,
    aq: Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[
    tuple[int, ...], tuple[int, ...], tuple[int, ...], dtype
]

Compute the shapes for the temporary and final outputs of the two gemms and activation in the fused expert function. Since the gemms are independent, the workspace for the first gemm can be shared with the workspace for the last gemm.

Returns a tuple of: - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - workspace2 shape tuple: must be large enough to hold the result of the activation function. - output shape tuple: must be exact size of the final gemm output. - Workspace type: The dtype to use for the workspace tensors. - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def workspace_shapes(
    self,
    a: torch.Tensor,
    aq: torch.Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
    """
    Compute the shapes for the temporary and final outputs of the two gemms
    and activation in the fused expert function.  Since the gemms are
    independent, the workspace for the first gemm can be shared with the
    workspace for the last gemm.

    Returns a tuple of:
    - workspace13 shape tuple: must be large enough to hold the
      result of either expert gemm.
    - workspace2 shape tuple: must be large enough to hold the
      result of the activation function.
    - output shape tuple: must be exact size of the final gemm output.
    - Workspace type: The dtype to use for the workspace tensors.
    - Note: in order for activation chunking to work, the first dimension
      of each tuple must be the number of tokens.
    """
    raise NotImplementedError

FusedMoEPrepareAndFinalize

Bases: ABC

An abstract base class for the [Quantize-Prepare] and [Finalize] steps described above.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
class FusedMoEPrepareAndFinalize(ABC):
    """
    An abstract base class for the [Quantize-Prepare] and [Finalize] steps
    described above.
    """

    @abstractmethod
    def prepare(
        self,
        a1: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_experts: int,
        expert_map: Optional[torch.Tensor],
        apply_router_weight_on_input: bool,
        quant_config: FusedMoEQuantConfig,
    ) -> PrepareResultType:
        """
        Perform any quantization (and/or) dispatching needed for this kernel.
        - a1: The (unquantized) input to the MoE layer.
        - topk_ids: The topk ids.
        - topk_weights: The topk weights.
        - num_experts: The total number of experts in the global expert space.
        - expert_map: A tensor mapping expert indices from the global expert
          space to the local expert space of the expert parallel shard.
        - apply_router_weight_on_input: When True, apply the weights to the
          activations, before quantization + dispatching.
        - quant_config: Quantization info provided by the fused experts.

        Returns a tuple of:
        - quantized + dispatched a.
        - Optional quantized + dispatched a1_scales.
        - Optional ExpertTokensMetadata containing gpu/cpu tensors
          as big as the number of local experts with the information about the
          number of tokens assigned to each local expert.
        - Optional dispatched expert topk IDs
        - Optional dispatched expert topk weight
        """
        raise NotImplementedError

    def supports_async(self) -> bool:
        """
        Indicates whether or not this class implements prepare_async and
        finalize_async.
        """
        return False

    def prepare_async(
        self,
        a1: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_experts: int,
        expert_map: Optional[torch.Tensor],
        apply_router_weight_on_input: bool,
        quant_config: FusedMoEQuantConfig,
    ) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
        """
        Perform any quantization (and/or) dispatching needed for this kernel
        but do not wait for results from other workers.
        - a1: The (unquantized) input to the MoE layer.
        - a1_scale: Optional scales for a1
        - a2_scale: Optional scales for the second MoE gemm.  Required to make
          sure the quantization is consistent for both gemms.
        - topk_ids: The topk ids.
        - topk_weights: The topk weights.
        - num_experts: The total number of experts in the global expert space.
        - expert_map: A tensor mapping expert indices from the global expert
          space to the local expert space of the expert parallel shard.
        - apply_router_weight_on_input: When True, apply the weights to the
          activations, before quantization + dispatching.

        Returns a callback or a hook callback pair that when invoked waits for 
        results from other workers and has the same return signature as 
        `prepare`, if a hook is returned this is more lightweight check that
        the recv is complete without doing extra work (used by DBO, will be 
        refactored in the very near future)

        e.g.

        ret = obj.prepare_async(...)

        if isinstance(ret, tuple):
            hook, receiver = ret
            hook()

        if hook is not None:
        a, a_scales, expert_meta, topk_ids, topk_weights = receiver()

        is equivalent to:

        a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
        """
        raise NotImplementedError

    @abstractmethod
    def finalize(
        self,
        output: torch.Tensor,
        fused_expert_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
        weight_and_reduce_impl: TopKWeightAndReduce,
    ) -> None:
        """
        Perform any combine plus apply weights and perform a reduction on the
        fused experts output.
        - output: The output tensor, written in place.  Must be (M, K) shape.
        - fused_expert_output: The unweighted, unreduced output of the fused
          experts, it will have (M, topk, K) shape.
        - topk_weights: The weights to be applied to the fused_experts_output.
        - topk_ids: The topk_ids.
        - apply_router_weight_on_input: When False, apply the weights to
          fused_expert_output.
        - weight_and_reduce_impl: An optional TopKWeightAndReduce
          implementation.
        """
        raise NotImplementedError

    def finalize_async(
        self,
        output: torch.Tensor,
        fused_expert_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
        weight_and_reduce_impl: TopKWeightAndReduce,
    ) -> Union[tuple[Callable, Callable], Callable]:
        """
        Perform any combine plus apply weights and perform a reduction on the
        fused experts output but do not wait for results from other workers.
        - output: The output tensor, written in place.  Must be (M, K) shape.
        - fused_expert_output: The unweighted, unreduced output of the fused
          experts, it will have (M, topk, K) shape.
        - topk_weights: The weights to be applied to the fused_experts_output.
        - topk_ids: The topk_ids.
        - apply_router_weight_on_input: When False, apply the weights to
          fused_expert_output.
        - weight_and_reduce_impl: An optional TopKWeightAndReduce
          implementation.

        Returns a callback or a hook callback pair that when invoked waits for 
        results from other workers and has the same return signature as 
        `finalize`, if a hook is returned this is more lightweight check that
        the recv is complete without doing extra work (used by DBO, will be 
        refactored in the very near future)

        ret = obj.finalize_async(output, ...)
        ... output not valid yet ...
        if isinstance(ret, tuple):
            hook, receiver = ret
            hook()
        receiver()
        ... output valid here ...

        is equivalent to:

        obj.finalize(output, ...)
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def activation_format(self) -> FusedMoEActivationFormat:
        """
        A property indicating the output format of the activations for the
        'prepare' method.
        """
        raise NotImplementedError

    @abstractmethod
    def topk_indices_dtype(self) -> Optional[torch.dtype]:
        """
        The PrepareFinalize All2All implementations generally constrain the
        dtype of the topk_ids they support. This function returns the
        required topk indices dtype so it can be respected.
        Return None if there are no such restrictions.
        """
        raise NotImplementedError

    @abstractmethod
    def max_num_tokens_per_rank(self) -> Optional[int]:
        """
        Some PrepareFinalize All2All implementations are batched. Meaning,
        they can process only as set of tokens at a time. This
        function returns the batch size i.e the maximum number of tokens
        the implementation can process at a time.
        Return None if there are no such restrictions.
        """
        raise NotImplementedError

    @abstractmethod
    def num_dispatchers(self) -> int:
        raise NotImplementedError

activation_format abstractmethod property

activation_format: FusedMoEActivationFormat

A property indicating the output format of the activations for the 'prepare' method.

finalize abstractmethod

finalize(
    output: Tensor,
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: TopKWeightAndReduce,
) -> None

Perform any combine plus apply weights and perform a reduction on the fused experts output. - output: The output tensor, written in place. Must be (M, K) shape. - fused_expert_output: The unweighted, unreduced output of the fused experts, it will have (M, topk, K) shape. - topk_weights: The weights to be applied to the fused_experts_output. - topk_ids: The topk_ids. - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. - weight_and_reduce_impl: An optional TopKWeightAndReduce implementation.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def finalize(
    self,
    output: torch.Tensor,
    fused_expert_output: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: TopKWeightAndReduce,
) -> None:
    """
    Perform any combine plus apply weights and perform a reduction on the
    fused experts output.
    - output: The output tensor, written in place.  Must be (M, K) shape.
    - fused_expert_output: The unweighted, unreduced output of the fused
      experts, it will have (M, topk, K) shape.
    - topk_weights: The weights to be applied to the fused_experts_output.
    - topk_ids: The topk_ids.
    - apply_router_weight_on_input: When False, apply the weights to
      fused_expert_output.
    - weight_and_reduce_impl: An optional TopKWeightAndReduce
      implementation.
    """
    raise NotImplementedError

finalize_async

finalize_async(
    output: Tensor,
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: TopKWeightAndReduce,
) -> Union[tuple[Callable, Callable], Callable]

Perform any combine plus apply weights and perform a reduction on the fused experts output but do not wait for results from other workers. - output: The output tensor, written in place. Must be (M, K) shape. - fused_expert_output: The unweighted, unreduced output of the fused experts, it will have (M, topk, K) shape. - topk_weights: The weights to be applied to the fused_experts_output. - topk_ids: The topk_ids. - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. - weight_and_reduce_impl: An optional TopKWeightAndReduce implementation.

Returns a callback or a hook callback pair that when invoked waits for results from other workers and has the same return signature as finalize, if a hook is returned this is more lightweight check that the recv is complete without doing extra work (used by DBO, will be refactored in the very near future)

ret = obj.finalize_async(output, ...) ... output not valid yet ... if isinstance(ret, tuple): hook, receiver = ret hook() receiver() ... output valid here ...

is equivalent to:

obj.finalize(output, ...)

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def finalize_async(
    self,
    output: torch.Tensor,
    fused_expert_output: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: TopKWeightAndReduce,
) -> Union[tuple[Callable, Callable], Callable]:
    """
    Perform any combine plus apply weights and perform a reduction on the
    fused experts output but do not wait for results from other workers.
    - output: The output tensor, written in place.  Must be (M, K) shape.
    - fused_expert_output: The unweighted, unreduced output of the fused
      experts, it will have (M, topk, K) shape.
    - topk_weights: The weights to be applied to the fused_experts_output.
    - topk_ids: The topk_ids.
    - apply_router_weight_on_input: When False, apply the weights to
      fused_expert_output.
    - weight_and_reduce_impl: An optional TopKWeightAndReduce
      implementation.

    Returns a callback or a hook callback pair that when invoked waits for 
    results from other workers and has the same return signature as 
    `finalize`, if a hook is returned this is more lightweight check that
    the recv is complete without doing extra work (used by DBO, will be 
    refactored in the very near future)

    ret = obj.finalize_async(output, ...)
    ... output not valid yet ...
    if isinstance(ret, tuple):
        hook, receiver = ret
        hook()
    receiver()
    ... output valid here ...

    is equivalent to:

    obj.finalize(output, ...)
    """
    raise NotImplementedError

max_num_tokens_per_rank abstractmethod

max_num_tokens_per_rank() -> Optional[int]

Some PrepareFinalize All2All implementations are batched. Meaning, they can process only as set of tokens at a time. This function returns the batch size i.e the maximum number of tokens the implementation can process at a time. Return None if there are no such restrictions.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def max_num_tokens_per_rank(self) -> Optional[int]:
    """
    Some PrepareFinalize All2All implementations are batched. Meaning,
    they can process only as set of tokens at a time. This
    function returns the batch size i.e the maximum number of tokens
    the implementation can process at a time.
    Return None if there are no such restrictions.
    """
    raise NotImplementedError

num_dispatchers abstractmethod

num_dispatchers() -> int
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def num_dispatchers(self) -> int:
    raise NotImplementedError

prepare abstractmethod

prepare(
    a1: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    num_experts: int,
    expert_map: Optional[Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> PrepareResultType

Perform any quantization (and/or) dispatching needed for this kernel. - a1: The (unquantized) input to the MoE layer. - topk_ids: The topk ids. - topk_weights: The topk weights. - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - quant_config: Quantization info provided by the fused experts.

Returns a tuple of: - quantized + dispatched a. - Optional quantized + dispatched a1_scales. - Optional ExpertTokensMetadata containing gpu/cpu tensors as big as the number of local experts with the information about the number of tokens assigned to each local expert. - Optional dispatched expert topk IDs - Optional dispatched expert topk weight

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def prepare(
    self,
    a1: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
    expert_map: Optional[torch.Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> PrepareResultType:
    """
    Perform any quantization (and/or) dispatching needed for this kernel.
    - a1: The (unquantized) input to the MoE layer.
    - topk_ids: The topk ids.
    - topk_weights: The topk weights.
    - num_experts: The total number of experts in the global expert space.
    - expert_map: A tensor mapping expert indices from the global expert
      space to the local expert space of the expert parallel shard.
    - apply_router_weight_on_input: When True, apply the weights to the
      activations, before quantization + dispatching.
    - quant_config: Quantization info provided by the fused experts.

    Returns a tuple of:
    - quantized + dispatched a.
    - Optional quantized + dispatched a1_scales.
    - Optional ExpertTokensMetadata containing gpu/cpu tensors
      as big as the number of local experts with the information about the
      number of tokens assigned to each local expert.
    - Optional dispatched expert topk IDs
    - Optional dispatched expert topk weight
    """
    raise NotImplementedError

prepare_async

prepare_async(
    a1: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    num_experts: int,
    expert_map: Optional[Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> Union[tuple[Callable, ReceiverType], ReceiverType]

Perform any quantization (and/or) dispatching needed for this kernel but do not wait for results from other workers. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make sure the quantization is consistent for both gemms. - topk_ids: The topk ids. - topk_weights: The topk weights. - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching.

Returns a callback or a hook callback pair that when invoked waits for results from other workers and has the same return signature as prepare, if a hook is returned this is more lightweight check that the recv is complete without doing extra work (used by DBO, will be refactored in the very near future)

e.g.

ret = obj.prepare_async(...)

if isinstance(ret, tuple): hook, receiver = ret hook()

if hook is not None: a, a_scales, expert_meta, topk_ids, topk_weights = receiver()

is equivalent to:

a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def prepare_async(
    self,
    a1: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
    expert_map: Optional[torch.Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
    """
    Perform any quantization (and/or) dispatching needed for this kernel
    but do not wait for results from other workers.
    - a1: The (unquantized) input to the MoE layer.
    - a1_scale: Optional scales for a1
    - a2_scale: Optional scales for the second MoE gemm.  Required to make
      sure the quantization is consistent for both gemms.
    - topk_ids: The topk ids.
    - topk_weights: The topk weights.
    - num_experts: The total number of experts in the global expert space.
    - expert_map: A tensor mapping expert indices from the global expert
      space to the local expert space of the expert parallel shard.
    - apply_router_weight_on_input: When True, apply the weights to the
      activations, before quantization + dispatching.

    Returns a callback or a hook callback pair that when invoked waits for 
    results from other workers and has the same return signature as 
    `prepare`, if a hook is returned this is more lightweight check that
    the recv is complete without doing extra work (used by DBO, will be 
    refactored in the very near future)

    e.g.

    ret = obj.prepare_async(...)

    if isinstance(ret, tuple):
        hook, receiver = ret
        hook()

    if hook is not None:
    a, a_scales, expert_meta, topk_ids, topk_weights = receiver()

    is equivalent to:

    a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
    """
    raise NotImplementedError

supports_async

supports_async() -> bool

Indicates whether or not this class implements prepare_async and finalize_async.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def supports_async(self) -> bool:
    """
    Indicates whether or not this class implements prepare_async and
    finalize_async.
    """
    return False

topk_indices_dtype abstractmethod

topk_indices_dtype() -> Optional[dtype]

The PrepareFinalize All2All implementations generally constrain the dtype of the topk_ids they support. This function returns the required topk indices dtype so it can be respected. Return None if there are no such restrictions.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def topk_indices_dtype(self) -> Optional[torch.dtype]:
    """
    The PrepareFinalize All2All implementations generally constrain the
    dtype of the topk_ids they support. This function returns the
    required topk indices dtype so it can be respected.
    Return None if there are no such restrictions.
    """
    raise NotImplementedError

SharedResizableBuffer

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
class SharedResizableBuffer:

    def __init__(self):
        self.buffer = None

    def get(self, shape: tuple[int, ...], device: torch.device,
            dtype: torch.dtype):
        if shape == () or shape is None:
            return None
        shape_numel = prod(shape)
        if (self.buffer is None or self.buffer.numel() < shape_numel
                or self.buffer.device != device or self.buffer.dtype != dtype):
            self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
        return self.buffer[:shape_numel].view(*shape)

buffer instance-attribute

buffer = None

__init__

__init__()
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def __init__(self):
    self.buffer = None

get

get(shape: tuple[int, ...], device: device, dtype: dtype)
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def get(self, shape: tuple[int, ...], device: torch.device,
        dtype: torch.dtype):
    if shape == () or shape is None:
        return None
    shape_numel = prod(shape)
    if (self.buffer is None or self.buffer.numel() < shape_numel
            or self.buffer.device != device or self.buffer.dtype != dtype):
        self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
    return self.buffer[:shape_numel].view(*shape)

TopKWeightAndReduce

Bases: ABC

An abstract base class for weight application and reduction implementations.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
class TopKWeightAndReduce(ABC):
    """
    An abstract base class for weight application and reduction implementations.
    """

    @abstractmethod
    def apply(self, output: Optional[torch.Tensor],
              fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
              topk_ids: torch.Tensor,
              apply_router_weight_on_input: bool) -> torch.Tensor:
        """
        Apply topk_weights to the fused_experts_outputs and/or reduce.
        If an output tensor is not passed, it will be created in the
        function.
        """
        raise NotImplementedError

apply abstractmethod

apply(
    output: Optional[Tensor],
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
) -> Tensor

Apply topk_weights to the fused_experts_outputs and/or reduce. If an output tensor is not passed, it will be created in the function.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
@abstractmethod
def apply(self, output: Optional[torch.Tensor],
          fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
          topk_ids: torch.Tensor,
          apply_router_weight_on_input: bool) -> torch.Tensor:
    """
    Apply topk_weights to the fused_experts_outputs and/or reduce.
    If an output tensor is not passed, it will be created in the
    function.
    """
    raise NotImplementedError

_chunk_scales

_chunk_scales(
    scales: Optional[Tensor], start: int, end: int
) -> Optional[Tensor]
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
                  end: int) -> Optional[torch.Tensor]:
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None

_moe_problem_size

_moe_problem_size(
    a1: Tensor, w1: Tensor, w2: Tensor, topk_ids: Tensor
) -> tuple[int, int, int, int, int]

Extract the MoE problem size from the given tensor arguments: - a: The hidden states, input to the MoE layer. - w1: The first set of expert weights. - w2: The second set of expert weights. - topk_ids: The topk ids.

Note: extracting the problem shape from the weight and activation tensors is not obvious. It needs to be done this way specifically due to subtle issues with particular kernels, e.g. the int4 kernels divide the trailing dimension by two, so it's not "correct" to extract N or K from the trailing dimension of w1 or w2. Similarly, some kernels transpose the weights, so this needs to be kept in mind.

Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
def _moe_problem_size(
    a1: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
    """
    Extract the MoE problem size from the given tensor arguments:
    - a: The hidden states, input to the MoE layer.
    - w1: The first set of expert weights.
    - w2: The second set of expert weights.
    - topk_ids: The topk ids.

    Note: extracting the problem shape from the weight and activation tensors is
    not obvious.  It needs to be done this way specifically due to subtle issues
    with particular kernels, e.g. the int4 kernels divide the trailing dimension
    by two, so it's not "correct" to extract N or K from the trailing dimension
    of w1 or w2.  Similarly, some kernels transpose the weights, so this needs
    to be kept in mind.
    """
    assert w1.dim() == 3 and w2.dim() == 3
    E, N, _ = w1.size()
    K = a1.size(-1)

    if a1.dim() == 2:
        # Make sure we are using the correct a1 (pre-permute).
        assert topk_ids.size(0) == a1.size(0), \
            f"{topk_ids.size(0)} != {a1.size(0)}"
        M = a1.size(0)
    else:
        assert a1.dim() == 3
        assert a1.size(0) == E, f"{a1.size(0)} == {E}"
        M = a1.size(1)  # This is max_num_tokens

    assert topk_ids.dim() == 2
    topk = topk_ids.size(1)

    return E, M, N, K, topk