-
Notifications
You must be signed in to change notification settings - Fork 400
Add NPU (Ascend) backend support for INT4 weight-only quantization workflow #3172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+292
−78
Merged
Changes from 4 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
f3aefca
Add NPU (Ascend) backend support for INT4 weight-only quantization wo…
orangeH25 68eea61
use torch.ops.npu prefix and drop redundant torch_npu import
orangeH25 164435e
Merge branch 'pytorch:main' into quant/int4/wo/0
orangeH25 06c77d1
Modify test file and update comments
orangeH25 498f052
Merge branch 'pytorch:main' into quant/int4/wo/0
orangeH25 ea2aa7a
add: merge NPU(Ascend) backend logic in Int4PlainInt32Tensor subclass
orangeH25 ca8f056
ruff format cleanup, replace error types, add torch version check
orangeH25 05af947
add torch_npu version assertion and show downstream testing result
orangeH25 25360da
add downstream testing result
orangeH25 fa3220f
unify NPU and XPU test cases into a single class
orangeH25 623c589
move CI display to quantization README and update test file
orangeH25 89ad729
Merge branch 'pytorch:main' into quant/int4/wo/0
orangeH25 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
107 changes: 107 additions & 0 deletions
107
test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
| import tempfile | ||
| from packaging import version | ||
|
|
||
| import torch | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| instantiate_parametrized_tests, | ||
| parametrize, | ||
| run_tests, | ||
| ) | ||
|
|
||
| from torchao.quantization import ( | ||
| Int4WeightOnlyConfig, | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.quantize_.common import SupportsActivationPreScaling | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.utils import ( | ||
| torch_version_at_least, | ||
| ) | ||
|
|
||
|
|
||
| def get_config(group_size): | ||
| return Int4WeightOnlyConfig( | ||
| group_size=group_size, | ||
| int4_packing_format="plain_int32", | ||
| ) | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+") | ||
| @unittest.skipIf( | ||
| torch.accelerator.current_accelerator(True).type == "npu" | ||
| and torch.accelerator.is_available(), | ||
| "NPU not available", | ||
| ) | ||
| class Int4PlainInt32TensorNPU(TestCase): | ||
|
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
| ((128,), 256, 128), | ||
| ((32, 128), 512, 128), | ||
| ((2, 32, 128), 256, 128), | ||
| ], | ||
| ) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| @parametrize("group_size", [32, 64]) | ||
| def test_linear(self, device, sizes, dtype, group_size): | ||
| M, N, K = sizes | ||
| input = torch.randn(*M, K, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
| orig_output = linear(input) | ||
| quantize_(linear, get_config(group_size)) | ||
| quantized_output = linear(input) | ||
| self.assertTrue(compute_error(orig_output, quantized_output) > 10) | ||
|
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_module_path(self, device, dtype): | ||
| linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) | ||
| quantize_(linear, get_config(group_size=64)) | ||
| self.assertEqual( | ||
| str(type(linear.weight)), | ||
| "<class 'torchao.quantization.Int4PlainInt32TensorNPU'>", | ||
| ) | ||
|
|
||
| with tempfile.NamedTemporaryFile() as f: | ||
| torch.save(linear.state_dict(), f) | ||
| f.seek(0) | ||
| state_dict = torch.load(f) | ||
| self.assertEqual( | ||
| str(type(state_dict["weight"])), | ||
| "<class 'torchao.quantization.Int4PlainInt32TensorNPU'>", | ||
| ) | ||
|
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_activation_prescaling(self, device, dtype): | ||
| input = torch.randn(1, 128, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) | ||
| original = linear(input) | ||
| quantize_(linear, get_config(64)) | ||
| qw = linear.weight | ||
| assert isinstance( | ||
| qw, SupportsActivationPreScaling | ||
| ), "Expected int4 tensor supports activation prescaling" | ||
| assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" | ||
| _ACT_PRE_SCALE = 2 | ||
| qw.act_pre_scale = _ACT_PRE_SCALE | ||
| quantized = linear(input) | ||
|
|
||
| # making sure activation pre scaling is successfully applied to the activation | ||
| self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(Int4PlainInt32TensorNPU) | ||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
232 changes: 232 additions & 0 deletions
232
torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from torchao.quantization.quant_primitives import ( | ||
| MappingType, | ||
| choose_qparams_affine, | ||
| quantize_affine, | ||
| ) | ||
| from torchao.utils import ( | ||
| TorchAOBaseTensor, | ||
| ) | ||
|
|
||
| __all__ = ["Int4PlainInt32TensorNPU"] | ||
|
|
||
| aten = torch.ops.aten | ||
|
|
||
|
|
||
| class Int4PlainInt32TensorNPU(TorchAOBaseTensor): | ||
| """ | ||
| int4 weight-only quantization on Ascend NPU backend (groupwise quantization only) | ||
|
|
||
| Tensor Attributes: | ||
| qdata: (N, K/8), packed int4 weight, the data type is int32 here with 8*int4, the original dtype can be float16 or bfloat16 | ||
|
||
| scale: (K/group_size, N), dtype is the same as the original Tensor type (float16 or bfloat16) | ||
| zero_point: (K/group_size, N), dtype is the same as the original Tensor type (float16 or bfloat16) | ||
|
|
||
| Non-Tensor Attributes: | ||
| block_size: the block size for quantization, representing the granularity | ||
| shape: shape of the original Tensor | ||
|
|
||
| Optional Tensor Data Attributes: | ||
| act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present, | ||
| we'll multiply activation Tensor with act_pre_scale before applying dynamic | ||
| quantization to activation or running quantized mm op | ||
|
|
||
| """ | ||
|
|
||
| tensor_data_names = ["qdata", "scale", "zero_point"] | ||
| tensor_attribute_names = ["block_size", "shape"] | ||
| optional_tensor_data_names = ["act_pre_scale"] | ||
|
|
||
| def __new__( | ||
| cls, | ||
| qdata, | ||
| scale, | ||
| zero_point, | ||
| block_size, | ||
| shape, | ||
| act_pre_scale: Optional[torch.Tensor] = None, | ||
| ): | ||
| kwargs = {} | ||
| kwargs["device"] = qdata.device | ||
| kwargs["dtype"] = scale.dtype | ||
| kwargs["requires_grad"] = False | ||
| return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
|
||
| def __init__( | ||
| self, | ||
| qdata, | ||
| scale, | ||
| zero_point, | ||
| block_size, | ||
| shape, | ||
| act_pre_scale: Optional[torch.Tensor] = None, | ||
| ): | ||
| self.qdata = qdata | ||
| self.scale = scale | ||
| self.zero_point = zero_point | ||
| self.block_size = block_size | ||
| self.act_pre_scale = act_pre_scale | ||
|
|
||
| def _quantization_type(self): | ||
| s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}" | ||
| if self.act_pre_scale is not None: | ||
| s += f", act_pre_scale.shape={self.act_pre_scale.shape}" | ||
| return s | ||
|
|
||
| @classmethod | ||
| def from_hp( | ||
| cls, | ||
| w: torch.Tensor, | ||
| block_size: List[int], | ||
| ): | ||
| assert w.ndim == 2 and w.device.type == "npu", ( | ||
| f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}" | ||
| ) | ||
| assert len(block_size) == w.ndim | ||
| assert w.dtype in [torch.float16, torch.bfloat16], ( | ||
| f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" | ||
| ) | ||
|
|
||
| original_shape = w.shape | ||
| mapping_type = MappingType.ASYMMETRIC | ||
| target_dtype = torch.int32 | ||
| quant_min = -8 | ||
| quant_max = 7 | ||
| eps = 1e-6 | ||
| scale_dtype = w.dtype | ||
| zero_point_dtype = w.dtype | ||
|
|
||
| scale, zero_point = choose_qparams_affine( | ||
| w, | ||
| mapping_type, | ||
| block_size, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| eps, | ||
| scale_dtype, | ||
| zero_point_dtype, | ||
| ) | ||
|
|
||
| int_data = quantize_affine( | ||
| w, | ||
| block_size, | ||
| scale, | ||
| zero_point, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| ) | ||
|
|
||
| assert int_data.dtype == torch.int32, ( | ||
| f"torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype" | ||
| ) | ||
|
|
||
| assert int_data.shape[-1] % 8 == 0, ( | ||
| f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" | ||
| ) | ||
|
|
||
| packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack( | ||
| int_data.contiguous(), 0 | ||
| ) | ||
|
|
||
| scale = scale.reshape(int_data.shape[0], -1) | ||
| zero_point = zero_point.reshape(int_data.shape[0], -1) | ||
|
|
||
| return Int4PlainInt32TensorNPU( | ||
| packed_weight, | ||
| scale.transpose(0, 1).contiguous(), | ||
| zero_point.transpose(0, 1).contiguous(), | ||
| block_size, | ||
| original_shape, | ||
| act_pre_scale=None, | ||
| ) | ||
|
|
||
|
|
||
| implements = Int4PlainInt32TensorNPU.implements | ||
| implements_torch_function = Int4PlainInt32TensorNPU.implements_torch_function | ||
|
|
||
|
|
||
| @implements(aten.linear.default) | ||
| @implements_torch_function(torch.nn.functional.linear) | ||
| def _(func, types, args, kwargs): | ||
|
|
||
| input_tensor, weight_tensor, bias = ( | ||
| args[0], | ||
| args[1], | ||
| args[2] if len(args) > 2 else None, | ||
| ) | ||
|
|
||
| assert input_tensor.device.type == "npu", ( | ||
| f"For NPU device only but got: {input_tensor.device.type}" | ||
| ) | ||
| assert isinstance(weight_tensor, Int4PlainInt32TensorNPU), ( | ||
| f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}" | ||
| ) | ||
| assert weight_tensor.block_size[0] == 1, ( | ||
| f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" | ||
| ) | ||
| assert input_tensor.shape[-1] == weight_tensor.shape[1], ( | ||
| f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" | ||
| ) | ||
|
|
||
| if weight_tensor.act_pre_scale is not None: | ||
| input_tensor = input_tensor * weight_tensor.act_pre_scale | ||
|
|
||
| act_mat = input_tensor | ||
| packed_weight = weight_tensor.qdata | ||
| scale = weight_tensor.scale | ||
| zero_point = weight_tensor.zero_point | ||
|
|
||
| orig_act_size = act_mat.shape | ||
| orig_dtype = act_mat.dtype | ||
|
|
||
| # dtype alignment | ||
| if act_mat.dtype == torch.float16: | ||
| scale = scale.to(torch.float16) | ||
| zero_point = zero_point.to(torch.float16) | ||
| if bias is not None: | ||
| bias = bias.to(torch.float16) | ||
| elif act_mat.dtype == torch.bfloat16: | ||
| scale = scale.to(torch.bfloat16) | ||
| zero_point = zero_point.to(torch.bfloat16) | ||
| if bias is not None: | ||
| bias = bias.to(torch.float32) | ||
|
|
||
| # reshape to 2D | ||
| act_mat = act_mat.reshape(-1, act_mat.shape[-1]) | ||
|
|
||
| # groupwise int4 quantization | ||
| groupsize = weight_tensor.block_size[1] | ||
|
|
||
| y = torch.ops.npu.npu_weight_quant_batchmatmul( | ||
| x=act_mat, | ||
| weight=packed_weight.contiguous().transpose(-1, -2), | ||
| antiquant_scale=scale, | ||
| antiquant_offset=zero_point, | ||
| antiquant_group_size=groupsize, | ||
| bias=bias, | ||
| ) | ||
|
|
||
| # remove out_feature padding | ||
| assert weight_tensor.ndim == 2 | ||
| orig_out_features = weight_tensor.shape[-2] | ||
| y = y[:, :orig_out_features] | ||
| y = y.reshape(*orig_act_size[:-1], orig_out_features) | ||
|
|
||
| return y.to(orig_dtype) | ||
|
|
||
|
|
||
| Int4PlainInt32TensorNPU.__module__ = "torchao.quantization" | ||
|
|
||
| # Allow a model with Int4PlainInt32TensorNPU weights to be loaded with `weights_only=True` | ||
| torch.serialization.add_safe_globals([Int4PlainInt32TensorNPU]) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, do we need NPUs to test this? I don't think we have any in CI.