Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao import quantize_
from torchao.quantization import PerGroup, PerRow, PerTensor
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
)


def get_config(granularity):
return Float8DynamicActivationFloat8WeightConfig(
activation_dtype=torch.float8_e4m3fn,
granularity=granularity,
float8_packing_format="opaque",
)


class ToyLinearModel(torch.nn.Module):
def __init__(self, K=64, N=32, bias=False):
super().__init__()
self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float)
self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float)

def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (
torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device)
* 0.1,
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestFloat8OpaqueTensor(TestCase):
"""Test cases for Float8OpaqueTensor on CPU"""

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [1, 160])
@common_utils.parametrize(
"x_granularity",
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
)
@common_utils.parametrize(
"w_granularity",
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
)
def test_dynamic_float8_linear(
self, dtype, x_dim, bias, bs, x_granularity, w_granularity
):
if isinstance(x_granularity, PerGroup):
if not isinstance(w_granularity, PerGroup):
return
if w_granularity.group_size != x_granularity.group_size:
return
device = "cpu"
m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device)
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

with torch.no_grad():
quantize_(
m,
get_config([x_granularity, w_granularity]),
)
y1 = m(*example_inputs)
assert compute_error(y, y1) > 20
y2, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
assert compute_error(y, y2) > 20

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [4, 128])
def test_dynamic_float8_linear_ref(self, dtype, x_dim, bias, bs):
device = "cpu"
# the shape is not supported by cpp kernel, so the ref path will be used.
m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device)
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

with torch.no_grad():
quantize_(
m,
get_config(PerRow()),
)
y1 = m(*example_inputs)
assert compute_error(y, y1) > 20
y2, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
assert compute_error(y, y2) > 20

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype)
quantize_(linear, get_config(PerRow()))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Float8OpaqueTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Float8OpaqueTensor'>",
)


common_utils.instantiate_parametrized_tests(TestFloat8OpaqueTensor)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
quantize_affine,
)
from .quantize_.workflows import (
Float8OpaqueTensor,
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Expand Down Expand Up @@ -174,6 +175,7 @@
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Int4OpaqueTensor",
"Float8OpaqueTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
75 changes: 55 additions & 20 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
KernelPreference,
)
from torchao.quantization.quantize_.workflows import (
Float8OpaqueTensor,
Float8PackingFormat,
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4MarlinSparseTensor,
Expand Down Expand Up @@ -1716,6 +1718,22 @@ def _input_activation_quant_func_fp8(
return activation


def _input_activation_quant_cpu_fp8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this function used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Removed.

x: torch.Tensor,
activation_granularity: FP8Granularity,
activation_dtype: torch.dtype,
):
"""Dynamic quantize activation to fp8 for CPU."""
block_size = get_block_size(x.shape, activation_granularity)
return to_affine_quantized_floatx(
input_float=x,
block_size=block_size,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
_layout=PlainLayout(),
)


def _fp8_mm_compat(weight: torch.Tensor) -> bool:
"""
Check if a weight tensor meets float8 quantization requirements.
Expand Down Expand Up @@ -1774,14 +1792,23 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
kernel_preference: KernelPreference = KernelPreference.AUTO
set_inductor_config: bool = True
version: int = 2
float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
)
activation_granularity, weight_granularity = _normalize_granularity(
self.granularity
)
if (
self.version == 2
and self.float8_packing_format == Float8PackingFormat.OPAQUE
):
activation_granularity, weight_granularity = (
Float8OpaqueTensor._normalize_and_check_granularity(self.granularity)
)
else:
activation_granularity, weight_granularity = _normalize_granularity(
self.granularity
)
self.granularity = [activation_granularity, weight_granularity]

default_use_fast_accum = True
Expand Down Expand Up @@ -1811,10 +1838,11 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
activation_value_lb = config.activation_value_lb
activation_value_ub = config.activation_value_ub
kernel_preference = config.kernel_preference
float8_packing_format = config.float8_packing_format

# Ensure works on device
_check_hardware_support(granularity)
activation_granularity, weight_granularity = granularity
is_cpu = weight.device.type == "cpu"

# Note: right now we assume it's weights of conv2d and conv3d purely based
# on the dimension of weight, currently there is no conflict with linear 2d
Expand All @@ -1834,12 +1862,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
return weight

elif not _fp8_mm_compat(weight):
elif not is_cpu and not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return weight

if isinstance(weight_granularity, PerRow):
if not is_cpu and isinstance(weight_granularity, PerRow):
assert weight.dtype == torch.bfloat16, (
"PerRow quantization only works for bfloat16 precision input weight"
)
Expand All @@ -1849,6 +1877,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
"Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details"
)

_check_hardware_support(granularity)
block_size = get_block_size(weight.shape[-2:], weight_granularity)
if weight.dim() == 3:
block_size = tuple([1] + list(block_size))
Expand Down Expand Up @@ -1879,14 +1908,26 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
kernel_preference=kernel_preference,
)

quantized_weight = Float8Tensor.from_hp(
weight,
float8_dtype=weight_dtype,
granularity=weight_granularity,
mm_config=mm_config,
kernel_preference=kernel_preference,
act_quant_kwargs=act_quant_kwargs,
)
if float8_packing_format == Float8PackingFormat.PLAIN:
quantized_weight = Float8Tensor.from_hp(
weight,
float8_dtype=weight_dtype,
granularity=weight_granularity,
mm_config=mm_config,
kernel_preference=kernel_preference,
act_quant_kwargs=act_quant_kwargs,
)
elif float8_packing_format == Float8PackingFormat.OPAQUE:
block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = Float8OpaqueTensor.from_hp(
weight,
block_size=block_size,
act_quant_kwargs=act_quant_kwargs,
)
else:
raise ValueError(
f"Unsupported float8 packing format: {float8_packing_format}"
)

return quantized_weight

Expand All @@ -1898,12 +1939,6 @@ def _float8_dynamic_activation_float8_weight_transform(
*,
parameter_name: str = "weight",
):
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
)
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

assert hasattr(module, parameter_name), (
f"applying float8 dynamic activation quant requires module to have parameter {parameter_name} attribute"
+ f" but {module} does not have one"
Expand Down
6 changes: 6 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .float8.float8_opaque_tensor import (
Float8OpaqueTensor,
)
from .float8.float8_packing_format import Float8PackingFormat
from .float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
Expand Down Expand Up @@ -37,7 +41,9 @@
"Int4MarlinSparseTensor",
"Int4PlainInt32Tensor",
"Int4TilePackedTo4dTensor",
"Float8OpaqueTensor",
"Float8Tensor",
"Float8PackingFormat",
"QuantizeTensorToFloat8Kwargs",
"Int4OpaqueTensor",
"Int4ChooseQParamsAlgorithm",
Expand Down
Loading
Loading