Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/quantization_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ First we want to lay out the torchao stack::

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor
---------------------------------------------------------------------------------------------
Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma
- scaled int4
- preshuffled (special format to optimize for loading)
- float8 act + int4 weight dynamic quantization and int4 weight only quantization
* - Int8Tensor
- plain (no packing needed)

.. note::
We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options.
Expand Down
217 changes: 217 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

import torch
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal import common_utils

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.utils import compute_error, get_block_size
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.testing.utils import TorchAOIntegrationTestCase


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestInt8Tensor(TorchAOIntegrationTestCase):
def setUp(self):
super().setUp()

self.test_shape = (32, 20)
self.dtype = torch.bfloat16
self.batch_size = 32

torch.manual_seed(42)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_creation_and_attributes(self, config):
"""Test tensor creation, dtypes, and ranges"""
linear = torch.nn.Linear(
self.test_shape[1],
self.test_shape[0],
bias=False,
dtype=self.dtype,
device="cuda",
)
quantize_(linear, config)

w = linear.weight

self.assertEqual(w.shape, self.test_shape)
self.assertEqual(w.qdata.dtype, torch.int8)
self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127))

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128), # 2D
((32, 128), 64, 256), # 3D
],
)
def test_int8_linear_variants(
self,
dtype: torch.dtype,
config,
compile: bool,
sizes: tuple,
):
"""Test linear operation supports including shape and compile"""
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
model_q = copy.deepcopy(model)

quantize_(model_q, config)

self.assertEqual(model_q.linear2.weight.scale.shape, (K,))
self.assertEqual(model_q.linear2.weight.scale.ndim, 1)

if compile:
model_q = torch.compile(model_q, fullgraph=True)

output_fp = model(input_tensor)
output_quantized = model_q(input_tensor)

assert compute_error(output_fp, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("device", ["cpu", "cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_slice(self, config, device, dtype):
"""Test tensor slicing with per-row quantization"""
tensor_size = 256
slice_sizes = (64, 128)

dummy = torch.nn.Linear(
tensor_size, tensor_size, bias=False, dtype=dtype, device=device
)
quantize_(dummy, config)

weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0])
weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1])

self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]))
self.assertEqual(weight2.scale, dummy.weight.scale)
with self.assertRaises(NotImplementedError):
_ = dummy.weight[::2]

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
],
)
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_index_select(self, config, granularity):
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
N, K = 256, 512
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
linear.weight.data = x

config = config(version=2, granularity=granularity)
quantize_(linear, config)

x_int8 = linear.weight
x_int8_0 = x_int8[0]

# Test dequantization consistency
torch.testing.assert_close(
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0
)

# Test block_size granularity
if isinstance(granularity, PerRow):
self.assertEqual(
list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K]
)
elif isinstance(granularity, PerTensor):
self.assertEqual(
list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K]
)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_dequantization_accuracy(self, config):
"""Test dequantization accuracy separately"""
linear = torch.nn.Linear(
256, 512, bias=False, dtype=torch.bfloat16, device="cuda"
)
weight_fp = copy.deepcopy(linear.weight)
quantize_(linear, config)

tensor = linear.weight
dequantized = tensor.dequantize()
self.assertEqual(dequantized.shape, weight_fp.shape)
assert compute_error(dequantized, weight_fp) > 20, (
f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}"
)

def test_available_gpu_kernels(self):
"""Check which GPU kernels are used"""
torch.compiler.reset()

M, K, N = 128, 256, 512
m = torch.nn.Sequential(
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
)

config = Int8DynamicActivationInt8WeightConfig(version=2)
quantize_(m, config)

m = torch.compile(m)
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)

out, code = run_and_get_code(m, x)

# Check expected kernels are present
FileCheck().check_count("triton_per_fused", 1).check_count(
"extern_kernels._int_mm", 1
).check_count("triton_poi_fused", 1).run(code[0])


if __name__ == "__main__":
common_utils.run_tests()
Loading
Loading