Skip to content

vllm.v1.kv_cache_interface

logger module-attribute

logger = init_logger(__name__)

AttentionSpec dataclass

Bases: KVCacheSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class AttentionSpec(KVCacheSpec):
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype

    @property
    def page_size_bytes(self) -> int:
        return 2 * self.block_size * self.num_kv_heads * self.head_size \
                * get_dtype_size(self.dtype)

dtype instance-attribute

dtype: dtype

head_size instance-attribute

head_size: int

num_kv_heads instance-attribute

num_kv_heads: int

page_size_bytes property

page_size_bytes: int

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
) -> None

ChunkedLocalAttentionSpec dataclass

Bases: AttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
                         max_model_len)

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

attention_chunk_size instance-attribute

attention_chunk_size: int

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    attention_chunk_size: int,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    max_model_len = vllm_config.model_config.max_model_len
    max_num_batched_tokens = (
        vllm_config.scheduler_config.max_num_batched_tokens)

    # During chunked prefill, we allocate KV cache for at most
    # `self.attention_chunk_size` computed tokens plus the newly scheduled
    # tokens. And we won't allocate KV cache for more than `max_model_len`
    # tokens.
    num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
                     max_model_len)

    return cdiv(num_tokens, self.block_size) * self.page_size_bytes

CrossAttentionSpec dataclass

Bases: AttentionSpec

KV cache spec for cross-attention layers in encoder-decoder models.

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
        max_encoder_len = vllm_config.scheduler_config.\
            max_num_encoder_input_tokens
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    # For cross-attention, we need to cache encoder states
    # Get encoder length (e.g., 1500 for Whisper).
    max_encoder_len = vllm_config.scheduler_config.\
        max_num_encoder_input_tokens
    return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes

EncoderOnlyAttentionSpec dataclass

Bases: AttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    # Encoder-only layers do not need KV cache
    return 0

FullAttentionSpec dataclass

Bases: AttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class FullAttentionSpec(AttentionSpec):
    sliding_window: Optional[int] = None
    attention_chunk_size: Optional[int] = None
    """
    When hybrid allocator is disabled and the model contains both full 
    attention layers and sliding window attention layers, sliding 
    window attention are regarded as full attention in KV cache manager 
    (blocks are allocated for all tokens), while computed as sliding window 
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
    Default to None for not using sliding window attention.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        dcp_world_size = \
            vllm_config.parallel_config.decode_context_parallel_size
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
        if dcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size)
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

    @classmethod
    def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
                "same window size.")

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single 
        FullAttentionSpec object.
        """
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be "
            "FullAttentionSpec.")

        sliding_window = set(spec.sliding_window for spec in specs
                             if spec.sliding_window is not None)
        attention_chunk_size = set(spec.attention_chunk_size for spec in specs
                                   if spec.attention_chunk_size is not None)
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge")
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec.")
        assert (
            (merged_spec.sliding_window is not None) +
            (merged_spec.attention_chunk_size is not None) <= 1
        ), ("Model with both sliding window layers and chunked local attention "
            "layers is not supported.")
        return merged_spec

attention_chunk_size class-attribute instance-attribute

attention_chunk_size: Optional[int] = None

When hybrid allocator is disabled and the model contains both full attention layers and sliding window attention layers, sliding window attention are regarded as full attention in KV cache manager (blocks are allocated for all tokens), while computed as sliding window attention in model runner. In this case, we use FullAttentionSpec and record the sliding window size. Default to None for not using sliding window attention.

sliding_window class-attribute instance-attribute

sliding_window: Optional[int] = None

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    sliding_window: Optional[int] = None,
    attention_chunk_size: Optional[int] = None,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    max_model_len = vllm_config.model_config.max_model_len
    dcp_world_size = \
        vllm_config.parallel_config.decode_context_parallel_size
    # Note(hc): each dcp rank only need save
    # (max_model_len//dcp_world_size) tokens locally.
    if dcp_world_size > 1:
        max_model_len = cdiv(max_model_len, dcp_world_size)
    return cdiv(max_model_len, self.block_size) * self.page_size_bytes

merge classmethod

merge(specs: list[Self]) -> Self

Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object.

Source code in vllm/v1/kv_cache_interface.py
@classmethod
def merge(cls, specs: list[Self]) -> Self:
    """
    Merge a list of FullAttentionSpec objects into a single 
    FullAttentionSpec object.
    """
    assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
        "All attention layers in the same KV cache group must be "
        "FullAttentionSpec.")

    sliding_window = set(spec.sliding_window for spec in specs
                         if spec.sliding_window is not None)
    attention_chunk_size = set(spec.attention_chunk_size for spec in specs
                               if spec.attention_chunk_size is not None)
    assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
        "MLAAttentionSpec should be merged in MLAAttentionSpec.merge")
    merged_spec = cls(
        block_size=specs[0].block_size,
        num_kv_heads=specs[0].num_kv_heads,
        head_size=specs[0].head_size,
        dtype=specs[0].dtype,
        sliding_window=cls.merge_window_sizes(sliding_window),
        attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
    )
    for spec in specs:
        for f in fields(AttentionSpec):
            assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                "All attention layers in the same KV cache group must have "
                "the same attention spec.")
    assert (
        (merged_spec.sliding_window is not None) +
        (merged_spec.attention_chunk_size is not None) <= 1
    ), ("Model with both sliding window layers and chunked local attention "
        "layers is not supported.")
    return merged_spec

merge_window_sizes classmethod

merge_window_sizes(window_sizes: set[int]) -> Optional[int]
Source code in vllm/v1/kv_cache_interface.py
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
    if len(window_sizes) == 0:
        return None
    elif len(window_sizes) == 1:
        return window_sizes.pop()
    else:
        raise ValueError(
            "All attention layers in the same KV cache group must have the "
            "same window size.")

KVCacheConfig dataclass

The KV cache configuration of a model.

Source code in vllm/v1/kv_cache_interface.py
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
    """The number of KV cache blocks"""
    num_blocks: int
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_tensors: list[KVCacheTensor]
    """
    The kv cache groups of the model.
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
    """
    kv_cache_groups: list[KVCacheGroupSpec]

kv_cache_groups instance-attribute

kv_cache_groups: list[KVCacheGroupSpec]

kv_cache_tensors instance-attribute

kv_cache_tensors: list[KVCacheTensor]

The kv cache groups of the model. For models with only one type of attention, there is only one group that contains all layers. For models with multiple types of attention, there will be multiple groups, see _get_kv_cache_config_uniform_page_size for more details.

num_blocks instance-attribute

num_blocks: int

How should model runner initialize the KV cache tensors for each layer

__init__

__init__(
    num_blocks: int,
    kv_cache_tensors: list[KVCacheTensor],
    kv_cache_groups: list[KVCacheGroupSpec],
) -> None

KVCacheGroupSpec dataclass

Represents a group of model layers that share the same KV cache block table. These layers are regarded as one layer in the KV cache manager.

Source code in vllm/v1/kv_cache_interface.py
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec

kv_cache_spec instance-attribute

kv_cache_spec: KVCacheSpec

layer_names instance-attribute

layer_names: list[str]

__init__

__init__(
    layer_names: list[str], kv_cache_spec: KVCacheSpec
) -> None

KVCacheSpec dataclass

A base class for specifying the KV cache format of one layer.

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class KVCacheSpec:
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        """
        The maximum possible memory usage of this KV cache in bytes.

        Returns:
            The KV cache size in bytes
        """
        raise NotImplementedError

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
        assert all(spec == specs[0] for spec in specs[1:]), (
            "All layers in the same KV cache group must be the same.")
        return copy.deepcopy(specs[0])

block_size instance-attribute

block_size: int

page_size_bytes property

page_size_bytes: int

The size of a page with block_size tokens in bytes.

Returns:

Type Description
int

The page size

__init__

__init__(block_size: int) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int

The maximum possible memory usage of this KV cache in bytes.

Returns:

Type Description
int

The KV cache size in bytes

Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    """
    The maximum possible memory usage of this KV cache in bytes.

    Returns:
        The KV cache size in bytes
    """
    raise NotImplementedError

merge classmethod

merge(specs: list[Self]) -> Self

Merge a list of KVCacheSpec objects into a single KVCacheSpec object.

Source code in vllm/v1/kv_cache_interface.py
@classmethod
def merge(cls, specs: list[Self]) -> Self:
    """
    Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
    """
    assert all(spec == specs[0] for spec in specs[1:]), (
        "All layers in the same KV cache group must be the same.")
    return copy.deepcopy(specs[0])

KVCacheTensor dataclass

A class for specifying how the workers should initialize the KV cache.

Source code in vllm/v1/kv_cache_interface.py
@dataclass
class KVCacheTensor:
    """
    A class for specifying how the workers should initialize the KV cache.
    """
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor

shared_by instance-attribute

shared_by: list[str]

size instance-attribute

size: int

__init__

__init__(size: int, shared_by: list[str]) -> None

MLAAttentionSpec dataclass

Bases: FullAttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
    cache_dtype_str: Optional[str] = None

    @property
    def page_size_bytes(self) -> int:
        if self.cache_dtype_str == "fp8_ds_mla":
            # See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
            #  for details.
            return self.block_size * 656
        return self.block_size * self.num_kv_heads * self.head_size \
                * get_dtype_size(self.dtype)

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be "
            "MLAAttentionSpec.")
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        assert len(cache_dtype_str_set) == 1, (
            "All attention layers in the same KV cache group must use the same "
            "quantization method.")
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            cache_dtype_str=cache_dtype_str_set.pop(),
        )

cache_dtype_str class-attribute instance-attribute

cache_dtype_str: Optional[str] = None

page_size_bytes property

page_size_bytes: int

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    sliding_window: Optional[int] = None,
    attention_chunk_size: Optional[int] = None,
    cache_dtype_str: Optional[str] = None,
) -> None

merge classmethod

merge(specs: list[Self]) -> Self
Source code in vllm/v1/kv_cache_interface.py
@classmethod
def merge(cls, specs: list[Self]) -> Self:
    assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
        "All attention layers in the same KV cache group must be "
        "MLAAttentionSpec.")
    cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
    assert len(cache_dtype_str_set) == 1, (
        "All attention layers in the same KV cache group must use the same "
        "quantization method.")
    return cls(
        block_size=specs[0].block_size,
        num_kv_heads=specs[0].num_kv_heads,
        head_size=specs[0].head_size,
        dtype=specs[0].dtype,
        cache_dtype_str=cache_dtype_str_set.pop(),
    )

MambaSpec dataclass

Bases: KVCacheSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
    dtypes: tuple[torch.dtype]
    page_size_padded: Optional[int] = None
    mamba_type: str = "mamba2"
    num_speculative_blocks: int = 0

    @property
    def page_size_bytes(self) -> int:
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
            for (shape, dtype) in zip(self.shapes, self.dtypes))
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # We allocate 1 block for each request now, so max_memory_usage_bytes is
        # the same as page_size_bytes.
        # Need to update this when supporting prefix caching.
        return self.page_size_bytes

dtypes instance-attribute

dtypes: tuple[dtype]

mamba_type class-attribute instance-attribute

mamba_type: str = 'mamba2'

num_speculative_blocks class-attribute instance-attribute

num_speculative_blocks: int = 0

page_size_bytes property

page_size_bytes: int

page_size_padded class-attribute instance-attribute

page_size_padded: Optional[int] = None

shapes instance-attribute

shapes: tuple[tuple[int, ...], ...]

__init__

__init__(
    block_size: int,
    shapes: tuple[tuple[int, ...], ...],
    dtypes: tuple[dtype],
    page_size_padded: Optional[int] = None,
    mamba_type: str = "mamba2",
    num_speculative_blocks: int = 0,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    # We allocate 1 block for each request now, so max_memory_usage_bytes is
    # the same as page_size_bytes.
    # Need to update this when supporting prefix caching.
    return self.page_size_bytes

SlidingWindowSpec dataclass

Bases: AttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
            "DCP not support sliding window."
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
                         max_model_len)

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes

sliding_window instance-attribute

sliding_window: int

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    sliding_window: int,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
        "DCP not support sliding window."
    max_model_len = vllm_config.model_config.max_model_len
    max_num_batched_tokens = (
        vllm_config.scheduler_config.max_num_batched_tokens)

    # During chunked prefill, we allocate KV cache for the last
    # `self.sliding_window-1` computed tokens plus the newly scheduled
    # tokens. And we won't allocate KV cache for more than `max_model_len`
    # tokens.
    num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
                     max_model_len)

    # +1 here because the sliding window may not start from the beginning
    # of the block. For example, if the block size is 4 and num_token
    # is 4, we need two blocks [XXCD] [EF] to store the sliding
    # window [CDEF] of 6 tokens.
    return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes

UniformTypeKVCacheSpecs dataclass

Bases: KVCacheSpec

A KV cache spec for multiple layers with the same type of attention. Here, same types means always need the same number of token slots. For example, sliding window attentions with different window sizes are not the same type and should not be merged into one UniformTypeKVCacheSpecs.

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
        return sum(spec.page_size_bytes
                   for spec in self.kv_cache_specs.values())

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
            cdiv(spec.max_memory_usage_bytes(vllm_config),
                 spec.page_size_bytes)
            for spec in self.kv_cache_specs.values())
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
        if isinstance(one_spec, FullAttentionSpec):
            return all(
                isinstance(spec, FullAttentionSpec)
                for spec in kv_cache_specs.values())
        elif isinstance(one_spec, CrossAttentionSpec):
            return all(
                isinstance(spec, CrossAttentionSpec)
                for spec in kv_cache_specs.values())
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
                for spec in kv_cache_specs.values())
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
                for spec in kv_cache_specs.values())
        elif isinstance(one_spec, MambaSpec):
            return all(
                isinstance(spec, MambaSpec) and spec.num_speculative_blocks ==
                one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values())
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
                f"Unsupported KV cache spec type: {type(one_spec)}")

    @classmethod
    def from_specs(cls, kv_cache_specs: dict[str,
                                             KVCacheSpec]) -> Optional[Self]:
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None

kv_cache_specs instance-attribute

kv_cache_specs: dict[str, KVCacheSpec]

page_size_bytes property

page_size_bytes: int

__init__

__init__(
    block_size: int, kv_cache_specs: dict[str, KVCacheSpec]
) -> None

from_specs classmethod

from_specs(
    kv_cache_specs: dict[str, KVCacheSpec],
) -> Optional[Self]

Return a SameTypeKVCacheSpecs object if all layers have the same type of KV cache spec. Return None if not.

Source code in vllm/v1/kv_cache_interface.py
@classmethod
def from_specs(cls, kv_cache_specs: dict[str,
                                         KVCacheSpec]) -> Optional[Self]:
    """
    Return a SameTypeKVCacheSpecs object if all layers have the same type
    of KV cache spec. Return None if not.
    """
    if cls.is_uniform_type(kv_cache_specs):
        block_size = next(iter(kv_cache_specs.values())).block_size
        return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
    else:
        return None

is_uniform_type classmethod

is_uniform_type(
    kv_cache_specs: dict[str, KVCacheSpec],
) -> bool

Whether all layers have the same type of KV cache spec.

Source code in vllm/v1/kv_cache_interface.py
@classmethod
def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
    """
    Whether all layers have the same type of KV cache spec.
    """
    block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
    if len(block_sizes) > 1:
        # Different block sizes, not uniform.
        return False
    one_spec = next(iter(kv_cache_specs.values()))
    if isinstance(one_spec, FullAttentionSpec):
        return all(
            isinstance(spec, FullAttentionSpec)
            for spec in kv_cache_specs.values())
    elif isinstance(one_spec, CrossAttentionSpec):
        return all(
            isinstance(spec, CrossAttentionSpec)
            for spec in kv_cache_specs.values())
    elif isinstance(one_spec, SlidingWindowSpec):
        return all(
            isinstance(spec, SlidingWindowSpec)
            and spec.sliding_window == one_spec.sliding_window
            for spec in kv_cache_specs.values())
    elif isinstance(one_spec, ChunkedLocalAttentionSpec):
        return all(
            isinstance(spec, ChunkedLocalAttentionSpec)
            and spec.attention_chunk_size == one_spec.attention_chunk_size
            for spec in kv_cache_specs.values())
    elif isinstance(one_spec, MambaSpec):
        return all(
            isinstance(spec, MambaSpec) and spec.num_speculative_blocks ==
            one_spec.num_speculative_blocks
            for spec in kv_cache_specs.values())
    else:
        # NOTE(Chen): Please add new branches for new KV cache spec types.
        raise NotImplementedError(
            f"Unsupported KV cache spec type: {type(one_spec)}")

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    max_num_pages = max(
        cdiv(spec.max_memory_usage_bytes(vllm_config),
             spec.page_size_bytes)
        for spec in self.kv_cache_specs.values())
    return max_num_pages * self.page_size_bytes