-
Notifications
You must be signed in to change notification settings - Fork 400
Align memory_format for conv2d/3d in Float8Tensor with hp Tensor #3352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3352
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fe4167f with merge base ff0e461 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
a25ffda to
8bbb423
Compare
8bbb423 to
16dba36
Compare
| # output should use channels_last format as long as any of the | ||
| # input or weight is channels_last | ||
| if is_input_channels_last or is_weight_channels_last: | ||
| output = output.to(memory_format=torch.channels_last_3d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the right thing to semantics-wise, but note that this will incur a copy if the output isn't already in channels_last. Ideally, the kernel itself would output into channels_last directly to avoid the copy.
Edit: oh I think you're already aware of this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the fbgemm kernel should output a tensor that's already in this format, so this becomes a no-op when either input or weight is in channels_last format
| act_qdata = act_qdata.contiguous() | ||
| weight_qdata = weight_qdata.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these contiguous() calls are the right thing to do here - note that this will clobber channels_last for the activation and weight. Calling contiguous(memory_format=torch.channels_last) would be more correct
Edit: from offline discussion, can't forget the permute()! we want a contiguous() (N, D, H, W, C_in) tensor, which is equivalent to a properly-permuted, contiguous(channels_last) (N, C_in, D, H, W) tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactored this to first to contiguous(memory_format=torch.channels_last_3d) and then do permute to make it easier to follow
58bf5d1 to
642e3d0
Compare
…recision Tensors Summary: att, we want to make sure the output of `F.conv3d(input, weight, ...)` and `F.conv3d(input, fp8_weight, ...)` have the same memory_format Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
642e3d0 to
fe4167f
Compare
…orch#3352) Align memory_format for conv2d and conv3d in Float8Tensor with high precision Tensors Summary: att, we want to make sure the output of `F.conv3d(input, weight, ...)` and `F.conv3d(input, fp8_weight, ...)` have the same memory_format Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
Summary:
att, we want to make sure the output of
F.conv3d(input, weight, ...)andF.conv3d(input, fp8_weight, ...)have the same memory_formatTest Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants
Reviewers:
Subscribers:
Tasks:
Tags: