diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 671186c7c7..265d890847 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -881,7 +881,10 @@ def _get_mxfp8_quant_autotune_configs(): # can be improved in the future. results = [] for ROW_TILE_SIZE in (128, 256, 512): - for COL_TILE_SIZE in (128, 256, 512): + # TODO: we can't use 512 for COL_TILE_SIZE. + # This is likely a triton bug, tracked in + # https://github.com/pytorch/ao/issues/3362 + for COL_TILE_SIZE in (128, 256): for num_warps in (4, 8): for num_stages in (2, 3): config = triton.Config(