Skip to content

vllm.model_executor.layers.fla.ops.utils

COMPILER_MODE module-attribute

COMPILER_MODE = getenv('FLA_COMPILER_MODE') == '1'

FLA_CI_ENV module-attribute

FLA_CI_ENV = getenv('FLA_CI_ENV') == '1'

FLA_GDN_FIX_BT module-attribute

FLA_GDN_FIX_BT = getenv('FLA_GDN_FIX_BT', '0') == '1'

SUPPRESS_LEVEL module-attribute

SUPPRESS_LEVEL = int(
    getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")
)

device module-attribute

device = (
    get_available_device()
    if get_available_device() != "hip"
    else "cuda"
)

device_platform module-attribute

device_platform = _check_platform()

device_torch_lib module-attribute

device_torch_lib = getattr(torch, device)

is_amd module-attribute

is_amd = device_platform == 'amd'

is_intel module-attribute

is_intel = device_platform == 'intel'

is_intel_alchemist module-attribute

is_intel_alchemist = (
    is_intel and "Intel(R) Arc(TM) A" in get_device_name(0)
)

is_nvidia module-attribute

is_nvidia = device_platform == 'nvidia'

is_nvidia_hopper module-attribute

is_nvidia_hopper = is_nvidia and (
    "NVIDIA H" in get_device_name(0)
    or get_device_capability()[0] >= 9
)

logger module-attribute

logger = getLogger(__name__)

use_cuda_graph module-attribute

use_cuda_graph = (
    is_nvidia and get("FLA_USE_CUDA_GRAPH", "0") == "1"
)

Backend

Bases: Enum

Source code in vllm/model_executor/layers/fla/ops/utils.py
class Backend(Enum):
    ADA = 101376  # RTX 4090
    AMPERE = 166912  # A100
    HOPPER = 232448  # H100
    DEFAULT = 102400  # Default

    @classmethod
    def get_shared_memory(cls, arch: str) -> int:
        try:
            return cls[arch.upper()].value
        except KeyError:
            return cls.DEFAULT.value

ADA class-attribute instance-attribute

ADA = 101376

AMPERE class-attribute instance-attribute

AMPERE = 166912

DEFAULT class-attribute instance-attribute

DEFAULT = 102400

HOPPER class-attribute instance-attribute

HOPPER = 232448

get_shared_memory classmethod

get_shared_memory(arch: str) -> int
Source code in vllm/model_executor/layers/fla/ops/utils.py
@classmethod
def get_shared_memory(cls, arch: str) -> int:
    try:
        return cls[arch.upper()].value
    except KeyError:
        return cls.DEFAULT.value

_check_platform cached

_check_platform() -> Literal[
    "nvidia", "amd", "intel", "musa"
]
Source code in vllm/model_executor/layers/fla/ops/utils.py
@functools.cache
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
    device = get_available_device()
    mapping = {
        "cuda": "nvidia",
        "hip": "amd",
        "xpu": "intel",
    }
    # return the mapped value, or the original if not found
    return mapping.get(device, device)

check_shared_mem cached

check_shared_mem(
    arch: str = "none", tensor_idx: int = 0
) -> bool
Source code in vllm/model_executor/layers/fla/ops/utils.py
@functools.cache
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
    try:
        device_shared_mem_list = get_all_max_shared_mem()
        max_shared_memory = device_shared_mem_list[tensor_idx]
        return max_shared_memory >= Backend.get_shared_memory(arch)
    except Exception:
        return False

get_all_max_shared_mem

get_all_max_shared_mem()
Source code in vllm/model_executor/layers/fla/ops/utils.py
def get_all_max_shared_mem():
    try:
        return [
            triton.runtime.driver.active.utils.get_device_properties(i)
            ['max_shared_mem'] for i in range(device_torch_lib.device_count())
        ]
    except BaseException:
        return [-1]

get_available_device cached

get_available_device() -> str
Source code in vllm/model_executor/layers/fla/ops/utils.py
@functools.cache
def get_available_device() -> str:
    try:
        return triton.runtime.driver.active.get_current_target().backend
    except BaseException:
        return 'cpu'

input_guard

input_guard(
    fn: Callable[..., Tensor],
) -> Callable[..., Tensor]

A decorator to make sure all input tensors are contiguous and set the device based on input tensors.

Source code in vllm/model_executor/layers/fla/ops/utils.py
def input_guard(
        fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
    """
    A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
    """

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        contiguous_args = (i if not isinstance(i, torch.Tensor) else
                           i.contiguous() for i in args)
        contiguous_kwargs = {
            k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
            for k, v in kwargs.items()
        }

        tensor = None
        for arg in args:
            if isinstance(arg, torch.Tensor):
                tensor = arg
                break
        if tensor is None:
            for value in kwargs.values():
                if isinstance(value, torch.Tensor):
                    tensor = value
                    break

        if tensor is not None:
            ctx = torch.cuda.device(tensor.device.index)
        else:
            ctx = contextlib.nullcontext()

        with ctx:
            return fn(*contiguous_args, **contiguous_kwargs)

    return wrapper

tensor_cache

tensor_cache(
    fn: Callable[..., Tensor],
) -> Callable[..., Tensor]

A decorator that caches the most recent results of a function with tensor inputs.

This decorator will store the output of the decorated function for the most recent set of input tensors. The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.

Parameters:

Name Type Description Default
fn Callable[..., Tensor]

The function to be decorated. It should take tensor inputs and return tensor outputs.

required

Returns:

Type Description
Callable[..., Tensor]

Callable[..., torch.Tensor]: A wrapped version of the input function with single-entry caching.

Source code in vllm/model_executor/layers/fla/ops/utils.py
def tensor_cache(
        fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
    """
    A decorator that caches the most recent results of a function with tensor inputs.

    This decorator will store the output of the decorated function for the most recent set of input tensors.
    The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.

    Args:
        fn (Callable[..., torch.Tensor]):
            The function to be decorated. It should take tensor inputs and return tensor outputs.

    Returns:
        Callable[..., torch.Tensor]:
            A wrapped version of the input function with single-entry caching.
    """

    cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
    cache_size = 4

    @functools.wraps(fn)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        nonlocal cache_entries, cache_size
        for i, entry in enumerate(cache_entries):
            last_args, last_kwargs, last_result = entry
            if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \
                and all(a is b for a, b in zip(args, last_args)) \
                and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
                cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [
                    (args, kwargs, last_result)
                ]
                return last_result

        result = fn(*args, **kwargs)

        if len(cache_entries) >= cache_size:
            cache_entries = cache_entries[1:]
        cache_entries.append((args, kwargs, result))
        return result

    return wrapper