Skip to content

vllm.model_executor.layers.fla.ops.l2norm

BT_LIST module-attribute

BT_LIST = [8, 16, 32, 64, 128]

USE_DEFAULT_FLA_NORM module-attribute

USE_DEFAULT_FLA_NORM = int(
    getenv("USE_DEFAULT_FLA_NORM", "0")
)

l2norm_fwd

l2norm_fwd(
    x: Tensor,
    eps: float = 1e-06,
    output_dtype: Optional[dtype] = None,
)
Source code in vllm/model_executor/layers/fla/ops/l2norm.py
def l2norm_fwd(x: torch.Tensor,
               eps: float = 1e-6,
               output_dtype: Optional[torch.dtype] = None):
    x_shape_og = x.shape
    x = x.view(-1, x.shape[-1])
    # allocate output
    if output_dtype is None:
        y = torch.empty_like(x)
    else:
        y = torch.empty_like(x, dtype=output_dtype)
    assert y.stride(-1) == 1
    T, D = x.shape[0], x.shape[-1]
    # rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
    # Less than 64KB per feature: enqueue fused kernel
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
    if D > BD:
        raise RuntimeError("This layer doesn't support feature dim >= 64KB.")

    if not USE_DEFAULT_FLA_NORM:
        MBLOCK = 32
        # M, N = x.shape
        l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
            x,
            y,
            eps,
            T,
            D,
            MBLOCK,
        )
    else:
        if D <= 512:
            NB = triton.cdiv(T, 2048)

            def grid(meta):
                return (triton.cdiv(T, meta['BT']), )

            l2norm_fwd_kernel[grid](
                x,
                y,
                eps,
                NB=NB,
                T=T,
                D=D,
                BD=BD,
            )
        else:
            l2norm_fwd_kernel1[(T, )](
                x,
                y,
                eps=eps,
                D=D,
                BD=BD,
            )

    return y.view(x_shape_og)

l2norm_fwd_kernel

l2norm_fwd_kernel(
    x,
    y,
    eps,
    NB,
    T,
    D: constexpr,
    BT: constexpr,
    BD: constexpr,
)
Source code in vllm/model_executor/layers/fla/ops/l2norm.py
@triton.autotune(configs=[
    triton.Config({'BT': BT}, num_warps=num_warps)
    for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
],
                 key=['D'])
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
    x,
    y,
    eps,
    NB,
    T,
    D: tl.constexpr,
    BT: tl.constexpr,
    BD: tl.constexpr,
):
    i_t = tl.program_id(0)
    p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
    b_var = tl.sum(b_x * b_x, axis=1)
    b_y = b_x / tl.sqrt(b_var + eps)[:, None]
    p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))

l2norm_fwd_kernel1

l2norm_fwd_kernel1(x, y, D, BD: constexpr, eps)
Source code in vllm/model_executor/layers/fla/ops/l2norm.py
@triton.autotune(configs=[
    triton.Config({}, num_warps=num_warps)
    for num_warps in [1, 2, 4, 8, 16, 32]
],
                 key=['D'])
@triton.jit
def l2norm_fwd_kernel1(
    x,
    y,
    D,
    BD: tl.constexpr,
    eps,
):
    i_t = tl.program_id(0)
    x += i_t * D
    y += i_t * D
    # Compute mean and variance
    cols = tl.arange(0, BD)
    mask = cols < D
    b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
    b_var = tl.sum(b_x * b_x, axis=0)
    b_rstd = 1 / tl.sqrt(b_var + eps)
    # tl.store(Rstd + i_t, rstd)
    # Normalize and apply linear transformation
    b_y = b_x * b_rstd
    tl.store(y + cols, b_y, mask=mask)

l2norm_fwd_kernel2

l2norm_fwd_kernel2(
    X, Y, eps, M, N: constexpr, MBLOCK: constexpr
)
Source code in vllm/model_executor/layers/fla/ops/l2norm.py
@triton.jit
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
    xoffset = tl.program_id(0) * MBLOCK
    row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
    xmask = row_idx < M
    rindex = tl.arange(0, N)[None, :]
    xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
    square = tl.broadcast_to(xs * xs, [MBLOCK, N])
    square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
    rsqrt = tl.rsqrt(square_sum + eps)
    tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)