Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
assert b_data.is_contiguous()
b_data = b_data.transpose(-1, -2)

a_scale = a_mx._scale_e8m0.view(M, K // 32)
b_scale = b_mx._scale_e8m0.view(N, K // 32)
a_scale = a_mx.scale.view(M, K // 32)
b_scale = b_mx.scale.view(N, K // 32)

a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)
Expand Down
30 changes: 15 additions & 15 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
assert data_mx.qdata.shape == (*prev_dims, K // 2)
else:
assert data_mx.qdata.shape == (*prev_dims, K)
assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size)
assert data_mx.scale.shape == (*prev_dims, K // block_size)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
assert torch.isnan(data_mx.qdata[0])
assert torch.all(data_mx.qdata[1:] == 0)
# fp32 denorm
Expand All @@ -168,7 +168,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# bf16 denorm
# fmt: off
Expand All @@ -189,7 +189,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# fp32 some denorm
# fmt: off
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# bf16 some denorm
# fmt: off
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# zero
data_hp = torch.tensor([0] * 32, dtype=torch.uint32).view(torch.float32)
Expand All @@ -262,7 +262,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# fp32 normal
# fmt: off
Expand Down Expand Up @@ -293,7 +293,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# bf16 normal
# fmt: off
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)


Expand All @@ -340,8 +340,8 @@ def test_exponent_nan_in(elem_dtype):
)
block_size = 4
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
assert torch.all(torch.isnan(tensor_mx._scale_e8m0[0]))
assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:]))
assert torch.all(torch.isnan(tensor_mx.scale[0]))
assert not torch.any(torch.isnan(tensor_mx.scale[1:]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -507,8 +507,8 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
x_mx_c = to_mx_c(x, elem_dtype, block_size)
torch.testing.assert_close(
x_mx._scale_e8m0,
x_mx_c._scale_e8m0,
x_mx.scale,
x_mx_c.scale,
atol=0,
rtol=0,
)
Expand All @@ -519,15 +519,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
pack_fp6 = False
x_mx_dq = to_dtype(
x_mx.qdata,
x_mx._scale_e8m0,
x_mx.scale,
x_mx._elem_dtype,
x_mx._block_size,
hp_dtype, # noqa: E501
pack_fp6,
)
x_mx_c_dq = to_dtype_c(
x_mx_c.qdata,
x_mx_c._scale_e8m0,
x_mx_c.scale,
x_mx_c._elem_dtype,
x_mx_c._block_size,
hp_dtype,
Expand Down
24 changes: 12 additions & 12 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):


class MXTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata", "_scale_e8m0"]
tensor_data_names = ["qdata", "scale"]
tensor_attribute_names = [
"_elem_dtype",
"_block_size",
Expand Down Expand Up @@ -548,10 +548,10 @@ def __new__(
# TODO investigate
assert target_numel == qdata.numel(), f"{target_numel} != {qdata.numel()}"

# `_scale_e8m0` has rank 1 and applies to a row-major memory layout of
# `scale` has rank 1 and applies to a row-major memory layout of
# `qdata`
self.qdata = qdata
self._scale_e8m0 = scale_e8m0_bits
self.scale = scale_e8m0_bits
self._elem_dtype = elem_dtype
self._block_size = block_size
self._orig_dtype = orig_dtype
Expand All @@ -562,15 +562,15 @@ def __new__(

def __repr__(self):
# TODO better elem dtype print for fp4
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501

def _quantization_type(self):
return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"

def to_dtype(self, target_dtype):
return to_dtype(
self.qdata,
self._scale_e8m0,
self.scale,
self._elem_dtype,
self._block_size,
target_dtype,
Expand Down Expand Up @@ -685,8 +685,8 @@ def _addmm_mx_dispatch(
assert a._block_size == 32, f"Invalid block size {a._block_size}"
assert b._block_size == 32, f"Invalid block size {b._block_size}"

a_scale = a._scale_e8m0.view(M, K // a._block_size)
b_scale = b._scale_e8m0.view(N, K // b._block_size)
a_scale = a.scale.view(M, K // a._block_size)
b_scale = b.scale.view(N, K // b._block_size)
a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)

Expand Down Expand Up @@ -757,7 +757,7 @@ def mx_t(func, types, args, kwargs):
old = args[0]
new = MXTensor(
old.qdata.t(),
old._scale_e8m0,
old.scale,
old._elem_dtype,
old._block_size,
old._orig_dtype,
Expand Down Expand Up @@ -801,7 +801,7 @@ def mx_view_op(func, types, args, kwargs):
new_data = func(data, new_size, *args[2:], **kwargs)
return MXTensor(
new_data,
args[0]._scale_e8m0,
args[0].scale,
args[0]._elem_dtype,
args[0]._block_size,
args[0]._orig_dtype,
Expand All @@ -821,7 +821,7 @@ def mx_slice(func, types, args, kwargs):
M, K = x.shape[0], x.shape[1]

# TODO why doesn't scale have shape?
scale_shaped = x._scale_e8m0.view(M, K // x._block_size)
scale_shaped = x.scale.view(M, K // x._block_size)

if dim == 0:
# Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now
Expand Down Expand Up @@ -888,12 +888,12 @@ def mx_clone(func, types, args, kwargs):
def mx_select(func, types, args, kwargs):
old_mx_tensor, dim, index = args
assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported"
assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), (
assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor.scale.shape), (
"unsupported"
)
new_mx_tensor = old_mx_tensor.__class__(
old_mx_tensor.qdata[index],
old_mx_tensor._scale_e8m0[index],
old_mx_tensor.scale[index],
old_mx_tensor._elem_dtype,
old_mx_tensor._block_size,
old_mx_tensor._orig_dtype,
Expand Down
Loading