Checkpoint save/load FSDP and ShardTensor support#1472
Checkpoint save/load FSDP and ShardTensor support#1472pzharrington wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR introduces FSDP and ShardTensor-aware checkpoint save/load by centralising distributed state-dict logic in
A minor style issue was also identified: a test file saves a plain Last reviewed commit: 52d03eb |
|
@pzharrington Does this need review from shard tensor side too? |
|
|
||
| from physicsnemo.distributed import DistributedManager, scatter_tensor | ||
| from physicsnemo.distributed import DistributedManager | ||
| from physicsnemo.domain_parallel import scatter_tensor |
There was a problem hiding this comment.
Should have caught this sooner and I appreciate you got it. Thanks!
|
I took a look through this PR - over all I think this is much needed (and long overdue) functionality, thank you for finally taking action when no one else would! Overall I have one concern to discuss: over the next release I think it's important to decouple the ShardTensor(DTensor) inheritance structure (I know that's not a surprise just raising it in this context... ). Will that break anything you've implemented here? We probably will need to use distributed tooling but there could be weird behavior introduced if I do that. What do you think? |
|
Copying response from Slack here for posterity
|
| attn_kernel_size: int = 31, | ||
| lead_time_steps: int = 0, | ||
| layernorm_backend: Literal["torch", "apex"] = "apex", | ||
| layernorm_backend: Literal["torch", "apex"] = "torch", |
There was a problem hiding this comment.
I assume this change is intentional? It's not mentioned in the PR description. But fine for me since it can be overridden by config file.
PhysicsNeMo Pull Request
Description
Summary
save_checkpointandload_checkpointnow automatically detect FSDP-wrapped and DTensor/ShardTensor-distributed models and use PyTorch's Distributed Checkpoint (DCP) state-dict APIs to gather/scatter model and optimizer state. In distributed mode all ranks call the functions collectively, while only rank 0 performs file I/O. This eliminates the need for manual parameter gathering/scattering that recipe code (e.g. StormCast) previously had to implement.load_model_weightsutility: A convenience function for loading a single.mdlusor.ptfile directly into a (potentially distributed) model, handling FSDP + DTensor redistribution automatically.parallel.py(gather_training_state,scatter_optimizer_state,shard_state_dict,scatter_object,get_state_dict_shard)and ~50 lines of rank-0 CPU model/optimizer bookkeeping from trainer.py. All ranks now participate symmetrically in_resume_or_init, callingload_checkpoint/save_checkpointdirectly.physicsnemo.core.Module.save: Added an optionalstate_dictparameter sosave_checkpointcan pass a pre-gathered full state dictionary for FSDP/DTensor models without callingself.state_dict()on the distributed module.StateDictOptions.broadcast_from_rank0(used in the pure-FSDP load path) was introduced in PyTorch 2.5. This option enables rank 0 to broadcast the full state dict to all other ranks without manual scatter, which is the standard non-DTensor distributed load mechanism.Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.