Skip to content

vllm.v1.attention.backends.gdn_attn ΒΆ

Backend for GatedDeltaNet attention.

GDNAttentionBackend ΒΆ

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/gdn_attn.py
class GDNAttentionBackend(AttentionBackend):

    @staticmethod
    def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
        return GDNAttentionMetadataBuilder

get_builder_cls staticmethod ΒΆ

get_builder_cls() -> type[GDNAttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/gdn_attn.py
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
    return GDNAttentionMetadataBuilder

GDNAttentionMetadata dataclass ΒΆ

Source code in vllm/v1/attention/backends/gdn_attn.py
@dataclass
class GDNAttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    num_spec_decodes: int
    num_spec_decode_tokens: int
    num_actual_tokens: int

    has_initial_state: Optional[torch.Tensor] = None

    spec_query_start_loc: Optional[
        torch.Tensor] = None  # shape: [num_spec_decodes + 1,]
    non_spec_query_start_loc: Optional[
        torch.Tensor] = None  # shape: [batch - num_spec_decodes + 1,]

    spec_state_indices_tensor: Optional[
        torch.Tensor] = None  # shape: [batch, num_spec]
    non_spec_state_indices_tensor: Optional[
        torch.Tensor] = None  # shape: [batch - num_spec_decodes,]
    spec_sequence_masks: Optional[torch.Tensor] = None  # shape: [batch,]
    spec_token_masks: Optional[
        torch.
        Tensor] = None  # shape: [num_prefill_tokens + num_decode_tokens,]
    num_accepted_tokens: Optional[torch.Tensor] = None  # shape: [batch,]

    # The following attributes are for triton implementation of causal_conv1d
    nums_dict: Optional[dict] = None
    batch_ptr: Optional[torch.Tensor] = None
    token_chunk_offset_ptr: Optional[torch.Tensor] = None

batch_ptr class-attribute instance-attribute ΒΆ

batch_ptr: Optional[Tensor] = None

has_initial_state class-attribute instance-attribute ΒΆ

has_initial_state: Optional[Tensor] = None

non_spec_query_start_loc class-attribute instance-attribute ΒΆ

non_spec_query_start_loc: Optional[Tensor] = None

non_spec_state_indices_tensor class-attribute instance-attribute ΒΆ

non_spec_state_indices_tensor: Optional[Tensor] = None

num_accepted_tokens class-attribute instance-attribute ΒΆ

num_accepted_tokens: Optional[Tensor] = None

num_actual_tokens instance-attribute ΒΆ

num_actual_tokens: int

num_decode_tokens instance-attribute ΒΆ

num_decode_tokens: int

num_decodes instance-attribute ΒΆ

num_decodes: int

num_prefill_tokens instance-attribute ΒΆ

num_prefill_tokens: int

num_prefills instance-attribute ΒΆ

num_prefills: int

num_spec_decode_tokens instance-attribute ΒΆ

num_spec_decode_tokens: int

num_spec_decodes instance-attribute ΒΆ

num_spec_decodes: int

nums_dict class-attribute instance-attribute ΒΆ

nums_dict: Optional[dict] = None

spec_query_start_loc class-attribute instance-attribute ΒΆ

spec_query_start_loc: Optional[Tensor] = None

spec_sequence_masks class-attribute instance-attribute ΒΆ

spec_sequence_masks: Optional[Tensor] = None

spec_state_indices_tensor class-attribute instance-attribute ΒΆ

spec_state_indices_tensor: Optional[Tensor] = None

spec_token_masks class-attribute instance-attribute ΒΆ

spec_token_masks: Optional[Tensor] = None

token_chunk_offset_ptr class-attribute instance-attribute ΒΆ

token_chunk_offset_ptr: Optional[Tensor] = None

__init__ ΒΆ

__init__(
    num_prefills: int,
    num_prefill_tokens: int,
    num_decodes: int,
    num_decode_tokens: int,
    num_spec_decodes: int,
    num_spec_decode_tokens: int,
    num_actual_tokens: int,
    has_initial_state: Optional[Tensor] = None,
    spec_query_start_loc: Optional[Tensor] = None,
    non_spec_query_start_loc: Optional[Tensor] = None,
    spec_state_indices_tensor: Optional[Tensor] = None,
    non_spec_state_indices_tensor: Optional[Tensor] = None,
    spec_sequence_masks: Optional[Tensor] = None,
    spec_token_masks: Optional[Tensor] = None,
    num_accepted_tokens: Optional[Tensor] = None,
    nums_dict: Optional[dict] = None,
    batch_ptr: Optional[Tensor] = None,
    token_chunk_offset_ptr: Optional[Tensor] = None,
) -> None

GDNAttentionMetadataBuilder ΒΆ

Bases: AttentionMetadataBuilder[GDNAttentionMetadata]

Source code in vllm/v1/attention/backends/gdn_attn.py
class GDNAttentionMetadataBuilder(
        AttentionMetadataBuilder[GDNAttentionMetadata]):

    cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

    reorder_batch_threshold: int = 1

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        assert isinstance(kv_cache_spec, MambaSpec)
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.speculative_config = vllm_config.speculative_config
        self.kv_cache_spec = kv_cache_spec
        if self.speculative_config:
            self.num_spec = self.speculative_config.num_speculative_tokens  # noqa: E501
        else:
            self.num_spec = 0
        self.use_spec_decode = self.num_spec > 0
        self._init_reorder_batch_threshold(1, self.use_spec_decode)

        self.use_full_cuda_graph = \
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()
        self.decode_cudagraph_max_bs = min(
            self.vllm_config.scheduler_config.max_num_seqs *
            (self.num_spec + 1), self.compilation_config.max_capture_size)

        self.spec_state_indices_tensor = torch.empty(
            (self.decode_cudagraph_max_bs, self.num_spec + 1),
            dtype=torch.int32,
            device=device,
        )
        self.non_spec_state_indices_tensor = torch.empty(
            (self.decode_cudagraph_max_bs, ),
            dtype=torch.int32,
            device=device,
        )
        self.spec_sequence_masks = torch.empty(
            (self.decode_cudagraph_max_bs, ),
            dtype=torch.bool,
            device=device,
        )
        self.spec_token_masks = torch.empty(
            (self.decode_cudagraph_max_bs * (self.num_spec + 1), ),
            dtype=torch.bool,
            device=device,
        )
        self.spec_query_start_loc = torch.empty(
            (self.decode_cudagraph_max_bs + 1, ),
            dtype=torch.int32,
            device=device,
        )
        self.non_spec_query_start_loc = torch.empty(
            (self.decode_cudagraph_max_bs + 1, ),
            dtype=torch.int32,
            device=device,
        )
        self.num_accepted_tokens = torch.empty(
            (self.decode_cudagraph_max_bs, ),
            dtype=torch.int32,
            device=device,
        )

    def build(  # type: ignore[override]
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        num_accepted_tokens: Optional[torch.Tensor] = None,
        num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
        fast_build: bool = False,
    ) -> GDNAttentionMetadata:
        m = common_attn_metadata

        query_start_loc = m.query_start_loc
        context_lens = m.num_computed_tokens_cpu
        context_lens_tensor = context_lens.to(query_start_loc.device)
        nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

        if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
                or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
                                               0].sum().item() == 0):
            spec_sequence_masks = None
            num_spec_decodes = 0
        else:
            spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
            num_spec_decodes = spec_sequence_masks.sum().item()
            if num_spec_decodes == 0:
                spec_sequence_masks = None
            else:
                spec_sequence_masks = spec_sequence_masks.to(
                    query_start_loc.device, non_blocking=True)

        if spec_sequence_masks is None:
            num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
                split_decodes_and_prefills(m, decode_threshold=1))
            num_spec_decode_tokens = 0
            spec_token_masks = None
            spec_state_indices_tensor = None
            non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
            spec_query_start_loc = None
            non_spec_query_start_loc = query_start_loc
            num_accepted_tokens = None
        else:
            query_lens = query_start_loc[1:] - query_start_loc[:-1]

            non_spec_query_lens = query_lens[~spec_sequence_masks]
            num_decodes = (non_spec_query_lens == 1).sum().item()
            num_prefills = non_spec_query_lens.size(0) - num_decodes
            num_decode_tokens = num_decodes
            num_prefill_tokens = non_spec_query_lens.sum().item(
            ) - num_decode_tokens

            if num_prefills == 0 and num_decodes == 0:
                spec_token_masks = torch.ones(
                    (min(num_spec_decodes *
                         (self.num_spec + 1), query_start_loc[-1].item())),
                    dtype=torch.bool,
                    device=query_start_loc.device)
                spec_state_indices_tensor = m.block_table_tensor[:, :self.
                                                                 num_spec + 1]
                non_spec_state_indices_tensor = None
                spec_query_start_loc = query_start_loc
                non_spec_query_start_loc = None
            else:
                spec_token_masks = torch.repeat_interleave(
                    spec_sequence_masks, query_lens)
                spec_state_indices_tensor = m.block_table_tensor[
                    spec_sequence_masks, :self.num_spec + 1]
                non_spec_state_indices_tensor = \
                    m.block_table_tensor[~spec_sequence_masks, 0]

                spec_query_start_loc = torch.zeros(
                    num_spec_decodes + 1,
                    dtype=torch.int32,
                    device=query_start_loc.device)
                torch.cumsum(query_lens[spec_sequence_masks],
                             dim=0,
                             out=spec_query_start_loc[1:])
                non_spec_query_start_loc = torch.zeros(
                    query_lens.size(0) - num_spec_decodes + 1,
                    dtype=torch.int32,
                    device=query_start_loc.device)
                torch.cumsum(query_lens[~spec_sequence_masks],
                             dim=0,
                             out=non_spec_query_start_loc[1:])

            num_spec_decode_tokens = (query_lens.sum().item() -
                                      num_prefill_tokens - num_decode_tokens)
            assert num_accepted_tokens is not None
            num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]

        if num_prefills > 0:
            has_initial_state = context_lens_tensor > 0
            if spec_sequence_masks is not None:
                has_initial_state = has_initial_state[~spec_sequence_masks]
            nums_dict, batch_ptr, token_chunk_offset_ptr = \
                compute_causal_conv1d_metadata(non_spec_query_start_loc)
        else:
            has_initial_state = None
        num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
            num_spec_decode_tokens

        # prepare tensors for cudagraph
        #
        # With speculative decoding, the xgrammar backend may rollback tokens
        # and causing some sequences has less draft tokens than self.num_spec.
        #
        # In above cases, the max possible batch size for n tokens, can be
        # min(n, cudagraph_max_bs).
        if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
                and num_spec_decodes <= self.decode_cudagraph_max_bs
                and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
            num_actual_tokens = self.vllm_config.pad_for_cudagraph(
                m.num_actual_tokens)
            batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)

            self.spec_state_indices_tensor[:num_spec_decodes].copy_(
                spec_state_indices_tensor, non_blocking=True)
            spec_state_indices_tensor = self.spec_state_indices_tensor[:
                                                                       batch_size]
            spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)

            self.spec_sequence_masks[:num_spec_decodes].copy_(
                spec_sequence_masks, non_blocking=True)
            spec_sequence_masks = self.spec_sequence_masks[:batch_size]
            spec_sequence_masks[num_spec_decodes:].fill_(False)

            assert spec_token_masks is not None
            self.spec_token_masks[:spec_token_masks.size(0)].copy_(
                spec_token_masks, non_blocking=True)
            spec_token_masks = self.spec_token_masks[:num_actual_tokens]
            spec_token_masks[spec_token_masks.size(0):].fill_(False)

            self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
                spec_query_start_loc, non_blocking=True)
            spec_num_query_tokens = spec_query_start_loc[
                -1]  # type: ignore[index]
            spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1]
            spec_query_start_loc[num_spec_decodes +
                                 1:].fill_(spec_num_query_tokens)

            self.num_accepted_tokens[:num_spec_decodes].copy_(
                num_accepted_tokens, non_blocking=True)
            num_accepted_tokens = self.num_accepted_tokens[:batch_size]
            num_accepted_tokens[num_spec_decodes:].fill_(1)

        if (self.use_full_cuda_graph and num_prefills == 0
                and num_spec_decodes == 0
                and num_decodes <= self.decode_cudagraph_max_bs):
            num_actual_tokens = self.vllm_config.pad_for_cudagraph(
                m.num_actual_tokens)
            batch_size = num_actual_tokens

            self.non_spec_state_indices_tensor[:num_decodes].copy_(
                non_spec_state_indices_tensor, non_blocking=True)
            non_spec_state_indices_tensor = \
                self.non_spec_state_indices_tensor[:batch_size]
            non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)

            self.non_spec_query_start_loc[:num_decodes + 1].copy_(
                non_spec_query_start_loc, non_blocking=True)
            non_spec_num_query_tokens = non_spec_query_start_loc[
                -1]  # type: ignore[index]
            non_spec_query_start_loc = \
                self.non_spec_query_start_loc[:batch_size + 1]
            non_spec_query_start_loc[num_decodes +
                                     1:].fill_(non_spec_num_query_tokens)

        attn_metadata = GDNAttentionMetadata(
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_spec_decodes=num_spec_decodes,
            num_spec_decode_tokens=num_spec_decode_tokens,
            num_actual_tokens=num_actual_tokens,
            has_initial_state=has_initial_state,
            spec_query_start_loc=spec_query_start_loc,
            non_spec_query_start_loc=non_spec_query_start_loc,
            spec_state_indices_tensor=spec_state_indices_tensor,
            non_spec_state_indices_tensor=non_spec_state_indices_tensor,
            spec_sequence_masks=spec_sequence_masks,
            spec_token_masks=spec_token_masks,
            num_accepted_tokens=num_accepted_tokens,
            nums_dict=nums_dict,
            batch_ptr=batch_ptr,
            token_chunk_offset_ptr=token_chunk_offset_ptr,
        )
        return attn_metadata

    def build_for_cudagraph_capture(
            self, common_attn_metadata: CommonAttentionMetadata):
        """
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with Mamba.
        """
        m = common_attn_metadata

        assert (
            m.num_reqs <= self.decode_cudagraph_max_bs
            and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
                f"GDN only supports decode-only full CUDAGraph capture. "
                f"Make sure batch size ({m.num_reqs}) <= "
                f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
                f"and number of tokens ({m.num_actual_tokens}) <= "
                f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")

        num_accepted_tokens = torch.diff(m.query_start_loc)
        num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
        m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()

        return self.build(0, m, num_accepted_tokens,
                          num_decode_draft_tokens_cpu)

compilation_config instance-attribute ΒΆ

compilation_config = compilation_config

cudagraph_support class-attribute instance-attribute ΒΆ

cudagraph_support = UNIFORM_BATCH

decode_cudagraph_max_bs instance-attribute ΒΆ

decode_cudagraph_max_bs = min(
    max_num_seqs * (num_spec + 1), max_capture_size
)

kv_cache_spec instance-attribute ΒΆ

kv_cache_spec = kv_cache_spec

non_spec_query_start_loc instance-attribute ΒΆ

non_spec_query_start_loc = empty(
    (decode_cudagraph_max_bs + 1,),
    dtype=int32,
    device=device,
)

non_spec_state_indices_tensor instance-attribute ΒΆ

non_spec_state_indices_tensor = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

num_accepted_tokens instance-attribute ΒΆ

num_accepted_tokens = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

num_spec instance-attribute ΒΆ

num_spec = num_speculative_tokens

reorder_batch_threshold class-attribute instance-attribute ΒΆ

reorder_batch_threshold: int = 1

spec_query_start_loc instance-attribute ΒΆ

spec_query_start_loc = empty(
    (decode_cudagraph_max_bs + 1,),
    dtype=int32,
    device=device,
)

spec_sequence_masks instance-attribute ΒΆ

spec_sequence_masks = empty(
    (decode_cudagraph_max_bs,), dtype=bool, device=device
)

spec_state_indices_tensor instance-attribute ΒΆ

spec_state_indices_tensor = empty(
    (decode_cudagraph_max_bs, num_spec + 1),
    dtype=int32,
    device=device,
)

spec_token_masks instance-attribute ΒΆ

spec_token_masks = empty(
    (decode_cudagraph_max_bs * (num_spec + 1),),
    dtype=bool,
    device=device,
)

speculative_config instance-attribute ΒΆ

speculative_config = speculative_config

use_full_cuda_graph instance-attribute ΒΆ

use_full_cuda_graph = has_full_cudagraphs()

use_spec_decode instance-attribute ΒΆ

use_spec_decode = num_spec > 0

vllm_config instance-attribute ΒΆ

vllm_config = vllm_config

__init__ ΒΆ

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/gdn_attn.py
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):
    assert isinstance(kv_cache_spec, MambaSpec)
    self.vllm_config = vllm_config
    self.compilation_config = vllm_config.compilation_config
    self.speculative_config = vllm_config.speculative_config
    self.kv_cache_spec = kv_cache_spec
    if self.speculative_config:
        self.num_spec = self.speculative_config.num_speculative_tokens  # noqa: E501
    else:
        self.num_spec = 0
    self.use_spec_decode = self.num_spec > 0
    self._init_reorder_batch_threshold(1, self.use_spec_decode)

    self.use_full_cuda_graph = \
        self.compilation_config.cudagraph_mode.has_full_cudagraphs()
    self.decode_cudagraph_max_bs = min(
        self.vllm_config.scheduler_config.max_num_seqs *
        (self.num_spec + 1), self.compilation_config.max_capture_size)

    self.spec_state_indices_tensor = torch.empty(
        (self.decode_cudagraph_max_bs, self.num_spec + 1),
        dtype=torch.int32,
        device=device,
    )
    self.non_spec_state_indices_tensor = torch.empty(
        (self.decode_cudagraph_max_bs, ),
        dtype=torch.int32,
        device=device,
    )
    self.spec_sequence_masks = torch.empty(
        (self.decode_cudagraph_max_bs, ),
        dtype=torch.bool,
        device=device,
    )
    self.spec_token_masks = torch.empty(
        (self.decode_cudagraph_max_bs * (self.num_spec + 1), ),
        dtype=torch.bool,
        device=device,
    )
    self.spec_query_start_loc = torch.empty(
        (self.decode_cudagraph_max_bs + 1, ),
        dtype=torch.int32,
        device=device,
    )
    self.non_spec_query_start_loc = torch.empty(
        (self.decode_cudagraph_max_bs + 1, ),
        dtype=torch.int32,
        device=device,
    )
    self.num_accepted_tokens = torch.empty(
        (self.decode_cudagraph_max_bs, ),
        dtype=torch.int32,
        device=device,
    )

build ΒΆ

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    num_accepted_tokens: Optional[Tensor] = None,
    num_decode_draft_tokens_cpu: Optional[Tensor] = None,
    fast_build: bool = False,
) -> GDNAttentionMetadata
Source code in vllm/v1/attention/backends/gdn_attn.py
def build(  # type: ignore[override]
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    num_accepted_tokens: Optional[torch.Tensor] = None,
    num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
    fast_build: bool = False,
) -> GDNAttentionMetadata:
    m = common_attn_metadata

    query_start_loc = m.query_start_loc
    context_lens = m.num_computed_tokens_cpu
    context_lens_tensor = context_lens.to(query_start_loc.device)
    nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

    if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
            or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
                                           0].sum().item() == 0):
        spec_sequence_masks = None
        num_spec_decodes = 0
    else:
        spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
        num_spec_decodes = spec_sequence_masks.sum().item()
        if num_spec_decodes == 0:
            spec_sequence_masks = None
        else:
            spec_sequence_masks = spec_sequence_masks.to(
                query_start_loc.device, non_blocking=True)

    if spec_sequence_masks is None:
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(m, decode_threshold=1))
        num_spec_decode_tokens = 0
        spec_token_masks = None
        spec_state_indices_tensor = None
        non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
        spec_query_start_loc = None
        non_spec_query_start_loc = query_start_loc
        num_accepted_tokens = None
    else:
        query_lens = query_start_loc[1:] - query_start_loc[:-1]

        non_spec_query_lens = query_lens[~spec_sequence_masks]
        num_decodes = (non_spec_query_lens == 1).sum().item()
        num_prefills = non_spec_query_lens.size(0) - num_decodes
        num_decode_tokens = num_decodes
        num_prefill_tokens = non_spec_query_lens.sum().item(
        ) - num_decode_tokens

        if num_prefills == 0 and num_decodes == 0:
            spec_token_masks = torch.ones(
                (min(num_spec_decodes *
                     (self.num_spec + 1), query_start_loc[-1].item())),
                dtype=torch.bool,
                device=query_start_loc.device)
            spec_state_indices_tensor = m.block_table_tensor[:, :self.
                                                             num_spec + 1]
            non_spec_state_indices_tensor = None
            spec_query_start_loc = query_start_loc
            non_spec_query_start_loc = None
        else:
            spec_token_masks = torch.repeat_interleave(
                spec_sequence_masks, query_lens)
            spec_state_indices_tensor = m.block_table_tensor[
                spec_sequence_masks, :self.num_spec + 1]
            non_spec_state_indices_tensor = \
                m.block_table_tensor[~spec_sequence_masks, 0]

            spec_query_start_loc = torch.zeros(
                num_spec_decodes + 1,
                dtype=torch.int32,
                device=query_start_loc.device)
            torch.cumsum(query_lens[spec_sequence_masks],
                         dim=0,
                         out=spec_query_start_loc[1:])
            non_spec_query_start_loc = torch.zeros(
                query_lens.size(0) - num_spec_decodes + 1,
                dtype=torch.int32,
                device=query_start_loc.device)
            torch.cumsum(query_lens[~spec_sequence_masks],
                         dim=0,
                         out=non_spec_query_start_loc[1:])

        num_spec_decode_tokens = (query_lens.sum().item() -
                                  num_prefill_tokens - num_decode_tokens)
        assert num_accepted_tokens is not None
        num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]

    if num_prefills > 0:
        has_initial_state = context_lens_tensor > 0
        if spec_sequence_masks is not None:
            has_initial_state = has_initial_state[~spec_sequence_masks]
        nums_dict, batch_ptr, token_chunk_offset_ptr = \
            compute_causal_conv1d_metadata(non_spec_query_start_loc)
    else:
        has_initial_state = None
    num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
        num_spec_decode_tokens

    # prepare tensors for cudagraph
    #
    # With speculative decoding, the xgrammar backend may rollback tokens
    # and causing some sequences has less draft tokens than self.num_spec.
    #
    # In above cases, the max possible batch size for n tokens, can be
    # min(n, cudagraph_max_bs).
    if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
            and num_spec_decodes <= self.decode_cudagraph_max_bs
            and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
        num_actual_tokens = self.vllm_config.pad_for_cudagraph(
            m.num_actual_tokens)
        batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)

        self.spec_state_indices_tensor[:num_spec_decodes].copy_(
            spec_state_indices_tensor, non_blocking=True)
        spec_state_indices_tensor = self.spec_state_indices_tensor[:
                                                                   batch_size]
        spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)

        self.spec_sequence_masks[:num_spec_decodes].copy_(
            spec_sequence_masks, non_blocking=True)
        spec_sequence_masks = self.spec_sequence_masks[:batch_size]
        spec_sequence_masks[num_spec_decodes:].fill_(False)

        assert spec_token_masks is not None
        self.spec_token_masks[:spec_token_masks.size(0)].copy_(
            spec_token_masks, non_blocking=True)
        spec_token_masks = self.spec_token_masks[:num_actual_tokens]
        spec_token_masks[spec_token_masks.size(0):].fill_(False)

        self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
            spec_query_start_loc, non_blocking=True)
        spec_num_query_tokens = spec_query_start_loc[
            -1]  # type: ignore[index]
        spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1]
        spec_query_start_loc[num_spec_decodes +
                             1:].fill_(spec_num_query_tokens)

        self.num_accepted_tokens[:num_spec_decodes].copy_(
            num_accepted_tokens, non_blocking=True)
        num_accepted_tokens = self.num_accepted_tokens[:batch_size]
        num_accepted_tokens[num_spec_decodes:].fill_(1)

    if (self.use_full_cuda_graph and num_prefills == 0
            and num_spec_decodes == 0
            and num_decodes <= self.decode_cudagraph_max_bs):
        num_actual_tokens = self.vllm_config.pad_for_cudagraph(
            m.num_actual_tokens)
        batch_size = num_actual_tokens

        self.non_spec_state_indices_tensor[:num_decodes].copy_(
            non_spec_state_indices_tensor, non_blocking=True)
        non_spec_state_indices_tensor = \
            self.non_spec_state_indices_tensor[:batch_size]
        non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)

        self.non_spec_query_start_loc[:num_decodes + 1].copy_(
            non_spec_query_start_loc, non_blocking=True)
        non_spec_num_query_tokens = non_spec_query_start_loc[
            -1]  # type: ignore[index]
        non_spec_query_start_loc = \
            self.non_spec_query_start_loc[:batch_size + 1]
        non_spec_query_start_loc[num_decodes +
                                 1:].fill_(non_spec_num_query_tokens)

    attn_metadata = GDNAttentionMetadata(
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        num_spec_decodes=num_spec_decodes,
        num_spec_decode_tokens=num_spec_decode_tokens,
        num_actual_tokens=num_actual_tokens,
        has_initial_state=has_initial_state,
        spec_query_start_loc=spec_query_start_loc,
        non_spec_query_start_loc=non_spec_query_start_loc,
        spec_state_indices_tensor=spec_state_indices_tensor,
        non_spec_state_indices_tensor=non_spec_state_indices_tensor,
        spec_sequence_masks=spec_sequence_masks,
        spec_token_masks=spec_token_masks,
        num_accepted_tokens=num_accepted_tokens,
        nums_dict=nums_dict,
        batch_ptr=batch_ptr,
        token_chunk_offset_ptr=token_chunk_offset_ptr,
    )
    return attn_metadata

build_for_cudagraph_capture ΒΆ

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
)

This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba.

Source code in vllm/v1/attention/backends/gdn_attn.py
def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata):
    """
    This method builds the metadata for full cudagraph capture.
    Currently, only decode is supported for full cudagraphs with Mamba.
    """
    m = common_attn_metadata

    assert (
        m.num_reqs <= self.decode_cudagraph_max_bs
        and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
            f"GDN only supports decode-only full CUDAGraph capture. "
            f"Make sure batch size ({m.num_reqs}) <= "
            f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
            f"and number of tokens ({m.num_actual_tokens}) <= "
            f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")

    num_accepted_tokens = torch.diff(m.query_start_loc)
    num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
    m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()

    return self.build(0, m, num_accepted_tokens,
                      num_decode_draft_tokens_cpu)