-
Notifications
You must be signed in to change notification settings - Fork 631
Change the pyTorch installation to CUDA 13 in Build All GitHub action #2308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR updates the PyTorch installation in the Build All GitHub action workflow to use CUDA 13.0, aligning it with the version used in the JAX container for consistency across the build environment.
Key changes:
- Modified PyTorch installation to explicitly use CUDA 13.0 from PyTorch's wheel repository
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR updates the PyTorch installation in the "All" build job of the GitHub Actions workflow to explicitly use CUDA 13.0wheels. The change modifies .github/workflows/build.yml by splitting the pip install command into two parts: one for general dependencies (pybind11[global], einops, onnxscript, nvidia-mathdx) and another specifically for PyTorch, sourcing it from PyTorch's cu130 wheel index. The motivation is to align the PyTorch CUDA version with what is supposedly used in the JAX container, ensuring consistency across the "All" job which builds TransformerEngine with support for both frameworks. This change fits into the broader CI infrastructure by ensuring that the combined PyTorch+JAX build uses compatible CUDA-enabled wheels.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| .github/workflows/build.yml | 2/5 | Modified PyTorch installation to use CUDA 13.0 wheels via explicit index URL, potentially introducing CUDA version mismatch issues |
Confidence score: 2/5
- This PR introduces potential CUDA version compatibility risks that could cause runtime failures in the "All" build job
- Score lowered due to: (1) possible mismatch between container system CUDA (likely 12.x from JAX container) and PyTorch cu130 requirements (CUDA 13.0), (2) lack of verification that the JAX container actually uses CUDA 13.0 as claimed, (3) no updates to documentation or comments explaining the version choice, and (4) potential for breaking existing builds if CUDA 13.0 is not available in the execution environment
- Pay close attention to
.github/workflows/build.yml- verify that the base container actually provides CUDA 13.0 libraries and test that PyTorch cu130 wheels work correctly with the container's CUDA installation
Sequence Diagram
sequenceDiagram
participant User
participant GitHub
participant Core Job
participant PyTorch Job
participant JAX Job
participant All Job
User->>GitHub: "Trigger workflow (PR or manual)"
par Core Build
GitHub->>Core Job: "Start in CUDA 12.1 container"
Core Job->>Core Job: "Install dependencies (git, python3.9, cudnn9)"
Core Job->>Core Job: "Checkout code with submodules"
Core Job->>Core Job: "Build with NVTE_FRAMEWORK=none"
Core Job->>Core Job: "Sanity check: import transformer_engine"
Core Job->>GitHub: "Report result"
and PyTorch Build
GitHub->>PyTorch Job: "Start in CUDA 12.8 container"
PyTorch Job->>PyTorch Job: "Install dependencies (torch, numpy, etc.)"
PyTorch Job->>PyTorch Job: "Checkout code with submodules"
PyTorch Job->>PyTorch Job: "Build with NVTE_FRAMEWORK=pytorch"
PyTorch Job->>PyTorch Job: "Sanity check: test_sanity_import.py"
PyTorch Job->>GitHub: "Report result"
and JAX Build
GitHub->>JAX Job: "Start in JAX container"
JAX Job->>JAX Job: "Install dependencies (pybind11, mathdx)"
JAX Job->>JAX Job: "Checkout code with submodules"
JAX Job->>JAX Job: "Build with NVTE_FRAMEWORK=jax"
JAX Job->>JAX Job: "Sanity check: test_sanity_import.py"
JAX Job->>GitHub: "Report result"
and All Frameworks Build
GitHub->>All Job: "Start in JAX container"
All Job->>All Job: "Install dependencies (pybind11, einops, onnxscript)"
All Job->>All Job: "Install torch with CUDA 13.0"
All Job->>All Job: "Checkout code with submodules"
All Job->>All Job: "Build with NVTE_FRAMEWORK=all"
All Job->>All Job: "Sanity check: both PyTorch and JAX tests"
All Job->>GitHub: "Report result"
end
GitHub->>User: "Workflow complete"
1 file reviewed, 1 comment
| run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1 | ||
| run: | | ||
| pip install pybind11[global] einops onnxscript nvidia-mathdx==25.1.1 | ||
| pip install torch --index-url https://download.pytorch.org/whl/cu130 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Installing PyTorch from cu130 index on JAX image (CUDA 12.6) may create CUDA version mismatch. The JAX container likely has CUDA 12.6system libraries, while PyTorch cu130 expects CUDA 13.0. Does the ghcr.io/nvidia/jax:jax container have CUDA 13.0 installed, or is it using CUDA 12.6? PyTorch cu130 requires CUDA 13.0 runtime.
match the version in the JAX container Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
ccf5a8e to
3ea1472
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR modifies the "Build All" GitHub Actions workflow to install PyTorch with CUDA 13.0 binaries by adding --index-url https://download.pytorch.org/whl/cu130 to the pip install command. The change splits the dependencies installation into two separate pip commands: one for general dependencies (pybind11, einops, onnxscript) and another specifically for PyTorch with the cu130 index. According to the PR description, this aligns the PyTorch CUDA version with what's expected in the JAX container (ghcr.io/nvidia/jax:jax). This workflow job builds both JAX and PyTorch backends together, so the change ensures both frameworks reference CUDA 13.0 toolchains during compilation and runtime initialization.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| .github/workflows/build.yml | 3/5 | Modified the 'all' job to install PyTorch from cu130 index by splitting dependency installation into two pip commands |
Confidence score: 3/5
- This PR carries moderate risk due to the unresolved CUDA version compatibility concern between the JAX container and PyTorch cu130
- Score reflects uncertainty about whether the ghcr.io/nvidia/jax:jax container actually provides CUDA 13.0 runtime libraries, as the previous review raised a valid concern that the container may have CUDA 12.6 while PyTorch cu130 expects CUDA 13.0, which could cause runtime failures
- Close attention should be paid to .github/workflows/build.yml to verify the JAX container's actual CUDA version matches the PyTorch cu130 requirement before merging
Sequence Diagram
sequenceDiagram
participant User
participant GitHub
participant Container
participant AptGet
participant Pip
participant Git
participant Build
participant Tests
User->>GitHub: "Trigger Build Workflow (PR or manual)"
rect rgb(200, 220, 240)
note right of GitHub: Core Job
GitHub->>Container: "Start CUDA 12.1.0 container"
Container->>AptGet: "Install git, python3.9, pip, cudnn9"
AptGet-->>Container: "Dependencies installed"
Container->>Pip: "Install cmake, pybind11, ninja"
Pip-->>Container: "Build tools ready"
Container->>Git: "Checkout repository with submodules"
Git-->>Container: "Code ready"
Container->>Build: "pip install (NVTE_FRAMEWORK=none)"
Build-->>Container: "Core build complete"
Container->>Tests: "python3 -c 'import transformer_engine'"
Tests-->>GitHub: "Core sanity check passed"
end
rect rgb(220, 240, 200)
note right of GitHub: PyTorch Job
GitHub->>Container: "Start CUDA 12.8.0 container"
Container->>AptGet: "Install git, python3.9, pip, cudnn9"
AptGet-->>Container: "Dependencies installed"
Container->>Pip: "Install cmake, torch, ninja, pydantic, etc."
Pip-->>Container: "PyTorch dependencies ready"
Container->>Git: "Checkout repository with submodules"
Git-->>Container: "Code ready"
Container->>Build: "pip install (NVTE_FRAMEWORK=pytorch)"
Build-->>Container: "PyTorch build complete"
Container->>Tests: "python3 tests/pytorch/test_sanity_import.py"
Tests-->>GitHub: "PyTorch sanity check passed"
end
rect rgb(240, 220, 200)
note right of GitHub: JAX Job
GitHub->>Container: "Start JAX container"
Container->>Pip: "Install pybind11"
Pip-->>Container: "Dependencies ready"
Container->>Git: "Checkout repository with submodules"
Git-->>Container: "Code ready"
Container->>Build: "pip install (NVTE_FRAMEWORK=jax)"
Build-->>Container: "JAX build complete"
Container->>Tests: "python3 tests/jax/test_sanity_import.py"
Tests-->>GitHub: "JAX sanity check passed"
end
rect rgb(240, 200, 220)
note right of GitHub: All Job
GitHub->>Container: "Start JAX container"
Container->>Pip: "Install pybind11, einops, onnxscript"
Pip-->>Container: "Base dependencies ready"
Container->>Pip: "Install torch with CUDA 13.0"
Pip-->>Container: "PyTorch ready"
Container->>Git: "Checkout repository with submodules"
Git-->>Container: "Code ready"
Container->>Build: "pip install (NVTE_FRAMEWORK=all)"
Build-->>Container: "Full build complete"
Container->>Tests: "python3 tests/pytorch/test_sanity_import.py"
Tests-->>Container: "PyTorch test passed"
Container->>Tests: "python3 tests/jax/test_sanity_import.py"
Tests-->>GitHub: "JAX test passed"
end
GitHub-->>User: "All build jobs completed"
1 file reviewed, no comments
|
This PR does not actually do what is claimed (and that will only be possible once we can compile TE with nvcc from pip, so that we could start from a bare container and install JAX, pyTorch and then TE with just pip dependencies, cc @ksivaman), since the jax container has CUDA installed system-wide and pytorch does not want to use that. But somehow it does resolve the original problem of the build all github action failing with disk space issue, so I still think it is worth to merge it. |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, ack it does not fix the core issue with multiple cuda installations but at least it unblocks CI so agree it is worth merging. Thanks!
Description
With this change it will match the version in the JAX container
Type of change