Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
assert b_data.is_contiguous()
b_data = b_data.transpose(-1, -2)

a_scale = a_mx._scale_e8m0.view(M, K // 32)
b_scale = b_mx._scale_e8m0.view(N, K // 32)
a_scale = a_mx.scale.view(M, K // 32)
b_scale = b_mx.scale.view(N, K // 32)

a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)
Expand Down
30 changes: 15 additions & 15 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
assert data_mx.qdata.shape == (*prev_dims, K // 2)
else:
assert data_mx.qdata.shape == (*prev_dims, K)
assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size)
assert data_mx.scale.shape == (*prev_dims, K // block_size)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
assert torch.isnan(data_mx.qdata[0])
assert torch.all(data_mx.qdata[1:] == 0)
# fp32 denorm
Expand All @@ -168,7 +168,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# bf16 denorm
# fmt: off
Expand All @@ -189,7 +189,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# fp32 some denorm
# fmt: off
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# bf16 some denorm
# fmt: off
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# zero
data_hp = torch.tensor([0] * 32, dtype=torch.uint32).view(torch.float32)
Expand All @@ -262,7 +262,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# fp32 normal
# fmt: off
Expand Down Expand Up @@ -293,7 +293,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
# bf16 normal
# fmt: off
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_to_mx_rceil():
data_mx = MXTensor.to_mx(
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
)
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)


Expand All @@ -340,8 +340,8 @@ def test_exponent_nan_in(elem_dtype):
)
block_size = 4
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
assert torch.all(torch.isnan(tensor_mx._scale_e8m0[0]))
assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:]))
assert torch.all(torch.isnan(tensor_mx.scale[0]))
assert not torch.any(torch.isnan(tensor_mx.scale[1:]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -507,8 +507,8 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
x_mx_c = to_mx_c(x, elem_dtype, block_size)
torch.testing.assert_close(
x_mx._scale_e8m0,
x_mx_c._scale_e8m0,
x_mx.scale,
x_mx_c.scale,
atol=0,
rtol=0,
)
Expand All @@ -519,15 +519,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
pack_fp6 = False
x_mx_dq = to_dtype(
x_mx.qdata,
x_mx._scale_e8m0,
x_mx.scale,
x_mx._elem_dtype,
x_mx._block_size,
hp_dtype, # noqa: E501
pack_fp6,
)
x_mx_c_dq = to_dtype_c(
x_mx_c.qdata,
x_mx_c._scale_e8m0,
x_mx_c.scale,
x_mx_c._elem_dtype,
x_mx_c._block_size,
hp_dtype,
Expand Down
16 changes: 7 additions & 9 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_nvfp4_swizzled_scales_view_semantics():

# Test full-width column slicing (should maintain views)
full_width_slice = tensor[:, 0:K]
assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr()
assert full_width_slice.scale.data_ptr() == tensor.scale.data_ptr()
assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr()


Expand Down Expand Up @@ -394,9 +394,7 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
use_triton_kernel=True,
)

torch.testing.assert_close(
nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten()
)
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
torch.testing.assert_close(
Expand Down Expand Up @@ -523,7 +521,7 @@ def test_nvfp4_to_copy():
x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda()
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
assert torch.equal(x.qdata, y.qdata)
assert torch.equal(x._scale_e4m3, y._scale_e4m3)
assert torch.equal(x.scale, y.scale)
assert x._per_tensor_scale is None
assert y._per_tensor_scale is None
assert x._act_per_tensor_scale is None
Expand Down Expand Up @@ -586,20 +584,20 @@ def test_scale_shape_matches_qdata(
if is_swizzled_scales:
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
expected_padded_m = ceil_div(orig_m, 128) * 32
actual_padded_m = x._scale_e4m3.shape[m_dim]
actual_padded_m = x.scale.shape[m_dim]
assert expected_padded_m == actual_padded_m, (
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x._scale_e4m3.shape}"
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}"
)

orig_k = x_hp.shape[k_dim]
expected_padded_k = orig_k // block_size
if is_swizzled_scales:
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
expected_padded_k = ceil_div(orig_k // block_size, 4) * 16
actual_padded_k = x._scale_e4m3.shape[k_dim]
actual_padded_k = x.scale.shape[k_dim]

assert expected_padded_k == actual_padded_k, (
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}"
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}"
)


Expand Down
24 changes: 12 additions & 12 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):


class MXTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata", "_scale_e8m0"]
tensor_data_names = ["qdata", "scale"]
tensor_attribute_names = [
"_elem_dtype",
"_block_size",
Expand Down Expand Up @@ -548,10 +548,10 @@ def __new__(
# TODO investigate
assert target_numel == qdata.numel(), f"{target_numel} != {qdata.numel()}"

# `_scale_e8m0` has rank 1 and applies to a row-major memory layout of
# `scale` has rank 1 and applies to a row-major memory layout of
# `qdata`
self.qdata = qdata
self._scale_e8m0 = scale_e8m0_bits
self.scale = scale_e8m0_bits
self._elem_dtype = elem_dtype
self._block_size = block_size
self._orig_dtype = orig_dtype
Expand All @@ -562,15 +562,15 @@ def __new__(

def __repr__(self):
# TODO better elem dtype print for fp4
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501

def _quantization_type(self):
return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"

def to_dtype(self, target_dtype):
return to_dtype(
self.qdata,
self._scale_e8m0,
self.scale,
self._elem_dtype,
self._block_size,
target_dtype,
Expand Down Expand Up @@ -685,8 +685,8 @@ def _addmm_mx_dispatch(
assert a._block_size == 32, f"Invalid block size {a._block_size}"
assert b._block_size == 32, f"Invalid block size {b._block_size}"

a_scale = a._scale_e8m0.view(M, K // a._block_size)
b_scale = b._scale_e8m0.view(N, K // b._block_size)
a_scale = a.scale.view(M, K // a._block_size)
b_scale = b.scale.view(N, K // b._block_size)
a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)

Expand Down Expand Up @@ -757,7 +757,7 @@ def mx_t(func, types, args, kwargs):
old = args[0]
new = MXTensor(
old.qdata.t(),
old._scale_e8m0,
old.scale,
old._elem_dtype,
old._block_size,
old._orig_dtype,
Expand Down Expand Up @@ -801,7 +801,7 @@ def mx_view_op(func, types, args, kwargs):
new_data = func(data, new_size, *args[2:], **kwargs)
return MXTensor(
new_data,
args[0]._scale_e8m0,
args[0].scale,
args[0]._elem_dtype,
args[0]._block_size,
args[0]._orig_dtype,
Expand All @@ -821,7 +821,7 @@ def mx_slice(func, types, args, kwargs):
M, K = x.shape[0], x.shape[1]

# TODO why doesn't scale have shape?
scale_shaped = x._scale_e8m0.view(M, K // x._block_size)
scale_shaped = x.scale.view(M, K // x._block_size)

if dim == 0:
# Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now
Expand Down Expand Up @@ -888,12 +888,12 @@ def mx_clone(func, types, args, kwargs):
def mx_select(func, types, args, kwargs):
old_mx_tensor, dim, index = args
assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported"
assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), (
assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor.scale.shape), (
"unsupported"
)
new_mx_tensor = old_mx_tensor.__class__(
old_mx_tensor.qdata[index],
old_mx_tensor._scale_e8m0[index],
old_mx_tensor.scale[index],
old_mx_tensor._elem_dtype,
old_mx_tensor._block_size,
old_mx_tensor._orig_dtype,
Expand Down
Loading
Loading