vllm.utils.deep_gemm ¶
Compatibility wrapper for DeepGEMM API changes.
Users of vLLM should always import only these wrappers.
__all__ module-attribute
¶
__all__ = [
"calc_diff",
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"fp8_mqa_logits",
"fp8_paged_mqa_logits",
"get_paged_mqa_logits_metadata",
"per_block_cast_to_fp8",
"is_deep_gemm_e8m0_used",
"is_deep_gemm_supported",
"get_num_sms",
"should_use_deepgemm_for_fp8_linear",
"get_col_major_tma_aligned_tensor",
]
_fp8_paged_mqa_logits_impl module-attribute
¶
_get_mn_major_tma_aligned_tensor_impl module-attribute
¶
_get_paged_mqa_logits_metadata_impl module-attribute
¶
_align ¶
_lazy_init ¶
Import deep_gemm and resolve symbols on first use.
Source code in vllm/utils/deep_gemm.py
_missing ¶
Placeholder for unavailable DeepGEMM backend.
calc_diff ¶
Return a global difference metric for unit tests.
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing torch.testing.assert_close
to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report 1 - sim
. Once kernel accuracy improves this helper can be removed.
Source code in vllm/utils/deep_gemm.py
fp8_gemm_nt ¶
fp8_m_grouped_gemm_nt_masked ¶
fp8_mqa_logits ¶
fp8_mqa_logits(
q: Tensor,
kv: tuple[Tensor, Tensor],
weights: Tensor,
cu_seqlen_ks: Tensor,
cu_seqlen_ke: Tensor,
) -> Tensor
Compute FP8 MQA logits for a single sequence without KV paging.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q | Tensor | Query tensor of shape [M, H, D]. Casted to | required |
kv | tuple[Tensor, Tensor] | Tuple | required |
weights | Tensor | weights of shape [M, H], dtype | required |
cu_seqlen_ks | Tensor | Start indices (inclusive) for valid K per query position, shape [M], dtype int32. | required |
cu_seqlen_ke | Tensor | End indices (exclusive) for valid K per query position, shape [M], dtype int32. | required |
Returns:
Type | Description |
---|---|
Tensor | Logits tensor of shape [M, N], dtype |
Source code in vllm/utils/deep_gemm.py
fp8_paged_mqa_logits ¶
fp8_paged_mqa_logits(
q_fp8: Tensor,
kv_cache_fp8: Tensor,
weights: Tensor,
context_lens: Tensor,
block_tables: Tensor,
schedule_metadata: Tensor,
max_model_len: int,
) -> Tensor
Compute FP8 MQA logits using paged KV-cache.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q_fp8 | Tensor | Query tensor of shape [B, next_n, H, D]. Casted to | required |
kv_cache_fp8 | Tensor | Paged KV-cache in packed FP8+scale layout with shape [num_blocks, block_size, 1, D+4], dtype | required |
weights | Tensor | Tensor of shape [B * next_n, H], dtype | required |
context_lens | Tensor | Tensor of shape [B], dtype int32; effective context length for each batch element. | required |
block_tables | Tensor | Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache. | required |
schedule_metadata | Tensor | Returned by | required |
max_model_len | int | Maximum sequence length used to size the logits output. | required |
Returns:
Type | Description |
---|---|
Tensor | Logits tensor of shape [B * next_n, max_model_len], dtype |
Tensor |
|
Source code in vllm/utils/deep_gemm.py
get_col_major_tma_aligned_tensor ¶
Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor
Source code in vllm/utils/deep_gemm.py
get_paged_mqa_logits_metadata ¶
Build scheduling metadata for paged MQA logits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context_lens | Tensor | Tensor of shape [B], dtype int32; effective context length per batch element. | required |
block_size | int | KV-cache block size in tokens (e.g., 64). | required |
num_sms | int | Number of SMs available. 132 for Hopper | required |
Returns:
Type | Description |
---|---|
Tensor | Backend-specific tensor consumed by |
Tensor | schedule work across SMs. |
Source code in vllm/utils/deep_gemm.py
is_deep_gemm_e8m0_used cached
¶
is_deep_gemm_e8m0_used() -> bool
Return True
if vLLM is configured to use DeepGEMM " "E8M0 scale on a Hopper or Blackwell-class GPU.
Source code in vllm/utils/deep_gemm.py
is_deep_gemm_supported cached
¶
is_deep_gemm_supported() -> bool
Return True
if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported.
Source code in vllm/utils/deep_gemm.py
m_grouped_fp8_gemm_nt_contiguous ¶
per_block_cast_to_fp8 ¶
per_block_cast_to_fp8(
x: Tensor,
block_size: list[int] = DEFAULT_BLOCK_SIZE,
use_ue8m0: bool = False,
) -> tuple[Tensor, Tensor]