Skip to content

vllm.model_executor.models.keye

KeyeImageInputs module-attribute

KeyeVideoInputs module-attribute

_I module-attribute

_I = TypeVar('_I', bound=KeyeProcessingInfo)

logger module-attribute

logger = init_logger(__name__)

BaseKeyeModule

Bases: Module

Source code in vllm/model_executor/models/keye.py
class BaseKeyeModule(nn.Module):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "lm_head.": "language_model.lm_head.",
        "model.": "language_model.model.",
    })

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: PretrainedConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

        self.visual = KeyeSiglipVisionModel(
            config.vision_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "visual"),
        )

        self.mlp_AR = self._build_projector(
            config,
            config.vision_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "mlp_AR"),
        )

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen3ForCausalLM"],
        )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @abstractmethod
    def _build_projector(self,
                         text_config: PretrainedConfig,
                         vision_config: PretrainedConfig,
                         quant_config: Optional[QuantizationConfig] = None,
                         prefix: str = "") -> nn.Module:
        raise ValueError("Need projector")

    def _process_image_input(self,
                             image_input: Any) -> tuple[torch.Tensor, ...]:
        siglip_position_ids = list()
        image_grid_hws = list()
        sample_indices = list()
        cu_seqlens = [0]

        image_grid_thw = image_input["image_grid_thw"]
        assert image_grid_thw.ndim == 2

        for idx, thaw in enumerate(image_grid_thw):
            thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
            numel = np.prod(thw_tuple)
            image_grid_hws.append(thw_tuple)
            image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
            siglip_position_ids.append(image_position_ids)
            sample_indices.append(torch.full((numel, ), idx,
                                             dtype=torch.int64))
            cu_seqlens.append(cu_seqlens[-1] + numel)

        if image_input["type"] == "image_embeds":
            raise ValueError(
                "Image embeddings are not supported for this processing path.")
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
            siglip_position_ids = torch.concat(siglip_position_ids,
                                               dim=0).to(pixel_values.device)
            cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
                pixel_values.device)
            sample_indices = torch.concat(sample_indices,
                                          dim=0).to(pixel_values.device)

            image_embeds = self.visual(
                pixel_values=pixel_values,
                image_grid_thw=image_grid_hws,
                position_ids=siglip_position_ids,
                vision_return_embed_list=False,
                interpolate_pos_encoding=True,
                sample_indices=sample_indices,
                cu_seqlens=cu_seqlens,
                use_rope=True,
                window_size=-1,
            )
            image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
            return image_embeds

    def _process_video_embeds(
        self,
        video_type: Literal["video_embeds", "pixel_values_videos"],
        video_grid_thw: list[torch.Tensor],
        pixel_values_videos: Optional[torch.Tensor] = None
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        siglip_position_ids = list()
        video_grid_hws = list()
        sample_indices = list()
        cu_seqlens = [0]

        assert video_grid_thw.ndim == 2
        for idx, sub_thw in enumerate(video_grid_thw):
            thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
            numel = np.prod(thw_tuple)

            video_grid_hws.append(thw_tuple)
            video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
            siglip_position_ids.append(video_position_ids)
            sample_indices.append(torch.full((numel, ), idx,
                                             dtype=torch.int64))
            cu_seqlens.append(cu_seqlens[-1] + numel)

        if video_type == "video_embeds":
            raise ValueError(
                "Video embeddings are not supported for this processing path.")
        else:
            pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
            siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
                pixel_values_videos.device)
            cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
                pixel_values_videos.device)
            sample_indices = torch.concat(sample_indices,
                                          dim=0).to(pixel_values_videos.device)

            video_embeds = self.visual(
                pixel_values=pixel_values_videos,
                image_grid_thw=video_grid_hws,
                position_ids=siglip_position_ids,
                vision_return_embed_list=True,
                interpolate_pos_encoding=True,
                sample_indices=sample_indices,
                cu_seqlens=cu_seqlens,
                use_rope=True,
                window_size=-1,
            )
            video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
            return video_embeds

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        for input_key in kwargs:
            if (input_key in ("pixel_values", "image_embeds")
                    and "images" not in modalities):
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if (input_key in ("pixel_values_videos", "video_embeds")
                    and "videos" not in modalities):
                modalities["videos"] = self._parse_and_validate_video_input(
                    **kwargs)

        return modalities

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return None

        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
                multimodal_embeddings += video_embeddings
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        """Run forward pass for Keye-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """Get the module prefix in multimodal models."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="mlp_AR.",
            tower_model="visual.",
        )

config instance-attribute

config = config

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_prefix={
        "lm_head.": "language_model.lm_head.",
        "model.": "language_model.model.",
    }
)

language_model instance-attribute

language_model = init_vllm_registered_model(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "language_model"),
    architectures=["Qwen3ForCausalLM"],
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

mlp_AR instance-attribute

mlp_AR = _build_projector(
    config,
    vision_config,
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "mlp_AR"),
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

visual instance-attribute

visual = KeyeSiglipVisionModel(
    vision_config,
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "visual"),
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/keye.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config: PretrainedConfig = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    multimodal_config = vllm_config.model_config.multimodal_config

    self.config = config
    self.multimodal_config = multimodal_config

    self.visual = KeyeSiglipVisionModel(
        config.vision_config,
        quant_config=quant_config,
        prefix=maybe_prefix(prefix, "visual"),
    )

    self.mlp_AR = self._build_projector(
        config,
        config.vision_config,
        quant_config=quant_config,
        prefix=maybe_prefix(prefix, "mlp_AR"),
    )

    self.language_model = init_vllm_registered_model(
        vllm_config=vllm_config,
        prefix=maybe_prefix(prefix, "language_model"),
        architectures=["Qwen3ForCausalLM"],
    )

    self.make_empty_intermediate_tensors = (
        self.language_model.make_empty_intermediate_tensors)

_build_projector abstractmethod

_build_projector(
    text_config: PretrainedConfig,
    vision_config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> Module
Source code in vllm/model_executor/models/keye.py
@abstractmethod
def _build_projector(self,
                     text_config: PretrainedConfig,
                     vision_config: PretrainedConfig,
                     quant_config: Optional[QuantizationConfig] = None,
                     prefix: str = "") -> nn.Module:
    raise ValueError("Need projector")

_parse_and_validate_multimodal_inputs

_parse_and_validate_multimodal_inputs(
    **kwargs: object,
) -> dict
Source code in vllm/model_executor/models/keye.py
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
    modalities = {}

    for input_key in kwargs:
        if (input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities):
            modalities["images"] = self._parse_and_validate_image_input(
                **kwargs)
        if (input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities):
            modalities["videos"] = self._parse_and_validate_video_input(
                **kwargs)

    return modalities

_process_image_input

_process_image_input(
    image_input: Any,
) -> tuple[Tensor, ...]
Source code in vllm/model_executor/models/keye.py
def _process_image_input(self,
                         image_input: Any) -> tuple[torch.Tensor, ...]:
    siglip_position_ids = list()
    image_grid_hws = list()
    sample_indices = list()
    cu_seqlens = [0]

    image_grid_thw = image_input["image_grid_thw"]
    assert image_grid_thw.ndim == 2

    for idx, thaw in enumerate(image_grid_thw):
        thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
        numel = np.prod(thw_tuple)
        image_grid_hws.append(thw_tuple)
        image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
        siglip_position_ids.append(image_position_ids)
        sample_indices.append(torch.full((numel, ), idx,
                                         dtype=torch.int64))
        cu_seqlens.append(cu_seqlens[-1] + numel)

    if image_input["type"] == "image_embeds":
        raise ValueError(
            "Image embeddings are not supported for this processing path.")
    else:
        pixel_values = image_input["pixel_values"].type(self.visual.dtype)
        siglip_position_ids = torch.concat(siglip_position_ids,
                                           dim=0).to(pixel_values.device)
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
            pixel_values.device)
        sample_indices = torch.concat(sample_indices,
                                      dim=0).to(pixel_values.device)

        image_embeds = self.visual(
            pixel_values=pixel_values,
            image_grid_thw=image_grid_hws,
            position_ids=siglip_position_ids,
            vision_return_embed_list=False,
            interpolate_pos_encoding=True,
            sample_indices=sample_indices,
            cu_seqlens=cu_seqlens,
            use_rope=True,
            window_size=-1,
        )
        image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
        return image_embeds

_process_video_embeds

_process_video_embeds(
    video_type: Literal[
        "video_embeds", "pixel_values_videos"
    ],
    video_grid_thw: list[Tensor],
    pixel_values_videos: Optional[Tensor] = None,
) -> Union[Tensor, list[Tensor]]
Source code in vllm/model_executor/models/keye.py
def _process_video_embeds(
    self,
    video_type: Literal["video_embeds", "pixel_values_videos"],
    video_grid_thw: list[torch.Tensor],
    pixel_values_videos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
    siglip_position_ids = list()
    video_grid_hws = list()
    sample_indices = list()
    cu_seqlens = [0]

    assert video_grid_thw.ndim == 2
    for idx, sub_thw in enumerate(video_grid_thw):
        thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
        numel = np.prod(thw_tuple)

        video_grid_hws.append(thw_tuple)
        video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
        siglip_position_ids.append(video_position_ids)
        sample_indices.append(torch.full((numel, ), idx,
                                         dtype=torch.int64))
        cu_seqlens.append(cu_seqlens[-1] + numel)

    if video_type == "video_embeds":
        raise ValueError(
            "Video embeddings are not supported for this processing path.")
    else:
        pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
        siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
            pixel_values_videos.device)
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
            pixel_values_videos.device)
        sample_indices = torch.concat(sample_indices,
                                      dim=0).to(pixel_values_videos.device)

        video_embeds = self.visual(
            pixel_values=pixel_values_videos,
            image_grid_thw=video_grid_hws,
            position_ids=siglip_position_ids,
            vision_return_embed_list=True,
            interpolate_pos_encoding=True,
            sample_indices=sample_indices,
            cu_seqlens=cu_seqlens,
            use_rope=True,
            window_size=-1,
        )
        video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
        return video_embeds

compute_logits

compute_logits(hidden_states: Tensor) -> Optional[Tensor]
Source code in vllm/model_executor/models/keye.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
    return self.language_model.compute_logits(hidden_states)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs: object,
) -> Union[Tensor, IntermediateTensors]

Run forward pass for Keye-VL.

Parameters:

Name Type Description Default
input_ids Tensor

Flattened (concatenated) input_ids corresponding to a batch.

required
positions Tensor

Flattened (concatenated) position ids corresponding to a batch. NOTE: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be (3, seq_len), otherwise it will be (seq_len,).

required
intermediate_tensors Optional[IntermediateTensors]

Intermediate tensors from prior forward pass.

None
inputs_embeds Optional[Tensor]

Optional tensor of input embeddings.

None
Source code in vllm/model_executor/models/keye.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    """Run forward pass for Keye-VL.

    Args:
        input_ids: Flattened (concatenated) input_ids corresponding to a
            batch.
        positions: Flattened (concatenated) position ids corresponding to a
            batch.
            **NOTE**: If mrope is enabled (default setting for Qwen2-VL
            opensource models), the shape will be `(3, seq_len)`,
            otherwise it will be `(seq_len,)`.
        intermediate_tensors: Intermediate tensors from prior forward pass.
        inputs_embeds: Optional tensor of input embeddings.
    """
    if intermediate_tensors is not None:
        inputs_embeds = None

    hidden_states = self.language_model.model(
        input_ids=input_ids,
        positions=positions,
        intermediate_tensors=intermediate_tensors,
        inputs_embeds=inputs_embeds,
    )

    return hidden_states

get_language_model

get_language_model() -> Module
Source code in vllm/model_executor/models/keye.py
def get_language_model(self) -> torch.nn.Module:
    return self.language_model

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models.

Source code in vllm/model_executor/models/keye.py
def get_mm_mapping(self) -> MultiModelKeys:
    """Get the module prefix in multimodal models."""
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="mlp_AR.",
        tower_model="visual.",
    )

get_multimodal_embeddings

get_multimodal_embeddings(
    **kwargs: object,
) -> Optional[MultiModalEmbeddings]
Source code in vllm/model_executor/models/keye.py
def get_multimodal_embeddings(
        self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

    modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
    if not modalities:
        return None

    multimodal_embeddings: tuple[torch.Tensor, ...] = ()

    for modality in modalities:
        if modality == "images":
            image_input = modalities["images"]
            vision_embeddings = self._process_image_input(image_input)
            multimodal_embeddings += vision_embeddings
        if modality == "videos":
            video_input = modalities["videos"]
            video_embeddings = self._process_video_input(video_input)
            multimodal_embeddings += video_embeddings
    return multimodal_embeddings

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> Optional[str]
Source code in vllm/model_executor/models/keye.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
    if modality.startswith("image"):
        return "<|vision_start|><|image_pad|><|vision_end|>"
    if modality.startswith("video"):
        return "<|vision_start|><|video_pad|><|vision_end|>"

    raise ValueError("Only image or video modality is supported")

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/keye.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

KeyeBaseDummyInputsBuilder

Bases: BaseDummyInputsBuilder[_I]

Source code in vllm/model_executor/models/keye.py
class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_processor = self.info.get_hf_processor()
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        target_width, target_height = (
            self.info.get_image_size_with_most_features())
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len)

        mm_data = {
            "image":
            self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
            ),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
            ),
        }

        return mm_data

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int, mm_counts: Mapping[str, int]
) -> MultiModalDataDict
Source code in vllm/model_executor/models/keye.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
    num_images = mm_counts.get("image", 0)
    num_videos = mm_counts.get("video", 0)

    target_width, target_height = (
        self.info.get_image_size_with_most_features())
    target_num_frames = self.info.get_num_frames_with_most_features(
        seq_len)

    mm_data = {
        "image":
        self._get_dummy_images(
            width=target_width,
            height=target_height,
            num_images=num_images,
        ),
        "video":
        self._get_dummy_videos(
            width=target_width,
            height=target_height,
            num_frames=target_num_frames,
            num_videos=num_videos,
        ),
    }

    return mm_data

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/keye.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    num_images = mm_counts.get("image", 0)
    num_videos = mm_counts.get("video", 0)

    hf_processor = self.info.get_hf_processor()
    image_token: str = hf_processor.image_token
    video_token: str = hf_processor.video_token

    return image_token * num_images + video_token * num_videos

KeyeDummyInputsBuilder

Bases: KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]

Source code in vllm/model_executor/models/keye.py
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
    ...

KeyeForConditionalGeneration

Bases: BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP

Source code in vllm/model_executor/models/keye.py
@MULTIMODAL_REGISTRY.register_processor(
    KeyeMultiModalProcessor,
    info=KeyeProcessingInfo,
    dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
                                   SupportsLoRA, SupportsPP):

    def _build_projector(self,
                         text_config: PretrainedConfig,
                         vision_config: PretrainedConfig,
                         quant_config: Optional[QuantizationConfig] = None,
                         prefix: str = "") -> nn.Module:
        return Projector(text_config, vision_config, quant_config, prefix)

    def _validate_and_reshape_mm_tensor(
            self, mm_input: NestedTensors,
            name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim == 5:
                return mm_input
            if mm_input.ndim != 3:
                raise ValueError(f"{name} should be 2D or batched 3D tensor. "
                                 f"Got ndim: {mm_input.ndim} "
                                 f"(shape={mm_input.shape})")
            return mm_input.reshape(-1, mm_input.shape[-1])
        elif is_list_of(mm_input, torch.Tensor):
            if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
                                                          for p in mm_input):
                return mm_input
        return torch.concat(mm_input)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[KeyeImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")

            return KeyeImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")

            return KeyeImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[KeyeVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos,
                "video pixel values",
            )
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return KeyeVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return KeyeVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_video_input(
            self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
        video_type = video_input["type"]
        video_grid_thw = video_input["video_grid_thw"]
        pixel_values_videos = video_input.get("pixel_values_videos", None)

        return tuple(
            self._process_video_embeds(video_type, video_grid_thw,
                                       pixel_values_videos))

_build_projector

_build_projector(
    text_config: PretrainedConfig,
    vision_config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> Module
Source code in vllm/model_executor/models/keye.py
def _build_projector(self,
                     text_config: PretrainedConfig,
                     vision_config: PretrainedConfig,
                     quant_config: Optional[QuantizationConfig] = None,
                     prefix: str = "") -> nn.Module:
    return Projector(text_config, vision_config, quant_config, prefix)

_parse_and_validate_image_input

_parse_and_validate_image_input(
    **kwargs: object,
) -> Optional[KeyeImageInputs]
Source code in vllm/model_executor/models/keye.py
def _parse_and_validate_image_input(
        self, **kwargs: object) -> Optional[KeyeImageInputs]:
    pixel_values = kwargs.pop("pixel_values", None)
    image_embeds = kwargs.pop("image_embeds", None)
    image_grid_thw = kwargs.pop("image_grid_thw", None)

    if pixel_values is None and image_embeds is None:
        return None

    if pixel_values is not None:
        pixel_values = self._validate_and_reshape_mm_tensor(
            pixel_values, "image pixel values")
        image_grid_thw = self._validate_and_reshape_mm_tensor(
            image_grid_thw, "image grid_thw")

        return KeyeImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
        )

    if image_embeds is not None:
        image_embeds = self._validate_and_reshape_mm_tensor(
            image_embeds, "image embeds")
        image_grid_thw = self._validate_and_reshape_mm_tensor(
            image_grid_thw, "image grid_thw")

        return KeyeImageEmbeddingInputs(
            type="image_embeds",
            image_embeds=image_embeds,
            image_grid_thw=image_grid_thw,
        )

_parse_and_validate_video_input

_parse_and_validate_video_input(
    **kwargs: object,
) -> Optional[KeyeVideoInputs]
Source code in vllm/model_executor/models/keye.py
def _parse_and_validate_video_input(
        self, **kwargs: object) -> Optional[KeyeVideoInputs]:
    pixel_values_videos = kwargs.pop("pixel_values_videos", None)
    video_embeds = kwargs.pop("video_embeds", None)
    video_grid_thw = kwargs.pop("video_grid_thw", None)

    if pixel_values_videos is None and video_embeds is None:
        return None

    if pixel_values_videos is not None:
        pixel_values_videos = self._validate_and_reshape_mm_tensor(
            pixel_values_videos,
            "video pixel values",
        )
        video_grid_thw = self._validate_and_reshape_mm_tensor(
            video_grid_thw, "video grid_thw")

        return KeyeVideoPixelInputs(
            type="pixel_values_videos",
            pixel_values_videos=pixel_values_videos,
            video_grid_thw=video_grid_thw,
        )

    if video_embeds is not None:
        video_embeds = self._validate_and_reshape_mm_tensor(
            video_embeds, "video embeds")
        video_grid_thw = self._validate_and_reshape_mm_tensor(
            video_grid_thw, "video grid_thw")

        return KeyeVideoEmbeddingInputs(
            type="video_embeds",
            video_embeds=video_embeds,
            video_grid_thw=video_grid_thw,
        )

_process_video_input

_process_video_input(
    video_input: KeyeVideoInputs,
) -> tuple[Tensor, ...]
Source code in vllm/model_executor/models/keye.py
def _process_video_input(
        self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
    video_type = video_input["type"]
    video_grid_thw = video_input["video_grid_thw"]
    pixel_values_videos = video_input.get("pixel_values_videos", None)

    return tuple(
        self._process_video_embeds(video_type, video_grid_thw,
                                   pixel_values_videos))

_validate_and_reshape_mm_tensor

_validate_and_reshape_mm_tensor(
    mm_input: NestedTensors, name: str
) -> Union[Tensor, list[Tensor]]
Source code in vllm/model_executor/models/keye.py
def _validate_and_reshape_mm_tensor(
        self, mm_input: NestedTensors,
        name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
    if not isinstance(mm_input, (torch.Tensor, list)):
        raise ValueError(f"Incorrect type of {name}. "
                         f"Got type: {type(mm_input)}")
    if isinstance(mm_input, torch.Tensor):
        if mm_input.ndim == 2:
            return mm_input
        if mm_input.ndim == 5:
            return mm_input
        if mm_input.ndim != 3:
            raise ValueError(f"{name} should be 2D or batched 3D tensor. "
                             f"Got ndim: {mm_input.ndim} "
                             f"(shape={mm_input.shape})")
        return mm_input.reshape(-1, mm_input.shape[-1])
    elif is_list_of(mm_input, torch.Tensor):
        if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
                                                      for p in mm_input):
            return mm_input
    return torch.concat(mm_input)

KeyeImageEmbeddingInputs

Bases: TensorSchema

Dimensions
  • nf: Number of image features
  • hs: Hidden size (must match the hidden size of language model backbone)
  • ni: Number of images
  • g: Grid dimensions (3 for t, h, w)
Source code in vllm/model_executor/models/keye.py
class KeyeImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size (must match the hidden size of language model 
          backbone)
        - ni: Number of images
        - g: Grid dimensions (3 for t, h, w)
    """
    type: Literal["image_embeds"]
    image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]

image_embeds instance-attribute

image_embeds: Annotated[Tensor, TensorShape(nf, hs)]

image_grid_thw instance-attribute

image_grid_thw: Annotated[Tensor, TensorShape(ni, 3)]

type instance-attribute

type: Literal['image_embeds']

KeyeImagePixelInputs

Bases: TensorSchema

Dimensions
  • b: Batch size
  • np: Number of patches
  • c: Number of channels
  • ps: Patch size
  • ni: Number of images
  • g: Grid dimensions (3 for t, h, w)
Source code in vllm/model_executor/models/keye.py
class KeyeImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - np: Number of patches
        - c: Number of channels
        - ps: Patch size
        - ni: Number of images
        - g: Grid dimensions (3 for t, h, w)
    """
    type: Literal["pixel_values"]
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]

image_grid_thw instance-attribute

image_grid_thw: Annotated[Tensor, TensorShape(ni, 3)]

pixel_values instance-attribute

pixel_values: Annotated[
    Tensor,
    TensorShape(b, numpy, 3, ps, ps, dynamic_dims={numpy}),
]

type instance-attribute

type: Literal['pixel_values']

KeyeMultiModalDataParser

Bases: MultiModalDataParser

Source code in vllm/model_executor/models/keye.py
class KeyeMultiModalDataParser(MultiModalDataParser):

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={
                    "image_embeds",
                    "image_grid_thw",
                },
                fields_factory=_keye_field_config,
            )

        return super()._parse_image_data(data)

    def _parse_video_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={
                    "video_embeds",
                    "video_grid_thw",
                },
                fields_factory=_keye_field_config,
            )

        return super()._parse_video_data(data)

_parse_image_data

_parse_image_data(
    data: Union[dict[str, Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]
Source code in vllm/model_executor/models/keye.py
def _parse_image_data(
    self,
    data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
    if isinstance(data, dict):
        return DictEmbeddingItems(
            data,
            modality="image",
            required_fields={
                "image_embeds",
                "image_grid_thw",
            },
            fields_factory=_keye_field_config,
        )

    return super()._parse_image_data(data)

_parse_video_data

_parse_video_data(
    data: Union[dict[str, Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]
Source code in vllm/model_executor/models/keye.py
def _parse_video_data(
    self,
    data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
    if isinstance(data, dict):
        return DictEmbeddingItems(
            data,
            modality="video",
            required_fields={
                "video_embeds",
                "video_grid_thw",
            },
            fields_factory=_keye_field_config,
        )

    return super()._parse_video_data(data)

KeyeMultiModalProcessor

Bases: BaseMultiModalProcessor[KeyeProcessingInfo]

Source code in vllm/model_executor/models/keye.py
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):

    def _get_data_parser(self) -> MultiModalDataParser:
        return KeyeMultiModalDataParser()

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_processor = self.info.get_image_processor(
            **hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        placeholder = {
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
        }

        merge_length = image_processor.merge_size**2

        def get_replacement_keye(item_idx: int, modality: str):
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens

        return [
            PromptReplacement(
                modality=modality,
                target=[placeholder[modality]],
                replacement=partial(get_replacement_keye, modality=modality),
            ) for modality in ("image", "video")
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _keye_field_config(hf_inputs)

_get_data_parser

_get_data_parser() -> MultiModalDataParser
Source code in vllm/model_executor/models/keye.py
def _get_data_parser(self) -> MultiModalDataParser:
    return KeyeMultiModalDataParser()

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/keye.py
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return _keye_field_config(hf_inputs)

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, Any],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/keye.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, Any],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
    hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
    image_processor = self.info.get_image_processor(
        **hf_processor_mm_kwargs)
    tokenizer = self.info.get_tokenizer()
    vocab = tokenizer.get_vocab()

    placeholder = {
        "image": vocab[hf_processor.image_token],
        "video": vocab[hf_processor.video_token],
    }

    merge_length = image_processor.merge_size**2

    def get_replacement_keye(item_idx: int, modality: str):
        out_item = out_mm_kwargs[modality][item_idx]
        grid_thw = out_item[f"{modality}_grid_thw"].data
        assert isinstance(grid_thw, torch.Tensor)

        num_tokens = int(grid_thw.prod()) // merge_length
        return [placeholder[modality]] * num_tokens

    return [
        PromptReplacement(
            modality=modality,
            target=[placeholder[modality]],
            replacement=partial(get_replacement_keye, modality=modality),
        ) for modality in ("image", "video")
    ]

KeyeProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/keye.py
class KeyeProcessingInfo(BaseProcessingInfo):

    def get_max_image_size(self) -> int:
        return 9999999  #_MAX_IMAGE_SIZE

    def get_max_frame_per_video(self) -> int:
        return 16  #_MAX_FRAMES_PER_VIDEO

    def get_image_processor(self, **kwargs: object):
        return self.get_hf_processor(**kwargs).image_processor

    def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {
            "image": self.get_max_image_tokens(),
            "video": self.get_max_video_tokens(seq_len),
        }

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
        image_processor,
    ) -> tuple[ImageSize, int]:
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = 1

        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
            preprocessed_size = ImageSize(width=resized_width,
                                          height=resized_height)
        else:
            preprocessed_size = ImageSize(width=image_width,
                                          height=image_height)

        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        image_processor,
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            image_processor=image_processor,
        )
        return num_image_tokens

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
        image_processor,
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            image_processor=image_processor,
        )
        return num_video_tokens

    def get_image_size_with_most_features(self, ) -> ImageSize:
        max_image_size, _ = self._get_vision_info(
            image_width=self.get_max_image_size(),
            image_height=self.get_max_image_size(),
            image_processor=None,
        )
        return max_image_size

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            image_processor=None,
        )

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
                image_processor=None,
            )

            if next_max_tokens > max_tokens:
                break

            num_frames = next_num_frames

        return num_frames

    def get_num_frames_with_most_features(self, seq_len: int) -> int:
        mm_config = self.ctx.get_mm_config()
        max_images = mm_config.get_limit_per_prompt("image")
        max_videos = mm_config.get_limit_per_prompt("video")

        max_image_tokens = self.get_max_image_tokens() * max_images
        max_total_frames = self._get_max_video_frames(seq_len -
                                                      max_image_tokens)
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1),
            self.get_max_frame_per_video(),
        )

        return max(max_frames_per_video, 1)

    def get_max_video_tokens(self, seq_len: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
            num_frames=self.get_num_frames_with_most_features(seq_len),
            image_processor=None,
        )

_get_max_video_frames

_get_max_video_frames(max_tokens: int) -> int
Source code in vllm/model_executor/models/keye.py
def _get_max_video_frames(self, max_tokens: int) -> int:
    target_width, target_height = self.get_image_size_with_most_features()

    num_frames = 0

    while True:
        next_num_frames = num_frames + 1
        next_max_tokens = self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
            num_frames=next_num_frames,
            image_processor=None,
        )

        if next_max_tokens > max_tokens:
            break

        num_frames = next_num_frames

    return num_frames

_get_vision_info

_get_vision_info(
    *,
    image_width: int,
    image_height: int,
    num_frames: int = 1,
    do_resize: bool = True,
    image_processor,
) -> tuple[ImageSize, int]
Source code in vllm/model_executor/models/keye.py
def _get_vision_info(
    self,
    *,
    image_width: int,
    image_height: int,
    num_frames: int = 1,
    do_resize: bool = True,
    image_processor,
) -> tuple[ImageSize, int]:
    if image_processor is None:
        image_processor = self.get_image_processor()

    hf_config = self.get_hf_config()
    vision_config = hf_config.vision_config
    patch_size = vision_config.patch_size
    merge_size = vision_config.spatial_merge_size
    temporal_patch_size = 1

    if do_resize:
        resized_height, resized_width = smart_resize(
            height=image_height,
            width=image_width,
            factor=patch_size * merge_size,
            min_pixels=image_processor.min_pixels,
            max_pixels=image_processor.max_pixels,
        )
        preprocessed_size = ImageSize(width=resized_width,
                                      height=resized_height)
    else:
        preprocessed_size = ImageSize(width=image_width,
                                      height=image_height)

    padded_num_frames = num_frames + num_frames % temporal_patch_size

    grid_t = max(padded_num_frames // temporal_patch_size, 1)
    grid_h = preprocessed_size.height // patch_size
    grid_w = preprocessed_size.width // patch_size

    num_patches = grid_t * grid_h * grid_w
    num_vision_tokens = num_patches // (merge_size**2)

    return preprocessed_size, num_vision_tokens

get_image_processor

get_image_processor(**kwargs: object)
Source code in vllm/model_executor/models/keye.py
def get_image_processor(self, **kwargs: object):
    return self.get_hf_processor(**kwargs).image_processor

get_image_size_with_most_features

get_image_size_with_most_features() -> ImageSize
Source code in vllm/model_executor/models/keye.py
def get_image_size_with_most_features(self, ) -> ImageSize:
    max_image_size, _ = self._get_vision_info(
        image_width=self.get_max_image_size(),
        image_height=self.get_max_image_size(),
        image_processor=None,
    )
    return max_image_size

get_max_frame_per_video

get_max_frame_per_video() -> int
Source code in vllm/model_executor/models/keye.py
def get_max_frame_per_video(self) -> int:
    return 16  #_MAX_FRAMES_PER_VIDEO

get_max_image_size

get_max_image_size() -> int
Source code in vllm/model_executor/models/keye.py
def get_max_image_size(self) -> int:
    return 9999999  #_MAX_IMAGE_SIZE

get_max_image_tokens

get_max_image_tokens() -> int
Source code in vllm/model_executor/models/keye.py
def get_max_image_tokens(self) -> int:
    target_width, target_height = self.get_image_size_with_most_features()

    return self.get_num_image_tokens(
        image_width=target_width,
        image_height=target_height,
        image_processor=None,
    )

get_max_video_tokens

get_max_video_tokens(seq_len: int) -> int
Source code in vllm/model_executor/models/keye.py
def get_max_video_tokens(self, seq_len: int) -> int:
    target_width, target_height = self.get_image_size_with_most_features()

    return self.get_num_video_tokens(
        image_width=target_width,
        image_height=target_height,
        num_frames=self.get_num_frames_with_most_features(seq_len),
        image_processor=None,
    )

get_mm_max_tokens_per_item

get_mm_max_tokens_per_item(
    seq_len: int, mm_counts: Mapping[str, int]
) -> Mapping[str, int]
Source code in vllm/model_executor/models/keye.py
def get_mm_max_tokens_per_item(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
    return {
        "image": self.get_max_image_tokens(),
        "video": self.get_max_video_tokens(seq_len),
    }

get_num_frames_with_most_features

get_num_frames_with_most_features(seq_len: int) -> int
Source code in vllm/model_executor/models/keye.py
def get_num_frames_with_most_features(self, seq_len: int) -> int:
    mm_config = self.ctx.get_mm_config()
    max_images = mm_config.get_limit_per_prompt("image")
    max_videos = mm_config.get_limit_per_prompt("video")

    max_image_tokens = self.get_max_image_tokens() * max_images
    max_total_frames = self._get_max_video_frames(seq_len -
                                                  max_image_tokens)
    max_frames_per_video = min(
        max_total_frames // max(max_videos, 1),
        self.get_max_frame_per_video(),
    )

    return max(max_frames_per_video, 1)

get_num_image_tokens

get_num_image_tokens(
    *, image_width: int, image_height: int, image_processor
) -> int
Source code in vllm/model_executor/models/keye.py
def get_num_image_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
    image_processor,
) -> int:
    _, num_image_tokens = self._get_vision_info(
        image_width=image_width,
        image_height=image_height,
        image_processor=image_processor,
    )
    return num_image_tokens

get_num_video_tokens

get_num_video_tokens(
    *,
    image_width: int,
    image_height: int,
    num_frames: int,
    image_processor,
) -> int
Source code in vllm/model_executor/models/keye.py
def get_num_video_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
    num_frames: int,
    image_processor,
) -> int:
    _, num_video_tokens = self._get_vision_info(
        image_width=image_width,
        image_height=image_height,
        num_frames=num_frames,
        image_processor=image_processor,
    )
    return num_video_tokens

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, Optional[int]]
Source code in vllm/model_executor/models/keye.py
def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]:
    return {"image": None, "video": None}

KeyeSiglipAttention

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper.

Source code in vllm/model_executor/models/keye.py
class KeyeSiglipAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You
    Need' paper."""

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config

        hidden_size = config.hidden_size
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_attention_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scale = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=hidden_size,
            output_size=hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        # Detect attention implementation.
        self.attn_backend = get_vit_attn_backend(
            head_size=self.head_dim, dtype=torch.get_default_dtype())

        self.use_upstream_fa = False
        if self.attn_backend != _Backend.FLASH_ATTN and \
            check_upstream_fa_availability(
                torch.get_default_dtype()):
            self.attn_backend = _Backend.FLASH_ATTN
            self.use_upstream_fa = True

        if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
            raise RuntimeError(
                f"Keye-VL does not support {self.attn_backend} backend now.")

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
        cu_seqlens: Optional[list[torch.Tensor]] = None,
        rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split(
            [self.q_size, self.kv_size, self.kv_size],
            dim=-1,
        )

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        batch_size = q.shape[0]

        if rope_emb is None:
            q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
            k = k.view(
                *k.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
            v = v.view(
                *v.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
        else:
            if cu_seqlens is None:
                raise ValueError(
                    "cu_seqlens cannot be None when rope_emb is not None.")
            cos, sin = rope_emb
            q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
            k = k.view(
                *k.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
            q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
            v = v.view(
                *v.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )

        if self.attn_backend == _Backend.FLASH_ATTN:
            if self.use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
            else:
                from vllm.vllm_flash_attn import flash_attn_varlen_func

            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

            output = flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                causal=False,
                softmax_scale=self.scale,
            )
            context_layer = rearrange(output,
                                      "(b s) ... -> b s ...",
                                      b=batch_size)
        elif self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

            attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
                                                       kv_seqlen=None,
                                                       device=q.device)

            context_layer = xops.memory_efficient_attention_forward(
                q, k, v, attn_bias=attn_bias, p=0, scale=None)

        context_layer = rearrange(context_layer,
                                  "b s h d -> b s (h d)").contiguous()

        output, _ = self.out_proj(context_layer)
        return output

attn_backend instance-attribute

attn_backend = get_vit_attn_backend(
    head_size=head_dim, dtype=get_default_dtype()
)

config instance-attribute

config = config

head_dim instance-attribute

head_dim = hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

out_proj instance-attribute

out_proj = RowParallelLinear(
    input_size=hidden_size,
    output_size=hidden_size,
    quant_config=quant_config,
    prefix=f"{prefix}.out_proj",
)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

scale instance-attribute

scale = head_dim ** -0.5

total_num_heads instance-attribute

total_num_heads = num_attention_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_attention_heads

use_upstream_fa instance-attribute

use_upstream_fa = False

__init__

__init__(
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/keye.py
def __init__(
    self,
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.config = config

    hidden_size = config.hidden_size
    self.hidden_size = config.hidden_size
    tp_size = get_tensor_model_parallel_world_size()
    self.total_num_heads = config.num_attention_heads
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.total_num_kv_heads = config.num_attention_heads
    if self.total_num_kv_heads >= tp_size:
        assert self.total_num_kv_heads % tp_size == 0
    else:
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    self.head_dim = config.hidden_size // self.total_num_heads
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim
    self.scale = self.head_dim**-0.5

    self.qkv_proj = QKVParallelLinear(
        hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=True,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.out_proj = RowParallelLinear(
        input_size=hidden_size,
        output_size=hidden_size,
        quant_config=quant_config,
        prefix=f"{prefix}.out_proj",
    )

    # Detect attention implementation.
    self.attn_backend = get_vit_attn_backend(
        head_size=self.head_dim, dtype=torch.get_default_dtype())

    self.use_upstream_fa = False
    if self.attn_backend != _Backend.FLASH_ATTN and \
        check_upstream_fa_availability(
            torch.get_default_dtype()):
        self.attn_backend = _Backend.FLASH_ATTN
        self.use_upstream_fa = True

    if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
        raise RuntimeError(
            f"Keye-VL does not support {self.attn_backend} backend now.")

forward

forward(
    hidden_states: Tensor,
    attention_mask: Optional[Tensor] = None,
    output_attentions: Optional[bool] = False,
    cu_seqlens: Optional[list[Tensor]] = None,
    rope_emb: Optional[tuple[Tensor, Tensor]] = None,
) -> Tensor
Source code in vllm/model_executor/models/keye.py
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = False,
    cu_seqlens: Optional[list[torch.Tensor]] = None,
    rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split(
        [self.q_size, self.kv_size, self.kv_size],
        dim=-1,
    )

    max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
    seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
    batch_size = q.shape[0]

    if rope_emb is None:
        q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
        k = k.view(
            *k.shape[:-1],
            self.num_kv_heads,
            self.head_dim,
        )
        v = v.view(
            *v.shape[:-1],
            self.num_kv_heads,
            self.head_dim,
        )
    else:
        if cu_seqlens is None:
            raise ValueError(
                "cu_seqlens cannot be None when rope_emb is not None.")
        cos, sin = rope_emb
        q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
        k = k.view(
            *k.shape[:-1],
            self.num_kv_heads,
            self.head_dim,
        )
        q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
        v = v.view(
            *v.shape[:-1],
            self.num_kv_heads,
            self.head_dim,
        )

    if self.attn_backend == _Backend.FLASH_ATTN:
        if self.use_upstream_fa:
            from flash_attn import flash_attn_varlen_func
        else:
            from vllm.vllm_flash_attn import flash_attn_varlen_func

        q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

        output = flash_attn_varlen_func(
            q,
            k,
            v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            causal=False,
            softmax_scale=self.scale,
        )
        context_layer = rearrange(output,
                                  "(b s) ... -> b s ...",
                                  b=batch_size)
    elif self.attn_backend == _Backend.XFORMERS:
        from xformers import ops as xops
        from xformers.ops.fmha.attn_bias import BlockDiagonalMask

        attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
                                                   kv_seqlen=None,
                                                   device=q.device)

        context_layer = xops.memory_efficient_attention_forward(
            q, k, v, attn_bias=attn_bias, p=0, scale=None)

    context_layer = rearrange(context_layer,
                              "b s h d -> b s (h d)").contiguous()

    output, _ = self.out_proj(context_layer)
    return output

KeyeSiglipEncoder

Bases: Module

Source code in vllm/model_executor/models/keye.py
class KeyeSiglipEncoder(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.layers = nn.ModuleList([
            KeyeSiglipEncoderLayer(
                config,
                quant_config=quant_config,
                prefix=f"{prefix}.layers.{layer_idx}",
            ) for layer_idx in range(config.num_hidden_layers)
        ])
        self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)

    @staticmethod
    def flatten_list(image_grid_thw):
        tmp_image_grid_thw = list()
        for image_grid in image_grid_thw:
            if isinstance(image_grid, list):
                tmp_image_grid_thw.extend(image_grid)
            else:
                tmp_image_grid_thw.append(image_grid)
        return tmp_image_grid_thw

    def forward(
        self,
        inputs_embeds,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cu_seqlens: Optional[list[torch.Tensor]] = None,
        image_grid_thw: Optional[list[Union[
            tuple[int, int, int],
            list[tuple[int, int, int]],
        ]]] = None,
        height_position_ids: Optional[torch.Tensor] = None,
        width_position_ids: Optional[torch.Tensor] = None,
        use_rope: Optional[bool] = False,
        window_size: Optional[bool] = -1,
        vision_or_text: str = "vision",
    ) -> BaseModelOutput:
        device = inputs_embeds.device
        hidden_states = inputs_embeds
        if use_rope is True:
            flatten_image_grid_thw = self.flatten_list(image_grid_thw)

            if width_position_ids is None or height_position_ids is None:
                split_hids = list()
                split_wids = list()
                for t, h, w in flatten_image_grid_thw:
                    image_pids = torch.arange(t * h * w,
                                              device=device) % (h * w)
                    sample_hids = image_pids // w
                    sample_wids = image_pids % w
                    split_hids.append(sample_hids)
                    split_wids.append(sample_wids)
                width_position_ids = torch.concat(split_wids, dim=0)
                height_position_ids = torch.concat(split_hids, dim=0)

            pids = torch.stack(
                [height_position_ids, width_position_ids],
                dim=-1,
            )
            max_grid_size = pids.max() + 1
            rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
            rope_emb = rope_emb_max_grid[pids].flatten(1)
            rope_emb = rope_emb.repeat(1, 2)
            rope_emb = (rope_emb.cos(), rope_emb.sin())
        else:
            rope_emb = None

        attn_cu_seqlens = cu_seqlens
        hidden_states = inputs_embeds
        assert attention_mask is None

        for encoder_layer in self.layers:
            hidden_states = encoder_layer(
                hidden_states,
                attention_mask,
                output_attentions=output_attentions,
                cu_seqlens=attn_cu_seqlens,
                rope_emb=rope_emb,
            )
        return hidden_states

config instance-attribute

config = config

layers instance-attribute

layers = ModuleList(
    [
        (
            KeyeSiglipEncoderLayer(
                config,
                quant_config=quant_config,
                prefix=f"{prefix}.layers.{layer_idx}",
            )
        )
        for layer_idx in (range(num_hidden_layers))
    ]
)

rotary_pos_emb instance-attribute

rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)

__init__

__init__(
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/keye.py
def __init__(
    self,
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.config = config
    embed_dim = config.hidden_size
    num_heads = config.num_attention_heads
    head_dim = embed_dim // num_heads
    self.layers = nn.ModuleList([
        KeyeSiglipEncoderLayer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.layers.{layer_idx}",
        ) for layer_idx in range(config.num_hidden_layers)
    ])
    self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)

flatten_list staticmethod

flatten_list(image_grid_thw)
Source code in vllm/model_executor/models/keye.py
@staticmethod
def flatten_list(image_grid_thw):
    tmp_image_grid_thw = list()
    for image_grid in image_grid_thw:
        if isinstance(image_grid, list):
            tmp_image_grid_thw.extend(image_grid)
        else:
            tmp_image_grid_thw.append(image_grid)
    return tmp_image_grid_thw

forward

forward(
    inputs_embeds,
    attention_mask: Optional[Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    cu_seqlens: Optional[list[Tensor]] = None,
    image_grid_thw: Optional[
        list[
            Union[
                tuple[int, int, int],
                list[tuple[int, int, int]],
            ]
        ]
    ] = None,
    height_position_ids: Optional[Tensor] = None,
    width_position_ids: Optional[Tensor] = None,
    use_rope: Optional[bool] = False,
    window_size: Optional[bool] = -1,
    vision_or_text: str = "vision",
) -> BaseModelOutput
Source code in vllm/model_executor/models/keye.py
def forward(
    self,
    inputs_embeds,
    attention_mask: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    cu_seqlens: Optional[list[torch.Tensor]] = None,
    image_grid_thw: Optional[list[Union[
        tuple[int, int, int],
        list[tuple[int, int, int]],
    ]]] = None,
    height_position_ids: Optional[torch.Tensor] = None,
    width_position_ids: Optional[torch.Tensor] = None,
    use_rope: Optional[bool] = False,
    window_size: Optional[bool] = -1,
    vision_or_text: str = "vision",
) -> BaseModelOutput:
    device = inputs_embeds.device
    hidden_states = inputs_embeds
    if use_rope is True:
        flatten_image_grid_thw = self.flatten_list(image_grid_thw)

        if width_position_ids is None or height_position_ids is None:
            split_hids = list()
            split_wids = list()
            for t, h, w in flatten_image_grid_thw:
                image_pids = torch.arange(t * h * w,
                                          device=device) % (h * w)
                sample_hids = image_pids // w
                sample_wids = image_pids % w
                split_hids.append(sample_hids)
                split_wids.append(sample_wids)
            width_position_ids = torch.concat(split_wids, dim=0)
            height_position_ids = torch.concat(split_hids, dim=0)

        pids = torch.stack(
            [height_position_ids, width_position_ids],
            dim=-1,
        )
        max_grid_size = pids.max() + 1
        rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
        rope_emb = rope_emb_max_grid[pids].flatten(1)
        rope_emb = rope_emb.repeat(1, 2)
        rope_emb = (rope_emb.cos(), rope_emb.sin())
    else:
        rope_emb = None

    attn_cu_seqlens = cu_seqlens
    hidden_states = inputs_embeds
    assert attention_mask is None

    for encoder_layer in self.layers:
        hidden_states = encoder_layer(
            hidden_states,
            attention_mask,
            output_attentions=output_attentions,
            cu_seqlens=attn_cu_seqlens,
            rope_emb=rope_emb,
        )
    return hidden_states

KeyeSiglipEncoderLayer

Bases: Module

Source code in vllm/model_executor/models/keye.py
class KeyeSiglipEncoderLayer(nn.Module):

    def __init__(
        self,
        config: Union[PretrainedConfig],
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
        self.self_attn = KeyeSiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        output_attentions: Optional[bool] = False,
        cu_seqlens: Optional[list[torch.Tensor]] = None,
        rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> tuple[torch.FloatTensor]:

        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            cu_seqlens=cu_seqlens,
            rope_emb=rope_emb,
        )

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)

        hidden_states = residual + hidden_states

        return hidden_states

embed_dim instance-attribute

embed_dim = hidden_size

layer_norm1 instance-attribute

layer_norm1 = LayerNorm(embed_dim, eps=layer_norm_eps)

layer_norm2 instance-attribute

layer_norm2 = LayerNorm(embed_dim, eps=layer_norm_eps)

mlp instance-attribute

mlp = SiglipMLP(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

self_attn instance-attribute

self_attn = KeyeSiglipAttention(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

__init__

__init__(
    config: Union[PretrainedConfig],
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/keye.py
def __init__(
    self,
    config: Union[PretrainedConfig],
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.embed_dim = config.hidden_size
    self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                    eps=config.layer_norm_eps)
    self.self_attn = KeyeSiglipAttention(
        config,
        quant_config=quant_config,
        prefix=f"{prefix}.self_attn",
    )
    self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                    eps=config.layer_norm_eps)
    self.mlp = SiglipMLP(
        config,
        quant_config=quant_config,
        prefix=f"{prefix}.mlp",
    )

forward

forward(
    hidden_states: Tensor,
    attention_mask: Tensor,
    output_attentions: Optional[bool] = False,
    cu_seqlens: Optional[list[Tensor]] = None,
    rope_emb: Optional[tuple[Tensor, Tensor]] = None,
) -> tuple[FloatTensor]
Source code in vllm/model_executor/models/keye.py
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
    output_attentions: Optional[bool] = False,
    cu_seqlens: Optional[list[torch.Tensor]] = None,
    rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> tuple[torch.FloatTensor]:

    residual = hidden_states

    hidden_states = self.layer_norm1(hidden_states)
    hidden_states = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        output_attentions=output_attentions,
        cu_seqlens=cu_seqlens,
        rope_emb=rope_emb,
    )

    hidden_states = residual + hidden_states

    residual = hidden_states
    hidden_states = self.layer_norm2(hidden_states)
    hidden_states = self.mlp(hidden_states)

    hidden_states = residual + hidden_states

    return hidden_states

KeyeSiglipVisionModel

Bases: Module

Source code in vllm/model_executor/models/keye.py
class KeyeSiglipVisionModel(nn.Module):
    config_class = PretrainedConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.vision_model = KeyeSiglipVisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
        )
        self.quant_config = quant_config

    @property
    def dtype(self) -> torch.dtype:
        return self.vision_model.embeddings.patch_embedding.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.vision_model.embeddings.patch_embedding.weight.device

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        pixel_values,
        sample_indices: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: bool = False,
        position_ids: Optional[torch.Tensor] = None,
        vision_return_embed_list: Optional[bool] = False,
        image_grid_thw: Optional[list[Union[
            tuple[int, int, int],
            list[tuple[int, int, int]],
        ]]] = None,
        cu_seqlens: Optional[list[torch.Tensor]] = None,
        return_pooler_output: Optional[bool] = True,
        use_rope: Optional[bool] = False,
        window_size: Optional[bool] = -1,
    ) -> BaseModelOutputWithPooling:

        return self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            position_ids=position_ids,
            vision_return_embed_list=vision_return_embed_list,
            image_grid_thw=image_grid_thw,
            sample_indices=sample_indices,
            cu_seqlens=cu_seqlens,
            return_pooler_output=return_pooler_output,
            use_rope=use_rope,
            window_size=window_size,
        )

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if "head.attention" in name or "head.layernorm" in name:
                continue
            if "head.mlp" in name or "head.probe" in name:
                continue
            if self.quant_config is not None and (
                    scale_name := self.quant_config.get_cache_scale(name)):
                param = params_dict[scale_name]
                weight_loader = getattr(
                    param,
                    "weight_loader",
                    default_weight_loader,
                )
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for (
                    param_name,
                    weight_name,
                    shard_id,
            ) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if name.endswith(".bias") and name not in params_dict:
                    continue
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(
                    param,
                    "weight_loader",
                    default_weight_loader,
                )
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config_class class-attribute instance-attribute

config_class = PretrainedConfig

device property

device: device

dtype property

dtype: dtype

main_input_name class-attribute instance-attribute

main_input_name = 'pixel_values'

quant_config instance-attribute

quant_config = quant_config

vision_model instance-attribute

vision_model = KeyeSiglipVisionTransformer(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.vision_model",
)

__init__

__init__(
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/keye.py
def __init__(
    self,
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()

    self.vision_model = KeyeSiglipVisionTransformer(
        config,
        quant_config=quant_config,
        prefix=f"{prefix}.vision_model",
    )
    self.quant_config = quant_config

forward

forward(
    pixel_values,
    sample_indices: Optional[Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    interpolate_pos_encoding: bool = False,
    position_ids: Optional[Tensor] = None,
    vision_return_embed_list: Optional[bool] = False,
    image_grid_thw: Optional[
        list[
            Union[
                tuple[int, int, int],
                list[tuple[int, int, int]],
            ]
        ]
    ] = None,
    cu_seqlens: Optional[list[Tensor]] = None,
    return_pooler_output: Optional[bool] = True,
    use_rope: Optional[bool] = False,
    window_size: Optional[bool] = -1,
) -> BaseModelOutputWithPooling
Source code in vllm/model_executor/models/keye.py
def forward(
    self,
    pixel_values,
    sample_indices: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    interpolate_pos_encoding: bool = False,
    position_ids: Optional[torch.Tensor] = None,
    vision_return_embed_list: Optional[bool] = False,
    image_grid_thw: Optional[list[Union[
        tuple[int, int, int],
        list[tuple[int, int, int]],
    ]]] = None,
    cu_seqlens: Optional[list[torch.Tensor]] = None,
    return_pooler_output: Optional[bool] = True,
    use_rope: Optional[bool] = False,
    window_size: Optional[bool] = -1,
) -> BaseModelOutputWithPooling:

    return self.vision_model(
        pixel_values=pixel_values,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        interpolate_pos_encoding=interpolate_pos_encoding,
        position_ids=position_ids,
        vision_return_embed_list=vision_return_embed_list,
        image_grid_thw=image_grid_thw,
        sample_indices=sample_indices,
        cu_seqlens=cu_seqlens,
        return_pooler_output=return_pooler_output,
        use_rope=use_rope,
        window_size=window_size,
    )

get_input_embeddings

get_input_embeddings() -> Module
Source code in vllm/model_executor/models/keye.py
def get_input_embeddings(self) -> nn.Module:
    return self.vision_model.embeddings.patch_embedding

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/keye.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
    ]
    params_dict = dict(self.named_parameters(remove_duplicate=False))
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if "rotary_emb.inv_freq" in name:
            continue
        if "head.attention" in name or "head.layernorm" in name:
            continue
        if "head.mlp" in name or "head.probe" in name:
            continue
        if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)):
            param = params_dict[scale_name]
            weight_loader = getattr(
                param,
                "weight_loader",
                default_weight_loader,
            )
            loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                             loaded_weight[0])
            weight_loader(param, loaded_weight)
            loaded_params.add(scale_name)
            continue
        for (
                param_name,
                weight_name,
                shard_id,
        ) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            if name.endswith(".bias") and name not in params_dict:
                continue
            name = maybe_remap_kv_scale_name(name, params_dict)
            if name is None:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(
                param,
                "weight_loader",
                default_weight_loader,
            )
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

KeyeSiglipVisionTransformer

Bases: Module

Source code in vllm/model_executor/models/keye.py
class KeyeSiglipVisionTransformer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = KeyeVisionEmbeddings(config)
        self.encoder = KeyeSiglipEncoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
        )
        self.post_layernorm = nn.LayerNorm(embed_dim,
                                           eps=config.layer_norm_eps)

    def forward(
        self,
        pixel_values,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = False,
        attention_mask: Optional[torch.Tensor] = None,
        sample_indices: Optional[torch.Tensor] = None,
        image_indices: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        height_position_ids: Optional[torch.Tensor] = None,
        width_position_ids: Optional[torch.Tensor] = None,
        cu_seqlens: Optional[list[torch.Tensor]] = None,
        padding_mask: Optional[torch.Tensor] = None,
        vision_return_embed_list: Optional[bool] = False,
        image_grid_thw: Optional[list[Union[
            tuple[int, int, int],
            list[tuple[int, int, int]],
        ]]] = None,
        return_pooler_output: Optional[bool] = True,
        use_rope: Optional[bool] = False,
        window_size: Optional[bool] = -1,
    ) -> BaseModelOutputWithPooling:

        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
            position_ids=position_ids,
            image_grid_thw=image_grid_thw,
        )

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            attention_mask=attention_mask,
            cu_seqlens=cu_seqlens,
            image_grid_thw=image_grid_thw,
            use_rope=use_rope,
            height_position_ids=height_position_ids,
            width_position_ids=width_position_ids,
            window_size=window_size,
            vision_or_text="vision",
        )

        last_hidden_state = self.post_layernorm(last_hidden_state)

        sample_hidden_state = list()
        if cu_seqlens is None:
            raise ValueError("cu_seqlens cannot be None for "
                             "SiglipVisionTransformer output processing.")
        for i in range(cu_seqlens.shape[0] - 1):
            start = cu_seqlens[i]
            end = cu_seqlens[i + 1]
            tensor = last_hidden_state[:, start:end, :].squeeze(0)
            sample_hidden_state.append(tensor)

        return sample_hidden_state

config instance-attribute

config = config

embeddings instance-attribute

embeddings = KeyeVisionEmbeddings(config)

encoder instance-attribute

encoder = KeyeSiglipEncoder(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.encoder",
)

post_layernorm instance-attribute

post_layernorm = LayerNorm(embed_dim, eps=layer_norm_eps)

__init__

__init__(
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/keye.py
def __init__(
    self,
    config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.config = config
    embed_dim = config.hidden_size

    self.embeddings = KeyeVisionEmbeddings(config)
    self.encoder = KeyeSiglipEncoder(
        config,
        quant_config=quant_config,
        prefix=f"{prefix}.encoder",
    )
    self.post_layernorm = nn.LayerNorm(embed_dim,
                                       eps=config.layer_norm_eps)

forward

forward(
    pixel_values,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    interpolate_pos_encoding: Optional[bool] = False,
    attention_mask: Optional[Tensor] = None,
    sample_indices: Optional[Tensor] = None,
    image_indices: Optional[Tensor] = None,
    position_ids: Optional[Tensor] = None,
    height_position_ids: Optional[Tensor] = None,
    width_position_ids: Optional[Tensor] = None,
    cu_seqlens: Optional[list[Tensor]] = None,
    padding_mask: Optional[Tensor] = None,
    vision_return_embed_list: Optional[bool] = False,
    image_grid_thw: Optional[
        list[
            Union[
                tuple[int, int, int],
                list[tuple[int, int, int]],
            ]
        ]
    ] = None,
    return_pooler_output: Optional[bool] = True,
    use_rope: Optional[bool] = False,
    window_size: Optional[bool] = -1,
) -> BaseModelOutputWithPooling
Source code in vllm/model_executor/models/keye.py
def forward(
    self,
    pixel_values,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    interpolate_pos_encoding: Optional[bool] = False,
    attention_mask: Optional[torch.Tensor] = None,
    sample_indices: Optional[torch.Tensor] = None,
    image_indices: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    height_position_ids: Optional[torch.Tensor] = None,
    width_position_ids: Optional[torch.Tensor] = None,
    cu_seqlens: Optional[list[torch.Tensor]] = None,
    padding_mask: Optional[torch.Tensor] = None,
    vision_return_embed_list: Optional[bool] = False,
    image_grid_thw: Optional[list[Union[
        tuple[int, int, int],
        list[tuple[int, int, int]],
    ]]] = None,
    return_pooler_output: Optional[bool] = True,
    use_rope: Optional[bool] = False,
    window_size: Optional[bool] = -1,
) -> BaseModelOutputWithPooling:

    hidden_states = self.embeddings(
        pixel_values,
        interpolate_pos_encoding=interpolate_pos_encoding,
        position_ids=position_ids,
        image_grid_thw=image_grid_thw,
    )

    last_hidden_state = self.encoder(
        inputs_embeds=hidden_states,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        attention_mask=attention_mask,
        cu_seqlens=cu_seqlens,
        image_grid_thw=image_grid_thw,
        use_rope=use_rope,
        height_position_ids=height_position_ids,
        width_position_ids=width_position_ids,
        window_size=window_size,
        vision_or_text="vision",
    )

    last_hidden_state = self.post_layernorm(last_hidden_state)

    sample_hidden_state = list()
    if cu_seqlens is None:
        raise ValueError("cu_seqlens cannot be None for "
                         "SiglipVisionTransformer output processing.")
    for i in range(cu_seqlens.shape[0] - 1):
        start = cu_seqlens[i]
        end = cu_seqlens[i + 1]
        tensor = last_hidden_state[:, start:end, :].squeeze(0)
        sample_hidden_state.append(tensor)

    return sample_hidden_state

KeyeVideoEmbeddingInputs

Bases: TensorSchema

Dimensions
  • nf: Number of video features
  • hs: Hidden size (must match the hidden size of language model backbone)
  • nv: Number of videos
  • g: Grid dimensions (3 for t, h, w)
Source code in vllm/model_executor/models/keye.py
class KeyeVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size (must match the hidden size of language model 
          backbone)
        - nv: Number of videos
        - g: Grid dimensions (3 for t, h, w)
    """
    type: Literal["video_embeds"]
    video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]

type instance-attribute

type: Literal['video_embeds']

video_embeds instance-attribute

video_embeds: Annotated[Tensor, TensorShape(nf, hs)]

video_grid_thw instance-attribute

video_grid_thw: Annotated[Tensor, TensorShape(nv, 3)]

KeyeVideoPixelInputs

Bases: TensorSchema

Dimensions
  • b: Batch size
  • np: Number of patches
  • c: Number of channels
  • ps: Patch size
  • ni: Number of images
  • g: Grid dimensions (3 for t, h, w)
Source code in vllm/model_executor/models/keye.py
class KeyeVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - np: Number of patches
        - c: Number of channels
        - ps: Patch size
        - ni: Number of images
        - g: Grid dimensions (3 for t, h, w)
    """
    type: Literal["pixel_values_videos"]
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
    video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]

pixel_values_videos instance-attribute

pixel_values_videos: Annotated[
    Tensor,
    TensorShape(b, numpy, 3, ps, ps, dynamic_dims={numpy}),
]

type instance-attribute

type: Literal['pixel_values_videos']

video_grid_thw instance-attribute

video_grid_thw: Annotated[Tensor, TensorShape(nv, 3)]

KeyeVisionEmbeddings

Bases: Module

Source code in vllm/model_executor/models/keye.py
class KeyeVisionEmbeddings(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.cache_position_embedding = dict()
        self.cache_position_count = dict()
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)

        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )

    def interpolate_pos_encoding(
        self,
        embeddings: torch.Tensor,
        height: int,
        width: int,
        is_after_patchify: bool = False,
    ) -> torch.Tensor:

        num_positions = self.position_embedding.weight.shape[0]

        patch_pos_embed = self.position_embedding.weight.unsqueeze(0)

        dim = embeddings.shape[-1]

        if is_after_patchify:
            new_height = height
            new_width = width
        else:
            new_height = height // self.patch_size
            new_width = width // self.patch_size

        sqrt_num_positions = torch_int(num_positions**0.5)
        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
                                                  sqrt_num_positions, dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            size=(new_height, new_width),
            mode="bilinear",
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def fetch_position_embedding_lfu_cache(self,
                                           embeddings,
                                           h,
                                           w,
                                           max_cache: int = 20):
        grid = (h, w)
        if grid in self.cache_position_embedding:
            self.cache_position_count[grid] += 1
            return self.cache_position_embedding[grid]

        if len(self.cache_position_embedding) >= max_cache:
            min_hit_grid = min(
                self.cache_position_count,
                key=self.cache_position_count.get,
            )
            self.cache_position_count.pop(min_hit_grid)
            self.cache_position_embedding.pop(min_hit_grid)

        position_embedding = self.interpolate_pos_encoding(
            embeddings, h, w, True)
        self.cache_position_count[grid] = 1
        self.cache_position_embedding[grid] = position_embedding
        return position_embedding

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        position_ids: Optional[torch.Tensor] = None,
        image_grid_thw: Optional[list[Union[
            tuple[int, int, int],
            list[tuple[int, int, int]],
        ]]] = None,
        interpolate_pos_encoding=False,
    ) -> torch.Tensor:
        if pixel_values.dim() == 4:
            pixel_values = pixel_values.unsqueeze(0)
        if pixel_values.dim() == 5:
            if position_ids is None:
                raise ValueError(
                    "position_ids cannot be None when pixel_values.dim() is 5."
                )
            (
                batch_size,
                squence_len,
                channel,
                height,
                width,
            ) = pixel_values.shape
            target_dtype = self.patch_embedding.weight.dtype
            pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
            patch_embeds = self.patch_embedding(
                pixel_values.to(dtype=target_dtype))
            embeddings = patch_embeds.flatten(-2).squeeze(-1)

            if interpolate_pos_encoding and image_grid_thw is not None:
                start = 0
                tmp_embeddings = list()
                for image_grid in image_grid_thw:
                    t, h, w = image_grid
                    end = start + t * h * w
                    image_embeddings = embeddings[start:end, :]
                    position_embedding = (self.interpolate_pos_encoding(
                        image_embeddings, h, w, True).squeeze(0).repeat(t, 1))
                    image_embeddings = image_embeddings + position_embedding
                    tmp_embeddings.append(image_embeddings)
                    start = end
                embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
            else:
                embeddings = embeddings + self.packing_position_embedding(
                    position_ids)
            return embeddings
        else:
            raise ValueError("Unsupported pixel_values dimension:"
                             f" {pixel_values.dim()}. Expected 4 or 5.")

cache_position_count instance-attribute

cache_position_count = dict()

cache_position_embedding instance-attribute

cache_position_embedding = dict()

config instance-attribute

config = config

embed_dim instance-attribute

embed_dim = hidden_size

image_size instance-attribute

image_size = image_size

num_patches instance-attribute

num_patches = (image_size // patch_size) ** 2

num_positions instance-attribute

num_positions = num_patches

packing_position_embedding instance-attribute

packing_position_embedding = Embedding(32768, embed_dim)

patch_embedding instance-attribute

patch_embedding = Conv2d(
    in_channels=num_channels,
    out_channels=embed_dim,
    kernel_size=patch_size,
    stride=patch_size,
    padding="valid",
)

patch_size instance-attribute

patch_size = patch_size

position_embedding instance-attribute

position_embedding = Embedding(num_positions, embed_dim)

__init__

__init__(config: PretrainedConfig)
Source code in vllm/model_executor/models/keye.py
def __init__(self, config: PretrainedConfig):
    super().__init__()
    self.config = config
    self.embed_dim = config.hidden_size
    self.image_size = config.image_size
    self.patch_size = config.patch_size

    self.patch_embedding = nn.Conv2d(
        in_channels=config.num_channels,
        out_channels=self.embed_dim,
        kernel_size=self.patch_size,
        stride=self.patch_size,
        padding="valid",
    )

    self.num_patches = (self.image_size // self.patch_size)**2
    self.num_positions = self.num_patches
    self.cache_position_embedding = dict()
    self.cache_position_count = dict()
    self.position_embedding = nn.Embedding(self.num_positions,
                                           self.embed_dim)
    self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)

    self.register_buffer(
        "position_ids",
        torch.arange(self.num_positions).expand((1, -1)),
        persistent=False,
    )

fetch_position_embedding_lfu_cache

fetch_position_embedding_lfu_cache(
    embeddings, h, w, max_cache: int = 20
)
Source code in vllm/model_executor/models/keye.py
def fetch_position_embedding_lfu_cache(self,
                                       embeddings,
                                       h,
                                       w,
                                       max_cache: int = 20):
    grid = (h, w)
    if grid in self.cache_position_embedding:
        self.cache_position_count[grid] += 1
        return self.cache_position_embedding[grid]

    if len(self.cache_position_embedding) >= max_cache:
        min_hit_grid = min(
            self.cache_position_count,
            key=self.cache_position_count.get,
        )
        self.cache_position_count.pop(min_hit_grid)
        self.cache_position_embedding.pop(min_hit_grid)

    position_embedding = self.interpolate_pos_encoding(
        embeddings, h, w, True)
    self.cache_position_count[grid] = 1
    self.cache_position_embedding[grid] = position_embedding
    return position_embedding

forward

forward(
    pixel_values: FloatTensor,
    position_ids: Optional[Tensor] = None,
    image_grid_thw: Optional[
        list[
            Union[
                tuple[int, int, int],
                list[tuple[int, int, int]],
            ]
        ]
    ] = None,
    interpolate_pos_encoding=False,
) -> Tensor
Source code in vllm/model_executor/models/keye.py
def forward(
    self,
    pixel_values: torch.FloatTensor,
    position_ids: Optional[torch.Tensor] = None,
    image_grid_thw: Optional[list[Union[
        tuple[int, int, int],
        list[tuple[int, int, int]],
    ]]] = None,
    interpolate_pos_encoding=False,
) -> torch.Tensor:
    if pixel_values.dim() == 4:
        pixel_values = pixel_values.unsqueeze(0)
    if pixel_values.dim() == 5:
        if position_ids is None:
            raise ValueError(
                "position_ids cannot be None when pixel_values.dim() is 5."
            )
        (
            batch_size,
            squence_len,
            channel,
            height,
            width,
        ) = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
        pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype))
        embeddings = patch_embeds.flatten(-2).squeeze(-1)

        if interpolate_pos_encoding and image_grid_thw is not None:
            start = 0
            tmp_embeddings = list()
            for image_grid in image_grid_thw:
                t, h, w = image_grid
                end = start + t * h * w
                image_embeddings = embeddings[start:end, :]
                position_embedding = (self.interpolate_pos_encoding(
                    image_embeddings, h, w, True).squeeze(0).repeat(t, 1))
                image_embeddings = image_embeddings + position_embedding
                tmp_embeddings.append(image_embeddings)
                start = end
            embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
        else:
            embeddings = embeddings + self.packing_position_embedding(
                position_ids)
        return embeddings
    else:
        raise ValueError("Unsupported pixel_values dimension:"
                         f" {pixel_values.dim()}. Expected 4 or 5.")

interpolate_pos_encoding

interpolate_pos_encoding(
    embeddings: Tensor,
    height: int,
    width: int,
    is_after_patchify: bool = False,
) -> Tensor
Source code in vllm/model_executor/models/keye.py
def interpolate_pos_encoding(
    self,
    embeddings: torch.Tensor,
    height: int,
    width: int,
    is_after_patchify: bool = False,
) -> torch.Tensor:

    num_positions = self.position_embedding.weight.shape[0]

    patch_pos_embed = self.position_embedding.weight.unsqueeze(0)

    dim = embeddings.shape[-1]

    if is_after_patchify:
        new_height = height
        new_width = width
    else:
        new_height = height // self.patch_size
        new_width = width // self.patch_size

    sqrt_num_positions = torch_int(num_positions**0.5)
    patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
                                              sqrt_num_positions, dim)
    patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

    patch_pos_embed = nn.functional.interpolate(
        patch_pos_embed,
        size=(new_height, new_width),
        mode="bilinear",
        align_corners=False,
    )

    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
    return patch_pos_embed

Projector

Bases: Module

Source code in vllm/model_executor/models/keye.py
class Projector(nn.Module):

    def __init__(
        self,
        text_config: PretrainedConfig,
        vision_config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.text_config = text_config
        self.vision_config = vision_config
        self.merge_kernel_size = (2, 2)

        self.hidden_size = (self.vision_config.hidden_size *
                            self.merge_kernel_size[0] *
                            self.merge_kernel_size[1])

        self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size,
                                           eps=1e-05)
        self.act = GELUActivation()

        self.linear_1 = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
        self.linear_2 = RowParallelLinear(
            self.hidden_size,
            self.text_config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )

    def forward(
        self,
        image_features: Union[torch.Tensor, list[torch.Tensor]],
        image_grid_thw: list[tuple[int, int, int]],
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        m1, m2 = self.merge_kernel_size
        if isinstance(image_features, (list, tuple)):
            processed_features = list()
            for image_feature, image_grid in zip(image_features,
                                                 image_grid_thw):
                image_feature = self.pre_norm(image_feature)
                t, h, w = image_grid

                image_feature = rearrange(
                    image_feature,
                    "(t h p1 w p2) d -> (t h w) (p1 p2 d)",
                    t=t,
                    h=h // m1,
                    p1=m1,
                    w=w // m2,
                    p2=m2,
                )
                hidden_states, _ = self.linear_1(image_feature)
                hidden_states = self.act(hidden_states)
                hidden_states, _ = self.linear_2(hidden_states)
                processed_features.append(hidden_states)

            return processed_features

        dims = image_features.shape[:-1]
        dim = image_features.shape[-1]
        image_features = image_features.view(np.prod(dims), dim)
        hidden_states = self.pre_norm(image_features).view(
            -1, self.hidden_size)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)

        return hidden_states.view(*dims, -1)

act instance-attribute

act = GELUActivation()

hidden_size instance-attribute

hidden_size = (
    hidden_size
    * merge_kernel_size[0]
    * merge_kernel_size[1]
)

linear_1 instance-attribute

linear_1 = ColumnParallelLinear(
    hidden_size,
    hidden_size,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.linear_1",
)

linear_2 instance-attribute

linear_2 = RowParallelLinear(
    hidden_size,
    hidden_size,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.linear_2",
)

merge_kernel_size instance-attribute

merge_kernel_size = (2, 2)

pre_norm instance-attribute

pre_norm = LayerNorm(hidden_size, eps=1e-05)

text_config instance-attribute

text_config = text_config

vision_config instance-attribute

vision_config = vision_config

__init__

__init__(
    text_config: PretrainedConfig,
    vision_config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/keye.py
def __init__(
    self,
    text_config: PretrainedConfig,
    vision_config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.text_config = text_config
    self.vision_config = vision_config
    self.merge_kernel_size = (2, 2)

    self.hidden_size = (self.vision_config.hidden_size *
                        self.merge_kernel_size[0] *
                        self.merge_kernel_size[1])

    self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size,
                                       eps=1e-05)
    self.act = GELUActivation()

    self.linear_1 = ColumnParallelLinear(
        self.hidden_size,
        self.hidden_size,
        bias=True,
        quant_config=quant_config,
        prefix=f"{prefix}.linear_1",
    )
    self.linear_2 = RowParallelLinear(
        self.hidden_size,
        self.text_config.hidden_size,
        bias=True,
        quant_config=quant_config,
        prefix=f"{prefix}.linear_2",
    )

forward

forward(
    image_features: Union[Tensor, list[Tensor]],
    image_grid_thw: list[tuple[int, int, int]],
) -> Union[Tensor, list[Tensor]]
Source code in vllm/model_executor/models/keye.py
def forward(
    self,
    image_features: Union[torch.Tensor, list[torch.Tensor]],
    image_grid_thw: list[tuple[int, int, int]],
) -> Union[torch.Tensor, list[torch.Tensor]]:
    m1, m2 = self.merge_kernel_size
    if isinstance(image_features, (list, tuple)):
        processed_features = list()
        for image_feature, image_grid in zip(image_features,
                                             image_grid_thw):
            image_feature = self.pre_norm(image_feature)
            t, h, w = image_grid

            image_feature = rearrange(
                image_feature,
                "(t h p1 w p2) d -> (t h w) (p1 p2 d)",
                t=t,
                h=h // m1,
                p1=m1,
                w=w // m2,
                p2=m2,
            )
            hidden_states, _ = self.linear_1(image_feature)
            hidden_states = self.act(hidden_states)
            hidden_states, _ = self.linear_2(hidden_states)
            processed_features.append(hidden_states)

        return processed_features

    dims = image_features.shape[:-1]
    dim = image_features.shape[-1]
    image_features = image_features.view(np.prod(dims), dim)
    hidden_states = self.pre_norm(image_features).view(
        -1, self.hidden_size)
    hidden_states = self.linear_1(hidden_states)
    hidden_states = self.act(hidden_states)
    hidden_states = self.linear_2(hidden_states)

    return hidden_states.view(*dims, -1)

SigLIPRotaryEmbedding

Bases: Module

Source code in vllm/model_executor/models/keye.py
class SigLIPRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.rope_init()

    def rope_init(self):
        inv_freq = 1.0 / (self.theta**(
            torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(
            seqlen,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )
        freqs = torch.outer(seq, self.inv_freq)
        return freqs

dim instance-attribute

dim = dim

theta instance-attribute

theta = theta

__init__

__init__(dim: int, theta: float = 10000.0) -> None
Source code in vllm/model_executor/models/keye.py
def __init__(self, dim: int, theta: float = 10000.0) -> None:
    super().__init__()
    self.dim = dim
    self.theta = theta
    self.rope_init()

forward

forward(seqlen: int) -> Tensor
Source code in vllm/model_executor/models/keye.py
def forward(self, seqlen: int) -> torch.Tensor:
    seq = torch.arange(
        seqlen,
        device=self.inv_freq.device,
        dtype=self.inv_freq.dtype,
    )
    freqs = torch.outer(seq, self.inv_freq)
    return freqs

rope_init

rope_init()
Source code in vllm/model_executor/models/keye.py
def rope_init(self):
    inv_freq = 1.0 / (self.theta**(
        torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim))
    self.register_buffer("inv_freq", inv_freq, persistent=False)

_keye_field_config

_keye_field_config(hf_inputs: Mapping[str, Tensor])
Source code in vllm/model_executor/models/keye.py
def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ):
    image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
    image_grid_sizes = image_grid_thw.prod(-1)

    video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
    video_grid_sizes = video_grid_thw.prod(-1)

    return dict(
        pixel_values=MultiModalFieldConfig.flat_from_sizes(
            "image", image_grid_sizes),
        image_embeds=MultiModalFieldConfig.flat_from_sizes(
            "image", image_grid_sizes),
        image_grid_thw=MultiModalFieldConfig.batched("image"),
        pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
            "video", video_grid_sizes),
        video_embeds=MultiModalFieldConfig.flat_from_sizes(
            "video", video_grid_sizes),
        video_grid_thw=MultiModalFieldConfig.batched("video"),
    )

apply_rotary_pos_emb_flashatt

apply_rotary_pos_emb_flashatt(
    q: Tensor, k: Tensor, cos: Tensor, sin: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/keye.py
def apply_rotary_pos_emb_flashatt(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()

    from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb

    q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
    k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
    return q_embed, k_embed

smart_resize

smart_resize(
    height: int,
    width: int,
    factor: int,
    min_pixels: int,
    max_pixels: int,
)
Source code in vllm/model_executor/models/keye.py
def smart_resize(
    height: int,
    width: int,
    factor: int,
    min_pixels: int,
    max_pixels: int,
):
    if height < factor:
        logger.warning(
            "smart_resize: height=%s < factor=%s, reset height=factor",
            height,
            factor,
        )
        width = round((width * factor) / height)
        height = factor

    if width < factor:
        logger.warning(
            "smart_resize: width=%s < factor=%s, reset width=factor",
            width,
            factor,
        )
        height = round((height * factor) / width)
        width = factor

    if max(height, width) / min(height, width) > 200:
        raise ValueError("absolute aspect ratio must be smaller than 200, got "
                         "{max(height, width) / min(height, width)}")
    h_bar = round(height / factor) * factor
    w_bar = round(width / factor) * factor
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = math.floor(height / beta / factor) * factor
        w_bar = math.floor(width / beta / factor) * factor
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = math.ceil(height * beta / factor) * factor
        w_bar = math.ceil(width * beta / factor) * factor
    return h_bar, w_bar