Skip to content

Checkpoint save/load FSDP and ShardTensor support#1472

Draft
pzharrington wants to merge 7 commits intoNVIDIA:mainfrom
pzharrington:fsdp-shardtensor-ckpt
Draft

Checkpoint save/load FSDP and ShardTensor support#1472
pzharrington wants to merge 7 commits intoNVIDIA:mainfrom
pzharrington:fsdp-shardtensor-ckpt

Conversation

@pzharrington
Copy link
Collaborator

@pzharrington pzharrington commented Mar 5, 2026

PhysicsNeMo Pull Request

Description

Summary

  • FSDP/ShardTensor-aware checkpoint save and load: save_checkpoint and load_checkpoint now automatically detect FSDP-wrapped and DTensor/ShardTensor-distributed models and use PyTorch's Distributed Checkpoint (DCP) state-dict APIs to gather/scatter model and optimizer state. In distributed mode all ranks call the functions collectively, while only rank 0 performs file I/O. This eliminates the need for manual parameter gathering/scattering that recipe code (e.g. StormCast) previously had to implement.
  • New load_model_weights utility: A convenience function for loading a single .mdlus or .pt file directly into a (potentially distributed) model, handling FSDP + DTensor redistribution automatically.
  • StormCast recipe simplification: Removed ~200 lines of manual checkpoint gather/scatter logic from parallel.py (gather_training_state, scatter_optimizer_state, shard_state_dict, scatter_object, get_state_dict_shard) and ~50 lines of rank-0 CPU model/optimizer bookkeeping from trainer.py. All ranks now participate symmetrically in _resume_or_init, calling load_checkpoint / save_checkpoint directly.
  • physicsnemo.core.Module.save: Added an optional state_dict parameter so save_checkpoint can pass a pre-gathered full state dictionary for FSDP/DTensor models without calling self.state_dict() on the distributed module.
  • Minimum torch version bump 2.4 → 2.5: Required because StateDictOptions.broadcast_from_rank0 (used in the pure-FSDP load path) was introduced in PyTorch 2.5. This option enables rank 0 to broadcast the full state dict to all other ranks without manual scatter, which is the standard non-DTensor distributed load mechanism.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@pzharrington
Copy link
Collaborator Author

@greptileai

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR introduces FSDP and ShardTensor-aware checkpoint save/load by centralising distributed state-dict logic in physicsnemo/utils/checkpoint.py, eliminating ~250 lines of manual gather/scatter code from the StormCast recipe, and adding a new load_model_weights convenience function. The approach is architecturally sound: save_checkpoint / load_checkpoint auto-detect distributed models via _is_distributed_model, enter a collective code path using PyTorch's DCP get_model_state_dict / set_model_state_dict APIs, and restrict file I/O to rank 0. One logic issue was found:

  • _redistribute_sd_for_dtensor docstring mismatch: The docstring states "all other entries are left unchanged," but when a fallback mesh is available, the implementation distributes all remaining plain tensors to the fallback mesh as Replicate DTensors, not just those expected to be DTensors. For models with both ShardTensor and plain parameters, this could cause downstream issues. The docstring should be corrected to accurately document this behavior.

A minor style issue was also identified: a test file saves a plain torch.save() weights file with a .mdlus extension, which is misleading about supported file formats even though the test passes due to type-based dispatch logic.

Last reviewed commit: 52d03eb

@coreyjadams
Copy link
Collaborator

@pzharrington Does this need review from shard tensor side too?


from physicsnemo.distributed import DistributedManager, scatter_tensor
from physicsnemo.distributed import DistributedManager
from physicsnemo.domain_parallel import scatter_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have caught this sooner and I appreciate you got it. Thanks!

@coreyjadams
Copy link
Collaborator

I took a look through this PR - over all I think this is much needed (and long overdue) functionality, thank you for finally taking action when no one else would!

Overall I have one concern to discuss: over the next release I think it's important to decouple the ShardTensor(DTensor) inheritance structure (I know that's not a surprise just raising it in this context... ). Will that break anything you've implemented here? We probably will need to use distributed tooling but there could be weird behavior introduced if I do that.

What do you think?

@pzharrington
Copy link
Collaborator Author

Copying response from Slack here for posterity

yes pivoting away from ShardTensor(DTensor) inheritance will break some stuff, but the level of refactoring needed I guess would depend on how far ShardTensor ends up drifting from DTensor. I wouldn't let the checkpoint functionality determine much in the "to DTensor or not to DTensor" decision as I assume whatever we end up with, there will be a non-horrible pathway to keeping the same user-facing functionality for checkpoints by shuffling things around under the hood. Bar none, eventually it will need some refactoring to support fsdp2 instead of FSDP (which is deprecating NO_SHARD)

attn_kernel_size: int = 31,
lead_time_steps: int = 0,
layernorm_backend: Literal["torch", "apex"] = "apex",
layernorm_backend: Literal["torch", "apex"] = "torch",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this change is intentional? It's not mentioned in the PR description. But fine for me since it can be overridden by config file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants