Practice/Meta/Design LLM Checkpointing System
Design LLM Checkpointing System
System DesignOptional
Problem Statement
Create a system that periodically captures the complete state of machine learning training jobs running across hundreds or thousands of GPUs. These snapshots must preserve everything needed to restart training from any saved point -- model parameters, optimizer momentum terms, random number generator states, and training data positions. The system should handle models with hundreds of billions of parameters (multiple terabytes of state) being trained on GPU clusters for weeks or months, where hardware failures are common and training interruptions costly.
Your design must support concurrent training jobs, each with different snapshotting policies (frequency, retention, conditional triggers based on validation metrics). The system should minimize interference with ongoing training -- checkpointing overhead should not significantly slow down GPU utilization or training throughput. Recovery from snapshots must guarantee deterministic resumption: restarting from step N should produce identical results as if training had never stopped.
Key Requirements
Functional
- Distributed state capture -- collect and persist sharded model state from hundreds of GPU workers without coordination bottlenecks
- Point-in-time consistency -- ensure all saved state corresponds to exactly the same training iteration across all workers
- Deterministic restoration -- rebuild training state such that resumed runs produce bit-identical results
- Policy-driven scheduling -- support time-based, step-based, and metric-based triggers for when to snapshot
- Version management -- track, tag, and selectively retain or expire snapshots according to configured rules
Non-Functional
- Scalability -- handle 1000+ GPU workers per job, 5+ TB total state per snapshot, dozens of concurrent training jobs
- Reliability -- tolerate worker failures during snapshot operations, detect and reject corrupted snapshots
- Latency -- complete snapshots in under 5 minutes; impose less than 5% overhead on training throughput
- Consistency -- provide atomic snapshot publication (all-or-nothing visibility) and prevent partial or mixed-step snapshots
What Interviewers Focus On
Based on real interview experiences, these are the areas interviewers probe most deeply:
1. Handling Distributed State at Scale
Interviewers want to see if you understand that a single centralized snapshot file is completely impractical for multi-terabyte model states. They're looking for a sharded approach where each worker independently persists its portion of the model.
Hints to consider:
- Each GPU rank writes its tensor shards to object storage in parallel, avoiding a central aggregation point
- Use a two-phase commit pattern: workers upload data blobs first, then a coordinator atomically publishes a manifest referencing all shards
- Implement checksums or content hashes for each shard to detect corruption during upload or download
- Consider local NVMe staging to decouple disk write speed from network upload speed
2. Achieving Cross-Worker Consistency
The trickiest part is ensuring all workers snapshot state from the exact same training iteration. Interviewers probe whether you recognize this requires explicit coordination, not just hoping workers happen to align.
Hints to consider:
- Introduce a barrier synchronization primitive (distributed lock, coordination service) at the end of target iterations
- All workers must reach the barrier before any worker begins writing state to ensure consistent step numbers
- Handle stragglers and failures gracefully: set timeouts, allow snapshot attempts to fail without crashing training
- The manifest should record not just shard locations but also the exact step, learning rate, and other training metadata
3. Minimizing Impact on Training Performance
Interviewers expect you to acknowledge that checkpointing competes with training for GPU memory bandwidth, network bandwidth, and storage I/O. Naive implementations can stall training for minutes.
Hints to consider:
- Decouple the critical path: copy tensors from GPU to CPU memory quickly (pinned allocations), then upload asynchronously
- Use compression (like zstd or lz4) to reduce bytes written, trading CPU for I/O -- but watch for compression becoming the bottleneck
- Stagger snapshot operations across workers to avoid network storms and storage hotspots
- Monitor and enforce I/O rate limits to prevent checkpointing from saturating storage paths used by training data loaders
4. Snapshot Lifecycle and Storage Management
Interviewers look for awareness that long-running training jobs generate many snapshots, and storage costs spiral without lifecycle policies. You need versioning, selective retention, and efficient garbage collection.
Hints to consider:
- Maintain a metadata catalog (relational database) tracking snapshot step, timestamp, validation metrics, and retention policy
- Support rules like "keep last 5, keep best 3 by validation loss, expire after 30 days" with automated cleanup jobs
- Use incremental or differential snapshots where only changed parameters are saved (though this adds complexity for large models with few unchanged weights)
- Provide tagging and pinning mechanisms so users can protect important snapshots from automatic expiration