-
Notifications
You must be signed in to change notification settings - Fork 400
Add quantize_ nn.Parameter support #3083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: This PR adds in a simple 2d and 3d moe implementation and tests `quantize_` on them to see if we get the same results. Test Plan: ``` pytest test/prototype/test_parameter.py -k test_quantize_parameter ``` Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3083
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6a6093e with merge base f3fc5e7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
current AOBaseConfig is more for linear weights, can it be extended to param config cleanly? |
Would it work to stick with def handle_module(model, fqn, config):
if has_parameter(model, fqn):
... new behavior for parameters, apply parameter swap config ...
elif has_parameter(model, fqn + '.weight'):
... old behavior, apply parameter swap config ...
elif has_module(model, fqn):
... old behavior, apply module swap ... |
Yeah, we can do this. Do you think we should keep the |
Yes I believe so, especially in the case of the Config object itself. We attach everything to the weight parameter for nn.Linear, so this allows us to specify the parameter name instead of assuming it's "weight". The only thing that does not map cleanly IMO is the I think we should define the transform for parameters as the base case (aka |
IMO we should change the current name and keep the old name for BC: ParamOrModuleFqnToConfig = ...
# for bc
ModuleFqnToConfig = ParamOrModuleFqnToConfig |
To me it seems that the transform has to be for modules, because it is inplace. User can target a parameter if they want to, but the transform function always runs on a module that owns the parameter. |
fe12f23 to
7c5ab04
Compare
torchao/quantization/quant_api.py
Outdated
| # skip if not direct child | ||
| if "." not in name: | ||
| for pattern in config.param_fqn_to_config: | ||
| if re.match(pattern, f"{fqn}.{name}"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it applies to all params, regardless of what it is? e.g. bias? should we be more specific in what people are configuring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should consider the regex syntax separately, I can remove from this PR.
One thing I would like would be for quantize_ log the modules/params it's swapping so it's easy to see what the difference is.
|
Does this mean we need to refactor all supported configs to use this structure? |
| # Create config with unsupported parameter handler | ||
| quant_config = FqnToConfig( | ||
| { | ||
| "0.weight": Float8DynamicActivationFloat8WeightConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check the precendence of "0" v.s. "0.param" as well? I think "0.param" should have precendence over "0"
| Config key ordered by precedence: | ||
| * fully qualified module name, e.g. `language.layers.0.q_proj` | ||
| * fully qualified parameter name, e.g. `language.layers.0.q_proj.weight` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this have higher precedence over the first one?
torchao/quantization/quant_api.py
Outdated
| whichever regex fully matches the module fqn first will be applied | ||
| (order of keys for dictionary are kept consistent since we are using OrderedDict) | ||
| * "_default", fallback for **all modules** if no match for all previous keys | ||
| * regex for parameter names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj.weight`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this one should also have higher precedence over the regex for modules, although not sure if it's easy to implement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we can just say param fqn > module_fqn in general. I think it should be easy to change this ontop of the recent changes.
torchao/quantization/quant_api.py
Outdated
| # TODO we plan to deprecate `_default later, so raise a warning if we find it passed in` | ||
| if "_default" in self.fqn_to_config: | ||
| warnings.warn( | ||
| "Config Deprecation: _default is deprecated and will no longer be supported in a future release." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can create an issue to track, similar to #2948
|
also Before: After: ? |
Yeah will do, sorry I'm still resolving some some bugs with the vllm integration so haven't gotten around to it. |
…orch/ao into jcaip/quantize_param_support
torchao/quantization/quant_api.py
Outdated
| assert hasattr(module, "weight"), ( | ||
| "applying int8 weight only quant requires module to have weight attribute" | ||
| assert hasattr(module, parameter_name), ( | ||
| "applying float8 weight only quant requires module to have {parametr_name} attribute" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: fix typo, parameter_name instead of parametr_name
torchao/quantization/quant_api.py
Outdated
|
|
||
| # for now, we need to keep track of what configs support custom param quantization. | ||
| # Once we've updated all the transform functions to take in a custom_param kwarg, we can delete this object and the subsequent check | ||
| CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS to CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS
also, can we add a TODO to migrate the rest of the callsites, and make sure it has an owner and a timeline
| model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) | ||
| example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) | ||
| quantize_(model, config) | ||
| quantize_(model, config, filter_fn=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the user is using AOBaseConfig which is not an instance of FQNToConfig, then there is no BC breaking change, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, filter_fn=None will default to _is_linear when it's not an instance of FQNToConfig
vkuzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for working on this!
This PR adds in support for quantizing `nn.Parameter` to `quantize_`. ### bc-breaking changes The top level `quantize_` API has the following bc-breaking changes: 1) Passing in both`filter_fn` and `ModuleFqnToConfig` is no longer supported and will now throw a value error if both are specified. Previously, we would quantize all modules that were both matched by `filter_fn` and specified in `ModuleFqnToConfig`. Users should now manually specify `filter_fn=None` when using `ModuleFqnToConfig`/`FqnToConfig`. 2) The semantics of `filter_fn=None` have changed. Previously passing in `None` would default to `_is_linear` when running `quantize_`. Now when `filter_fn=None` is specified we ignore `filter_fn` completely and only rely on `FqnToConfig` to quantize the model. Note that this is equivalent to passing in `filter_fn=lambda mod, fqn: True` in the previous API. 3) The default `filter_fn` has changed from `None` to `_is_linear` and `_default` in `ModuleFqnToConfig` now only applies to linear layers. Previously `_default` would apply to all modules that passed `filter_fn`. We plan to deprecate `_default` in the future, please see pytorch#3229 for more details. Before: ```python model = torch.nn.Sequential( torch.nn.Linear(128, 128), torch.nn.Linear(128, 128), torch.nn.Conv2d(128, 128, 3, 1, 1), ).cuda().to(torch.bfloat16) config = ModuleFqnToConfig({ "0": Float8DynamicActivationFloat8WeightConfig(), }) # these are equivalent quantize_(model, config, filter_fn=_is_linear) quantize_(model, config, filter_fn=None) quantize_(model, config) ``` ``` > Sequential( (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (1): Linear(in_features=128, out_features=128, bias=True) (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ``` After: ```python # user must specify None quantize_(model, config, filter_fn=None) ``` ``` > Sequential( (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (1): Linear(in_features=128, out_features=128, bias=True) (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ``` After: ```python # these now error quantize_(model, config, filter_fn=_is_linear) quantize_(model, config) ``` ``` > ValueError: Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified. ``` #### Example for _default changes: Before: ```python class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.register_parameter("weight", torch.nn.Parameter(torch.randn(128, 128))) model = torch.nn.Sequential( torch.nn.Linear(128, 128), torch.nn.Linear(128, 128), MyModule(), ).cuda().to(torch.bfloat16) config = ModuleFqnToConfig({ "_default": Float8DynamicActivationFloat8WeightConfig(), }) quantize_(model, config, filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) or isinstance(mod, MyModule)) ``` ``` > Sequential( (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (2): MyModule(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) ) ``` After: ```python # only linear is applied for default quantize_(model, config, filter_fn=None) ``` ``` > Sequential( (0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>)) (2): MyModule() ) ``` ### Summary `ModuleFqnToConfig` has been renamed to `FqnToConfig`, which now accepts both module fqn and parameter fqns. `ModuleFqnToConfig` has been aliased to maintain BC. The keys to `FqnToConfig` can be one of the following (in order of precedence): 1) exact parameter FQN ```python quant_config = FqnToConfig({ "linear1.weight": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` 2) exact module FQN ```python quant_config = FqnToConfig({ "linear1": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` 3) regex that matches parameter FQN (prepended by `re:`) ```python quant_config = FqnToConfig({ "re:linear*.weight": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` 4) regex that matches module FQN (prepended by `re:`) ```python quant_config = FqnToConfig({ "re:linear*": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` 5) _default, only applies to `nn.Linear` layers ```python quant_config = FqnToConfig({ "_default": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` To enable support for parameter fqn for a paticular config, we need to add the `parameter_name` kwarg into the config signature, and update `CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS`. See the changes [here](https://github.com/pytorch/ao/pull/3083/files#diff-bf4d50867e3d649de2d89146592bf47d2f258c4c19126c8acf0e120ee904b726R1874) for more details. `Float8DynamicActivationFloat8WeightConfig` has been enabled by this PR, but other configs will throw an `NotImplementedError`. ### Test Plan 1) unit tests for new config: ``` pytest test/quantization/test_quant_api.py::TestFqnToConfig ``` 2) regression test for ModuleFqnToConfig ``` pytest test/quantization/test_quant_api.py -k test_module_fqn_to_config ``` 3) Make sure that we can load old HF checkpoints to maintain BC, run [this](https://huggingface.co/torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev#test-loading) 4) Make sure that this doesn't break BC with transformers ``` pytest tests/quantization/torchao_integration/test_torchao.py -k test_module_fqn_to_config ``` 5) make sure that this doesn't break BC in VLLM: ``` pytest tests/quantization/test_torchao.py ``` ___ ## How do our configs translate for MoEs? Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference? ### Some background on MoE inference There are two ways that forwards is implemented for MoE - For loop of `nn.Linear` - In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen [here](https://github.com/huggingface/transformers/blob/6cade29278c4aee3f174f8950f97a3873bdb212f/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L117). **In this case, I argue that the semantics of the configs do not change at all from the normal `nn.Linear` case, as we are just doing a bunch of normal 2d linear matmuls.** - bmm/grouped mm on the 3d weights / activations directly. **For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.** I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly). Below is a breakdown of what Configs map to AQT / non-AQT subclasses: | not using AQT | AffineQuantizedTensor | |-----------|---------------| | Float8DynamicActivationFloat8WeightConfig | FPXWeightOnlyConfig | | Float8DynamicActivationInt4WeightConfig | Float8WeightOnlyConfig | | Float8StaticActivationFloat8WeightConfig | Float8DynamicActivationFloat8SemiSparseWeightConfig | | Int4WeightOnlyConfig (v2) | GemliteUIntXWeightOnlyConfig | | | Int4DynamicActivationInt4WeightConfig | | | Int8DynamicActivationInt4WeightConfig | | | Int8DynamicActivationInt8WeightConfig | | | Int8WeightOnlyConfig | | | IntxWeightOnlyConfig | | | UIntXWeightOnlyConfig | For these the majority of the semantics remain the same, the only semantics that really changes is `PerRow` granularity. and there's a very natural extension of `PerRow` to the 3d case (apply on the last dimension). I took a look at the keys of the non-AQT configs below and what they would mean for MoEs. #### Float8DynamicActivationFloat8WeightConfig ``` [('activation_dtype', <class 'torch.dtype'>), ('weight_dtype', <class 'torch.dtype'>), ('granularity', typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.List[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]), ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]), ('activation_value_lb', typing.Optional[float]), ('activation_value_ub', typing.Optional[float]), ('kernel_preference', <enum 'KernelPreference'>), ('set_inductor_config', <class 'bool'>), ('version', <class 'int'>)] ``` `activation_dtype`, `weight_dtype`, `activation_value_lb`, `activation_value_ub` all do not change meaning semantically. `granularity=PerTensor()` does not change semantic meaning - we still use a single tensor to scale the entire weight tensor. `granularity=PerRow()` does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K). `mm_config` `kernel_preference` and `set_inductor_config` stay the same as well. #### Float8StaticActivationFloat8WeightConfig ``` [('scale', <class 'torch.Tensor'>), ('activation_dtype', <class 'torch.dtype'>), ('weight_dtype', <class 'torch.dtype'>), ('granularity', typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.Tuple[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')], typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]), ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]), ('set_inductor_config', <class 'bool'>)] ``` `scale` should be passed in as a 3d tensor instead of a 2d tensor in the case of `PerRow` granularity #### Float8DynamicActivationInt4WeightConfig ``` [('int4_packing_format', <enum 'Int4PackingFormat'>)] ``` int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor [supports](https://github.com/pytorch/ao/blob/895573980e085b02a2c6abbc82239bae7f1318d6/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py#L154) 3d weights. #### Int4WeightOnlyConfig ``` [('group_size', <class 'int'>), ('layout', typing.Optional[torchao.dtypes.uintx.tensor_core_tiled_layout.TensorCoreTiledLayout]), ('use_hqq', <class 'bool'>), ('zero_point_domain', typing.Optional[torchao.quantization.quant_primitives.ZeroPointDomain]), ('set_inductor_config', <class 'bool'>), ('preserve_zero', typing.Optional[bool]), ('int4_packing_format', <enum 'Int4PackingFormat'>), ('int4_choose_qparams_algorithm', <enum 'Int4ChooseQParamsAlgorithm'>), ('version', <class 'int'>)] ``` `group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm`, `set_inductor_config` are the only things that are set for v2 config, I don't think these semantics of these change, although there are some packing formats that do not support 3d weights. It looks like (`Int4PackingFormat.PLAIN_INT32`, `Int4PackingFormat.MARLIN_SPARSE`).
This PR adds in support for quantizing
nn.Parametertoquantize_.bc-breaking changes
The top level
quantize_API has the following bc-breaking changes:filter_fnandModuleFqnToConfigis no longer supported and will now throw a value error if both are specified. Previously, we would quantize all modules that were both matched byfilter_fnand specified inModuleFqnToConfig. Users should now manually specifyfilter_fn=Nonewhen usingModuleFqnToConfig/FqnToConfig.filter_fn=Nonehave changed. Previously passing inNonewould default to_is_linearwhen runningquantize_. Now whenfilter_fn=Noneis specified we ignorefilter_fncompletely and only rely onFqnToConfigto quantize the model. Note that this is equivalent to passing infilter_fn=lambda mod, fqn: Truein the previous API.filter_fnhas changed fromNoneto_is_linearand_defaultinModuleFqnToConfignow only applies to linear layers. Previously_defaultwould apply to all modules that passedfilter_fn. We plan to deprecate_defaultin the future, please see Deprecation for_defaultin FqnToConfig #3229 for more details.Before:
After:
After:
Example for _default changes:
Before:
After:
Summary
ModuleFqnToConfighas been renamed toFqnToConfig, which now accepts both module fqn and parameter fqns.ModuleFqnToConfighas been aliased to maintain BC. The keys toFqnToConfigcan be one of the following (in order of precedence):re:)re:)nn.LinearlayersTo enable support for parameter fqn for a paticular config, we need to add the
parameter_namekwarg into the config signature, and updateCUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS. See the changes here for more details.Float8DynamicActivationFloat8WeightConfighas been enabled by this PR, but other configs will throw anNotImplementedError.Test Plan
Make sure that we can load old HF checkpoints to maintain BC, run this
Make sure that this doesn't break BC with transformers
How do our configs translate for MoEs?
Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference?
Some background on MoE inference
There are two ways that forwards is implemented for MoE
nn.Linear- In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen here.In this case, I argue that the semantics of the configs do not change at all from the normal
nn.Linearcase, as we are just doing a bunch of normal 2d linear matmuls.For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.
I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly).
Below is a breakdown of what Configs map to AQT / non-AQT subclasses:
For these the majority of the semantics remain the same, the only semantics that really changes is
PerRowgranularity. and there's a very natural extension ofPerRowto the 3d case (apply on the last dimension).I took a look at the keys of the non-AQT configs below and what they would mean for MoEs.
Float8DynamicActivationFloat8WeightConfig
activation_dtype,weight_dtype,activation_value_lb,activation_value_uball do not change meaning semantically.granularity=PerTensor()does not change semantic meaning - we still use a single tensor to scale the entire weight tensor.granularity=PerRow()does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K).mm_configkernel_preferenceandset_inductor_configstay the same as well.Float8StaticActivationFloat8WeightConfig
scaleshould be passed in as a 3d tensor instead of a 2d tensor in the case ofPerRowgranularityFloat8DynamicActivationInt4WeightConfig
int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor supports 3d weights.
Int4WeightOnlyConfig
group_size,int4_packing_format,int4_choose_qparams_algorithm,set_inductor_configare the only things that are set for v2 config,I don't think these semantics of these change, although there are some packing formats that do not support 3d weights. It looks like (
Int4PackingFormat.PLAIN_INT32,Int4PackingFormat.MARLIN_SPARSE).