Skip to content

vllm.model_executor.models.vision ΒΆ

VisionFeatureSelectStrategy module-attribute ΒΆ

VisionFeatureSelectStrategy = Union[
    Literal["class", "default", "full"],
    Callable[[Tensor], Tensor],
]

_C module-attribute ΒΆ

_C = TypeVar('_C', bound=PretrainedConfig)

logger module-attribute ΒΆ

logger = init_logger(__name__)

VisionEncoderInfo ΒΆ

Bases: ABC, Generic[_C]

Source code in vllm/model_executor/models/vision.py
class VisionEncoderInfo(ABC, Generic[_C]):

    def __init__(self, hf_config: _C) -> None:
        super().__init__()

        self.hf_config = hf_config
        self.vision_config = hf_config.vision_config

    @abstractmethod
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_image_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_grid_length(self) -> int:
        raise NotImplementedError

hf_config instance-attribute ΒΆ

hf_config = hf_config

vision_config instance-attribute ΒΆ

vision_config = vision_config

__init__ ΒΆ

__init__(hf_config: _C) -> None
Source code in vllm/model_executor/models/vision.py
def __init__(self, hf_config: _C) -> None:
    super().__init__()

    self.hf_config = hf_config
    self.vision_config = hf_config.vision_config

get_image_size abstractmethod ΒΆ

get_image_size() -> int
Source code in vllm/model_executor/models/vision.py
@abstractmethod
def get_image_size(self) -> int:
    raise NotImplementedError

get_num_image_tokens abstractmethod ΒΆ

get_num_image_tokens(
    *, image_width: int, image_height: int
) -> int
Source code in vllm/model_executor/models/vision.py
@abstractmethod
def get_num_image_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
) -> int:
    raise NotImplementedError

get_patch_grid_length abstractmethod ΒΆ

get_patch_grid_length() -> int
Source code in vllm/model_executor/models/vision.py
@abstractmethod
def get_patch_grid_length(self) -> int:
    raise NotImplementedError

get_patch_size abstractmethod ΒΆ

get_patch_size() -> int
Source code in vllm/model_executor/models/vision.py
@abstractmethod
def get_patch_size(self) -> int:
    raise NotImplementedError

VisionLanguageConfig ΒΆ

Bases: Protocol

Source code in vllm/model_executor/models/vision.py
class VisionLanguageConfig(Protocol):
    vision_config: Final[PretrainedConfig]

vision_config instance-attribute ΒΆ

vision_config: Final[PretrainedConfig]

_get_vision_feature_selector ΒΆ

_get_vision_feature_selector(
    strategy: VisionFeatureSelectStrategy,
) -> Callable[[Tensor], Tensor]
Source code in vllm/model_executor/models/vision.py
def _get_vision_feature_selector(
    strategy: VisionFeatureSelectStrategy,
) -> Callable[[torch.Tensor], torch.Tensor]:
    if callable(strategy):
        return strategy

    # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
    if strategy == "class":
        return lambda feats: feats[:, 0, :]

    # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
    if strategy == "default":
        return lambda feats: feats[:, 1:, :]

    if strategy == "full":
        return lambda feats: feats

    assert_never(strategy)

get_load_balance_assignment ΒΆ

get_load_balance_assignment(
    sizes: list[int], num_gpus: int = 2
) -> tuple[list[int], list[int], list[int]]

Generate load balancing assignment and metadata for distributing data across GPUs. The load is determined by the total image sizes, not the number of images.

Parameters:

Name Type Description Default
sizes list[int]

The size of each image

required
num_gpus int

Number of GPUs to balance across

2

Returns:

Name Type Description
shuffle_indices list[int]

Indices to reorder data for balanced loading

gpu_sample_counts list[int]

Number of samples assigned to each GPU

grouped_sizes_per_gpu list[int]

Total size assigned to each GPU

Example
sizes = [1000, 100, 200, 50]
num_gpus=2
Source code in vllm/model_executor/models/vision.py
def get_load_balance_assignment(
    sizes: list[int],
    num_gpus: int = 2,
) -> tuple[list[int], list[int], list[int]]:
    """
    Generate load balancing assignment and metadata 
    for distributing data across GPUs.
    The load is determined by the total image sizes,
    not the number of images.

    Args:
        sizes: The size of each image
        num_gpus: Number of GPUs to balance across

    Returns:
        shuffle_indices: 
            Indices to reorder data for balanced loading
        gpu_sample_counts: 
            Number of samples assigned to each GPU
        grouped_sizes_per_gpu: 
            Total size assigned to each GPU

    Example:
        ```
        sizes = [1000, 100, 200, 50]
        num_gpus=2
        ```

    """

    n_samples = len(sizes)

    # Handle edge cases
    if n_samples == 0:
        return [], [0] * num_gpus, [0] * num_gpus

    # Use greedy algorithm - balance by total size, not sample count
    gpu_assignments = [list[int]() for _ in range(num_gpus)]
    gpu_loads = [0] * num_gpus  # This tracks total SIZE, not sample count

    # Sort indices by size (largest first for better load balancing)
    # sizes = [1000, 100, 200, 50]
    # large_to_small_indices = [0, 2, 1, 3]
    large_to_small_indices = sorted(range(n_samples),
                                    key=lambda i: sizes[i],
                                    reverse=True)

    for idx in large_to_small_indices:
        # Find GPU with minimum current load (by total size)
        min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
        gpu_assignments[min_gpu].append(idx)
        gpu_loads[min_gpu] += sizes[idx]

    # Create shuffle indices and counts
    shuffle_indices = list[int]()
    gpu_sample_counts = list[int]()
    for gpu_id in range(num_gpus):
        # GPU_0 = [1000] = [0]
        # GPU_1 = [200, 100, 50] = [2, 1, 3]
        # shuffle_indices = [0, 2, 1, 3]
        shuffle_indices.extend(gpu_assignments[gpu_id])
        # GPU_0 = [1]
        # GPU_1 = [3]
        # gpu_sample_counts = [1, 3]
        gpu_sample_counts.append(len(gpu_assignments[gpu_id]))

    return (shuffle_indices, gpu_sample_counts, gpu_loads)

get_vision_encoder_info ΒΆ

get_vision_encoder_info(
    hf_config: VisionLanguageConfig,
) -> VisionEncoderInfo
Source code in vllm/model_executor/models/vision.py
def get_vision_encoder_info(
        hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
    # Avoid circular imports
    from .clip import CLIPEncoderInfo, CLIPVisionConfig
    from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
    from .siglip import SiglipEncoderInfo, SiglipVisionConfig

    if isinstance(hf_config.vision_config, CLIPVisionConfig):
        return CLIPEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, SiglipVisionConfig):
        return SiglipEncoderInfo(hf_config)

    msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
    raise NotImplementedError(msg)

get_vit_attn_backend ΒΆ

get_vit_attn_backend(
    head_size: int, dtype: dtype
) -> _Backend

Get the available attention backend for Vision Transformer.

Source code in vllm/model_executor/models/vision.py
def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
    """
    Get the available attention backend for Vision Transformer.
    """
    # Lazy import to avoid circular dependency
    from vllm.attention.selector import get_env_variable_attn_backend

    selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
    if selected_backend is not None:
        return selected_backend

    return current_platform.get_vit_attn_backend(head_size, dtype)

resolve_visual_encoder_outputs ΒΆ

resolve_visual_encoder_outputs(
    encoder_outputs: Union[Tensor, list[Tensor]],
    post_layer_norm: Optional[LayerNorm],
    *,
    select_layers: Optional[list[int]] = None,
    max_possible_layers: Optional[int] = None,
    feature_select_strategy: Optional[
        VisionFeatureSelectStrategy
    ] = None,
) -> Tensor

Given the outputs a visual encoder module that may correspond to the output of the last layer, or a list of hidden states to be stacked, handle post normalization and resolve it into a single output tensor.

Parameters:

Name Type Description Default
encoder_outputs Union[Tensor, list[Tensor]]

Output of encoder's last layer or all hidden states.

required
post_layer_norm Optional[LayerNorm]

Post norm to apply to the output of the encoder.

required
select_layers Optional[list[int]]

Optional layer indices to grab from the encoder outputs; if provided, encoder outputs must be a list.

None
max_possible_layers Optional[int]

Total layers in the fully loaded visual encoder.

None
feature_select_strategy Optional[VisionFeatureSelectStrategy]

Defines how to select the hidden states from each layer.

None
Source code in vllm/model_executor/models/vision.py
def resolve_visual_encoder_outputs(
    encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
    post_layer_norm: Optional[torch.nn.LayerNorm],
    *,
    select_layers: Optional[list[int]] = None,
    max_possible_layers: Optional[int] = None,
    feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:
    """Given the outputs a visual encoder module that may correspond to the
    output of the last layer, or a list of hidden states to be stacked,
    handle post normalization and resolve it into a single output tensor.

    Args:
        encoder_outputs: Output of encoder's last layer or all hidden states.
        post_layer_norm: Post norm to apply to the output of the encoder.
        select_layers: Optional layer indices to grab from the encoder
            outputs; if provided, encoder outputs must be a list.
        max_possible_layers: Total layers in the fully loaded visual encoder.
        feature_select_strategy: Defines how to select the hidden states
            from each layer.
    """
    if select_layers is None:
        if not isinstance(encoder_outputs, torch.Tensor):
            raise ValueError("Expected only a single encoder output when "
                             "`select_layers` is not provided")

        if feature_select_strategy is not None:
            select_features = _get_vision_feature_selector(
                feature_select_strategy)
            encoder_outputs = select_features(encoder_outputs)

        if post_layer_norm is not None:
            return post_layer_norm(encoder_outputs)

        return encoder_outputs

    if max_possible_layers is None:
        raise ValueError("`max_possible_layers` must be provided "
                         "alongside `select_layers`")

    # Get the hidden states corresponding to the layer indices.
    # Negative values are relative to the full visual encoder,
    # so offset them depending on how many layers were loaded.
    # NOTE: this assumes that encoder_outputs is a list containing
    # the inputs to the visual encoder, followed by the hidden states
    # of each layer.
    num_loaded_layers = len(encoder_outputs) - 1
    offset = max_possible_layers - num_loaded_layers
    hs_pool = [
        encoder_outputs[layer_idx]
        if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
        for layer_idx in select_layers
    ]

    if feature_select_strategy is not None:
        select_features = _get_vision_feature_selector(feature_select_strategy)
        hs_pool = [select_features(hs) for hs in hs_pool]

    # Apply post-norm on the final hidden state if we are using it
    uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1)
    if post_layer_norm is not None and uses_last_layer:
        hs_pool[-1] = post_layer_norm(hs_pool[-1])

    return torch.cat(hs_pool, dim=-1)

run_dp_sharded_mrope_vision_model ΒΆ

run_dp_sharded_mrope_vision_model(
    vision_model: Module,
    pixel_values: Tensor,
    grid_thw_list: list[list[int]],
    *,
    rope_type: Literal["rope_3d", "rope_2d"],
) -> tuple[Tensor, ...]

Run a vision model with data parallelism (DP) sharding. The function will shard the input image tensor on the first dimension and run the vision model. This function is used to run the vision model with mrope.

Parameters:

Name Type Description Default
vision_model Module

Vision model.

required
pixel_values Tensor

Image/Video input tensor.

required
grid_thw_list list[list[int]]

List of grid dimensions for each image

required
rope_type Literal['rope_3d', 'rope_2d']

Type of rope used in the vision model. Different rope types have different dimension to do ViT. "rope_3d" for 3D rope (e.g., Qwen2.5-VL) "rope_2d" for 2D rope (e.g., Kimi-VL)

required

Returns: torch.Tensor: Output image embeddings

Example
vision_model.out_hidden_size = 64
vision_model.spatial_merge_size = 2
pixel_values.shape = (1350, channel)
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
tp_size=2
Source code in vllm/model_executor/models/vision.py
def run_dp_sharded_mrope_vision_model(
    vision_model: torch.nn.Module,
    pixel_values: torch.Tensor,
    grid_thw_list: list[list[int]],
    *,
    rope_type: Literal["rope_3d", "rope_2d"],
) -> tuple[torch.Tensor, ...]:
    """Run a vision model with data parallelism (DP) sharding. 
    The function will shard the input image tensor on the 
    first dimension and run the vision model.
    This function is used to run the vision model with mrope.

    Args:
        vision_model (torch.nn.Module): Vision model.
        pixel_values (torch.Tensor): Image/Video input tensor.
        grid_thw_list: List of grid dimensions for each image
        rope_type: Type of rope used in the vision model.
                   Different rope types have different dimension to do ViT.
                   "rope_3d" for 3D rope (e.g., Qwen2.5-VL)
                   "rope_2d" for 2D rope (e.g., Kimi-VL)
    Returns:
        torch.Tensor: Output image embeddings

    Example:
        ```
        vision_model.out_hidden_size = 64
        vision_model.spatial_merge_size = 2
        pixel_values.shape = (1350, channel)
        grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
        tp_size=2
        ```

    """
    tp_size = get_tensor_model_parallel_world_size()

    # GPU_0 tp_rank_local = 0
    # GPU_1 tp_rank_local = 1
    tp_rank_local = get_tensor_model_parallel_rank()

    # patches_per_image = [1000, 100, 200, 50]
    patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
    # patches_per_image = [0, 1000, 1100, 1300, 1350]
    cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]

    # Get load balancing assignment with all metadata
    # image_to_tp_rank = [0, 2, 1, 3]
    # gpu_sample_counts = [1, 3]
    # grouped_pixel_values_len = [1000, 350]
    (image_to_tp_rank, gpu_sample_counts,
     grouped_pixel_values_len) = get_load_balance_assignment(
         patches_per_image, tp_size)

    # cu_gpu_sample_counts = [0, 1, 4]
    cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]

    # GPU_0 image_idxs_local = [0]
    # GPU_1 image_idxs_local = [2, 1, 3]
    image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]:
                                        cum_gpu_sample_counts[tp_rank_local +
                                                              1]]

    # Get the pixel values for the local images based on the image_idxs_local
    if len(image_idxs_local) > 0:
        pixel_values_local = torch.cat([
            pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]]
            for i in image_idxs_local
        ])
    else:
        # Handle case where this rank has no images
        pixel_values_local = torch.empty((0, pixel_values.shape[1]),
                                         device=pixel_values.device,
                                         dtype=pixel_values.dtype)
    # embed_dim_reduction_factor = 2 * 2
    if rope_type == "rope_2d":
        embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
                                      vision_model.merge_kernel_size[1])
    else:
        embed_dim_reduction_factor = (vision_model.spatial_merge_size *
                                      vision_model.spatial_merge_size)

    # Find the max length across all ranks
    # The output embedding of every DP rank has to be
    # padded to this length for tensor_model_parallel_all_gather
    # to work
    max_len_per_rank = max(
        grouped_pixel_values_len) // embed_dim_reduction_factor
    local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]

    # Run the vision model on the local pixel_values_local
    if rope_type == "rope_2d":
        if pixel_values_local.shape[0] > 0:
            image_embeds_local = vision_model(
                pixel_values_local, torch.tensor(local_grid_thw_list))
            if isinstance(image_embeds_local, list):
                image_embeds_local = torch.cat(image_embeds_local, dim=0)
        else:
            out_dim = getattr(vision_model.config, "hidden_size", None)
            image_embeds_local = torch.empty(
                (0, embed_dim_reduction_factor, out_dim),
                device=pixel_values.device,
                dtype=pixel_values.dtype)
    else:
        if pixel_values_local.shape[0] > 0:
            image_embeds_local = vision_model(pixel_values_local,
                                              local_grid_thw_list)
        else:
            # Handle empty case
            image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
                                             device=pixel_values.device,
                                             dtype=pixel_values.dtype)

    # Pad the output based on max_len_per_rank
    # for tensor_model_parallel_all_gather to work
    current_len = image_embeds_local.shape[0]
    if current_len < max_len_per_rank:
        padding_size = max_len_per_rank - current_len
        if rope_type == "rope_2d":
            padding = torch.empty((padding_size, image_embeds_local.shape[1],
                                   image_embeds_local.shape[2]),
                                  dtype=image_embeds_local.dtype,
                                  device=image_embeds_local.device)
        else:
            padding = torch.empty((padding_size, image_embeds_local.shape[1]),
                                  dtype=image_embeds_local.dtype,
                                  device=image_embeds_local.device)
        image_embeds_local_padded = torch.cat([image_embeds_local, padding],
                                              dim=0)
    else:
        image_embeds_local_padded = image_embeds_local

    # Do all_gather to collect embeddings from all ranks
    gathered_embeds = tensor_model_parallel_all_gather(
        image_embeds_local_padded, dim=0)

    # Remove padding and reconstruct per-rank embeddings
    rank_embeddings = list[torch.Tensor]()
    for rank in range(tp_size):
        start_idx = rank * max_len_per_rank
        end_idx = start_idx + (grouped_pixel_values_len[rank] //
                               embed_dim_reduction_factor)
        rank_embeddings.append(gathered_embeds[start_idx:end_idx])

    patches_per_output_image = [(patch_size // embed_dim_reduction_factor)
                                for patch_size in patches_per_image]

    # Reconstruct embeddings in the original order
    original_order_embeddings = [None] * len(grid_thw_list)
    current_idx = 0
    for rank in range(tp_size):
        count = gpu_sample_counts[rank]
        if count > 0:
            # Get images assigned to this rank in shuffled order
            # GPU_0 = image_idxs_local  [0]
            # GPU_1 = image_idxs_local  [2, 1, 3]
            rank_images = image_to_tp_rank[current_idx:current_idx + count]

            rank_embed = rank_embeddings[rank]
            # Split rank embeddings back to individual images
            embed_start = 0
            for img_idx in rank_images:
                img_patches = patches_per_output_image[img_idx]
                original_order_embeddings[img_idx] = rank_embed[
                    embed_start:embed_start + img_patches]
                embed_start += img_patches
            current_idx += count
    out_embeddings = tuple(embed for embed in original_order_embeddings
                           if embed is not None)
    assert len(out_embeddings) == len(
        original_order_embeddings), "Found unassigned embeddings"
    return out_embeddings

run_dp_sharded_vision_model ΒΆ

run_dp_sharded_vision_model(
    image_input: Tensor, vision_model: Module
) -> Tensor

Run a vision model with data parallelism (DP) sharding. The function will shard the input image tensor on the first dimension and run the vision model

Parameters:

Name Type Description Default
image_input Tensor

Image input tensor.

required
vision_model Module

Vision model.

required

Returns: torch.Tensor: Output image embeddings

Source code in vllm/model_executor/models/vision.py
def run_dp_sharded_vision_model(image_input: torch.Tensor,
                                vision_model: torch.nn.Module) -> torch.Tensor:
    """Run a vision model with data parallelism (DP) sharding. The function 
    will shard the input image tensor on the first dimension and run the vision
    model

    Args:
        image_input (torch.Tensor): Image input tensor.
        vision_model (torch.nn.Module): Vision model.
    Returns:
        torch.Tensor: Output image embeddings
    """

    num_chunks = image_input.shape[0]
    mp_world_size = get_tensor_model_parallel_world_size()
    num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
    num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
    pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
    image_input_padded = torch.nn.functional.pad(image_input, pad)
    rank = get_tensor_model_parallel_rank()
    image_input_per_rank = image_input_padded[rank *
                                              num_chunks_per_rank:(rank + 1) *
                                              num_chunks_per_rank, ...]

    vision_embeddings = vision_model(image_input_per_rank)
    # Ensure tensor is contiguous before all_gather
    vision_embeddings = vision_embeddings.contiguous()
    vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
                                                         dim=0)
    vision_embeddings = vision_embeddings[:num_chunks, ...]
    return vision_embeddings