-
Notifications
You must be signed in to change notification settings - Fork 633
Add 2d quant for mxfp8 #2634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add 2d quant for mxfp8 #2634
Conversation
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds an MXFP8 “2D block scaling” mode (one E8M0 scale per 32×32 block shared between rowwise/colwise), plumbed end-to-end from Python recipes/quantizers through the C++ QuantizationConfig and down into the CUDA MXFP8 quantization kernel. It also adds a new PyTorch unit test that compares GPU output against a CPU reference and wires the test into the L0 pytest runner. Key integration points:
Issues to address before merge:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as Python MXFP8Quantizer
participant RS as MXFP8BlockScalingRecipeState
participant CppQ as C++ MXFP8Quantizer
participant QC as QuantizationConfigWrapper
participant Disp as quantize_fwd/bwd_helper
participant MX as mxfp8::quantize
participant Ker as quantize_mxfp8_kernel
RS->>Py: make_quantizers(mode)
Py->>CppQ: call quantize/update_quantized
CppQ->>QC: set_mxfp8_2d_quantization(with_2d_quantization)
CppQ->>Disp: nvte_quantize_v2(input, output, QuantizationConfig)
Disp->>MX: mxfp8::quantize(..., mxfp8_2d_quantization)
MX->>MX: if use_2d_quantization: force ScalingType=BIDIMENSIONAL
MX->>Ker: launch with kIs2DBlockScaling
Ker->>Ker: colwise pass computes E8M0 per 32x32 block
Ker->>Ker: write block_scales_2d[] in shared mem
Ker->>Ker: rowwise pass shfl-broadcast scale from shared mem
Ker-->>MX: write row/col scales + FP8 data
MX-->>Disp: return
Disp-->>CppQ: return
CppQ-->>Py: return MXFP8Tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
| e8m0_t scale_from_shmem; | ||
| if (thread_lane < THREADS_X) { | ||
| scale_from_shmem = block_scales_2d[thread_lane]; | ||
| } | ||
| // Broadcast: each thread gets scale from lane matching its tid_X_rowwise | ||
| biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scale_from_shmem is potentially uninitialized for threads where thread_lane >= THREADS_X. While __shfl_sync only reads from lanes specified by tid_X_rowwise (which should be < THREADS_X), it's safer to initialize this variable.
| e8m0_t scale_from_shmem; | |
| if (thread_lane < THREADS_X) { | |
| scale_from_shmem = block_scales_2d[thread_lane]; | |
| } | |
| // Broadcast: each thread gets scale from lane matching its tid_X_rowwise | |
| biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise); | |
| e8m0_t scale_from_shmem = 0; | |
| if (thread_lane < THREADS_X) { | |
| scale_from_shmem = block_scales_2d[thread_lane]; | |
| } |
| } | ||
| } | ||
|
|
||
| if (use_2d_quantization) { scaling_type = ScalingType::BIDIMENSIONAL; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this intentional? I think it will fail if the output does not have one of the rowwise/columnwise pieces allocated. Also the 2D quantization should not care about that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two reasons for this:
- First, the current implementation of this quantize kernel makes it easy for me to obtain the 32x32 block amax from the columnwise code path, but it's more difficult from the rowwise code path. Therefore, I cannot implement 2D quant when using rowwise alone.
- Second, this 2D quant is currently only used for weights, so in practice it's only needed when using BIDIMENSIONAL mode.
However, I think one of your concerns is valid—I shouldn't directly modify the scaling type here; I should check it instead. I've now changed it so that when using 2D quant, it asserts that scaling type == BIDIMENSIONAL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the inference case the weights would only ask for the rowwise quantization, so we need to be able to support that.
Signed-off-by: kunlunl <kunlunl@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 3 comments
| @@ -420,7 +421,8 @@ struct QuantizationConfig { | |||
| sizeof(NVTETensor), // rng_seed and offset | |||
| sizeof(uint8_t), // nvfp4_2d_quantization | |||
| sizeof(uint8_t), // stochastic_rounding | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QuantizationConfig layout mismatch
QuantizationConfig::attr_sizes[] was extended with mxfp8_2d_quantization (common.h:415-423), but QuantizationConfig itself uses bool fields. If nvte_set/get_quantization_config_attribute assumes all attributes are packed according to attr_sizes, adding an extra byte-sized attribute without updating any corresponding packing/unpacking logic can break attribute indexing for subsequent attributes (or any code that iterates kNVTEQuantizationConfigNumAttributes). Please double-check the code that uses attr_sizes to ensure the new attribute is reflected everywhere it’s consumed (and that kNVTEQuantizationConfigNumAttributes matches the size of attr_sizes).
| ) # (num_block_rows, num_block_cols, 32, 32) | ||
|
|
||
| # Compute amax for each 32x32 block | ||
| block_amax = torch.amax( | ||
| torch.abs(x_blocks.to(torch.float32)), dim=(-1, -2) | ||
| ) # (num_block_rows, num_block_cols) | ||
|
|
||
| # Convert to E8M0 scale inverse | ||
| block_scale_e8m0 = float_to_e8m0(block_amax) # (num_block_rows, num_block_cols) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect float bit-cast
float_to_e8m0 does val_u32 = val.view(torch.int32) (test_mxfp8_2d_quantize.py:104-106). On PyTorch, .view(dtype) is a numeric cast, not a bit reinterpretation. This makes the reference implementation compute wrong exponents/mantissas and can cause false failures/passes.
Use val.view(torch.int32) only if you’ve explicitly reinterpreted bytes (e.g., via val.view(torch.uint8) + view(torch.int32) on the same storage) or use val.to(torch.int32) with torch.frexp/torch.bitwise_* alternatives. As written, the reference is not modeling the GPU’s IEEE754 bit extraction.
Additional Comments (1)
In If 2D block scaling is only valid when both rowwise+colwise outputs are requested, it should be validated (error) instead of overriding |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: