-
Notifications
You must be signed in to change notification settings - Fork 400
[CPU] add Float8OpaqueTensor for dynamic float8 act float8 weight #3075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jerryzh168
merged 12 commits into
pytorch:main
from
Xia-Weiwen:float8_opaque_tensor_new
Nov 14, 2025
Merged
Changes from 8 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
d460134
[CPU] add Float8OpaqueTensor for dynamic float8 act float8 weight
Xia-Weiwen cf8dc09
Update _normalize_granularity
Xia-Weiwen 4333727
Update torchao/quantization/quant_api.py
Xia-Weiwen 6e1c2a2
Fix CI
Xia-Weiwen 7980de8
Merge branch 'main' into float8_opaque_tensor_new
Xia-Weiwen ecf5e1a
Merge branch 'main' into float8_opaque_tensor_new
Xia-Weiwen 1044dca
remove unnecessary changes
Xia-Weiwen 8044d4a
Merge branch 'main' into float8_opaque_tensor_new
Xia-Weiwen b1e715f
Merge branch 'main' into float8_opaque_tensor_new
Xia-Weiwen 52906df
Refine code
Xia-Weiwen c7524ea
Refine code
Xia-Weiwen c26d34b
Refine code
Xia-Weiwen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
164 changes: 164 additions & 0 deletions
164
test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import tempfile | ||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal import common_utils | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| run_tests, | ||
| ) | ||
|
|
||
| from torchao import quantize_ | ||
| from torchao.quantization import PerGroup, PerRow, PerTensor | ||
| from torchao.quantization.quant_api import ( | ||
| Float8DynamicActivationFloat8WeightConfig, | ||
| ) | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.utils import ( | ||
| torch_version_at_least, | ||
| ) | ||
|
|
||
|
|
||
| def get_config(granularity): | ||
| return Float8DynamicActivationFloat8WeightConfig( | ||
| activation_dtype=torch.float8_e4m3fn, | ||
| granularity=granularity, | ||
| float8_packing_format="opaque", | ||
| ) | ||
|
|
||
|
|
||
| class ToyLinearModel(torch.nn.Module): | ||
Xia-Weiwen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self, K=64, N=32, bias=False): | ||
| super().__init__() | ||
| self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) | ||
| self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) | ||
|
|
||
| def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): | ||
| return ( | ||
| torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device) | ||
| * 0.1, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| x = self.linear1(x) | ||
| x = self.linear2(x) | ||
| return x | ||
|
|
||
|
|
||
| class TestFloat8OpaqueTensor(TestCase): | ||
| """Test cases for Float8OpaqueTensor on CPU""" | ||
|
|
||
| @unittest.skipIf( | ||
| "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
| reason="cpp kernels not built", | ||
| ) | ||
| @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
| @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) | ||
| @common_utils.parametrize("x_dim", [2, 3]) | ||
| @common_utils.parametrize("bias", [True, False]) | ||
| @common_utils.parametrize("bs", [1, 160]) | ||
| @common_utils.parametrize( | ||
| "x_granularity", | ||
| [PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)], | ||
| ) | ||
Xia-Weiwen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @common_utils.parametrize( | ||
| "w_granularity", | ||
| [PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)], | ||
| ) | ||
| def test_dynamic_float8_linear( | ||
| self, dtype, x_dim, bias, bs, x_granularity, w_granularity | ||
| ): | ||
| if isinstance(x_granularity, PerGroup): | ||
| if not isinstance(w_granularity, PerGroup): | ||
| return | ||
| if w_granularity.group_size != x_granularity.group_size: | ||
| return | ||
| device = "cpu" | ||
| m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) | ||
| example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
| if x_dim == 3: | ||
| example_inputs = (example_inputs[0].unsqueeze(0),) | ||
| y = m(*example_inputs) | ||
|
|
||
| with torch.no_grad(): | ||
Xia-Weiwen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| quantize_( | ||
| m, | ||
| get_config([x_granularity, w_granularity]), | ||
| ) | ||
| y1 = m(*example_inputs) | ||
| assert compute_error(y, y1) > 20 | ||
| y2, code = torch._inductor.utils.run_and_get_code( | ||
| torch.compile(m, fullgraph=True, dynamic=True), | ||
| *example_inputs, | ||
| ) | ||
| # ensure the expected op is in the code | ||
| assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
| assert compute_error(y, y2) > 20 | ||
|
|
||
| @unittest.skipIf( | ||
| "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
| reason="cpp kernels not built", | ||
| ) | ||
| @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
| @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) | ||
| @common_utils.parametrize("x_dim", [2, 3]) | ||
| @common_utils.parametrize("bias", [True, False]) | ||
| @common_utils.parametrize("bs", [4, 128]) | ||
| def test_dynamic_float8_linear_ref(self, dtype, x_dim, bias, bs): | ||
Xia-Weiwen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| device = "cpu" | ||
| # the shape is not supported by cpp kernel, so the ref path will be used. | ||
| m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) | ||
| example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
| if x_dim == 3: | ||
| example_inputs = (example_inputs[0].unsqueeze(0),) | ||
| y = m(*example_inputs) | ||
|
|
||
| with torch.no_grad(): | ||
| quantize_( | ||
| m, | ||
| get_config(PerRow()), | ||
| ) | ||
| y1 = m(*example_inputs) | ||
| assert compute_error(y, y1) > 20 | ||
| y2, code = torch._inductor.utils.run_and_get_code( | ||
| torch.compile(m, fullgraph=True, dynamic=True), | ||
| *example_inputs, | ||
| ) | ||
| # ensure the expected op is in the code | ||
| assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
| assert compute_error(y, y2) > 20 | ||
|
|
||
| @unittest.skipIf( | ||
| "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
| reason="cpp kernels not built", | ||
| ) | ||
| @common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) | ||
| def test_module_path(self, dtype): | ||
| linear = torch.nn.Linear(128, 256, dtype=dtype) | ||
| quantize_(linear, get_config(PerRow())) | ||
| self.assertEqual( | ||
| str(type(linear.weight)), | ||
| "<class 'torchao.quantization.Float8OpaqueTensor'>", | ||
| ) | ||
|
|
||
| with tempfile.NamedTemporaryFile() as f: | ||
| torch.save(linear.state_dict(), f) | ||
| f.seek(0) | ||
| state_dict = torch.load(f) | ||
| self.assertEqual( | ||
| str(type(state_dict["weight"])), | ||
| "<class 'torchao.quantization.Float8OpaqueTensor'>", | ||
| ) | ||
|
|
||
|
|
||
| common_utils.instantiate_parametrized_tests(TestFloat8OpaqueTensor) | ||
Xia-Weiwen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,6 +73,8 @@ | |
| KernelPreference, | ||
| ) | ||
| from torchao.quantization.quantize_.workflows import ( | ||
| Float8OpaqueTensor, | ||
| Float8PackingFormat, | ||
| Float8Tensor, | ||
| Int4ChooseQParamsAlgorithm, | ||
| Int4MarlinSparseTensor, | ||
|
|
@@ -1716,6 +1718,22 @@ def _input_activation_quant_func_fp8( | |
| return activation | ||
|
|
||
|
|
||
| def _input_activation_quant_cpu_fp8( | ||
|
||
| x: torch.Tensor, | ||
| activation_granularity: FP8Granularity, | ||
| activation_dtype: torch.dtype, | ||
| ): | ||
| """Dynamic quantize activation to fp8 for CPU.""" | ||
| block_size = get_block_size(x.shape, activation_granularity) | ||
| return to_affine_quantized_floatx( | ||
| input_float=x, | ||
| block_size=block_size, | ||
| target_dtype=activation_dtype, | ||
| scale_dtype=torch.float32, | ||
| _layout=PlainLayout(), | ||
| ) | ||
|
|
||
|
|
||
| def _fp8_mm_compat(weight: torch.Tensor) -> bool: | ||
| """ | ||
| Check if a weight tensor meets float8 quantization requirements. | ||
|
|
@@ -1774,14 +1792,23 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): | |
| kernel_preference: KernelPreference = KernelPreference.AUTO | ||
| set_inductor_config: bool = True | ||
| version: int = 2 | ||
| float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN | ||
|
|
||
| def __post_init__(self): | ||
| torch._C._log_api_usage_once( | ||
| "torchao.quantization.Float8DynamicActivationFloat8WeightConfig" | ||
| ) | ||
| activation_granularity, weight_granularity = _normalize_granularity( | ||
| self.granularity | ||
| ) | ||
| if ( | ||
| self.version == 2 | ||
| and self.float8_packing_format == Float8PackingFormat.OPAQUE | ||
| ): | ||
| activation_granularity, weight_granularity = ( | ||
| Float8OpaqueTensor._normalize_and_check_granularity(self.granularity) | ||
| ) | ||
| else: | ||
| activation_granularity, weight_granularity = _normalize_granularity( | ||
| self.granularity | ||
| ) | ||
| self.granularity = [activation_granularity, weight_granularity] | ||
|
|
||
| default_use_fast_accum = True | ||
|
|
@@ -1811,10 +1838,11 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): | |
| activation_value_lb = config.activation_value_lb | ||
| activation_value_ub = config.activation_value_ub | ||
| kernel_preference = config.kernel_preference | ||
| float8_packing_format = config.float8_packing_format | ||
|
|
||
| # Ensure works on device | ||
| _check_hardware_support(granularity) | ||
| activation_granularity, weight_granularity = granularity | ||
| is_cpu = weight.device.type == "cpu" | ||
|
|
||
| # Note: right now we assume it's weights of conv2d and conv3d purely based | ||
| # on the dimension of weight, currently there is no conflict with linear 2d | ||
|
|
@@ -1834,12 +1862,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): | |
| if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: | ||
| return weight | ||
|
|
||
| elif not _fp8_mm_compat(weight): | ||
| elif not is_cpu and not _fp8_mm_compat(weight): | ||
Xia-Weiwen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # TODO(future PR): this should really throw an exception instead of silently | ||
| # not doing what the user asked | ||
| return weight | ||
|
|
||
| if isinstance(weight_granularity, PerRow): | ||
| if not is_cpu and isinstance(weight_granularity, PerRow): | ||
Xia-Weiwen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert weight.dtype == torch.bfloat16, ( | ||
| "PerRow quantization only works for bfloat16 precision input weight" | ||
| ) | ||
|
|
@@ -1849,6 +1877,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): | |
| "Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" | ||
| ) | ||
|
|
||
| _check_hardware_support(granularity) | ||
Xia-Weiwen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| block_size = get_block_size(weight.shape[-2:], weight_granularity) | ||
| if weight.dim() == 3: | ||
| block_size = tuple([1] + list(block_size)) | ||
|
|
@@ -1879,14 +1908,26 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): | |
| kernel_preference=kernel_preference, | ||
| ) | ||
|
|
||
| quantized_weight = Float8Tensor.from_hp( | ||
| weight, | ||
| float8_dtype=weight_dtype, | ||
| granularity=weight_granularity, | ||
| mm_config=mm_config, | ||
| kernel_preference=kernel_preference, | ||
| act_quant_kwargs=act_quant_kwargs, | ||
| ) | ||
| if float8_packing_format == Float8PackingFormat.PLAIN: | ||
| quantized_weight = Float8Tensor.from_hp( | ||
| weight, | ||
| float8_dtype=weight_dtype, | ||
| granularity=weight_granularity, | ||
| mm_config=mm_config, | ||
| kernel_preference=kernel_preference, | ||
| act_quant_kwargs=act_quant_kwargs, | ||
| ) | ||
| elif float8_packing_format == Float8PackingFormat.OPAQUE: | ||
| block_size = get_block_size(weight.shape, weight_granularity) | ||
| quantized_weight = Float8OpaqueTensor.from_hp( | ||
| weight, | ||
| block_size=block_size, | ||
| act_quant_kwargs=act_quant_kwargs, | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported float8 packing format: {float8_packing_format}" | ||
| ) | ||
|
|
||
| return quantized_weight | ||
|
|
||
|
|
@@ -1898,12 +1939,6 @@ def _float8_dynamic_activation_float8_weight_transform( | |
| *, | ||
| parameter_name: str = "weight", | ||
| ): | ||
| assert is_sm_at_least_89() or is_MI300(), ( | ||
jerryzh168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" | ||
| ) | ||
| if config.set_inductor_config: | ||
jerryzh168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| torchao.quantization.utils.recommended_inductor_config_setter() | ||
|
|
||
| assert hasattr(module, parameter_name), ( | ||
| f"applying float8 dynamic activation quant requires module to have parameter {parameter_name} attribute" | ||
| + f" but {module} does not have one" | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.