Skip to content

Conversation

@kunlunl
Copy link
Contributor

@kunlunl kunlunl commented Jan 29, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 29, 2026

Greptile Overview

Greptile Summary

This PR adds an MXFP8 “2D block scaling” mode (one E8M0 scale per 32×32 block shared between rowwise/colwise), plumbed end-to-end from Python recipes/quantizers through the C++ QuantizationConfig and down into the CUDA MXFP8 quantization kernel. It also adds a new PyTorch unit test that compares GPU output against a CPU reference and wires the test into the L0 pytest runner.

Key integration points:

  • Python: MXFP8BlockScaling recipe now carries mxfp8_2d_quantization QParams (weight-only in forward), and MXFP8Quantizer propagates with_2d_quantization.
  • C++/C API: new QuantizationConfig attribute kNVTEQuantizationConfigMXFP82DQuantization and wrapper setter, plus get/set handling.
  • CUDA: MXFP8 quantization dispatch passes the flag to mxfp8::quantize, which selects a kernel specialization for 2D and shares colwise-computed scales to the rowwise pass via shared memory.

Issues to address before merge:

  • The new CPU reference implementation in the test appears to compute IEEE754 exponent/mantissa via a numeric dtype cast (.view(torch.int32)), which doesn’t match bit-level reinterpretation and makes the test’s ground truth unreliable.
  • In the kernel wrapper, enabling use_2d_quantization currently overrides scaling_type to BIDIMENSIONAL, which can silently change the caller’s requested scaling behavior rather than validating compatibility.
  • QuantizationConfig’s new attribute size entry needs verification that any packing/indexing logic consuming attr_sizes stays consistent with the enum count and get/set paths.

Confidence Score: 3/5

  • This PR adds useful functionality but has a couple of correctness issues that should be fixed before merge.
  • Score reduced due to (1) an incorrect CPU reference implementation in the new unit test (bit reinterpretation bug) and (2) logic that forces bidimensional scaling when the 2D flag is set, which can change caller-requested semantics. The rest of the plumbing looks coherent but QuantizationConfig attribute layout should be double-checked.
  • tests/pytorch/test_mxfp8_2d_quantize.py; transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh; transformer_engine/common/common.h

Important Files Changed

Filename Overview
qa/L0_pytorch_unittest/test.sh Adds the new MXFP8 2D quantization pytest file to the L0 PyTorch unittest script.
tests/pytorch/test_mxfp8_2d_quantize.py Adds extensive GPU-vs-reference tests for MXFP8 2D block scaling; reference float_to_e8m0 uses an incorrect dtype view for bit extraction, making the reference potentially invalid.
transformer_engine/common/cast/dispatch/quantize.cuh Plumbs mxfp8_2d_quantization flag from QuantizationConfig into MXFP8 quantize dispatch for fwd/bwd.
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh Implements optional 2D block scaling in the MXFP8 kernel; currently forces scaling_type to BIDIMENSIONAL when flag is set, which can change caller-requested scaling semantics.
transformer_engine/common/common.h Adds mxfp8_2d_quantization to QuantizationConfig and attr_sizes; needs careful verification that attribute packing/indexing logic stays consistent.
transformer_engine/common/include/transformer_engine/transformer_engine.h Extends public QuantizationConfig attribute enum and wrapper API with MXFP8 2D quantization setter.
transformer_engine/common/recipe/init.py Adds mxfp8_2d_quantization flag to QParams and MXFP8BlockScaling recipe, defaulting via env var and applying 2D only to weight quantization params.
transformer_engine/common/transformer_engine.cpp Adds get/set handling for the new MXFP8 2D quantization config attribute.
transformer_engine/pytorch/csrc/common.h Adds with_2d_quantization field to the C++ MXFP8Quantizer wrapper.
transformer_engine/pytorch/csrc/quantizer.cpp Reads with_2d_quantization from Python quantizer and sets the corresponding QuantizationConfig attribute before nvte_quantize_v2.
transformer_engine/pytorch/quantization.py Updates MXFP8BlockScalingRecipeState to construct MXFP8Quantizer instances with per-qparam 2D settings (weight-only in forward, grad in backward).
transformer_engine/pytorch/tensor/mxfp8_tensor.py Extends MXFP8Quantizer Python class to carry with_2d_quantization through init/copy.

Sequence Diagram

sequenceDiagram
    participant Py as Python MXFP8Quantizer
    participant RS as MXFP8BlockScalingRecipeState
    participant CppQ as C++ MXFP8Quantizer
    participant QC as QuantizationConfigWrapper
    participant Disp as quantize_fwd/bwd_helper
    participant MX as mxfp8::quantize
    participant Ker as quantize_mxfp8_kernel

    RS->>Py: make_quantizers(mode)
    Py->>CppQ: call quantize/update_quantized
    CppQ->>QC: set_mxfp8_2d_quantization(with_2d_quantization)
    CppQ->>Disp: nvte_quantize_v2(input, output, QuantizationConfig)
    Disp->>MX: mxfp8::quantize(..., mxfp8_2d_quantization)
    MX->>MX: if use_2d_quantization: force ScalingType=BIDIMENSIONAL
    MX->>Ker: launch with kIs2DBlockScaling
    Ker->>Ker: colwise pass computes E8M0 per 32x32 block
    Ker->>Ker: write block_scales_2d[] in shared mem
    Ker->>Ker: rowwise pass shfl-broadcast scale from shared mem
    Ker-->>MX: write row/col scales + FP8 data
    MX-->>Disp: return
    Disp-->>CppQ: return
    CppQ-->>Py: return MXFP8Tensor
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +455 to +460
e8m0_t scale_from_shmem;
if (thread_lane < THREADS_X) {
scale_from_shmem = block_scales_2d[thread_lane];
}
// Broadcast: each thread gets scale from lane matching its tid_X_rowwise
biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale_from_shmem is potentially uninitialized for threads where thread_lane >= THREADS_X. While __shfl_sync only reads from lanes specified by tid_X_rowwise (which should be < THREADS_X), it's safer to initialize this variable.

Suggested change
e8m0_t scale_from_shmem;
if (thread_lane < THREADS_X) {
scale_from_shmem = block_scales_2d[thread_lane];
}
// Broadcast: each thread gets scale from lane matching its tid_X_rowwise
biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise);
e8m0_t scale_from_shmem = 0;
if (thread_lane < THREADS_X) {
scale_from_shmem = block_scales_2d[thread_lane];
}

}
}

if (use_2d_quantization) { scaling_type = ScalingType::BIDIMENSIONAL; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? I think it will fail if the output does not have one of the rowwise/columnwise pieces allocated. Also the 2D quantization should not care about that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two reasons for this:

  • First, the current implementation of this quantize kernel makes it easy for me to obtain the 32x32 block amax from the columnwise code path, but it's more difficult from the rowwise code path. Therefore, I cannot implement 2D quant when using rowwise alone.
  • Second, this 2D quant is currently only used for weights, so in practice it's only needed when using BIDIMENSIONAL mode.

However, I think one of your concerns is valid—I shouldn't directly modify the scaling type here; I should check it instead. I've now changed it so that when using 2D quant, it asserts that scaling type == BIDIMENSIONAL.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the inference case the weights would only ask for the rowwise quantization, so we need to be able to support that.

kunlunl and others added 2 commits February 10, 2026 14:24
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 415 to 423
@@ -420,7 +421,8 @@ struct QuantizationConfig {
sizeof(NVTETensor), // rng_seed and offset
sizeof(uint8_t), // nvfp4_2d_quantization
sizeof(uint8_t), // stochastic_rounding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QuantizationConfig layout mismatch

QuantizationConfig::attr_sizes[] was extended with mxfp8_2d_quantization (common.h:415-423), but QuantizationConfig itself uses bool fields. If nvte_set/get_quantization_config_attribute assumes all attributes are packed according to attr_sizes, adding an extra byte-sized attribute without updating any corresponding packing/unpacking logic can break attribute indexing for subsequent attributes (or any code that iterates kNVTEQuantizationConfigNumAttributes). Please double-check the code that uses attr_sizes to ensure the new attribute is reflected everywhere it’s consumed (and that kNVTEQuantizationConfigNumAttributes matches the size of attr_sizes).

Comment on lines +100 to +108
) # (num_block_rows, num_block_cols, 32, 32)

# Compute amax for each 32x32 block
block_amax = torch.amax(
torch.abs(x_blocks.to(torch.float32)), dim=(-1, -2)
) # (num_block_rows, num_block_cols)

# Convert to E8M0 scale inverse
block_scale_e8m0 = float_to_e8m0(block_amax) # (num_block_rows, num_block_cols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect float bit-cast

float_to_e8m0 does val_u32 = val.view(torch.int32) (test_mxfp8_2d_quantize.py:104-106). On PyTorch, .view(dtype) is a numeric cast, not a bit reinterpretation. This makes the reference implementation compute wrong exponents/mantissas and can cause false failures/passes.

Use val.view(torch.int32) only if you’ve explicitly reinterpreted bytes (e.g., via val.view(torch.uint8) + view(torch.int32) on the same storage) or use val.to(torch.int32) with torch.frexp/torch.bitwise_* alternatives. As written, the reference is not modeling the GPU’s IEEE754 bit extraction.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh
2D flag ignored for weights

In mxfp8::quantize(...), use_2d_quantization unconditionally forces scaling_type = ScalingType::BIDIMENSIONAL (quantize_mxfp8.cuh:776-779). That changes the requested scaling mode even when callers intended rowwise-only or colwise-only (e.g., weight-only paths could request rowwise-only). This will silently produce different scale tensor shapes/semantics than the caller expects.

If 2D block scaling is only valid when both rowwise+colwise outputs are requested, it should be validated (error) instead of overriding scaling_type; otherwise the override should be limited to the specific call sites that already requested bidimensional scaling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants