Skip to content

vllm.compilation.pass_manager

logger module-attribute

logger = init_logger(__name__)

PostGradPassManager

Bases: CustomGraphPass

The pass manager for post-grad passes. It handles configuration, adding custom passes, and running passes. It supports uuid for the Inductor code cache. That includes torch<2.6 support using pickling (in .inductor_pass.CustomGraphPass).

The order of the post-grad post-passes is: 1. passes (constructor parameter) 2. default passes (NoopEliminationPass, FusionPass) 3. config["post_grad_custom_post_pass"] (if it exists) 4. fix_functionalization This way, all passes operate on a functionalized graph.

Source code in vllm/compilation/pass_manager.py
class PostGradPassManager(CustomGraphPass):
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
    2. default passes (NoopEliminationPass, FusionPass)
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

    def __init__(self):
        self.passes: list[InductorPass] = []

    @with_pattern_match_debug
    def __call__(self, graph: fx.Graph):
        VllmInductorPass.dump_prefix = 0  # reset dump index

        shape = get_pass_context().runtime_shape
        for pass_ in self.passes:
            if pass_.is_applicable_for_shape(shape):
                pass_(graph)
                VllmInductorPass.dump_prefix += 1

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1

        # always run fix_functionalization last
        self.fix_functionalization(graph)
        VllmInductorPass.dump_prefix = None  # Cleanup index

    def configure(self, config: VllmConfig):
        self.pass_config = config.compilation_config.pass_config
        if self.pass_config.enable_noop:
            self.passes += [NoOpEliminationPass(config)]

        if self.pass_config.enable_sequence_parallelism:
            self.passes += [SequenceParallelismPass(config)]
            if self.pass_config.enable_async_tp:
                self.passes += [AsyncTPPass(config)]

        if self.pass_config.enable_fi_allreduce_fusion:
            self.passes += [AllReduceFusionPass(config)]

        if self.pass_config.enable_fusion:
            self.passes += [RMSNormQuantFusionPass(config)]
            self.passes += [ActivationQuantFusionPass(config)]

        if self.pass_config.enable_attn_fusion:
            self.passes += [AttnFusionPass(config)]

        # needs a functional graph
        self.post_cleanup = PostCleanupPass(config)
        self.fix_functionalization = FixFunctionalizationPass(config)

    def add(self, pass_: InductorPass):
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

    def uuid(self):
        """
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
        """
        state = {"pass_config": self.pass_config.uuid(), "passes": []}
        for pass_ in self.passes:
            state["passes"].append(pass_.uuid())
        state["passes"].append(self.fix_functionalization.uuid())
        return InductorPass.hash_dict(state)

passes instance-attribute

passes: list[InductorPass] = []

__call__

__call__(graph: Graph)
Source code in vllm/compilation/pass_manager.py
@with_pattern_match_debug
def __call__(self, graph: fx.Graph):
    VllmInductorPass.dump_prefix = 0  # reset dump index

    shape = get_pass_context().runtime_shape
    for pass_ in self.passes:
        if pass_.is_applicable_for_shape(shape):
            pass_(graph)
            VllmInductorPass.dump_prefix += 1

    # post-cleanup goes before fix_functionalization
    # because it requires a functional graph
    self.post_cleanup(graph)
    VllmInductorPass.dump_prefix += 1

    # always run fix_functionalization last
    self.fix_functionalization(graph)
    VllmInductorPass.dump_prefix = None  # Cleanup index

__init__

__init__()
Source code in vllm/compilation/pass_manager.py
def __init__(self):
    self.passes: list[InductorPass] = []

add

add(pass_: InductorPass)
Source code in vllm/compilation/pass_manager.py
def add(self, pass_: InductorPass):
    assert isinstance(pass_, InductorPass)
    self.passes.append(pass_)

configure

configure(config: VllmConfig)
Source code in vllm/compilation/pass_manager.py
def configure(self, config: VllmConfig):
    self.pass_config = config.compilation_config.pass_config
    if self.pass_config.enable_noop:
        self.passes += [NoOpEliminationPass(config)]

    if self.pass_config.enable_sequence_parallelism:
        self.passes += [SequenceParallelismPass(config)]
        if self.pass_config.enable_async_tp:
            self.passes += [AsyncTPPass(config)]

    if self.pass_config.enable_fi_allreduce_fusion:
        self.passes += [AllReduceFusionPass(config)]

    if self.pass_config.enable_fusion:
        self.passes += [RMSNormQuantFusionPass(config)]
        self.passes += [ActivationQuantFusionPass(config)]

    if self.pass_config.enable_attn_fusion:
        self.passes += [AttnFusionPass(config)]

    # needs a functional graph
    self.post_cleanup = PostCleanupPass(config)
    self.fix_functionalization = FixFunctionalizationPass(config)

uuid

uuid()

The PostGradPassManager is set as a custom pass in the Inductor and affects compilation caching. Its uuid depends on the UUIDs of all dependent passes and the pass config. See InductorPass for more info.

Source code in vllm/compilation/pass_manager.py
def uuid(self):
    """
    The PostGradPassManager is set as a custom pass in the Inductor and
    affects compilation caching. Its uuid depends on the UUIDs of all
    dependent passes and the pass config. See InductorPass for more info.
    """
    state = {"pass_config": self.pass_config.uuid(), "passes": []}
    for pass_ in self.passes:
        state["passes"].append(pass_.uuid())
    state["passes"].append(self.fix_functionalization.uuid())
    return InductorPass.hash_dict(state)

with_pattern_match_debug

with_pattern_match_debug(fn)

Function decorator that turns on inductor pattern match debug for the duration of the call. Used to avoid logging builtin Inductor pattern matching.

Source code in vllm/compilation/pass_manager.py
def with_pattern_match_debug(fn):
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper