Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ TorchAO is an easy to use quantization library for native PyTorch. TorchAO works

Check out our [docs](https://docs.pytorch.org/ao/main/) for more details!

## Third-party Pipeline Status
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this and mention NPU support requirements (torch_npu >=2.7.1) in the quantization README instead? I would put here: https://github.com/pytorch/ao/tree/main/torchao/quantization#a16w4-weightonly-quantization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I’ll add a subheading under the a16w4-weightonly-quantization section for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved.


| Backend | Inference |
| ----------- | -------------------------------------------------------------------------------------------------------------------- |
| Ascend NPU | [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml) |

## 🚀 Quick Start

First, install TorchAO. We recommend installing the latest stable version:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import pytest
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
Expand All @@ -33,9 +33,19 @@ def get_config(group_size):
)


@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
class Int4PlainInt32Tensor(TestCase):
_MIN_VER = {
"xpu": "2.8.0",
"npu": "2.7.1",
}

def setUp(self):
min_req = type(self)._MIN_VER.get(self.device_type)
if not torch_version_at_least(min_req):
self.skipTest(
f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}"
)

@parametrize(
"sizes",
[
Expand All @@ -46,24 +56,36 @@ class Int4PlainInt32Tensor(TestCase):
)
@parametrize("dtype", [torch.bfloat16, torch.half])
@parametrize("group_size", [32, 64, 128])
def test_linear(self, sizes, dtype, group_size):
device = "xpu"
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
def test_linear(self, device, sizes, dtype, group_size, thresholds):
M, N, K = sizes
if "npu" in device and group_size == K:
pytest.skip(
f"{device} does not support group_size equal to K dimension ({group_size} == {K})"
)
threshold = thresholds.get(device.split(":")[0])

input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(group_size))
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)
self.assertTrue(compute_error(original, quantized) > threshold)

compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
if "xpu" in device:
compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > threshold)

@parametrize("dtype", [torch.bfloat16, torch.half])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
quantize_(linear, get_config(group_size=128))
def test_module_path(self, device, dtype):
device = self.device_type
K, N, group_size = 128, 256, 128
if "npu" in device:
group_size = 64

linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
quantize_(linear, get_config(group_size))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
Expand All @@ -78,13 +100,22 @@ def test_module_path(self, dtype):
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

def test_activation_prescaling(self):
dtype = torch.bfloat16
device = "xpu"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
def test_activation_prescaling(self, device, dtype, thresholds):
device = self.device_type
if "xpu" in device and dtype == torch.float16:
pytest.skip(f"{device} test_activation_prescaling don't test {dtype}")

threshold = thresholds.get(device.split(":")[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does device_type have :? I thought it should only be things like xpu, npu, cuda, not cuda:0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch — you’re right! I actually meant to use the function argument device there, but forgot to remove device = self.device_type from the setup.
device can include the suffix like ":0", while device_type should not. I’ll fix that, thanks for pointing it out!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

K, N, group_size = 128, 256, 128
if "npu" in device:
group_size = 64

input = torch.randn(1, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(128))
quantize_(linear, get_config(group_size))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
Expand All @@ -95,10 +126,12 @@ def test_activation_prescaling(self):
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold)


instantiate_parametrized_tests(Int4PlainInt32Tensor)
instantiate_device_type_tests(
Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True
)


if __name__ == "__main__":
Expand Down
Loading
Loading