def rocm_aiter_rotary_emb(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, cos_sin_cache: torch.Tensor,
head_size: int, rotary_dim: int,
is_neox_style: bool):
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
positions,
sin,
cos,
query_,
key_,
rotate_style,
False,
)
query = query.view(query_shape)
key = key.view(key_shape)