Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions test/prototype/moe_training/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@

# this test requires torchtitan
try:
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
from torchtitan.models.moe import MoE, MoEArgs
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
except ImportError:
pytest.skip(
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
Expand All @@ -62,9 +62,6 @@ def device_mesh_1d() -> DeviceMesh:
"""
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if not dist.is_initialized():
dist.init_process_group("nccl", rank=rank, world_size=world_size)

device_mesh = init_device_mesh("cuda", (world_size,))
torch.manual_seed(1)
torch.cuda.set_device(rank)
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/moe_training/test_fsdp_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@
ExpertTensorParallel,
NoParallel,
TensorParallel,
set_token_group_alignment_size_m,
)
from torchtitan.models.moe import MoE, MoEArgs
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
except ImportError:
pytest.skip(
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
Expand Down
7 changes: 2 additions & 5 deletions test/prototype/moe_training/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@

# this test requires torchtitan
try:
from torchtitan.distributed import NoParallel
from torchtitan.distributed.expert_parallel import (
ExpertParallel,
ExpertTensorParallel,
NoParallel,
TensorParallel,
set_token_group_alignment_size_m,
)
from torchtitan.models.moe import MoE, MoEArgs
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
except ImportError:
pytest.skip(
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
Expand All @@ -80,9 +80,6 @@ def device_mesh_1d() -> DeviceMesh:
"""
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if not dist.is_initialized():
dist.init_process_group("nccl", rank=rank, world_size=world_size)

device_mesh = init_device_mesh("cuda", (world_size,))
torch.manual_seed(1)
torch.cuda.set_device(rank)
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/moe_training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

# this test requires torchtitan
try:
from torchtitan.distributed.expert_parallel import (
from torchtitan.models.moe import MoE, MoEArgs
from torchtitan.models.moe.utils import (
set_token_group_alignment_size_m,
)
from torchtitan.models.moe import MoE, MoEArgs
except ImportError:
pytest.skip(
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
Expand Down
47 changes: 22 additions & 25 deletions torchao/prototype/moe_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

This prototype provides:

1. Quantized building block for low precision MoE training: [_quantize_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L42). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
1. Quantized building block for low precision MoE training: [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L677). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
- Using MXFP8 on a B200 GPU, this provides:
- **~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes
- **~1.15 - 1.3x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
- We also provide the following convenience functions for specific recipes:
- [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L677)
- [_to_fp8_rowwise_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L678)
- **~1.19 - 1.6x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes



Expand All @@ -28,12 +25,12 @@ This prototype provides:
- [Limitations](#limitations)

## Examples
#### _quantize_then_scaled_grouped_mm usage
#### _to_mxfp8_and_scaled_grouped_mm usage
```python
import torch
from torch.nn import functional as F
from torchao.prototype.moe_training import (
_quantize_then_scaled_grouped_mm
_to_mxfp8_then_scaled_grouped_mm,
)
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchao.prototype.moe_training.utils import generate_jagged_offs
Expand All @@ -48,11 +45,10 @@ B = torch.randn(num_groups, N, K, dtype=torch.bfloat16, device="cuda", requires_
offs = generate_jagged_offs(num_groups, total_M, device="cuda")

# Forward and backward example
out = _quantize_then_scaled_grouped_mm(
out = _to_mxfp8_then_scaled_grouped_mm(
A,
B.transpose(-2, -1),
offs=offs,
scaling_type=MoEScalingType.MXFP8,
offs,
)

# (Fake labels for demonstration purposes)
Expand All @@ -63,20 +59,20 @@ loss.backward()

#### Model conversion API example: end-to-end training
```python
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
import torch
from torch import nn
from torch.nn import functional as F

# this feature requires CUDA 12.8+ and SM100+
# This feature requires CUDA 12.8+ and SM100+
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)

from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
from torchao.quantization.quant_api import quantize_

# this example uses torchtitan llama4 MoE, see
# this benchmark requires torchtitan
# This example uses torchtitan Llama4 MoE.
try:
from torchtitan.distributed.expert_parallel import (
from torchtitan.models.moe.utils import (
set_token_group_alignment_size_m,
)
from torchtitan.models.moe import MoE, MoEArgs
Expand All @@ -86,7 +82,7 @@ except ImportError:
)


# initialize model
# Initialize model
device = torch.device("cuda")
moe_args = MoEArgs(
num_experts=8,
Expand All @@ -96,7 +92,7 @@ model = MoE(moe_args, dim, hidden_dim).to(torch.bfloat16).to(device)
init_std = 0.02
model.init_weights(init_std, device)

# module filter function to define which modules to quantize
# Module filter function to define which modules to quantize
target_fqns = ["experts"]


Expand All @@ -106,31 +102,32 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return True
return False

# Token group alignment size must be 32 for MXFP8 training
alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16
# Token group sizes must be padded to multiple of MXFP8 scaling block size (1x32)
alignment_size = 32
set_token_group_alignment_size_m(alignment_size)

# quantize the model
config = MoETrainingConfig()
# Convert model to use MXFP8 scaled grouped GEMMs
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# training loop
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size, seq_len = 2, 2048
for step in range(10):
# Simulate random batch of input data
x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
batch_size, seq_len, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)

# forward pass
# Forward pass
out = model(x)

# compute loss
# Compute loss with fake labels for demonstration purposes
labels = torch.ones_like(out)
out_loss = F.mse_loss(out, labels)
print(f"step {step} loss: {out_loss.item()}")

# backward pass
# Backward pass
out_loss.backward()
optimizer.step()
optimizer.zero_grad()
Expand Down
6 changes: 5 additions & 1 deletion torchao/prototype/moe_training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from torchao.prototype.moe_training.scaled_grouped_mm import (
_quantize_then_scaled_grouped_mm,
_to_mxfp8_then_scaled_grouped_mm,
)

__all__ = ["_quantize_then_scaled_grouped_mm"]
__all__ = [
"_quantize_then_scaled_grouped_mm",
"_to_mxfp8_then_scaled_grouped_mm",
]
Loading