Skip to content

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Oct 27, 2025

Description

With this change it will match the version in the JAX container

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

@ptrendx ptrendx requested a review from Copilot October 27, 2025 16:35
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR updates the PyTorch installation in the Build All GitHub action workflow to use CUDA 13.0, aligning it with the version used in the JAX container for consistency across the build environment.

Key changes:

  • Modified PyTorch installation to explicitly use CUDA 13.0 from PyTorch's wheel repository

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR updates the PyTorch installation in the "All" build job of the GitHub Actions workflow to explicitly use CUDA 13.0wheels. The change modifies .github/workflows/build.yml by splitting the pip install command into two parts: one for general dependencies (pybind11[global], einops, onnxscript, nvidia-mathdx) and another specifically for PyTorch, sourcing it from PyTorch's cu130 wheel index. The motivation is to align the PyTorch CUDA version with what is supposedly used in the JAX container, ensuring consistency across the "All" job which builds TransformerEngine with support for both frameworks. This change fits into the broader CI infrastructure by ensuring that the combined PyTorch+JAX build uses compatible CUDA-enabled wheels.

Important Files Changed

Filename Score Overview
.github/workflows/build.yml 2/5 Modified PyTorch installation to use CUDA 13.0 wheels via explicit index URL, potentially introducing CUDA version mismatch issues

Confidence score: 2/5

  • This PR introduces potential CUDA version compatibility risks that could cause runtime failures in the "All" build job
  • Score lowered due to: (1) possible mismatch between container system CUDA (likely 12.x from JAX container) and PyTorch cu130 requirements (CUDA 13.0), (2) lack of verification that the JAX container actually uses CUDA 13.0 as claimed, (3) no updates to documentation or comments explaining the version choice, and (4) potential for breaking existing builds if CUDA 13.0 is not available in the execution environment
  • Pay close attention to .github/workflows/build.yml - verify that the base container actually provides CUDA 13.0 libraries and test that PyTorch cu130 wheels work correctly with the container's CUDA installation

Sequence Diagram

sequenceDiagram
    participant User
    participant GitHub
    participant Core Job
    participant PyTorch Job
    participant JAX Job
    participant All Job

    User->>GitHub: "Trigger workflow (PR or manual)"
    
    par Core Build
        GitHub->>Core Job: "Start in CUDA 12.1 container"
        Core Job->>Core Job: "Install dependencies (git, python3.9, cudnn9)"
        Core Job->>Core Job: "Checkout code with submodules"
        Core Job->>Core Job: "Build with NVTE_FRAMEWORK=none"
        Core Job->>Core Job: "Sanity check: import transformer_engine"
        Core Job->>GitHub: "Report result"
    and PyTorch Build
        GitHub->>PyTorch Job: "Start in CUDA 12.8 container"
        PyTorch Job->>PyTorch Job: "Install dependencies (torch, numpy, etc.)"
        PyTorch Job->>PyTorch Job: "Checkout code with submodules"
        PyTorch Job->>PyTorch Job: "Build with NVTE_FRAMEWORK=pytorch"
        PyTorch Job->>PyTorch Job: "Sanity check: test_sanity_import.py"
        PyTorch Job->>GitHub: "Report result"
    and JAX Build
        GitHub->>JAX Job: "Start in JAX container"
        JAX Job->>JAX Job: "Install dependencies (pybind11, mathdx)"
        JAX Job->>JAX Job: "Checkout code with submodules"
        JAX Job->>JAX Job: "Build with NVTE_FRAMEWORK=jax"
        JAX Job->>JAX Job: "Sanity check: test_sanity_import.py"
        JAX Job->>GitHub: "Report result"
    and All Frameworks Build
        GitHub->>All Job: "Start in JAX container"
        All Job->>All Job: "Install dependencies (pybind11, einops, onnxscript)"
        All Job->>All Job: "Install torch with CUDA 13.0"
        All Job->>All Job: "Checkout code with submodules"
        All Job->>All Job: "Build with NVTE_FRAMEWORK=all"
        All Job->>All Job: "Sanity check: both PyTorch and JAX tests"
        All Job->>GitHub: "Report result"
    end
    
    GitHub->>User: "Workflow complete"
Loading

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
run: |
pip install pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
pip install torch --index-url https://download.pytorch.org/whl/cu130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Installing PyTorch from cu130 index on JAX image (CUDA 12.6) may create CUDA version mismatch. The JAX container likely has CUDA 12.6system libraries, while PyTorch cu130 expects CUDA 13.0. Does the ghcr.io/nvidia/jax:jax container have CUDA 13.0 installed, or is it using CUDA 12.6? PyTorch cu130 requires CUDA 13.0 runtime.

match the version in the JAX container

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx force-pushed the pr_build_all_remove branch from ccf5a8e to 3ea1472 Compare October 27, 2025 23:18
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR modifies the "Build All" GitHub Actions workflow to install PyTorch with CUDA 13.0 binaries by adding --index-url https://download.pytorch.org/whl/cu130 to the pip install command. The change splits the dependencies installation into two separate pip commands: one for general dependencies (pybind11, einops, onnxscript) and another specifically for PyTorch with the cu130 index. According to the PR description, this aligns the PyTorch CUDA version with what's expected in the JAX container (ghcr.io/nvidia/jax:jax). This workflow job builds both JAX and PyTorch backends together, so the change ensures both frameworks reference CUDA 13.0 toolchains during compilation and runtime initialization.

Important Files Changed

Filename Score Overview
.github/workflows/build.yml 3/5 Modified the 'all' job to install PyTorch from cu130 index by splitting dependency installation into two pip commands

Confidence score: 3/5

  • This PR carries moderate risk due to the unresolved CUDA version compatibility concern between the JAX container and PyTorch cu130
  • Score reflects uncertainty about whether the ghcr.io/nvidia/jax:jax container actually provides CUDA 13.0 runtime libraries, as the previous review raised a valid concern that the container may have CUDA 12.6 while PyTorch cu130 expects CUDA 13.0, which could cause runtime failures
  • Close attention should be paid to .github/workflows/build.yml to verify the JAX container's actual CUDA version matches the PyTorch cu130 requirement before merging

Sequence Diagram

sequenceDiagram
    participant User
    participant GitHub
    participant Container
    participant AptGet
    participant Pip
    participant Git
    participant Build
    participant Tests

    User->>GitHub: "Trigger Build Workflow (PR or manual)"
    
    rect rgb(200, 220, 240)
    note right of GitHub: Core Job
    GitHub->>Container: "Start CUDA 12.1.0 container"
    Container->>AptGet: "Install git, python3.9, pip, cudnn9"
    AptGet-->>Container: "Dependencies installed"
    Container->>Pip: "Install cmake, pybind11, ninja"
    Pip-->>Container: "Build tools ready"
    Container->>Git: "Checkout repository with submodules"
    Git-->>Container: "Code ready"
    Container->>Build: "pip install (NVTE_FRAMEWORK=none)"
    Build-->>Container: "Core build complete"
    Container->>Tests: "python3 -c 'import transformer_engine'"
    Tests-->>GitHub: "Core sanity check passed"
    end

    rect rgb(220, 240, 200)
    note right of GitHub: PyTorch Job
    GitHub->>Container: "Start CUDA 12.8.0 container"
    Container->>AptGet: "Install git, python3.9, pip, cudnn9"
    AptGet-->>Container: "Dependencies installed"
    Container->>Pip: "Install cmake, torch, ninja, pydantic, etc."
    Pip-->>Container: "PyTorch dependencies ready"
    Container->>Git: "Checkout repository with submodules"
    Git-->>Container: "Code ready"
    Container->>Build: "pip install (NVTE_FRAMEWORK=pytorch)"
    Build-->>Container: "PyTorch build complete"
    Container->>Tests: "python3 tests/pytorch/test_sanity_import.py"
    Tests-->>GitHub: "PyTorch sanity check passed"
    end

    rect rgb(240, 220, 200)
    note right of GitHub: JAX Job
    GitHub->>Container: "Start JAX container"
    Container->>Pip: "Install pybind11"
    Pip-->>Container: "Dependencies ready"
    Container->>Git: "Checkout repository with submodules"
    Git-->>Container: "Code ready"
    Container->>Build: "pip install (NVTE_FRAMEWORK=jax)"
    Build-->>Container: "JAX build complete"
    Container->>Tests: "python3 tests/jax/test_sanity_import.py"
    Tests-->>GitHub: "JAX sanity check passed"
    end

    rect rgb(240, 200, 220)
    note right of GitHub: All Job
    GitHub->>Container: "Start JAX container"
    Container->>Pip: "Install pybind11, einops, onnxscript"
    Pip-->>Container: "Base dependencies ready"
    Container->>Pip: "Install torch with CUDA 13.0"
    Pip-->>Container: "PyTorch ready"
    Container->>Git: "Checkout repository with submodules"
    Git-->>Container: "Code ready"
    Container->>Build: "pip install (NVTE_FRAMEWORK=all)"
    Build-->>Container: "Full build complete"
    Container->>Tests: "python3 tests/pytorch/test_sanity_import.py"
    Tests-->>Container: "PyTorch test passed"
    Container->>Tests: "python3 tests/jax/test_sanity_import.py"
    Tests-->>GitHub: "JAX test passed"
    end

    GitHub-->>User: "All build jobs completed"
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx
Copy link
Member Author

ptrendx commented Oct 27, 2025

This PR does not actually do what is claimed (and that will only be possible once we can compile TE with nvcc from pip, so that we could start from a bare container and install JAX, pyTorch and then TE with just pip dependencies, cc @ksivaman), since the jax container has CUDA installed system-wide and pytorch does not want to use that. But somehow it does resolve the original problem of the build all github action failing with disk space issue, so I still think it is worth to merge it.
Tagging @jberchtold-nvidia

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, ack it does not fix the core issue with multiple cuda installations but at least it unblocks CI so agree it is worth merging. Thanks!

@ptrendx ptrendx merged commit 4cf2f12 into NVIDIA:main Oct 27, 2025
12 checks passed
@greptile-apps greptile-apps bot mentioned this pull request Feb 3, 2026
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants