Skip to content

checkpoint

Classes

fastvideo.train.utils.checkpoint.CheckpointManager

CheckpointManager(*, method: Any, dataloader: Any, output_dir: str, config: CheckpointConfig, callbacks: Any | None = None, raw_config: dict[str, Any] | None = None)

Role-based checkpoint manager for training runtime.

  • Checkpoint policy lives in YAML (via TrainingArgs fields).
  • Resume path is typically provided via CLI (--resume-from-checkpoint).
Source code in fastvideo/train/utils/checkpoint.py
def __init__(
    self,
    *,
    method: Any,
    dataloader: Any,
    output_dir: str,
    config: CheckpointConfig,
    callbacks: Any | None = None,
    raw_config: dict[str, Any] | None = None,
) -> None:
    self.method = method
    self.dataloader = dataloader
    self.output_dir = str(output_dir)
    self.config = config
    self._callbacks = callbacks
    self._raw_config = raw_config
    self._last_saved_step: int | None = None

Functions

fastvideo.train.utils.checkpoint.CheckpointManager.load_metadata staticmethod
load_metadata(checkpoint_dir: str | Path) -> dict[str, Any]

Read metadata.json from a checkpoint dir.

Source code in fastvideo/train/utils/checkpoint.py
@staticmethod
def load_metadata(checkpoint_dir: str | Path, ) -> dict[str, Any]:
    """Read ``metadata.json`` from a checkpoint dir."""
    meta_path = Path(checkpoint_dir) / "metadata.json"
    if not meta_path.is_file():
        raise FileNotFoundError(f"No metadata.json in {checkpoint_dir}")
    with open(meta_path, encoding="utf-8") as f:
        return json.load(f)  # type: ignore[no-any-return]
fastvideo.train.utils.checkpoint.CheckpointManager.load_rng_snapshot
load_rng_snapshot(checkpoint_path: str) -> None

Restore per-rank RNG state from the snapshot file.

Must be called AFTER dcp.load and after iter(dataloader) so no later operation can clobber the restored state.

Source code in fastvideo/train/utils/checkpoint.py
def load_rng_snapshot(
    self,
    checkpoint_path: str,
) -> None:
    """Restore per-rank RNG state from the snapshot file.

    Must be called AFTER ``dcp.load`` **and** after
    ``iter(dataloader)`` so no later operation can
    clobber the restored state.
    """
    resolved = _resolve_resume_checkpoint(
        checkpoint_path,
        output_dir=self.output_dir,
    )
    if resolved is None:
        return
    rank = _rank()
    rng_path = resolved / f"rng_state_rank{rank}.pt"
    if not rng_path.is_file():
        # Fall back to legacy single-file snapshot.
        rng_path = resolved / "rng_state.pt"
    if not rng_path.is_file():
        logger.warning(
            "No rng_state in %s; skipping "
            "RNG snapshot restore.",
            resolved,
        )
        return

    rng = torch.load(
        rng_path,
        map_location="cpu",
        weights_only=False,
    )
    if "torch_rng" in rng:
        torch.set_rng_state(rng["torch_rng"])
    if "python_rng" in rng:
        random.setstate(rng["python_rng"])
    if "numpy_rng" in rng:
        np.random.set_state(rng["numpy_rng"])

    torch.cuda.set_rng_state(rng["cuda_rng"])
    self.method.cuda_generator.set_state(rng["gen_cuda"])
    logger.info(
        "Restored RNG snapshot from %s",
        rng_path,
    )

Functions