diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 7459b2504c..8dd6410597 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -69,6 +69,7 @@ float8_static_activation_float8_weight, float8_weight_only, fpx_weight_only, + fqn_matches_fqn_config, gemlite_uintx_weight_only, int4_dynamic_activation_int4_weight, int4_weight_only, @@ -142,6 +143,7 @@ "float8_static_activation_float8_weight", "uintx_weight_only", "fpx_weight_only", + "fqn_matches_fqn_config", "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", "Int4DynamicActivationInt4WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ddeb8c7ca6..85be1ebafd 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -480,7 +480,7 @@ def quantize_( for module_fqn, module in model.named_modules(): if ( - _fqn_matches_fqn_config(module_fqn, config) + fqn_matches_fqn_config(module_fqn, config) or _module_param_matches_fqn_config(module, module_fqn, config) or ("_default" in config.fqn_to_config and _is_linear(module)) ): @@ -488,7 +488,9 @@ def quantize_( module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn ) # this replaces inplace, so no need to reassign - _fqn_to_config_handler(module, module_name, config, device) + _fqn_to_config_handler(module, module_name, config) + if device is not None: + module.to(device=device) return if isinstance(config, AOBaseConfig): filter_fn = _is_linear if filter_fn is None else filter_fn @@ -2470,7 +2472,6 @@ def _fqn_to_config_handler( module: torch.nn.Module, fqn: str, config: FqnToConfig, - device: Optional[torch.device] = None, ): """This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig. @@ -2479,7 +2480,6 @@ def _fqn_to_config_handler( fqn (str): The fully qualified name of the module containing the parameters. config (FqnToConfig): Configuration object containing regex patterns / fqn mapped to quantization configurations. - device (Optional[torch.device]): The device to move the module to as part of quantization Returns: torch.nn.Module: The modified module with quantized parameters. @@ -2487,9 +2487,6 @@ def _fqn_to_config_handler( Raises: NotImplementedError: If the quantization configuration is not yet supported for parameter quantization. """ - if device is not None: - module = module.to(device) - parameter_config_found = False top_level_params = [] for i, (parameter_name, param) in enumerate(list(module.named_parameters())): @@ -2563,7 +2560,7 @@ def _fqn_to_config_handler( return module -def _fqn_matches_fqn_config( +def fqn_matches_fqn_config( fqn: str, config: FqnToConfig, ): @@ -2608,7 +2605,7 @@ def _module_param_matches_fqn_config( for name, param in module.named_parameters(): if name in dir(module): parameter_fqn = f"{fqn}.{name}" if len(fqn) > 0 else name - if _fqn_matches_fqn_config(parameter_fqn, config): + if fqn_matches_fqn_config(parameter_fqn, config): return True return False