Skip to content

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Feb 6, 2026

Description

Pieces taken from #2600.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • NVFP4 quantization for grouped tensor.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
@ksivaman ksivaman added the MoE label Feb 6, 2026
@ksivaman ksivaman marked this pull request as draft February 6, 2026 06:38
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds NVFP4 quantization support for GroupedTensor by introducing new graph-safe grouped Hadamard transform CUDA implementations and exposing new C API entrypoints in the public headers. The build system is updated to compile the new CUDA sources, and the headers are extended to provide grouped amax + hadamard+cast fusion functions usable from the existing TransformerEngine API surface.

Confidence Score: 2/5

  • This PR has merge-blocking correctness and build-configuration issues.
  • Score is reduced due to a definite undefined-behavior bug where an uninitialized QuantizationConfig can be passed when quant_config is null, and a build configuration gap where a new CUTLASS-based CUDA source misses the special compile flags applied to other CUTLASS kernels (risking debug-build hangs per the existing comment).
  • transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu; transformer_engine/common/CMakeLists.txt

Important Files Changed

Filename Overview
transformer_engine/common/CMakeLists.txt Adds two new hadamard_transform graph_safe CUDA sources to build; new CUTLASS-based source isn’t added to CUTLASS_KERNEL_SOURCES so it misses debug anti-hang compile flags.
transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu Introduces graph-safe grouped hadamard transform path for GroupedTensor; review focuses on new device-side tensor-id selection and wrapper API wiring.
transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu Adds CUTLASS/CUTE-based graph-safe grouped hadamard+cast+quant fusion; contains a wrapper that can pass an uninitialized QuantizationConfig when quant_config is null.
transformer_engine/common/include/transformer_engine/hadamard_transform.h Adds new GroupedTensor graph-safe hadamard/quantization C API declarations; signatures appear consistent with new .cu implementations.
transformer_engine/common/include/transformer_engine/multi_tensor.h Extends multi_tensor API with group amax graph-safe entrypoint for NVTEGroupedTensor; declaration matches new implementation wiring.
transformer_engine/common/include/transformer_engine/transformer_engine.h Extends public C API for GroupedTensor and new param/quantization attributes; changes appear coherent but should be checked for ABI compatibility expectations across releases.

Sequence Diagram

sequenceDiagram
  participant C as Client (C/C++)
  participant TE as TransformerEngine C API
  participant GT as GroupedTensor wrapper
  participant K as CUDA kernel(s)

  C->>TE: nvte_group_hadamard_transform_cast_fusion_graph_safe(input, output, hadamard, quant_config, workspace, stream)
  TE->>GT: convertNVTEGroupedTensorCheck(input/output)
  TE->>TE: QuantizationConfig quant_config_cpp (from quant_config)
  TE->>TE: group_hadamard_transform_cast_fusion_graph_safe(..., quant_config_cpp, ...)
  TE->>K: launch graph-safe grouped hadamard+cast+quant fusion kernels
  K-->>TE: write quantized outputs + amax/scales
  TE-->>C: return
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +709 to +724
// TODO(zhongbo): double check the logic here
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);

// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
// TODO(zhongbo): double check the logic here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiple TODO comments requesting logic verification in critical group index calculation and tensor layout code - verify group_idx calculation and tensor layout logic are correct before merging

Suggested change
// TODO(zhongbo): double check the logic here
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
// TODO(zhongbo): double check the logic here
// Determine the current tensor group index based on tile offset
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>(
reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)),
make_shape(M, cur_N), DStride{}); // (M,packed_N)

Comment on lines +778 to +795
// TODO(zhongbo): double check the logic here
int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors,
global_tile_n_offset * M, packed_N, M, offsets);

if (cur_group_idx != group_idx) {
group_idx = cur_group_idx;
c_global_amax_val = shared_storage.global_d_amax[group_idx];
// update amax
global_encode_scale = c_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / c_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
global_decode_scale = 1.0f / global_encode_scale;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
// TODO(zhongbo): double check the logic here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more TODO comments in epilogue loop - verify group index recalculation and amax scaling logic

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

using transformer_engine::detail::ShapeRepresentation;

void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr);
// TODO(zhongbo): add input sanity checks here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add input sanity checks as noted in TODO

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@ksivaman
Copy link
Member Author

ksivaman commented Feb 9, 2026

Fixes #2510

@ksivaman ksivaman marked this pull request as ready for review February 11, 2026 06:23
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1500 to +1512
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}

if (input_tensor->num_tensors == 0) {
return;
}

// Call the multi-tensor Hadamard transform amax implementation.
group_hadamard_transform_cast_fusion_graph_safe(
input_tensor, output_tensor, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp,
*quant_workspace_tensor, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uninitialized quant_config used

QuantizationConfig quant_config_cpp; is left uninitialized when quant_config == nullptr, but it’s still passed into group_hadamard_transform_cast_fusion_graph_safe(...) (line 1511). That’s undefined behavior and can lead to garbage config values being read inside the kernel launch path. Consider default-initializing (QuantizationConfig quant_config_cpp{};) or returning an error when quant_config is required.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Additional Comments (1)

transformer_engine/common/CMakeLists.txt
Missing CUTLASS debug flags

hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu includes CUTLASS/CUTE headers, but it’s not listed in CUTLASS_KERNEL_SOURCES, so it won’t get the -g0;-dopt=on compile options that this file group relies on to avoid debug-build hangs. Add this new source to CUTLASS_KERNEL_SOURCES so it is compiled with the same options as the other CUTLASS kernels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants