Skip to content

vllm.distributed.device_communicators.all2all

logger module-attribute

logger = init_logger(__name__)

AgRsAll2AllManager

Bases: All2AllManagerBase

An implementation of all2all communication based on all-gather (dispatch) and reduce-scatter (combine).

Source code in vllm/distributed/device_communicators/all2all.py
class AgRsAll2AllManager(All2AllManagerBase):
    """
    An implementation of all2all communication based on
    all-gather (dispatch) and reduce-scatter (combine).
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        is_sequence_parallel: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Gather hidden_states and router_logits from all dp ranks.
        """
        sizes = get_forward_context(
        ).dp_metadata.get_chunk_sizes_across_dp_rank()

        dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
        assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
        hidden_states, router_logits = dist_group.all_gatherv(
            [hidden_states, router_logits],
            dim=0,
            sizes=sizes,
        )
        return hidden_states, router_logits

    def combine(self,
                hidden_states: torch.Tensor,
                is_sequence_parallel: bool = False) -> torch.Tensor:
        """
        Reduce-scatter hidden_states across all dp ranks.
        """
        sizes = get_forward_context(
        ).dp_metadata.get_chunk_sizes_across_dp_rank()

        dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
        hidden_states = dist_group.reduce_scatterv(hidden_states,
                                                   dim=0,
                                                   sizes=sizes)
        return hidden_states

    def destroy(self):
        pass

__init__

__init__(cpu_group)
Source code in vllm/distributed/device_communicators/all2all.py
def __init__(self, cpu_group):
    super().__init__(cpu_group)

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor

Reduce-scatter hidden_states across all dp ranks.

Source code in vllm/distributed/device_communicators/all2all.py
def combine(self,
            hidden_states: torch.Tensor,
            is_sequence_parallel: bool = False) -> torch.Tensor:
    """
    Reduce-scatter hidden_states across all dp ranks.
    """
    sizes = get_forward_context(
    ).dp_metadata.get_chunk_sizes_across_dp_rank()

    dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
    hidden_states = dist_group.reduce_scatterv(hidden_states,
                                               dim=0,
                                               sizes=sizes)
    return hidden_states

destroy

destroy()
Source code in vllm/distributed/device_communicators/all2all.py
def destroy(self):
    pass

dispatch

dispatch(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
) -> tuple[Tensor, Tensor]

Gather hidden_states and router_logits from all dp ranks.

Source code in vllm/distributed/device_communicators/all2all.py
def dispatch(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Gather hidden_states and router_logits from all dp ranks.
    """
    sizes = get_forward_context(
    ).dp_metadata.get_chunk_sizes_across_dp_rank()

    dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
    assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
    hidden_states, router_logits = dist_group.all_gatherv(
        [hidden_states, router_logits],
        dim=0,
        sizes=sizes,
    )
    return hidden_states, router_logits

DeepEPAll2AllManagerBase

Bases: All2AllManagerBase

All2All communication based on DeepEP High-Throughput kernels.

Source code in vllm/distributed/device_communicators/all2all.py
class DeepEPAll2AllManagerBase(All2AllManagerBase):
    """
    All2All communication based on DeepEP High-Throughput kernels.
    """

    def __init__(self, cpu_group):
        assert has_deep_ep(
        ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels."  # noqa
        super().__init__(cpu_group)
        self.handle_cache = Cache()

        # This is the DeepEP default. Stick to it till we can establish
        # reasonable defaults based on profiling.
        self.num_sms = 20

    def get_handle(self, kwargs):
        raise NotImplementedError

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        is_sequence_parallel: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError

    def combine(self,
                hidden_states: torch.Tensor,
                is_sequence_parallel: bool = False) -> torch.Tensor:
        raise NotImplementedError

    def destroy(self):
        pass

handle_cache instance-attribute

handle_cache = Cache()

num_sms instance-attribute

num_sms = 20

__init__

__init__(cpu_group)
Source code in vllm/distributed/device_communicators/all2all.py
def __init__(self, cpu_group):
    assert has_deep_ep(
    ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels."  # noqa
    super().__init__(cpu_group)
    self.handle_cache = Cache()

    # This is the DeepEP default. Stick to it till we can establish
    # reasonable defaults based on profiling.
    self.num_sms = 20

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor
Source code in vllm/distributed/device_communicators/all2all.py
def combine(self,
            hidden_states: torch.Tensor,
            is_sequence_parallel: bool = False) -> torch.Tensor:
    raise NotImplementedError

destroy

destroy()
Source code in vllm/distributed/device_communicators/all2all.py
def destroy(self):
    pass

dispatch

dispatch(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
) -> tuple[Tensor, Tensor]
Source code in vllm/distributed/device_communicators/all2all.py
def dispatch(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    raise NotImplementedError

get_handle

get_handle(kwargs)
Source code in vllm/distributed/device_communicators/all2all.py
def get_handle(self, kwargs):
    raise NotImplementedError

DeepEPHTAll2AllManager

Bases: DeepEPAll2AllManagerBase

All2All communication based on DeepEP High-Throughput kernels.

Source code in vllm/distributed/device_communicators/all2all.py
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
    """
    All2All communication based on DeepEP High-Throughput kernels.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def _make_all2all_kwargs(self) -> dict[Any, Any]:
        # Defaults for internode and intranode are taken from DeepEP tests.
        num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
        num_rdma_bytes = None
        num_qps_per_rank = None

        if self.internode:
            num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
            num_qps_per_rank = self.num_sms // 2
        else:
            num_rdma_bytes = 0
            num_qps_per_rank = 1

        assert num_rdma_bytes is not None
        assert num_qps_per_rank is not None
        return dict(group=self.cpu_group,
                    num_nvl_bytes=num_nvl_bytes,
                    num_rdma_bytes=num_rdma_bytes,
                    low_latency_mode=False,
                    num_qps_per_rank=num_qps_per_rank)

    def get_handle(self, kwargs):

        assert len(kwargs) == 0, (
            "DeepEPHTAll2AllManager expects no arguments. All the required "
            "args are computed in the Manager itself.")

        import deep_ep
        buffer_kwargs = self._make_all2all_kwargs()
        logger.debug("DeepEP all2all args %s", buffer_kwargs)
        handle: deep_ep.Buffer = self.handle_cache.get_or_create(
            buffer_kwargs, deep_ep.Buffer)
        return handle

    def set_num_sms(self, num_sms: int):
        import deep_ep

        # Right now the buffers are sized for only what the kernels were
        # created with. So we can only reduce the number of SMS used
        # but not increase it.
        if num_sms > self.num_sms:
            num_sms = self.num_sms
        deep_ep.Buffer.set_num_sms(num_sms)

__init__

__init__(cpu_group)
Source code in vllm/distributed/device_communicators/all2all.py
def __init__(self, cpu_group):
    super().__init__(cpu_group)

_make_all2all_kwargs

_make_all2all_kwargs() -> dict[Any, Any]
Source code in vllm/distributed/device_communicators/all2all.py
def _make_all2all_kwargs(self) -> dict[Any, Any]:
    # Defaults for internode and intranode are taken from DeepEP tests.
    num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
    num_rdma_bytes = None
    num_qps_per_rank = None

    if self.internode:
        num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
        num_qps_per_rank = self.num_sms // 2
    else:
        num_rdma_bytes = 0
        num_qps_per_rank = 1

    assert num_rdma_bytes is not None
    assert num_qps_per_rank is not None
    return dict(group=self.cpu_group,
                num_nvl_bytes=num_nvl_bytes,
                num_rdma_bytes=num_rdma_bytes,
                low_latency_mode=False,
                num_qps_per_rank=num_qps_per_rank)

get_handle

get_handle(kwargs)
Source code in vllm/distributed/device_communicators/all2all.py
def get_handle(self, kwargs):

    assert len(kwargs) == 0, (
        "DeepEPHTAll2AllManager expects no arguments. All the required "
        "args are computed in the Manager itself.")

    import deep_ep
    buffer_kwargs = self._make_all2all_kwargs()
    logger.debug("DeepEP all2all args %s", buffer_kwargs)
    handle: deep_ep.Buffer = self.handle_cache.get_or_create(
        buffer_kwargs, deep_ep.Buffer)
    return handle

set_num_sms

set_num_sms(num_sms: int)
Source code in vllm/distributed/device_communicators/all2all.py
def set_num_sms(self, num_sms: int):
    import deep_ep

    # Right now the buffers are sized for only what the kernels were
    # created with. So we can only reduce the number of SMS used
    # but not increase it.
    if num_sms > self.num_sms:
        num_sms = self.num_sms
    deep_ep.Buffer.set_num_sms(num_sms)

DeepEPLLAll2AllManager

Bases: DeepEPAll2AllManagerBase

All2All communication based on DeepEP Low-Latency kernels.

Source code in vllm/distributed/device_communicators/all2all.py
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
    """
    All2All communication based on DeepEP Low-Latency kernels.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def _make_all2all_kwargs(
        self,
        max_num_tokens_per_dp_rank: int,
        token_hidden_size: int,
        num_ep_ranks: int,
        num_global_experts: int,
        num_local_experts: int,
    ) -> dict[Any, Any]:
        """
        max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
          can dispatch all the ranks must hold the same value.
        token_hidden_size: the hidden dimension of each token.
        num_ep_ranks: the number of EP group ranks.
        num_global_experts: Number of experts in the model.
        num_local_experts: Number of experts in an EP rank.
        """
        import deep_ep

        # Defaults for internode and intranode are taken from DeepEP tests.
        num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
        num_qps_per_rank = num_local_experts
        num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
            num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
            hidden=token_hidden_size,
            num_ranks=num_ep_ranks,
            num_experts=num_global_experts)

        assert num_rdma_bytes is not None
        return dict(group=self.cpu_group,
                    num_nvl_bytes=num_nvl_bytes,
                    num_rdma_bytes=num_rdma_bytes,
                    low_latency_mode=True,
                    num_qps_per_rank=num_qps_per_rank)

    def get_handle(self, kwargs):
        """
        The kwargs for DeepEPLLAll2AllManager is dictated by
        _make_all2all_kwargs.
        """
        import deep_ep
        buffer_kwargs = self._make_all2all_kwargs(**kwargs)
        logger.debug("DeepEP all2all args %s", buffer_kwargs)
        handle: deep_ep.Buffer = self.handle_cache.get_or_create(
            buffer_kwargs, deep_ep.Buffer)
        return handle

    # DeepEP LL uses RDMA so no SMs are used for communication
    def max_sms_used(self) -> Optional[int]:
        return 0

__init__

__init__(cpu_group)
Source code in vllm/distributed/device_communicators/all2all.py
def __init__(self, cpu_group):
    super().__init__(cpu_group)

_make_all2all_kwargs

_make_all2all_kwargs(
    max_num_tokens_per_dp_rank: int,
    token_hidden_size: int,
    num_ep_ranks: int,
    num_global_experts: int,
    num_local_experts: int,
) -> dict[Any, Any]
the maximum number of tokens a DP rank

can dispatch all the ranks must hold the same value.

token_hidden_size: the hidden dimension of each token. num_ep_ranks: the number of EP group ranks. num_global_experts: Number of experts in the model. num_local_experts: Number of experts in an EP rank.

Source code in vllm/distributed/device_communicators/all2all.py
def _make_all2all_kwargs(
    self,
    max_num_tokens_per_dp_rank: int,
    token_hidden_size: int,
    num_ep_ranks: int,
    num_global_experts: int,
    num_local_experts: int,
) -> dict[Any, Any]:
    """
    max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
      can dispatch all the ranks must hold the same value.
    token_hidden_size: the hidden dimension of each token.
    num_ep_ranks: the number of EP group ranks.
    num_global_experts: Number of experts in the model.
    num_local_experts: Number of experts in an EP rank.
    """
    import deep_ep

    # Defaults for internode and intranode are taken from DeepEP tests.
    num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
    num_qps_per_rank = num_local_experts
    num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
        num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
        hidden=token_hidden_size,
        num_ranks=num_ep_ranks,
        num_experts=num_global_experts)

    assert num_rdma_bytes is not None
    return dict(group=self.cpu_group,
                num_nvl_bytes=num_nvl_bytes,
                num_rdma_bytes=num_rdma_bytes,
                low_latency_mode=True,
                num_qps_per_rank=num_qps_per_rank)

get_handle

get_handle(kwargs)

The kwargs for DeepEPLLAll2AllManager is dictated by _make_all2all_kwargs.

Source code in vllm/distributed/device_communicators/all2all.py
def get_handle(self, kwargs):
    """
    The kwargs for DeepEPLLAll2AllManager is dictated by
    _make_all2all_kwargs.
    """
    import deep_ep
    buffer_kwargs = self._make_all2all_kwargs(**kwargs)
    logger.debug("DeepEP all2all args %s", buffer_kwargs)
    handle: deep_ep.Buffer = self.handle_cache.get_or_create(
        buffer_kwargs, deep_ep.Buffer)
    return handle

max_sms_used

max_sms_used() -> Optional[int]
Source code in vllm/distributed/device_communicators/all2all.py
def max_sms_used(self) -> Optional[int]:
    return 0

FlashInferAllToAllManager

Bases: All2AllManagerBase

All2All communication based on flashinfer kernels.

Source code in vllm/distributed/device_communicators/all2all.py
class FlashInferAllToAllManager(All2AllManagerBase):
    """
    All2All communication based on flashinfer kernels.
    """

    def __init__(self, cpu_group):
        assert has_flashinfer_all2all(
        ), "flashinfer all2all module not found. Please install/check flashinfer"  # noqa
        super().__init__(cpu_group)
        logger.debug(
            "Initialize for flashinfer All2All "
            "rank=%d, world size=%d", self.rank, self.world_size)
        self.initialized = False
        self.alltoall_info = None

    def initialize(
        self,
        world_size: int,
        rank: int,
        gpus_per_node: int,
    ):
        """Initialize workspace"""
        if self.initialized:
            return

        self.cleanup()
        logger.debug("making map: "
                     "rank=%d, world size=%d", rank, world_size)
        self.mapping = Mapping(
            world_size,
            rank,
            gpus_per_node,
            tp_size=world_size,
        )

        from vllm.distributed.device_communicators.mnnvl_compat import (
            CustomCommunicator)
        dp_config = MnnvlConfig(
            comm_backend=CustomCommunicator(get_dp_group().cpu_group),
            fabric_page_size=1 << 29,  # 512MB
            allocation_granularity=0  # Auto-detect
        )

        self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
            self.mapping, dp_config)
        self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
            self.mapping, dp_config)

        self.world_size = world_size
        self.rank = rank
        self.gpus_per_node = gpus_per_node
        self.initialized = True

        logger.info("FlashInfer All2All initialized for rank %s, size %s",
                    rank, world_size)

    def ensure_alltoall_workspace_initialized(self):
        """Ensure workspace is initialized"""
        if not has_flashinfer_all2all():
            return False

        if self.world_size <= 1:
            return False

        if not self.initialized:
            self.initialize(
                world_size=self.world_size,
                rank=self.rank,
                gpus_per_node=torch.cuda.device_count,
            )
        return self.initialized

    def get_handle(self, kwargs):
        return self

    def cleanup(self):
        """Clean up workspace"""
        if self.initialized and self.workspace_tensor is not None \
            and self.prepare_workspace_tensor is not None:
            try:
                del self.workspace_tensor
                del self.prepare_workspace_tensor
            except Exception as e:
                logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
            finally:
                self.workspace_tensor = None
                self.prepare_workspace_tensor = None
                self.mapping = None
                self.initialized = False

alltoall_info instance-attribute

alltoall_info = None

initialized instance-attribute

initialized = False

__init__

__init__(cpu_group)
Source code in vllm/distributed/device_communicators/all2all.py
def __init__(self, cpu_group):
    assert has_flashinfer_all2all(
    ), "flashinfer all2all module not found. Please install/check flashinfer"  # noqa
    super().__init__(cpu_group)
    logger.debug(
        "Initialize for flashinfer All2All "
        "rank=%d, world size=%d", self.rank, self.world_size)
    self.initialized = False
    self.alltoall_info = None

cleanup

cleanup()

Clean up workspace

Source code in vllm/distributed/device_communicators/all2all.py
def cleanup(self):
    """Clean up workspace"""
    if self.initialized and self.workspace_tensor is not None \
        and self.prepare_workspace_tensor is not None:
        try:
            del self.workspace_tensor
            del self.prepare_workspace_tensor
        except Exception as e:
            logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
        finally:
            self.workspace_tensor = None
            self.prepare_workspace_tensor = None
            self.mapping = None
            self.initialized = False

ensure_alltoall_workspace_initialized

ensure_alltoall_workspace_initialized()

Ensure workspace is initialized

Source code in vllm/distributed/device_communicators/all2all.py
def ensure_alltoall_workspace_initialized(self):
    """Ensure workspace is initialized"""
    if not has_flashinfer_all2all():
        return False

    if self.world_size <= 1:
        return False

    if not self.initialized:
        self.initialize(
            world_size=self.world_size,
            rank=self.rank,
            gpus_per_node=torch.cuda.device_count,
        )
    return self.initialized

get_handle

get_handle(kwargs)
Source code in vllm/distributed/device_communicators/all2all.py
def get_handle(self, kwargs):
    return self

initialize

initialize(world_size: int, rank: int, gpus_per_node: int)

Initialize workspace

Source code in vllm/distributed/device_communicators/all2all.py
def initialize(
    self,
    world_size: int,
    rank: int,
    gpus_per_node: int,
):
    """Initialize workspace"""
    if self.initialized:
        return

    self.cleanup()
    logger.debug("making map: "
                 "rank=%d, world size=%d", rank, world_size)
    self.mapping = Mapping(
        world_size,
        rank,
        gpus_per_node,
        tp_size=world_size,
    )

    from vllm.distributed.device_communicators.mnnvl_compat import (
        CustomCommunicator)
    dp_config = MnnvlConfig(
        comm_backend=CustomCommunicator(get_dp_group().cpu_group),
        fabric_page_size=1 << 29,  # 512MB
        allocation_granularity=0  # Auto-detect
    )

    self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
        self.mapping, dp_config)
    self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
        self.mapping, dp_config)

    self.world_size = world_size
    self.rank = rank
    self.gpus_per_node = gpus_per_node
    self.initialized = True

    logger.info("FlashInfer All2All initialized for rank %s, size %s",
                rank, world_size)

NaiveAll2AllManager

Bases: All2AllManagerBase

A naive implementation of all2all communication. It uses all-reduce under the hood, which is not efficient at all. The main purpose is for testing and debugging.

Source code in vllm/distributed/device_communicators/all2all.py
class NaiveAll2AllManager(All2AllManagerBase):
    """
    A naive implementation of all2all communication.
    It uses all-reduce under the hood, which is not
    efficient at all. The main purpose is for testing and
    debugging.
    """

    def __init__(self, cpu_group):
        super().__init__(cpu_group)

    def naive_multicast(self, x: torch.Tensor,
                        cu_tokens_across_sp_cpu: torch.Tensor,
                        is_sequence_parallel: bool) -> torch.Tensor:
        assert (len(x.shape) == 2)
        buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
                             device=x.device,
                             dtype=x.dtype)

        rank = self.rank if is_sequence_parallel else self.dp_rank
        world_size = (self.world_size
                      if is_sequence_parallel else self.dp_world_size)

        start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
        end = cu_tokens_across_sp_cpu[rank]
        buffer[start:end, :].copy_(x)
        for idx in range(world_size):
            start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
            end = cu_tokens_across_sp_cpu[idx]
            get_ep_group().broadcast(buffer[start:end, :], idx)

        return buffer

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        is_sequence_parallel: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor]:
        sp_size = self.tp_group.world_size if is_sequence_parallel else 1
        dp_metadata = get_forward_context().dp_metadata
        cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

        hidden_states = self.naive_multicast(hidden_states,
                                             cu_tokens_across_sp_cpu,
                                             is_sequence_parallel)
        router_logits = self.naive_multicast(router_logits,
                                             cu_tokens_across_sp_cpu,
                                             is_sequence_parallel)
        return hidden_states, router_logits

    def combine(self,
                hidden_states: torch.Tensor,
                is_sequence_parallel: bool = False) -> torch.Tensor:

        ep_rank = self.rank if is_sequence_parallel else self.dp_rank

        dp_metadata = get_forward_context().dp_metadata
        sp_size = self.tp_group.world_size if is_sequence_parallel else 1
        cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

        start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
        end = cu_tokens_across_sp_cpu[ep_rank]

        all_hidden_states = get_ep_group().all_reduce(hidden_states)
        hidden_states = all_hidden_states[start:end, :]
        return hidden_states

    def destroy(self):
        pass

__init__

__init__(cpu_group)
Source code in vllm/distributed/device_communicators/all2all.py
def __init__(self, cpu_group):
    super().__init__(cpu_group)

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor
Source code in vllm/distributed/device_communicators/all2all.py
def combine(self,
            hidden_states: torch.Tensor,
            is_sequence_parallel: bool = False) -> torch.Tensor:

    ep_rank = self.rank if is_sequence_parallel else self.dp_rank

    dp_metadata = get_forward_context().dp_metadata
    sp_size = self.tp_group.world_size if is_sequence_parallel else 1
    cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

    start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
    end = cu_tokens_across_sp_cpu[ep_rank]

    all_hidden_states = get_ep_group().all_reduce(hidden_states)
    hidden_states = all_hidden_states[start:end, :]
    return hidden_states

destroy

destroy()
Source code in vllm/distributed/device_communicators/all2all.py
def destroy(self):
    pass

dispatch

dispatch(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
) -> tuple[Tensor, Tensor]
Source code in vllm/distributed/device_communicators/all2all.py
def dispatch(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    sp_size = self.tp_group.world_size if is_sequence_parallel else 1
    dp_metadata = get_forward_context().dp_metadata
    cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

    hidden_states = self.naive_multicast(hidden_states,
                                         cu_tokens_across_sp_cpu,
                                         is_sequence_parallel)
    router_logits = self.naive_multicast(router_logits,
                                         cu_tokens_across_sp_cpu,
                                         is_sequence_parallel)
    return hidden_states, router_logits

naive_multicast

naive_multicast(
    x: Tensor,
    cu_tokens_across_sp_cpu: Tensor,
    is_sequence_parallel: bool,
) -> Tensor
Source code in vllm/distributed/device_communicators/all2all.py
def naive_multicast(self, x: torch.Tensor,
                    cu_tokens_across_sp_cpu: torch.Tensor,
                    is_sequence_parallel: bool) -> torch.Tensor:
    assert (len(x.shape) == 2)
    buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
                         device=x.device,
                         dtype=x.dtype)

    rank = self.rank if is_sequence_parallel else self.dp_rank
    world_size = (self.world_size
                  if is_sequence_parallel else self.dp_world_size)

    start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
    end = cu_tokens_across_sp_cpu[rank]
    buffer[start:end, :].copy_(x)
    for idx in range(world_size):
        start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
        end = cu_tokens_across_sp_cpu[idx]
        get_ep_group().broadcast(buffer[start:end, :], idx)

    return buffer

PPLXAll2AllManager

Bases: All2AllManagerBase

All2All communication based on PPLX kernels.

Source code in vllm/distributed/device_communicators/all2all.py
class PPLXAll2AllManager(All2AllManagerBase):
    """
    All2All communication based on PPLX kernels.
    """

    def __init__(self, cpu_group):
        assert has_pplx(
        ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels."  # noqa
        super().__init__(cpu_group)

        if self.internode:
            # inter-node communication needs nvshmem,
            # intra-node communication uses p2p mapping directly
            from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
                                              nvshmem_get_unique_id,
                                              nvshmem_init)
            logger.debug(
                "Initialize NVSHMEM for pplx_kernels: "
                "rank=%d, world size=%d", self.rank, self.world_size)
            uid = nvshmem_get_unique_id(
            ) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
            dist.broadcast(uid,
                           src=dist.get_process_group_ranks(self.cpu_group)[0],
                           group=self.cpu_group)
            logger.debug("PPLX NVSHMEM UID = %s", uid)
            nvshmem_init(uid, self.rank, self.world_size)

        self.handle_cache = Cache()

    def get_handle(self, kwargs):
        import pplx_kernels as pplx
        return self.handle_cache.get_or_create(
            kwargs, pplx.AllToAll.internode
            if self.internode else pplx.AllToAll.intranode)

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        is_sequence_parallel: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError

    def combine(self,
                hidden_states: torch.Tensor,
                is_sequence_parallel: bool = False) -> torch.Tensor:
        raise NotImplementedError

    def destroy(self):
        with self.handle_cache._lock:
            for _, handle in self.handle_cache._cache.items():
                handle.destroy()

        if self.internode:
            from pplx_kernels.nvshmem import nvshmem_finalize
            logger.debug("PPLX NVSHMEM finalize")
            nvshmem_finalize()

handle_cache instance-attribute

handle_cache = Cache()

__init__

__init__(cpu_group)
Source code in vllm/distributed/device_communicators/all2all.py
def __init__(self, cpu_group):
    assert has_pplx(
    ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels."  # noqa
    super().__init__(cpu_group)

    if self.internode:
        # inter-node communication needs nvshmem,
        # intra-node communication uses p2p mapping directly
        from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
                                          nvshmem_get_unique_id,
                                          nvshmem_init)
        logger.debug(
            "Initialize NVSHMEM for pplx_kernels: "
            "rank=%d, world size=%d", self.rank, self.world_size)
        uid = nvshmem_get_unique_id(
        ) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
        dist.broadcast(uid,
                       src=dist.get_process_group_ranks(self.cpu_group)[0],
                       group=self.cpu_group)
        logger.debug("PPLX NVSHMEM UID = %s", uid)
        nvshmem_init(uid, self.rank, self.world_size)

    self.handle_cache = Cache()

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor
Source code in vllm/distributed/device_communicators/all2all.py
def combine(self,
            hidden_states: torch.Tensor,
            is_sequence_parallel: bool = False) -> torch.Tensor:
    raise NotImplementedError

destroy

destroy()
Source code in vllm/distributed/device_communicators/all2all.py
def destroy(self):
    with self.handle_cache._lock:
        for _, handle in self.handle_cache._cache.items():
            handle.destroy()

    if self.internode:
        from pplx_kernels.nvshmem import nvshmem_finalize
        logger.debug("PPLX NVSHMEM finalize")
        nvshmem_finalize()

dispatch

dispatch(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
) -> tuple[Tensor, Tensor]
Source code in vllm/distributed/device_communicators/all2all.py
def dispatch(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    raise NotImplementedError

get_handle

get_handle(kwargs)
Source code in vllm/distributed/device_communicators/all2all.py
def get_handle(self, kwargs):
    import pplx_kernels as pplx
    return self.handle_cache.get_or_create(
        kwargs, pplx.AllToAll.internode
        if self.internode else pplx.AllToAll.intranode)