Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
triton_to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor

torch.manual_seed(0)

Expand Down Expand Up @@ -76,6 +77,18 @@ def to_mx_dim1_reference(
return data_d1.t(), scale_d1


def to_nvfp4_reference(x_hp):
nvfp4_tensor = NVFP4Tensor.to_nvfp4(x_hp, use_triton_kernel=False)
return nvfp4_tensor.qdata, nvfp4_tensor.scale


def to_nvfp4_reference_triton_swizzle(x_hp):
nvfp4_tensor = NVFP4Tensor.to_nvfp4(
x_hp, use_triton_kernel=True, is_swizzled_scales=True
)
return nvfp4_tensor.qdata, nvfp4_tensor.scale


def benchmark_cuda_function_in_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3

Expand All @@ -99,6 +112,8 @@ def run(
"dim0_mxfp4_floor",
"dim0_mxfp8_rceil",
"dim0_mxfp8_triton_floor",
"dim0_nvfp4",
"dim0_nvfp4_triton_swizzle",
"dim1_mxfp8_floor",
"dim1_mxfp8_rceil",
"dim1_mxfp8_triton_floor",
Expand Down Expand Up @@ -240,6 +255,37 @@ def run(
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim0_nvfp4":
to_nvfp4_reference_c = torch.compile(to_nvfp4_reference)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you set fullgraph = True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do in a separate PR

y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False)

for _ in range(2):
__ = to_nvfp4_reference_c(x, use_triton_kernel=False)
time_us = benchmark_cuda_function_in_microseconds(
lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False),
x,
)
assert y_d0.dtype == torch.uint8
assert s_d0.dtype == torch.float8_e4m3fn
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim0_nvfp4_triton_swizzle":
y_d0, s_d0 = to_nvfp4_reference_triton_swizzle(x)

for _ in range(2):
__ = to_nvfp4_reference_triton_swizzle(x)
time_us = benchmark_cuda_function_in_microseconds(
lambda x: to_nvfp4_reference_triton_swizzle(x),
x,
)
assert y_d0.dtype == torch.uint8
assert s_d0.dtype == torch.float8_e4m3fn
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mxfp8_floor":
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
Expand Down
Loading