-
Notifications
You must be signed in to change notification settings - Fork 399
Add NPU (Ascend) backend support for INT4 weight-only quantization workflow #3172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
f3aefca
68eea61
164435e
06c77d1
498f052
ea2aa7a
ca8f056
05af947
25360da
fa3220f
623c589
89ad729
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,7 @@ def get_config(group_size): | |
|
|
||
| @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") | ||
| @unittest.skipIf(not torch.xpu.is_available(), "XPU not available") | ||
| class Int4PlainInt32Tensor(TestCase): | ||
| class Int4PlainInt32TensorXPU(TestCase): | ||
|
||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
|
|
@@ -98,8 +98,74 @@ def test_activation_prescaling(self): | |
| self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(Int4PlainInt32Tensor) | ||
| @unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+") | ||
| @unittest.skipIf( | ||
| torch.accelerator.current_accelerator().type != "npu" | ||
| or not torch.accelerator.is_available(), | ||
| "NPU not available", | ||
| ) | ||
| class Int4PlainInt32TensorNPU(TestCase): | ||
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
| ((128,), 256, 128), | ||
| ((32, 128), 512, 128), | ||
| ((2, 32, 128), 256, 128), | ||
| ], | ||
| ) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| @parametrize("group_size", [32, 64]) | ||
| def test_linear(self, device, sizes, dtype, group_size): | ||
| M, N, K = sizes | ||
| input = torch.randn(*M, K, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
| orig_output = linear(input) | ||
| quantize_(linear, get_config(group_size)) | ||
| quantized_output = linear(input) | ||
| self.assertTrue(compute_error(orig_output, quantized_output) > 10) | ||
|
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_module_path(self, device, dtype): | ||
| linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) | ||
| quantize_(linear, get_config(group_size=64)) | ||
| self.assertEqual( | ||
| str(type(linear.weight)), | ||
| "<class 'torchao.quantization.Int4PlainInt32Tensor'>", | ||
| ) | ||
|
|
||
| with tempfile.NamedTemporaryFile() as f: | ||
| torch.save(linear.state_dict(), f) | ||
| f.seek(0) | ||
| state_dict = torch.load(f) | ||
| self.assertEqual( | ||
| str(type(state_dict["weight"])), | ||
| "<class 'torchao.quantization.Int4PlainInt32Tensor'>", | ||
| ) | ||
|
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_activation_prescaling(self, device, dtype): | ||
| input = torch.randn(1, 128, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) | ||
| original = linear(input) | ||
| quantize_(linear, get_config(64)) | ||
| qw = linear.weight | ||
| assert isinstance(qw, SupportsActivationPreScaling), ( | ||
| "Expected int4 tensor supports activation prescaling" | ||
| ) | ||
| assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" | ||
| _ACT_PRE_SCALE = 2 | ||
| qw.act_pre_scale = _ACT_PRE_SCALE | ||
| quantized = linear(input) | ||
|
|
||
| # making sure activation pre scaling is successfully applied to the activation | ||
| self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(Int4PlainInt32TensorXPU) | ||
| instantiate_parametrized_tests(Int4PlainInt32TensorNPU) | ||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add this and mention NPU support requirements (torch_npu >=2.7.1) in the quantization README instead? I would put here: https://github.com/pytorch/ao/tree/main/torchao/quantization#a16w4-weightonly-quantization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I’ll add a subheading under the
a16w4-weightonly-quantizationsection for thisThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved.