vllm.v1.attention.backends.mla.common ¶
MLA Common Components¶
This file implements common components for MLA implementations.
First we define:
Sq as Q sequence length Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a compute friendly approach, we generally want to use the compute friendly approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) and the data-movement friendly approach for "decode" (i.e. the ratio Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled prefill or decode by the scheduler, but this is something we should probably tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way: * Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1
More Extent Definitions:¶
C Context length, Skv - Sq
H hidden size N number of attention heads Lq latent dimension for Q 1536 in DSV3 Lkv latent dimension for K/V 512 in DSV3 P nope dimension, no rope. 128 in DSV3 R rope dimension, goes through rope. 64 in DSV3 V V head dim. 128 in DSV3
Vector/Matrix Definitions¶
h_t hidden states (input to attention) shape [Sq, H] q_c latent/compressed Q shape [Sq, Lq] q_nope uncompressed Q (no-rope) shape [Sq, N, P] q_pe uncompressed Q (rope) shape [Sq, N, R] kv_c latent/compressed KV shape [Skv, Lkv] k_pe decoupled k position embeddings shape [Skv, R] new_kv_c new kv_c from current iter shape [Sq, Lkv] new_k_pe new k_pe from current iter shape [Sq, R] cache_kv_c cached k_c from previous iters shape [C, Lkv] cache_k_pe cached k_pe from previous iters shape [C, R] W_DQ project h_t to q_c shape [H, Lq] W_UQ project q_c to q_nope shape [Lq, N * P] W_QR project q_c to q_pe shape [Lq, N * R] W_DKV project h_t to kv_c shape [H, Lkv] W_UK project kv_c to k_nope shape [Lkv, N, P] W_KR project h_t to k_pe shape [H, R] W_UV project kv_c to v shape [Lkv, N, V] W_O project v to h_t shape [N * V, H]
Compute Friendly Approach (i.e. "_forward_prefill"):¶
q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(Sq, N, P) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R // V headdim = V // spda_o shape [Sq, N, V] spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v ) return spda_o @ W_O
in the actual code,
kv_b_proj
is [W_UK; W_UV] concatenated per head q_b_proj
is [W_UQ; W_QR] concatenated per head out_proj
is W_O
Data-Movement Friendly Approach (i.e. "_forward_decode"):¶
Runtime q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(-1, N, P) ql_nope = einsum("snh,lnh->snl", q, W_UK) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// MQA with QK headdim = Lkv + R // V headdim = Lkv // spda_o shape [Sq, N, Lkv] // NOTE: this is less compute-friendly since Lkv > P // but is more data-movement friendly since its MQA vs MHA spda_o = scaled_dot_product_attention( torch.cat([ql_nope, q_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1), kv_c )
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) return o.view(-1, N * V) @ self.num_heads @ W_O
Chunked Prefill¶
For chunked prefill we want to use the compute friendly algorithm. We are assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. Sq
) is small.
However, the compute-friendly approach can potentially run out of memory if Skv is large due to: k_nope = (kv_c @ W_UK).view(Skv, N, P)
To mitigate this, we chunk the computation of attention with respect to the current context (i.e. cache_kv_c
and cache_k_pe
) so that we can used a fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage
q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(Sq, N, P) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV // with QK headdim = P + R // V headdim = V // curr_o shape [Sq, N, V] // curr_lse shape [N, Sq], this is just order FA returns curr_o, curr_lse = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), new_v, casual=True, return_softmax_lse=True )
// Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): chunk_start = chunk_idx * MCC chunk_end = min(chunk_start + MCC, C) Sc = chunk_end - chunk_start cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1),
cache_v_chunk,
casual=False,
return_softmax_lse=True
)
curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o,
suffix_lse=curr_lse,
prefix_output=chunk_o,
prefix_lse=chunk_lse,
)
return curr_o @ W_O
FLASHINFER_WORKSPACE_BUFFER_SIZE module-attribute
¶
CudnnPrefillMetadata dataclass
¶
Bases: MLACommonPrefillMetadata
Source code in vllm/v1/attention/backends/mla/common.py
ChunkedContextMetadata dataclass
¶
FlashInferPrefillMetadata dataclass
¶
Bases: MLACommonPrefillMetadata
Source code in vllm/v1/attention/backends/mla/common.py
MLACommonBackend ¶
Bases: AttentionBackend
Source code in vllm/v1/attention/backends/mla/common.py
get_builder_cls staticmethod
¶
get_builder_cls() -> type[MLACommonMetadataBuilder]
get_kv_cache_shape staticmethod
¶
get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/mla/common.py
get_metadata_cls staticmethod
¶
get_metadata_cls() -> type[AttentionMetadata]
get_supported_dtypes classmethod
¶
get_supported_head_sizes classmethod
¶
validate_head_size classmethod
¶
validate_head_size(head_size: int) -> None
Source code in vllm/v1/attention/backends/mla/common.py
MLACommonBaseImpl ¶
Bases: MLAAttentionImpl[A]
, Generic[A]
NOTE: Please read the comment at the top of the file before trying to understand this class
Source code in vllm/v1/attention/backends/mla/common.py
937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 |
|
__init__ ¶
__init__(
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
indexer=None,
q_pad_num_heads: Optional[int] = None,
) -> None
Source code in vllm/v1/attention/backends/mla/common.py
_v_up_proj ¶
Source code in vllm/v1/attention/backends/mla/common.py
process_weights_after_loading ¶
process_weights_after_loading(act_dtype: dtype)
Source code in vllm/v1/attention/backends/mla/common.py
985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 |
|
MLACommonDecodeMetadata dataclass
¶
Source code in vllm/v1/attention/backends/mla/common.py
MLACommonImpl ¶
Bases: MLACommonBaseImpl[M]
, Generic[M]
NOTE: Please read the comment at the top of the file before trying to understand this class
Source code in vllm/v1/attention/backends/mla/common.py
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 |
|
_run_prefill_context_chunk instance-attribute
¶
chunked_prefill_workspace_size instance-attribute
¶
chunked_prefill_workspace_size = (
determine_chunked_prefill_workspace_size(
get_current_vllm_config()
)
)
__init__ ¶
Source code in vllm/v1/attention/backends/mla/common.py
_compute_prefill_context ¶
_compute_prefill_context(
q: Tensor,
kv_c_and_k_pe_cache: Tensor,
attn_metadata: MLACommonMetadata,
k_scale: Tensor,
)
Source code in vllm/v1/attention/backends/mla/common.py
1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 |
|
_context_parallel_compute_prefill_context ¶
_context_parallel_compute_prefill_context(
q: Tensor,
kv_c_and_k_pe_cache: Tensor,
attn_metadata: MLACommonMetadata,
k_scale: Tensor,
dcp_world_size: int,
)
Source code in vllm/v1/attention/backends/mla/common.py
1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 |
|
_flash_attn_varlen_diff_headdims ¶
_flash_attn_varlen_diff_headdims(
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs,
)
Source code in vllm/v1/attention/backends/mla/common.py
_forward_decode abstractmethod
¶
_forward_decode(
q: Union[Tensor, tuple[Tensor, Tensor]],
kv_c_and_k_pe_cache: Tensor,
attn_metadata: M,
layer: AttentionLayer,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/v1/attention/backends/mla/common.py
_forward_prefill ¶
_forward_prefill(
q: Tensor,
kv_c_normed: Tensor,
k_pe: Tensor,
kv_c_and_k_pe_cache: Tensor,
attn_metadata: MLACommonMetadata,
k_scale: Tensor,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/common.py
_run_prefill_context_chunk_cudnn ¶
_run_prefill_context_chunk_cudnn(
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
)
Source code in vllm/v1/attention/backends/mla/common.py
_run_prefill_context_chunk_fa ¶
_run_prefill_context_chunk_fa(
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
)
Source code in vllm/v1/attention/backends/mla/common.py
_run_prefill_context_chunk_fi ¶
_run_prefill_context_chunk_fi(
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
)
Source code in vllm/v1/attention/backends/mla/common.py
_run_prefill_new_tokens_cudnn ¶
_run_prefill_new_tokens_cudnn(
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
)
Source code in vllm/v1/attention/backends/mla/common.py
_run_prefill_new_tokens_fa ¶
_run_prefill_new_tokens_fa(
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
)
Source code in vllm/v1/attention/backends/mla/common.py
_run_prefill_new_tokens_fi ¶
_run_prefill_new_tokens_fi(
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
)
Source code in vllm/v1/attention/backends/mla/common.py
forward ¶
forward(
layer: AttentionLayer,
q: Tensor,
k_c_normed: Tensor,
k_pe: Tensor,
kv_cache: Tensor,
attn_metadata: M,
output: Optional[Tensor] = None,
output_scale: Optional[Tensor] = None,
output_block_scale: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/common.py
1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 |
|
process_weights_after_loading ¶
process_weights_after_loading(act_dtype: dtype)
Source code in vllm/v1/attention/backends/mla/common.py
1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 |
|
MLACommonMetadata dataclass
¶
Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to understand this class
Source code in vllm/v1/attention/backends/mla/common.py
prefill class-attribute
instance-attribute
¶
prefill: Optional[
Union[
MLACommonPrefillMetadata,
FlashInferPrefillMetadata,
CudnnPrefillMetadata,
]
] = None
__init__ ¶
__init__(
num_reqs: int,
max_query_len: int,
max_seq_len: int,
num_actual_tokens: int,
query_start_loc: Tensor,
slot_mapping: Tensor,
num_decodes: int,
num_decode_tokens: int,
num_prefills: int,
head_dim: Optional[int] = None,
decode: Optional[D] = None,
prefill: Optional[
Union[
MLACommonPrefillMetadata,
FlashInferPrefillMetadata,
CudnnPrefillMetadata,
]
] = None,
) -> None
MLACommonMetadataBuilder ¶
Bases: AttentionMetadataBuilder[M]
NOTE: Please read the comment at the top of the file before trying to understand this class
Source code in vllm/v1/attention/backends/mla/common.py
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 |
|
_fi_prefill_chunks instance-attribute
¶
_fi_prefill_chunks: list[
BatchPrefillWithRaggedKVCacheWrapper
] = []
_fi_prefill_main instance-attribute
¶
_fi_prefill_main: Optional[
BatchPrefillWithRaggedKVCacheWrapper
] = None
_global_hyperparameters instance-attribute
¶
_global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(
vllm_config, layer_names, MLACommonImpl
)
)
_workspace_buffer instance-attribute
¶
_workspace_buffer = empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=uint8,
device=device,
)
chunked_prefill_workspace instance-attribute
¶
chunked_prefill_workspace = empty(
(
chunked_prefill_workspace_size
+ chunked_prefill_workspace_size // dcp_world_size,
get_head_size(),
),
dtype=dtype,
device=device,
)
chunked_prefill_workspace_size instance-attribute
¶
cudnn_workspace instance-attribute
¶
cudnn_workspace = empty(
CUDNN_WORKSPACE_SIZE * max_num_seqs,
dtype=int8,
device=device,
)
metadata_cls instance-attribute
¶
metadata_cls = (
metadata_cls
if metadata_cls is not None
else MLACommonMetadata
)
prefill_metadata_cls instance-attribute
¶
prefill_metadata_cls = (
FlashInferPrefillMetadata
if _use_fi_prefill
else CudnnPrefillMetadata
if _use_cudnn_prefill
else MLACommonPrefillMetadata
)
__init__ ¶
__init__(
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: device,
metadata_cls: Optional[type[M]] = None,
)
Source code in vllm/v1/attention/backends/mla/common.py
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 |
|
_build_decode ¶
_build_decode(
block_table_tensor: Tensor,
seq_lens_cpu: Tensor,
seq_lens_device: Tensor,
query_start_loc_cpu: Tensor,
query_start_loc_device: Tensor,
num_decode_tokens: int,
) -> MLACommonDecodeMetadata
Source code in vllm/v1/attention/backends/mla/common.py
_build_fi_prefill_wrappers ¶
_build_fi_prefill_wrappers(
prefill: FlashInferPrefillMetadata,
)
Source code in vllm/v1/attention/backends/mla/common.py
build ¶
build(
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M
Source code in vllm/v1/attention/backends/mla/common.py
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 |
|
build_for_cudagraph_capture ¶
build_for_cudagraph_capture(
common_attn_metadata: CommonAttentionMetadata,
) -> M
This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with MLA.
Source code in vllm/v1/attention/backends/mla/common.py
determine_chunked_prefill_workspace_size staticmethod
¶
determine_chunked_prefill_workspace_size(
vllm_config: VllmConfig,
) -> int
Source code in vllm/v1/attention/backends/mla/common.py
MLACommonPrefillMetadata dataclass
¶
Prefill Specific Metadata
Source code in vllm/v1/attention/backends/mla/common.py
chunked_context class-attribute
instance-attribute
¶
chunked_context: Optional[ChunkedContextMetadata] = None
ChunkedContextMetadata dataclass
¶
Source code in vllm/v1/attention/backends/mla/common.py
cp_chunk_seq_lens class-attribute
instance-attribute
¶
cu_seq_lens_lst class-attribute
instance-attribute
¶
origin_context_lens class-attribute
instance-attribute
¶
__init__ ¶
__init__(
cu_seq_lens: Tensor,
starts: Tensor,
seq_tot: list[int],
max_seq_lens: list[int],
seq_lens: Tensor,
workspace: Tensor,
cp_chunk_seq_lens: Optional[list[list[int]]] = None,
origin_context_lens: Optional[list[int]] = None,
cp_cu_seq_lens: Optional[Tensor] = None,
chunk_size: Optional[int] = None,
cu_seq_lens_lst: Optional[list[list[int]]] = None,
) -> None
dynamic_per_batched_tensor_quant ¶
Source code in vllm/v1/attention/backends/mla/common.py
reorg_kvcache ¶
reorg_kvcache(
allgatered_kv_c_normed: Tensor,
allgatered_k_pe: Tensor,
cp_chunk_seq_lens_lst: list[int],
origin_context_lens: list[int],
cp_world_size: int,
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[Tensor, Tensor]
reorg kvcache after cp local gather to tp layout for attn kernel.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cp_chunk_seq_lens_lst | list[int] | chunk context lengths under CP. | required |
origin_context_lens | list[int] | origin full context lengths under CP. | required |
cp_world_size | int | CP size. | required |
sum_seq_len | int | the sum of cp_chunk_seq_lens_lst. | required |
max_seq_len | int | the max value of cp_chunk_seq_lens_lst. | required |
chunk_size | int | equals to max_context_chunk from chunked_context_metadata building. | required |
chunk_idx | int | chunk idx of chunked_prefill. | required |
toks | int | the number of tokens for local gather cache. | required |