Skip to content

vllm.v1.kv_offload.abstract

OffloadingManager class for managing KV data offloading in vLLM v1

This class runs in the scheduler, tracks which blocks are offloaded and their address.

The class provides the following primitives

lookup() - find the length of the maximal series of blocks, starting from the first one, that are all offloaded. prepare_load() - prepare given blocks to be read. The given blocks will be protected from eviction. This function returns a LoadSpec which encapsulates information required for performing the load. touch() - marks the give blocks as recently used. Can be used to track block's LRU. This function is separated from the prepare_load function to allow setting block recency even for blocks which do not need reading from the cache, such as blocks that are cached by the GPU prefix cache. complete_load() - mark blocks which were previously prepared to be loaded as done loading. This is to re-allow their eviction. prepare_store() - prepare the given blocks to be written. Returns a StoreSpec encapsulating offloading information, as well as a list of blocks that were evicted as a result. complete_store() - marks a previous store as completed. Following this call, the given blocks will become loadable.

LoadStoreSpec

Bases: ABC

Abstract metadata that encapsulates information allowing a worker to load, and optionally also to store, blocks of KV data.

Source code in vllm/v1/kv_offload/abstract.py
class LoadStoreSpec(ABC):
    """
    Abstract metadata that encapsulates information allowing a worker
    to load, and optionally also to store, blocks of KV data.
    """

    @staticmethod
    @abstractmethod
    def medium() -> str:
        """
        Returns a string representation of the medium type
        this store/load targets.
        """
        pass

medium abstractmethod staticmethod

medium() -> str

Returns a string representation of the medium type this store/load targets.

Source code in vllm/v1/kv_offload/abstract.py
@staticmethod
@abstractmethod
def medium() -> str:
    """
    Returns a string representation of the medium type
    this store/load targets.
    """
    pass

OffloadingEvent dataclass

Source code in vllm/v1/kv_offload/abstract.py
@dataclass
class OffloadingEvent:
    block_hashes: list[BlockHash]
    block_size: int
    medium: str
    # True if blocks are removed, False if stored
    removed: bool

block_hashes instance-attribute

block_hashes: list[BlockHash]

block_size instance-attribute

block_size: int

medium instance-attribute

medium: str

removed instance-attribute

removed: bool

__init__

__init__(
    block_hashes: list[BlockHash],
    block_size: int,
    medium: str,
    removed: bool,
) -> None

OffloadingManager

Bases: ABC

Source code in vllm/v1/kv_offload/abstract.py
class OffloadingManager(ABC):

    @abstractmethod
    def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
        """
        Finds the length of the maximal series of blocks, starting from the
        first one, that are all offloaded.

        Args:
            block_hashes: the hashes identifying the blocks to lookup.

        Returns:
            An integer representing the maximal number of blocks that
            are currently offloaded.
        """
        pass

    @abstractmethod
    def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
        """
        Prepare the given blocks to be read.
        The given blocks will be protected from eviction until
        complete_load is called.
        It assumes all given blocks are offloaded.

        Args:
            block_hashes: the hashes identifying the blocks.

        Returns:
            A LoadStoreSpec that can be used by a worker to locate and load
            the actual offloaded KV data.
        """
        pass

    def touch(self, block_hashes: Iterable[BlockHash]):
        """
        Mark the given blocks as recently used.
        This could in practice mean moving them to the end of an LRU list.

        Args:
            block_hashes: the hashes identifying the blocks.
        """
        return

    def complete_load(self, block_hashes: Iterable[BlockHash]):
        """
        Marks previous blocks that were prepared to load as done loading.

        Args:
            block_hashes: the hashes identifying the blocks.
        """
        return

    @abstractmethod
    def prepare_store(
            self,
            block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]:
        """
        Prepare the given blocks to be offloaded.
        The given blocks will be protected from eviction until
        complete_store is called.

        Args:
            block_hashes: the hashes identifying the blocks.

        Returns:
            A PrepareStoreOutput indicating which blocks need storing,
            where to store them (LoadStoreSpec), and list of blocks that
            were evicted as a result.
            None is returned if the blocks cannot be stored.
        """
        pass

    def complete_store(self,
                       block_hashes: Iterable[BlockHash],
                       success: bool = True):
        """
        Marks blocks which were previously prepared to be stored, as stored.
        Following this call, the blocks become loadable.
        If if_success is False, blocks that were not marked as stored will be
        removed.

        Args:
            block_hashes: the hashes identifying the blocks.
            success: whether the blocks were stored successfully.
        """
        return

    def take_events(self) -> Iterable[OffloadingEvent]:
        """
        Take the offloading events from the manager.

        Yields:
            New OffloadingEvents collected since the last call.
        """
        return ()

complete_load

complete_load(block_hashes: Iterable[BlockHash])

Marks previous blocks that were prepared to load as done loading.

Parameters:

Name Type Description Default
block_hashes Iterable[BlockHash]

the hashes identifying the blocks.

required
Source code in vllm/v1/kv_offload/abstract.py
def complete_load(self, block_hashes: Iterable[BlockHash]):
    """
    Marks previous blocks that were prepared to load as done loading.

    Args:
        block_hashes: the hashes identifying the blocks.
    """
    return

complete_store

complete_store(
    block_hashes: Iterable[BlockHash], success: bool = True
)

Marks blocks which were previously prepared to be stored, as stored. Following this call, the blocks become loadable. If if_success is False, blocks that were not marked as stored will be removed.

Parameters:

Name Type Description Default
block_hashes Iterable[BlockHash]

the hashes identifying the blocks.

required
success bool

whether the blocks were stored successfully.

True
Source code in vllm/v1/kv_offload/abstract.py
def complete_store(self,
                   block_hashes: Iterable[BlockHash],
                   success: bool = True):
    """
    Marks blocks which were previously prepared to be stored, as stored.
    Following this call, the blocks become loadable.
    If if_success is False, blocks that were not marked as stored will be
    removed.

    Args:
        block_hashes: the hashes identifying the blocks.
        success: whether the blocks were stored successfully.
    """
    return

lookup abstractmethod

lookup(block_hashes: Iterable[BlockHash]) -> int

Finds the length of the maximal series of blocks, starting from the first one, that are all offloaded.

Parameters:

Name Type Description Default
block_hashes Iterable[BlockHash]

the hashes identifying the blocks to lookup.

required

Returns:

Type Description
int

An integer representing the maximal number of blocks that

int

are currently offloaded.

Source code in vllm/v1/kv_offload/abstract.py
@abstractmethod
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
    """
    Finds the length of the maximal series of blocks, starting from the
    first one, that are all offloaded.

    Args:
        block_hashes: the hashes identifying the blocks to lookup.

    Returns:
        An integer representing the maximal number of blocks that
        are currently offloaded.
    """
    pass

prepare_load abstractmethod

prepare_load(
    block_hashes: Iterable[BlockHash],
) -> LoadStoreSpec

Prepare the given blocks to be read. The given blocks will be protected from eviction until complete_load is called. It assumes all given blocks are offloaded.

Parameters:

Name Type Description Default
block_hashes Iterable[BlockHash]

the hashes identifying the blocks.

required

Returns:

Type Description
LoadStoreSpec

A LoadStoreSpec that can be used by a worker to locate and load

LoadStoreSpec

the actual offloaded KV data.

Source code in vllm/v1/kv_offload/abstract.py
@abstractmethod
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
    """
    Prepare the given blocks to be read.
    The given blocks will be protected from eviction until
    complete_load is called.
    It assumes all given blocks are offloaded.

    Args:
        block_hashes: the hashes identifying the blocks.

    Returns:
        A LoadStoreSpec that can be used by a worker to locate and load
        the actual offloaded KV data.
    """
    pass

prepare_store abstractmethod

prepare_store(
    block_hashes: Iterable[BlockHash],
) -> Optional[PrepareStoreOutput]

Prepare the given blocks to be offloaded. The given blocks will be protected from eviction until complete_store is called.

Parameters:

Name Type Description Default
block_hashes Iterable[BlockHash]

the hashes identifying the blocks.

required

Returns:

Type Description
Optional[PrepareStoreOutput]

A PrepareStoreOutput indicating which blocks need storing,

Optional[PrepareStoreOutput]

where to store them (LoadStoreSpec), and list of blocks that

Optional[PrepareStoreOutput]

were evicted as a result.

Optional[PrepareStoreOutput]

None is returned if the blocks cannot be stored.

Source code in vllm/v1/kv_offload/abstract.py
@abstractmethod
def prepare_store(
        self,
        block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]:
    """
    Prepare the given blocks to be offloaded.
    The given blocks will be protected from eviction until
    complete_store is called.

    Args:
        block_hashes: the hashes identifying the blocks.

    Returns:
        A PrepareStoreOutput indicating which blocks need storing,
        where to store them (LoadStoreSpec), and list of blocks that
        were evicted as a result.
        None is returned if the blocks cannot be stored.
    """
    pass

take_events

take_events() -> Iterable[OffloadingEvent]

Take the offloading events from the manager.

Yields:

Type Description
Iterable[OffloadingEvent]

New OffloadingEvents collected since the last call.

Source code in vllm/v1/kv_offload/abstract.py
def take_events(self) -> Iterable[OffloadingEvent]:
    """
    Take the offloading events from the manager.

    Yields:
        New OffloadingEvents collected since the last call.
    """
    return ()

touch

touch(block_hashes: Iterable[BlockHash])

Mark the given blocks as recently used. This could in practice mean moving them to the end of an LRU list.

Parameters:

Name Type Description Default
block_hashes Iterable[BlockHash]

the hashes identifying the blocks.

required
Source code in vllm/v1/kv_offload/abstract.py
def touch(self, block_hashes: Iterable[BlockHash]):
    """
    Mark the given blocks as recently used.
    This could in practice mean moving them to the end of an LRU list.

    Args:
        block_hashes: the hashes identifying the blocks.
    """
    return

PrepareStoreOutput dataclass

Source code in vllm/v1/kv_offload/abstract.py
@dataclass
class PrepareStoreOutput:
    block_hashes_to_store: list[BlockHash]
    store_spec: LoadStoreSpec
    block_hashes_evicted: list[BlockHash]

block_hashes_evicted instance-attribute

block_hashes_evicted: list[BlockHash]

block_hashes_to_store instance-attribute

block_hashes_to_store: list[BlockHash]

store_spec instance-attribute

store_spec: LoadStoreSpec

__init__

__init__(
    block_hashes_to_store: list[BlockHash],
    store_spec: LoadStoreSpec,
    block_hashes_evicted: list[BlockHash],
) -> None