Skip to content

vllm.multimodal.evs

compute_mrope_for_media

compute_mrope_for_media(
    video_size_thw: LongTensor,
    spatial_merge_size: int,
    tokens_per_second: float = 1.0,
    video_second_per_grid: float = 1.0,
) -> Tensor

Computes the mrope for video embeddings based on the grid dimensions. Computed mrope positions match original qwen 2.5 implementation, but positions are built for media being the first element in sequence.

Parameters:

Name Type Description Default
video_size_thw LongTensor

Media size (num frames, rows, cols)

required
spatial_merge_size int

Size reduction for rows & cols dimensions.

required
tokens_per_second float

Number of tokens per second.

1.0
video_second_per_grid float

Number of seconds per video.

1.0

Returns:

Type Description
Tensor

Tensor of shape (T * H * W, 4) where last dimension

Tensor

represents mrope positions [0:3), while the last channel

Tensor

contains value of llm_grid_w repeated for all positions.

Source code in vllm/multimodal/evs.py
def compute_mrope_for_media(
    video_size_thw: torch.LongTensor,
    spatial_merge_size: int,
    tokens_per_second: float = 1.0,
    video_second_per_grid: float = 1.0,
) -> torch.Tensor:
    """
    Computes the mrope for video embeddings based on the grid dimensions.
    Computed mrope positions match original qwen 2.5 implementation,
    but positions are built for media being the first element in sequence.

    Args:
        video_size_thw: Media size (num frames, rows, cols)
        spatial_merge_size: Size reduction for rows & cols dimensions.
        tokens_per_second: Number of tokens per second.
        video_second_per_grid: Number of seconds per video.

    Returns:
        Tensor of shape `(T * H * W, 4)` where last dimension
        represents mrope positions [0:3), while the last channel
        contains value of llm_grid_w repeated for all positions.
    """
    llm_grid_t = video_size_thw[0]
    llm_grid_h = video_size_thw[1] // spatial_merge_size
    llm_grid_w = video_size_thw[2] // spatial_merge_size

    t_index = ((torch.arange(llm_grid_t).view(-1, 1).expand(
        -1, llm_grid_h * llm_grid_w).mul(
            tokens_per_second * video_second_per_grid)).long().flatten())
    h_index = (torch.arange(llm_grid_h).view(1, -1,
                                             1).expand(llm_grid_t, -1,
                                                       llm_grid_w).flatten())
    w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
        llm_grid_t, llm_grid_h, -1).flatten())
    llm_grid_w = (torch.tensor([llm_grid_w
                                ]).view(1, 1,
                                        1).expand(llm_grid_t, llm_grid_h,
                                                  llm_grid_w).flatten())

    positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
    return positions

compute_retained_tokens_count

compute_retained_tokens_count(
    video_size_thw: LongTensor,
    spatial_merge_size: int,
    q: float,
) -> int

Compute the number of retained tokens for a given video. Method ensures that we retain all the tokens from the first frame regardless of the pruning rate.

Parameters:

Name Type Description Default
video_size_thw LongTensor

The size of the video in the format of (T, H, W).

required
spatial_merge_size int

The size of the spatial merge.

required
q float

The pruning rate.

required

Returns:

Type Description
int

The number of retained tokens.

Source code in vllm/multimodal/evs.py
def compute_retained_tokens_count(video_size_thw: torch.LongTensor,
                                  spatial_merge_size: int, q: float) -> int:
    """
    Compute the number of retained tokens for a given video.
    Method ensures that we retain all the tokens from the first frame
    regardless of the pruning rate.

    Args:
        video_size_thw: The size of the video in the format of (T, H, W).
        spatial_merge_size: The size of the spatial merge.
        q: The pruning rate.

    Returns:
        The number of retained tokens.
    """
    T, H, W = map(int, video_size_thw)
    min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
    evs_num_tokens = int(T * min_num_tokens * (1 - q))
    return max(min_num_tokens, evs_num_tokens)

compute_retention_mask

compute_retention_mask(
    video_embeds: Tensor,
    video_size_thw: LongTensor,
    spatial_merge_size: int,
    q: float,
) -> Tensor

Computes the retention mask for input video embeddings.

Parameters:

Name Type Description Default
video_embeds `torch.Tensor`

The input video embeddings of shape (T * H * W // spatial_merge_size ^ 2, hidden_size)

required
video_size_thw `torch.LongTensor` of shape `(3)`

The temporal, height and width of video.

required
spatial_merge_size int

Size reduction for rows & cols dimensions.

required
q float

(float): Pruning rate factor [0,1)

required

Returns:

Type Description
Tensor

torch.Tensor: The retention mask for the video embeddings of (T * H * W // spatial_merge_size ^ 2) shape.

Source code in vllm/multimodal/evs.py
def compute_retention_mask(
    video_embeds: torch.Tensor,
    video_size_thw: torch.LongTensor,
    spatial_merge_size: int,
    q: float,
) -> torch.Tensor:
    """
    Computes the retention mask for input video embeddings.

    Args:
        video_embeds (`torch.Tensor`): The input video embeddings
            of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
        video_size_thw (`torch.LongTensor` of shape `(3)`):
            The temporal, height and width of video.
        spatial_merge_size: Size reduction for rows & cols dimensions.
        q: (`float`): Pruning rate factor [0,1)

    Returns:
        `torch.Tensor`: The retention mask for the video embeddings of
            `(T * H * W // spatial_merge_size ^ 2)` shape.
    """
    T, H, W = video_size_thw

    # Use reshape instead of einops to avoid graph breaks
    video_embeds = video_embeds.reshape(
        T,
        H // spatial_merge_size,
        W // spatial_merge_size,
        video_embeds.size(-1),
    )

    # Core EVS
    similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...],
                                                       video_embeds[:-1, ...],
                                                       dim=-1)
    dissimilarity = 1 - similarity

    # Always ensure we include all tokens from the first frame
    dissimilarity = torch.cat(
        [255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity],
        dim=0)

    dissimilarity_flat = dissimilarity.view(-1)
    order = torch.argsort(dissimilarity_flat,
                          dim=-1,
                          descending=True,
                          stable=True)
    retain_num_tokens = compute_retained_tokens_count(video_size_thw,
                                                      spatial_merge_size, q)
    topk_indices = order[:retain_num_tokens]

    retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
    retention_mask[topk_indices] = True
    retention_mask = retention_mask.reshape(dissimilarity.size())

    mask = retention_mask.view(-1)  # "T H W -> (T H W)"
    return mask

recompute_mrope_positions

recompute_mrope_positions(
    input_ids: LongTensor,
    multimodal_positions: list[Tensor],
    mrope_positions: LongTensor,
    num_computed_tokens: int,
    vision_start_token_id: int,
    image_token_id: int,
    video_token_id: int,
) -> tuple[LongTensor, int]

Update part of input mrope positions. Original mrope_positions are computed incorrectly, so once we prune media tokens we should reflect this in the mrope positions for the LLM.

This method supports chunked prefill approach where multimodal_embeddings are passed to LLM in chunks, so input multimodal_embeddings may contain zero, some or even some part of all multimodal_embeddings for a given prompt.

Each multimodal_positions has 4 extra channels (First 3 channels corresponds to original 3 mrope positions, last channel is the maximum width of the media repeated). Provided multimodal_positions do not reflect location of media position in sequence - they are computed like the media is in the 0-th position in the sequence.

Method works as follows: it recomputes mrope_positions starting from the num_computed_tokens for total_len_of_multimodal_embeddings and then shifts all text tokens that goes after total_len_of_multimodal_embeddings.

It also handles case when multimodal_embeddings is partial (e.g. one media is split into two prefill stages)

Parameters:

Name Type Description Default
input_ids LongTensor

(N,) All input tokens of the prompt (entire sequence).

required
multimodal_positions list[Tensor]

List of mrope positsions for each media.

required
mrope_positions LongTensor

Existing mrope positions (4, N) for entire sequence.

required
num_computed_tokens int

A number of computed tokens so far.

required
vision_start_token_id int

Token indicating start of vision media.

required
image_token_id int

Image token id

required
video_token_id int

Video token id

required

Returns:

Type Description
tuple[LongTensor, int]

Tuple of (mrope_positions, mrope_position_delta).

Source code in vllm/multimodal/evs.py
def recompute_mrope_positions(
    input_ids: torch.LongTensor,
    multimodal_positions: list[torch.Tensor],
    mrope_positions: torch.LongTensor,
    num_computed_tokens: int,
    vision_start_token_id: int,
    image_token_id: int,
    video_token_id: int,
) -> tuple[torch.LongTensor, int]:
    """
    Update part of input mrope positions.
    Original mrope_positions are computed incorrectly, so once we prune media
    tokens we should reflect this in the mrope positions for the LLM.

    This method supports chunked prefill approach where
    multimodal_embeddings are passed to LLM in chunks, so input
    multimodal_embeddings may contain zero, some or even some part of all
    multimodal_embeddings for a given prompt.

    Each multimodal_positions has 4 extra channels
    (First 3 channels corresponds to original 3 mrope positions, last channel
    is the maximum width of the media repeated). Provided multimodal_positions
    do not reflect location of media position in sequence - they are computed
    like the media is in the 0-th position in the sequence.

    Method works as follows: it recomputes mrope_positions starting from the
    `num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
    shifts all text tokens that goes after total_len_of_multimodal_embeddings.

    It also handles case when multimodal_embeddings is partial
    (e.g. one media is split into two prefill stages)

    Args:
        input_ids: (N,) All input tokens of the prompt (entire sequence).
        multimodal_positions: List of mrope positsions for each media.
        mrope_positions: Existing mrope positions (4, N) for entire sequence.
        num_computed_tokens: A number of computed tokens so far.
        vision_start_token_id: Token indicating start of vision media.
        image_token_id: Image token id
        video_token_id: Video token id

    Returns:
        Tuple of (mrope_positions, mrope_position_delta).
    """

    # Tensors
    positions: torch.LongTensor = typing.cast(
        torch.LongTensor, mrope_positions.clone())  # (3, N)
    N = input_ids.numel()

    image_mask = input_ids.eq(image_token_id)
    video_mask = input_ids.eq(video_token_id)
    media_mask = image_mask | video_mask
    text_mask = ~media_mask

    # Early exit: no media in this chunk
    if len(multimodal_positions) == 0:
        delta = (int((positions.max().item() + 1) -
                     N) if positions.numel() else -N)
        return positions, delta

    total_mm_tokens = torch.count_nonzero(media_mask)
    seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])

    # Early exit: we've updated positions for all media tokens
    # (and consequently - for all remaining text tokens)
    if seen_mm_tokens == total_mm_tokens:
        delta = (int((positions.max().item() + 1) -
                     N) if positions.numel() else -N)
        return positions, delta

    vision_start_indices = (input_ids == vision_start_token_id).nonzero(
        as_tuple=True)[0]

    for mm_pos in multimodal_positions:
        # Each mm_pos can be a complete embedding for single media
        # or it can be a part of a single media (due to chunked prefill)

        # Cases to cover
        # - Current prefill chunk has no vision start indexes at all
        # - Vision start token appeared in previous prefill round
        # - Regular case
        seen_vision_start_indices = vision_start_indices[vision_start_indices <
                                                         num_computed_tokens]

        if len(seen_vision_start_indices):
            # If we have encountered some vision start indexes,
            # then we should check the condition:
            # | --- prefill 1 ------| ---- prefill 2 ----- |
            # | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
            last_vision_start_token = seen_vision_start_indices[-1]
            seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
                media_mask[:last_vision_start_token])
            in_the_middle_of_media = (
                seen_mm_tokens > seem_mm_tokens_before_last_vision_start)

            if in_the_middle_of_media:
                mm_embeddings_seen = (seen_mm_tokens -
                                      seem_mm_tokens_before_last_vision_start)
                global_mm_start = last_vision_start_token
            else:
                # We have completed previous mm_embedding part and
                # ready to start a new one
                next_vision_start_token = vision_start_indices[
                    vision_start_indices >= num_computed_tokens][0]
                mm_embeddings_seen = 0
                global_mm_start = next_vision_start_token

        else:
            # If there were no vision start indexes so far,
            # let's find first vision start index
            next_vision_start_token = vision_start_indices[
                vision_start_indices >= num_computed_tokens][0]

            mm_embeddings_seen = 0
            global_mm_start = next_vision_start_token

        # Offset right after vision_start_token
        base = positions[-1, global_mm_start] + 1
        local_start = global_mm_start + 1 + mm_embeddings_seen
        local_end = local_start + mm_pos.shape[1]
        positions[:, local_start:local_end] = mm_pos[0:3] + base

        # mm_pos[3, 0] is the max width of the media
        offset = mm_pos[3, 0] + base

        text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)

        positions[:, local_end:N] = text_pos_sum + offset - 1

        # Include distance to the next vision start token
        num_computed_tokens += mm_pos.shape[1]

    mrope_positions_delta = (positions.max() + 1 - N).item()
    return positions, mrope_positions_delta