-
Notifications
You must be signed in to change notification settings - Fork 638
[JAX] TE Permutation integration to Maxtext #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <[email protected]>
Signed-off-by: tdophung <[email protected]>
…ger than num tokens Signed-off-by: tdophung <[email protected]>
Signed-off-by: JAX Toolbox <[email protected]>
for more information, see https://pre-commit.ci
|
This PR contain changes cherry-picked from #2651 . I can wait until this gets merged and then merge mine, but if my PR is needed more urgently, happy to remove the cherry picked change |
Greptile OverviewGreptile SummaryThis PR adds necessary changes to support MaxText integration with TE permutation operations when Expert Parallelism (EP) > 1. Key Changes:
Technical Implementation:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as MaxText User
participant API as sort_chunks_by_index
participant FWD as _sort_chunks_by_index_fwd_rule
participant Kernel as _make_chunk_sort_map_kernel
participant BWD as _sort_chunks_by_index_bwd_rule
User->>API: sort_chunks_by_index(inp, split_sizes, sorted_indices)
API->>FWD: Forward pass
FWD->>Kernel: Generate row_id_map with padding handling
Note over Kernel: Compute total_valid_tokens<br/>Apply identity mapping for pid >= total_valid_tokens
Kernel-->>FWD: row_id_map (with padding masked)
FWD->>FWD: sort_chunks_by_map(inp, row_id_map)
FWD-->>API: (output, row_id_map), residuals
Note over FWD: residuals now include split_sizes<br/>and sorted_indices (not nondiff_argnums)
User->>BWD: Backward pass (gradient)
BWD->>BWD: Extract split_sizes, sorted_indices from residuals
BWD->>BWD: sort_chunks_by_map(output_grad, row_id_map, is_forward=False)
BWD-->>User: (inp_grad, zeros_like(split_sizes), zeros_like(sorted_indices))
Last reviewed commit: 11a45d3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 3 comments
transformer_engine/jax/inspect.py
Outdated
| _inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) | ||
|
|
||
|
|
||
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name parameter is unused - not passed to C++ backend or used in filename
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: | |
| def inspect_array(x: jnp.ndarray) -> jnp.ndarray: |
| std::ofstream file(filename, std::ios::binary); | ||
| if (file.is_open()) { | ||
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | ||
| file.close(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No error handling if file fails to open - silently continues without writing data
| std::ofstream file(filename, std::ios::binary); | |
| if (file.is_open()) { | |
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | |
| file.close(); | |
| } | |
| std::ofstream file(filename, std::ios::binary); | |
| if (!file.is_open()) { | |
| return ffi::Error(ffi::ErrorCode::kInternal, "Failed to open file for writing"); | |
| } | |
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | |
| file.close(); |
|
/te-ci |
|
/te_ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Signed-off-by: JAX Toolbox <[email protected]>
…rmerEngine into maxtext_integ_2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
|
/te_ci |
|
/te-ci |
Description
Changes needed on TE side to make maxtext integration works
Issue # 2585
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: