Skip to content

vllm.compilation.fusion

FP4_DTYPE module-attribute

FP4_DTYPE = uint8

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

FUSED_OPS module-attribute

FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
    FusedRMSQuantKey(kFp8StaticTensorSym, False): default,
    FusedRMSQuantKey(kFp8StaticTensorSym, True): default,
    FusedRMSQuantKey(kFp8DynamicTokenSym, False): default,
    FusedRMSQuantKey(kFp8DynamicTokenSym, True): default,
}

QUANT_OPS module-attribute

QUANT_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: default,
    kFp8DynamicTensorSym: default,
    kFp8DynamicTokenSym: default,
}

RMS_ADD_OP module-attribute

RMS_ADD_OP = default

RMS_OP module-attribute

RMS_OP = default

logger module-attribute

logger = init_logger(__name__)

FusedAddRMSNormDynamicQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
                 group_shape: GroupShape = GroupShape.PER_TOKEN,
                 symmetric=True):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(fused_add=True,
                               quant=QuantKey(dtype=quant_dtype,
                                              scale=scale,
                                              symmetric=symmetric))
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(result: torch.Tensor, input: torch.Tensor,
                    residual: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at = auto_functionalized(RMS_ADD_OP,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     epsilon=self.epsilon)
            at1 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at[1],
                                      scale=scale,
                                      scale_ub=None)

            # result, residual, scale
            return at1[1], at[2], at1[2]

        def replacement(result: torch.Tensor, input: torch.Tensor,
                        residual: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon,
                                     scale_ub=None,
                                     residual=residual)

            # result, residual, scale
            return at[1], at[3], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape = PER_TOKEN,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(self,
             epsilon: float,
             quant_dtype: torch.dtype,
             group_shape: GroupShape = GroupShape.PER_TOKEN,
             symmetric=True):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(fused_add=True,
                           quant=QuantKey(dtype=quant_dtype,
                                          scale=scale,
                                          symmetric=symmetric))
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):

    def pattern(result: torch.Tensor, input: torch.Tensor,
                residual: torch.Tensor, weight: torch.Tensor,
                scale: torch.Tensor):
        at = auto_functionalized(RMS_ADD_OP,
                                 input=input,
                                 residual=residual,
                                 weight=weight,
                                 epsilon=self.epsilon)
        at1 = auto_functionalized(self.QUANT_OP,
                                  result=result,
                                  input=at[1],
                                  scale=scale,
                                  scale_ub=None)

        # result, residual, scale
        return at1[1], at[2], at1[2]

    def replacement(result: torch.Tensor, input: torch.Tensor,
                    residual: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
        at = auto_functionalized(self.FUSED_OP,
                                 result=result,
                                 input=input,
                                 weight=weight,
                                 scale=scale,
                                 epsilon=self.epsilon,
                                 scale_ub=None,
                                 residual=residual)

        # result, residual, scale
        return at[1], at[3], at[2]

    inputs = [
        torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
        empty_bf16(5, 4),  # input
        empty_bf16(5, 4),  # residual
        empty_bf16(1, 5),  # weight
        empty_fp32(1, 1)  # scale
    ]

    pm.register_replacement(
        pattern,
        replacement,
        inputs,
        pm.fwd_only,
        pm_pass,
    )

FusedAddRMSNormStaticQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
                 symmetric=True):
        key = FusedRMSQuantKey(fused_add=True,
                               quant=QuantKey(dtype=quant_dtype,
                                              scale=kStaticTensorScale,
                                              symmetric=symmetric))
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(result: torch.Tensor, input: torch.Tensor,
                    residual: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at = auto_functionalized(RMS_ADD_OP,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     epsilon=self.epsilon)
            at1 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at[1],
                                      scale=scale)

            # result, residual
            return at1[1], at[2]

        def replacement(result: torch.Tensor, input: torch.Tensor,
                        residual: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon)

            # result, residual
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float, quant_dtype: dtype, symmetric=True
)
Source code in vllm/compilation/fusion.py
def __init__(self,
             epsilon: float,
             quant_dtype: torch.dtype,
             symmetric=True):
    key = FusedRMSQuantKey(fused_add=True,
                           quant=QuantKey(dtype=quant_dtype,
                                          scale=kStaticTensorScale,
                                          symmetric=symmetric))
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):

    def pattern(result: torch.Tensor, input: torch.Tensor,
                residual: torch.Tensor, weight: torch.Tensor,
                scale: torch.Tensor):
        at = auto_functionalized(RMS_ADD_OP,
                                 input=input,
                                 residual=residual,
                                 weight=weight,
                                 epsilon=self.epsilon)
        at1 = auto_functionalized(self.QUANT_OP,
                                  result=result,
                                  input=at[1],
                                  scale=scale)

        # result, residual
        return at1[1], at[2]

    def replacement(result: torch.Tensor, input: torch.Tensor,
                    residual: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
        at = auto_functionalized(self.FUSED_OP,
                                 result=result,
                                 input=input,
                                 residual=residual,
                                 weight=weight,
                                 scale=scale,
                                 epsilon=self.epsilon)

        # result, residual
        return at[1], at[2]

    inputs = [
        torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
        empty_bf16(5, 4),  # input
        empty_bf16(5, 4),  # residual
        empty_bf16(1, 5),  # weight
        empty_fp32(1, 1)  # scale
    ]

    pm.register_replacement(
        pattern,
        replacement,
        inputs,
        pm.fwd_only,
        pm_pass,
    )

FusedRMSQuantKey

Bases: NamedTuple

Named tuple for identifying the type of RMSNorm + quant fusion. quant: type of quantization fused_add: does the op also perform the residual add

Source code in vllm/compilation/fusion.py
class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
    quant: QuantKey
    fused_add: bool

    def __str__(self):
        return (f"FusedQuantKey({self.quant}, with"
                f"{'' if self.fused_add else 'out'} residual)")

fused_add instance-attribute

fused_add: bool

quant instance-attribute

quant: QuantKey

__str__

__str__()
Source code in vllm/compilation/fusion.py
def __str__(self):
    return (f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)")

RMSNormDynamicQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
                 group_shape: GroupShape = GroupShape.PER_TOKEN,
                 symmetric=True):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(fused_add=False,
                               quant=QuantKey(dtype=quant_dtype,
                                              scale=scale,
                                              symmetric=symmetric))
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(result: torch.Tensor, result_rms: torch.Tensor,
                    input: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(RMS_OP,
                                      result=result_rms,
                                      input=input,
                                      weight=weight,
                                      epsilon=self.epsilon)
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at1[1],
                                      scale=scale,
                                      scale_ub=None)

            # result, scale
            return at2[1], at2[2]

        def replacement(result: torch.Tensor, result_rms: torch.Tensor,
                        input: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon,
                                     scale_ub=None,
                                     residual=None)

            # result, scale
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape = PER_TOKEN,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(self,
             epsilon: float,
             quant_dtype: torch.dtype,
             group_shape: GroupShape = GroupShape.PER_TOKEN,
             symmetric=True):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(fused_add=False,
                           quant=QuantKey(dtype=quant_dtype,
                                          scale=scale,
                                          symmetric=symmetric))
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):

    def pattern(result: torch.Tensor, result_rms: torch.Tensor,
                input: torch.Tensor, weight: torch.Tensor,
                scale: torch.Tensor):
        at1 = auto_functionalized(RMS_OP,
                                  result=result_rms,
                                  input=input,
                                  weight=weight,
                                  epsilon=self.epsilon)
        at2 = auto_functionalized(self.QUANT_OP,
                                  result=result,
                                  input=at1[1],
                                  scale=scale,
                                  scale_ub=None)

        # result, scale
        return at2[1], at2[2]

    def replacement(result: torch.Tensor, result_rms: torch.Tensor,
                    input: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
        at = auto_functionalized(self.FUSED_OP,
                                 result=result,
                                 input=input,
                                 weight=weight,
                                 scale=scale,
                                 epsilon=self.epsilon,
                                 scale_ub=None,
                                 residual=None)

        # result, scale
        return at[1], at[2]

    inputs = [
        torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
        empty_bf16(5, 4),  # result_rms
        empty_bf16(5, 4),  # input
        empty_bf16(1, 5),  # weight
        empty_fp32(1, 1)  # scale
    ]

    pm.register_replacement(
        pattern,
        replacement,
        inputs,
        pm.fwd_only,
        pm_pass,
    )

RMSNormQuantFusionPass

Bases: VllmPatternMatcherPass

This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. It also supports fused_add_rms_norm.

Source code in vllm/compilation/fusion.py
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
    """
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
    """

    @enable_fake_mode
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="rmsnorm_quant_fusion_pass")

        for epsilon in [1e-5, 1e-6]:
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon,
                                      FP8_DTYPE).register(self.patterns)

            # Fuse fused_add_rms_norm + static fp8 quant
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
                self.patterns)

            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon,
                                       FP8_DTYPE).register(self.patterns)

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
                self.patterns)

        self.dump_patterns(config, self.patterns)

    @VllmInductorPass.time_and_log
    def __call__(self, graph: fx.Graph):
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
        return self.hash_source(self, RMSNormQuantPattern,
                                RMSNormStaticQuantPattern,
                                RMSNormDynamicQuantPattern,
                                FusedAddRMSNormStaticQuantPattern,
                                FusedAddRMSNormDynamicQuantPattern)

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="rmsnorm_quant_fusion_pass"
)

__call__

__call__(graph: Graph)
Source code in vllm/compilation/fusion.py
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
    self.matched_count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", self.matched_count)

__init__

__init__(config: VllmConfig)
Source code in vllm/compilation/fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig):
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="rmsnorm_quant_fusion_pass")

    for epsilon in [1e-5, 1e-6]:
        # Fuse rms_norm + static fp8 quant
        RMSNormStaticQuantPattern(epsilon,
                                  FP8_DTYPE).register(self.patterns)

        # Fuse fused_add_rms_norm + static fp8 quant
        FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
            self.patterns)

        # Fuse rms_norm + dynamic per-token fp8 quant
        RMSNormDynamicQuantPattern(epsilon,
                                   FP8_DTYPE).register(self.patterns)

        # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
        FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
            self.patterns)

    self.dump_patterns(config, self.patterns)

uuid

uuid() -> Any
Source code in vllm/compilation/fusion.py
def uuid(self) -> Any:
    return self.hash_source(self, RMSNormQuantPattern,
                            RMSNormStaticQuantPattern,
                            RMSNormDynamicQuantPattern,
                            FusedAddRMSNormStaticQuantPattern,
                            FusedAddRMSNormDynamicQuantPattern)

RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormQuantPattern:

    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype

        assert key.quant in QUANT_OPS, \
            f"unsupported quantization scheme {key.quant}"
        self.QUANT_OP = QUANT_OPS[key.quant]

        assert key in FUSED_OPS, \
            f"unsupported fused rmsnorm+quant op for {key}"
        self.FUSED_OP = FUSED_OPS[key]

FUSED_OP instance-attribute

FUSED_OP = FUSED_OPS[key]

QUANT_OP instance-attribute

QUANT_OP = QUANT_OPS[quant]

epsilon instance-attribute

epsilon = epsilon

quant_dtype instance-attribute

quant_dtype = dtype

__init__

__init__(epsilon: float, key: FusedRMSQuantKey)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
    self.epsilon = epsilon
    self.quant_dtype = key.quant.dtype

    assert key.quant in QUANT_OPS, \
        f"unsupported quantization scheme {key.quant}"
    self.QUANT_OP = QUANT_OPS[key.quant]

    assert key in FUSED_OPS, \
        f"unsupported fused rmsnorm+quant op for {key}"
    self.FUSED_OP = FUSED_OPS[key]

RMSNormStaticQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormStaticQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
                 symmetric=True):
        fused_key = FusedRMSQuantKey(fused_add=False,
                                     quant=QuantKey(dtype=quant_dtype,
                                                    scale=kStaticTensorScale,
                                                    symmetric=symmetric))
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
        def pattern(result: torch.Tensor, result_rms: torch.Tensor,
                    input: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(RMS_OP,
                                      result=result_rms,
                                      input=input,
                                      weight=weight,
                                      epsilon=self.epsilon)
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at1[1],
                                      scale=scale)

            # result
            return at2[1]

        def replacement(result: torch.Tensor, result_rms: torch.Tensor,
                        input: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon)

            # result
            return at[1]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
                                pm_pass)

__init__

__init__(
    epsilon: float, quant_dtype: dtype, symmetric=True
)
Source code in vllm/compilation/fusion.py
def __init__(self,
             epsilon: float,
             quant_dtype: torch.dtype,
             symmetric=True):
    fused_key = FusedRMSQuantKey(fused_add=False,
                                 quant=QuantKey(dtype=quant_dtype,
                                                scale=kStaticTensorScale,
                                                symmetric=symmetric))
    super().__init__(epsilon, fused_key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    # Cannot use methods, as the self argument affects tracing
    def pattern(result: torch.Tensor, result_rms: torch.Tensor,
                input: torch.Tensor, weight: torch.Tensor,
                scale: torch.Tensor):
        at1 = auto_functionalized(RMS_OP,
                                  result=result_rms,
                                  input=input,
                                  weight=weight,
                                  epsilon=self.epsilon)
        at2 = auto_functionalized(self.QUANT_OP,
                                  result=result,
                                  input=at1[1],
                                  scale=scale)

        # result
        return at2[1]

    def replacement(result: torch.Tensor, result_rms: torch.Tensor,
                    input: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
        at = auto_functionalized(self.FUSED_OP,
                                 result=result,
                                 input=input,
                                 weight=weight,
                                 scale=scale,
                                 epsilon=self.epsilon)

        # result
        return at[1]

    inputs = [
        torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
        empty_bf16(5, 4),  # result_rms
        empty_bf16(5, 4),  # input
        empty_bf16(1, 5),  # weight
        empty_fp32(1, 1)  # scale
    ]

    pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
                            pm_pass)

empty_bf16

empty_bf16(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

empty_fp32

empty_fp32(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")

empty_i32

empty_i32(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")