Bases: CustomOp
Process logits and apply logits processors from sampling metadata.
This layer does the following: 1. Gather logits from model hidden_states. 2. Scale logits if needed. 3. Apply logits processors (if any).
Source code in vllm/model_executor/layers/logits_processor.py
| @CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
1. Gather logits from model hidden_states.
2. Scale logits if needed.
3. Apply logits processors (if any).
"""
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False,
soft_cap: Optional[float] = None) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather()
def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap
if self.scale != 1.0:
logits *= self.scale
return logits
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""gather/all-gather the logits tensor across model parallel group."""
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
return logits
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
# Gather logits for TP
logits = self._gather_logits(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits
def extra_repr(self) -> str:
s = f"vocab_size={self.vocab_size}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
return s
|
logits_as_input = logits_as_input
org_vocab_size instance-attribute
org_vocab_size = org_vocab_size or vocab_size
soft_cap instance-attribute
use_all_gather instance-attribute
use_all_gather = use_all_gather()
vocab_size instance-attribute
__init__
Parameters:
Name | Type | Description | Default |
scale | float | A scaling factor to apply to the logits. | 1.0 |
Source code in vllm/model_executor/layers/logits_processor.py
| def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False,
soft_cap: Optional[float] = None) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather()
|
_gather_logits
gather/all-gather the logits tensor across model parallel group.
Source code in vllm/model_executor/layers/logits_processor.py
| def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""gather/all-gather the logits tensor across model parallel group."""
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
return logits
|
_get_logits
Source code in vllm/model_executor/layers/logits_processor.py
| def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
# Gather logits for TP
logits = self._gather_logits(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits
|
Source code in vllm/model_executor/layers/logits_processor.py
| def extra_repr(self) -> str:
s = f"vocab_size={self.vocab_size}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
return s
|
forward
Source code in vllm/model_executor/layers/logits_processor.py
| def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap
if self.scale != 1.0:
logits *= self.scale
return logits
|