Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.multi_connector

logger module-attribute

logger = init_logger(__name__)

MultiConnector

Bases: KVConnectorBase_V1

A wrapper for using multiple KVConnectors at the same time.

The current logic is: - Load KV from the first connector that advertises available tokens from get_num_new_matched_tokens(), based on the order in the config. - Save to all connectors.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
class MultiConnector(KVConnectorBase_V1):
    """
    A wrapper for using multiple KVConnectors at the same time.

    The current logic is:
    - Load KV from the first connector that advertises available tokens from
      get_num_new_matched_tokens(), based on the order in the config.
    - Save to all connectors.
    """

    def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
        super().__init__(vllm_config=vllm_config, role=role)
        self._connectors: list[KVConnectorBase_V1] = []
        self._ktc_kv_transfer_config = []
        ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
            "connectors")
        assert ktcs is not None
        for ktc in ktcs:
            temp_config = copy.copy(vllm_config)
            engine_id = ktc.get("engine_id",
                                vllm_config.kv_transfer_config.engine_id)
            temp_config.kv_transfer_config = KVTransferConfig(
                **ktc, engine_id=engine_id)
            self._connectors.append(
                KVConnectorFactory.create_connector(temp_config, role))
            self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)

        # A mapping from request id to the index of the connector chosen to
        # load the request from (if any).
        self._requests_to_connector: dict[str, int] = {}

        # Keeps track of *additional* remaining async saves (beyond 1) to be
        # finished per request. Not needed for async loads since we only allow
        # a single connector to load.
        # Propagated from scheduler to worker side via the connector metadata.
        self._extra_async_saves: dict[str, int] = {}

    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        for c in self._connectors:
            c.register_kv_caches(kv_caches)

    # We must override the base class method here because we need to bind
    # the metadata to each connector in the order of the connectors in the
    # MultiKVConnectorMetadata.
    def bind_connector_metadata(
            self, connector_metadata: KVConnectorMetadata) -> None:
        assert isinstance(connector_metadata, MultiKVConnectorMetadata)
        if connector_metadata.extra_async_saves:
            self._extra_async_saves.update(
                connector_metadata.extra_async_saves)
        for c, cm in zip(self._connectors, connector_metadata.metadata):
            c.bind_connector_metadata(cm)

    def clear_connector_metadata(self) -> None:
        for c in self._connectors:
            c.clear_connector_metadata()

    def shutdown(self):
        exception: Optional[Exception] = None
        for c in self._connectors:
            try:
                c.shutdown()
            except Exception as e:
                logger.exception("Exception during connector %s shutdown.",
                                 c.__class__.__name__)
                exception = e
        if exception:
            raise exception

    # ==============================
    # Worker-side methods
    # ==============================
    def start_load_kv(self, forward_context: "ForwardContext",
                      **kwargs) -> None:
        for c in self._connectors:
            c.start_load_kv(forward_context, **kwargs)

    def wait_for_layer_load(self, layer_name: str) -> None:
        for c in self._connectors:
            c.wait_for_layer_load(layer_name)

    def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                      attn_metadata: "AttentionMetadata", **kwargs) -> None:
        for c in self._connectors:
            c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)

    def wait_for_save(self):
        for c in self._connectors:
            c.wait_for_save()

    def get_finished(
        self, finished_req_ids: set[str]
    ) -> tuple[Optional[set[str]], Optional[set[str]]]:
        finished_sending: set[str] = set()
        finished_recving: set[str] = set()
        for c in self._connectors:
            sending, recving = c.get_finished(finished_req_ids)
            if not recving and not sending:
                continue
            # Aggregate finished recving request ids.
            finished_recving.update(recving or ())
            # Aggregate finished sending request ids - only include
            # once we've drained the "extra" count (for cases where
            # more than one connector is async-saving the same request).
            for req_id in sending or ():
                extra_pending = self._extra_async_saves.get(req_id)
                if extra_pending is None:
                    finished_sending.add(req_id)
                    continue
                assert extra_pending > 0
                if extra_pending == 1:
                    del self._extra_async_saves[req_id]
                else:
                    self._extra_async_saves[req_id] = extra_pending - 1

        return finished_sending or None, finished_recving or None

    # ==============================
    # Scheduler-side methods
    # ==============================
    def get_num_new_matched_tokens(
        self,
        request: "Request",
        num_computed_tokens: int,
    ) -> tuple[Optional[int], bool]:
        to_return = (0, False)
        for i, c in enumerate(self._connectors):
            toks, load_async = c.get_num_new_matched_tokens(
                request, num_computed_tokens)
            # If there is a connector still looking up the matches,
            # we return None to indicate that we are not done yet.
            if toks is None:
                return (None, False)
            # The first connector that has new matched tokens will be assigned
            # to this request.
            if to_return[0] == 0 and toks > 0:
                self._requests_to_connector[request.request_id] = i
                to_return = (toks, load_async)
        return to_return

    def update_state_after_alloc(self, request: "Request",
                                 blocks: "KVCacheBlocks",
                                 num_external_tokens: int):
        chosen_connector = self._requests_to_connector.get(
            request.request_id, -1)
        empty_blocks = blocks.new_empty()
        for i, c in enumerate(self._connectors):
            if i == chosen_connector:
                # Forward call to the chosen connector (if any).
                c.update_state_after_alloc(request, blocks,
                                           num_external_tokens)
            else:
                # Call with empty blocks for other connectors.
                c.update_state_after_alloc(request, empty_blocks, 0)

    def build_connector_meta(
            self,
            scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
        metadata = MultiKVConnectorMetadata(metadata=tuple(
            c.build_connector_meta(scheduler_output)
            for c in self._connectors))
        if self._extra_async_saves:
            metadata.extra_async_saves = self._extra_async_saves
            self._extra_async_saves = {}
        return metadata

    def update_connector_output(self, connector_output: KVConnectorOutput):
        for c in self._connectors:
            c.update_connector_output(connector_output)

    def request_finished(
        self,
        request: "Request",
        blocks: list[int],
    ) -> tuple[bool, Optional[dict[str, Any]]]:
        async_saves = 0
        kv_txfer_params = None
        for c in self._connectors:
            async_save, txfer_params = c.request_finished(request, blocks)
            if async_save:
                async_saves += 1
            if txfer_params is not None:
                if kv_txfer_params is not None:
                    # TODO we can probably change this to merge the dicts here,
                    # checking for key clashes.
                    raise RuntimeError(
                        "Only one connector can produce KV transfer params")
                kv_txfer_params = txfer_params
        if async_saves > 1:
            self._extra_async_saves[request.request_id] = async_saves - 1

        # Clean up other state for this request.
        self._requests_to_connector.pop(request.request_id, None)

        return async_saves > 0, kv_txfer_params

    def take_events(self) -> Iterable["KVCacheEvent"]:
        for c in self._connectors:
            yield from c.take_events()

    @classmethod
    def get_required_kvcache_layout(
            cls, vllm_config: "VllmConfig") -> Optional[str]:
        """
        Get the required KV cache layout for this connector.
        Args:
            vllm_config (VllmConfig): the vllm config.

        Returns:
            str: the required KV cache layout. e.g. HND, or NHD.
            None if the connector does not require a specific layout.
        """
        ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
            "connectors")
        assert ktcs is not None
        layouts: set[str] = set()
        temp_vllm_config = copy.copy(vllm_config)
        for ktc in ktcs:
            kv_transfer_config = KVTransferConfig(**ktc)
            temp_vllm_config.kv_transfer_config = kv_transfer_config
            connector_cls = KVConnectorFactory.get_connector_class(
                kv_transfer_config)
            required_kvcache_layout = (
                connector_cls.get_required_kvcache_layout(temp_vllm_config))
            if required_kvcache_layout is not None:
                layouts.add(required_kvcache_layout)

        if len(layouts) > 1:
            raise ValueError(f"KV cache layout mismatch: "
                             f"found {len(layouts)} different layouts "
                             f"({', '.join(layouts) })."
                             f"All connectors must use the same layout.")
        return next(iter(layouts), None)

    @classmethod
    def build_kv_connector_stats(
            cls,
            data: Optional[dict[str,
                                Any]] = None) -> Optional[KVConnectorStats]:
        return MultiKVConnectorStats(data=data) if data is not None \
            else MultiKVConnectorStats()

    def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]:
        # Group connector stats by connector type.
        stats_by_connector: Optional[MultiKVConnectorStats] = None
        for c in self._connectors:
            stats = c.get_kv_connector_stats()
            if stats is None:
                continue
            if stats_by_connector is None:
                # Lazy init to allow optional return value.
                stats_by_connector = MultiKVConnectorStats()
            stats_by_connector[c.__class__.__name__] = stats
        return stats_by_connector

_connectors instance-attribute

_connectors: list[KVConnectorBase_V1] = []

_extra_async_saves instance-attribute

_extra_async_saves: dict[str, int] = {}

_ktc_kv_transfer_config instance-attribute

_ktc_kv_transfer_config = []

_requests_to_connector instance-attribute

_requests_to_connector: dict[str, int] = {}

__init__

__init__(vllm_config: VllmConfig, role: KVConnectorRole)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
    super().__init__(vllm_config=vllm_config, role=role)
    self._connectors: list[KVConnectorBase_V1] = []
    self._ktc_kv_transfer_config = []
    ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
        "connectors")
    assert ktcs is not None
    for ktc in ktcs:
        temp_config = copy.copy(vllm_config)
        engine_id = ktc.get("engine_id",
                            vllm_config.kv_transfer_config.engine_id)
        temp_config.kv_transfer_config = KVTransferConfig(
            **ktc, engine_id=engine_id)
        self._connectors.append(
            KVConnectorFactory.create_connector(temp_config, role))
        self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)

    # A mapping from request id to the index of the connector chosen to
    # load the request from (if any).
    self._requests_to_connector: dict[str, int] = {}

    # Keeps track of *additional* remaining async saves (beyond 1) to be
    # finished per request. Not needed for async loads since we only allow
    # a single connector to load.
    # Propagated from scheduler to worker side via the connector metadata.
    self._extra_async_saves: dict[str, int] = {}

bind_connector_metadata

bind_connector_metadata(
    connector_metadata: KVConnectorMetadata,
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def bind_connector_metadata(
        self, connector_metadata: KVConnectorMetadata) -> None:
    assert isinstance(connector_metadata, MultiKVConnectorMetadata)
    if connector_metadata.extra_async_saves:
        self._extra_async_saves.update(
            connector_metadata.extra_async_saves)
    for c, cm in zip(self._connectors, connector_metadata.metadata):
        c.bind_connector_metadata(cm)

build_connector_meta

build_connector_meta(
    scheduler_output: SchedulerOutput,
) -> MultiKVConnectorMetadata
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def build_connector_meta(
        self,
        scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
    metadata = MultiKVConnectorMetadata(metadata=tuple(
        c.build_connector_meta(scheduler_output)
        for c in self._connectors))
    if self._extra_async_saves:
        metadata.extra_async_saves = self._extra_async_saves
        self._extra_async_saves = {}
    return metadata

build_kv_connector_stats classmethod

build_kv_connector_stats(
    data: Optional[dict[str, Any]] = None,
) -> Optional[KVConnectorStats]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@classmethod
def build_kv_connector_stats(
        cls,
        data: Optional[dict[str,
                            Any]] = None) -> Optional[KVConnectorStats]:
    return MultiKVConnectorStats(data=data) if data is not None \
        else MultiKVConnectorStats()

clear_connector_metadata

clear_connector_metadata() -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def clear_connector_metadata(self) -> None:
    for c in self._connectors:
        c.clear_connector_metadata()

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[Optional[set[str]], Optional[set[str]]]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def get_finished(
    self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
    finished_sending: set[str] = set()
    finished_recving: set[str] = set()
    for c in self._connectors:
        sending, recving = c.get_finished(finished_req_ids)
        if not recving and not sending:
            continue
        # Aggregate finished recving request ids.
        finished_recving.update(recving or ())
        # Aggregate finished sending request ids - only include
        # once we've drained the "extra" count (for cases where
        # more than one connector is async-saving the same request).
        for req_id in sending or ():
            extra_pending = self._extra_async_saves.get(req_id)
            if extra_pending is None:
                finished_sending.add(req_id)
                continue
            assert extra_pending > 0
            if extra_pending == 1:
                del self._extra_async_saves[req_id]
            else:
                self._extra_async_saves[req_id] = extra_pending - 1

    return finished_sending or None, finished_recving or None

get_kv_connector_stats

get_kv_connector_stats() -> Optional[MultiKVConnectorStats]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]:
    # Group connector stats by connector type.
    stats_by_connector: Optional[MultiKVConnectorStats] = None
    for c in self._connectors:
        stats = c.get_kv_connector_stats()
        if stats is None:
            continue
        if stats_by_connector is None:
            # Lazy init to allow optional return value.
            stats_by_connector = MultiKVConnectorStats()
        stats_by_connector[c.__class__.__name__] = stats
    return stats_by_connector

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[Optional[int], bool]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def get_num_new_matched_tokens(
    self,
    request: "Request",
    num_computed_tokens: int,
) -> tuple[Optional[int], bool]:
    to_return = (0, False)
    for i, c in enumerate(self._connectors):
        toks, load_async = c.get_num_new_matched_tokens(
            request, num_computed_tokens)
        # If there is a connector still looking up the matches,
        # we return None to indicate that we are not done yet.
        if toks is None:
            return (None, False)
        # The first connector that has new matched tokens will be assigned
        # to this request.
        if to_return[0] == 0 and toks > 0:
            self._requests_to_connector[request.request_id] = i
            to_return = (toks, load_async)
    return to_return

get_required_kvcache_layout classmethod

get_required_kvcache_layout(
    vllm_config: VllmConfig,
) -> Optional[str]

Get the required KV cache layout for this connector. Args: vllm_config (VllmConfig): the vllm config.

Returns:

Name Type Description
str Optional[str]

the required KV cache layout. e.g. HND, or NHD.

Optional[str]

None if the connector does not require a specific layout.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@classmethod
def get_required_kvcache_layout(
        cls, vllm_config: "VllmConfig") -> Optional[str]:
    """
    Get the required KV cache layout for this connector.
    Args:
        vllm_config (VllmConfig): the vllm config.

    Returns:
        str: the required KV cache layout. e.g. HND, or NHD.
        None if the connector does not require a specific layout.
    """
    ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
        "connectors")
    assert ktcs is not None
    layouts: set[str] = set()
    temp_vllm_config = copy.copy(vllm_config)
    for ktc in ktcs:
        kv_transfer_config = KVTransferConfig(**ktc)
        temp_vllm_config.kv_transfer_config = kv_transfer_config
        connector_cls = KVConnectorFactory.get_connector_class(
            kv_transfer_config)
        required_kvcache_layout = (
            connector_cls.get_required_kvcache_layout(temp_vllm_config))
        if required_kvcache_layout is not None:
            layouts.add(required_kvcache_layout)

    if len(layouts) > 1:
        raise ValueError(f"KV cache layout mismatch: "
                         f"found {len(layouts)} different layouts "
                         f"({', '.join(layouts) })."
                         f"All connectors must use the same layout.")
    return next(iter(layouts), None)

register_kv_caches

register_kv_caches(kv_caches: dict[str, Tensor])
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
    for c in self._connectors:
        c.register_kv_caches(kv_caches)

request_finished

request_finished(
    request: Request, blocks: list[int]
) -> tuple[bool, Optional[dict[str, Any]]]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def request_finished(
    self,
    request: "Request",
    blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
    async_saves = 0
    kv_txfer_params = None
    for c in self._connectors:
        async_save, txfer_params = c.request_finished(request, blocks)
        if async_save:
            async_saves += 1
        if txfer_params is not None:
            if kv_txfer_params is not None:
                # TODO we can probably change this to merge the dicts here,
                # checking for key clashes.
                raise RuntimeError(
                    "Only one connector can produce KV transfer params")
            kv_txfer_params = txfer_params
    if async_saves > 1:
        self._extra_async_saves[request.request_id] = async_saves - 1

    # Clean up other state for this request.
    self._requests_to_connector.pop(request.request_id, None)

    return async_saves > 0, kv_txfer_params

save_kv_layer

save_kv_layer(
    layer_name: str,
    kv_layer: Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
                  attn_metadata: "AttentionMetadata", **kwargs) -> None:
    for c in self._connectors:
        c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)

shutdown

shutdown()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def shutdown(self):
    exception: Optional[Exception] = None
    for c in self._connectors:
        try:
            c.shutdown()
        except Exception as e:
            logger.exception("Exception during connector %s shutdown.",
                             c.__class__.__name__)
            exception = e
    if exception:
        raise exception

start_load_kv

start_load_kv(
    forward_context: ForwardContext, **kwargs
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def start_load_kv(self, forward_context: "ForwardContext",
                  **kwargs) -> None:
    for c in self._connectors:
        c.start_load_kv(forward_context, **kwargs)

take_events

take_events() -> Iterable[KVCacheEvent]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def take_events(self) -> Iterable["KVCacheEvent"]:
    for c in self._connectors:
        yield from c.take_events()

update_connector_output

update_connector_output(
    connector_output: KVConnectorOutput,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def update_connector_output(self, connector_output: KVConnectorOutput):
    for c in self._connectors:
        c.update_connector_output(connector_output)

update_state_after_alloc

update_state_after_alloc(
    request: Request,
    blocks: KVCacheBlocks,
    num_external_tokens: int,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def update_state_after_alloc(self, request: "Request",
                             blocks: "KVCacheBlocks",
                             num_external_tokens: int):
    chosen_connector = self._requests_to_connector.get(
        request.request_id, -1)
    empty_blocks = blocks.new_empty()
    for i, c in enumerate(self._connectors):
        if i == chosen_connector:
            # Forward call to the chosen connector (if any).
            c.update_state_after_alloc(request, blocks,
                                       num_external_tokens)
        else:
            # Call with empty blocks for other connectors.
            c.update_state_after_alloc(request, empty_blocks, 0)

wait_for_layer_load

wait_for_layer_load(layer_name: str) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def wait_for_layer_load(self, layer_name: str) -> None:
    for c in self._connectors:
        c.wait_for_layer_load(layer_name)

wait_for_save

wait_for_save()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def wait_for_save(self):
    for c in self._connectors:
        c.wait_for_save()

MultiKVConnectorMetadata dataclass

Bases: KVConnectorMetadata

Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@dataclass
class MultiKVConnectorMetadata(KVConnectorMetadata):
    metadata: tuple[KVConnectorMetadata, ...]
    extra_async_saves: Optional[dict[str, int]] = None

extra_async_saves class-attribute instance-attribute

extra_async_saves: Optional[dict[str, int]] = None

metadata instance-attribute

metadata: tuple[KVConnectorMetadata, ...]

__init__

__init__(
    metadata: tuple[KVConnectorMetadata, ...],
    extra_async_saves: Optional[dict[str, int]] = None,
) -> None

MultiKVConnectorStats dataclass

Bases: KVConnectorStats

Maintain a dict of KVConnectorStats objects, one for each connector. This is used to aggregate the stats from all connectors separately.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@dataclass
class MultiKVConnectorStats(KVConnectorStats):
    """
    Maintain a dict of KVConnectorStats objects, one for each connector.
    This is used to aggregate the stats from all connectors separately.
    """

    def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
        for connector_id, stats in other.data.items():
            if connector_id not in self.data:
                self[connector_id] = stats
            else:
                assert isinstance(stats, type(self.data[connector_id]))
                self[connector_id] = self[connector_id].aggregate(stats)
        return self

    def reset(self):
        for stats in self.data.values():
            stats.reset()

    def reduce(self) -> dict[str, Any]:
        # TODO (NickLucche) Adjust for logging on separate lines
        return {
            connector_id: stats.reduce()
            for connector_id, stats in self.data.items()
        }

    def is_empty(self) -> bool:
        return all(stats.is_empty() for stats in self.data.values())

    def __getitem__(self, connector_id: str) -> KVConnectorStats:
        return self.data[connector_id]

    def __setitem__(self, connector_id: str, stats: KVConnectorStats):
        self.data[connector_id] = stats

__getitem__

__getitem__(connector_id: str) -> KVConnectorStats
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def __getitem__(self, connector_id: str) -> KVConnectorStats:
    return self.data[connector_id]

__init__

__init__(data: dict[str, Any] = dict()) -> None

__setitem__

__setitem__(connector_id: str, stats: KVConnectorStats)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
    self.data[connector_id] = stats

aggregate

aggregate(other: KVConnectorStats) -> KVConnectorStats
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
    for connector_id, stats in other.data.items():
        if connector_id not in self.data:
            self[connector_id] = stats
        else:
            assert isinstance(stats, type(self.data[connector_id]))
            self[connector_id] = self[connector_id].aggregate(stats)
    return self

is_empty

is_empty() -> bool
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def is_empty(self) -> bool:
    return all(stats.is_empty() for stats in self.data.values())

reduce

reduce() -> dict[str, Any]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def reduce(self) -> dict[str, Any]:
    # TODO (NickLucche) Adjust for logging on separate lines
    return {
        connector_id: stats.reduce()
        for connector_id, stats in self.data.items()
    }

reset

reset()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
def reset(self):
    for stats in self.data.values():
        stats.reset()