Skip to content

vllm.v1.attention.backends.triton_attn

High-Performance Triton-only Attention layer.

logger module-attribute

logger = init_logger(__name__)

TritonAttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/triton_attn.py
class TritonAttentionBackend(AttentionBackend):

    accept_output_buffer: bool = True

    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16, torch.float32]

    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        # Triton Attention supports any head size above 32
        if head_size < 32:
            raise ValueError(
                f"Head size {head_size} is not supported by TritonAttention."
                f"Head sizes need to be larger or equal 32 for this backend. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

    @staticmethod
    def get_name() -> str:
        return "TRITON_ATTN"

    @staticmethod
    def get_impl_cls() -> type["TritonAttentionImpl"]:
        return TritonAttentionImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return TritonAttentionMetadata

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

    @staticmethod
    def use_cascade_attention(*args, **kwargs) -> bool:
        return False

    @staticmethod
    def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
        return TritonAttentionMetadataBuilder

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[TritonAttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
    return TritonAttentionMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[TritonAttentionImpl]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]:
    return TritonAttentionImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")
    return (num_blocks, 2, block_size, num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return TritonAttentionMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_name() -> str:
    return "TRITON_ATTN"

get_supported_dtypes classmethod

get_supported_dtypes() -> list[dtype]
Source code in vllm/v1/attention/backends/triton_attn.py
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
    return [torch.float16, torch.bfloat16, torch.float32]

use_cascade_attention staticmethod

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
    return False

validate_head_size classmethod

validate_head_size(head_size: int) -> None
Source code in vllm/v1/attention/backends/triton_attn.py
@classmethod
def validate_head_size(cls, head_size: int) -> None:
    # Triton Attention supports any head size above 32
    if head_size < 32:
        raise ValueError(
            f"Head size {head_size} is not supported by TritonAttention."
            f"Head sizes need to be larger or equal 32 for this backend. "
            "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
            "FlexAttention backend which supports all head sizes.")

TritonAttentionImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/triton_attn.py
class TritonAttentionImpl(AttentionImpl):

    def fused_output_quant_supported(self, quant_key: QuantKey):
        return quant_key == kFp8StaticTensorSym

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
        attn_type: AttentionType = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[int] = None,
        sinks: Optional[torch.Tensor] = None,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        TritonAttentionBackend.validate_head_size(head_size)

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "TritonAttentionImpl")

        self.fp8_dtype = current_platform.fp8_dtype()

        self.sinks = sinks
        if sinks is not None:
            assert sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads as the number of "
                f"heads in the layer. Sinks shape: {sinks.shape}, "
                f"num_heads: {num_heads}.")

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
        output_block_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with Paged Attention impl. in Triton.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache: shape =
                [num_blocks, 2, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if output_block_scale is not None:
            raise NotImplementedError(
                "fused block_scale output quantization is not yet supported"
                " for TritonAttentionImpl")

        if attn_metadata is None:
            # Profiling run.
            return output

        assert attn_metadata.use_cascade is False

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens
        key_cache, value_cache = kv_cache.unbind(1)

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            if self.kv_cache_dtype.startswith("fp8"):
                key_cache = key_cache.view(self.fp8_dtype)
                value_cache = value_cache.view(self.fp8_dtype)
                # triton kernel does not support uint8 kv_cache
                #  (because some explicit casts (e.g. float8_e4m3fnuz)
                #   are not supported)
            triton_reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

        if self.kv_cache_dtype.startswith("fp8"):
            if key_cache.dtype != self.fp8_dtype:
                key_cache = key_cache.view(self.fp8_dtype)
                value_cache = value_cache.view(self.fp8_dtype)
            num_tokens, num_heads, head_size = query.shape
            assert layer._q_scale_float == 1.0, \
                "A non 1.0 q_scale is not currently supported."
            if current_platform.is_cuda():
                # Skip Q quantization on ROCm and XPU, enable this on cuda
                # only, since dequantizing back to f32 in the attention kernel
                # is not supported.
                query, _ = ops.scaled_fp8_quant(
                    query.reshape(
                        (num_tokens, num_heads * head_size)).contiguous(),
                    layer._q_scale)
                query = query.reshape((num_tokens, num_heads, head_size))

        cu_seqlens_q = attn_metadata.query_start_loc
        seqused_k = attn_metadata.seq_lens
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_seq_len
        block_table = attn_metadata.block_table

        descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

        unified_attention(
            q=query[:num_actual_tokens],
            k=key_cache,
            v=value_cache,
            out=output[:num_actual_tokens],
            cu_seqlens_q=cu_seqlens_q,
            max_seqlen_q=max_seqlen_q,
            seqused_k=seqused_k,
            max_seqlen_k=max_seqlen_k,
            softmax_scale=self.scale,
            causal=True,
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            block_table=block_table,
            softcap=self.logits_soft_cap,
            q_descale=None,  # Not supported
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
            sinks=self.sinks,
            output_scale=output_scale,
        )

        return output

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

fp8_dtype instance-attribute

fp8_dtype = fp8_dtype()

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = float(scale)

sinks instance-attribute

sinks = sinks

sliding_window instance-attribute

sliding_window = (-1, -1)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = DECODER,
    kv_sharing_target_layer_name: Optional[int] = None,
    sinks: Optional[Tensor] = None,
) -> None
Source code in vllm/v1/attention/backends/triton_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[int] = None,
    sinks: Optional[torch.Tensor] = None,
) -> None:
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    if sliding_window is None:
        self.sliding_window = (-1, -1)
    else:
        self.sliding_window = (sliding_window - 1, 0)
    self.kv_cache_dtype = kv_cache_dtype
    if logits_soft_cap is None:
        # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
        logits_soft_cap = 0
    self.logits_soft_cap = logits_soft_cap
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    TritonAttentionBackend.validate_head_size(head_size)

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "TritonAttentionImpl")

    self.fp8_dtype = current_platform.fp8_dtype()

    self.sinks = sinks
    if sinks is not None:
        assert sinks.shape[0] == num_heads, (
            "Sinks must have the same number of heads as the number of "
            f"heads in the layer. Sinks shape: {sinks.shape}, "
            f"num_heads: {num_heads}.")

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: TritonAttentionMetadata,
    output: Optional[Tensor] = None,
    output_scale: Optional[Tensor] = None,
    output_block_scale: Optional[Tensor] = None,
) -> Tensor

Forward pass with Paged Attention impl. in Triton.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
kv_cache Tensor

shape = [num_blocks, 2, block_size, num_kv_heads, head_size]

required
attn_metadata TritonAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/v1/attention/backends/triton_attn.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: TritonAttentionMetadata,
    output: Optional[torch.Tensor] = None,
    output_scale: Optional[torch.Tensor] = None,
    output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with Paged Attention impl. in Triton.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache: shape =
            [num_blocks, 2, block_size, num_kv_heads, head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    assert output is not None, "Output tensor must be provided."

    if output_block_scale is not None:
        raise NotImplementedError(
            "fused block_scale output quantization is not yet supported"
            " for TritonAttentionImpl")

    if attn_metadata is None:
        # Profiling run.
        return output

    assert attn_metadata.use_cascade is False

    # IMPORTANT!
    # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
    # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
    # in this method. For example, `view` and `slice` (or `[:n]`) operations
    # are surprisingly slow even in the case they do not invoke any GPU ops.
    # Minimize the PyTorch ops in this method as much as possible.
    # Whenever making a change in this method, please benchmark the
    # performance to make sure it does not introduce any overhead.

    num_actual_tokens = attn_metadata.num_actual_tokens
    key_cache, value_cache = kv_cache.unbind(1)

    if self.kv_sharing_target_layer_name is None:
        # Reshape the input keys and values and store them in the cache.
        # Skip this if sharing KV cache with an earlier attention layer.
        if self.kv_cache_dtype.startswith("fp8"):
            key_cache = key_cache.view(self.fp8_dtype)
            value_cache = value_cache.view(self.fp8_dtype)
            # triton kernel does not support uint8 kv_cache
            #  (because some explicit casts (e.g. float8_e4m3fnuz)
            #   are not supported)
        triton_reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

    if self.kv_cache_dtype.startswith("fp8"):
        if key_cache.dtype != self.fp8_dtype:
            key_cache = key_cache.view(self.fp8_dtype)
            value_cache = value_cache.view(self.fp8_dtype)
        num_tokens, num_heads, head_size = query.shape
        assert layer._q_scale_float == 1.0, \
            "A non 1.0 q_scale is not currently supported."
        if current_platform.is_cuda():
            # Skip Q quantization on ROCm and XPU, enable this on cuda
            # only, since dequantizing back to f32 in the attention kernel
            # is not supported.
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))

    cu_seqlens_q = attn_metadata.query_start_loc
    seqused_k = attn_metadata.seq_lens
    max_seqlen_q = attn_metadata.max_query_len
    max_seqlen_k = attn_metadata.max_seq_len
    block_table = attn_metadata.block_table

    descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

    unified_attention(
        q=query[:num_actual_tokens],
        k=key_cache,
        v=value_cache,
        out=output[:num_actual_tokens],
        cu_seqlens_q=cu_seqlens_q,
        max_seqlen_q=max_seqlen_q,
        seqused_k=seqused_k,
        max_seqlen_k=max_seqlen_k,
        softmax_scale=self.scale,
        causal=True,
        alibi_slopes=self.alibi_slopes,
        window_size=self.sliding_window,
        block_table=block_table,
        softcap=self.logits_soft_cap,
        q_descale=None,  # Not supported
        k_descale=layer._k_scale.expand(descale_shape),
        v_descale=layer._v_scale.expand(descale_shape),
        sinks=self.sinks,
        output_scale=output_scale,
    )

    return output

fused_output_quant_supported

fused_output_quant_supported(quant_key: QuantKey)
Source code in vllm/v1/attention/backends/triton_attn.py
def fused_output_quant_supported(self, quant_key: QuantKey):
    return quant_key == kFp8StaticTensorSym

TritonAttentionMetadata dataclass

Source code in vllm/v1/attention/backends/triton_attn.py
@dataclass
class TritonAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

    num_actual_tokens: int  # Number of tokens excluding padding.
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table: torch.Tensor
    slot_mapping: torch.Tensor

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]

    # Optional aot scheduling
    scheduler_metadata: Optional[torch.Tensor] = None
    prefix_scheduler_metadata: Optional[torch.Tensor] = None

block_table instance-attribute

block_table: Tensor

common_prefix_len instance-attribute

common_prefix_len: int

cu_prefix_query_lens instance-attribute

cu_prefix_query_lens: Optional[Tensor]

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

prefix_kv_lens instance-attribute

prefix_kv_lens: Optional[Tensor]

prefix_scheduler_metadata class-attribute instance-attribute

prefix_scheduler_metadata: Optional[Tensor] = None

query_start_loc instance-attribute

query_start_loc: Tensor

scheduler_metadata class-attribute instance-attribute

scheduler_metadata: Optional[Tensor] = None

seq_lens instance-attribute

seq_lens: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

suffix_kv_lens instance-attribute

suffix_kv_lens: Optional[Tensor]

use_cascade instance-attribute

use_cascade: bool

__init__

__init__(
    num_actual_tokens: int,
    max_query_len: int,
    query_start_loc: Tensor,
    max_seq_len: int,
    seq_lens: Tensor,
    block_table: Tensor,
    slot_mapping: Tensor,
    use_cascade: bool,
    common_prefix_len: int,
    cu_prefix_query_lens: Optional[Tensor],
    prefix_kv_lens: Optional[Tensor],
    suffix_kv_lens: Optional[Tensor],
    scheduler_metadata: Optional[Tensor] = None,
    prefix_scheduler_metadata: Optional[Tensor] = None,
) -> None

TritonAttentionMetadataBuilder

Bases: AttentionMetadataBuilder[TritonAttentionMetadata]

Source code in vllm/v1/attention/backends/triton_attn.py
class TritonAttentionMetadataBuilder(
        AttentionMetadataBuilder[TritonAttentionMetadata]):
    cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)

        self.block_size = kv_cache_spec.block_size

        model_config = vllm_config.model_config
        self.num_heads_q = model_config.get_num_attention_heads(
            vllm_config.parallel_config)
        self.num_heads_kv = model_config.get_num_kv_heads(
            vllm_config.parallel_config)
        self.headdim = model_config.get_head_size()

    def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata
    ) -> TritonAttentionMetadata:
        attn_metadata = self.build(0, common_attn_metadata)
        # When doing full graph capture, setting seq_lens to
        # max_model_len will cause graph capture to be extremely
        # slow, so here we set it to 1.
        attn_metadata.seq_lens.fill_(1)
        return attn_metadata

    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> TritonAttentionMetadata:
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len

        max_seq_len = common_attn_metadata.max_seq_len
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping

        use_cascade = common_prefix_len > 0

        if use_cascade:
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
                                                device=self.device)
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
                                          device=self.device)
            suffix_kv_lens = (common_attn_metadata.seq_lens_cpu -
                              common_prefix_len)
            suffix_kv_lens = suffix_kv_lens.to(self.device)
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None
            prefix_scheduler_metadata = None

        attn_metadata = TritonAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table=block_table_tensor,
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
            prefix_scheduler_metadata=prefix_scheduler_metadata,
        )
        return attn_metadata

block_size instance-attribute

block_size = block_size

cudagraph_support class-attribute

cudagraph_support: AttentionCGSupport = ALWAYS

headdim instance-attribute

headdim = get_head_size()

num_heads_kv instance-attribute

num_heads_kv = get_num_kv_heads(parallel_config)

num_heads_q instance-attribute

num_heads_q = get_num_attention_heads(parallel_config)

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/triton_attn.py
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):
    super().__init__(kv_cache_spec, layer_names, vllm_config, device)

    self.block_size = kv_cache_spec.block_size

    model_config = vllm_config.model_config
    self.num_heads_q = model_config.get_num_attention_heads(
        vllm_config.parallel_config)
    self.num_heads_kv = model_config.get_num_kv_heads(
        vllm_config.parallel_config)
    self.headdim = model_config.get_head_size()

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> TritonAttentionMetadata
Source code in vllm/v1/attention/backends/triton_attn.py
def build(self,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata,
          fast_build: bool = False) -> TritonAttentionMetadata:
    num_actual_tokens = common_attn_metadata.num_actual_tokens
    max_query_len = common_attn_metadata.max_query_len

    max_seq_len = common_attn_metadata.max_seq_len
    query_start_loc = common_attn_metadata.query_start_loc
    seq_lens = common_attn_metadata.seq_lens
    block_table_tensor = common_attn_metadata.block_table_tensor
    slot_mapping = common_attn_metadata.slot_mapping

    use_cascade = common_prefix_len > 0

    if use_cascade:
        cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                            dtype=torch.int32,
                                            device=self.device)
        prefix_kv_lens = torch.tensor([common_prefix_len],
                                      dtype=torch.int32,
                                      device=self.device)
        suffix_kv_lens = (common_attn_metadata.seq_lens_cpu -
                          common_prefix_len)
        suffix_kv_lens = suffix_kv_lens.to(self.device)
    else:
        cu_prefix_query_lens = None
        prefix_kv_lens = None
        suffix_kv_lens = None
        prefix_scheduler_metadata = None

    attn_metadata = TritonAttentionMetadata(
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
        query_start_loc=query_start_loc,
        max_seq_len=max_seq_len,
        seq_lens=seq_lens,
        block_table=block_table_tensor,
        slot_mapping=slot_mapping,
        use_cascade=use_cascade,
        common_prefix_len=common_prefix_len,
        cu_prefix_query_lens=cu_prefix_query_lens,
        prefix_kv_lens=prefix_kv_lens,
        suffix_kv_lens=suffix_kv_lens,
        prefix_scheduler_metadata=prefix_scheduler_metadata,
    )
    return attn_metadata

build_for_cudagraph_capture

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
) -> TritonAttentionMetadata
Source code in vllm/v1/attention/backends/triton_attn.py
def build_for_cudagraph_capture(
    self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
    attn_metadata = self.build(0, common_attn_metadata)
    # When doing full graph capture, setting seq_lens to
    # max_model_len will cause graph capture to be extremely
    # slow, so here we set it to 1.
    attn_metadata.seq_lens.fill_(1)
    return attn_metadata