Skip to content

vllm.lora.layers.vocal_parallel_embedding

VocabParallelEmbeddingWithLoRA

Bases: BaseLayerWithLoRA

Source code in vllm/lora/layers/vocal_parallel_embedding.py
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.embeddings_slice: Optional[tuple[int, int]]
        self.embeddings_weights: Optional[torch.Tensor]

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

        if self.base_layer.num_added_embeddings_per_partition > 0:
            # We can start adding lora weights
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
        bias: Optional[torch.Tensor] = None,
    ):
        self.reset_lora(index)
        # NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
        # so we need transpose here
        self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
            lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
                                lora_b, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ].copy_(embeddings_tensor, non_blocking=True)
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
                    self.embeddings_tensors.shape[2],
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
                assert self.embeddings_weights is not None
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
                                        1, 0)

        # NB: Don't use torch.narrow here. torch.narrow triggers some
        # Dynamic Shape specialization in torch.compile
        num_tokens = x.shape[0]
        indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
        indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]

        full_lora_a_embeddings = F.embedding(
            x + indices_1,
            self.lora_a_stacked_2d,
        )
        full_output = self.base_layer.forward(x +
                                              (indices_0 * added_tokens_mask))

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
                full_lora_a_embeddings.shape[1],
                -1,
            )

        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_embedding(
                full_output,
                full_lora_a_embeddings,
                self.lora_b_stacked,
                add_input=True)

        if not current_platform.can_update_inplace():
            full_output = lora_output

        return full_output.view_as(full_output_org)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return type(source_layer) is VocabParallelEmbedding

    @property
    def weight(self):
        return self.base_layer.weight

base_layer instance-attribute

base_layer = base_layer

embeddings_slice instance-attribute

embeddings_slice: Optional[tuple[int, int]]

embeddings_weights instance-attribute

embeddings_weights: Optional[Tensor]

weight property

weight

__init__

__init__(base_layer: VocabParallelEmbedding) -> None
Source code in vllm/lora/layers/vocal_parallel_embedding.py
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
    super().__init__()
    self.base_layer = base_layer
    self.embeddings_slice: Optional[tuple[int, int]]
    self.embeddings_weights: Optional[torch.Tensor]

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: Optional[PretrainedConfig],
) -> bool
Source code in vllm/lora/layers/vocal_parallel_embedding.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: Optional[PretrainedConfig],
) -> bool:
    return type(source_layer) is VocabParallelEmbedding

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> None
Source code in vllm/lora/layers/vocal_parallel_embedding.py
def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None) -> None:

    if self.base_layer.num_added_embeddings_per_partition > 0:
        # We can start adding lora weights
        self.embeddings_weights = self.base_layer.weight.data[
            self.base_layer.num_org_embeddings_per_partition:self.
            base_layer.num_org_embeddings_per_partition +
            self.base_layer.num_added_embeddings_per_partition]
        self.embeddings_slice = (
            self.base_layer.shard_indices.added_vocab_start_index -
            self.base_layer.org_vocab_size,
            self.base_layer.shard_indices.added_vocab_end_index -
            self.base_layer.org_vocab_size)
        self.base_layer.weight.data[
            self.base_layer.num_org_embeddings_per_partition:].fill_(0)
    else:
        self.embeddings_slice = None
        self.embeddings_weights = None

    self.embeddings_tensors = torch.zeros(
        (
            max_loras,
            lora_config.lora_extra_vocab_size,
            self.base_layer.embedding_dim,
        ),
        dtype=self.base_layer.weight.dtype,
        device=self.base_layer.weight.device,
    )
    self.lora_a_stacked = torch.zeros(
        (
            max_loras,
            self.base_layer.org_vocab_size +
            lora_config.lora_extra_vocab_size,
            lora_config.max_lora_rank,
        ),
        dtype=lora_config.lora_dtype,
        device=self.base_layer.weight.device,
    )
    self.lora_b_stacked = torch.zeros(
        (
            max_loras,
            1,
            self.base_layer.embedding_dim,
            lora_config.max_lora_rank,
        ),
        dtype=lora_config.lora_dtype,
        device=self.base_layer.weight.device,
    )
    self.lora_a_stacked_2d = self.lora_a_stacked.view(
        self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
        self.lora_a_stacked.shape[2],
    )

forward

forward(x: Tensor) -> Tensor
Source code in vllm/lora/layers/vocal_parallel_embedding.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
                                    1, 0)

    # NB: Don't use torch.narrow here. torch.narrow triggers some
    # Dynamic Shape specialization in torch.compile
    num_tokens = x.shape[0]
    indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
    indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]

    full_lora_a_embeddings = F.embedding(
        x + indices_1,
        self.lora_a_stacked_2d,
    )
    full_output = self.base_layer.forward(x +
                                          (indices_0 * added_tokens_mask))

    full_output_org = full_output
    if full_output.ndim == 3:
        full_output = full_output.view(
            full_output.shape[0] * full_output.shape[1], -1)
    if full_lora_a_embeddings.ndim == 3:
        full_lora_a_embeddings = full_lora_a_embeddings.view(
            full_lora_a_embeddings.shape[0] *
            full_lora_a_embeddings.shape[1],
            -1,
        )

    lora_output: Optional[
        torch.Tensor] = self.punica_wrapper.add_lora_embedding(
            full_output,
            full_lora_a_embeddings,
            self.lora_b_stacked,
            add_input=True)

    if not current_platform.can_update_inplace():
        full_output = lora_output

    return full_output.view_as(full_output_org)

reset_lora

reset_lora(index: int)
Source code in vllm/lora/layers/vocal_parallel_embedding.py
def reset_lora(self, index: int):
    self.lora_a_stacked[index] = 0
    self.lora_b_stacked[index] = 0
    self.embeddings_tensors[index] = 0

set_lora

set_lora(
    index: int,
    lora_a: Tensor,
    lora_b: Tensor,
    embeddings_tensor: Optional[Tensor],
    bias: Optional[Tensor] = None,
)
Source code in vllm/lora/layers/vocal_parallel_embedding.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor,
    lora_b: torch.Tensor,
    embeddings_tensor: Optional[torch.Tensor],
    bias: Optional[torch.Tensor] = None,
):
    self.reset_lora(index)
    # NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
    # so we need transpose here
    self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
        lora_a.T, non_blocking=True)
    self.lora_b_stacked[index,
                        0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
                            lora_b, non_blocking=True)
    if embeddings_tensor is not None:
        self.embeddings_tensors[
            index,
            :embeddings_tensor.shape[0],
            :embeddings_tensor.shape[1],
        ].copy_(embeddings_tensor, non_blocking=True)
        if self.embeddings_slice is not None:
            # TODO(yard1): Optimize this copy, we don't need to copy
            # everything, just the modified part
            embeddings = self.embeddings_tensors.view(
                self.embeddings_tensors.shape[0] *
                self.embeddings_tensors.shape[1],
                self.embeddings_tensors.shape[2],
            )[self.embeddings_slice[0]:self.embeddings_slice[1]]
            assert self.embeddings_weights is not None
            self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)