Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Feb 11, 2026

Description

Changes needed on TE side to make maxtext integration works

Issue # 2585

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Masking out padding tokens on each local EP (will be used in local permute step)
  • Pass along split_sizes and sorted_indices in residual of sort_chunks_by_index (local_permute) to avoid mismatch in size issue during tracing when EP>1

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung
Copy link
Collaborator Author

This PR contain changes cherry-picked from #2651 . I can wait until this gets merged and then merge mine, but if my PR is needed more urgently, happy to remove the cherry picked change

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

This PR adds necessary changes to support MaxText integration with TE permutation operations when Expert Parallelism (EP) > 1.

Key Changes:

  • Padding token masking in Triton kernel: Modified _make_chunk_sort_map_kernel to compute total_valid_tokens and apply identity mapping for padding tokens (positions beyond the valid token count), preventing corruption of valid output positions
  • VJP tracing fix: Removed nondiff_argnums=(1, 2) from @jax.custom_vjp decorator and instead passed split_sizes and sorted_indices through residuals with zero gradients, resolving size mismatch issues during JAX tracing with EP>1

Technical Implementation:

  • The Triton kernel now checks pid < total_valid_tokens before applying the permutation mapping, using tl.where(pid < total_valid_tokens, dst_row, pid) to preserve padding tokens at their original positions
  • The backward pass now correctly unpacks split_sizes and sorted_indices from residuals and returns zero gradients for these integer arrays, maintaining gradient flow correctness

Confidence Score: 4/5

  • This PR is safe to merge with minor risk - changes are well-targeted to specific integration issues
  • The changes are focused and address specific technical issues (padding token handling and JAX tracing) for MaxText integration. The padding logic is sound and the VJP changes follow standard JAX patterns. Confidence reduced from 5 to 4 due to lack of new tests mentioned in the checklist.
  • No files require special attention - both changes are straightforward bug fixes

Important Files Changed

Filename Overview
transformer_engine/common/triton/permutation.py Added padding token masking logic to _make_chunk_sort_map_kernel to handle buffers larger than valid token count with identity mapping
transformer_engine/jax/permutation.py Removed nondiff_argnums from @jax.custom_vjp and passed split_sizes/sorted_indices through residuals with zero gradients to fix tracing issues

Sequence Diagram

sequenceDiagram
    participant User as MaxText User
    participant API as sort_chunks_by_index
    participant FWD as _sort_chunks_by_index_fwd_rule
    participant Kernel as _make_chunk_sort_map_kernel
    participant BWD as _sort_chunks_by_index_bwd_rule

    User->>API: sort_chunks_by_index(inp, split_sizes, sorted_indices)
    API->>FWD: Forward pass
    FWD->>Kernel: Generate row_id_map with padding handling
    Note over Kernel: Compute total_valid_tokens<br/>Apply identity mapping for pid >= total_valid_tokens
    Kernel-->>FWD: row_id_map (with padding masked)
    FWD->>FWD: sort_chunks_by_map(inp, row_id_map)
    FWD-->>API: (output, row_id_map), residuals
    Note over FWD: residuals now include split_sizes<br/>and sorted_indices (not nondiff_argnums)
    
    User->>BWD: Backward pass (gradient)
    BWD->>BWD: Extract split_sizes, sorted_indices from residuals
    BWD->>BWD: sort_chunks_by_map(output_grad, row_id_map, is_forward=False)
    BWD-->>User: (inp_grad, zeros_like(split_sizes), zeros_like(sorted_indices))
Loading

Last reviewed commit: 11a45d3

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)


def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name parameter is unused - not passed to C++ backend or used in filename

Suggested change
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
def inspect_array(x: jnp.ndarray) -> jnp.ndarray:

Comment on lines 116 to 120
std::ofstream file(filename, std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No error handling if file fails to open - silently continues without writing data

Suggested change
std::ofstream file(filename, std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
}
std::ofstream file(filename, std::ios::binary);
if (!file.is_open()) {
return ffi::Error(ffi::ErrorCode::kInternal, "Failed to open file for writing");
}
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();

@tdophung
Copy link
Collaborator Author

/te-ci

@tdophung
Copy link
Collaborator Author

/te_ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@tdophung
Copy link
Collaborator Author

/te_ci

@tdophung
Copy link
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants