Skip to content

vllm.compilation.vllm_inductor_pass

logger module-attribute

logger = init_logger(__name__)

PrinterInductorPass

Bases: VllmInductorPass

Source code in vllm/compilation/vllm_inductor_pass.py
class PrinterInductorPass(VllmInductorPass):

    def __init__(self, name: str, config: VllmConfig):
        super().__init__(config)
        self.name = name

    def __call__(self, graph: torch.fx.Graph):
        self.dump_graph(graph, self.name)

name instance-attribute

name = name

__call__

__call__(graph: Graph)
Source code in vllm/compilation/vllm_inductor_pass.py
def __call__(self, graph: torch.fx.Graph):
    self.dump_graph(graph, self.name)

__init__

__init__(name: str, config: VllmConfig)
Source code in vllm/compilation/vllm_inductor_pass.py
def __init__(self, name: str, config: VllmConfig):
    super().__init__(config)
    self.name = name

VllmInductorPass

Bases: InductorPass

An inductor pass with access to vLLM PassConfig. It provides timing, logging, and dumping utilities.

Source code in vllm/compilation/vllm_inductor_pass.py
class VllmInductorPass(InductorPass):
    """
    An inductor pass with access to vLLM PassConfig.
    It provides timing, logging, and dumping utilities.
    """
    dump_prefix: ClassVar[Optional[int]] = None
    """Keep track of pass index for debug dump ordering."""

    def __init__(self, config: VllmConfig):
        self.pass_config = config.compilation_config.pass_config
        self.model_dtype = config.model_config.dtype if config.model_config \
            else None
        self.device = config.device_config.device if config.device_config \
            else None
        self.pass_name = self.__class__.__name__

    @staticmethod
    def time_and_log(call_fn):

        @functools.wraps(call_fn)
        def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
            self.begin()
            self.dump_graph(graph, "before")
            call_fn(self, graph)
            self.dump_graph(graph, "after")
            self.end_and_log()

        return wrapped

    def dump_graph(self, graph: torch.fx.Graph, stage: str):
        i = VllmInductorPass.dump_prefix
        i_str = "" if i is None else f".{i}"
        lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
                               graph.owning_module)

    def begin(self):
        self._start_time = time.perf_counter_ns()

    def end_and_log(self):
        self._end_time = time.perf_counter_ns()
        duration_ms = float(self._end_time - self._start_time) / 1.0e6
        logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)

device instance-attribute

device = device if device_config else None

dump_prefix class-attribute

dump_prefix: Optional[int] = None

Keep track of pass index for debug dump ordering.

model_dtype instance-attribute

model_dtype = dtype if model_config else None

pass_config instance-attribute

pass_config = pass_config

pass_name instance-attribute

pass_name = __name__

__init__

__init__(config: VllmConfig)
Source code in vllm/compilation/vllm_inductor_pass.py
def __init__(self, config: VllmConfig):
    self.pass_config = config.compilation_config.pass_config
    self.model_dtype = config.model_config.dtype if config.model_config \
        else None
    self.device = config.device_config.device if config.device_config \
        else None
    self.pass_name = self.__class__.__name__

begin

begin()
Source code in vllm/compilation/vllm_inductor_pass.py
def begin(self):
    self._start_time = time.perf_counter_ns()

dump_graph

dump_graph(graph: Graph, stage: str)
Source code in vllm/compilation/vllm_inductor_pass.py
def dump_graph(self, graph: torch.fx.Graph, stage: str):
    i = VllmInductorPass.dump_prefix
    i_str = "" if i is None else f".{i}"
    lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
                           graph.owning_module)

end_and_log

end_and_log()
Source code in vllm/compilation/vllm_inductor_pass.py
def end_and_log(self):
    self._end_time = time.perf_counter_ns()
    duration_ms = float(self._end_time - self._start_time) / 1.0e6
    logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)

time_and_log staticmethod

time_and_log(call_fn)
Source code in vllm/compilation/vllm_inductor_pass.py
@staticmethod
def time_and_log(call_fn):

    @functools.wraps(call_fn)
    def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before")
        call_fn(self, graph)
        self.dump_graph(graph, "after")
        self.end_and_log()

    return wrapped

VllmPatternMatcherPass

Bases: VllmInductorPass

A VllmInductorPass that uses the Inductor pattern matcher. Its main use is providing the dump_patterns utility that dumps the Inductor pattern matcher patterns into a file, which greatly aids debugging.

TODO(luka) move more utilities to this pass.

Source code in vllm/compilation/vllm_inductor_pass.py
class VllmPatternMatcherPass(VllmInductorPass):
    """
    A VllmInductorPass that uses the Inductor pattern matcher.
    Its main use is providing the dump_patterns utility that dumps the
    Inductor pattern matcher patterns into a file, which greatly aids debugging.

    TODO(luka) move more utilities to this pass.
    """
    matched_count: int = 0
    """The number of matched patterns in the pass."""

    _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
        r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")

    def _replace_op_overloads(self, string: str) -> str:
        """Replace <OpOverload(..., ...)> with nicer formulations"""
        return self._OP_OVERLOAD_PATTERN.sub(
            lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
            string,
        )

    def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
        """
        If debug dumping is enabled, dump the Inductor pattern-matcher patterns
        into the debug_dump_path folder next to the dumped fx graphs.

        This method does its best to print something that looks like Python code
        for easier debugging and potentially navigation. If any errors appear in
        the output, please add to this method.

        TODO(luka): use pattern object to manually produce pattern graph
        """
        debug_dump_path = config.compile_debug_dump_path()
        if not debug_dump_path:
            return

        debug_dump_path.mkdir(parents=True, exist_ok=True)

        from vllm.utils import unique_filepath
        file_path = unique_filepath(
            lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")

        with file_path.open("w") as f:
            print(
                f'# This file was produced by VllmPatternMatcherPass.'
                f'dump_patterns for {self.pass_name}.\n'
                f'# It does its best to produce valid-Python-looking code but'
                f' please add to dump_patterns if there are any errors.\n\n'
                f'from torch._higher_order_ops.auto_functionalize import '
                f'auto_functionalized as auto_functionalized\n'
                f'from torch._inductor.pattern_matcher import *',
                file=f)

            for node, patterns in pm_pass.patterns.items():
                # fix the operator.getitem repr
                if node[1] == operator.getitem:
                    node_repr = f"({repr(node[0])}, operator.getitem)"
                else:
                    node_repr = repr(node)

                node_repr = self._replace_op_overloads(node_repr)

                print(f"\n\n# Patterns for op: {node_repr}", file=f)
                for i, pattern in enumerate(patterns):
                    # reserve auto_functionalized ahead of time
                    pp = PatternPrettyPrinter()
                    pp.namespace.create_name("auto_functionalized", None)

                    # Assemble pattern
                    out_node = pp.pretty_print(pattern.pattern)
                    pattern_repr = "\n".join([f"def pattern_{i}():"] + [
                        f"{pp.memoized_objs_names[key]} = "
                        f"{pp.memoized_objs_pp[key]}"
                        for key in pp.memoized_objs_names
                    ] + [f"return {out_node}"]).replace("\n", "\n    ")

                    pattern_repr = self._replace_op_overloads(pattern_repr)
                    print(f"{pattern_repr}\n", file=f)

_OP_OVERLOAD_PATTERN class-attribute

_OP_OVERLOAD_PATTERN: Pattern = compile(
    "<OpOverload\\(op='([^']*)', overload='([^']*)'\\)>"
)

matched_count class-attribute instance-attribute

matched_count: int = 0

The number of matched patterns in the pass.

_replace_op_overloads

_replace_op_overloads(string: str) -> str

Replace with nicer formulations

Source code in vllm/compilation/vllm_inductor_pass.py
def _replace_op_overloads(self, string: str) -> str:
    """Replace <OpOverload(..., ...)> with nicer formulations"""
    return self._OP_OVERLOAD_PATTERN.sub(
        lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
        string,
    )

dump_patterns

dump_patterns(
    config: VllmConfig, pm_pass: PatternMatcherPass
)

If debug dumping is enabled, dump the Inductor pattern-matcher patterns into the debug_dump_path folder next to the dumped fx graphs.

This method does its best to print something that looks like Python code for easier debugging and potentially navigation. If any errors appear in the output, please add to this method.

TODO(luka): use pattern object to manually produce pattern graph

Source code in vllm/compilation/vllm_inductor_pass.py
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
    """
    If debug dumping is enabled, dump the Inductor pattern-matcher patterns
    into the debug_dump_path folder next to the dumped fx graphs.

    This method does its best to print something that looks like Python code
    for easier debugging and potentially navigation. If any errors appear in
    the output, please add to this method.

    TODO(luka): use pattern object to manually produce pattern graph
    """
    debug_dump_path = config.compile_debug_dump_path()
    if not debug_dump_path:
        return

    debug_dump_path.mkdir(parents=True, exist_ok=True)

    from vllm.utils import unique_filepath
    file_path = unique_filepath(
        lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")

    with file_path.open("w") as f:
        print(
            f'# This file was produced by VllmPatternMatcherPass.'
            f'dump_patterns for {self.pass_name}.\n'
            f'# It does its best to produce valid-Python-looking code but'
            f' please add to dump_patterns if there are any errors.\n\n'
            f'from torch._higher_order_ops.auto_functionalize import '
            f'auto_functionalized as auto_functionalized\n'
            f'from torch._inductor.pattern_matcher import *',
            file=f)

        for node, patterns in pm_pass.patterns.items():
            # fix the operator.getitem repr
            if node[1] == operator.getitem:
                node_repr = f"({repr(node[0])}, operator.getitem)"
            else:
                node_repr = repr(node)

            node_repr = self._replace_op_overloads(node_repr)

            print(f"\n\n# Patterns for op: {node_repr}", file=f)
            for i, pattern in enumerate(patterns):
                # reserve auto_functionalized ahead of time
                pp = PatternPrettyPrinter()
                pp.namespace.create_name("auto_functionalized", None)

                # Assemble pattern
                out_node = pp.pretty_print(pattern.pattern)
                pattern_repr = "\n".join([f"def pattern_{i}():"] + [
                    f"{pp.memoized_objs_names[key]} = "
                    f"{pp.memoized_objs_pp[key]}"
                    for key in pp.memoized_objs_names
                ] + [f"return {out_node}"]).replace("\n", "\n    ")

                pattern_repr = self._replace_op_overloads(pattern_repr)
                print(f"{pattern_repr}\n", file=f)