Skip to content

vllm.model_executor

Modules:

Name Description
custom_op
layers
model_loader
models
parameter
utils

Utils for model executor.

warmup

__all__ module-attribute

__all__ = [
    "set_random_seed",
    "BasevLLMParameter",
    "PackedvLLMParameter",
]

BasevLLMParameter

Bases: Parameter

Base parameter for vLLM linear layers. Extends the torch.nn.parameter by taking in a linear weight loader. Will copy the loaded weight into the parameter when the provided weight loader is called.

Source code in vllm/model_executor/parameter.py
class BasevLLMParameter(Parameter):
    """
    Base parameter for vLLM linear layers. Extends the torch.nn.parameter
    by taking in a linear weight loader. Will copy the loaded weight
    into the parameter when the provided weight loader is called.
    """

    def __new__(cls, data: Optional[torch.Tensor], **kwargs):

        return super().__new__(cls, data=data, requires_grad=False)

    def __init__(self, data: torch.Tensor, weight_loader: Callable):
        """
        Initialize the BasevLLMParameter

        :param data: torch tensor with the parameter data
        :param weight_loader: weight loader callable

        :returns: a torch.nn.parameter
        """

        # During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        from vllm.platforms import current_platform
        if current_platform.use_sync_weight_loader():
            weight_loader = current_platform.make_synced_weight_loader(
                weight_loader)

        self._weight_loader = weight_loader
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()

    @property
    def weight_loader(self) -> Callable:
        # NOTE(@ksayers) some models such as mamba_mixer2 override the
        # weight loader to support custom loading. In the future, model-specific
        # weight loading should be implemented via Model.load_weights. In the
        # meantime, support deleting and overriding `weight_loader`` attribute
        if self._weight_loader is None:
            raise AttributeError(f"{self.__class__.__name__} weight_loader "
                                 "attribute has been deleted")
        return self._weight_loader

    @weight_loader.setter
    def weight_loader(self, value: Callable):
        self._weight_loader = value

    @weight_loader.deleter
    def weight_loader(self):
        self._weight_loader = None  # type: ignore[assignment]

    def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
        cond1 = self.data.ndim == 1 and self.data.numel() == 1
        cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
        return (cond1 and cond2)

    def _assert_and_load(self, loaded_weight: torch.Tensor):
        assert (self.data.shape == loaded_weight.shape
                or self._is_1d_and_scalar(loaded_weight))
        self.data.copy_(loaded_weight)

    def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
        self._assert_and_load(loaded_weight)

    def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
        self._assert_and_load(loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

    def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
        if isinstance(shard_id, int):
            return shard_id

        # if not int, assume shard_id for qkv
        # map to int and return
        qkv_idxs = {"q": 0, "k": 1, "v": 2}
        assert isinstance(shard_id, str)
        assert shard_id in qkv_idxs
        return qkv_idxs[shard_id]

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

_weight_loader instance-attribute

_weight_loader = weight_loader

tp_rank instance-attribute

tp_size instance-attribute

weight_loader deletable property writable

weight_loader: Callable

__init__

__init__(data: Tensor, weight_loader: Callable)

Initialize the BasevLLMParameter

:param data: torch tensor with the parameter data :param weight_loader: weight loader callable

:returns: a torch.nn.parameter

Source code in vllm/model_executor/parameter.py
def __init__(self, data: torch.Tensor, weight_loader: Callable):
    """
    Initialize the BasevLLMParameter

    :param data: torch tensor with the parameter data
    :param weight_loader: weight loader callable

    :returns: a torch.nn.parameter
    """

    # During weight loading, we often do something like:
    # narrowed_tensor = param.data.narrow(0, offset, len)
    # narrowed_tensor.copy_(real_weight)
    # expecting narrowed_tensor and param.data to share the same storage.
    # However, on TPUs, narrowed_tensor will lazily propagate to the base
    # tensor, which is param.data, leading to the redundant memory usage.
    # This sometimes causes OOM errors during model loading. To avoid this,
    # we sync the param tensor after its weight loader is called.
    from vllm.platforms import current_platform
    if current_platform.use_sync_weight_loader():
        weight_loader = current_platform.make_synced_weight_loader(
            weight_loader)

    self._weight_loader = weight_loader
    self.tp_rank = get_tensor_model_parallel_rank()
    self.tp_size = get_tensor_model_parallel_world_size()

__new__

__new__(data: Optional[Tensor], **kwargs)
Source code in vllm/model_executor/parameter.py
def __new__(cls, data: Optional[torch.Tensor], **kwargs):

    return super().__new__(cls, data=data, requires_grad=False)

__torch_function__ classmethod

__torch_function__(func, types, args=(), kwargs=None)
Source code in vllm/model_executor/parameter.py
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    return super().__torch_function__(func, types, args, kwargs)

_assert_and_load

_assert_and_load(loaded_weight: Tensor)
Source code in vllm/model_executor/parameter.py
def _assert_and_load(self, loaded_weight: torch.Tensor):
    assert (self.data.shape == loaded_weight.shape
            or self._is_1d_and_scalar(loaded_weight))
    self.data.copy_(loaded_weight)

_is_1d_and_scalar

_is_1d_and_scalar(loaded_weight: Tensor)
Source code in vllm/model_executor/parameter.py
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
    cond1 = self.data.ndim == 1 and self.data.numel() == 1
    cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
    return (cond1 and cond2)

_shard_id_as_int

_shard_id_as_int(shard_id: Union[str, int]) -> int
Source code in vllm/model_executor/parameter.py
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
    if isinstance(shard_id, int):
        return shard_id

    # if not int, assume shard_id for qkv
    # map to int and return
    qkv_idxs = {"q": 0, "k": 1, "v": 2}
    assert isinstance(shard_id, str)
    assert shard_id in qkv_idxs
    return qkv_idxs[shard_id]

load_column_parallel_weight

load_column_parallel_weight(loaded_weight: Tensor)
Source code in vllm/model_executor/parameter.py
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
    self._assert_and_load(loaded_weight)

load_merged_column_weight

load_merged_column_weight(loaded_weight: Tensor, **kwargs)
Source code in vllm/model_executor/parameter.py
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
    self._assert_and_load(loaded_weight)

load_qkv_weight

load_qkv_weight(loaded_weight: Tensor, **kwargs)
Source code in vllm/model_executor/parameter.py
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
    self._assert_and_load(loaded_weight)

load_row_parallel_weight

load_row_parallel_weight(loaded_weight: Tensor)
Source code in vllm/model_executor/parameter.py
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
    self._assert_and_load(loaded_weight)

PackedvLLMParameter

Bases: ModelWeightParameter

Parameter for model weights which are packed on disk. Example: GPTQ Marlin weights are int4 or int8, packed into int32. Extends the ModelWeightParameter to take in the packed factor, the packed dimension, and optionally, marlin tile size for marlin kernels. Adjusts the shard_size and shard_offset for fused linear layers model weight loading by accounting for packing and optionally, marlin tile size.

Source code in vllm/model_executor/parameter.py
class PackedvLLMParameter(ModelWeightParameter):
    """
    Parameter for model weights which are packed on disk.
    Example: GPTQ Marlin weights are int4 or int8, packed into int32.
    Extends the ModelWeightParameter to take in the
    packed factor, the packed dimension, and optionally, marlin
    tile size for marlin kernels. Adjusts the shard_size and 
    shard_offset for fused linear layers model weight loading
    by accounting for packing and optionally, marlin tile size.
    """

    def __init__(self,
                 packed_factor: Union[int, Fraction],
                 packed_dim: int,
                 marlin_tile_size: Optional[int] = None,
                 bitblas_tile_size: Optional[int] = None,
                 **kwargs):
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
        self._marlin_tile_size = marlin_tile_size
        self._bitblas_tile_size = bitblas_tile_size
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
    def marlin_tile_size(self):
        return self._marlin_tile_size

    @property
    def bitblas_tile_size(self):
        return self._bitblas_tile_size

    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
            marlin_tile_size=self.marlin_tile_size,
            bitblas_tile_size=self.bitblas_tile_size)

_bitblas_tile_size instance-attribute

_bitblas_tile_size = bitblas_tile_size

_marlin_tile_size instance-attribute

_marlin_tile_size = marlin_tile_size

_packed_dim instance-attribute

_packed_dim = packed_dim

_packed_factor instance-attribute

_packed_factor = packed_factor

bitblas_tile_size property

bitblas_tile_size

marlin_tile_size property

marlin_tile_size

packed_dim property

packed_dim

packed_factor property

packed_factor

__init__

__init__(
    packed_factor: Union[int, Fraction],
    packed_dim: int,
    marlin_tile_size: Optional[int] = None,
    bitblas_tile_size: Optional[int] = None,
    **kwargs,
)
Source code in vllm/model_executor/parameter.py
def __init__(self,
             packed_factor: Union[int, Fraction],
             packed_dim: int,
             marlin_tile_size: Optional[int] = None,
             bitblas_tile_size: Optional[int] = None,
             **kwargs):
    self._packed_factor = packed_factor
    self._packed_dim = packed_dim
    self._marlin_tile_size = marlin_tile_size
    self._bitblas_tile_size = bitblas_tile_size
    super().__init__(**kwargs)

adjust_shard_indexes_for_packing

adjust_shard_indexes_for_packing(shard_size, shard_offset)
Source code in vllm/model_executor/parameter.py
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
    return _adjust_shard_indexes_for_packing(
        shard_size=shard_size,
        shard_offset=shard_offset,
        packed_factor=self.packed_factor,
        marlin_tile_size=self.marlin_tile_size,
        bitblas_tile_size=self.bitblas_tile_size)

set_random_seed

set_random_seed(seed: int) -> None
Source code in vllm/model_executor/utils.py
def set_random_seed(seed: int) -> None:
    from vllm.platforms import current_platform

    current_platform.seed_everything(seed)