vllm.model_executor.layers.fla.ops.layernorm_guard ¶
LayerNormFn ¶
Bases: Function
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
forward staticmethod
¶
forward(
ctx,
x,
weight,
bias,
z=None,
eps=1e-06,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
)
If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
LayerNormGated ¶
Bases: Module
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
__init__ ¶
__init__(
hidden_size,
eps: float = 1e-05,
group_size: Optional[int] = None,
norm_before_gate: bool = True,
device: Optional[device] = None,
dtype: Optional[dtype] = None,
)
If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
forward ¶
If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
RMSNormGated ¶
Bases: Module
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
__init__ ¶
__init__(
hidden_size,
eps: float = 1e-05,
group_size: Optional[int] = None,
norm_before_gate: bool = False,
device: Optional[device] = None,
dtype: Optional[dtype] = None,
)
If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
forward ¶
If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
layer_norm_fwd ¶
layer_norm_fwd(
x: Tensor,
weight: Tensor,
bias: Tensor,
eps: float,
z: Tensor = None,
out: Tensor = None,
group_size: int = None,
norm_before_gate: bool = True,
is_rms_norm: bool = False,
)
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
layer_norm_fwd_kernel ¶
layer_norm_fwd_kernel(
X,
Y,
W,
B,
Z,
Mean,
Rstd,
stride_x_row,
stride_y_row,
stride_z_row,
M,
N,
eps,
BLOCK_N: constexpr,
HAS_BIAS: constexpr,
HAS_Z: constexpr,
NORM_BEFORE_GATE: constexpr,
IS_RMS_NORM: constexpr,
)
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
layernorm_fn ¶
layernorm_fn(
x,
weight,
bias,
z=None,
eps=1e-06,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
)
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
rms_norm_ref ¶
rms_norm_ref(
x,
weight,
bias,
z=None,
eps=1e-06,
group_size=None,
norm_before_gate=True,
upcast=True,
)