Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ TorchAO is an easy to use quantization library for native PyTorch. TorchAO works

Check out our [docs](https://docs.pytorch.org/ao/main/) for more details!

## Third-party Pipeline Status
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this and mention NPU support requirements (torch_npu >=2.7.1) in the quantization README instead? I would put here: https://github.com/pytorch/ao/tree/main/torchao/quantization#a16w4-weightonly-quantization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I’ll add a subheading under the a16w4-weightonly-quantization section for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved.


| Backend | Inference |
| ----------- | -------------------------------------------------------------------------------------------------------------------- |
| Ascend NPU | [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml) |

## 🚀 Quick Start

First, install TorchAO. We recommend installing the latest stable version:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_config(group_size):

@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
class Int4PlainInt32Tensor(TestCase):
class Int4PlainInt32TensorXPU(TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: revert change

@parametrize(
"sizes",
[
Expand Down Expand Up @@ -98,8 +98,74 @@ def test_activation_prescaling(self):
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)


instantiate_parametrized_tests(Int4PlainInt32Tensor)
@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
@unittest.skipIf(
torch.accelerator.current_accelerator().type != "npu"
or not torch.accelerator.is_available(),
"NPU not available",
)
class Int4PlainInt32TensorNPU(TestCase):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be merged with the xpu case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I’ll combine them into a single test class.

@parametrize("device", ["npu"])
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 128),
],
)
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("group_size", [32, 64])
def test_linear(self, device, sizes, dtype, group_size):
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
orig_output = linear(input)
quantize_(linear, get_config(group_size))
quantized_output = linear(input)
self.assertTrue(compute_error(orig_output, quantized_output) > 10)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_module_path(self, device, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
quantize_(linear, get_config(group_size=64))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_activation_prescaling(self, device, dtype):
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(64))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
)
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
qw.act_pre_scale = _ACT_PRE_SCALE
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)


instantiate_parametrized_tests(Int4PlainInt32TensorXPU)
instantiate_parametrized_tests(Int4PlainInt32TensorNPU)

if __name__ == "__main__":
run_tests()
Loading
Loading