Skip to content

vllm.model_executor.layers.fla.ops.index

prepare_chunk_indices

prepare_chunk_indices(
    cu_seqlens: LongTensor, chunk_size: int
) -> LongTensor
Source code in vllm/model_executor/layers/fla/ops/index.py
@tensor_cache
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
                          chunk_size: int) -> torch.LongTensor:
    indices = torch.cat([
        torch.arange(n)
        for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
    ])
    return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
                       1).to(cu_seqlens)

prepare_chunk_offsets

prepare_chunk_offsets(
    cu_seqlens: LongTensor, chunk_size: int
) -> LongTensor
Source code in vllm/model_executor/layers/fla/ops/index.py
@tensor_cache
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
                          chunk_size: int) -> torch.LongTensor:
    return torch.cat([
        cu_seqlens.new_tensor([0]),
        triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
    ]).cumsum(-1)

prepare_lens

prepare_lens(cu_seqlens: LongTensor) -> LongTensor
Source code in vllm/model_executor/layers/fla/ops/index.py
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
    return cu_seqlens[1:] - cu_seqlens[:-1]