Skip to content

vllm.model_executor.models.glm4_1v

Inference-only GLM-4V model compatible with HuggingFace weights.

Glm4vImageInputs module-attribute

Glm4vVideoInputs module-attribute

_MAX_FRAMES_PER_VIDEO module-attribute

_MAX_FRAMES_PER_VIDEO = 600

logger module-attribute

logger = init_logger(__name__)

Glm4vDummyInputsBuilder

Bases: BaseDummyInputsBuilder[Glm4vProcessingInfo]

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_config = self.info.get_hf_config()
        hf_processor = self.info.get_hf_processor()
        tokenizer = self.info.get_tokenizer()

        image_token: str = hf_processor.image_token
        video_token_ids = [
            hf_config.video_start_token_id,
            hf_processor.video_token_id,
            hf_config.video_end_token_id,
        ]
        video_token = tokenizer.decode(video_token_ids)

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        target_width, target_height = (
            self.info.get_image_size_with_most_features())
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts)
        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[VideoItem]:
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
                "frames_indices": [i for i in range(num_frames)],
                "video_backend": "opencv",
                "do_sample_frames": False,
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)

        return video_items

_get_dummy_videos

_get_dummy_videos(
    *,
    width: int,
    height: int,
    num_frames: int,
    num_videos: int,
) -> list[VideoItem]
Source code in vllm/model_executor/models/glm4_1v.py
def _get_dummy_videos(
    self,
    *,
    width: int,
    height: int,
    num_frames: int,
    num_videos: int,
) -> list[VideoItem]:
    video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
    video_items = []
    for i in range(num_videos):
        video_metadata = {
            "fps": 2.0,
            "duration": num_frames / 2.0,
            "total_num_frames": num_frames,
            "frames_indices": [i for i in range(num_frames)],
            "video_backend": "opencv",
            "do_sample_frames": False,
        }
        video_item = (video.copy(), video_metadata)
        video_items.append(video_item)

    return video_items

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int, mm_counts: Mapping[str, int]
) -> MultiModalDataDict
Source code in vllm/model_executor/models/glm4_1v.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
    num_images = mm_counts.get("image", 0)
    num_videos = mm_counts.get("video", 0)

    target_width, target_height = (
        self.info.get_image_size_with_most_features())
    target_num_frames = self.info.get_num_frames_with_most_features(
        seq_len, mm_counts)
    return {
        "image":
        self._get_dummy_images(width=target_width,
                               height=target_height,
                               num_images=num_images),
        "video":
        self._get_dummy_videos(
            width=target_width,
            height=target_height,
            num_frames=target_num_frames,
            num_videos=num_videos,
        ),
    }

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/glm4_1v.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    num_images = mm_counts.get("image", 0)
    num_videos = mm_counts.get("video", 0)

    hf_config = self.info.get_hf_config()
    hf_processor = self.info.get_hf_processor()
    tokenizer = self.info.get_tokenizer()

    image_token: str = hf_processor.image_token
    video_token_ids = [
        hf_config.video_start_token_id,
        hf_processor.video_token_id,
        hf_config.video_end_token_id,
    ]
    video_token = tokenizer.decode(video_token_ids)

    return image_token * num_images + video_token * num_videos

Glm4vForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsLoRA, SupportsPP

Source code in vllm/model_executor/models/glm4_1v.py
@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
                                    SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": ["gate_up_proj"]
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
        })

    supports_encoder_tp_data = True

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|begin_of_image|><|image|><|end_of_image|>"
        if modality.startswith("video"):
            return "<|begin_of_video|><|video|><|end_of_video|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

        self.visual = Glm4vVisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-5),
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "visual"),
            use_data_parallel=self.use_data_parallel,
        )

        if config.model_type == "glm4v":
            architectures = ["Glm4ForCausalLM"]
        elif config.model_type == "glm4v_moe":
            architectures = ["Glm4MoeForCausalLM"]
        else:
            architectures = None

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=architectures)

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    def _validate_and_reshape_mm_tensor(self, mm_input: object,
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(
                f"Incorrect type of {name}. Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
                raise ValueError(f"{name} should be 2D or batched 3D tensor. "
                                 f"Got ndim: {mm_input.ndim} "
                                 f"(shape={mm_input.shape})")
            return mm_input.reshape(-1, mm_input.shape[-1])
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Glm4vImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")

            return Glm4vImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")

            return Glm4vImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[Glm4vVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos, "video pixel values")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Glm4vVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Glm4vVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_image_input(
            self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]:
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(self.visual,
                                                         pixel_values,
                                                         grid_thw.tolist(),
                                                         rope_type="rope_3d")
            else:
                image_embeds = self.visual(pixel_values,
                                           grid_thw=grid_thw.tolist())
        merge_size = self.visual.spatial_merge_size
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
        return image_embeds.split(sizes)

    def _process_video_input(
            self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]:
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
                self.visual.dtype)
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(self.visual,
                                                         pixel_values_videos,
                                                         grid_thw.tolist(),
                                                         rope_type="rope_3d")
            else:
                video_embeds = self.visual(pixel_values_videos,
                                           grid_thw=grid_thw.tolist())
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
        return video_embeds.split(sizes)

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (input_key in ("pixel_values", "image_embeds")
                    and "image" not in mm_input_by_modality):
                mm_input_by_modality["image"] = (
                    self._parse_and_validate_image_input(**kwargs))
            if (input_key in ("pixel_values_videos", "video_embeds")
                    and "video" not in mm_input_by_modality):
                mm_input_by_modality["video"] = (
                    self._parse_and_validate_video_input(**kwargs))
        return mm_input_by_modality

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
            **kwargs)
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor corresponding to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += vision_embeddings
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                multimodal_embeddings += video_embeddings
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        """Run forward pass for GLM-4V.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for GLM-4V
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
            intermediate_tensors: Optional intermediate tensors for pipeline
                parallelism.
            inputs_embeds: Optional pre-computed input embeddings.
            **kwargs: Additional keyword arguments.
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.model",
            connector="visual.merger.",
            tower_model="visual.",
        )

config instance-attribute

config = config

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_prefix={
        "lm_head.": "language_model.lm_head.",
        "model.language_model.": "language_model.model.",
        "model.visual.": "visual.",
    }
)

language_model instance-attribute

language_model = init_vllm_registered_model(
    vllm_config=vllm_config,
    hf_config=text_config,
    prefix=maybe_prefix(prefix, "language_model"),
    architectures=architectures,
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_up_proj"],
}

supports_encoder_tp_data class-attribute instance-attribute

supports_encoder_tp_data = True

use_data_parallel instance-attribute

use_data_parallel = mm_encoder_tp_mode == 'data'

visual instance-attribute

visual = Glm4vVisionTransformer(
    vision_config,
    norm_eps=getattr(config, "rms_norm_eps", 1e-05),
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "visual"),
    use_data_parallel=use_data_parallel,
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    multimodal_config = vllm_config.model_config.multimodal_config

    self.config = config
    self.multimodal_config = multimodal_config
    self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

    self.visual = Glm4vVisionTransformer(
        config.vision_config,
        norm_eps=getattr(config, "rms_norm_eps", 1e-5),
        quant_config=quant_config,
        prefix=maybe_prefix(prefix, "visual"),
        use_data_parallel=self.use_data_parallel,
    )

    if config.model_type == "glm4v":
        architectures = ["Glm4ForCausalLM"]
    elif config.model_type == "glm4v_moe":
        architectures = ["Glm4MoeForCausalLM"]
    else:
        architectures = None

    self.language_model = init_vllm_registered_model(
        vllm_config=vllm_config,
        hf_config=config.text_config,
        prefix=maybe_prefix(prefix, "language_model"),
        architectures=architectures)

    self.make_empty_intermediate_tensors = (
        self.language_model.make_empty_intermediate_tensors)

_parse_and_validate_image_input

_parse_and_validate_image_input(
    **kwargs: object,
) -> Optional[Glm4vImageInputs]
Source code in vllm/model_executor/models/glm4_1v.py
def _parse_and_validate_image_input(
        self, **kwargs: object) -> Optional[Glm4vImageInputs]:
    pixel_values = kwargs.pop("pixel_values", None)
    image_embeds = kwargs.pop("image_embeds", None)
    image_grid_thw = kwargs.pop("image_grid_thw", None)

    if pixel_values is None and image_embeds is None:
        return None

    if pixel_values is not None:
        pixel_values = self._validate_and_reshape_mm_tensor(
            pixel_values, "image pixel values")
        image_grid_thw = self._validate_and_reshape_mm_tensor(
            image_grid_thw, "image grid_thw")

        return Glm4vImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
        )

    if image_embeds is not None:
        image_embeds = self._validate_and_reshape_mm_tensor(
            image_embeds, "image embeds")
        image_grid_thw = self._validate_and_reshape_mm_tensor(
            image_grid_thw, "image grid_thw")

        return Glm4vImageEmbeddingInputs(
            type="image_embeds",
            image_embeds=image_embeds,
            image_grid_thw=image_grid_thw,
        )

_parse_and_validate_multimodal_inputs

_parse_and_validate_multimodal_inputs(
    **kwargs: object,
) -> dict
Source code in vllm/model_executor/models/glm4_1v.py
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
    mm_input_by_modality = {}

    # Preserve the order of modalities if there are multiple of them
    # from the order of kwargs.
    for input_key in kwargs:
        if (input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality):
            mm_input_by_modality["image"] = (
                self._parse_and_validate_image_input(**kwargs))
        if (input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality):
            mm_input_by_modality["video"] = (
                self._parse_and_validate_video_input(**kwargs))
    return mm_input_by_modality

_parse_and_validate_video_input

_parse_and_validate_video_input(
    **kwargs: object,
) -> Optional[Glm4vVideoInputs]
Source code in vllm/model_executor/models/glm4_1v.py
def _parse_and_validate_video_input(
        self, **kwargs: object) -> Optional[Glm4vVideoInputs]:
    pixel_values_videos = kwargs.pop("pixel_values_videos", None)
    video_embeds = kwargs.pop("video_embeds", None)
    video_grid_thw = kwargs.pop("video_grid_thw", None)

    if pixel_values_videos is None and video_embeds is None:
        return None

    if pixel_values_videos is not None:
        pixel_values_videos = self._validate_and_reshape_mm_tensor(
            pixel_values_videos, "video pixel values")
        video_grid_thw = self._validate_and_reshape_mm_tensor(
            video_grid_thw, "video grid_thw")

        return Glm4vVideoPixelInputs(
            type="pixel_values_videos",
            pixel_values_videos=pixel_values_videos,
            video_grid_thw=video_grid_thw,
        )

    if video_embeds is not None:
        video_embeds = self._validate_and_reshape_mm_tensor(
            video_embeds, "video embeds")
        video_grid_thw = self._validate_and_reshape_mm_tensor(
            video_grid_thw, "video grid_thw")

        return Glm4vVideoEmbeddingInputs(
            type="video_embeds",
            video_embeds=video_embeds,
            video_grid_thw=video_grid_thw,
        )

_process_image_input

_process_image_input(
    image_input: Glm4vImageInputs,
) -> tuple[Tensor, ...]
Source code in vllm/model_executor/models/glm4_1v.py
def _process_image_input(
        self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]:
    grid_thw = image_input["image_grid_thw"]
    assert grid_thw.ndim == 2
    grid_thw_list = grid_thw.tolist()

    if image_input["type"] == "image_embeds":
        image_embeds = image_input["image_embeds"].type(self.visual.dtype)
    else:
        pixel_values = image_input["pixel_values"].type(self.visual.dtype)
        if self.use_data_parallel:
            return run_dp_sharded_mrope_vision_model(self.visual,
                                                     pixel_values,
                                                     grid_thw.tolist(),
                                                     rope_type="rope_3d")
        else:
            image_embeds = self.visual(pixel_values,
                                       grid_thw=grid_thw.tolist())
    merge_size = self.visual.spatial_merge_size
    sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
             (merge_size * merge_size)).tolist()
    return image_embeds.split(sizes)

_process_video_input

_process_video_input(
    video_input: Glm4vVideoInputs,
) -> tuple[Tensor, ...]
Source code in vllm/model_executor/models/glm4_1v.py
def _process_video_input(
        self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]:
    grid_thw = video_input["video_grid_thw"]
    assert grid_thw.ndim == 2
    grid_thw_list = grid_thw.tolist()

    if video_input["type"] == "video_embeds":
        video_embeds = video_input["video_embeds"].type(self.visual.dtype)
    else:
        pixel_values_videos = video_input["pixel_values_videos"].type(
            self.visual.dtype)
        if self.use_data_parallel:
            return run_dp_sharded_mrope_vision_model(self.visual,
                                                     pixel_values_videos,
                                                     grid_thw.tolist(),
                                                     rope_type="rope_3d")
        else:
            video_embeds = self.visual(pixel_values_videos,
                                       grid_thw=grid_thw.tolist())
    # Split concatenated embeddings for each video item.
    merge_size = self.visual.spatial_merge_size
    sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
             (merge_size * merge_size)).tolist()
    return video_embeds.split(sizes)

_validate_and_reshape_mm_tensor

_validate_and_reshape_mm_tensor(
    mm_input: object, name: str
) -> Tensor
Source code in vllm/model_executor/models/glm4_1v.py
def _validate_and_reshape_mm_tensor(self, mm_input: object,
                                    name: str) -> torch.Tensor:
    if not isinstance(mm_input, (torch.Tensor, list)):
        raise ValueError(
            f"Incorrect type of {name}. Got type: {type(mm_input)}")
    if isinstance(mm_input, torch.Tensor):
        if mm_input.ndim == 2:
            return mm_input
        if mm_input.ndim != 3:
            raise ValueError(f"{name} should be 2D or batched 3D tensor. "
                             f"Got ndim: {mm_input.ndim} "
                             f"(shape={mm_input.shape})")
        return mm_input.reshape(-1, mm_input.shape[-1])
    else:
        return torch.concat(mm_input)

compute_logits

compute_logits(hidden_states: Tensor) -> Optional[Tensor]
Source code in vllm/model_executor/models/glm4_1v.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
    return self.language_model.compute_logits(hidden_states)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs: object,
) -> Union[Tensor, IntermediateTensors]

Run forward pass for GLM-4V.

Parameters:

Name Type Description Default
input_ids Tensor

Flattened (concatenated) input_ids corresponding to a batch.

required
positions Tensor

Flattened (concatenated) position ids corresponding to a batch. NOTE: If mrope is enabled (default setting for GLM-4V opensource models), the shape will be (3, seq_len), otherwise it will be `(seq_len,).

required
intermediate_tensors Optional[IntermediateTensors]

Optional intermediate tensors for pipeline parallelism.

None
inputs_embeds Optional[Tensor]

Optional pre-computed input embeddings.

None
**kwargs object

Additional keyword arguments.

{}
Source code in vllm/model_executor/models/glm4_1v.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    """Run forward pass for GLM-4V.

    Args:
        input_ids: Flattened (concatenated) input_ids corresponding to a
            batch.
        positions: Flattened (concatenated) position ids corresponding to a
            batch.
            **NOTE**: If mrope is enabled (default setting for GLM-4V
            opensource models), the shape will be `(3, seq_len)`,
            otherwise it will be `(seq_len,).
        intermediate_tensors: Optional intermediate tensors for pipeline
            parallelism.
        inputs_embeds: Optional pre-computed input embeddings.
        **kwargs: Additional keyword arguments.
    """
    if intermediate_tensors is not None:
        inputs_embeds = None

    hidden_states = self.language_model.model(
        input_ids=input_ids,
        positions=positions,
        intermediate_tensors=intermediate_tensors,
        inputs_embeds=inputs_embeds,
    )
    return hidden_states

get_language_model

get_language_model() -> Module
Source code in vllm/model_executor/models/glm4_1v.py
def get_language_model(self) -> torch.nn.Module:
    return self.language_model

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/glm4_1v.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model.model",
        connector="visual.merger.",
        tower_model="visual.",
    )

get_multimodal_embeddings

get_multimodal_embeddings(
    **kwargs: object,
) -> Optional[MultiModalEmbeddings]
Source code in vllm/model_executor/models/glm4_1v.py
def get_multimodal_embeddings(
        self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
    mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
        **kwargs)
    if not mm_input_by_modality:
        return None

    # The result multimodal_embeddings is tuple of tensors, with each
    # tensor corresponding to a multimodal data item (image or video).
    multimodal_embeddings: tuple[torch.Tensor, ...] = ()

    # NOTE: It is important to iterate over the keys in this dictionary
    # to preserve the order of the modalities.
    for modality in mm_input_by_modality:
        multimodal_input = mm_input_by_modality[modality]
        if modality == "image":
            vision_embeddings = self._process_image_input(multimodal_input)
            multimodal_embeddings += vision_embeddings
        if modality == "video":
            video_embeddings = self._process_video_input(multimodal_input)
            multimodal_embeddings += video_embeddings
    return multimodal_embeddings

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> Optional[str]
Source code in vllm/model_executor/models/glm4_1v.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
    if modality.startswith("image"):
        return "<|begin_of_image|><|image|><|end_of_image|>"
    if modality.startswith("video"):
        return "<|begin_of_video|><|video|><|end_of_video|>"

    raise ValueError("Only image or video modality is supported")

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/glm4_1v.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

Glm4vImageEmbeddingInputs

Bases: TensorSchema

Dimensions
  • f: Number of image features (varies based on image resolution)
  • h: Hidden size (must match language model backbone)
  • n: Number of images
  • g: Grid dimensions (3 for grid_t, grid_h, grid_w)
Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - f: Number of image features (varies based on image resolution)
        - h: Hidden size (must match language model backbone)
        - n: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
    """
    type: Literal["image_embeds"] = "image_embeds"

    image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)]

image_embeds instance-attribute

image_embeds: Annotated[Tensor, TensorShape(f, h)]

image_grid_thw instance-attribute

image_grid_thw: Annotated[Tensor, TensorShape(n, 3)]

type class-attribute instance-attribute

type: Literal['image_embeds'] = 'image_embeds'

Glm4vImagePixelInputs

Bases: TensorSchema

Dimensions
  • np: Number of patches
  • cpp: Number of channels * patch_size * patch_size
  • ni: Number of images
  • g: Grid dimensions (3 for grid_t, grid_h, grid_w)
Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: Number of patches
        - cpp: Number of channels * patch_size * patch_size
        - ni: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
    """
    type: Literal["pixel_values"] = "pixel_values"

    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]

image_grid_thw instance-attribute

image_grid_thw: Annotated[Tensor, TensorShape(ni, 3)]

pixel_values instance-attribute

pixel_values: Annotated[Tensor, TensorShape(numpy, cpp)]

type class-attribute instance-attribute

type: Literal['pixel_values'] = 'pixel_values'

Glm4vMoeForConditionalGeneration

Bases: Glm4vForConditionalGeneration

Source code in vllm/model_executor/models/glm4_1v.py
@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

Glm4vMultiModalProcessor

Bases: BaseMultiModalProcessor[Glm4vProcessingInfo]

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):

    def _get_data_parser(self) -> MultiModalDataParser:
        return MultiModalDataParser(video_needs_metadata=True)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # GLM-4.1V use `image_token_id` as video placeholder, we need to
        # replace it with `video_token_id` for video processing. So we
        # separate video processing from image processing.
        if ("videos" in mm_data and isinstance(mm_data["videos"], list)
                and len(mm_data["videos"]) > 0):
            video_grid_thw_lst = []
            pixel_values_videos_lst = []
            for item in mm_data.pop("videos", []):
                video_array, metadata = item

                # don't update mm_kwargs inplace
                video_mm_kwargs = dict(**mm_kwargs)
                video_mm_kwargs["do_sample_frames"] = metadata.get(
                    "do_sample_frames", True)

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]

                # backward compatibility for Transformers 4.55
                unuse_metadata = ["do_sample_frames"]
                if not hasattr(
                        VideoMetadata,
                        "frames_indices") and "frames_indices" in metadata:
                    unuse_metadata.append("frames_indices")

                video_mm_data["video_metadata"] = [[
                    VideoMetadata(
                        **{
                            k: metadata[k]
                            for k in metadata if k not in unuse_metadata
                        })
                ]]

                video_outputs = super()._call_hf_processor(
                    prompt="<|begin_of_video|><|video|><|end_of_video|>",
                    mm_data=video_mm_data,
                    mm_kwargs=video_mm_kwargs,
                    tok_kwargs=tok_kwargs,
                )
                if not video_mm_kwargs["do_sample_frames"] and Version(
                        TRANSFORMERS_VERSION) < Version("4.56.0"):
                    # Transformers v4.55 has incorrect timestamps issue for
                    # skip sampling. We construct the placeholder manually to
                    # get placeholders with correct timestamps.
                    placeholder = self.info._construct_video_placeholder(
                        video_array,
                        metadata,
                        video_outputs["video_grid_thw"].squeeze(0),
                    )
                    video_placeholder = processor.tokenizer.decode(placeholder)
                else:
                    input_ids = video_outputs.pop("input_ids")
                    input_ids[input_ids == processor.image_token_id] = (
                        processor.video_token_id)
                    video_placeholder = processor.tokenizer.batch_decode(
                        input_ids)[0]
                prompt = prompt.replace(
                    "<|begin_of_video|><|video|><|end_of_video|>",
                    video_placeholder,
                    1,
                )

                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
                pixel_values_videos_lst.append(
                    video_outputs["pixel_values_videos"])
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _create_qwen2vl_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size)(
                hf_inputs)

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_processor = self.info.get_image_processor(
            **hf_processor_mm_kwargs)

        merge_length = image_processor.merge_size**2

        def get_image_replacement_glm4v(item_idx: int):
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_glm4v(item_idx: int):
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            video, metadata = mm_items["video"][item_idx]
            placeholder = self.info._construct_video_placeholder(
                video, metadata, grid_thw)
            return PromptUpdateDetails.select_token_id(
                placeholder,
                embed_token_id=hf_processor.video_token_id,
            )

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_glm4v,
            ),
            PromptReplacement(
                modality="video",
                target="<|begin_of_video|><|video|><|end_of_video|>",
                replacement=get_video_replacement_glm4v,
            ),
        ]

_call_hf_processor

_call_hf_processor(
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature
Source code in vllm/model_executor/models/glm4_1v.py
def _call_hf_processor(
    self,
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature:
    mm_data = dict(mm_data)
    processor = self.info.get_hf_processor(**mm_kwargs)

    # GLM-4.1V use `image_token_id` as video placeholder, we need to
    # replace it with `video_token_id` for video processing. So we
    # separate video processing from image processing.
    if ("videos" in mm_data and isinstance(mm_data["videos"], list)
            and len(mm_data["videos"]) > 0):
        video_grid_thw_lst = []
        pixel_values_videos_lst = []
        for item in mm_data.pop("videos", []):
            video_array, metadata = item

            # don't update mm_kwargs inplace
            video_mm_kwargs = dict(**mm_kwargs)
            video_mm_kwargs["do_sample_frames"] = metadata.get(
                "do_sample_frames", True)

            video_mm_data = dict()
            video_mm_data["videos"] = [[video_array]]

            # backward compatibility for Transformers 4.55
            unuse_metadata = ["do_sample_frames"]
            if not hasattr(
                    VideoMetadata,
                    "frames_indices") and "frames_indices" in metadata:
                unuse_metadata.append("frames_indices")

            video_mm_data["video_metadata"] = [[
                VideoMetadata(
                    **{
                        k: metadata[k]
                        for k in metadata if k not in unuse_metadata
                    })
            ]]

            video_outputs = super()._call_hf_processor(
                prompt="<|begin_of_video|><|video|><|end_of_video|>",
                mm_data=video_mm_data,
                mm_kwargs=video_mm_kwargs,
                tok_kwargs=tok_kwargs,
            )
            if not video_mm_kwargs["do_sample_frames"] and Version(
                    TRANSFORMERS_VERSION) < Version("4.56.0"):
                # Transformers v4.55 has incorrect timestamps issue for
                # skip sampling. We construct the placeholder manually to
                # get placeholders with correct timestamps.
                placeholder = self.info._construct_video_placeholder(
                    video_array,
                    metadata,
                    video_outputs["video_grid_thw"].squeeze(0),
                )
                video_placeholder = processor.tokenizer.decode(placeholder)
            else:
                input_ids = video_outputs.pop("input_ids")
                input_ids[input_ids == processor.image_token_id] = (
                    processor.video_token_id)
                video_placeholder = processor.tokenizer.batch_decode(
                    input_ids)[0]
            prompt = prompt.replace(
                "<|begin_of_video|><|video|><|end_of_video|>",
                video_placeholder,
                1,
            )

            video_grid_thw_lst.append(video_outputs["video_grid_thw"])
            pixel_values_videos_lst.append(
                video_outputs["pixel_values_videos"])
        video_outputs = dict(
            pixel_values_videos=torch.cat(pixel_values_videos_lst),
            video_grid_thw=torch.cat(video_grid_thw_lst),
        )
    else:
        video_outputs = dict()

    processed_outputs = super()._call_hf_processor(
        prompt=prompt,
        mm_data=mm_data,
        mm_kwargs=mm_kwargs,
        tok_kwargs=tok_kwargs,
    )
    combined_outputs = dict(
        processed_outputs,
        **video_outputs,
    )
    return BatchFeature(combined_outputs)

_get_data_parser

_get_data_parser() -> MultiModalDataParser
Source code in vllm/model_executor/models/glm4_1v.py
def _get_data_parser(self) -> MultiModalDataParser:
    return MultiModalDataParser(video_needs_metadata=True)

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/glm4_1v.py
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return _create_qwen2vl_field_factory(
        self.info.get_hf_config().vision_config.spatial_merge_size)(
            hf_inputs)

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, Any],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/glm4_1v.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, Any],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
    hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
    image_processor = self.info.get_image_processor(
        **hf_processor_mm_kwargs)

    merge_length = image_processor.merge_size**2

    def get_image_replacement_glm4v(item_idx: int):
        out_item = out_mm_kwargs["image"][item_idx]
        grid_thw = out_item["image_grid_thw"].data
        assert isinstance(grid_thw, torch.Tensor)

        num_tokens = int(grid_thw.prod()) // merge_length
        return [hf_processor.image_token_id] * num_tokens

    def get_video_replacement_glm4v(item_idx: int):
        out_item = out_mm_kwargs["video"][item_idx]
        grid_thw = out_item["video_grid_thw"].data
        assert isinstance(grid_thw, torch.Tensor)

        video, metadata = mm_items["video"][item_idx]
        placeholder = self.info._construct_video_placeholder(
            video, metadata, grid_thw)
        return PromptUpdateDetails.select_token_id(
            placeholder,
            embed_token_id=hf_processor.video_token_id,
        )

    return [
        PromptReplacement(
            modality="image",
            target=hf_processor.image_token,
            replacement=get_image_replacement_glm4v,
        ),
        PromptReplacement(
            modality="video",
            target="<|begin_of_video|><|video|><|end_of_video|>",
            replacement=get_video_replacement_glm4v,
        ),
    ]

Glm4vPatchMerger

Bases: Module

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vPatchMerger(nn.Module):

    def __init__(
        self,
        d_model: int,
        context_dim: int,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        prefix: str = "",
        use_data_parallel: bool = False,
    ) -> None:
        super().__init__()
        self.hidden_size = d_model
        self.proj = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=bias,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
        self.post_projection_norm = nn.LayerNorm(self.hidden_size)
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=self.hidden_size,
            output_sizes=[context_dim] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
            disable_tp=use_data_parallel,
        )
        self.down_proj = RowParallelLinear(
            context_dim,
            self.hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
            disable_tp=use_data_parallel,
        )
        self.act_fn = SiluAndMul()
        self.extra_activation_func = nn.GELU()

    def forward(self, x: torch.Tensor):
        x, _ = self.proj(x)
        x = self.extra_activation_func(self.post_projection_norm(x))
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = SiluAndMul()

down_proj instance-attribute

down_proj = RowParallelLinear(
    context_dim,
    hidden_size,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
    disable_tp=use_data_parallel,
)

extra_activation_func instance-attribute

extra_activation_func = GELU()

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    input_size=hidden_size,
    output_sizes=[context_dim] * 2,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
    disable_tp=use_data_parallel,
)

hidden_size instance-attribute

hidden_size = d_model

post_projection_norm instance-attribute

post_projection_norm = LayerNorm(hidden_size)

proj instance-attribute

proj = ColumnParallelLinear(
    hidden_size,
    hidden_size,
    bias=bias,
    gather_output=True,
    quant_config=quant_config,
    prefix=f"{prefix}.proj",
    disable_tp=use_data_parallel,
)

__init__

__init__(
    d_model: int,
    context_dim: int,
    quant_config: Optional[QuantizationConfig] = None,
    bias: bool = False,
    prefix: str = "",
    use_data_parallel: bool = False,
) -> None
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(
    self,
    d_model: int,
    context_dim: int,
    quant_config: Optional[QuantizationConfig] = None,
    bias: bool = False,
    prefix: str = "",
    use_data_parallel: bool = False,
) -> None:
    super().__init__()
    self.hidden_size = d_model
    self.proj = ColumnParallelLinear(
        self.hidden_size,
        self.hidden_size,
        bias=bias,
        gather_output=True,
        quant_config=quant_config,
        prefix=f"{prefix}.proj",
        disable_tp=use_data_parallel,
    )
    self.post_projection_norm = nn.LayerNorm(self.hidden_size)
    self.gate_up_proj = MergedColumnParallelLinear(
        input_size=self.hidden_size,
        output_sizes=[context_dim] * 2,
        bias=bias,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
        disable_tp=use_data_parallel,
    )
    self.down_proj = RowParallelLinear(
        context_dim,
        self.hidden_size,
        bias=bias,
        quant_config=quant_config,
        prefix=f"{prefix}.down_proj",
        disable_tp=use_data_parallel,
    )
    self.act_fn = SiluAndMul()
    self.extra_activation_func = nn.GELU()

forward

forward(x: Tensor)
Source code in vllm/model_executor/models/glm4_1v.py
def forward(self, x: torch.Tensor):
    x, _ = self.proj(x)
    x = self.extra_activation_func(self.post_projection_norm(x))
    gate_up, _ = self.gate_up_proj(x)
    x = self.act_fn(gate_up)
    x, _ = self.down_proj(x)
    return x

Glm4vProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config()

    def get_tokenizer(self):
        return self.ctx.tokenizer

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": 1}

    def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor

    def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 16,
        do_resize: bool = True,
        max_image_pixels: int = 28 * 28 * 2 * 30000,
    ) -> tuple[ImageSize, int]:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
        if do_resize:
            resized_height, resized_width = smart_resize(
                num_frames=num_frames
                if num_frames > temporal_patch_size else temporal_patch_size,
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                max_pixels=max_image_pixels,
            )
            preprocessed_size = ImageSize(width=resized_width,
                                          height=resized_height)
        else:
            preprocessed_size = ImageSize(width=image_width,
                                          height=image_height)

        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_image_size_with_most_features(self) -> ImageSize:
        max_image_size, _ = self._get_vision_info(image_width=9999999,
                                                  image_height=9999999)
        return max_image_size

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            max_image_pixels=28 * 28 * 2 * 6144,
        )
        return num_image_tokens

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            max_image_pixels=28 * 28 * 2 * 30000,
        )
        return num_video_tokens

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
            if next_max_tokens > max_tokens or next_max_tokens == 0:
                break

            num_frames = next_num_frames

        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        max_image_tokens = self.get_max_image_tokens() * max_images
        max_total_frames = self._get_max_video_frames(seq_len -
                                                      max_image_tokens)
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)

        return max(max_frames_per_video, 1)

    def _get_video_second_idx(self, metadata: dict[str, Any],
                              total_frames: int) -> list[int]:
        video_processor = self.get_video_processor()

        video_fps = metadata.get("fps", video_processor.fps)
        meta_frames = metadata.get("total_num_frames", total_frames)
        max_frame_idx = meta_frames - 1
        duration = metadata.get("duration",
                                round(max_frame_idx / video_fps) + 1)
        do_sample_frames = metadata["do_sample_frames"]
        if not do_sample_frames:
            frame_indices = metadata["frames_indices"]
        else:
            if duration <= video_processor.max_duration:
                n = int(math.floor(duration * video_processor.fps))
                frame_indices = [
                    min(
                        max_frame_idx,
                        int(math.ceil(i * video_fps / video_processor.fps)),
                    ) for i in range(n)
                ]
            else:
                num_samples = int(video_processor.max_duration *
                                  video_processor.fps)
                if num_samples >= meta_frames:
                    frame_indices = list(range(meta_frames))
                else:
                    target_seconds = np.linspace(0,
                                                 duration,
                                                 num_samples,
                                                 endpoint=True)
                    frame_indices = [
                        min(max_frame_idx, int(math.ceil(t * video_fps)))
                        for t in target_seconds
                    ]

        seen, uniq = set(), []
        for idx in frame_indices:
            if idx not in seen:
                seen.add(idx)
                uniq.append(idx)
        if len(uniq) & 1:
            uniq.append(uniq[-1])
        frame_indices = uniq

        full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
        timestamps_list = full_second_idxs[::2]
        selected_timestamps = []
        for idx in range(0, len(timestamps_list)):
            selected_timestamps.append(timestamps_list[idx])
        return selected_timestamps

    def _construct_video_placeholder(
        self,
        video_array: np.ndarray,
        metadata: dict[str, Any],
        grid_thw: torch.Tensor,
    ) -> str:
        hf_processor = self.get_hf_processor()
        tokenizer = self.get_tokenizer()
        image_processor = hf_processor.image_processor

        hf_config = self.get_hf_config()
        boi_token_id = hf_config.image_start_token_id
        eoi_token_id = hf_config.image_end_token_id
        bov_token_id = hf_config.video_start_token_id
        eov_token_id = hf_config.video_end_token_id
        merge_length = image_processor.merge_size**2

        assert isinstance(grid_thw, torch.Tensor)
        timestamps = self._get_video_second_idx(metadata, len(video_array))
        frames_idx_token = [
            tokenizer.encode(str(i), add_special_tokens=False)
            for i in timestamps
        ]
        T, H, W = grid_thw
        num_tokens_per_frame = int(H * W) // merge_length
        placeholder = []
        placeholder.append(bov_token_id)
        for frame_idx in frames_idx_token:
            placeholder.append(boi_token_id)
            placeholder.extend([hf_processor.video_token_id] *
                               num_tokens_per_frame)
            placeholder.append(eoi_token_id)
            placeholder.extend(frame_idx)
        placeholder.append(eov_token_id)

        return placeholder

_construct_video_placeholder

_construct_video_placeholder(
    video_array: ndarray,
    metadata: dict[str, Any],
    grid_thw: Tensor,
) -> str
Source code in vllm/model_executor/models/glm4_1v.py
def _construct_video_placeholder(
    self,
    video_array: np.ndarray,
    metadata: dict[str, Any],
    grid_thw: torch.Tensor,
) -> str:
    hf_processor = self.get_hf_processor()
    tokenizer = self.get_tokenizer()
    image_processor = hf_processor.image_processor

    hf_config = self.get_hf_config()
    boi_token_id = hf_config.image_start_token_id
    eoi_token_id = hf_config.image_end_token_id
    bov_token_id = hf_config.video_start_token_id
    eov_token_id = hf_config.video_end_token_id
    merge_length = image_processor.merge_size**2

    assert isinstance(grid_thw, torch.Tensor)
    timestamps = self._get_video_second_idx(metadata, len(video_array))
    frames_idx_token = [
        tokenizer.encode(str(i), add_special_tokens=False)
        for i in timestamps
    ]
    T, H, W = grid_thw
    num_tokens_per_frame = int(H * W) // merge_length
    placeholder = []
    placeholder.append(bov_token_id)
    for frame_idx in frames_idx_token:
        placeholder.append(boi_token_id)
        placeholder.extend([hf_processor.video_token_id] *
                           num_tokens_per_frame)
        placeholder.append(eoi_token_id)
        placeholder.extend(frame_idx)
    placeholder.append(eov_token_id)

    return placeholder

_get_max_video_frames

_get_max_video_frames(max_tokens: int) -> int
Source code in vllm/model_executor/models/glm4_1v.py
def _get_max_video_frames(self, max_tokens: int) -> int:
    target_width, target_height = self.get_image_size_with_most_features()

    num_frames = 0

    while True:
        next_num_frames = num_frames + 1
        next_max_tokens = self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
            num_frames=next_num_frames,
        )
        if next_max_tokens > max_tokens or next_max_tokens == 0:
            break

        num_frames = next_num_frames

    return num_frames

_get_video_second_idx

_get_video_second_idx(
    metadata: dict[str, Any], total_frames: int
) -> list[int]
Source code in vllm/model_executor/models/glm4_1v.py
def _get_video_second_idx(self, metadata: dict[str, Any],
                          total_frames: int) -> list[int]:
    video_processor = self.get_video_processor()

    video_fps = metadata.get("fps", video_processor.fps)
    meta_frames = metadata.get("total_num_frames", total_frames)
    max_frame_idx = meta_frames - 1
    duration = metadata.get("duration",
                            round(max_frame_idx / video_fps) + 1)
    do_sample_frames = metadata["do_sample_frames"]
    if not do_sample_frames:
        frame_indices = metadata["frames_indices"]
    else:
        if duration <= video_processor.max_duration:
            n = int(math.floor(duration * video_processor.fps))
            frame_indices = [
                min(
                    max_frame_idx,
                    int(math.ceil(i * video_fps / video_processor.fps)),
                ) for i in range(n)
            ]
        else:
            num_samples = int(video_processor.max_duration *
                              video_processor.fps)
            if num_samples >= meta_frames:
                frame_indices = list(range(meta_frames))
            else:
                target_seconds = np.linspace(0,
                                             duration,
                                             num_samples,
                                             endpoint=True)
                frame_indices = [
                    min(max_frame_idx, int(math.ceil(t * video_fps)))
                    for t in target_seconds
                ]

    seen, uniq = set(), []
    for idx in frame_indices:
        if idx not in seen:
            seen.add(idx)
            uniq.append(idx)
    if len(uniq) & 1:
        uniq.append(uniq[-1])
    frame_indices = uniq

    full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
    timestamps_list = full_second_idxs[::2]
    selected_timestamps = []
    for idx in range(0, len(timestamps_list)):
        selected_timestamps.append(timestamps_list[idx])
    return selected_timestamps

_get_vision_info

_get_vision_info(
    *,
    image_width: int,
    image_height: int,
    num_frames: int = 16,
    do_resize: bool = True,
    max_image_pixels: int = 28 * 28 * 2 * 30000,
) -> tuple[ImageSize, int]
Source code in vllm/model_executor/models/glm4_1v.py
def _get_vision_info(
    self,
    *,
    image_width: int,
    image_height: int,
    num_frames: int = 16,
    do_resize: bool = True,
    max_image_pixels: int = 28 * 28 * 2 * 30000,
) -> tuple[ImageSize, int]:
    hf_config = self.get_hf_config()
    vision_config = hf_config.vision_config
    patch_size = vision_config.patch_size
    merge_size = vision_config.spatial_merge_size
    temporal_patch_size = vision_config.temporal_patch_size
    if do_resize:
        resized_height, resized_width = smart_resize(
            num_frames=num_frames
            if num_frames > temporal_patch_size else temporal_patch_size,
            height=image_height,
            width=image_width,
            factor=patch_size * merge_size,
            max_pixels=max_image_pixels,
        )
        preprocessed_size = ImageSize(width=resized_width,
                                      height=resized_height)
    else:
        preprocessed_size = ImageSize(width=image_width,
                                      height=image_height)

    # NOTE: Frames are padded to be divisible by `temporal_patch_size`
    # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
    padded_num_frames = num_frames + num_frames % temporal_patch_size

    grid_t = max(padded_num_frames // temporal_patch_size, 1)
    grid_h = preprocessed_size.height // patch_size
    grid_w = preprocessed_size.width // patch_size

    num_patches = grid_t * grid_h * grid_w
    num_vision_tokens = num_patches // (merge_size**2)

    return preprocessed_size, num_vision_tokens

get_hf_config

get_hf_config()
Source code in vllm/model_executor/models/glm4_1v.py
def get_hf_config(self):
    return self.ctx.get_hf_config()

get_image_processor

get_image_processor(
    **kwargs: object,
) -> Glm4vImageProcessor
Source code in vllm/model_executor/models/glm4_1v.py
def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor:
    return self.get_hf_processor(**kwargs).image_processor

get_image_size_with_most_features

get_image_size_with_most_features() -> ImageSize
Source code in vllm/model_executor/models/glm4_1v.py
def get_image_size_with_most_features(self) -> ImageSize:
    max_image_size, _ = self._get_vision_info(image_width=9999999,
                                              image_height=9999999)
    return max_image_size

get_max_image_tokens

get_max_image_tokens() -> int
Source code in vllm/model_executor/models/glm4_1v.py
def get_max_image_tokens(self) -> int:
    target_width, target_height = self.get_image_size_with_most_features()

    return self.get_num_image_tokens(
        image_width=target_width,
        image_height=target_height,
    )

get_num_frames_with_most_features

get_num_frames_with_most_features(
    seq_len: int, mm_counts: Mapping[str, int]
) -> int
Source code in vllm/model_executor/models/glm4_1v.py
def get_num_frames_with_most_features(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> int:
    max_images = mm_counts.get("image", 0)
    max_videos = mm_counts.get("video", 0)

    max_image_tokens = self.get_max_image_tokens() * max_images
    max_total_frames = self._get_max_video_frames(seq_len -
                                                  max_image_tokens)
    max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                               _MAX_FRAMES_PER_VIDEO)

    return max(max_frames_per_video, 1)

get_num_image_tokens

get_num_image_tokens(
    *, image_width: int, image_height: int
) -> int
Source code in vllm/model_executor/models/glm4_1v.py
def get_num_image_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
) -> int:
    _, num_image_tokens = self._get_vision_info(
        image_width=image_width,
        image_height=image_height,
        max_image_pixels=28 * 28 * 2 * 6144,
    )
    return num_image_tokens

get_num_video_tokens

get_num_video_tokens(
    *, image_width: int, image_height: int, num_frames: int
) -> int
Source code in vllm/model_executor/models/glm4_1v.py
def get_num_video_tokens(
    self,
    *,
    image_width: int,
    image_height: int,
    num_frames: int,
) -> int:
    _, num_video_tokens = self._get_vision_info(
        image_width=image_width,
        image_height=image_height,
        num_frames=num_frames,
        max_image_pixels=28 * 28 * 2 * 30000,
    )
    return num_video_tokens

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, Optional[int]]
Source code in vllm/model_executor/models/glm4_1v.py
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
    return {"image": None, "video": 1}

get_tokenizer

get_tokenizer()
Source code in vllm/model_executor/models/glm4_1v.py
def get_tokenizer(self):
    return self.ctx.tokenizer

get_video_processor

get_video_processor(
    **kwargs: object,
) -> Glm4vVideoProcessor
Source code in vllm/model_executor/models/glm4_1v.py
def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor:
    return self.get_hf_processor(**kwargs).video_processor

Glm4vVideoEmbeddingInputs

Bases: TensorSchema

Dimensions
  • p: Number of video patches across all frames
  • h: Hidden size (must match language model backbone)
  • f: Number of frames
  • g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w)
Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - p: Number of video patches across all frames
        - h: Hidden size (must match language model backbone)
        - f: Number of frames
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
          video, grid_h, grid_w)
    """
    type: Literal["video_embeds"] = "video_embeds"

    video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")]
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]

type class-attribute instance-attribute

type: Literal['video_embeds'] = 'video_embeds'

video_embeds instance-attribute

video_embeds: Annotated[Tensor, TensorShape(p, h)]

video_grid_thw instance-attribute

video_grid_thw: Annotated[Tensor, TensorShape(f, 3)]

Glm4vVideoPixelInputs

Bases: TensorSchema

Dimensions
  • np: Number of patches
  • ctpp: Number of channels * temporal_patch_size * patch_size * patch_size
  • f: Number of frames
  • g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w)
Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: Number of patches
        - ctpp: Number of channels * temporal_patch_size *
            patch_size * patch_size
        - f: Number of frames
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
          video, grid_h, grid_w)
    """
    type: Literal["pixel_values_videos"] = "pixel_values_videos"

    pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")]
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]

pixel_values_videos instance-attribute

pixel_values_videos: Annotated[
    Tensor, TensorShape(numpy, ctpp)
]

type class-attribute instance-attribute

type: Literal["pixel_values_videos"] = "pixel_values_videos"

video_grid_thw instance-attribute

video_grid_thw: Annotated[Tensor, TensorShape(f, 3)]

Glm4vVisionAttention

Bases: Module

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVisionAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        projection_size: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
        self.tp_size = (1 if use_data_parallel else
                        get_tensor_model_parallel_world_size())
        self.tp_rank = (0 if use_data_parallel else
                        parallel_state.get_tensor_model_parallel_rank())
        self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
            num_heads, self.tp_size)

        self.qkv = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            total_num_kv_heads=num_heads,
            bias=False,
            quant_config=quant_config,
            # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
            prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            bias=False,
            disable_tp=use_data_parallel,
        )

        # Detect attention implementation.
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
            dtype=torch.get_default_dtype())
        self.use_upstream_fa = False
        if self.attn_backend != _Backend.FLASH_ATTN and \
            check_upstream_fa_availability(torch.get_default_dtype()):
            self.attn_backend = _Backend.FLASH_ATTN
            self.use_upstream_fa = True

        if self.attn_backend not in {
                _Backend.FLASH_ATTN,
                _Backend.TORCH_SDPA,
                _Backend.XFORMERS,
        }:
            raise RuntimeError(
                f"GLM-4V does not support {self.attn_backend} backend now.")

    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

    def forward(
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
        # [s, b, c] --> [s, b, head * 3 * head_dim]
        x, _ = self.qkv(x)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
        batch_size = q.shape[1]

        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
                   for x in (q, k, v))
        if rotary_pos_emb is not None:
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)

        if self.attn_backend == _Backend.FLASH_ATTN:
            # from vllm_flash_attn.flash_attn_interface import (
            #   flash_attn_varlen_func)
            if self.use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
            else:
                from vllm.vllm_flash_attn import flash_attn_varlen_func

            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

            output = flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0,
                causal=False,
            )

            context_layer = rearrange(output,
                                      "(b s) h d -> s b (h d)",
                                      b=batch_size).contiguous()
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
            for i in range(1, len(cu_seqlens)):
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
                                 for x in [q_i, k_i, v_i])
                output_i = F.scaled_dot_product_attention(q_i,
                                                          k_i,
                                                          v_i,
                                                          dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
            context_layer = rearrange(context_layer,
                                      "b s h d -> s b (h d)").contiguous()
        elif self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

            attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
                                                       kv_seqlen=None,
                                                       device=q.device)

            context_layer = xops.memory_efficient_attention_forward(
                q, k, v, attn_bias=attn_bias, p=0, scale=None)
            context_layer = rearrange(context_layer,
                                      "b s h d -> s b (h d)").contiguous()

        output, _ = self.proj(context_layer)
        return output

attn_backend instance-attribute

attn_backend = get_vit_attn_backend(
    head_size=hidden_size_per_attention_head,
    dtype=get_default_dtype(),
)

hidden_size_per_attention_head instance-attribute

hidden_size_per_attention_head = divide(
    projection_size, num_heads
)

num_attention_heads_per_partition instance-attribute

num_attention_heads_per_partition = divide(
    num_heads, tp_size
)

proj instance-attribute

proj = RowParallelLinear(
    input_size=projection_size,
    output_size=embed_dim,
    quant_config=quant_config,
    prefix=f"{prefix}.proj",
    bias=False,
    disable_tp=use_data_parallel,
)

qkv instance-attribute

qkv = QKVParallelLinear(
    hidden_size=embed_dim,
    head_size=hidden_size_per_attention_head,
    total_num_heads=num_heads,
    total_num_kv_heads=num_heads,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj"
    if quant_config
    else f"{prefix}.qkv",
    disable_tp=use_data_parallel,
)

tp_rank instance-attribute

tp_rank = (
    0
    if use_data_parallel
    else get_tensor_model_parallel_rank()
)

tp_size instance-attribute

tp_size = (
    1
    if use_data_parallel
    else get_tensor_model_parallel_world_size()
)

use_upstream_fa instance-attribute

use_upstream_fa = False

__init__

__init__(
    embed_dim: int,
    num_heads: int,
    projection_size: int,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
) -> None
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    projection_size: int,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
) -> None:
    super().__init__()
    # Per attention head and per partition values.
    self.tp_size = (1 if use_data_parallel else
                    get_tensor_model_parallel_world_size())
    self.tp_rank = (0 if use_data_parallel else
                    parallel_state.get_tensor_model_parallel_rank())
    self.hidden_size_per_attention_head = dist_utils.divide(
        projection_size, num_heads)
    self.num_attention_heads_per_partition = dist_utils.divide(
        num_heads, self.tp_size)

    self.qkv = QKVParallelLinear(
        hidden_size=embed_dim,
        head_size=self.hidden_size_per_attention_head,
        total_num_heads=num_heads,
        total_num_kv_heads=num_heads,
        bias=False,
        quant_config=quant_config,
        # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
        prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
        disable_tp=use_data_parallel,
    )
    self.proj = RowParallelLinear(
        input_size=projection_size,
        output_size=embed_dim,
        quant_config=quant_config,
        prefix=f"{prefix}.proj",
        bias=False,
        disable_tp=use_data_parallel,
    )

    # Detect attention implementation.
    self.attn_backend = get_vit_attn_backend(
        head_size=self.hidden_size_per_attention_head,
        dtype=torch.get_default_dtype())
    self.use_upstream_fa = False
    if self.attn_backend != _Backend.FLASH_ATTN and \
        check_upstream_fa_availability(torch.get_default_dtype()):
        self.attn_backend = _Backend.FLASH_ATTN
        self.use_upstream_fa = True

    if self.attn_backend not in {
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
    }:
        raise RuntimeError(
            f"GLM-4V does not support {self.attn_backend} backend now.")

forward

forward(
    x: Tensor,
    cu_seqlens: Tensor,
    rotary_pos_emb: Tensor,
    max_seqlen: Optional[int] = None,
    seqlens: Optional[list[int]] = None,
) -> Tensor
Source code in vllm/model_executor/models/glm4_1v.py
def forward(
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
        max_seqlen: Optional[int] = None,  # Only used for Flash Attention
        seqlens: Optional[list[int]] = None,  # Only used for xFormers
) -> torch.Tensor:
    # [s, b, c] --> [s, b, head * 3 * head_dim]
    x, _ = self.qkv(x)

    # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
    q, k, v = self.split_qkv(x)
    batch_size = q.shape[1]

    q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
               for x in (q, k, v))
    if rotary_pos_emb is not None:
        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
        qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
        q, k = torch.chunk(qk_rotated, 2, dim=0)

    if self.attn_backend == _Backend.FLASH_ATTN:
        # from vllm_flash_attn.flash_attn_interface import (
        #   flash_attn_varlen_func)
        if self.use_upstream_fa:
            from flash_attn import flash_attn_varlen_func
        else:
            from vllm.vllm_flash_attn import flash_attn_varlen_func

        q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

        output = flash_attn_varlen_func(
            q,
            k,
            v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            dropout_p=0,
            causal=False,
        )

        context_layer = rearrange(output,
                                  "(b s) h d -> s b (h d)",
                                  b=batch_size).contiguous()
    elif self.attn_backend == _Backend.TORCH_SDPA:
        # Execute attention entry by entry for speed & less VRAM.
        outputs = []
        for i in range(1, len(cu_seqlens)):
            start_idx = cu_seqlens[i - 1]
            end_idx = cu_seqlens[i]
            q_i = q[:, start_idx:end_idx]
            k_i = k[:, start_idx:end_idx]
            v_i = v[:, start_idx:end_idx]
            q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
                             for x in [q_i, k_i, v_i])
            output_i = F.scaled_dot_product_attention(q_i,
                                                      k_i,
                                                      v_i,
                                                      dropout_p=0.0)
            output_i = rearrange(output_i, "b h s d -> b s h d ")
            outputs.append(output_i)
        context_layer = torch.cat(outputs, dim=1)
        context_layer = rearrange(context_layer,
                                  "b s h d -> s b (h d)").contiguous()
    elif self.attn_backend == _Backend.XFORMERS:
        from xformers import ops as xops
        from xformers.ops.fmha.attn_bias import BlockDiagonalMask

        attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
                                                   kv_seqlen=None,
                                                   device=q.device)

        context_layer = xops.memory_efficient_attention_forward(
            q, k, v, attn_bias=attn_bias, p=0, scale=None)
        context_layer = rearrange(context_layer,
                                  "b s h d -> s b (h d)").contiguous()

    output, _ = self.proj(context_layer)
    return output

split_qkv

split_qkv(qkv: Tensor) -> tuple[Tensor, ...]
Source code in vllm/model_executor/models/glm4_1v.py
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
    # [s, b, 3 * head * head_dim]
    seq_len, bs, _ = qkv.shape

    # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
    q, k, v = qkv.chunk(3, dim=2)

    # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
    new_shape = (
        seq_len,
        bs,
        self.num_attention_heads_per_partition,
        self.hidden_size_per_attention_head,
    )
    q, k, v = (x.view(*new_shape) for x in (q, k, v))
    return q, k, v

Glm4vVisionBlock

Bases: Module

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = Glm4vVisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
        )
        self.mlp = Glm4vVisionMLP(
            dim,
            mlp_hidden_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )

    def forward(
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
        x_attn = self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )
        x_fused_norm, residual = self.norm2(x, residual=x_attn)
        x = residual + self.mlp(x_fused_norm)

        return x

attn instance-attribute

attn = Glm4vVisionAttention(
    embed_dim=dim,
    num_heads=num_heads,
    projection_size=dim,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
    use_data_parallel=use_data_parallel,
)

mlp instance-attribute

mlp = Glm4vVisionMLP(
    dim,
    mlp_hidden_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
    use_data_parallel=use_data_parallel,
)

norm1 instance-attribute

norm1 = norm_layer(dim)

norm2 instance-attribute

norm2 = norm_layer(dim)

__init__

__init__(
    dim: int,
    num_heads: int,
    mlp_hidden_dim: int,
    norm_layer: Optional[Callable[[int], Module]] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
) -> None
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(
    self,
    dim: int,
    num_heads: int,
    mlp_hidden_dim: int,
    norm_layer: Optional[Callable[[int], nn.Module]] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
) -> None:
    super().__init__()
    if norm_layer is None:
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
    self.norm1 = norm_layer(dim)
    self.norm2 = norm_layer(dim)
    self.attn = Glm4vVisionAttention(
        embed_dim=dim,
        num_heads=num_heads,
        projection_size=dim,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
        use_data_parallel=use_data_parallel,
    )
    self.mlp = Glm4vVisionMLP(
        dim,
        mlp_hidden_dim,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.mlp",
        use_data_parallel=use_data_parallel,
    )

forward

forward(
    x: Tensor,
    cu_seqlens: Tensor,
    rotary_pos_emb: Tensor,
    max_seqlen: Optional[int] = None,
    seqlens: Optional[list[int]] = None,
) -> Tensor
Source code in vllm/model_executor/models/glm4_1v.py
def forward(
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
        max_seqlen: Optional[int] = None,  # Only used for Flash Attention
        seqlens: Optional[list[int]] = None,  # Only used for xFormers
) -> torch.Tensor:
    x_attn = self.attn(
        self.norm1(x),
        cu_seqlens=cu_seqlens,
        rotary_pos_emb=rotary_pos_emb,
        max_seqlen=max_seqlen,
        seqlens=seqlens,
    )
    x_fused_norm, residual = self.norm2(x, residual=x_attn)
    x = residual + self.mlp(x_fused_norm)

    return x

Glm4vVisionEmbeddings

Bases: Module

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVisionEmbeddings(nn.Module):

    def __init__(self, config: Glm4vVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )

    def forward(self, embeddings, lengths, image_shapes, h_coords,
                w_coords) -> torch.Tensor:
        pos_embed_weight = self.position_embedding.weight
        hidden_size = pos_embed_weight.shape[1]
        total_seq = h_coords.shape[0]
        device = pos_embed_weight.device

        # Move coordinates to correct device
        h_coords, w_coords = h_coords.to(device), w_coords.to(device)

        # Handle empty sequence case
        if total_seq == 0:
            adapted_pos_embed = torch.empty(0,
                                            hidden_size,
                                            device=device,
                                            dtype=pos_embed_weight.dtype)
        else:
            # Convert inputs to tensors if needed
            if isinstance(lengths, list):
                lengths = torch.tensor(lengths,
                                       device=device,
                                       dtype=torch.long)
            if not isinstance(image_shapes, torch.Tensor):
                image_shapes = torch.tensor(image_shapes,
                                            device=device,
                                            dtype=torch.long)

            # Prepare 2D position embedding
            orig_size_sq = pos_embed_weight.shape[0]
            orig_size = int(orig_size_sq**0.5)
            pos_embed_2d = (pos_embed_weight.view(
                orig_size, orig_size,
                hidden_size).permute(2, 0,
                                     1).unsqueeze(0).to(device=device,
                                                        dtype=torch.float32))

            # Calculate target dimensions for each patch
            # Add bounds checking for data parallel mode
            if len(lengths) > image_shapes.shape[0]:
                # In data parallel mode, some GPUs might not have all
                # image shapes
                # Use available image shapes, cycling if necessary
                target_h_list = []
                target_w_list = []
                for i in range(len(lengths)):
                    # Cycle through available shapes
                    shape_idx = i % image_shapes.shape[0]
                    target_h_list.append(image_shapes[shape_idx,
                                                      1].repeat(lengths[i]))
                    target_w_list.append(image_shapes[shape_idx,
                                                      2].repeat(lengths[i]))
                target_h = torch.cat(target_h_list).to(device=device,
                                                       dtype=torch.float32)
                target_w = torch.cat(target_w_list).to(device=device,
                                                       dtype=torch.float32)
            else:
                target_h = torch.cat([
                    image_shapes[i, 1].repeat(lengths[i])
                    for i in range(len(lengths))
                ]).to(device=device, dtype=torch.float32)
                target_w = torch.cat([
                    image_shapes[i, 2].repeat(lengths[i])
                    for i in range(len(lengths))
                ]).to(device=device, dtype=torch.float32)

            # Normalize coordinates to [-1, 1] range for grid_sample
            h_coords = h_coords.to(device=device, dtype=torch.float32)
            w_coords = w_coords.to(device=device, dtype=torch.float32)
            norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
            norm_h = ((h_coords + 0.5) / target_h) * 2 - 1

            # Create sampling grid
            grid = (torch.stack((norm_w, norm_h),
                                dim=-1).unsqueeze(0).unsqueeze(2))

            # Perform bicubic interpolation
            interpolated_embed_fp32 = F.grid_sample(
                pos_embed_2d,
                grid,
                mode="bicubic",
                align_corners=False,
                padding_mode="border",
            )

            # Reshape and convert back to original dtype
            adapted_pos_embed_fp32 = (
                interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0))
            adapted_pos_embed = adapted_pos_embed_fp32.to(
                pos_embed_weight.dtype).to(embeddings.device)

        # Add adapted position encoding to embeddings
        embeddings = embeddings + adapted_pos_embed
        return embeddings

config instance-attribute

config = config

embed_dim instance-attribute

embed_dim = hidden_size

image_size instance-attribute

image_size = image_size

num_patches instance-attribute

num_patches = (image_size // patch_size) ** 2

num_positions instance-attribute

num_positions = num_patches

patch_size instance-attribute

patch_size = patch_size

position_embedding instance-attribute

position_embedding = Embedding(num_positions, embed_dim)

__init__

__init__(config: Glm4vVisionConfig)
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(self, config: Glm4vVisionConfig):
    super().__init__()
    self.config = config
    self.embed_dim = config.hidden_size
    self.image_size = config.image_size
    self.patch_size = config.patch_size

    self.num_patches = (self.image_size // self.patch_size)**2
    self.num_positions = self.num_patches
    self.position_embedding = nn.Embedding(self.num_positions,
                                           self.embed_dim)
    self.register_buffer(
        "position_ids",
        torch.arange(self.num_positions).expand((1, -1)),
        persistent=False,
    )

forward

forward(
    embeddings, lengths, image_shapes, h_coords, w_coords
) -> Tensor
Source code in vllm/model_executor/models/glm4_1v.py
def forward(self, embeddings, lengths, image_shapes, h_coords,
            w_coords) -> torch.Tensor:
    pos_embed_weight = self.position_embedding.weight
    hidden_size = pos_embed_weight.shape[1]
    total_seq = h_coords.shape[0]
    device = pos_embed_weight.device

    # Move coordinates to correct device
    h_coords, w_coords = h_coords.to(device), w_coords.to(device)

    # Handle empty sequence case
    if total_seq == 0:
        adapted_pos_embed = torch.empty(0,
                                        hidden_size,
                                        device=device,
                                        dtype=pos_embed_weight.dtype)
    else:
        # Convert inputs to tensors if needed
        if isinstance(lengths, list):
            lengths = torch.tensor(lengths,
                                   device=device,
                                   dtype=torch.long)
        if not isinstance(image_shapes, torch.Tensor):
            image_shapes = torch.tensor(image_shapes,
                                        device=device,
                                        dtype=torch.long)

        # Prepare 2D position embedding
        orig_size_sq = pos_embed_weight.shape[0]
        orig_size = int(orig_size_sq**0.5)
        pos_embed_2d = (pos_embed_weight.view(
            orig_size, orig_size,
            hidden_size).permute(2, 0,
                                 1).unsqueeze(0).to(device=device,
                                                    dtype=torch.float32))

        # Calculate target dimensions for each patch
        # Add bounds checking for data parallel mode
        if len(lengths) > image_shapes.shape[0]:
            # In data parallel mode, some GPUs might not have all
            # image shapes
            # Use available image shapes, cycling if necessary
            target_h_list = []
            target_w_list = []
            for i in range(len(lengths)):
                # Cycle through available shapes
                shape_idx = i % image_shapes.shape[0]
                target_h_list.append(image_shapes[shape_idx,
                                                  1].repeat(lengths[i]))
                target_w_list.append(image_shapes[shape_idx,
                                                  2].repeat(lengths[i]))
            target_h = torch.cat(target_h_list).to(device=device,
                                                   dtype=torch.float32)
            target_w = torch.cat(target_w_list).to(device=device,
                                                   dtype=torch.float32)
        else:
            target_h = torch.cat([
                image_shapes[i, 1].repeat(lengths[i])
                for i in range(len(lengths))
            ]).to(device=device, dtype=torch.float32)
            target_w = torch.cat([
                image_shapes[i, 2].repeat(lengths[i])
                for i in range(len(lengths))
            ]).to(device=device, dtype=torch.float32)

        # Normalize coordinates to [-1, 1] range for grid_sample
        h_coords = h_coords.to(device=device, dtype=torch.float32)
        w_coords = w_coords.to(device=device, dtype=torch.float32)
        norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
        norm_h = ((h_coords + 0.5) / target_h) * 2 - 1

        # Create sampling grid
        grid = (torch.stack((norm_w, norm_h),
                            dim=-1).unsqueeze(0).unsqueeze(2))

        # Perform bicubic interpolation
        interpolated_embed_fp32 = F.grid_sample(
            pos_embed_2d,
            grid,
            mode="bicubic",
            align_corners=False,
            padding_mode="border",
        )

        # Reshape and convert back to original dtype
        adapted_pos_embed_fp32 = (
            interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0))
        adapted_pos_embed = adapted_pos_embed_fp32.to(
            pos_embed_weight.dtype).to(embeddings.device)

    # Add adapted position encoding to embeddings
    embeddings = embeddings + adapted_pos_embed
    return embeddings

Glm4vVisionMLP

Bases: Module

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVisionMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=in_features,
            output_sizes=[hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
            disable_tp=use_data_parallel,
        )
        self.down_proj = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
            disable_tp=use_data_parallel,
        )
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = SiluAndMul()

down_proj instance-attribute

down_proj = RowParallelLinear(
    hidden_features,
    in_features,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
    disable_tp=use_data_parallel,
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    input_size=in_features,
    output_sizes=[hidden_features] * 2,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
    disable_tp=use_data_parallel,
)

__init__

__init__(
    in_features: int,
    hidden_features: int,
    bias: bool = False,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
)
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(
    self,
    in_features: int,
    hidden_features: int,
    bias: bool = False,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
):
    super().__init__()
    self.gate_up_proj = MergedColumnParallelLinear(
        input_size=in_features,
        output_sizes=[hidden_features] * 2,
        bias=bias,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
        disable_tp=use_data_parallel,
    )
    self.down_proj = RowParallelLinear(
        hidden_features,
        in_features,
        bias=bias,
        quant_config=quant_config,
        prefix=f"{prefix}.down_proj",
        disable_tp=use_data_parallel,
    )
    self.act_fn = SiluAndMul()

forward

forward(x: Tensor)
Source code in vllm/model_executor/models/glm4_1v.py
def forward(self, x: torch.Tensor):
    x, _ = self.gate_up_proj(x)
    x = self.act_fn(x)
    x, _ = self.down_proj(x)
    return x

Glm4vVisionPatchEmbed

Bases: Module

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 1,
        in_channels: int = 3,
        hidden_size: int = 1536,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(
            in_channels,
            hidden_size,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                   self.patch_size)
        x = self.proj(x).view(L, self.hidden_size)
        return x

hidden_size instance-attribute

hidden_size = hidden_size

patch_size instance-attribute

patch_size = patch_size

proj instance-attribute

proj = Conv3d(
    in_channels,
    hidden_size,
    kernel_size=kernel_size,
    stride=kernel_size,
    bias=True,
)

temporal_patch_size instance-attribute

temporal_patch_size = temporal_patch_size

__init__

__init__(
    patch_size: int = 14,
    temporal_patch_size: int = 1,
    in_channels: int = 3,
    hidden_size: int = 1536,
) -> None
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(
    self,
    patch_size: int = 14,
    temporal_patch_size: int = 1,
    in_channels: int = 3,
    hidden_size: int = 1536,
) -> None:
    super().__init__()
    self.patch_size = patch_size
    self.temporal_patch_size = temporal_patch_size
    self.hidden_size = hidden_size

    kernel_size = (temporal_patch_size, patch_size, patch_size)
    self.proj = nn.Conv3d(
        in_channels,
        hidden_size,
        kernel_size=kernel_size,
        stride=kernel_size,
        bias=True,
    )

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/glm4_1v.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    L, C = x.shape
    x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
               self.patch_size)
    x = self.proj(x).view(L, self.hidden_size)
    return x

Glm4vVisionRotaryEmbedding

Bases: Module

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
            self.inv_freq = 1.0 / (self.theta**(torch.arange(
                0,
                self.dim,
                2,
                dtype=torch.float,
                device=self.inv_freq.device,
            ) / self.dim))
            seq = torch.arange(seqlen,
                               device=self.inv_freq.device,
                               dtype=self.inv_freq.dtype)
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]

_freqs_cached instance-attribute

_freqs_cached = None

_seq_len_cached instance-attribute

_seq_len_cached = 0

dim instance-attribute

dim = dim

theta instance-attribute

theta = theta

__init__

__init__(dim: int, theta: float = 10000.0) -> None
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(self, dim: int, theta: float = 10000.0) -> None:
    super().__init__()
    self.dim = dim
    self.theta = theta
    inv_freq = 1.0 / (theta
                      **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
    self.register_buffer("inv_freq", inv_freq, persistent=False)
    self._seq_len_cached = 0
    self._freqs_cached = None

forward

forward(seqlen: int) -> Tensor
Source code in vllm/model_executor/models/glm4_1v.py
def forward(self, seqlen: int) -> torch.Tensor:
    self.update_freqs_cache(seqlen)
    return self._freqs_cached[:seqlen]

update_freqs_cache

update_freqs_cache(seqlen: int) -> None
Source code in vllm/model_executor/models/glm4_1v.py
def update_freqs_cache(self, seqlen: int) -> None:
    if seqlen > self._seq_len_cached:
        seqlen *= 2
        self._seq_len_cached = seqlen
        self.inv_freq = 1.0 / (self.theta**(torch.arange(
            0,
            self.dim,
            2,
            dtype=torch.float,
            device=self.inv_freq.device,
        ) / self.dim))
        seq = torch.arange(seqlen,
                           device=self.inv_freq.device,
                           dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq, self.inv_freq)
        self._freqs_cached = freqs

Glm4vVisionTransformer

Bases: Module

Source code in vllm/model_executor/models/glm4_1v.py
class Glm4vVisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Glm4vVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ) -> None:
        super().__init__()

        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        in_channels = vision_config.in_channels
        depth = vision_config.depth
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
        self.use_data_parallel = use_data_parallel

        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.out_hidden_size = vision_config.out_hidden_size

        self.patch_embed = Glm4vVisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
            in_channels=in_channels,
            hidden_size=self.hidden_size,
        )

        norm_layer = partial(RMSNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
        self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
        self.blocks = nn.ModuleList([
            Glm4vVisionBlock(
                dim=self.hidden_size,
                num_heads=self.num_heads,
                mlp_hidden_dim=vision_config.out_hidden_size,
                norm_layer=norm_layer,
                quant_config=quant_config,
                prefix=f"{prefix}.blocks.{layer_idx}",
                use_data_parallel=self.use_data_parallel,
            ) for layer_idx in range(depth)
        ])
        self.merger = Glm4vPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=vision_config.intermediate_size,
            quant_config=quant_config,
            bias=False,
            prefix=f"{prefix}.merger",
            use_data_parallel=self.use_data_parallel,
        )
        self.embeddings = Glm4vVisionEmbeddings(vision_config)

        self.post_conv_layernorm = RMSNorm(vision_config.hidden_size,
                                           eps=vision_config.rms_norm_eps)
        self.downsample = nn.Conv2d(
            in_channels=vision_config.hidden_size,
            out_channels=vision_config.out_hidden_size,
            kernel_size=vision_config.spatial_merge_size,
            stride=vision_config.spatial_merge_size,
        )
        self.post_layernorm = RMSNorm(vision_config.hidden_size,
                                      eps=vision_config.rms_norm_eps)

        self.attn_backend = get_vit_attn_backend(
            head_size=head_dim, dtype=torch.get_default_dtype())
        if self.attn_backend != _Backend.FLASH_ATTN and \
            check_upstream_fa_availability(torch.get_default_dtype()):
            self.attn_backend = _Backend.FLASH_ATTN

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            hpos_ids = (hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten())
            wpos_ids = (wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten())
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb, pos_ids

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
    ) -> tuple[Optional[int], Optional[list[int]]]:
        max_seqlen, seqlens = None, None
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        if self.attn_backend == _Backend.FLASH_ATTN:
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        return max_seqlen, seqlens

    def forward(
        self,
        x: torch.Tensor,
        grid_thw: list[list[int]],
    ) -> torch.Tensor:
        # Convert grid_thw to tensor (always expecting list format now)
        grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)

        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)
        x = self.post_conv_layernorm(x)

        # compute position embedding
        rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
        # compute cu_seqlens
        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                             grid_thw[:, 0]).cumsum(
                                                 dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
        x = self.embeddings(x, seqlens, grid_thw, image_type_ids[:, 0],
                            image_type_ids[:, 1])

        # transformers
        x = x.unsqueeze(1)
        for blk in self.blocks:
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )

        # adapter
        x = self.post_layernorm(x)

        x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size,
                   x.shape[-1])
        x = x.permute(0, 3, 1, 2)
        x = self.downsample(x).view(-1, self.out_hidden_size)
        x = self.merger(x)

        return x

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

attn_backend instance-attribute

attn_backend = get_vit_attn_backend(
    head_size=head_dim, dtype=get_default_dtype()
)

blocks instance-attribute

blocks = ModuleList(
    [
        (
            Glm4vVisionBlock(
                dim=hidden_size,
                num_heads=num_heads,
                mlp_hidden_dim=out_hidden_size,
                norm_layer=norm_layer,
                quant_config=quant_config,
                prefix=f"{prefix}.blocks.{layer_idx}",
                use_data_parallel=use_data_parallel,
            )
        )
        for layer_idx in (range(depth))
    ]
)

device property

device: device

downsample instance-attribute

downsample = Conv2d(
    in_channels=hidden_size,
    out_channels=out_hidden_size,
    kernel_size=spatial_merge_size,
    stride=spatial_merge_size,
)

dtype property

dtype: dtype

embeddings instance-attribute

embeddings = Glm4vVisionEmbeddings(vision_config)

hidden_size instance-attribute

hidden_size = hidden_size

merger instance-attribute

merger = Glm4vPatchMerger(
    d_model=out_hidden_size,
    context_dim=intermediate_size,
    quant_config=quant_config,
    bias=False,
    prefix=f"{prefix}.merger",
    use_data_parallel=use_data_parallel,
)

num_heads instance-attribute

num_heads = num_heads

out_hidden_size instance-attribute

out_hidden_size = out_hidden_size

patch_embed instance-attribute

patch_embed = Glm4vVisionPatchEmbed(
    patch_size=patch_size,
    temporal_patch_size=temporal_patch_size,
    in_channels=in_channels,
    hidden_size=hidden_size,
)

patch_size instance-attribute

patch_size = patch_size

post_conv_layernorm instance-attribute

post_conv_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

post_layernorm instance-attribute

post_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

rotary_pos_emb instance-attribute

rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)

spatial_merge_size instance-attribute

spatial_merge_size = spatial_merge_size

use_data_parallel instance-attribute

use_data_parallel = use_data_parallel

__init__

__init__(
    vision_config: Glm4vVisionConfig,
    norm_eps: float = 1e-06,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
) -> None
Source code in vllm/model_executor/models/glm4_1v.py
def __init__(
    self,
    vision_config: Glm4vVisionConfig,
    norm_eps: float = 1e-6,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
) -> None:
    super().__init__()

    patch_size = vision_config.patch_size
    temporal_patch_size = vision_config.temporal_patch_size
    in_channels = vision_config.in_channels
    depth = vision_config.depth
    self.hidden_size = vision_config.hidden_size
    self.num_heads = vision_config.num_heads
    self.use_data_parallel = use_data_parallel

    self.patch_size = vision_config.patch_size
    self.spatial_merge_size = vision_config.spatial_merge_size
    self.out_hidden_size = vision_config.out_hidden_size

    self.patch_embed = Glm4vVisionPatchEmbed(
        patch_size=patch_size,
        temporal_patch_size=temporal_patch_size,
        in_channels=in_channels,
        hidden_size=self.hidden_size,
    )

    norm_layer = partial(RMSNorm, eps=norm_eps)
    head_dim = self.hidden_size // self.num_heads
    self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
    self.blocks = nn.ModuleList([
        Glm4vVisionBlock(
            dim=self.hidden_size,
            num_heads=self.num_heads,
            mlp_hidden_dim=vision_config.out_hidden_size,
            norm_layer=norm_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.blocks.{layer_idx}",
            use_data_parallel=self.use_data_parallel,
        ) for layer_idx in range(depth)
    ])
    self.merger = Glm4vPatchMerger(
        d_model=vision_config.out_hidden_size,
        context_dim=vision_config.intermediate_size,
        quant_config=quant_config,
        bias=False,
        prefix=f"{prefix}.merger",
        use_data_parallel=self.use_data_parallel,
    )
    self.embeddings = Glm4vVisionEmbeddings(vision_config)

    self.post_conv_layernorm = RMSNorm(vision_config.hidden_size,
                                       eps=vision_config.rms_norm_eps)
    self.downsample = nn.Conv2d(
        in_channels=vision_config.hidden_size,
        out_channels=vision_config.out_hidden_size,
        kernel_size=vision_config.spatial_merge_size,
        stride=vision_config.spatial_merge_size,
    )
    self.post_layernorm = RMSNorm(vision_config.hidden_size,
                                  eps=vision_config.rms_norm_eps)

    self.attn_backend = get_vit_attn_backend(
        head_size=head_dim, dtype=torch.get_default_dtype())
    if self.attn_backend != _Backend.FLASH_ATTN and \
        check_upstream_fa_availability(torch.get_default_dtype()):
        self.attn_backend = _Backend.FLASH_ATTN

compute_attn_mask_seqlen

compute_attn_mask_seqlen(
    cu_seqlens: Tensor,
) -> tuple[Optional[int], Optional[list[int]]]
Source code in vllm/model_executor/models/glm4_1v.py
def compute_attn_mask_seqlen(
    self,
    cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
    max_seqlen, seqlens = None, None
    seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
    if self.attn_backend == _Backend.FLASH_ATTN:
        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
    return max_seqlen, seqlens

forward

forward(x: Tensor, grid_thw: list[list[int]]) -> Tensor
Source code in vllm/model_executor/models/glm4_1v.py
def forward(
    self,
    x: torch.Tensor,
    grid_thw: list[list[int]],
) -> torch.Tensor:
    # Convert grid_thw to tensor (always expecting list format now)
    grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)

    # patchify
    x = x.to(device=self.device, dtype=self.dtype)
    x = self.patch_embed(x)
    x = self.post_conv_layernorm(x)

    # compute position embedding
    rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
    # compute cu_seqlens
    cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                         grid_thw[:, 0]).cumsum(
                                             dim=0, dtype=torch.int32)
    cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

    # pre-compute seqlens for attn mask to reduce cuMemcpy operations
    max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
    x = self.embeddings(x, seqlens, grid_thw, image_type_ids[:, 0],
                        image_type_ids[:, 1])

    # transformers
    x = x.unsqueeze(1)
    for blk in self.blocks:
        x = blk(
            x,
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

    # adapter
    x = self.post_layernorm(x)

    x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size,
               x.shape[-1])
    x = x.permute(0, 3, 1, 2)
    x = self.downsample(x).view(-1, self.out_hidden_size)
    x = self.merger(x)

    return x

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/glm4_1v.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("attn.qkv.", "attn.q.", "q"),
        ("attn.qkv.", "attn.k.", "k"),
        ("attn.qkv.", "attn.v.", "v"),
        ("gate_up_proj", "gate_proj", 0),
        ("gate_up_proj", "up_proj", 1),
    ]
    params_dict = dict(self.named_parameters(remove_duplicate=False))
    loaded_params: set[str] = set()

    for name, loaded_weight in weights:
        for param_name, weight_name, shard_id in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

rot_pos_emb

rot_pos_emb(grid_thw: Tensor) -> Tensor
Source code in vllm/model_executor/models/glm4_1v.py
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
    pos_ids = []
    for t, h, w in grid_thw:
        hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
        wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
        hpos_ids = (hpos_ids.reshape(
            h // self.spatial_merge_size,
            self.spatial_merge_size,
            w // self.spatial_merge_size,
            self.spatial_merge_size,
        ).permute(0, 2, 1, 3).flatten())
        wpos_ids = (wpos_ids.reshape(
            h // self.spatial_merge_size,
            self.spatial_merge_size,
            w // self.spatial_merge_size,
            self.spatial_merge_size,
        ).permute(0, 2, 1, 3).flatten())
        pos_ids.append(
            torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
    pos_ids = torch.cat(pos_ids, dim=0)
    max_grid_size = grid_thw[:, 1:].max()
    rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
    rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
    return rotary_pos_emb, pos_ids

all_gather_interleave

all_gather_interleave(
    local_tensor, hidden_size: int, tp_size: int
)

All-gather the input tensor interleavely across model parallel group.

Source code in vllm/model_executor/models/glm4_1v.py
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
    """All-gather the input tensor interleavely across model parallel group."""
    import torch.distributed as dist

    gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
    dist.all_gather(
        gathered_tensors,
        local_tensor,
        group=parallel_state.get_tp_group().device_group,
    )

    gathered_tensors_split = [
        torch.split(tensor, hidden_size // tp_size, -1)
        for tensor in gathered_tensors
    ]
    ordered_tensors = [
        tensor for pair in zip(*gathered_tensors_split) for tensor in pair
    ]
    result_tensor = torch.cat(ordered_tensors, dim=-1)
    return result_tensor