-
Notifications
You must be signed in to change notification settings - Fork 399
[pt2e] Fix QAT annotations for special qspecs #3337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3337
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c19b8cd with merge base e2aab90 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D86983409. |
jerryzh168
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah makes sense, can you add a test
Yep, this is in progress, just wanted to push first for Xing to try |
646cd49 to
d76110c
Compare
| for n in nodes_to_check: | ||
| if n.target == torch.ops.aten.batch_norm.default: | ||
| num_batch_norm_nodes_checked += 1 | ||
| self.assertTrue(n not in old_nodes, "found old node in qspec") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to also check they are all in new_nodes?
jerryzh168
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg, thanks for the fix, had one inline comment
**Summary:** In pt2e QAT, we first annotate the nodes to be quantized and then perform the pattern replacement (i.e. QAT fusion) in the prepare step. After this pattern replacement, old nodes are replaced with new nodes, and any references to the old nodes must also be updated to refer to the new nodes. This commit fixes a bug where, for special qspecs like the `SharedQuantizationSpec` and `DerivedQuantizationSpec`, we only update the values of a node's `input_qspec_map`, not the keys. As a result, the keys still refer to the old nodes that do not exist anymore after the QAT fusion. **Test Plan:** ``` python test/quantization/pt2e/test_quantize_pt2e_qat.py -k test_qat_shared_qspec ```
d76110c to
c19b8cd
Compare
**Summary:** In pt2e QAT, we first annotate the nodes to be quantized and then perform the pattern replacement (i.e. QAT fusion) in the prepare step. After this pattern replacement, old nodes are replaced with new nodes, and any references to the old nodes must also be updated to refer to the new nodes. This commit fixes a bug where, for special qspecs like the `SharedQuantizationSpec` and `DerivedQuantizationSpec`, we only update the values of a node's `input_qspec_map`, not the keys. As a result, the keys still refer to the old nodes that do not exist anymore after the QAT fusion. **Test Plan:** ``` python test/quantization/pt2e/test_quantize_pt2e_qat.py -k test_qat_shared_qspec ```
Summary: In pt2e QAT, we first annotate the nodes to be quantized and then perform the pattern replacement (i.e. QAT fusion) in the prepare step. After this pattern replacement, old nodes are replaced with new nodes, and any references to the old nodes must also be updated to refer to the new nodes.
This commit fixes a bug where, for special qspecs like the
SharedQuantizationSpecandDerivedQuantizationSpec, we only update the values of a node'sinput_qspec_map, not the keys. As a result, the keys still refer to the old nodes that do not exist anymore after the QAT fusion.Test Plan: