Skip to content

vllm.v1.kv_offload.worker.cpu_gpu

logger module-attribute

logger = init_logger(__name__)

CpuGpuOffloadingHandler

Bases: OffloadingHandler

Source code in vllm/v1/kv_offload/worker/cpu_gpu.py
class CpuGpuOffloadingHandler(OffloadingHandler):

    def __init__(self, gpu_block_size: int, cpu_block_size: int,
                 num_cpu_blocks: int, gpu_caches: dict[str, torch.Tensor],
                 attn_backends: dict[str, type[AttentionBackend]]):
        assert cpu_block_size % gpu_block_size == 0
        self.block_size_factor = cpu_block_size // gpu_block_size

        # cuda streams for gpu->cpu and cpu->gpu
        self.d2h_stream = torch.cuda.Stream()
        self.h2d_stream = torch.cuda.Stream()

        # job_id -> transfer cuda event
        self.transfer_events: dict[int, torch.cuda.Event] = {}
        # list of cuda events available for re-use
        self.events_pool: list[torch.cuda.Event] = []

        pin_memory = is_pin_memory_available()

        # allocate cpu tensors
        logger.info("Allocating %d CPU tensors...", len(gpu_caches))
        self.gpu_tensors: list[torch.Tensor] = []
        self.cpu_tensors: list[torch.Tensor] = []
        self.kv_dim_before_num_blocks: list[bool] = []
        for layer_name, gpu_tensor in gpu_caches.items():
            self.gpu_tensors.append(gpu_tensor)

            gpu_shape = gpu_tensor.shape
            test_shape = attn_backends[layer_name].get_kv_cache_shape(
                num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256)
            if test_shape[0] == 1234:
                # shape is (num_blocks, ...)
                num_blocks_idx = 0
                self.kv_dim_before_num_blocks.append(False)
            else:
                # shape should be (2, num_blocks, ...)
                assert test_shape[0] == 2
                assert test_shape[1] == 1234
                assert gpu_shape[0] == 2

                num_blocks_idx = 1
                self.kv_dim_before_num_blocks.append(True)

            cpu_shape = list(gpu_shape)
            cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor

            logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
            self.cpu_tensors.append(
                torch.zeros(cpu_shape,
                            dtype=gpu_tensor.dtype,
                            device="cpu",
                            pin_memory=pin_memory))

    def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
        src_spec, dst_spec = spec
        if isinstance(src_spec, CPULoadStoreSpec):
            assert isinstance(dst_spec, GPULoadStoreSpec)
            stream = self.h2d_stream
            src_tensors = self.cpu_tensors
            dst_tensors = self.gpu_tensors
            src_block_size_factor = self.block_size_factor
            dst_block_size_factor = 1
        else:
            assert isinstance(src_spec, GPULoadStoreSpec)
            assert isinstance(dst_spec, CPULoadStoreSpec)
            stream = self.d2h_stream
            src_tensors = self.gpu_tensors
            dst_tensors = self.cpu_tensors
            src_block_size_factor = 1
            dst_block_size_factor = self.block_size_factor

        src_blocks = src_spec.block_ids
        dst_blocks = dst_spec.block_ids
        assert src_blocks.ndim == 1
        assert dst_blocks.ndim == 1

        dst_sub_blocks_to_skip = (-src_blocks.size % dst_block_size_factor)
        src_sub_block_count = src_blocks.size * src_block_size_factor

        assert (
            src_sub_block_count == dst_blocks.size * dst_block_size_factor -
            dst_sub_blocks_to_skip)

        src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
        expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
        expand_block_ids(dst_blocks,
                         dst_block_size_factor,
                         src_to_dst[:, 1],
                         skip_count=dst_sub_blocks_to_skip)
        src_to_dst_tensor = torch.from_numpy(src_to_dst)

        event = self.events_pool.pop() if self.events_pool \
            else torch.cuda.Event()
        with torch.cuda.stream(stream):
            for src_tensor, dst_tensor, kv_dim in zip(
                    src_tensors, dst_tensors, self.kv_dim_before_num_blocks):
                if kv_dim:
                    src_key_cache = src_tensor[0]
                    dst_key_cache = dst_tensor[0]
                    ops.swap_blocks(src_key_cache, dst_key_cache,
                                    src_to_dst_tensor)
                    src_value_cache = src_tensor[1]
                    dst_value_cache = dst_tensor[1]
                    ops.swap_blocks(src_value_cache, dst_value_cache,
                                    src_to_dst_tensor)
                else:
                    ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
            event.record(stream)

        self.transfer_events[job_id] = event

        # success
        return True

    def get_finished(self) -> list[TransferResult]:
        results: list[TransferResult] = []
        for job_id, event in self.transfer_events.items():
            if event.query():
                results.append((job_id, True))
                self.events_pool.append(event)
        for job_id, _ in results:
            del self.transfer_events[job_id]
        return results

block_size_factor instance-attribute

block_size_factor = cpu_block_size // gpu_block_size

cpu_tensors instance-attribute

cpu_tensors: list[Tensor] = []

d2h_stream instance-attribute

d2h_stream = Stream()

events_pool instance-attribute

events_pool: list[Event] = []

gpu_tensors instance-attribute

gpu_tensors: list[Tensor] = []

h2d_stream instance-attribute

h2d_stream = Stream()

kv_dim_before_num_blocks instance-attribute

kv_dim_before_num_blocks: list[bool] = []

transfer_events instance-attribute

transfer_events: dict[int, Event] = {}

__init__

__init__(
    gpu_block_size: int,
    cpu_block_size: int,
    num_cpu_blocks: int,
    gpu_caches: dict[str, Tensor],
    attn_backends: dict[str, type[AttentionBackend]],
)
Source code in vllm/v1/kv_offload/worker/cpu_gpu.py
def __init__(self, gpu_block_size: int, cpu_block_size: int,
             num_cpu_blocks: int, gpu_caches: dict[str, torch.Tensor],
             attn_backends: dict[str, type[AttentionBackend]]):
    assert cpu_block_size % gpu_block_size == 0
    self.block_size_factor = cpu_block_size // gpu_block_size

    # cuda streams for gpu->cpu and cpu->gpu
    self.d2h_stream = torch.cuda.Stream()
    self.h2d_stream = torch.cuda.Stream()

    # job_id -> transfer cuda event
    self.transfer_events: dict[int, torch.cuda.Event] = {}
    # list of cuda events available for re-use
    self.events_pool: list[torch.cuda.Event] = []

    pin_memory = is_pin_memory_available()

    # allocate cpu tensors
    logger.info("Allocating %d CPU tensors...", len(gpu_caches))
    self.gpu_tensors: list[torch.Tensor] = []
    self.cpu_tensors: list[torch.Tensor] = []
    self.kv_dim_before_num_blocks: list[bool] = []
    for layer_name, gpu_tensor in gpu_caches.items():
        self.gpu_tensors.append(gpu_tensor)

        gpu_shape = gpu_tensor.shape
        test_shape = attn_backends[layer_name].get_kv_cache_shape(
            num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256)
        if test_shape[0] == 1234:
            # shape is (num_blocks, ...)
            num_blocks_idx = 0
            self.kv_dim_before_num_blocks.append(False)
        else:
            # shape should be (2, num_blocks, ...)
            assert test_shape[0] == 2
            assert test_shape[1] == 1234
            assert gpu_shape[0] == 2

            num_blocks_idx = 1
            self.kv_dim_before_num_blocks.append(True)

        cpu_shape = list(gpu_shape)
        cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor

        logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
        self.cpu_tensors.append(
            torch.zeros(cpu_shape,
                        dtype=gpu_tensor.dtype,
                        device="cpu",
                        pin_memory=pin_memory))

get_finished

get_finished() -> list[TransferResult]
Source code in vllm/v1/kv_offload/worker/cpu_gpu.py
def get_finished(self) -> list[TransferResult]:
    results: list[TransferResult] = []
    for job_id, event in self.transfer_events.items():
        if event.query():
            results.append((job_id, True))
            self.events_pool.append(event)
    for job_id, _ in results:
        del self.transfer_events[job_id]
    return results

transfer_async

transfer_async(job_id: int, spec: TransferSpec) -> bool
Source code in vllm/v1/kv_offload/worker/cpu_gpu.py
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
    src_spec, dst_spec = spec
    if isinstance(src_spec, CPULoadStoreSpec):
        assert isinstance(dst_spec, GPULoadStoreSpec)
        stream = self.h2d_stream
        src_tensors = self.cpu_tensors
        dst_tensors = self.gpu_tensors
        src_block_size_factor = self.block_size_factor
        dst_block_size_factor = 1
    else:
        assert isinstance(src_spec, GPULoadStoreSpec)
        assert isinstance(dst_spec, CPULoadStoreSpec)
        stream = self.d2h_stream
        src_tensors = self.gpu_tensors
        dst_tensors = self.cpu_tensors
        src_block_size_factor = 1
        dst_block_size_factor = self.block_size_factor

    src_blocks = src_spec.block_ids
    dst_blocks = dst_spec.block_ids
    assert src_blocks.ndim == 1
    assert dst_blocks.ndim == 1

    dst_sub_blocks_to_skip = (-src_blocks.size % dst_block_size_factor)
    src_sub_block_count = src_blocks.size * src_block_size_factor

    assert (
        src_sub_block_count == dst_blocks.size * dst_block_size_factor -
        dst_sub_blocks_to_skip)

    src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
    expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
    expand_block_ids(dst_blocks,
                     dst_block_size_factor,
                     src_to_dst[:, 1],
                     skip_count=dst_sub_blocks_to_skip)
    src_to_dst_tensor = torch.from_numpy(src_to_dst)

    event = self.events_pool.pop() if self.events_pool \
        else torch.cuda.Event()
    with torch.cuda.stream(stream):
        for src_tensor, dst_tensor, kv_dim in zip(
                src_tensors, dst_tensors, self.kv_dim_before_num_blocks):
            if kv_dim:
                src_key_cache = src_tensor[0]
                dst_key_cache = dst_tensor[0]
                ops.swap_blocks(src_key_cache, dst_key_cache,
                                src_to_dst_tensor)
                src_value_cache = src_tensor[1]
                dst_value_cache = dst_tensor[1]
                ops.swap_blocks(src_value_cache, dst_value_cache,
                                src_to_dst_tensor)
            else:
                ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
        event.record(stream)

    self.transfer_events[job_id] = event

    # success
    return True

expand_block_ids

expand_block_ids(
    block_ids: ndarray,
    block_size_factor: int,
    output: ndarray,
    skip_count: int = 0,
)

Convert a list of block IDs to a list of matching block ids, assuming each block is composed of actual block_size_factor blocks. Outputs to output tensor. The first skip_count blocks will be skipped. Note that skip_count must be less than block_size_factor.

For example, if block_ids = [0, 1, 3] and block_size_factor = 4, then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] since 0 maps to [0, 1, 2, 3] 1 maps to [4, 5, 6, 7] and 3 maps to [12, 13, 14, 15]

Source code in vllm/v1/kv_offload/worker/cpu_gpu.py
def expand_block_ids(block_ids: np.ndarray,
                     block_size_factor: int,
                     output: np.ndarray,
                     skip_count: int = 0):
    """
    Convert a list of block IDs to a list of matching block ids,
    assuming each block is composed of actual block_size_factor blocks.
    Outputs to output tensor.
    The first skip_count blocks will be skipped.
    Note that skip_count must be less than block_size_factor.

    For example, if block_ids = [0, 1, 3] and block_size_factor =  4,
    then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
    since 0 maps to [0, 1, 2, 3]
    1 maps to [4, 5, 6, 7]
    and 3 maps to [12, 13, 14, 15]
    """
    assert skip_count < block_size_factor

    first_range = np.arange(skip_count, block_size_factor)
    full_range = np.arange(0, block_size_factor)

    output_idx = 0
    for i, block_id in enumerate(block_ids):
        base_block_id = block_id * block_size_factor
        indices = first_range if i == 0 else full_range
        output_end_idx = output_idx + len(indices)
        output[output_idx:output_end_idx] = base_block_id + indices
        output_idx = output_end_idx