-
-
Notifications
You must be signed in to change notification settings - Fork 12.6k
map torchao quantized checkpoints to vLLM's MoE kernels #28421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a work-in-progress to add support for mapping torchao quantized MoE checkpoints to vLLM's optimized kernels. The changes primarily involve modifying torchao.py to handle FusedMoE layers and qwen2_moe.py to load Float8Tensor weights. The approach seems reasonable for the stated goal. My review includes a few suggestions to improve code robustness and remove an internal link before this can be landed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
41f071f to
9d047b3
Compare
| # TODO(before land): test other formats, and make it explicit when | ||
| # something is not supported with a nice error message. | ||
|
|
||
| layer.w13_weight_scale = torch.nn.Parameter(layer.w13_weight.scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
before land: need to polish this to set requires_grad and other attrs properly
5f32c46 to
047c44b
Compare
| is_float8_rowwise = isinstance( | ||
| torchao_config, Float8DynamicActivationFloat8WeightConfig | ||
| ) and torchao_config.granularity == [PerRow(), PerRow()] | ||
| # Special case of float8 rowwise where the HuggingFace weight is stored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: it's faster to preprocess the checkpoint to convert expert weights to MNK, but good to have this path for supporting naive HF format
kylesayrs
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll call out that compressed tensors has some nice matching utilities, especially ones that might be good for matching these sets of fused weights.
There are plans to use these utils in compressed_tensors.py as well, not sure if they might be helpful here.
thank you! Will take a look, I'd like to reuse instead of reimplementing where makes sense :) |
kylesayrs
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! I think doing this conversion on a per-scheme basis is probably the correct approach.
You could potentially generalize this to some sort of adapter mixin for Linear and Moe layers (that way you can share conversion logic for modules which have the same parameters), but that can be thought through more later.
thank you, good to hear this approach looks reasonable! Going to polish this PR a bit. |
d58f788 to
96c59da
Compare
…rnels Summary: This is not ready for review yet, for now just hacking to understand quantization and MoEs in vLLM. High level, I want to map torchao generated MoE checkpoints to vLLM optimized fused kernels. Note that I **do not** want torchao to take over the kernels here, instead torchao just provides the checkpoint and there is glue code to let vLLM select the kernels. TODO iterate some more and write down the design in more detail Test Plan: Tested locally with Qwen1.5-MoE-A2.7B with experts quantized to float8 with rowwise scaling. Reviewers: Subscribers: Tasks: Tags: Signed-off-by: vasiliy <[email protected]> Signed-off-by: <[email protected]>
Summary: Adds an e2e example of how to use torchao to quantize LLaMa 4 Scout. Note that this needs: * a recent `transformers` version (higher than 4.57, not officially released yet so user needs to build from source) * a recent `fbgemm_gpu` version nightly from `2025.11.22` or after * to run this in vLLM, vllm-project/vllm#28421 is needed (not yet landed). Test Plan: ```bash with-proxy time python examples/quantize_llama_4.py ~/local/tmp/20251201_test/ ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3c47130 ghstack-comment-id: 3599037297 Pull-Request: #3408
Summary: Adds an e2e example of how to use torchao to quantize LLaMa 4 Scout. Note that this needs: * a recent `transformers` version (higher than 4.57, not officially released yet so user needs to build from source) * a recent `fbgemm_gpu` version nightly from `2025.11.22` or after * to run this in vLLM, vllm-project/vllm#28421 is needed (not yet landed). Test Plan: ```bash with-proxy time python examples/quantize_llama_4.py ~/local/tmp/20251201_test/ ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 76125e9 ghstack-comment-id: 3599037297 Pull-Request: #3408
Summary:
Purpose
This is a POC to map
torchao's quantized checkpoints for Mixture-of-Experts modules to thecompressed-tensorsMoE path for w8a8 rowwise quant scheme. We do this by:CompressedTensorsW8A8Fp8MoEMethod, and overriding itscreate_weightsandprocess_weights_after_loadingmethods to map from torchao quantized tensor to compressed-tensors plain tensor format. This is the part where the "torchao -> compressed_tensor" conversion happens. Note that for float8 w8a8 rowwise, this is a metadata-only change.TorchAOConfigto set the quant method toTorchAOWrappingCompressedTensorsW8A8Fp8MoEMethodwhen appropriate.Note that there are no changes to the
compressed-tensorspath due to the existing vLLM APIs already being expressive enough to do this mapping.For now, I only implemented one w8a8 scheme to demonstrate a proof of concept. In the future, the following could be done:
a. map to more schemes for w8a8
b. add mappings for w4a8, etc
Test Plan
Qwen/Qwen1.5-MoE-A2.7Bwith experts quantized to float8 with rowwise scaling with torchaometa-llama/Llama-4-Scout-17B-16E-Instructwith experts quantized to float8 with rowwise scaling with torchao> CUDA_VISIBLE_DEVICES=4,5,6,7 vllm bench throughput --model ../pytorch_scripts/hf_torchao_vllm/data/torchao/fp8-experts-only-mnk-testing-Llama-4-Scout-17B-16E-Instruct --dataset-name sonnet --dataset-path benchmarks/sonnet.txt --num-prompts 10 --tensor-parallel-size 4 --max-model-len 2048 --gpu-memory-utilization 0.8 ... Throughput: 4.07 requests/s, 2623.30 total tokens/s, 609.79 output tokens/s Total num prompt tokens: 4953 Total num output tokens: 1500Test Result
see above, the new funtionality did not exist before this PR
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.