Skip to content

vllm.lora.layers.base_linear

BaseLinearLayerWithLoRA

Bases: BaseLayerWithLoRA

Source code in vllm/lora/layers/base_linear.py
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: LinearBase):
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        # Ensure tp_size and tp_rank consistency with the base_layer.
        self.tp_size = self.base_layer.tp_size
        self.tp_rank = self.base_layer.tp_rank
        self.device = _get_lora_device(self.base_layer)
        self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
        self.output_slices: tuple[int, ...]
        self.output_size: int
        self.n_slices: int

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
        #
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
            lora_a_out_size = (lora_config.max_lora_rank if
                               not lora_config.fully_sharded_loras else divide(
                                   lora_config.max_lora_rank, self.tp_size))
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = (self.output_size if
                               not lora_config.fully_sharded_loras else divide(
                                   self.output_size, self.tp_size))
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_a_out_size,
                self.input_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            ) for _ in range(self.n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            ) for _ in range(self.n_slices))
        if lora_config.bias_enabled:
            lora_bias_out_size = lora_b_out_size
            self.lora_bias_stacked = tuple(
                torch.zeros(
                    max_loras,
                    1,
                    lora_bias_out_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ) for _ in range(self.n_slices))
        self.output_slices = (self.lora_b_stacked[0].shape[2], )

    def reset_lora(self, index: int):
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0
            if self.lora_config.bias_enabled:
                # Make mypy happy
                self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
                                              self.lora_bias_stacked)
                self.lora_bias_stacked[s_index][index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
        lora_bias: Optional[torch.Tensor] = None,
    ):
        # Except for QKVParallelLinearWithLoRA and
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
        assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
                self.n_slices == 1)

        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)

        self.lora_a_stacked[0][index,
                               0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
                                   lora_a, non_blocking=True)
        self.lora_b_stacked[0][index,
                               0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
                                   lora_b, non_blocking=True)
        if lora_bias is not None:

            self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
            assert len(self.lora_bias_stacked)
            self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
                lora_bias, non_blocking=True)

    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

        # In transformers backend, x and output have extra batch dimension like
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_linear(
                output, x, self.lora_a_stacked, self.lora_b_stacked,
                self.lora_bias_stacked, 1.0, self.output_slices)
        if not current_platform.can_update_inplace():
            output = lora_output

        return output

    @property
    def weight(self) -> torch.Tensor:

        # unquantizedLinear
        if hasattr(self.base_layer, "weight"):
            return self.base_layer.weight
        # Compressed Tensor
        elif hasattr(self.base_layer, "weight_packed"):
            return self.base_layer.weight_packed
        # GPTQ/AWQ
        elif hasattr(self.base_layer, "qweight"):
            return self.base_layer.qweight
        # marlin
        elif hasattr(self.base_layer, "B"):
            return self.base_layer.B
        # HQQ marlin
        elif hasattr(self.base_layer, "W_q"):
            return self.base_layer.W_q
        else:
            raise ValueError(f"Unsupported base layer: {self.base_layer}")

    @property
    def bias(self) -> Optional[torch.Tensor]:
        if hasattr(self.base_layer, "bias"):
            return self.base_layer.bias
        else:
            return None

base_layer instance-attribute

base_layer = base_layer

bias property

device instance-attribute

device = _get_lora_device(base_layer)

input_size instance-attribute

input_size = input_size

lora_bias_stacked instance-attribute

lora_bias_stacked: Optional[tuple[Tensor, ...]] = None

n_slices instance-attribute

n_slices: int

output_size instance-attribute

output_size: int

output_slices instance-attribute

output_slices: tuple[int, ...]

tp_rank instance-attribute

tp_rank = tp_rank

tp_size instance-attribute

tp_size = tp_size

weight property

weight: Tensor

__init__

__init__(base_layer: LinearBase)
Source code in vllm/lora/layers/base_linear.py
def __init__(self, base_layer: LinearBase):
    super().__init__()
    self.base_layer = base_layer
    self.input_size = self.base_layer.input_size
    # Ensure tp_size and tp_rank consistency with the base_layer.
    self.tp_size = self.base_layer.tp_size
    self.tp_rank = self.base_layer.tp_rank
    self.device = _get_lora_device(self.base_layer)
    self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
    self.output_slices: tuple[int, ...]
    self.output_size: int
    self.n_slices: int

apply

apply(x: Tensor, bias: Optional[Tensor] = None) -> Tensor
Source code in vllm/lora/layers/base_linear.py
def apply(self,
          x: torch.Tensor,
          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

    # In transformers backend, x and output have extra batch dimension like
    # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
    # therefore we need to flatten the batch dimensions.
    if x.ndim == 3 and output.ndim == 3:
        output = output.flatten(0, 1)
        x = x.flatten(0, 1)

    lora_output: Optional[
        torch.Tensor] = self.punica_wrapper.add_lora_linear(
            output, x, self.lora_a_stacked, self.lora_b_stacked,
            self.lora_bias_stacked, 1.0, self.output_slices)
    if not current_platform.can_update_inplace():
        output = lora_output

    return output

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> None
Source code in vllm/lora/layers/base_linear.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> None:
    self.lora_config = lora_config
    #
    if isinstance(self.base_layer, ReplicatedLinear):
        lora_a_out_size = lora_config.max_lora_rank
        lora_b_out_size = self.output_size

    elif isinstance(self.base_layer, ColumnParallelLinear):
        lora_a_out_size = (lora_config.max_lora_rank if
                           not lora_config.fully_sharded_loras else divide(
                               lora_config.max_lora_rank, self.tp_size))
        lora_b_out_size = self.output_size

    elif isinstance(self.base_layer, RowParallelLinear):
        lora_a_out_size = lora_config.max_lora_rank
        lora_b_out_size = (self.output_size if
                           not lora_config.fully_sharded_loras else divide(
                               self.output_size, self.tp_size))
    else:
        raise NotImplementedError

    self.lora_a_stacked = tuple(
        torch.zeros(
            max_loras,
            1,
            lora_a_out_size,
            self.input_size,
            dtype=lora_config.lora_dtype,
            device=self.device,
        ) for _ in range(self.n_slices))
    self.lora_b_stacked = tuple(
        torch.zeros(
            max_loras,
            1,
            lora_b_out_size,
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
            device=self.device,
        ) for _ in range(self.n_slices))
    if lora_config.bias_enabled:
        lora_bias_out_size = lora_b_out_size
        self.lora_bias_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_bias_out_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            ) for _ in range(self.n_slices))
    self.output_slices = (self.lora_b_stacked[0].shape[2], )

reset_lora

reset_lora(index: int)
Source code in vllm/lora/layers/base_linear.py
def reset_lora(self, index: int):
    for s_index in range(self.n_slices):
        self.lora_a_stacked[s_index][index] = 0
        self.lora_b_stacked[s_index][index] = 0
        if self.lora_config.bias_enabled:
            # Make mypy happy
            self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
            self.lora_bias_stacked[s_index][index] = 0

set_lora

set_lora(
    index: int,
    lora_a: Tensor,
    lora_b: Tensor,
    embeddings_tensor: Optional[Tensor],
    lora_bias: Optional[Tensor] = None,
)
Source code in vllm/lora/layers/base_linear.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor,
    lora_b: torch.Tensor,
    embeddings_tensor: Optional[torch.Tensor],
    lora_bias: Optional[torch.Tensor] = None,
):
    # Except for QKVParallelLinearWithLoRA and
    # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
    # store weights in a tuple of size 1. These two layers will
    # override this function.
    assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
            self.n_slices == 1)

    self.reset_lora(index)
    if self.tp_size > 1:
        lora_a = self.slice_lora_a(lora_a)
        lora_b = self.slice_lora_b(lora_b)
        if lora_bias is not None:
            lora_bias = self.slice_bias(lora_bias)

    self.lora_a_stacked[0][index,
                           0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
                               lora_a, non_blocking=True)
    self.lora_b_stacked[0][index,
                           0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
                               lora_b, non_blocking=True)
    if lora_bias is not None:

        self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
                                      self.lora_bias_stacked)
        assert len(self.lora_bias_stacked)
        self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
            lora_bias, non_blocking=True)