Skip to content

vllm.distributed.device_communicators.xpu_communicator

logger module-attribute

logger = init_logger(__name__)

XpuCommunicator

Bases: DeviceCommunicatorBase

Source code in vllm/distributed/device_communicators/xpu_communicator.py
class XpuCommunicator(DeviceCommunicatorBase):

    def __init__(self,
                 cpu_group: ProcessGroup,
                 device: Optional[torch.device] = None,
                 device_group: Optional[ProcessGroup] = None,
                 unique_name: str = ""):
        super().__init__(cpu_group, device, device_group, unique_name)
        if self.use_all2all:
            all2all_backend = envs.VLLM_ALL2ALL_BACKEND
            if all2all_backend != "naive":
                logger.warning(
                    "`%s` all2all manager is not supported on XPU."
                    "Falling back to `naive` all2all manager for XPU.",
                    all2all_backend)
                all2all_backend = "naive"
            if all2all_backend == "naive":
                from .all2all import NaiveAll2AllManager
                self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
                logger.info("Using naive all2all manager.")

    def all_reduce(self, input_) -> torch.Tensor:
        dist.all_reduce(input_, group=self.device_group)
        return input_

    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
               dim: int = -1) -> Optional[torch.Tensor]:
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        # For xpu path, gather doesn't work properly together with ray
        # cluster so we use all_gather instead for now.
        input_size = input_.size()
        # Allocate output tensor.
        output_tensor = torch.empty((self.world_size, ) + input_size,
                                    dtype=input_.dtype,
                                    device=input_.device)
        # All-gather.
        dist.all_gather_into_tensor(output_tensor,
                                    input_,
                                    group=self.device_group)
        if self.rank_in_group == dst:
            # Reshape
            output_tensor = output_tensor.movedim(0, dim)
            output_tensor = output_tensor.reshape(input_size[:dim] +
                                                  (self.world_size *
                                                   input_size[dim], ) +
                                                  input_size[dim + 1:])
        else:
            output_tensor = None
        return output_tensor

    def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
        dist.broadcast(input_, src=src, group=self.device_group)

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        is_sequence_parallel: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor]:
        assert self.all2all_manager is not None
        hidden_states, router_logits = self.all2all_manager.dispatch(
            hidden_states, router_logits, is_sequence_parallel)
        return hidden_states, router_logits

    def combine(self,
                hidden_states: torch.Tensor,
                is_sequence_parallel: bool = False) -> torch.Tensor:
        assert self.all2all_manager is not None
        hidden_states = self.all2all_manager.combine(hidden_states,
                                                     is_sequence_parallel)
        return hidden_states

all2all_manager instance-attribute

all2all_manager = NaiveAll2AllManager(cpu_group)

__init__

__init__(
    cpu_group: ProcessGroup,
    device: Optional[device] = None,
    device_group: Optional[ProcessGroup] = None,
    unique_name: str = "",
)
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: Optional[torch.device] = None,
             device_group: Optional[ProcessGroup] = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)
    if self.use_all2all:
        all2all_backend = envs.VLLM_ALL2ALL_BACKEND
        if all2all_backend != "naive":
            logger.warning(
                "`%s` all2all manager is not supported on XPU."
                "Falling back to `naive` all2all manager for XPU.",
                all2all_backend)
            all2all_backend = "naive"
        if all2all_backend == "naive":
            from .all2all import NaiveAll2AllManager
            self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
            logger.info("Using naive all2all manager.")

all_reduce

all_reduce(input_) -> Tensor
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def all_reduce(self, input_) -> torch.Tensor:
    dist.all_reduce(input_, group=self.device_group)
    return input_

broadcast

broadcast(input_: Tensor, src: int = 0) -> None
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
    dist.broadcast(input_, src=src, group=self.device_group)

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def combine(self,
            hidden_states: torch.Tensor,
            is_sequence_parallel: bool = False) -> torch.Tensor:
    assert self.all2all_manager is not None
    hidden_states = self.all2all_manager.combine(hidden_states,
                                                 is_sequence_parallel)
    return hidden_states

dispatch

dispatch(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
) -> tuple[Tensor, Tensor]
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def dispatch(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    assert self.all2all_manager is not None
    hidden_states, router_logits = self.all2all_manager.dispatch(
        hidden_states, router_logits, is_sequence_parallel)
    return hidden_states, router_logits

gather

gather(
    input_: Tensor, dst: int = 0, dim: int = -1
) -> Optional[Tensor]
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> Optional[torch.Tensor]:
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # For xpu path, gather doesn't work properly together with ray
    # cluster so we use all_gather instead for now.
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((self.world_size, ) + input_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    dist.all_gather_into_tensor(output_tensor,
                                input_,
                                group=self.device_group)
    if self.rank_in_group == dst:
        # Reshape
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(input_size[:dim] +
                                              (self.world_size *
                                               input_size[dim], ) +
                                              input_size[dim + 1:])
    else:
        output_tensor = None
    return output_tensor