vllm.model_executor.layers.fla.ops.chunk_scaled_dot_kkt ¶
chunk_scaled_dot_kkt_fwd ¶
chunk_scaled_dot_kkt_fwd(
k: Tensor,
beta: Tensor,
g_cumsum: Optional[Tensor] = None,
cu_seqlens: Optional[LongTensor] = None,
chunk_size: int = 64,
output_dtype: dtype = float32,
) -> Tensor
Compute beta * K * K^T.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
k | Tensor | The key tensor of shape | required |
beta | Tensor | The beta tensor of shape | required |
g_cumsum | Tensor | The cumulative sum of the gate tensor of shape | None |
cu_seqlens | LongTensor | The cumulative sequence lengths of the input tensor. Default: None | None |
chunk_size | int | The chunk size. Default: 64. | 64 |
output_dtype | dtype | The dtype of the output tensor. Default: | float32 |
Returns:
Type | Description |
---|---|
Tensor | beta * K * K^T of shape |
Source code in vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
chunk_scaled_dot_kkt_fwd_kernel ¶
chunk_scaled_dot_kkt_fwd_kernel(
k,
beta,
g_cumsum,
A,
cu_seqlens,
chunk_indices,
T,
H: constexpr,
Hg: constexpr,
K: constexpr,
BT: constexpr,
BK: constexpr,
IS_VARLEN: constexpr,
USE_G: constexpr,
)