Skip to content

base

Classes

fastvideo.train.methods.base.TrainingMethod

TrainingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: Module, ABC

Base training method (algorithm layer).

Subclasses own their role models (student, teacher, critic, …) as plain attributes and manage optimizers directly — no RoleManager or RoleHandle.

The constructor receives role_models (a dict[str, ModelBase]) and a cfg object. It calls init_preprocessors on the student and builds self.role_modules for FSDP wrapping.

A single shared CUDA RNG generator (cuda_generator) is created in :meth:on_train_start. All torch.randn / torch.randint calls in methods and models must use this generator instead of relying on global RNG state.

Source code in fastvideo/train/methods/base.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__()
    self.tracker: Any | None = None
    self._role_models: dict[str, ModelBase] = dict(role_models)

    self.student = role_models["student"]
    self.training_config = cfg.training
    self.method_config: dict[str, Any] = dict(cfg.method)
    self.validation_config: dict[str, Any] = dict(getattr(cfg, "validation", {}) or {})

    # Build nn.ModuleDict for FSDP / checkpoint visibility.
    self.role_modules = torch.nn.ModuleDict()
    for role, model in role_models.items():
        mods: dict[str, torch.nn.Module] = {}
        transformer = getattr(model, "transformer", None)
        if isinstance(transformer, torch.nn.Module):
            mods["transformer"] = transformer
        if mods:
            self.role_modules[role] = torch.nn.ModuleDict(mods)

Functions

fastvideo.train.methods.base.TrainingMethod.checkpoint_state
checkpoint_state() -> dict[str, Any]

Return DCP-ready checkpoint state for all trainable roles.

Keys follow the convention: roles.<role>.<module>, optimizers.<role>, schedulers.<role>, random_state.*.

EMA state is managed by the EMACallback and is checkpointed through the callback state mechanism.

Source code in fastvideo/train/methods/base.py
def checkpoint_state(self) -> dict[str, Any]:
    """Return DCP-ready checkpoint state for all trainable roles.

    Keys follow the convention:
    ``roles.<role>.<module>``, ``optimizers.<role>``,
    ``schedulers.<role>``, ``random_state.*``.

    EMA state is managed by the ``EMACallback`` and is
    checkpointed through the callback state mechanism.
    """
    states: dict[str, Any] = {}

    for role, model in self._role_models.items():
        if not getattr(model, "_trainable", False):
            continue

        modules: dict[str, torch.nn.Module] = {}
        if model.transformer is not None:
            modules["transformer"] = model.transformer

        container = _RoleModuleContainer(modules)

        for module_name, module in modules.items():
            states[f"roles.{role}.{module_name}"] = ModelWrapper(module)

        opt = self._optimizer_dict.get(role)
        if opt is not None:
            states[f"optimizers.{role}"] = OptimizerWrapper(container, opt)

        sched = self._lr_scheduler_dict.get(role)
        if sched is not None:
            states[f"schedulers.{role}"] = SchedulerWrapper(sched)

    return states
fastvideo.train.methods.base.TrainingMethod.get_grad_clip_targets
get_grad_clip_targets(iteration: int) -> dict[str, Module]

Return modules whose gradients should be clipped.

Override in subclasses to add/conditionally include modules (e.g. critic, conditionally student). Default: student transformer.

Source code in fastvideo/train/methods/base.py
def get_grad_clip_targets(
    self,
    iteration: int,
) -> dict[str, torch.nn.Module]:
    """Return modules whose gradients should be clipped.

    Override in subclasses to add/conditionally include
    modules (e.g. critic, conditionally student).
    Default: student transformer.
    """
    return {"student": self.student.transformer}
fastvideo.train.methods.base.TrainingMethod.seed_optimizer_state_for_resume
seed_optimizer_state_for_resume() -> None

Seed optimizer state so DCP can load saved state.

A fresh optimizer has empty state (exp_avg, exp_avg_sq, step are only created on the first optimizer.step()). DCP needs matching entries to load into; without them the saved optimizer state is silently dropped.

Source code in fastvideo/train/methods/base.py
def seed_optimizer_state_for_resume(self) -> None:
    """Seed optimizer state so DCP can load saved state.

    A fresh optimizer has empty state (exp_avg, exp_avg_sq,
    step are only created on the first optimizer.step()).
    DCP needs matching entries to load into; without them
    the saved optimizer state is silently dropped.
    """
    for opt in self.get_optimizers(0):
        for group in opt.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                if len(opt.state.get(p, {})) > 0:
                    continue
                opt.state[p] = {
                    "step": torch.tensor(0.0),
                    "exp_avg": torch.zeros_like(p),
                    "exp_avg_sq": torch.zeros_like(p),
                }

Functions