Skip to content

vllm.forward_context

_forward_context module-attribute

_forward_context: Optional[ForwardContext] = None

batchsize_forward_time module-attribute

batchsize_forward_time: defaultdict = defaultdict(list)

batchsize_logging_interval module-attribute

batchsize_logging_interval: float = (
    VLLM_LOG_BATCHSIZE_INTERVAL
)

forward_start_time module-attribute

forward_start_time: float = 0

last_logging_time module-attribute

last_logging_time: float = 0

logger module-attribute

logger = init_logger(__name__)

track_batchsize module-attribute

track_batchsize: bool = VLLM_LOG_BATCHSIZE_INTERVAL >= 0

BatchDescriptor

Bases: NamedTuple

Batch descriptor for cudagraph dispatching. We should keep the num of items as minimal as possible to properly and uniquely describe the padded batch for cudagraph.

Source code in vllm/forward_context.py
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
    num_tokens: int
    uniform_decode: bool = False
    """
    False can also be used for an uniform decode batch to dispatch to the 
    cudagraph supporting non-uniform batches.
    """

    @property
    def non_uniform(self) -> "BatchDescriptor":
        """
        Return a non-uniform version of current batch descriptor.
        """
        return BatchDescriptor(self.num_tokens, uniform_decode=False)

non_uniform property

non_uniform: BatchDescriptor

Return a non-uniform version of current batch descriptor.

num_tokens instance-attribute

num_tokens: int

uniform_decode class-attribute instance-attribute

uniform_decode: bool = False

False can also be used for an uniform decode batch to dispatch to the cudagraph supporting non-uniform batches.

DPMetadata dataclass

Source code in vllm/forward_context.py
@dataclass
class DPMetadata:
    max_tokens_across_dp_cpu: torch.Tensor
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
    local_sizes: Optional[list[int]] = None

    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
        from vllm.distributed.parallel_state import get_dp_group
        device = current_platform.device_type
        group = get_dp_group().device_group

        # Transfering this tensor from GPU to CPU will introduce a GPU sync
        # point that could adversely affect performance of vllm with asynch
        # scheduling. This environment variable exists to quickly disable
        # this optimization if we run into this case.
        if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
            logger.info_once(
                "Using CPU all reduce to syncronize DP padding between ranks.")
            device = "cpu"
            group = get_dp_group().cpu_group
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device=device,
                                         dtype=torch.int32)
        dist.all_reduce(num_tokens_tensor, group=group)
        return num_tokens_tensor.cpu()

    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cummulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            (self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
        num_tokens_across_sp_cpu = (
            num_tokens_across_sp_cpu.repeat_interleave(sp_size))
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

    @staticmethod
    def should_ubatch_across_dp(
            should_ubatch: bool, orig_num_tokens_per_ubatch: int,
            padded_num_tokens_per_ubatch: int, dp_size: int,
            dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
        """
        1. Decides if each DP rank is going to microbatch. Either all ranks
        run with microbatching or none of them do. If this function decides
        not to run with microbatching. It will "abort" meaning that no padding
        information will be returned to the caller. It will return (False, None)

        2. Determines the total number of tokens that each rank will run.
        All ranks will be padded out so that the run with the same number
        of tokens

        Returns: tuple[
            should_ubatch: Are all DP ranks going to microbatch
            num_tokens_after_padding: A tensor containing the total number of
            tokens per-microbatch for each DP rank including padding. Will be
            None if should_ubatch if False
        ]
        """

        device = current_platform.device_type
        tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
        tensor[0][dp_rank] = orig_num_tokens_per_ubatch
        tensor[1][dp_rank] = padded_num_tokens_per_ubatch
        tensor[2][dp_rank] = 1 if should_ubatch else 0

        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(tensor, group=get_dp_group().device_group)

        result: bool = bool(torch.all(tensor[2] == 1).item())
        if not result:
            return result, None

        orig_num_tokens_tensor = tensor[0, :]
        padded_num_tokens_tensor = tensor[1, :]

        orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
        padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
        if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
            logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
                         padded_max_num_tokens)
            return False, None
        return result, padded_num_tokens_tensor.cpu()

    @staticmethod
    def make(
        parallel_config: ParallelConfig,
        attn_metadata: Any,
        num_tokens: int,
        num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
    ) -> "DPMetadata":

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
        assert (num_tokens_across_dp_cpu is None
                or num_tokens_across_dp_cpu[dp_rank] == batchsize
                ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        if num_tokens_across_dp_cpu is None:
            num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
                batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)

    @contextmanager
    def chunked_sizes(self, sequence_parallel_size: int,
                      max_chunk_size_per_rank: int, chunk_idx: int):
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        `self.local_sizes` is only valid inside the context.

        Args:
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                    we use SP between the layers to avoid
                                    redundant ops. We need this value to
                                    compute the chunked sizes.
            max_chunk_size_per_rank: The max number of tokens each rank is 
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        self.local_sizes = _compute_chunked_local_num_tokens(
            self.num_tokens_across_dp_cpu, sequence_parallel_size,
            max_chunk_size_per_rank, chunk_idx)
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
        Context mamager for setting self.local_sizes. Same as self.chunked_sizes
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
            self.num_tokens_across_dp_cpu, sequence_parallel_size)
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
        assert self.local_sizes is not None
        return self.local_sizes

local_sizes class-attribute instance-attribute

local_sizes: Optional[list[int]] = None

max_tokens_across_dp_cpu instance-attribute

max_tokens_across_dp_cpu: Tensor

num_tokens_across_dp_cpu instance-attribute

num_tokens_across_dp_cpu: Tensor

__init__

__init__(
    max_tokens_across_dp_cpu: Tensor,
    num_tokens_across_dp_cpu: Tensor,
    local_sizes: Optional[list[int]] = None,
) -> None

chunked_sizes

chunked_sizes(
    sequence_parallel_size: int,
    max_chunk_size_per_rank: int,
    chunk_idx: int,
)

Context manager to compute and temporarily set the per-rank local token sizes for a specific chunk during chunked forward execution.

This is necessary to ensure each DP (data parallel) rank processes its designated portion of tokens in lockstep with others, even when the token counts are uneven or some ranks have completed their input early.

For chunked execution, we break up the total tokens on each rank into multiple chunks (of at most max_chunk_size_per_rank), and for a given chunk_idx, this context manager sets self.local_sizes to the number of tokens to process in that chunk on each rank.

self.local_sizes is only valid inside the context.

Parameters:

Name Type Description Default
sequence_parallel_size int

When Attn is TP and MoE layers are EP, we use SP between the layers to avoid redundant ops. We need this value to compute the chunked sizes.

required
max_chunk_size_per_rank int

The max number of tokens each rank is allowed to process in this chunk.

required
chunk_idx int

The index of the chunk to compute sizes for.

required
Source code in vllm/forward_context.py
@contextmanager
def chunked_sizes(self, sequence_parallel_size: int,
                  max_chunk_size_per_rank: int, chunk_idx: int):
    """
    Context manager to compute and temporarily set the per-rank local token
    sizes for a specific chunk during chunked forward execution.

    This is necessary to ensure each DP (data parallel) rank processes its
    designated portion of tokens in lockstep with others, even when the
    token counts are uneven or some ranks have completed their input early.

    For chunked execution, we break up the total tokens on each rank into
    multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
    `chunk_idx`, this context manager sets `self.local_sizes` to the number
    of tokens to process in that chunk on each rank.

    `self.local_sizes` is only valid inside the context.

    Args:
        sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                we use SP between the layers to avoid
                                redundant ops. We need this value to
                                compute the chunked sizes.
        max_chunk_size_per_rank: The max number of tokens each rank is 
                                 allowed to process in this chunk.
        chunk_idx: The index of the chunk to compute sizes for.
    """
    self.local_sizes = _compute_chunked_local_num_tokens(
        self.num_tokens_across_dp_cpu, sequence_parallel_size,
        max_chunk_size_per_rank, chunk_idx)
    try:
        yield self.local_sizes
    finally:
        self.local_sizes = None

cu_tokens_across_sp

cu_tokens_across_sp(sp_size: int) -> Tensor
Source code in vllm/forward_context.py
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
    num_tokens_across_sp_cpu = (
        (self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
    num_tokens_across_sp_cpu = (
        num_tokens_across_sp_cpu.repeat_interleave(sp_size))
    return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

get_chunk_sizes_across_dp_rank

get_chunk_sizes_across_dp_rank() -> Optional[list[int]]
Source code in vllm/forward_context.py
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
    assert self.local_sizes is not None
    return self.local_sizes

make staticmethod

make(
    parallel_config: ParallelConfig,
    attn_metadata: Any,
    num_tokens: int,
    num_tokens_across_dp_cpu: Optional[Tensor] = None,
) -> DPMetadata
Source code in vllm/forward_context.py
@staticmethod
def make(
    parallel_config: ParallelConfig,
    attn_metadata: Any,
    num_tokens: int,
    num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
) -> "DPMetadata":

    assert parallel_config.data_parallel_size > 1
    dp_size = parallel_config.data_parallel_size
    dp_rank = parallel_config.data_parallel_rank
    if attn_metadata is not None and hasattr(attn_metadata,
                                             "num_prefill_tokens"):
        # for v0 attention backends
        batchsize = attn_metadata.num_prefill_tokens + \
            attn_metadata.num_decode_tokens
    else:
        # for v1 attention backends or no attn_metadata
        batchsize = num_tokens

    # If num_tokens_across_dp is None, it will be computed by all_reduce
    # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
    assert (num_tokens_across_dp_cpu is None
            or num_tokens_across_dp_cpu[dp_rank] == batchsize
            ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
    if num_tokens_across_dp_cpu is None:
        num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
            batchsize, dp_size, dp_rank)
    max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
    return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)

num_tokens_across_dp staticmethod

num_tokens_across_dp(
    num_tokens: int, dp_size: int, dp_rank: int
) -> Tensor

Gather the num_tokens across all DP ranks and return results in a CPU tensor of size dp_size.

Source code in vllm/forward_context.py
@staticmethod
def num_tokens_across_dp(num_tokens: int, dp_size: int,
                         dp_rank: int) -> torch.Tensor:
    """
    Gather the num_tokens across all DP ranks and return results in a
    CPU tensor of size dp_size.
    """
    from vllm.distributed.parallel_state import get_dp_group
    device = current_platform.device_type
    group = get_dp_group().device_group

    # Transfering this tensor from GPU to CPU will introduce a GPU sync
    # point that could adversely affect performance of vllm with asynch
    # scheduling. This environment variable exists to quickly disable
    # this optimization if we run into this case.
    if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
        logger.info_once(
            "Using CPU all reduce to syncronize DP padding between ranks.")
        device = "cpu"
        group = get_dp_group().cpu_group
    num_tokens_across_dp = [0] * dp_size
    num_tokens_across_dp[dp_rank] = num_tokens
    num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                     device=device,
                                     dtype=torch.int32)
    dist.all_reduce(num_tokens_tensor, group=group)
    return num_tokens_tensor.cpu()

should_ubatch_across_dp staticmethod

should_ubatch_across_dp(
    should_ubatch: bool,
    orig_num_tokens_per_ubatch: int,
    padded_num_tokens_per_ubatch: int,
    dp_size: int,
    dp_rank: int,
) -> tuple[bool, Optional[Tensor]]
  1. Decides if each DP rank is going to microbatch. Either all ranks run with microbatching or none of them do. If this function decides not to run with microbatching. It will "abort" meaning that no padding information will be returned to the caller. It will return (False, None)

  2. Determines the total number of tokens that each rank will run. All ranks will be padded out so that the run with the same number of tokens

tuple[

Name Type Description
should_ubatch bool

Are all DP ranks going to microbatch

num_tokens_after_padding Optional[Tensor]

A tensor containing the total number of

tuple[bool, Optional[Tensor]]

tokens per-microbatch for each DP rank including padding. Will be

tuple[bool, Optional[Tensor]]

None if should_ubatch if False

]

Source code in vllm/forward_context.py
@staticmethod
def should_ubatch_across_dp(
        should_ubatch: bool, orig_num_tokens_per_ubatch: int,
        padded_num_tokens_per_ubatch: int, dp_size: int,
        dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
    """
    1. Decides if each DP rank is going to microbatch. Either all ranks
    run with microbatching or none of them do. If this function decides
    not to run with microbatching. It will "abort" meaning that no padding
    information will be returned to the caller. It will return (False, None)

    2. Determines the total number of tokens that each rank will run.
    All ranks will be padded out so that the run with the same number
    of tokens

    Returns: tuple[
        should_ubatch: Are all DP ranks going to microbatch
        num_tokens_after_padding: A tensor containing the total number of
        tokens per-microbatch for each DP rank including padding. Will be
        None if should_ubatch if False
    ]
    """

    device = current_platform.device_type
    tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
    tensor[0][dp_rank] = orig_num_tokens_per_ubatch
    tensor[1][dp_rank] = padded_num_tokens_per_ubatch
    tensor[2][dp_rank] = 1 if should_ubatch else 0

    from vllm.distributed.parallel_state import get_dp_group
    dist.all_reduce(tensor, group=get_dp_group().device_group)

    result: bool = bool(torch.all(tensor[2] == 1).item())
    if not result:
        return result, None

    orig_num_tokens_tensor = tensor[0, :]
    padded_num_tokens_tensor = tensor[1, :]

    orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
    padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
    if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
        logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
                     padded_max_num_tokens)
        return False, None
    return result, padded_num_tokens_tensor.cpu()

sp_local_sizes

sp_local_sizes(sequence_parallel_size: int)

Context mamager for setting self.local_sizes. Same as self.chunked_sizes but without any chunking.

Source code in vllm/forward_context.py
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
    """
    Context mamager for setting self.local_sizes. Same as self.chunked_sizes
    but without any chunking.
    """
    self.local_sizes = _compute_sp_num_tokens(
        self.num_tokens_across_dp_cpu, sequence_parallel_size)
    try:
        yield self.local_sizes
    finally:
        self.local_sizes = None

ForwardContext dataclass

Source code in vllm/forward_context.py
@dataclass
class ForwardContext:
    # copy from vllm_config.compilation_config.static_forward_context
    no_compile_layers: dict[str, Any]
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
                         list[dict[str, "AttentionMetadata"]]]
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
    batch_descriptor: Optional[BatchDescriptor] = None

    ubatch_slices: Optional[UBatchSlices] = None

    def __post_init__(self):
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), \
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"

attn_metadata instance-attribute

batch_descriptor class-attribute instance-attribute

batch_descriptor: Optional[BatchDescriptor] = None

cudagraph_runtime_mode class-attribute instance-attribute

cudagraph_runtime_mode: CUDAGraphMode = NONE

dp_metadata class-attribute instance-attribute

dp_metadata: Optional[DPMetadata] = None

no_compile_layers instance-attribute

no_compile_layers: dict[str, Any]

Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one for each microbatch. Set dynamically for each forward pass

ubatch_slices class-attribute instance-attribute

ubatch_slices: Optional[UBatchSlices] = None

virtual_engine instance-attribute

virtual_engine: int

__init__

__init__(
    no_compile_layers: dict[str, Any],
    attn_metadata: Union[
        AttentionMetadata,
        dict[str, AttentionMetadata],
        list[dict[str, AttentionMetadata]],
    ],
    virtual_engine: int,
    dp_metadata: Optional[DPMetadata] = None,
    cudagraph_runtime_mode: CUDAGraphMode = NONE,
    batch_descriptor: Optional[BatchDescriptor] = None,
    ubatch_slices: Optional[UBatchSlices] = None,
) -> None

__post_init__

__post_init__()
Source code in vllm/forward_context.py
def __post_init__(self):
    assert self.cudagraph_runtime_mode.valid_runtime_modes(), \
        f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"

_compute_chunked_local_num_tokens

_compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: Tensor,
    sequence_parallel_size: int,
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]
Source code in vllm/forward_context.py
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
                                      sequence_parallel_size: int,
                                      max_num_tokens: int,
                                      chunk_idx: int) -> list[int]:

    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
                                       sequence_parallel_size)
    sp_size = len(sp_tokens)

    local_size = [-1] * sp_size
    for i in range(sp_size):
        # Take into account sharding if MoE activation is sequence parallel.
        local_size[i] = min(max_num_tokens,
                            sp_tokens[i] - (max_num_tokens * chunk_idx))
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size

_compute_sp_num_tokens

_compute_sp_num_tokens(
    num_tokens_across_dp_cpu: Tensor,
    sequence_parallel_size: int,
) -> list[int]
Source code in vllm/forward_context.py
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
                           sequence_parallel_size: int) -> list[int]:
    sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
                 sequence_parallel_size)

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()

create_forward_context

create_forward_context(
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
    dp_metadata: Optional[DPMetadata] = None,
    cudagraph_runtime_mode: CUDAGraphMode = NONE,
    batch_descriptor: Optional[BatchDescriptor] = None,
    ubatch_slices: Optional[UBatchSlices] = None,
)
Source code in vllm/forward_context.py
def create_forward_context(
        attn_metadata: Any,
        vllm_config: VllmConfig,
        virtual_engine: int = 0,
        dp_metadata: Optional[DPMetadata] = None,
        cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
        batch_descriptor: Optional[BatchDescriptor] = None,
        ubatch_slices: Optional[UBatchSlices] = None):
    return ForwardContext(no_compile_layers=vllm_config.compilation_config.
                          static_forward_context,
                          virtual_engine=virtual_engine,
                          attn_metadata=attn_metadata,
                          dp_metadata=dp_metadata,
                          cudagraph_runtime_mode=cudagraph_runtime_mode,
                          batch_descriptor=batch_descriptor,
                          ubatch_slices=ubatch_slices)

get_forward_context

get_forward_context() -> ForwardContext

Get the current forward context.

Source code in vllm/forward_context.py
def get_forward_context() -> ForwardContext:
    """Get the current forward context."""
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
    return _forward_context

override_forward_context

override_forward_context(
    forward_context: Optional[ForwardContext],
)

A context manager that overrides the current forward context. This is used to override the forward context for a specific forward pass.

Source code in vllm/forward_context.py
@contextmanager
def override_forward_context(forward_context: Optional[ForwardContext]):
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context

set_forward_context

set_forward_context(
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
    num_tokens: Optional[int] = None,
    num_tokens_across_dp: Optional[Tensor] = None,
    cudagraph_runtime_mode: CUDAGraphMode = NONE,
    batch_descriptor: Optional[BatchDescriptor] = None,
    ubatch_slices: Optional[UBatchSlices] = None,
)

A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass.

Source code in vllm/forward_context.py
@contextmanager
def set_forward_context(
        attn_metadata: Any,
        vllm_config: VllmConfig,
        virtual_engine: int = 0,
        num_tokens: Optional[int] = None,
        num_tokens_across_dp: Optional[torch.Tensor] = None,
        cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
        batch_descriptor: Optional[BatchDescriptor] = None,
        ubatch_slices: Optional[UBatchSlices] = None):
    """A context manager that stores the current forward context,
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
    global forward_start_time
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()

    dp_metadata: Optional[DPMetadata] = None
    if vllm_config.parallel_config.data_parallel_size > 1 and (
            attn_metadata is not None or num_tokens is not None):
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
                                      attn_metadata, num_tokens or 0,
                                      num_tokens_across_dp)

    forward_context = create_forward_context(attn_metadata, vllm_config,
                                             virtual_engine, dp_metadata,
                                             cudagraph_runtime_mode,
                                             batch_descriptor, ubatch_slices)

    try:
        with override_forward_context(forward_context):
            yield
    finally:
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
            if hasattr(attn_metadata, "num_prefill_tokens"):
                # for v0 attention backends
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
            else:
                # for v1 attention backends
                batchsize = num_tokens
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)