-
Notifications
You must be signed in to change notification settings - Fork 634
[C] NVFP4 quantization for GroupedTensor
#2655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds NVFP4 quantization support for Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant C as Client (C/C++)
participant TE as TransformerEngine C API
participant GT as GroupedTensor wrapper
participant K as CUDA kernel(s)
C->>TE: nvte_group_hadamard_transform_cast_fusion_graph_safe(input, output, hadamard, quant_config, workspace, stream)
TE->>GT: convertNVTEGroupedTensorCheck(input/output)
TE->>TE: QuantizationConfig quant_config_cpp (from quant_config)
TE->>TE: group_hadamard_transform_cast_fusion_graph_safe(..., quant_config_cpp, ...)
TE->>K: launch graph-safe grouped hadamard+cast+quant fusion kernels
K-->>TE: write quantized outputs + amax/scales
TE-->>C: return
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| // TODO(zhongbo): double check the logic here | ||
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | ||
| packed_N, M, offsets); | ||
|
|
||
| // Determine quantization scale factor layouts/output splits for this group | ||
| TSFDLayout sfd_layout; | ||
| int cur_N = static_cast<int>(first_dims[group_idx]); | ||
| if constexpr (kEnableSwizzleSFOutput) { | ||
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | ||
| } else { | ||
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | ||
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | ||
| } | ||
| // Build output tensors for columns and their quant scales | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multiple TODO comments requesting logic verification in critical group index calculation and tensor layout code - verify group_idx calculation and tensor layout logic are correct before merging
| // TODO(zhongbo): double check the logic here | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| // TODO(zhongbo): double check the logic here | |
| // Determine the current tensor group index based on tile offset | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>( | |
| reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)), | |
| make_shape(M, cur_N), DStride{}); // (M,packed_N) |
| // TODO(zhongbo): double check the logic here | ||
| int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| global_tile_n_offset * M, packed_N, M, offsets); | ||
|
|
||
| if (cur_group_idx != group_idx) { | ||
| group_idx = cur_group_idx; | ||
| c_global_amax_val = shared_storage.global_d_amax[group_idx]; | ||
| // update amax | ||
| global_encode_scale = c_global_amax_val > 0.0f | ||
| ? cutlass::minimum_with_nan_propagation<float>{}( | ||
| (fp8_max * fp4_max) / c_global_amax_val, | ||
| cutlass::platform::numeric_limits<float>::max()) | ||
| : 1.0f; | ||
| global_decode_scale = 1.0f / global_encode_scale; | ||
| if constexpr (kUseFastMath) { | ||
| global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; | ||
| } | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more TODO comments in epilogue loop - verify group index recalculation and amax scaling logic
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| using transformer_engine::detail::ShapeRepresentation; | ||
|
|
||
| void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr); | ||
| // TODO(zhongbo): add input sanity checks here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add input sanity checks as noted in TODO
|
Fixes #2510 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 2 comments
| QuantizationConfig quant_config_cpp; | ||
| if (quant_config != nullptr) { | ||
| quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config); | ||
| } | ||
|
|
||
| if (input_tensor->num_tensors == 0) { | ||
| return; | ||
| } | ||
|
|
||
| // Call the multi-tensor Hadamard transform amax implementation. | ||
| group_hadamard_transform_cast_fusion_graph_safe( | ||
| input_tensor, output_tensor, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, | ||
| *quant_workspace_tensor, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uninitialized quant_config used
QuantizationConfig quant_config_cpp; is left uninitialized when quant_config == nullptr, but it’s still passed into group_hadamard_transform_cast_fusion_graph_safe(...) (line 1511). That’s undefined behavior and can lead to garbage config values being read inside the kernel launch path. Consider default-initializing (QuantizationConfig quant_config_cpp{};) or returning an error when quant_config is required.
Additional Comments (1)
|
Description
Pieces taken from #2600.
Type of change
Changes
Checklist: