Skip to content

vllm.model_executor.model_loader.default_loader

logger module-attribute

logger = init_logger(__name__)

DefaultModelLoader

Bases: BaseModelLoader

Model loader that can load different file types from disk.

Source code in vllm/model_executor/model_loader/default_loader.py
class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

    # default number of thread when enable multithread weight loading
    DEFAULT_NUM_THREADS = 8

    @dataclasses.dataclass
    class Source:
        """A source for weights."""

        model_or_path: str
        """The model ID or path."""

        revision: Optional[str]
        """The optional model revision."""

        prefix: str = ""
        """A prefix to prepend to all weights."""

        fall_back_to_pt: bool = True
        """Whether .pt weights can be used."""

        allow_patterns_overrides: Optional[list[str]] = None
        """If defined, weights will load exclusively using these patterns."""

    counter_before_loading_weights: float = 0.0
    counter_after_loading_weights: float = 0.0

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

        extra_config = load_config.model_loader_extra_config
        allowed_keys = {"enable_multithread_load", "num_threads"}
        unexpected_keys = set(extra_config.keys()) - allowed_keys

        if unexpected_keys:
            raise ValueError(f"Unexpected extra config keys for load format "
                             f"{load_config.load_format}: "
                             f"{unexpected_keys}")

    def _prepare_weights(
        self,
        model_name_or_path: str,
        revision: Optional[str],
        fall_back_to_pt: bool,
        allow_patterns_overrides: Optional[list[str]],
    ) -> tuple[str, list[str], bool]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        model_name_or_path = (maybe_download_from_modelscope(
            model_name_or_path, revision) or model_name_or_path)

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
        index_file = SAFE_WEIGHTS_INDEX_NAME
        # Some quantized models use .pt files for storing the weights.
        if load_format == "auto":
            allow_patterns = ["*.safetensors", "*.bin"]
        elif (load_format == "safetensors"
              or load_format == "fastsafetensors"):
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
        elif load_format == "mistral":
            use_safetensors = True
            allow_patterns = ["consolidated*.safetensors"]
            index_file = "consolidated.safetensors.index.json"
        elif load_format == "pt":
            allow_patterns = ["*.pt"]
        elif load_format == "npcache":
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

        if allow_patterns_overrides is not None:
            allow_patterns = allow_patterns_overrides

        if not is_local:
            hf_folder = download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
        else:
            hf_folder = model_name_or_path

        hf_weights_files: list[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern == "*.safetensors":
                    use_safetensors = True
                break

        if use_safetensors:
            # For models like Mistral-7B-Instruct-v0.3
            # there are both sharded safetensors files and a consolidated
            # safetensors file. Using both breaks.
            # Here, we download the `model.safetensors.index.json` and filter
            # any files not found in the index.
            if not is_local:
                download_safetensors_index_file_from_hf(
                    model_name_or_path,
                    index_file,
                    self.load_config.download_dir,
                    revision,
                )
            hf_weights_files = filter_duplicate_safetensors_files(
                hf_weights_files, hf_folder, index_file)
        else:
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
            self, source: "Source"
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        extra_config = self.load_config.model_loader_extra_config
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
            source.model_or_path, source.revision, source.fall_back_to_pt,
            source.allow_patterns_overrides)
        if self.load_config.load_format == "npcache":
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
            weights_iterator = np_cache_weights_iterator(
                source.model_or_path,
                self.load_config.download_dir,
                hf_folder,
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
            )
        elif use_safetensors:
            if self.load_config.load_format == "fastsafetensors":
                weights_iterator = fastsafetensors_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                )
            else:
                if extra_config.get("enable_multithread_load"):
                    weights_iterator = (
                        multi_thread_safetensors_weights_iterator(
                            hf_weights_files,
                            self.load_config.use_tqdm_on_load,
                            max_workers=extra_config.get(
                                "num_threads", self.DEFAULT_NUM_THREADS),
                        ))
                else:
                    weights_iterator = safetensors_weights_iterator(
                        hf_weights_files,
                        self.load_config.use_tqdm_on_load,
                        self.load_config.safetensors_load_strategy,
                    )
        else:
            if extra_config.get("enable_multithread_load"):
                weights_iterator = multi_thread_pt_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.pt_load_map_location,
                    max_workers=extra_config.get("num_threads",
                                                 self.DEFAULT_NUM_THREADS),
                )
            else:
                weights_iterator = pt_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.pt_load_map_location,
                )

        if current_platform.is_tpu():
            from vllm.platforms.tpu import USE_TPU_COMMONS

            if not USE_TPU_COMMONS:
                # In PyTorch XLA, we should call `torch_xla.sync`
                # frequently so that not too many ops are accumulated
                # in the XLA program.
                import torch_xla

                def _xla_weights_iterator(iterator: Generator):
                    for weights in iterator:
                        yield weights
                        torch_xla.sync(wait=False)

                weights_iterator = _xla_weights_iterator(weights_iterator)

        if self.counter_before_loading_weights == 0.0:
            self.counter_before_loading_weights = time.perf_counter()
        # Apply the prefix.
        return ((source.prefix + name, tensor)
                for (name, tensor) in weights_iterator)

    def get_all_weights(
        self,
        model_config: ModelConfig,
        model: nn.Module,
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        primary_weights = DefaultModelLoader.Source(
            model_config.model,
            model_config.revision,
            prefix="",
            fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
                                    True),
            allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
                                             None),
        )
        yield from self._get_weights_iterator(primary_weights)

        secondary_weights = cast(
            Iterable[DefaultModelLoader.Source],
            getattr(model, "secondary_weights", ()),
        )
        for source in secondary_weights:
            yield from self._get_weights_iterator(source)

    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model,
                              model_config.revision,
                              fall_back_to_pt=True,
                              allow_patterns_overrides=None)

    def load_weights(self, model: nn.Module,
                     model_config: ModelConfig) -> None:
        weights_to_load = {name for name, _ in model.named_parameters()}
        loaded_weights = model.load_weights(
            self.get_all_weights(model_config, model))
        self.counter_after_loading_weights = time.perf_counter()
        logger.info(
            "Loading weights took %.2f seconds",
            self.counter_after_loading_weights -
            self.counter_before_loading_weights)
        # We only enable strict check for non-quantized models
        # that have loaded weights tracking currently.
        if model_config.quantization is None and loaded_weights is not None:
            weights_not_loaded = weights_to_load - loaded_weights
            if weights_not_loaded:
                raise ValueError("Following weights were not initialized from "
                                 f"checkpoint: {weights_not_loaded}")

DEFAULT_NUM_THREADS class-attribute instance-attribute

DEFAULT_NUM_THREADS = 8

counter_after_loading_weights class-attribute instance-attribute

counter_after_loading_weights: float = 0.0

counter_before_loading_weights class-attribute instance-attribute

counter_before_loading_weights: float = 0.0

Source dataclass

A source for weights.

Source code in vllm/model_executor/model_loader/default_loader.py
@dataclasses.dataclass
class Source:
    """A source for weights."""

    model_or_path: str
    """The model ID or path."""

    revision: Optional[str]
    """The optional model revision."""

    prefix: str = ""
    """A prefix to prepend to all weights."""

    fall_back_to_pt: bool = True
    """Whether .pt weights can be used."""

    allow_patterns_overrides: Optional[list[str]] = None
    """If defined, weights will load exclusively using these patterns."""

allow_patterns_overrides class-attribute instance-attribute

allow_patterns_overrides: Optional[list[str]] = None

If defined, weights will load exclusively using these patterns.

fall_back_to_pt class-attribute instance-attribute

fall_back_to_pt: bool = True

Whether .pt weights can be used.

model_or_path instance-attribute

model_or_path: str

The model ID or path.

prefix class-attribute instance-attribute

prefix: str = ''

A prefix to prepend to all weights.

revision instance-attribute

revision: Optional[str]

The optional model revision.

__init__

__init__(
    model_or_path: str,
    revision: Optional[str],
    prefix: str = "",
    fall_back_to_pt: bool = True,
    allow_patterns_overrides: Optional[list[str]] = None,
) -> None

__init__

__init__(load_config: LoadConfig)
Source code in vllm/model_executor/model_loader/default_loader.py
def __init__(self, load_config: LoadConfig):
    super().__init__(load_config)

    extra_config = load_config.model_loader_extra_config
    allowed_keys = {"enable_multithread_load", "num_threads"}
    unexpected_keys = set(extra_config.keys()) - allowed_keys

    if unexpected_keys:
        raise ValueError(f"Unexpected extra config keys for load format "
                         f"{load_config.load_format}: "
                         f"{unexpected_keys}")

_get_weights_iterator

_get_weights_iterator(
    source: Source,
) -> Generator[tuple[str, Tensor], None, None]

Get an iterator for the model weights based on the load format.

Source code in vllm/model_executor/model_loader/default_loader.py
def _get_weights_iterator(
        self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Get an iterator for the model weights based on the load format."""
    extra_config = self.load_config.model_loader_extra_config
    hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
        source.model_or_path, source.revision, source.fall_back_to_pt,
        source.allow_patterns_overrides)
    if self.load_config.load_format == "npcache":
        # Currently np_cache only support *.bin checkpoints
        assert use_safetensors is False
        weights_iterator = np_cache_weights_iterator(
            source.model_or_path,
            self.load_config.download_dir,
            hf_folder,
            hf_weights_files,
            self.load_config.use_tqdm_on_load,
        )
    elif use_safetensors:
        if self.load_config.load_format == "fastsafetensors":
            weights_iterator = fastsafetensors_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
            )
        else:
            if extra_config.get("enable_multithread_load"):
                weights_iterator = (
                    multi_thread_safetensors_weights_iterator(
                        hf_weights_files,
                        self.load_config.use_tqdm_on_load,
                        max_workers=extra_config.get(
                            "num_threads", self.DEFAULT_NUM_THREADS),
                    ))
            else:
                weights_iterator = safetensors_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.safetensors_load_strategy,
                )
    else:
        if extra_config.get("enable_multithread_load"):
            weights_iterator = multi_thread_pt_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
                self.load_config.pt_load_map_location,
                max_workers=extra_config.get("num_threads",
                                             self.DEFAULT_NUM_THREADS),
            )
        else:
            weights_iterator = pt_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
                self.load_config.pt_load_map_location,
            )

    if current_platform.is_tpu():
        from vllm.platforms.tpu import USE_TPU_COMMONS

        if not USE_TPU_COMMONS:
            # In PyTorch XLA, we should call `torch_xla.sync`
            # frequently so that not too many ops are accumulated
            # in the XLA program.
            import torch_xla

            def _xla_weights_iterator(iterator: Generator):
                for weights in iterator:
                    yield weights
                    torch_xla.sync(wait=False)

            weights_iterator = _xla_weights_iterator(weights_iterator)

    if self.counter_before_loading_weights == 0.0:
        self.counter_before_loading_weights = time.perf_counter()
    # Apply the prefix.
    return ((source.prefix + name, tensor)
            for (name, tensor) in weights_iterator)

_prepare_weights

_prepare_weights(
    model_name_or_path: str,
    revision: Optional[str],
    fall_back_to_pt: bool,
    allow_patterns_overrides: Optional[list[str]],
) -> tuple[str, list[str], bool]

Prepare weights for the model.

If the model is not local, it will be downloaded.

Source code in vllm/model_executor/model_loader/default_loader.py
def _prepare_weights(
    self,
    model_name_or_path: str,
    revision: Optional[str],
    fall_back_to_pt: bool,
    allow_patterns_overrides: Optional[list[str]],
) -> tuple[str, list[str], bool]:
    """Prepare weights for the model.

    If the model is not local, it will be downloaded."""
    model_name_or_path = (maybe_download_from_modelscope(
        model_name_or_path, revision) or model_name_or_path)

    is_local = os.path.isdir(model_name_or_path)
    load_format = self.load_config.load_format
    use_safetensors = False
    index_file = SAFE_WEIGHTS_INDEX_NAME
    # Some quantized models use .pt files for storing the weights.
    if load_format == "auto":
        allow_patterns = ["*.safetensors", "*.bin"]
    elif (load_format == "safetensors"
          or load_format == "fastsafetensors"):
        use_safetensors = True
        allow_patterns = ["*.safetensors"]
    elif load_format == "mistral":
        use_safetensors = True
        allow_patterns = ["consolidated*.safetensors"]
        index_file = "consolidated.safetensors.index.json"
    elif load_format == "pt":
        allow_patterns = ["*.pt"]
    elif load_format == "npcache":
        allow_patterns = ["*.bin"]
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    if fall_back_to_pt:
        allow_patterns += ["*.pt"]

    if allow_patterns_overrides is not None:
        allow_patterns = allow_patterns_overrides

    if not is_local:
        hf_folder = download_weights_from_hf(
            model_name_or_path,
            self.load_config.download_dir,
            allow_patterns,
            revision,
            ignore_patterns=self.load_config.ignore_patterns,
        )
    else:
        hf_folder = model_name_or_path

    hf_weights_files: list[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
        if len(hf_weights_files) > 0:
            if pattern == "*.safetensors":
                use_safetensors = True
            break

    if use_safetensors:
        # For models like Mistral-7B-Instruct-v0.3
        # there are both sharded safetensors files and a consolidated
        # safetensors file. Using both breaks.
        # Here, we download the `model.safetensors.index.json` and filter
        # any files not found in the index.
        if not is_local:
            download_safetensors_index_file_from_hf(
                model_name_or_path,
                index_file,
                self.load_config.download_dir,
                revision,
            )
        hf_weights_files = filter_duplicate_safetensors_files(
            hf_weights_files, hf_folder, index_file)
    else:
        hf_weights_files = filter_files_not_needed_for_inference(
            hf_weights_files)

    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`")

    return hf_folder, hf_weights_files, use_safetensors

download_model

download_model(model_config: ModelConfig) -> None
Source code in vllm/model_executor/model_loader/default_loader.py
def download_model(self, model_config: ModelConfig) -> None:
    self._prepare_weights(model_config.model,
                          model_config.revision,
                          fall_back_to_pt=True,
                          allow_patterns_overrides=None)

get_all_weights

get_all_weights(
    model_config: ModelConfig, model: Module
) -> Generator[tuple[str, Tensor], None, None]
Source code in vllm/model_executor/model_loader/default_loader.py
def get_all_weights(
    self,
    model_config: ModelConfig,
    model: nn.Module,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    primary_weights = DefaultModelLoader.Source(
        model_config.model,
        model_config.revision,
        prefix="",
        fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
                                True),
        allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
                                         None),
    )
    yield from self._get_weights_iterator(primary_weights)

    secondary_weights = cast(
        Iterable[DefaultModelLoader.Source],
        getattr(model, "secondary_weights", ()),
    )
    for source in secondary_weights:
        yield from self._get_weights_iterator(source)

load_weights

load_weights(
    model: Module, model_config: ModelConfig
) -> None
Source code in vllm/model_executor/model_loader/default_loader.py
def load_weights(self, model: nn.Module,
                 model_config: ModelConfig) -> None:
    weights_to_load = {name for name, _ in model.named_parameters()}
    loaded_weights = model.load_weights(
        self.get_all_weights(model_config, model))
    self.counter_after_loading_weights = time.perf_counter()
    logger.info(
        "Loading weights took %.2f seconds",
        self.counter_after_loading_weights -
        self.counter_before_loading_weights)
    # We only enable strict check for non-quantized models
    # that have loaded weights tracking currently.
    if model_config.quantization is None and loaded_weights is not None:
        weights_not_loaded = weights_to_load - loaded_weights
        if weights_not_loaded:
            raise ValueError("Following weights were not initialized from "
                             f"checkpoint: {weights_not_loaded}")