Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
MultiTensorInputRecorder,
)
from .granularity import (
Granularity,
PerAxis,
PerBlock,
PerGroup,
PerRow,
PerTensor,
Expand Down Expand Up @@ -197,8 +199,10 @@
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"Granularity",
"PerTensor",
"PerAxis",
"PerBlock",
"PerGroup",
"PerRow",
"PerToken",
Expand Down
15 changes: 15 additions & 0 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class PerGroup(Granularity):
group_size: int


@dataclass(frozen=True)
class PerRow(Granularity):
"""
Represents row-wise granularity in quantization.
Expand All @@ -83,6 +84,7 @@ class PerRow(Granularity):
pass


@dataclass(frozen=True)
class PerToken(Granularity):
"""
Represents per-token granularity in quantization.
Expand All @@ -99,3 +101,16 @@ class PerToken(Granularity):
"""

pass


@dataclass(frozen=True)
class PerBlock(Granularity):
"""
Represents per-block granularity in quantization. See
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
`block_size`
Attributes:
block_size (Tuple[int, ...]): The size of each quantization group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: match tuple type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated.

"""

block_size: tuple[int, ...]
22 changes: 1 addition & 21 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from .granularity import (
Granularity,
PerAxis,
PerRow,
PerTensor,
)
Expand All @@ -24,6 +23,7 @@
_get_reduction_params,
choose_qparams_affine_with_min_max,
)
from .utils import get_block_size

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,26 +63,6 @@ def _with_args(cls_or_self, *args, **kwargs):
return r


def get_block_size(
input_shape: Tuple[int, ...], granularity: Granularity
) -> Tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.

Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity: The granularity type of the quantization
"""
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity.axis] = 1
return tuple(block_size)
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
raise ValueError(f"Unsupported Granularity: {granularity}")


ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:


Expand Down
10 changes: 2 additions & 8 deletions torchao/quantization/pt2e/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
update_equivalent_types_dict,
)

from .. import Granularity, PerAxis, PerBlock, PerGroup, PerRow, PerTensor, PerToken
from ..utils import get_block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does feel a bit weird, does importing from full path works here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated

from .fake_quantize import (
FakeQuantize,
FakeQuantizeBase,
Expand All @@ -48,7 +50,6 @@
from .observer import (
AffineQuantizedObserverBase,
FixedQParamsObserver,
Granularity,
HistogramObserver,
MappingType,
MinMaxObserver,
Expand All @@ -57,20 +58,13 @@
NoopObserver,
ObserverBase,
PartialWrapper,
PerAxis,
PerBlock,
PerChannelMinMaxObserver,
PerGroup,
PerRow,
PerTensor,
PerToken,
PlaceholderObserver,
RecordingObserver,
ReuseInputObserver,
TorchAODType,
UniformQuantizationObserverBase,
ZeroPointDomain,
get_block_size,
)

for _f in [
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/pt2e/_affine_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

import torch

from torchao.quantization import Granularity
from torchao.quantization.pt2e.observer import (
AffineQuantizedObserverBase,
Granularity,
MappingType,
TorchAODType,
ZeroPointDomain,
get_block_size,
)
from torchao.quantization.utils import get_block_size

ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:

Expand Down
144 changes: 10 additions & 134 deletions torchao/quantization/pt2e/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,23 @@
from torch.fx import Node

import torchao
from torchao.quantization import (
Granularity,
PerAxis,
PerBlock,
PerGroup,
PerRow,
PerTensor,
PerToken,
)
from torchao.quantization.pt2e.utils import (
calculate_qmin_qmax,
check_min_max_valid,
is_per_channel,
is_per_tensor,
validate_qmin_qmax,
)
from torchao.quantization.utils import get_block_size

__all__ = [
"default_affine_fixed_qparams_observer",
Expand Down Expand Up @@ -1622,7 +1632,6 @@ def calculate_qparams(self):
We plan to merge the following with torchao repo after we move pt2e flow to torchao
copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
"""
from dataclasses import dataclass
from enum import Enum, auto


Expand Down Expand Up @@ -1679,139 +1688,6 @@ class TorchAODType(Enum):
INT7 = auto()


@dataclass(frozen=True)
class Granularity:
"""
Base class for representing the granularity of quantization.

This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""


@dataclass(frozen=True)
class PerBlock(Granularity):
"""
Represents per-block granularity in quantization. See
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
`block_size`

Attributes:
block_size (Tuple[int, ...]): The size of each quantization group
"""

block_size: tuple[int, ...]


@dataclass(frozen=True)
class PerTensor(Granularity):
"""
Represents per-tensor granularity in quantization.

This granularity type calculates the quantization parameters
based off the entire tensor.

"""


@dataclass(frozen=True)
class PerAxis(Granularity):
"""
Represents per-axis granularity in quantization.

This granularity type calculates different quantization parameters
along a specified axis of the tensor.

For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.

Attributes:
axis (int): The axis along which reduction is performed.
"""

axis: int


@dataclass(frozen=True)
class PerGroup(Granularity):
"""
Represents per-channel group granularity in quantization.

This granularity type calculates different quantization parameters
for each group of <group_size> elements.

For example if the input tensor is shape [8, 16], and the group size is 4, then
the input tensor is reshaped to [64, 4]
quantization parameters are calculated for each group of 4 elements,
giving a total of 64 quantization parameters.

Attributes:
group_size (int): The size of each quantization group

"""

group_size: int


class PerRow(Granularity):
"""
Represents row-wise granularity in quantization.

This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
"""


class PerToken(Granularity):
"""
Represents per-token granularity in quantization.

This granularity type calculates a different set of quantization parameters
for each token, which is represented as the last dimension of the tensor.

For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
with 4 elements each, and we will calculate 6 sets of quantization parameters,
one for each token.

If the input tensor has only two dimensions, e.g. [8, 16], then this is
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
"""


def get_block_size(
input_shape: tuple[int, ...], granularity: Granularity
) -> tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.

Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity: The granularity type of the quantization
"""
assert isinstance(granularity, Granularity), (
"Please provide an instance of Granularity, not subclass of it"
)
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity.axis] = 1
return tuple(block_size)
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
elif isinstance(granularity, PerGroup):
assert len(input_shape) == 2, (
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
)
return (1, granularity.group_size)
elif isinstance(granularity, PerToken):
block_size = [1] * len(input_shape)
block_size[-1] = input_shape[-1]
return tuple(block_size)
raise ValueError(f"Unsupported Granularity: {granularity}")


class AffineQuantizedObserverBase(ABC, torch.nn.Module):
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)

Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
PerRow,
PerToken,
)
from torchao.quantization.observer import get_block_size
from torchao.quantization.quant_primitives import (
_DTYPE_TO_BIT_WIDTH,
_DTYPE_TO_QVALUE_BOUNDS,
Expand All @@ -28,6 +27,7 @@
)
from torchao.quantization.utils import (
_get_per_token_block_size,
get_block_size,
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
)
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
)
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
from torchao.quantization.observer import AffineQuantizedObserverBase
from torchao.quantization.quantize_.common import (
KernelPreference,
)
Expand All @@ -87,6 +87,7 @@
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
)
from torchao.quantization.utils import get_block_size
from torchao.quantization.weight_tensor_linear_activation_quantization import (
to_weight_tensor_with_linear_activation_quantization_metadata,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
preprocess_scale,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.observer import get_block_size
from torchao.quantization.quant_primitives import (
_choose_scale_float8,
_dequantize_affine_float8,
Expand All @@ -34,6 +33,7 @@
QuantizeTensorKwargs,
_choose_quant_func_and_quantize_tensor,
)
from torchao.quantization.utils import get_block_size
from torchao.utils import (
TorchAOBaseTensor,
_is_fbgemm_genai_gpu_available,
Expand Down
Loading
Loading