Skip to content

vllm.v1.worker.ubatching

_CURRENT_CONTEXTS module-attribute

_CURRENT_CONTEXTS: list[Optional[UBatchContext]] = [
    None,
    None,
]

_THREAD_ID_TO_CONTEXT module-attribute

_THREAD_ID_TO_CONTEXT: dict = {}

dbo_maybe_run_recv_hook module-attribute

dbo_maybe_run_recv_hook = _register_ubatch_function(
    maybe_run_recv_hook
)

dbo_switch_to_comm module-attribute

dbo_switch_to_comm = _register_ubatch_function(
    switch_to_comm
)

dbo_switch_to_comm_sync module-attribute

dbo_switch_to_comm_sync = _register_ubatch_function(
    switch_to_comm_sync
)

dbo_switch_to_compute module-attribute

dbo_switch_to_compute = _register_ubatch_function(
    switch_to_compute
)

dbo_switch_to_compute_sync module-attribute

dbo_switch_to_compute_sync = _register_ubatch_function(
    switch_to_compute_sync
)

dbo_yield module-attribute

dbo_yield_and_switch_from_comm_to_compute module-attribute

dbo_yield_and_switch_from_comm_to_compute = (
    _register_ubatch_function(
        yield_and_switch_from_comm_to_compute
    )
)

dbo_yield_and_switch_from_compute_to_comm module-attribute

dbo_yield_and_switch_from_compute_to_comm = (
    _register_ubatch_function(
        yield_and_switch_from_compute_to_comm
    )
)

UBatchContext

Context manager for micro-batching synchronization using threading events.

Source code in vllm/v1/worker/ubatching.py
class UBatchContext:
    """
    Context manager for micro-batching synchronization using threading events.
    """

    def __init__(self,
                 id: int,
                 comm_stream: torch.cuda.Stream,
                 compute_stream: torch.cuda.Stream,
                 forward_context: ForwardContext,
                 ready_barrier: threading.Barrier,
                 cpu_wait_event: threading.Event,
                 cpu_signal_event: threading.Event,
                 gpu_comm_done_event: torch.cuda.Event,
                 gpu_compute_done_event: torch.cuda.Event,
                 schedule: str = "default"):
        self.id = id
        self.comm_stream = comm_stream
        self.compute_stream = compute_stream
        self.forward_context = forward_context
        self.ready_barrier = ready_barrier
        self.cpu_wait_event = cpu_wait_event
        self.cpu_signal_event = cpu_signal_event
        self.current_stream = compute_stream
        self.gpu_comm_done_event = gpu_comm_done_event
        self.gpu_compute_done_event = gpu_compute_done_event
        self.schedule = schedule
        self.recv_hook = None

    def __enter__(self):
        global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
        _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
        _CURRENT_CONTEXTS[self.id] = self
        self.ready_barrier.wait()

        self.cpu_wait_event.wait()
        self.cpu_wait_event.clear()
        self._restore_context()
        # Assume we want to start on the compute stream
        self.update_stream(self.compute_stream)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
        _CURRENT_CONTEXTS[self.id] = None
        del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
        self.maybe_run_recv_hook()
        self.cpu_signal_event.set()
        self.cpu_wait_event.clear()
        return False

    def _restore_context(self):
        forward_context._forward_context = self.forward_context

    def update_stream(self, stream):
        self.current_stream = stream
        if current_stream() != self.current_stream:
            torch.cuda.set_stream(self.current_stream)

    def _signal_comm_done(self):
        self.gpu_comm_done_event.record(self.comm_stream)

    def _signal_compute_done(self):
        self.gpu_compute_done_event.record(self.compute_stream)

    def _wait_compute_done(self):
        self.comm_stream.wait_event(self.gpu_compute_done_event)

    def _wait_comm_done(self):
        self.compute_stream.wait_event(self.gpu_comm_done_event)

    def _cpu_yield(self):
        # It is critical for correctness that only one thread is running
        # at a time. These asserts just make sure that this is the only
        # thread running before waking the other one up and going to sleep
        assert forward_context._forward_context == self.forward_context
        assert current_stream() == self.current_stream
        assert not self.cpu_wait_event.is_set()

        self.cpu_signal_event.set()
        self.cpu_wait_event.wait()
        self.cpu_wait_event.clear()
        self._restore_context()

    def switch_to_comm(self):
        self.update_stream(self.comm_stream)

    def switch_to_compute(self):
        self.update_stream(self.compute_stream)

    def switch_to_comm_sync(self):
        self._signal_compute_done()
        self.update_stream(self.comm_stream)
        self._wait_compute_done()

    def switch_to_compute_sync(self):
        self._signal_comm_done()
        self.update_stream(self.compute_stream)
        self._wait_comm_done()

    def maybe_run_recv_hook(self):
        if self.recv_hook is not None:
            self.recv_hook()
            self.recv_hook = None

    def yield_(self):
        self.current_stream = current_stream()
        self._cpu_yield()
        self.update_stream(self.current_stream)

    def yield_and_switch_from_compute_to_comm(self):
        assert current_stream() == self.compute_stream
        self._signal_compute_done()
        self._cpu_yield()
        assert self.current_stream == self.compute_stream
        self.update_stream(self.comm_stream)
        self._wait_compute_done()

    def yield_and_switch_from_comm_to_compute(self):
        assert current_stream() == self.comm_stream
        self._signal_comm_done()
        self._cpu_yield()
        assert self.current_stream == self.comm_stream
        self.update_stream(self.compute_stream)
        self._wait_comm_done()

comm_stream instance-attribute

comm_stream = comm_stream

compute_stream instance-attribute

compute_stream = compute_stream

cpu_signal_event instance-attribute

cpu_signal_event = cpu_signal_event

cpu_wait_event instance-attribute

cpu_wait_event = cpu_wait_event

current_stream instance-attribute

current_stream = compute_stream

forward_context instance-attribute

forward_context = forward_context

gpu_comm_done_event instance-attribute

gpu_comm_done_event = gpu_comm_done_event

gpu_compute_done_event instance-attribute

gpu_compute_done_event = gpu_compute_done_event

id instance-attribute

id = id

ready_barrier instance-attribute

ready_barrier = ready_barrier

recv_hook instance-attribute

recv_hook = None

schedule instance-attribute

schedule = schedule

__enter__

__enter__()
Source code in vllm/v1/worker/ubatching.py
def __enter__(self):
    global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
    _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
    _CURRENT_CONTEXTS[self.id] = self
    self.ready_barrier.wait()

    self.cpu_wait_event.wait()
    self.cpu_wait_event.clear()
    self._restore_context()
    # Assume we want to start on the compute stream
    self.update_stream(self.compute_stream)
    return self

__exit__

__exit__(exc_type, exc_val, exc_tb)
Source code in vllm/v1/worker/ubatching.py
def __exit__(self, exc_type, exc_val, exc_tb):
    global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
    _CURRENT_CONTEXTS[self.id] = None
    del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
    self.maybe_run_recv_hook()
    self.cpu_signal_event.set()
    self.cpu_wait_event.clear()
    return False

__init__

__init__(
    id: int,
    comm_stream: Stream,
    compute_stream: Stream,
    forward_context: ForwardContext,
    ready_barrier: Barrier,
    cpu_wait_event: Event,
    cpu_signal_event: Event,
    gpu_comm_done_event: Event,
    gpu_compute_done_event: Event,
    schedule: str = "default",
)
Source code in vllm/v1/worker/ubatching.py
def __init__(self,
             id: int,
             comm_stream: torch.cuda.Stream,
             compute_stream: torch.cuda.Stream,
             forward_context: ForwardContext,
             ready_barrier: threading.Barrier,
             cpu_wait_event: threading.Event,
             cpu_signal_event: threading.Event,
             gpu_comm_done_event: torch.cuda.Event,
             gpu_compute_done_event: torch.cuda.Event,
             schedule: str = "default"):
    self.id = id
    self.comm_stream = comm_stream
    self.compute_stream = compute_stream
    self.forward_context = forward_context
    self.ready_barrier = ready_barrier
    self.cpu_wait_event = cpu_wait_event
    self.cpu_signal_event = cpu_signal_event
    self.current_stream = compute_stream
    self.gpu_comm_done_event = gpu_comm_done_event
    self.gpu_compute_done_event = gpu_compute_done_event
    self.schedule = schedule
    self.recv_hook = None

_cpu_yield

_cpu_yield()
Source code in vllm/v1/worker/ubatching.py
def _cpu_yield(self):
    # It is critical for correctness that only one thread is running
    # at a time. These asserts just make sure that this is the only
    # thread running before waking the other one up and going to sleep
    assert forward_context._forward_context == self.forward_context
    assert current_stream() == self.current_stream
    assert not self.cpu_wait_event.is_set()

    self.cpu_signal_event.set()
    self.cpu_wait_event.wait()
    self.cpu_wait_event.clear()
    self._restore_context()

_restore_context

_restore_context()
Source code in vllm/v1/worker/ubatching.py
def _restore_context(self):
    forward_context._forward_context = self.forward_context

_signal_comm_done

_signal_comm_done()
Source code in vllm/v1/worker/ubatching.py
def _signal_comm_done(self):
    self.gpu_comm_done_event.record(self.comm_stream)

_signal_compute_done

_signal_compute_done()
Source code in vllm/v1/worker/ubatching.py
def _signal_compute_done(self):
    self.gpu_compute_done_event.record(self.compute_stream)

_wait_comm_done

_wait_comm_done()
Source code in vllm/v1/worker/ubatching.py
def _wait_comm_done(self):
    self.compute_stream.wait_event(self.gpu_comm_done_event)

_wait_compute_done

_wait_compute_done()
Source code in vllm/v1/worker/ubatching.py
def _wait_compute_done(self):
    self.comm_stream.wait_event(self.gpu_compute_done_event)

maybe_run_recv_hook

maybe_run_recv_hook()
Source code in vllm/v1/worker/ubatching.py
def maybe_run_recv_hook(self):
    if self.recv_hook is not None:
        self.recv_hook()
        self.recv_hook = None

switch_to_comm

switch_to_comm()
Source code in vllm/v1/worker/ubatching.py
def switch_to_comm(self):
    self.update_stream(self.comm_stream)

switch_to_comm_sync

switch_to_comm_sync()
Source code in vllm/v1/worker/ubatching.py
def switch_to_comm_sync(self):
    self._signal_compute_done()
    self.update_stream(self.comm_stream)
    self._wait_compute_done()

switch_to_compute

switch_to_compute()
Source code in vllm/v1/worker/ubatching.py
def switch_to_compute(self):
    self.update_stream(self.compute_stream)

switch_to_compute_sync

switch_to_compute_sync()
Source code in vllm/v1/worker/ubatching.py
def switch_to_compute_sync(self):
    self._signal_comm_done()
    self.update_stream(self.compute_stream)
    self._wait_comm_done()

update_stream

update_stream(stream)
Source code in vllm/v1/worker/ubatching.py
def update_stream(self, stream):
    self.current_stream = stream
    if current_stream() != self.current_stream:
        torch.cuda.set_stream(self.current_stream)

yield_

yield_()
Source code in vllm/v1/worker/ubatching.py
def yield_(self):
    self.current_stream = current_stream()
    self._cpu_yield()
    self.update_stream(self.current_stream)

yield_and_switch_from_comm_to_compute

yield_and_switch_from_comm_to_compute()
Source code in vllm/v1/worker/ubatching.py
def yield_and_switch_from_comm_to_compute(self):
    assert current_stream() == self.comm_stream
    self._signal_comm_done()
    self._cpu_yield()
    assert self.current_stream == self.comm_stream
    self.update_stream(self.compute_stream)
    self._wait_comm_done()

yield_and_switch_from_compute_to_comm

yield_and_switch_from_compute_to_comm()
Source code in vllm/v1/worker/ubatching.py
def yield_and_switch_from_compute_to_comm(self):
    assert current_stream() == self.compute_stream
    self._signal_compute_done()
    self._cpu_yield()
    assert self.current_stream == self.compute_stream
    self.update_stream(self.comm_stream)
    self._wait_compute_done()

_register_ubatch_function

_register_ubatch_function(func)
Source code in vllm/v1/worker/ubatching.py
def _register_ubatch_function(func):

    def wrapper(*args, **kwargs):
        if len(_THREAD_ID_TO_CONTEXT) > 0:
            ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
            ctx = _CURRENT_CONTEXTS[ctx_idx]
            func(ctx, *args, **kwargs)

    return wrapper

dbo_current_ubatch_id

dbo_current_ubatch_id() -> int
Source code in vllm/v1/worker/ubatching.py
def dbo_current_ubatch_id() -> int:
    if len(_THREAD_ID_TO_CONTEXT) == 0:
        return 0
    return _THREAD_ID_TO_CONTEXT[threading.get_ident()]

dbo_enabled

dbo_enabled() -> bool
Source code in vllm/v1/worker/ubatching.py
def dbo_enabled() -> bool:
    return len(_THREAD_ID_TO_CONTEXT) > 0

dbo_register_recv_hook

dbo_register_recv_hook(recv_hook)
Source code in vllm/v1/worker/ubatching.py
def dbo_register_recv_hook(recv_hook):
    if len(_THREAD_ID_TO_CONTEXT) > 0:
        ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
        next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
        next_ctx.recv_hook = recv_hook

make_ubatch_contexts

make_ubatch_contexts(
    num_micro_batches: int,
    compute_stream: Stream,
    comm_stream: Stream,
    forward_contexts: list[ForwardContext],
    ready_barrier: Barrier,
    schedule: str = "default",
) -> list[UBatchContext]
Source code in vllm/v1/worker/ubatching.py
def make_ubatch_contexts(
    num_micro_batches: int,
    compute_stream: torch.cuda.Stream,
    comm_stream: torch.cuda.Stream,
    forward_contexts: list[ForwardContext],
    ready_barrier: threading.Barrier,
    schedule: str = "default",
) -> list[UBatchContext]:
    assert num_micro_batches == 2, "only been tested with 2 micro-batches"
    """
    Create a context manager for micro-batching synchronization.
    """
    cpu_events = [threading.Event() for _ in range(num_micro_batches)]
    gpu_comm_done_events = [
        torch.cuda.Event() for _ in range(num_micro_batches)
    ]
    gpu_compute_done_events = [
        torch.cuda.Event() for _ in range(num_micro_batches)
    ]

    assert len(forward_contexts) == 2

    ctxs = []
    for i in range(num_micro_batches):
        ctx = UBatchContext(id=i,
                            compute_stream=compute_stream,
                            comm_stream=comm_stream,
                            forward_context=forward_contexts[i],
                            ready_barrier=ready_barrier,
                            cpu_wait_event=cpu_events[i],
                            cpu_signal_event=cpu_events[(i + 1) %
                                                        num_micro_batches],
                            gpu_comm_done_event=gpu_comm_done_events[i],
                            gpu_compute_done_event=gpu_compute_done_events[i],
                            schedule=schedule)
        ctxs.append(ctx)

    return ctxs