Context manager for micro-batching synchronization using threading events.
Source code in vllm/v1/worker/ubatching.py
| class UBatchContext:
"""
Context manager for micro-batching synchronization using threading events.
"""
def __init__(self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
forward_context: ForwardContext,
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
schedule: str = "default"):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.forward_context = forward_context
self.ready_barrier = ready_barrier
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.schedule = schedule
self.recv_hook = None
def __enter__(self):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self
self.ready_barrier.wait()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we want to start on the compute stream
self.update_stream(self.compute_stream)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_CURRENT_CONTEXTS[self.id] = None
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
self.maybe_run_recv_hook()
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
return False
def _restore_context(self):
forward_context._forward_context = self.forward_context
def update_stream(self, stream):
self.current_stream = stream
if current_stream() != self.current_stream:
torch.cuda.set_stream(self.current_stream)
def _signal_comm_done(self):
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
self.comm_stream.wait_event(self.gpu_compute_done_event)
def _wait_comm_done(self):
self.compute_stream.wait_event(self.gpu_comm_done_event)
def _cpu_yield(self):
# It is critical for correctness that only one thread is running
# at a time. These asserts just make sure that this is the only
# thread running before waking the other one up and going to sleep
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
def switch_to_comm(self):
self.update_stream(self.comm_stream)
def switch_to_compute(self):
self.update_stream(self.compute_stream)
def switch_to_comm_sync(self):
self._signal_compute_done()
self.update_stream(self.comm_stream)
self._wait_compute_done()
def switch_to_compute_sync(self):
self._signal_comm_done()
self.update_stream(self.compute_stream)
self._wait_comm_done()
def maybe_run_recv_hook(self):
if self.recv_hook is not None:
self.recv_hook()
self.recv_hook = None
def yield_(self):
self.current_stream = current_stream()
self._cpu_yield()
self.update_stream(self.current_stream)
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
self._signal_compute_done()
self._cpu_yield()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
self._signal_comm_done()
self._cpu_yield()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
|
comm_stream instance-attribute
comm_stream = comm_stream
compute_stream instance-attribute
compute_stream = compute_stream
cpu_signal_event instance-attribute
cpu_signal_event = cpu_signal_event
cpu_wait_event instance-attribute
cpu_wait_event = cpu_wait_event
current_stream instance-attribute
current_stream = compute_stream
forward_context instance-attribute
forward_context = forward_context
gpu_comm_done_event instance-attribute
gpu_comm_done_event = gpu_comm_done_event
gpu_compute_done_event instance-attribute
gpu_compute_done_event = gpu_compute_done_event
ready_barrier instance-attribute
ready_barrier = ready_barrier
recv_hook instance-attribute
schedule instance-attribute
__enter__
Source code in vllm/v1/worker/ubatching.py
| def __enter__(self):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self
self.ready_barrier.wait()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we want to start on the compute stream
self.update_stream(self.compute_stream)
return self
|
__exit__
__exit__(exc_type, exc_val, exc_tb)
Source code in vllm/v1/worker/ubatching.py
| def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_CURRENT_CONTEXTS[self.id] = None
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
self.maybe_run_recv_hook()
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
return False
|
__init__
__init__(
id: int,
comm_stream: Stream,
compute_stream: Stream,
forward_context: ForwardContext,
ready_barrier: Barrier,
cpu_wait_event: Event,
cpu_signal_event: Event,
gpu_comm_done_event: Event,
gpu_compute_done_event: Event,
schedule: str = "default",
)
Source code in vllm/v1/worker/ubatching.py
| def __init__(self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
forward_context: ForwardContext,
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
schedule: str = "default"):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.forward_context = forward_context
self.ready_barrier = ready_barrier
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.schedule = schedule
self.recv_hook = None
|
_cpu_yield
Source code in vllm/v1/worker/ubatching.py
| def _cpu_yield(self):
# It is critical for correctness that only one thread is running
# at a time. These asserts just make sure that this is the only
# thread running before waking the other one up and going to sleep
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
|
_restore_context
Source code in vllm/v1/worker/ubatching.py
| def _restore_context(self):
forward_context._forward_context = self.forward_context
|
_signal_comm_done
Source code in vllm/v1/worker/ubatching.py
| def _signal_comm_done(self):
self.gpu_comm_done_event.record(self.comm_stream)
|
_signal_compute_done
Source code in vllm/v1/worker/ubatching.py
| def _signal_compute_done(self):
self.gpu_compute_done_event.record(self.compute_stream)
|
_wait_comm_done
Source code in vllm/v1/worker/ubatching.py
| def _wait_comm_done(self):
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
_wait_compute_done
Source code in vllm/v1/worker/ubatching.py
| def _wait_compute_done(self):
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
maybe_run_recv_hook
Source code in vllm/v1/worker/ubatching.py
| def maybe_run_recv_hook(self):
if self.recv_hook is not None:
self.recv_hook()
self.recv_hook = None
|
switch_to_comm
Source code in vllm/v1/worker/ubatching.py
| def switch_to_comm(self):
self.update_stream(self.comm_stream)
|
switch_to_comm_sync
Source code in vllm/v1/worker/ubatching.py
| def switch_to_comm_sync(self):
self._signal_compute_done()
self.update_stream(self.comm_stream)
self._wait_compute_done()
|
switch_to_compute
Source code in vllm/v1/worker/ubatching.py
| def switch_to_compute(self):
self.update_stream(self.compute_stream)
|
switch_to_compute_sync
Source code in vllm/v1/worker/ubatching.py
| def switch_to_compute_sync(self):
self._signal_comm_done()
self.update_stream(self.compute_stream)
self._wait_comm_done()
|
update_stream
Source code in vllm/v1/worker/ubatching.py
| def update_stream(self, stream):
self.current_stream = stream
if current_stream() != self.current_stream:
torch.cuda.set_stream(self.current_stream)
|
yield_
Source code in vllm/v1/worker/ubatching.py
| def yield_(self):
self.current_stream = current_stream()
self._cpu_yield()
self.update_stream(self.current_stream)
|
yield_and_switch_from_comm_to_compute
yield_and_switch_from_comm_to_compute()
Source code in vllm/v1/worker/ubatching.py
| def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
self._signal_comm_done()
self._cpu_yield()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
|
yield_and_switch_from_compute_to_comm
yield_and_switch_from_compute_to_comm()
Source code in vllm/v1/worker/ubatching.py
| def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
self._signal_compute_done()
self._cpu_yield()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
|