Skip to content

Conversation

@jcaip
Copy link
Contributor

@jcaip jcaip commented Sep 26, 2025

This PR adds in support for quantizing nn.Parameter to quantize_.

bc-breaking changes

The top level quantize_ API has the following bc-breaking changes:

  1. Passing in bothfilter_fn and ModuleFqnToConfig is no longer supported and will now throw a value error if both are specified. Previously, we would quantize all modules that were both matched by filter_fn and specified in ModuleFqnToConfig. Users should now manually specify filter_fn=None when using ModuleFqnToConfig/FqnToConfig.
  2. The semantics of filter_fn=None have changed. Previously passing in None would default to _is_linear when running quantize_. Now when filter_fn=None is specified we ignore filter_fn completely and only rely on FqnToConfig to quantize the model. Note that this is equivalent to passing in filter_fn=lambda mod, fqn: True in the previous API.
  3. The default filter_fn has changed from None to _is_linear and _default in ModuleFqnToConfig now only applies to linear layers. Previously _default would apply to all modules that passed filter_fn. We plan to deprecate _default in the future, please see Deprecation for _default in FqnToConfig #3229 for more details.

Before:

model = torch.nn.Sequential(
    torch.nn.Linear(128, 128), 
    torch.nn.Linear(128, 128), 
    torch.nn.Conv2d(128, 128, 3, 1, 1), 
).cuda().to(torch.bfloat16)

config = ModuleFqnToConfig({
    "0": Float8DynamicActivationFloat8WeightConfig(), 
})

# these are equivalent
quantize_(model, config, filter_fn=_is_linear)
quantize_(model, config, filter_fn=None)
quantize_(model, config)
> Sequential(
  (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (1): Linear(in_features=128, out_features=128, bias=True)
  (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

After:

# user must specify None
quantize_(model, config, filter_fn=None)
> Sequential(
  (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (1): Linear(in_features=128, out_features=128, bias=True)
  (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

After:

# these now error
quantize_(model, config, filter_fn=_is_linear)
quantize_(model, config)
> ValueError: Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified.

Example for _default changes:

Before:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_parameter("weight", torch.nn.Parameter(torch.randn(128, 128)))

model = torch.nn.Sequential(
    torch.nn.Linear(128, 128), 
    torch.nn.Linear(128, 128), 
    MyModule(),
).cuda().to(torch.bfloat16)

config = ModuleFqnToConfig({
    "_default": Float8DynamicActivationFloat8WeightConfig(), 
})

quantize_(model, config, filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) or isinstance(mod, MyModule))
> Sequential(
  (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (2): MyModule(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
)

After:

# only linear is applied for default
quantize_(model, config, filter_fn=None)
> Sequential(
  (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (2): MyModule()
)

Summary

ModuleFqnToConfig has been renamed to FqnToConfig, which now accepts both module fqn and parameter fqns. ModuleFqnToConfig has been aliased to maintain BC. The keys to FqnToConfig can be one of the following (in order of precedence):

  1. exact parameter FQN
quant_config = FqnToConfig({
    "linear1.weight": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
  1. exact module FQN
quant_config = FqnToConfig({
    "linear1": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
  1. regex that matches parameter FQN (prepended by re:)
quant_config = FqnToConfig({
    "re:linear*.weight": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
  1. regex that matches module FQN (prepended by re:)
quant_config = FqnToConfig({
    "re:linear*": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
  1. _default, only applies to nn.Linear layers
quant_config = FqnToConfig({
    "_default": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})

To enable support for parameter fqn for a paticular config, we need to add the parameter_name kwarg into the config signature, and update CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS. See the changes here for more details.

Float8DynamicActivationFloat8WeightConfig has been enabled by this PR, but other configs will throw an NotImplementedError.

Test Plan

  1. unit tests for new config:
pytest test/quantization/test_quant_api.py::TestFqnToConfig
  1. regression test for ModuleFqnToConfig
pytest test/quantization/test_quant_api.py -k test_module_fqn_to_config
  1. Make sure that we can load old HF checkpoints to maintain BC, run this

  2. Make sure that this doesn't break BC with transformers

pytest tests/quantization/torchao_integration/test_torchao.py -k test_module_fqn_to_config
  1. make sure that this doesn't break BC in VLLM:
pytest tests/quantization/test_torchao.py

How do our configs translate for MoEs?

Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference?

Some background on MoE inference

There are two ways that forwards is implemented for MoE

  • For loop of nn.Linear - In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen here.

In this case, I argue that the semantics of the configs do not change at all from the normal nn.Linear case, as we are just doing a bunch of normal 2d linear matmuls.

  • bmm/grouped mm on the 3d weights / activations directly.

For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.

I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly).

Below is a breakdown of what Configs map to AQT / non-AQT subclasses:

not using AQT AffineQuantizedTensor
Float8DynamicActivationFloat8WeightConfig FPXWeightOnlyConfig
Float8DynamicActivationInt4WeightConfig Float8WeightOnlyConfig
Float8StaticActivationFloat8WeightConfig Float8DynamicActivationFloat8SemiSparseWeightConfig
Int4WeightOnlyConfig (v2) GemliteUIntXWeightOnlyConfig
Int4DynamicActivationInt4WeightConfig
Int8DynamicActivationInt4WeightConfig
Int8DynamicActivationInt8WeightConfig
Int8WeightOnlyConfig
IntxWeightOnlyConfig
UIntXWeightOnlyConfig

For these the majority of the semantics remain the same, the only semantics that really changes is PerRow granularity. and there's a very natural extension of PerRow to the 3d case (apply on the last dimension).

I took a look at the keys of the non-AQT configs below and what they would mean for MoEs.

Float8DynamicActivationFloat8WeightConfig

[('activation_dtype', <class 'torch.dtype'>),
 ('weight_dtype', <class 'torch.dtype'>),
 ('granularity',
  typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.List[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]),
 ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]),
 ('activation_value_lb', typing.Optional[float]),
 ('activation_value_ub', typing.Optional[float]),
 ('kernel_preference', <enum 'KernelPreference'>),
 ('set_inductor_config', <class 'bool'>),
 ('version', <class 'int'>)]

activation_dtype, weight_dtype, activation_value_lb, activation_value_ub all do not change meaning semantically.
granularity=PerTensor() does not change semantic meaning - we still use a single tensor to scale the entire weight tensor.
granularity=PerRow() does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K).
mm_config kernel_preference and set_inductor_config stay the same as well.

Float8StaticActivationFloat8WeightConfig

[('scale', <class 'torch.Tensor'>),
 ('activation_dtype', <class 'torch.dtype'>),
 ('weight_dtype', <class 'torch.dtype'>),
 ('granularity',
  typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.Tuple[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')], typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]),
 ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]),
 ('set_inductor_config', <class 'bool'>)]

scale should be passed in as a 3d tensor instead of a 2d tensor in the case of PerRow granularity

Float8DynamicActivationInt4WeightConfig

[('int4_packing_format', <enum 'Int4PackingFormat'>)]

int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor supports 3d weights.

Int4WeightOnlyConfig

[('group_size', <class 'int'>),
 ('layout',
  typing.Optional[torchao.dtypes.uintx.tensor_core_tiled_layout.TensorCoreTiledLayout]),
 ('use_hqq', <class 'bool'>),
 ('zero_point_domain',
  typing.Optional[torchao.quantization.quant_primitives.ZeroPointDomain]),
 ('set_inductor_config', <class 'bool'>),
 ('preserve_zero', typing.Optional[bool]),
 ('int4_packing_format', <enum 'Int4PackingFormat'>),
 ('int4_choose_qparams_algorithm', <enum 'Int4ChooseQParamsAlgorithm'>),
 ('version', <class 'int'>)]

group_size, int4_packing_format, int4_choose_qparams_algorithm, set_inductor_config are the only things that are set for v2 config,

I don't think these semantics of these change, although there are some packing formats that do not support 3d weights. It looks like (Int4PackingFormat.PLAIN_INT32, Int4PackingFormat.MARLIN_SPARSE).

Summary:

This PR adds in a simple 2d and 3d moe implementation and tests
`quantize_` on them to see if we get the same results.

Test Plan:

```
pytest test/prototype/test_parameter.py -k test_quantize_parameter
```

Reviewers:

Subscribers:

Tasks:

Tags:
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3083

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6a6093e with merge base f3fc5e7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 26, 2025
@jcaip jcaip requested review from jerryzh168 and vkuzo September 26, 2025 21:00
@jerryzh168
Copy link
Contributor

current AOBaseConfig is more for linear weights, can it be extended to param config cleanly?

@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

Add in ParamFqnToConfig config
This new config is very similar to ModuleFqnToConfig except it takes in nn.Parameter FQNs and also supports regexs.

Would it work to stick with ModuleFqnToConfig and update its meaning, to avoid introducing a new object with a lot of similarities with the old object? Pseudocode of what it could do:

def handle_module(model, fqn, config):
    if has_parameter(model, fqn):
        ... new behavior for parameters, apply parameter swap config ...
    elif has_parameter(model, fqn + '.weight'):
        ... old behavior, apply parameter swap config ...
    elif has_module(model, fqn):
        ... old behavior, apply module swap ...

@jcaip
Copy link
Contributor Author

jcaip commented Sep 29, 2025

Would it work to stick with ModuleFqnToConfig and update its meaning, to avoid introducing a new object with a lot of similarities with the old object?

Yeah, we can do this. Do you think we should keep the ModuleFqnToConfig name? It's a little confusing I feel to pass in parameter fqn but it's also being used by huggingface and vllm so I think it would be better to keep it as is.

@jcaip
Copy link
Contributor Author

jcaip commented Sep 29, 2025

current AOBaseConfig is more for linear weights, can it be extended to param config cleanly?

Yes I believe so, especially in the case of the Config object itself. We attach everything to the weight parameter for nn.Linear, so this allows us to specify the parameter name instead of assuming it's "weight".

The only thing that does not map cleanly IMO is the module_registration:

        # non user facing code
        @register_quantize_module_handler(WorkflowFooConfig)
        def _transform(
            mod: torch.nn.Module,
            config: WorkflowFooConfig,
        ) -> torch.nn.Module:
            # the transform is implemented here, usually a tensor sublass
            # weight swap or a module swap

I think we should define the transform for parameters as the base case (aka @register_quantize_handler) , and use that for the module flow (assuming the parameter is module.weight), since it's the more general case.

@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

Do you think we should keep the ModuleFqnToConfig name? It's a little confusing I feel to pass in parameter fqn but it's also being used by huggingface and vllm so I think it would be better to keep it as is.

IMO we should change the current name and keep the old name for BC:

ParamOrModuleFqnToConfig = ...

# for bc
ModuleFqnToConfig = ParamOrModuleFqnToConfig

@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

I think we should define the transform for parameters as the base case

To me it seems that the transform has to be for modules, because it is inplace. User can target a parameter if they want to, but the transform function always runs on a module that owns the parameter.

# skip if not direct child
if "." not in name:
for pattern in config.param_fqn_to_config:
if re.match(pattern, f"{fqn}.{name}"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it applies to all params, regardless of what it is? e.g. bias? should we be more specific in what people are configuring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should consider the regex syntax separately, I can remove from this PR.

One thing I would like would be for quantize_ log the modules/params it's swapping so it's easy to see what the difference is.

@andrewor14
Copy link
Contributor

Does this mean we need to refactor all supported configs to use this structure?

@register_quantized_param_handler(config)
def _float8_dynamic_activation_float8_weight_quantize_tensor(...):
    # returns quantized tensor

def _float8_dynamic_activation_float8_weight_transform(...):
    module.weight = _float8_dynamic_activation_float8_weight_quantize_tensor(...)
    return module

# Create config with unsupported parameter handler
quant_config = FqnToConfig(
{
"0.weight": Float8DynamicActivationFloat8WeightConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check the precendence of "0" v.s. "0.param" as well? I think "0.param" should have precendence over "0"

Config key ordered by precedence:
* fully qualified module name, e.g. `language.layers.0.q_proj`
* fully qualified parameter name, e.g. `language.layers.0.q_proj.weight`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this have higher precedence over the first one?

whichever regex fully matches the module fqn first will be applied
(order of keys for dictionary are kept consistent since we are using OrderedDict)
* "_default", fallback for **all modules** if no match for all previous keys
* regex for parameter names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj.weight`.
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this one should also have higher precedence over the regex for modules, although not sure if it's easy to implement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can just say param fqn > module_fqn in general. I think it should be easy to change this ontop of the recent changes.

# TODO we plan to deprecate `_default later, so raise a warning if we find it passed in`
if "_default" in self.fqn_to_config:
warnings.warn(
"Config Deprecation: _default is deprecated and will no longer be supported in a future release."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can create an issue to track, similar to #2948

@jerryzh168
Copy link
Contributor

jerryzh168 commented Oct 20, 2025

also bc-breaking changes is still not clear to me, can you list one note for each API call?

Before:

API 1

After:

API1 meaning change or updated way to have the same behavior of API1

?

@jcaip
Copy link
Contributor Author

jcaip commented Oct 20, 2025

also bc-breaking changes is still not clear to me, can you list one note for each API call?

Before:

API 1

After:

API1 meaning change or updated way to have the same behavior of API1

?

Yeah will do, sorry I'm still resolving some some bugs with the vllm integration so haven't gotten around to it.

assert hasattr(module, "weight"), (
"applying int8 weight only quant requires module to have weight attribute"
assert hasattr(module, parameter_name), (
"applying float8 weight only quant requires module to have {parametr_name} attribute"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fix typo, parameter_name instead of parametr_name


# for now, we need to keep track of what configs support custom param quantization.
# Once we've updated all the transform functions to take in a custom_param kwarg, we can delete this object and the subsequent check
CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS to CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS

also, can we add a TODO to migrate the rest of the callsites, and make sure it has an owner and a timeline

model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
quantize_(model, config)
quantize_(model, config, filter_fn=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the user is using AOBaseConfig which is not an instance of FQNToConfig, then there is no BC breaking change, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, filter_fn=None will default to _is_linear when it's not an instance of FQNToConfig

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for working on this!

@jcaip jcaip merged commit 50a555b into main Oct 28, 2025
21 checks passed
namgyu-youn pushed a commit to namgyu-youn/ao that referenced this pull request Nov 21, 2025
This PR adds in support for quantizing `nn.Parameter` to `quantize_`. 

### bc-breaking changes

The top level `quantize_` API has the following bc-breaking changes:

1) Passing in both`filter_fn` and `ModuleFqnToConfig` is no longer supported and will now throw a value error if both are specified. Previously, we would quantize all modules that were both matched by `filter_fn` and specified in `ModuleFqnToConfig`. Users should now manually specify `filter_fn=None` when using `ModuleFqnToConfig`/`FqnToConfig`. 
2) The semantics of `filter_fn=None` have changed. Previously passing in `None` would default to `_is_linear` when running `quantize_`. Now when `filter_fn=None` is specified we ignore `filter_fn` completely and only rely on `FqnToConfig` to quantize the model. Note that this is equivalent to passing in `filter_fn=lambda mod, fqn: True` in the previous API. 
3) The default `filter_fn` has changed from `None` to `_is_linear` and `_default` in `ModuleFqnToConfig` now only applies to linear layers. Previously `_default` would apply to all modules that passed `filter_fn`. We plan to deprecate `_default` in the future, please see pytorch#3229 for more details. 

Before:
```python
model = torch.nn.Sequential(
    torch.nn.Linear(128, 128), 
    torch.nn.Linear(128, 128), 
    torch.nn.Conv2d(128, 128, 3, 1, 1), 
).cuda().to(torch.bfloat16)

config = ModuleFqnToConfig({
    "0": Float8DynamicActivationFloat8WeightConfig(), 
})

# these are equivalent
quantize_(model, config, filter_fn=_is_linear)
quantize_(model, config, filter_fn=None)
quantize_(model, config)
```
```
> Sequential(
  (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (1): Linear(in_features=128, out_features=128, bias=True)
  (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
```
After:
```python
# user must specify None
quantize_(model, config, filter_fn=None)
```
```
> Sequential(
  (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (1): Linear(in_features=128, out_features=128, bias=True)
  (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
```
After:
```python
# these now error
quantize_(model, config, filter_fn=_is_linear)
quantize_(model, config)
```
```
> ValueError: Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified.
```

#### Example for _default changes:
Before:
```python
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_parameter("weight", torch.nn.Parameter(torch.randn(128, 128)))

model = torch.nn.Sequential(
    torch.nn.Linear(128, 128), 
    torch.nn.Linear(128, 128), 
    MyModule(),
).cuda().to(torch.bfloat16)

config = ModuleFqnToConfig({
    "_default": Float8DynamicActivationFloat8WeightConfig(), 
})

quantize_(model, config, filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) or isinstance(mod, MyModule))
```
```
> Sequential(
  (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (2): MyModule(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
)
```

After:
```python
# only linear is applied for default
quantize_(model, config, filter_fn=None)
```
```
> Sequential(
  (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
  (2): MyModule()
)
```


### Summary

`ModuleFqnToConfig` has been renamed to `FqnToConfig`, which now accepts both module fqn and parameter fqns. `ModuleFqnToConfig` has been aliased to maintain BC.  The keys to `FqnToConfig` can be one of the following (in order of precedence):

1) exact parameter FQN
```python
quant_config = FqnToConfig({
    "linear1.weight": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
```
2) exact module FQN 
```python
quant_config = FqnToConfig({
    "linear1": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
```
3) regex that matches parameter FQN (prepended by `re:`)
```python
quant_config = FqnToConfig({
    "re:linear*.weight": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
```
4) regex that matches module FQN (prepended by `re:`)
```python
quant_config = FqnToConfig({
    "re:linear*": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
```
5) _default, only applies to `nn.Linear` layers
```python
quant_config = FqnToConfig({
    "_default": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
```

To enable support for parameter fqn for a paticular config, we need to add the `parameter_name` kwarg into the config signature, and update `CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS`. See the changes [here](https://github.com/pytorch/ao/pull/3083/files#diff-bf4d50867e3d649de2d89146592bf47d2f258c4c19126c8acf0e120ee904b726R1874) for more details. 

`Float8DynamicActivationFloat8WeightConfig` has been enabled by this PR, but other configs will throw an `NotImplementedError`. 


### Test Plan

1) unit tests for new config:
```
pytest test/quantization/test_quant_api.py::TestFqnToConfig
```

2) regression test for ModuleFqnToConfig
```
pytest test/quantization/test_quant_api.py -k test_module_fqn_to_config
```
3) Make sure that we can load old HF checkpoints to maintain BC, run [this](https://huggingface.co/torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev#test-loading)

4) Make sure that this doesn't break BC with transformers
```
pytest tests/quantization/torchao_integration/test_torchao.py -k test_module_fqn_to_config
```

5) make sure that this doesn't break BC in VLLM:
```
pytest tests/quantization/test_torchao.py
```
___

## How do our configs translate for MoEs?

Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference? 

### Some background on MoE inference
There are two ways that forwards is implemented for MoE

- For loop of `nn.Linear` - In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen [here](https://github.com/huggingface/transformers/blob/6cade29278c4aee3f174f8950f97a3873bdb212f/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L117). 

**In this case, I argue that the semantics of the configs do not change at all from the normal `nn.Linear` case, as we are just doing a bunch of normal 2d linear matmuls.** 

- bmm/grouped mm on the 3d weights / activations directly. 

**For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.**

I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly). 

Below is a breakdown of what Configs map to AQT / non-AQT subclasses:
| not using AQT | AffineQuantizedTensor |
|-----------|---------------|
| Float8DynamicActivationFloat8WeightConfig | FPXWeightOnlyConfig |
| Float8DynamicActivationInt4WeightConfig | Float8WeightOnlyConfig |
| Float8StaticActivationFloat8WeightConfig | Float8DynamicActivationFloat8SemiSparseWeightConfig |
| Int4WeightOnlyConfig (v2) | GemliteUIntXWeightOnlyConfig |
|  | Int4DynamicActivationInt4WeightConfig |
|  | Int8DynamicActivationInt4WeightConfig |
|  | Int8DynamicActivationInt8WeightConfig |
|  | Int8WeightOnlyConfig |
|  | IntxWeightOnlyConfig |
|  | UIntXWeightOnlyConfig |

For these the majority of the semantics remain the same, the only semantics that really changes is `PerRow` granularity. and there's a very natural extension of `PerRow` to the 3d case (apply on the last dimension). 

I took a look at the keys of the non-AQT configs below and what they would mean for MoEs. 

#### Float8DynamicActivationFloat8WeightConfig
```
[('activation_dtype', <class 'torch.dtype'>),
 ('weight_dtype', <class 'torch.dtype'>),
 ('granularity',
  typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.List[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]),
 ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]),
 ('activation_value_lb', typing.Optional[float]),
 ('activation_value_ub', typing.Optional[float]),
 ('kernel_preference', <enum 'KernelPreference'>),
 ('set_inductor_config', <class 'bool'>),
 ('version', <class 'int'>)]
```

`activation_dtype`, `weight_dtype`, `activation_value_lb`, `activation_value_ub` all do not change meaning semantically. 
`granularity=PerTensor()` does not change semantic meaning - we still use a single tensor to scale the entire weight tensor. 
`granularity=PerRow()` does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K). 
`mm_config` `kernel_preference` and `set_inductor_config` stay the same as well. 

#### Float8StaticActivationFloat8WeightConfig
```
[('scale', <class 'torch.Tensor'>),
 ('activation_dtype', <class 'torch.dtype'>),
 ('weight_dtype', <class 'torch.dtype'>),
 ('granularity',
  typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.Tuple[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')], typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]),
 ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]),
 ('set_inductor_config', <class 'bool'>)]
 ```
`scale` should be passed in as a 3d tensor instead of a 2d tensor in the case of `PerRow` granularity

#### Float8DynamicActivationInt4WeightConfig
```
[('int4_packing_format', <enum 'Int4PackingFormat'>)]
```

int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor [supports](https://github.com/pytorch/ao/blob/895573980e085b02a2c6abbc82239bae7f1318d6/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py#L154) 3d weights. 

#### Int4WeightOnlyConfig
```
[('group_size', <class 'int'>),
 ('layout',
  typing.Optional[torchao.dtypes.uintx.tensor_core_tiled_layout.TensorCoreTiledLayout]),
 ('use_hqq', <class 'bool'>),
 ('zero_point_domain',
  typing.Optional[torchao.quantization.quant_primitives.ZeroPointDomain]),
 ('set_inductor_config', <class 'bool'>),
 ('preserve_zero', typing.Optional[bool]),
 ('int4_packing_format', <enum 'Int4PackingFormat'>),
 ('int4_choose_qparams_algorithm', <enum 'Int4ChooseQParamsAlgorithm'>),
 ('version', <class 'int'>)]
 ```

`group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm`, `set_inductor_config` are the only things that are set for v2 config, 

I don't think these semantics of these change, although there are some packing formats that do not support 3d weights.  It looks like (`Int4PackingFormat.PLAIN_INT32`,  `Int4PackingFormat.MARLIN_SPARSE`).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bc-breaking Use this tag if this PR breaks backward compatibility topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants