Skip to content

vllm.model_executor.layers.quantization.utils.gptq_utils

get_dynamic_override

get_dynamic_override(
    config: Union[GPTQConfig, GPTQMarlinConfig],
    layer_name: str,
    key: Optional[str] = None,
    default_value: Union[int, bool, None] = None,
) -> Union[dict, int, bool, None]
Source code in vllm/model_executor/layers/quantization/utils/gptq_utils.py
def get_dynamic_override(
    config: Union[GPTQConfig, GPTQMarlinConfig],
    layer_name: str,
    key: Optional[str] = None,
    default_value: Union[int, bool,
                         None] = None) -> Union[dict, int, bool, None]:
    for pattern, pattern_dict in config.dynamic.items():
        # Negative match: matched modules are excluded from quantized init
        if pattern.startswith("-:"):
            if re.match(pattern.removeprefix("-:"), layer_name):
                return False
        # Positive match: matched modules have quant properties overrides
        # base quant config
        elif re.match(pattern.removeprefix("+:"), layer_name):
            if key is None:
                return pattern_dict
            else:
                return pattern_dict.get(key, default_value)
    return default_value

get_linear_quant_method

get_linear_quant_method(
    config: Union[GPTQConfig, GPTQMarlinConfig],
    layer: Module,
    prefix: str,
    linear_method_cls: type,
)
Source code in vllm/model_executor/layers/quantization/utils/gptq_utils.py
def get_linear_quant_method(
    config: Union[GPTQConfig, GPTQMarlinConfig],
    layer: torch.nn.Module,
    prefix: str,
    linear_method_cls: type,
):
    cloned_config = deepcopy(config)
    parallel_lm_head_quantized = isinstance(
        layer, ParallelLMHead) and cloned_config.lm_head_quantized
    if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
        is_layer_quantized = is_layer_gptq_quantized(
            prefix=prefix,
            quantized_layers=cloned_config.modules_in_block_to_quantize,
            fused_mapping=cloned_config.packed_modules_mapping)
        # False = skip module, None = no override, else = Positive match
        if get_dynamic_override(  # noqa: E712
                cloned_config,  # noqa: E712
                layer_name=prefix) == False or (
                    not is_layer_quantized):  # noqa: E712
            if parallel_lm_head_quantized:
                return UnquantizedEmbeddingMethod()
            return UnquantizedLinearMethod()

        if prefix:
            # Dynamic per module/layer rules may override base config
            override_config(cloned_config, prefix=prefix)

        return linear_method_cls(cloned_config)
    return None

is_layer_gptq_quantized

is_layer_gptq_quantized(
    prefix: str,
    quantized_layers: list[str],
    fused_mapping: Mapping[
        str, list[str]
    ] = MappingProxyType({}),
) -> bool
Source code in vllm/model_executor/layers/quantization/utils/gptq_utils.py
def is_layer_gptq_quantized(
    prefix: str,
    quantized_layers: list[str],
    fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
    # prefix: model.layers.0.self_attn.q_proj
    # proj_name: q_proj

    # GPTQ's `modules_in_block_to_quantize`:
    # Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"]
    # Full prefix ["model.layers.0.self_attn.q_proj"]

    proj_name = prefix.split(".")[-1]

    # Fused layers like gate_up_proj or qkv_proj will not be fused
    # in the safetensors checkpoint. So, we convert the name
    # from the fused version to unfused + check to make sure that
    # each shard of the fused layer has the same scheme.
    if proj_name in fused_mapping:
        shard_prefixes = [
            prefix.replace(proj_name, shard_proj_name)
            for shard_proj_name in fused_mapping[proj_name]
        ]

        is_quantized = None
        for shard_prefix in shard_prefixes:
            is_shard_quantized = any(layer in shard_prefix
                                     for layer in quantized_layers)

            if is_quantized is None:
                is_quantized = is_shard_quantized
            elif is_shard_quantized != is_quantized:
                raise ValueError(
                    f"Detected some but not all shards of {prefix} "
                    "are quantized. All shards of fused layers "
                    "to have the same precision.")
    else:
        is_quantized = any(layer in prefix for layer in quantized_layers)

    assert is_quantized is not None
    return is_quantized

override_config

override_config(
    config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str
)
Source code in vllm/model_executor/layers/quantization/utils/gptq_utils.py
def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str):
    weight_bits = get_dynamic_override(config, prefix, "bits",
                                       config.weight_bits)
    if isinstance(weight_bits, int):
        config.weight_bits = weight_bits
    group_size = get_dynamic_override(config, prefix, "group_size",
                                      config.group_size)
    if isinstance(group_size, int):
        config.group_size = group_size
    desc_act = get_dynamic_override(config, prefix, "desc_act",
                                    config.desc_act)
    if isinstance(desc_act, bool):
        config.desc_act = desc_act

    config.pack_factor = Fraction(32, config.weight_bits)  # packed into int32
    if config.get_name() == "gptq_marlin":
        assert isinstance(config, GPTQMarlinConfig)
        is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
        if isinstance(is_sym, bool):
            config.is_sym = is_sym

        if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
            raise ValueError("Unsupported quantization config: "
                             f"bits={config.weight_bits}, sym={config.is_sym}")

        config.quant_type = config.TYPE_MAP[(config.weight_bits,
                                             config.is_sym)]
    elif config.get_name() == "gptq":
        assert isinstance(config, GPTQConfig)
        if config.weight_bits not in [2, 3, 4, 8]:
            raise ValueError(
                "Currently, only 2/3/4/8-bit weight quantization is "
                f"supported for GPTQ, but got {config.weight_bits} bits.")