Skip to content

vllm.model_executor.layers.logits_processor

A layer that compute logits from hidden_stats.

LogitsProcessor

Bases: CustomOp

Process logits and apply logits processors from sampling metadata.

This layer does the following: 1. Gather logits from model hidden_states. 2. Scale logits if needed. 3. Apply logits processors (if any).

Source code in vllm/model_executor/layers/logits_processor.py
@CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
    """Process logits and apply logits processors from sampling metadata.

    This layer does the following:
    1. Gather logits from model hidden_states.
    2. Scale logits if needed.
    3. Apply logits processors (if any).
    """

    def __init__(self,
                 vocab_size: int,
                 org_vocab_size: Optional[int] = None,
                 scale: float = 1.0,
                 logits_as_input: bool = False,
                 soft_cap: Optional[float] = None) -> None:
        """
        Args:
            scale: A scaling factor to apply to the logits.
        """
        super().__init__()
        self.scale = scale
        self.vocab_size = vocab_size
        # Whether the input is logits (default is hidden states).
        self.logits_as_input = logits_as_input
        # original vocabulary size (without LoRA).
        self.org_vocab_size = org_vocab_size or vocab_size
        # Soft cap the logits. Used in Gemma 2.
        self.soft_cap = soft_cap
        # Whether to use gather or all-gather to gather the logits.
        self.use_all_gather = current_platform.use_all_gather()

    def forward(
        self,
        lm_head: VocabParallelEmbedding,
        hidden_states: torch.Tensor,
        embedding_bias: Optional[torch.Tensor] = None,
    ) -> Optional[torch.Tensor]:
        if self.logits_as_input:
            logits = hidden_states
        else:
            # Get the logits for the next tokens.
            logits = self._get_logits(hidden_states, lm_head, embedding_bias)
        if logits is not None:
            if self.soft_cap is not None:
                logits = logits / self.soft_cap
                logits = torch.tanh(logits)
                logits = logits * self.soft_cap

            if self.scale != 1.0:
                logits *= self.scale
        return logits

    def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
        """gather/all-gather the logits tensor across model parallel group."""
        if self.use_all_gather:
            # Gather is not supported for some devices such as TPUs.
            # Use all-gather instead.
            # NOTE(woosuk): Here, the outputs of every device should not be None
            # because XLA requires strict SPMD among all devices. Every device
            # should execute the same operations after gathering the logits.
            logits = tensor_model_parallel_all_gather(logits)
        else:
            # None may be returned for rank > 0
            logits = tensor_model_parallel_gather(logits)
        return logits

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
        lm_head: VocabParallelEmbedding,
        embedding_bias: Optional[torch.Tensor],
    ) -> Optional[torch.Tensor]:
        # Get the logits for the next tokens.
        logits = lm_head.quant_method.apply(lm_head,
                                            hidden_states,
                                            bias=embedding_bias)

        # Gather logits for TP
        logits = self._gather_logits(logits)

        # Remove paddings in vocab (if any).
        if logits is not None:
            logits = logits[..., :self.org_vocab_size]
        return logits

    def extra_repr(self) -> str:
        s = f"vocab_size={self.vocab_size}"
        s += f", org_vocab_size={self.org_vocab_size}"
        s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
        return s

logits_as_input instance-attribute

logits_as_input = logits_as_input

org_vocab_size instance-attribute

org_vocab_size = org_vocab_size or vocab_size

scale instance-attribute

scale = scale

soft_cap instance-attribute

soft_cap = soft_cap

use_all_gather instance-attribute

use_all_gather = use_all_gather()

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    vocab_size: int,
    org_vocab_size: Optional[int] = None,
    scale: float = 1.0,
    logits_as_input: bool = False,
    soft_cap: Optional[float] = None,
) -> None

Parameters:

Name Type Description Default
scale float

A scaling factor to apply to the logits.

1.0
Source code in vllm/model_executor/layers/logits_processor.py
def __init__(self,
             vocab_size: int,
             org_vocab_size: Optional[int] = None,
             scale: float = 1.0,
             logits_as_input: bool = False,
             soft_cap: Optional[float] = None) -> None:
    """
    Args:
        scale: A scaling factor to apply to the logits.
    """
    super().__init__()
    self.scale = scale
    self.vocab_size = vocab_size
    # Whether the input is logits (default is hidden states).
    self.logits_as_input = logits_as_input
    # original vocabulary size (without LoRA).
    self.org_vocab_size = org_vocab_size or vocab_size
    # Soft cap the logits. Used in Gemma 2.
    self.soft_cap = soft_cap
    # Whether to use gather or all-gather to gather the logits.
    self.use_all_gather = current_platform.use_all_gather()

_gather_logits

_gather_logits(logits: Tensor) -> Tensor

gather/all-gather the logits tensor across model parallel group.

Source code in vllm/model_executor/layers/logits_processor.py
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
    """gather/all-gather the logits tensor across model parallel group."""
    if self.use_all_gather:
        # Gather is not supported for some devices such as TPUs.
        # Use all-gather instead.
        # NOTE(woosuk): Here, the outputs of every device should not be None
        # because XLA requires strict SPMD among all devices. Every device
        # should execute the same operations after gathering the logits.
        logits = tensor_model_parallel_all_gather(logits)
    else:
        # None may be returned for rank > 0
        logits = tensor_model_parallel_gather(logits)
    return logits

_get_logits

_get_logits(
    hidden_states: Tensor,
    lm_head: VocabParallelEmbedding,
    embedding_bias: Optional[Tensor],
) -> Optional[Tensor]
Source code in vllm/model_executor/layers/logits_processor.py
def _get_logits(
    self,
    hidden_states: torch.Tensor,
    lm_head: VocabParallelEmbedding,
    embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
    # Get the logits for the next tokens.
    logits = lm_head.quant_method.apply(lm_head,
                                        hidden_states,
                                        bias=embedding_bias)

    # Gather logits for TP
    logits = self._gather_logits(logits)

    # Remove paddings in vocab (if any).
    if logits is not None:
        logits = logits[..., :self.org_vocab_size]
    return logits

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/logits_processor.py
def extra_repr(self) -> str:
    s = f"vocab_size={self.vocab_size}"
    s += f", org_vocab_size={self.org_vocab_size}"
    s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
    return s

forward

forward(
    lm_head: VocabParallelEmbedding,
    hidden_states: Tensor,
    embedding_bias: Optional[Tensor] = None,
) -> Optional[Tensor]
Source code in vllm/model_executor/layers/logits_processor.py
def forward(
    self,
    lm_head: VocabParallelEmbedding,
    hidden_states: torch.Tensor,
    embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
    if self.logits_as_input:
        logits = hidden_states
    else:
        # Get the logits for the next tokens.
        logits = self._get_logits(hidden_states, lm_head, embedding_bias)
    if logits is not None:
        if self.soft_cap is not None:
            logits = logits / self.soft_cap
            logits = torch.tanh(logits)
            logits = logits * self.soft_cap

        if self.scale != 1.0:
            logits *= self.scale
    return logits