Skip to content

vllm.lora.layers.logits_processor

LogitsProcessorWithLoRA

Bases: BaseLayerWithLoRA

LoRA wrapper for LogitsProcessor, with extra logic to handle the application of the LoRA adapter and added LoRA vocabulary.

Parameters:

Name Type Description Default
base_layer LogitsProcessor

LogitsProcessor layer

required
hidden_size int

hidden size of the model

required
dtype dtype

data type of the model

required
device device

device of the model

required
sharded_to_full_mapping Optional[list[int]]

index mapping from sharded vocab to full vocab received from base_layer.get_sharded_to_full_mapping(). If None, no reindexing will be done.

required
Source code in vllm/lora/layers/logits_processor.py
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """

    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
                 sharded_to_full_mapping: Optional[list[int]]) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping

    @property
    def logits_as_input(self):
        return self.base_layer.logits_as_input

    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

    @property
    def scale(self):
        return self.base_layer.scale

    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

    @property
    def use_all_gather(self):
        return self.base_layer.use_all_gather

    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
            raise ValueError("When using LoRA, vocab size must be "
                             "32000 >= vocab_size <= 257024")
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
        bias: Optional[torch.Tensor] = None,
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
                                lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
                                lora_b, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ] = embeddings_tensor

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
        lm_head: VocabParallelEmbedding,
        embedding_bias: Optional[torch.Tensor] = None,
    ) -> Optional[torch.Tensor]:
        # Get the logits for the next tokens.
        logits = lm_head.quant_method.apply(lm_head, hidden_states)
        if embedding_bias is not None:
            logits += embedding_bias

        # Gather logits for TP
        logits = self.base_layer._gather_logits(logits)

        if logits is None:
            return None

        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])

        neg_inf, pos_inf = current_platform.get_infinity_values(
            lora_logits.dtype)

        lora_logits[-1] = neg_inf
        lora_logits = lora_logits.mT
        indices_padded = self.punica_wrapper.sampler_indices_padded

        if current_platform.is_tpu() or current_platform.is_xpu():
            indices_padded = indices_padded[:logits.size(0)]

        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
        ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
                                                      posinf=pos_inf,
                                                      neginf=neg_inf))

        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
               lora_logits.shape[1]] = lora_logits

        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_logits(
                logits, hidden_states, self.lora_a_stacked,
                self.lora_b_stacked, 1.0)

        if not current_platform.can_update_inplace():
            logits = lora_output

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        # Special handling for the LogitsProcessor.
        return False

base_layer instance-attribute

base_layer = base_layer

device instance-attribute

device = device

dtype instance-attribute

dtype = dtype

hidden_size instance-attribute

hidden_size = hidden_size

include_gpu_probs_tensor property

include_gpu_probs_tensor

logits_as_input property

logits_as_input

org_vocab_size property

org_vocab_size

scale property

scale

sharded_to_full_mapping instance-attribute

sharded_to_full_mapping = sharded_to_full_mapping

should_modify_greedy_probs_inplace property

should_modify_greedy_probs_inplace

soft_cap property

soft_cap

tp_rank instance-attribute

tp_size instance-attribute

use_all_gather property

use_all_gather

vocab_size property

vocab_size

__init__

__init__(
    base_layer: LogitsProcessor,
    hidden_size: int,
    dtype: dtype,
    device: device,
    sharded_to_full_mapping: Optional[list[int]],
) -> None
Source code in vllm/lora/layers/logits_processor.py
def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
             dtype: torch.dtype, device: torch.device,
             sharded_to_full_mapping: Optional[list[int]]) -> None:
    super().__init__()
    self.base_layer = base_layer
    self.hidden_size = hidden_size
    self.dtype = dtype
    self.device = device
    self.tp_size = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()
    self.sharded_to_full_mapping = sharded_to_full_mapping

_get_logits

_get_logits(
    hidden_states: Tensor,
    lm_head: VocabParallelEmbedding,
    embedding_bias: Optional[Tensor] = None,
) -> Optional[Tensor]
Source code in vllm/lora/layers/logits_processor.py
def _get_logits(
    self,
    hidden_states: torch.Tensor,
    lm_head: VocabParallelEmbedding,
    embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
    # Get the logits for the next tokens.
    logits = lm_head.quant_method.apply(lm_head, hidden_states)
    if embedding_bias is not None:
        logits += embedding_bias

    # Gather logits for TP
    logits = self.base_layer._gather_logits(logits)

    if logits is None:
        return None

    if self.sharded_to_full_mapping_gpu is not None:
        # Reindex full logits tensor to ensure 1:1 mapping between
        # index and token_id
        # Example for:
        #   org_vocab_size = 4
        #   added_vocab_size = 2
        #   pad_to_size = 8
        #   tp_size = 2

        # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
        # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

        # Therefore, the mapping is expected to be:
        # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
        # we get:
        # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
        # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
        logits = logits[:, self.sharded_to_full_mapping_gpu]

    lora_logits = torch.empty(
        self.embeddings_tensors.shape[0] + 1,
        self.embeddings_tensors.shape[1],
        hidden_states.shape[0],
        dtype=self.embeddings_tensors.dtype,
        device=self.embeddings_tensors.device,
    )
    torch.matmul(self.embeddings_tensors,
                 hidden_states.T,
                 out=lora_logits[:-1])

    neg_inf, pos_inf = current_platform.get_infinity_values(
        lora_logits.dtype)

    lora_logits[-1] = neg_inf
    lora_logits = lora_logits.mT
    indices_padded = self.punica_wrapper.sampler_indices_padded

    if current_platform.is_tpu() or current_platform.is_xpu():
        indices_padded = indices_padded[:logits.size(0)]

    lora_logits = (lora_logits.reshape(
        lora_logits.shape[0] * lora_logits.shape[1],
        lora_logits.shape[2],
    ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
                                                  posinf=pos_inf,
                                                  neginf=neg_inf))

    logits[:,
           self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
           lora_logits.shape[1]] = lora_logits

    lora_output: Optional[
        torch.Tensor] = self.punica_wrapper.add_lora_logits(
            logits, hidden_states, self.lora_a_stacked,
            self.lora_b_stacked, 1.0)

    if not current_platform.can_update_inplace():
        logits = lora_output

    # Remove paddings in vocab (if any).
    logits = logits[:, :self.base_layer.vocab_size]
    return logits

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: Optional[PretrainedConfig],
) -> bool
Source code in vllm/lora/layers/logits_processor.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: Optional[PretrainedConfig],
) -> bool:
    # Special handling for the LogitsProcessor.
    return False

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> None
Source code in vllm/lora/layers/logits_processor.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
) -> None:
    # TODO: Verify if this condition can be further relaxed
    if 32000 < self.base_layer.vocab_size > 257024:
        raise ValueError("When using LoRA, vocab size must be "
                         "32000 >= vocab_size <= 257024")
    self.lora_a_stacked = torch.zeros(
        (
            max_loras,
            1,
            lora_config.max_lora_rank,
            self.hidden_size,
        ),
        dtype=lora_config.lora_dtype,
        device=self.device,
    )
    self.lora_b_stacked = torch.zeros(
        (
            max_loras,
            1,
            # Pad for kernel compatibility
            math.ceil(self.base_layer.vocab_size /
                      lora_config.lora_vocab_padding_size) *
            lora_config.lora_vocab_padding_size,
            lora_config.max_lora_rank,
        ),
        dtype=lora_config.lora_dtype,
        device=self.device,
    )
    self.embeddings_tensors = torch.full(
        (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
        fill_value=float("-inf"),
        dtype=self.dtype,
        device=self.device,
    )
    if self.sharded_to_full_mapping is not None:
        self.sharded_to_full_mapping_gpu = torch.tensor(
            self.sharded_to_full_mapping,
            device=self.device,
            dtype=torch.long)
    else:
        self.sharded_to_full_mapping_gpu = None

forward

forward(*args, **kwargs)
Source code in vllm/lora/layers/logits_processor.py
def forward(self, *args, **kwargs):
    return type(self.base_layer).forward(self, *args, **kwargs)

reset_lora

reset_lora(index: int)
Source code in vllm/lora/layers/logits_processor.py
def reset_lora(self, index: int):
    self.lora_a_stacked[index] = 0
    self.lora_b_stacked[index] = 0
    self.embeddings_tensors[index] = float("-inf")

set_lora

set_lora(
    index: int,
    lora_a: Tensor,
    lora_b: Tensor,
    embeddings_tensor: Optional[Tensor],
    bias: Optional[Tensor] = None,
)
Source code in vllm/lora/layers/logits_processor.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor,
    lora_b: torch.Tensor,
    embeddings_tensor: Optional[torch.Tensor],
    bias: Optional[torch.Tensor] = None,
):
    self.reset_lora(index)
    self.lora_a_stacked[index,
                        0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
                            lora_a, non_blocking=True)
    self.lora_b_stacked[index,
                        0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
                            lora_b, non_blocking=True)
    if embeddings_tensor is not None:
        self.embeddings_tensors[
            index,
            :embeddings_tensor.shape[0],
            :embeddings_tensor.shape[1],
        ] = embeddings_tensor