Skip to content

train

Modules

fastvideo.train.callbacks

Classes

fastvideo.train.callbacks.Callback

Base callback with no-op hooks.

Subclasses override whichever hooks they need. The training_config and method attributes are set by CallbackDict after instantiation.

fastvideo.train.callbacks.CallbackDict
CallbackDict(callback_configs: dict[str, dict[str, Any]], training_config: TrainingConfig)

Manages a collection of named callbacks.

Instantiates each callback from its _target_ config and dispatches hook calls to all registered callbacks.

Source code in fastvideo/train/callbacks/callback.py
def __init__(
    self,
    callback_configs: dict[str, dict[str, Any]],
    training_config: TrainingConfig,
) -> None:
    self._callbacks: dict[str, Callback] = {}
    if not callback_configs:
        return
    for name, cb_cfg in callback_configs.items():
        cb_cfg = dict(cb_cfg)
        if "_target_" not in cb_cfg:
            if name in _BUILTIN_CALLBACKS:
                cb_cfg["_target_"] = (_BUILTIN_CALLBACKS[name])
            else:
                logger.warning(
                    "Callback %r is missing "
                    "'_target_', skipping: %s",
                    name,
                    cb_cfg,
                )
                continue
        logger.info(
            "Instantiating callback %r: %s",
            name,
            cb_cfg,
        )
        cb = instantiate(cb_cfg)
        if not isinstance(cb, Callback):
            raise TypeError(f"Callback {name!r} resolved to "
                            f"{type(cb).__name__}, expected a "
                            f"Callback subclass.")
        cb.training_config = training_config
        cb._callback_dict = self
        self._callbacks[name] = cb
fastvideo.train.callbacks.EMACallback
EMACallback(*, decay: float = 0.9999, start_iter: int = 0)

Bases: Callback

Manage EMA shadow weights for the student transformer.

All configuration lives in the YAML callbacks.ema section:

.. code-block:: yaml

callbacks:
  ema:
    decay: 0.9999
    start_iter: 0

The callback creates an EMA_FSDP instance at train start, updates it after each optimizer step, and exposes an ema_context() context manager for temporarily swapping EMA weights into the live model (used by validation).

Source code in fastvideo/train/callbacks/ema.py
def __init__(
    self,
    *,
    decay: float = 0.9999,
    start_iter: int = 0,
) -> None:
    self._decay = float(decay)
    self._start_iter = int(start_iter)
    self._ema_started = False
    self.student_ema: EMA_FSDP | None = None
Functions
fastvideo.train.callbacks.EMACallback.ema_context
ema_context(transformer: Module) -> Generator[Module, None, None]

Temporarily swap EMA weights into transformer.

If EMA is not active, yields the transformer unchanged.

Source code in fastvideo/train/callbacks/ema.py
@contextlib.contextmanager
def ema_context(
    self,
    transformer: torch.nn.Module,
) -> Generator[torch.nn.Module, None, None]:
    """Temporarily swap EMA weights into *transformer*.

    If EMA is not active, yields the transformer unchanged.
    """
    if (self.student_ema is not None and self._ema_started):
        with self.student_ema.apply_to_model(transformer, ):
            yield transformer
    else:
        yield transformer
fastvideo.train.callbacks.GradNormClipCallback
GradNormClipCallback(*, max_grad_norm: float = 1.0, log_grad_norms: bool = True)

Bases: Callback

Clip gradient norms before the optimizer step.

max_grad_norm must be set explicitly in the callback config (callbacks.grad_clip.max_grad_norm).

Source code in fastvideo/train/callbacks/grad_clip.py
def __init__(
    self,
    *,
    max_grad_norm: float = 1.0,
    log_grad_norms: bool = True,
) -> None:
    self._max_grad_norm = float(max_grad_norm)
    self._log_grad_norms = bool(log_grad_norms)
fastvideo.train.callbacks.ValidationCallback
ValidationCallback(*, pipeline_target: str, dataset_file: str, every_steps: int = 100, sampling_steps: list[int] | None = None, guidance_scale: float | None = None, num_frames: int | None = None, output_dir: str | None = None, sampling_timesteps: list[int] | None = None, **pipeline_kwargs: Any)

Bases: Callback

Generic validation callback driven entirely by YAML config.

Works with any pipeline that follows the PipelineCls.from_pretrained(...) + pipeline.forward() contract.

Source code in fastvideo/train/callbacks/validation.py
def __init__(
    self,
    *,
    pipeline_target: str,
    dataset_file: str,
    every_steps: int = 100,
    sampling_steps: list[int] | None = None,
    guidance_scale: float | None = None,
    num_frames: int | None = None,
    output_dir: str | None = None,
    sampling_timesteps: list[int] | None = None,
    **pipeline_kwargs: Any,
) -> None:
    self.pipeline_target = str(pipeline_target)
    self.dataset_file = str(dataset_file)
    self.every_steps = int(every_steps)
    self.sampling_steps = ([int(s) for s in sampling_steps] if sampling_steps else [40])
    self.guidance_scale = (float(guidance_scale) if guidance_scale is not None else None)
    self.num_frames = (int(num_frames) if num_frames is not None else None)
    self.output_dir = (str(output_dir) if output_dir is not None else None)
    self.sampling_timesteps = ([int(s) for s in sampling_timesteps] if sampling_timesteps is not None else None)
    self.pipeline_kwargs = dict(pipeline_kwargs)

    # Set after on_train_start.
    self._pipeline: Any | None = None
    self._pipeline_key: tuple[Any, ...] | None = None
    self._sampling_param: SamplingParam | None = None
    self.tracker: Any = DummyTracker()
    self.validation_random_generator: (torch.Generator | None) = None
    self.seed: int = 0

Modules

fastvideo.train.callbacks.callback

Callback base class and CallbackDict manager.

Adapted from FastGen's callback pattern to FastVideo's types.

Classes
fastvideo.train.callbacks.callback.Callback

Base callback with no-op hooks.

Subclasses override whichever hooks they need. The training_config and method attributes are set by CallbackDict after instantiation.

fastvideo.train.callbacks.callback.CallbackDict
CallbackDict(callback_configs: dict[str, dict[str, Any]], training_config: TrainingConfig)

Manages a collection of named callbacks.

Instantiates each callback from its _target_ config and dispatches hook calls to all registered callbacks.

Source code in fastvideo/train/callbacks/callback.py
def __init__(
    self,
    callback_configs: dict[str, dict[str, Any]],
    training_config: TrainingConfig,
) -> None:
    self._callbacks: dict[str, Callback] = {}
    if not callback_configs:
        return
    for name, cb_cfg in callback_configs.items():
        cb_cfg = dict(cb_cfg)
        if "_target_" not in cb_cfg:
            if name in _BUILTIN_CALLBACKS:
                cb_cfg["_target_"] = (_BUILTIN_CALLBACKS[name])
            else:
                logger.warning(
                    "Callback %r is missing "
                    "'_target_', skipping: %s",
                    name,
                    cb_cfg,
                )
                continue
        logger.info(
            "Instantiating callback %r: %s",
            name,
            cb_cfg,
        )
        cb = instantiate(cb_cfg)
        if not isinstance(cb, Callback):
            raise TypeError(f"Callback {name!r} resolved to "
                            f"{type(cb).__name__}, expected a "
                            f"Callback subclass.")
        cb.training_config = training_config
        cb._callback_dict = self
        self._callbacks[name] = cb
Functions
fastvideo.train.callbacks.ema

EMA (Exponential Moving Average) callback.

Owns the full EMA lifecycle: creation, per-step updates, weight swapping for validation, and checkpoint state. All EMA config lives under callbacks.ema in the YAML file.

Classes
fastvideo.train.callbacks.ema.EMACallback
EMACallback(*, decay: float = 0.9999, start_iter: int = 0)

Bases: Callback

Manage EMA shadow weights for the student transformer.

All configuration lives in the YAML callbacks.ema section:

.. code-block:: yaml

callbacks:
  ema:
    decay: 0.9999
    start_iter: 0

The callback creates an EMA_FSDP instance at train start, updates it after each optimizer step, and exposes an ema_context() context manager for temporarily swapping EMA weights into the live model (used by validation).

Source code in fastvideo/train/callbacks/ema.py
def __init__(
    self,
    *,
    decay: float = 0.9999,
    start_iter: int = 0,
) -> None:
    self._decay = float(decay)
    self._start_iter = int(start_iter)
    self._ema_started = False
    self.student_ema: EMA_FSDP | None = None
Functions
fastvideo.train.callbacks.ema.EMACallback.ema_context
ema_context(transformer: Module) -> Generator[Module, None, None]

Temporarily swap EMA weights into transformer.

If EMA is not active, yields the transformer unchanged.

Source code in fastvideo/train/callbacks/ema.py
@contextlib.contextmanager
def ema_context(
    self,
    transformer: torch.nn.Module,
) -> Generator[torch.nn.Module, None, None]:
    """Temporarily swap EMA weights into *transformer*.

    If EMA is not active, yields the transformer unchanged.
    """
    if (self.student_ema is not None and self._ema_started):
        with self.student_ema.apply_to_model(transformer, ):
            yield transformer
    else:
        yield transformer
Functions
fastvideo.train.callbacks.grad_clip

Gradient norm clipping callback.

Clips gradients on modules returned by method.get_grad_clip_targets() before the optimizer step. Optionally logs per-module grad norms to the tracker.

Classes
fastvideo.train.callbacks.grad_clip.GradNormClipCallback
GradNormClipCallback(*, max_grad_norm: float = 1.0, log_grad_norms: bool = True)

Bases: Callback

Clip gradient norms before the optimizer step.

max_grad_norm must be set explicitly in the callback config (callbacks.grad_clip.max_grad_norm).

Source code in fastvideo/train/callbacks/grad_clip.py
def __init__(
    self,
    *,
    max_grad_norm: float = 1.0,
    log_grad_norms: bool = True,
) -> None:
    self._max_grad_norm = float(max_grad_norm)
    self._log_grad_norms = bool(log_grad_norms)
Functions
fastvideo.train.callbacks.validation

Validation callback.

All configuration is read from the YAML callbacks.validation section. The pipeline class is resolved from pipeline_target.

Classes
fastvideo.train.callbacks.validation.ValidationCallback
ValidationCallback(*, pipeline_target: str, dataset_file: str, every_steps: int = 100, sampling_steps: list[int] | None = None, guidance_scale: float | None = None, num_frames: int | None = None, output_dir: str | None = None, sampling_timesteps: list[int] | None = None, **pipeline_kwargs: Any)

Bases: Callback

Generic validation callback driven entirely by YAML config.

Works with any pipeline that follows the PipelineCls.from_pretrained(...) + pipeline.forward() contract.

Source code in fastvideo/train/callbacks/validation.py
def __init__(
    self,
    *,
    pipeline_target: str,
    dataset_file: str,
    every_steps: int = 100,
    sampling_steps: list[int] | None = None,
    guidance_scale: float | None = None,
    num_frames: int | None = None,
    output_dir: str | None = None,
    sampling_timesteps: list[int] | None = None,
    **pipeline_kwargs: Any,
) -> None:
    self.pipeline_target = str(pipeline_target)
    self.dataset_file = str(dataset_file)
    self.every_steps = int(every_steps)
    self.sampling_steps = ([int(s) for s in sampling_steps] if sampling_steps else [40])
    self.guidance_scale = (float(guidance_scale) if guidance_scale is not None else None)
    self.num_frames = (int(num_frames) if num_frames is not None else None)
    self.output_dir = (str(output_dir) if output_dir is not None else None)
    self.sampling_timesteps = ([int(s) for s in sampling_timesteps] if sampling_timesteps is not None else None)
    self.pipeline_kwargs = dict(pipeline_kwargs)

    # Set after on_train_start.
    self._pipeline: Any | None = None
    self._pipeline_key: tuple[Any, ...] | None = None
    self._sampling_param: SamplingParam | None = None
    self.tracker: Any = DummyTracker()
    self.validation_random_generator: (torch.Generator | None) = None
    self.seed: int = 0
Functions

fastvideo.train.entrypoint

Modules

fastvideo.train.entrypoint.dcp_to_diffusers

Convert a DCP training checkpoint to a diffusers-style model directory.

Works on a single GPU regardless of how many GPUs were used for training (DCP handles resharding automatically).

Usage (no torchrun needed)::

python -m fastvideo.train.entrypoint.dcp_to_diffusers         --checkpoint /path/to/checkpoint-1000         --output-dir /path/to/diffusers_output

Or with torchrun (also fine)::

torchrun --nproc_per_node=1         -m fastvideo.train.entrypoint.dcp_to_diffusers         --checkpoint ... --output-dir ...

The checkpoint must contain metadata.json (written by CheckpointManager). If the checkpoint predates metadata support, pass --config explicitly to provide the training YAML.

Functions
fastvideo.train.entrypoint.dcp_to_diffusers.convert
convert(*, checkpoint_dir: str, output_dir: str, config_path: str | None = None, role: str = 'student', overwrite: bool = False) -> str

Load a DCP checkpoint and export as a diffusers model.

Returns the path to the exported model directory.

Source code in fastvideo/train/entrypoint/dcp_to_diffusers.py
def convert(
    *,
    checkpoint_dir: str,
    output_dir: str,
    config_path: str | None = None,
    role: str = "student",
    overwrite: bool = False,
) -> str:
    """Load a DCP checkpoint and export as a diffusers model.

    Returns the path to the exported model directory.
    """
    _ensure_distributed()

    from fastvideo.distributed import (
        maybe_init_distributed_environment_and_model_parallel, )
    from fastvideo.train.utils.builder import build_from_config
    from fastvideo.train.utils.checkpoint import (
        CheckpointManager,
        _resolve_resume_checkpoint,
    )
    from fastvideo.train.utils.config import (
        RunConfig,
        load_run_config,
    )

    import torch.distributed.checkpoint as dcp

    # -- Resolve checkpoint directory --
    resolved = _resolve_resume_checkpoint(
        checkpoint_dir,
        output_dir=checkpoint_dir,
    )
    dcp_dir = resolved / "dcp"
    if not dcp_dir.is_dir():
        raise FileNotFoundError(f"Missing dcp/ under {resolved}")

    # -- Obtain config --
    cfg: RunConfig
    if config_path is not None:
        cfg = load_run_config(config_path)
    else:
        metadata = CheckpointManager.load_metadata(resolved)
        raw_config = metadata.get("config")
        if raw_config is None:
            raise ValueError("Checkpoint metadata.json does not "
                             "contain 'config'. Pass --config "
                             "explicitly.")
        cfg = _run_config_from_raw(raw_config)

    tc = cfg.training

    # -- Init distributed (1 GPU is enough; DCP reshards) --
    maybe_init_distributed_environment_and_model_parallel(
        tp_size=1,
        sp_size=1,
    )

    # Override distributed config so model loading uses 1 GPU.
    tc.distributed.tp_size = 1
    tc.distributed.sp_size = 1
    tc.distributed.num_gpus = 1
    tc.distributed.hsdp_replicate_dim = 1
    tc.distributed.hsdp_shard_dim = 1

    # -- Build model (loads pretrained weights + FSDP) --
    _, method, _, _ = build_from_config(cfg)

    # -- Load DCP weights into the model --
    states = method.checkpoint_state()
    logger.info(
        "Loading DCP checkpoint from %s",
        resolved,
    )
    dcp.load(states, checkpoint_id=str(dcp_dir))

    # -- Export to diffusers format --
    model = method._role_models[role]
    base_model_path = str(tc.model_path)
    if not base_model_path:
        raise ValueError("Cannot determine base_model_path from "
                         "config. Ensure models.student.init_from "
                         "is set.")

    logger.info(
        "Exporting role=%s to %s (base=%s)",
        role,
        output_dir,
        base_model_path,
    )
    result = _save_role_pretrained(
        role=role,
        base_model_path=base_model_path,
        output_dir=output_dir,
        overwrite=overwrite,
        model=model,
    )
    logger.info("Export complete: %s", result)
    return result
fastvideo.train.entrypoint.misc
Modules
fastvideo.train.entrypoint.misc.wan_ode_init_conversion

Convert Self-Forcing ode_init.pt to HuggingFace diffusers format.

The official ode_init.pt from https://huggingface.co/gdhe17/Self-Forcing/resolve/main/checkpoints/ode_init.pt stores weights under {"generator": {<original_wan_keys>}}.

This script converts those keys to diffusers WanTransformer3DModel format, verifies them against a reference model, and saves a complete diffusers-compatible model directory (transformer + scheduler + vae + text_encoder + tokenizer).

Usage

python -m fastvideo.train.entrypoint.misc.wan_ode_init_conversion --input /path/to/ode_init.pt --output /path/to/WanOdeInit --base-model Wan-AI/Wan2.1-T2V-1.3B-Diffusers

Functions
fastvideo.train.entrypoint.misc.wan_ode_init_conversion.convert_state_dict
convert_state_dict(orig_sd: dict[str, Tensor]) -> dict[str, Tensor]

Convert an entire original-Wan state dict.

Source code in fastvideo/train/entrypoint/misc/wan_ode_init_conversion.py
def convert_state_dict(orig_sd: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]:
    """Convert an entire original-Wan state dict."""
    return {_convert_key(k): v for k, v in orig_sd.items()}
fastvideo.train.entrypoint.train

YAML-only training entrypoint.

Usage::

torchrun --nproc_per_node=<N> -m fastvideo.train.entrypoint.train         --config path/to/run.yaml

Any unknown --dotted.key value arguments are applied as overrides to the YAML config before parsing. For example::

torchrun --nproc_per_node=8 -m fastvideo.train.entrypoint.train         --config path/to/run.yaml         --training.distributed.num_gpus 8         --training.optimizer.learning_rate 1e-5
Functions
fastvideo.train.entrypoint.train.run_training_from_config
run_training_from_config(config_path: str, *, dry_run: bool = False, overrides: list[str] | None = None) -> None

YAML-only training entrypoint (schema v2).

Source code in fastvideo/train/entrypoint/train.py
def run_training_from_config(
    config_path: str,
    *,
    dry_run: bool = False,
    overrides: list[str] | None = None,
) -> None:
    """YAML-only training entrypoint (schema v2)."""

    from fastvideo.distributed import (
        maybe_init_distributed_environment_and_model_parallel, )
    from fastvideo.train import Trainer
    from fastvideo.train.utils.checkpoint import (
        CheckpointConfig,
        CheckpointManager,
    )
    from fastvideo.train.utils.builder import build_from_config
    from fastvideo.train.utils.config import load_run_config

    # Enable deterministic mode for reproducibility.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    cfg = load_run_config(config_path, overrides=overrides)
    tc = cfg.training

    # Auto-set attention backend for VSA when sparsity is configured.
    if tc.vsa_sparsity > 0.0:
        os.environ.setdefault(
            "FASTVIDEO_ATTENTION_BACKEND",
            "VIDEO_SPARSE_ATTN",
        )

    maybe_init_distributed_environment_and_model_parallel(
        tc.distributed.tp_size,
        tc.distributed.sp_size,
    )

    _, method, dataloader, start_step = build_from_config(cfg)

    if dry_run:
        logger.info("Dry-run: config parsed and "
                    "build_from_config succeeded.")
        return

    trainer = Trainer(
        tc,
        config=cfg.resolved_config(),
        callback_configs=cfg.callbacks,
    )

    # Attach the exact YAML used for this run to the
    # tracker (e.g., W&B Files).
    trainer.tracker.log_file(
        os.path.abspath(os.path.expanduser(config_path)),
        name="run.yaml",
    )

    ckpt_config = CheckpointConfig(
        save_steps=int(tc.checkpoint.training_state_checkpointing_steps or 0),
        keep_last=int(tc.checkpoint.checkpoints_total_limit or 0),
    )

    checkpoint_manager = CheckpointManager(
        method=method,
        dataloader=dataloader,
        output_dir=tc.checkpoint.output_dir,
        config=ckpt_config,
        callbacks=trainer.callbacks,
        raw_config=cfg.raw,
    )

    trainer.run(
        method,
        dataloader=dataloader,
        max_steps=tc.loop.max_train_steps,
        start_step=start_step,
        checkpoint_manager=checkpoint_manager,
    )

fastvideo.train.methods

Classes

fastvideo.train.methods.TrainingMethod
TrainingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: Module, ABC

Base training method (algorithm layer).

Subclasses own their role models (student, teacher, critic, …) as plain attributes and manage optimizers directly — no RoleManager or RoleHandle.

The constructor receives role_models (a dict[str, ModelBase]) and a cfg object. It calls init_preprocessors on the student and builds self.role_modules for FSDP wrapping.

A single shared CUDA RNG generator (cuda_generator) is created in :meth:on_train_start. All torch.randn / torch.randint calls in methods and models must use this generator instead of relying on global RNG state.

Source code in fastvideo/train/methods/base.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__()
    self.tracker: Any | None = None
    self._role_models: dict[str, ModelBase] = dict(role_models)

    self.student = role_models["student"]
    self.training_config = cfg.training
    self.method_config: dict[str, Any] = dict(cfg.method)
    self.validation_config: dict[str, Any] = dict(getattr(cfg, "validation", {}) or {})

    # Build nn.ModuleDict for FSDP / checkpoint visibility.
    self.role_modules = torch.nn.ModuleDict()
    for role, model in role_models.items():
        mods: dict[str, torch.nn.Module] = {}
        transformer = getattr(model, "transformer", None)
        if isinstance(transformer, torch.nn.Module):
            mods["transformer"] = transformer
        if mods:
            self.role_modules[role] = torch.nn.ModuleDict(mods)
Functions
fastvideo.train.methods.TrainingMethod.checkpoint_state
checkpoint_state() -> dict[str, Any]

Return DCP-ready checkpoint state for all trainable roles.

Keys follow the convention: roles.<role>.<module>, optimizers.<role>, schedulers.<role>, random_state.*.

EMA state is managed by the EMACallback and is checkpointed through the callback state mechanism.

Source code in fastvideo/train/methods/base.py
def checkpoint_state(self) -> dict[str, Any]:
    """Return DCP-ready checkpoint state for all trainable roles.

    Keys follow the convention:
    ``roles.<role>.<module>``, ``optimizers.<role>``,
    ``schedulers.<role>``, ``random_state.*``.

    EMA state is managed by the ``EMACallback`` and is
    checkpointed through the callback state mechanism.
    """
    states: dict[str, Any] = {}

    for role, model in self._role_models.items():
        if not getattr(model, "_trainable", False):
            continue

        modules: dict[str, torch.nn.Module] = {}
        if model.transformer is not None:
            modules["transformer"] = model.transformer

        container = _RoleModuleContainer(modules)

        for module_name, module in modules.items():
            states[f"roles.{role}.{module_name}"] = ModelWrapper(module)

        opt = self._optimizer_dict.get(role)
        if opt is not None:
            states[f"optimizers.{role}"] = OptimizerWrapper(container, opt)

        sched = self._lr_scheduler_dict.get(role)
        if sched is not None:
            states[f"schedulers.{role}"] = SchedulerWrapper(sched)

    return states
fastvideo.train.methods.TrainingMethod.get_grad_clip_targets
get_grad_clip_targets(iteration: int) -> dict[str, Module]

Return modules whose gradients should be clipped.

Override in subclasses to add/conditionally include modules (e.g. critic, conditionally student). Default: student transformer.

Source code in fastvideo/train/methods/base.py
def get_grad_clip_targets(
    self,
    iteration: int,
) -> dict[str, torch.nn.Module]:
    """Return modules whose gradients should be clipped.

    Override in subclasses to add/conditionally include
    modules (e.g. critic, conditionally student).
    Default: student transformer.
    """
    return {"student": self.student.transformer}
fastvideo.train.methods.TrainingMethod.seed_optimizer_state_for_resume
seed_optimizer_state_for_resume() -> None

Seed optimizer state so DCP can load saved state.

A fresh optimizer has empty state (exp_avg, exp_avg_sq, step are only created on the first optimizer.step()). DCP needs matching entries to load into; without them the saved optimizer state is silently dropped.

Source code in fastvideo/train/methods/base.py
def seed_optimizer_state_for_resume(self) -> None:
    """Seed optimizer state so DCP can load saved state.

    A fresh optimizer has empty state (exp_avg, exp_avg_sq,
    step are only created on the first optimizer.step()).
    DCP needs matching entries to load into; without them
    the saved optimizer state is silently dropped.
    """
    for opt in self.get_optimizers(0):
        for group in opt.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                if len(opt.state.get(p, {})) > 0:
                    continue
                opt.state[p] = {
                    "step": torch.tensor(0.0),
                    "exp_avg": torch.zeros_like(p),
                    "exp_avg_sq": torch.zeros_like(p),
                }

Modules

fastvideo.train.methods.base
Classes
fastvideo.train.methods.base.TrainingMethod
TrainingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: Module, ABC

Base training method (algorithm layer).

Subclasses own their role models (student, teacher, critic, …) as plain attributes and manage optimizers directly — no RoleManager or RoleHandle.

The constructor receives role_models (a dict[str, ModelBase]) and a cfg object. It calls init_preprocessors on the student and builds self.role_modules for FSDP wrapping.

A single shared CUDA RNG generator (cuda_generator) is created in :meth:on_train_start. All torch.randn / torch.randint calls in methods and models must use this generator instead of relying on global RNG state.

Source code in fastvideo/train/methods/base.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__()
    self.tracker: Any | None = None
    self._role_models: dict[str, ModelBase] = dict(role_models)

    self.student = role_models["student"]
    self.training_config = cfg.training
    self.method_config: dict[str, Any] = dict(cfg.method)
    self.validation_config: dict[str, Any] = dict(getattr(cfg, "validation", {}) or {})

    # Build nn.ModuleDict for FSDP / checkpoint visibility.
    self.role_modules = torch.nn.ModuleDict()
    for role, model in role_models.items():
        mods: dict[str, torch.nn.Module] = {}
        transformer = getattr(model, "transformer", None)
        if isinstance(transformer, torch.nn.Module):
            mods["transformer"] = transformer
        if mods:
            self.role_modules[role] = torch.nn.ModuleDict(mods)
Functions
fastvideo.train.methods.base.TrainingMethod.checkpoint_state
checkpoint_state() -> dict[str, Any]

Return DCP-ready checkpoint state for all trainable roles.

Keys follow the convention: roles.<role>.<module>, optimizers.<role>, schedulers.<role>, random_state.*.

EMA state is managed by the EMACallback and is checkpointed through the callback state mechanism.

Source code in fastvideo/train/methods/base.py
def checkpoint_state(self) -> dict[str, Any]:
    """Return DCP-ready checkpoint state for all trainable roles.

    Keys follow the convention:
    ``roles.<role>.<module>``, ``optimizers.<role>``,
    ``schedulers.<role>``, ``random_state.*``.

    EMA state is managed by the ``EMACallback`` and is
    checkpointed through the callback state mechanism.
    """
    states: dict[str, Any] = {}

    for role, model in self._role_models.items():
        if not getattr(model, "_trainable", False):
            continue

        modules: dict[str, torch.nn.Module] = {}
        if model.transformer is not None:
            modules["transformer"] = model.transformer

        container = _RoleModuleContainer(modules)

        for module_name, module in modules.items():
            states[f"roles.{role}.{module_name}"] = ModelWrapper(module)

        opt = self._optimizer_dict.get(role)
        if opt is not None:
            states[f"optimizers.{role}"] = OptimizerWrapper(container, opt)

        sched = self._lr_scheduler_dict.get(role)
        if sched is not None:
            states[f"schedulers.{role}"] = SchedulerWrapper(sched)

    return states
fastvideo.train.methods.base.TrainingMethod.get_grad_clip_targets
get_grad_clip_targets(iteration: int) -> dict[str, Module]

Return modules whose gradients should be clipped.

Override in subclasses to add/conditionally include modules (e.g. critic, conditionally student). Default: student transformer.

Source code in fastvideo/train/methods/base.py
def get_grad_clip_targets(
    self,
    iteration: int,
) -> dict[str, torch.nn.Module]:
    """Return modules whose gradients should be clipped.

    Override in subclasses to add/conditionally include
    modules (e.g. critic, conditionally student).
    Default: student transformer.
    """
    return {"student": self.student.transformer}
fastvideo.train.methods.base.TrainingMethod.seed_optimizer_state_for_resume
seed_optimizer_state_for_resume() -> None

Seed optimizer state so DCP can load saved state.

A fresh optimizer has empty state (exp_avg, exp_avg_sq, step are only created on the first optimizer.step()). DCP needs matching entries to load into; without them the saved optimizer state is silently dropped.

Source code in fastvideo/train/methods/base.py
def seed_optimizer_state_for_resume(self) -> None:
    """Seed optimizer state so DCP can load saved state.

    A fresh optimizer has empty state (exp_avg, exp_avg_sq,
    step are only created on the first optimizer.step()).
    DCP needs matching entries to load into; without them
    the saved optimizer state is silently dropped.
    """
    for opt in self.get_optimizers(0):
        for group in opt.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                if len(opt.state.get(p, {})) > 0:
                    continue
                opt.state[p] = {
                    "step": torch.tensor(0.0),
                    "exp_avg": torch.zeros_like(p),
                    "exp_avg_sq": torch.zeros_like(p),
                }
Functions
fastvideo.train.methods.distribution_matching
Classes
fastvideo.train.methods.distribution_matching.DMD2Method
DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
fastvideo.train.methods.distribution_matching.SelfForcingMethod
SelfForcingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: DMD2Method

Self-Forcing DMD2 (distribution matching) method.

Requires a causal student implementing CausalModelBase.

Source code in fastvideo/train/methods/distribution_matching/self_forcing.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(
        cfg=cfg,
        role_models=role_models,
    )

    # Validate causal student.
    if not isinstance(self.student, CausalModelBase):
        raise ValueError("SelfForcingMethod requires a causal student "
                         "implementing CausalModelBase.")

    if self._rollout_mode != "simulate":
        raise ValueError("SelfForcingMethod only supports "
                         "method_config.rollout_mode='simulate'")

    mcfg = self.method_config

    chunk_size = get_optional_int(
        mcfg,
        "chunk_size",
        where="method_config.chunk_size",
    )
    if chunk_size is None:
        chunk_size = 3
    if chunk_size <= 0:
        raise ValueError("method_config.chunk_size must be a positive "
                         f"integer, got {chunk_size}")
    self._chunk_size = int(chunk_size)

    sample_type_raw = mcfg.get("student_sample_type", "sde")
    sample_type = _require_str(
        sample_type_raw,
        where="method_config.student_sample_type",
    )
    sample_type = sample_type.strip().lower()
    if sample_type not in {"sde", "ode"}:
        raise ValueError("method_config.student_sample_type must be one "
                         f"of {{sde, ode}}, got {sample_type_raw!r}")
    self._student_sample_type: Literal["sde", "ode"] = (
        sample_type  # type: ignore[assignment]
    )

    same_step_raw = mcfg.get("same_step_across_blocks", False)
    if same_step_raw is None:
        same_step_raw = False
    self._same_step_across_blocks = _require_bool(
        same_step_raw,
        where="method_config.same_step_across_blocks",
    )

    last_step_raw = mcfg.get("last_step_only", False)
    if last_step_raw is None:
        last_step_raw = False
    self._last_step_only = _require_bool(
        last_step_raw,
        where="method_config.last_step_only",
    )

    context_noise = get_optional_float(
        mcfg,
        "context_noise",
        where="method_config.context_noise",
    )
    if context_noise is None:
        context_noise = 0.0
    if context_noise < 0.0:
        raise ValueError("method_config.context_noise must be >= 0, "
                         f"got {context_noise}")
    self._context_noise = float(context_noise)

    enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
    if enable_grad_raw is None:
        enable_grad_raw = True
    self._enable_gradient_in_rollout = _require_bool(
        enable_grad_raw,
        where="method_config.enable_gradient_in_rollout",
    )

    start_grad_frame = get_optional_int(
        mcfg,
        "start_gradient_frame",
        where="method_config.start_gradient_frame",
    )
    if start_grad_frame is None:
        start_grad_frame = 0
    if start_grad_frame < 0:
        raise ValueError("method_config.start_gradient_frame must be "
                         f">= 0, got {start_grad_frame}")
    self._start_gradient_frame = int(start_grad_frame)

    shift = float(getattr(
        self.training_config.pipeline_config,
        "flow_shift",
        0.0,
    ) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=True,
    )

    self._sf_denoising_step_list: torch.Tensor | None = None
Modules
fastvideo.train.methods.distribution_matching.dmd2

DMD2 distillation method (algorithm layer).

Classes
fastvideo.train.methods.distribution_matching.dmd2.DMD2Method
DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions
fastvideo.train.methods.distribution_matching.self_forcing

Self-Forcing distillation method (algorithm layer).

Classes
fastvideo.train.methods.distribution_matching.self_forcing.SelfForcingMethod
SelfForcingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: DMD2Method

Self-Forcing DMD2 (distribution matching) method.

Requires a causal student implementing CausalModelBase.

Source code in fastvideo/train/methods/distribution_matching/self_forcing.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(
        cfg=cfg,
        role_models=role_models,
    )

    # Validate causal student.
    if not isinstance(self.student, CausalModelBase):
        raise ValueError("SelfForcingMethod requires a causal student "
                         "implementing CausalModelBase.")

    if self._rollout_mode != "simulate":
        raise ValueError("SelfForcingMethod only supports "
                         "method_config.rollout_mode='simulate'")

    mcfg = self.method_config

    chunk_size = get_optional_int(
        mcfg,
        "chunk_size",
        where="method_config.chunk_size",
    )
    if chunk_size is None:
        chunk_size = 3
    if chunk_size <= 0:
        raise ValueError("method_config.chunk_size must be a positive "
                         f"integer, got {chunk_size}")
    self._chunk_size = int(chunk_size)

    sample_type_raw = mcfg.get("student_sample_type", "sde")
    sample_type = _require_str(
        sample_type_raw,
        where="method_config.student_sample_type",
    )
    sample_type = sample_type.strip().lower()
    if sample_type not in {"sde", "ode"}:
        raise ValueError("method_config.student_sample_type must be one "
                         f"of {{sde, ode}}, got {sample_type_raw!r}")
    self._student_sample_type: Literal["sde", "ode"] = (
        sample_type  # type: ignore[assignment]
    )

    same_step_raw = mcfg.get("same_step_across_blocks", False)
    if same_step_raw is None:
        same_step_raw = False
    self._same_step_across_blocks = _require_bool(
        same_step_raw,
        where="method_config.same_step_across_blocks",
    )

    last_step_raw = mcfg.get("last_step_only", False)
    if last_step_raw is None:
        last_step_raw = False
    self._last_step_only = _require_bool(
        last_step_raw,
        where="method_config.last_step_only",
    )

    context_noise = get_optional_float(
        mcfg,
        "context_noise",
        where="method_config.context_noise",
    )
    if context_noise is None:
        context_noise = 0.0
    if context_noise < 0.0:
        raise ValueError("method_config.context_noise must be >= 0, "
                         f"got {context_noise}")
    self._context_noise = float(context_noise)

    enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
    if enable_grad_raw is None:
        enable_grad_raw = True
    self._enable_gradient_in_rollout = _require_bool(
        enable_grad_raw,
        where="method_config.enable_gradient_in_rollout",
    )

    start_grad_frame = get_optional_int(
        mcfg,
        "start_gradient_frame",
        where="method_config.start_gradient_frame",
    )
    if start_grad_frame is None:
        start_grad_frame = 0
    if start_grad_frame < 0:
        raise ValueError("method_config.start_gradient_frame must be "
                         f">= 0, got {start_grad_frame}")
    self._start_gradient_frame = int(start_grad_frame)

    shift = float(getattr(
        self.training_config.pipeline_config,
        "flow_shift",
        0.0,
    ) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=True,
    )

    self._sf_denoising_step_list: torch.Tensor | None = None
Functions
fastvideo.train.methods.fine_tuning
Classes
fastvideo.train.methods.fine_tuning.DiffusionForcingSFTMethod
DiffusionForcingSFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Diffusion-forcing SFT (DFSFT): train only student with inhomogeneous timesteps.

Source code in fastvideo/train/methods/fine_tuning/dfsft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DFSFT requires role 'student'")
    if not self.student._trainable:
        raise ValueError("DFSFT requires student to be trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    self._chunk_size = self._parse_chunk_size(self.method_config.get("chunk_size", None))
    self._timestep_index_range = (self._parse_timestep_index_range())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
fastvideo.train.methods.fine_tuning.FineTuneMethod
FineTuneMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Supervised finetuning: only student participates.

Source code in fastvideo/train/methods/fine_tuning/finetune.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("FineTuneMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("FineTuneMethod requires student to be "
                         "trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Modules
fastvideo.train.methods.fine_tuning.dfsft

Diffusion-forcing SFT method (DFSFT; algorithm layer).

Classes
fastvideo.train.methods.fine_tuning.dfsft.DiffusionForcingSFTMethod
DiffusionForcingSFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Diffusion-forcing SFT (DFSFT): train only student with inhomogeneous timesteps.

Source code in fastvideo/train/methods/fine_tuning/dfsft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DFSFT requires role 'student'")
    if not self.student._trainable:
        raise ValueError("DFSFT requires student to be trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    self._chunk_size = self._parse_chunk_size(self.method_config.get("chunk_size", None))
    self._timestep_index_range = (self._parse_timestep_index_range())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions
fastvideo.train.methods.fine_tuning.finetune

Supervised finetuning method (algorithm layer).

Classes
fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod
FineTuneMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Supervised finetuning: only student participates.

Source code in fastvideo/train/methods/fine_tuning/finetune.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("FineTuneMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("FineTuneMethod requires student to be "
                         "trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions
fastvideo.train.methods.knowledge_distillation
Classes
fastvideo.train.methods.knowledge_distillation.KDCausalMethod
KDCausalMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: KDMethod

KD for causal Wan: per-frame block-quantized timestep sampling.

Identical to :class:KDMethod except single_train_step samples a per-frame denoising step index (block-quantized to groups of num_frames_per_block frames) instead of one index per batch. This matches the legacy ODEInitTrainingPipeline training scheme required by causal / streaming student models.

Additional YAML field under method::

num_frames_per_block: 3   # frames sharing the same noise level
Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    self._num_frames_per_block: int = int(self.method_config.get("num_frames_per_block", 3))
    if self._num_frames_per_block < 1:
        raise ValueError("num_frames_per_block must be >= 1")
fastvideo.train.methods.knowledge_distillation.KDMethod
KDMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Knowledge Distillation training method.

Trains the student with MSE loss on teacher ODE trajectories cached to method_config.teacher_path_cache.

Roles
  • student (required, trainable): the model being distilled.
  • teacher (optional, non-trainable): used to generate the cache on first run; freed from GPU memory afterwards.

If the cache is incomplete and no teacher is configured, an error is raised at the start of training.

Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("KDMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("KDMethod requires student to be trainable")

    mcfg = self.method_config

    # --- Parse method config ---
    raw_t_list = mcfg.get("t_list")
    if not isinstance(raw_t_list, list) or not raw_t_list:
        raise ValueError("method_config.t_list must be a non-empty list "
                         "of integer timestep values, e.g. "
                         "[999, 937, 833, 624, 0]")
    self._t_list: list[int] = [int(t) for t in raw_t_list]

    raw_steps = mcfg.get("student_sample_steps")
    if raw_steps is None:
        raw_steps = len(self._t_list) - 1
    self._num_steps: int = int(raw_steps)
    if len(self._t_list) != self._num_steps + 1:
        raise ValueError(f"len(t_list)={len(self._t_list)} must equal "
                         f"student_sample_steps+1={self._num_steps + 1}")

    cache_dir = mcfg.get("teacher_path_cache")
    if not cache_dir:
        raise ValueError("method_config.teacher_path_cache must be set")
    self._cache_dir: str = str(cache_dir)

    self._teacher_guidance_scale: float = float(mcfg.get("teacher_guidance_scale", 1.0))
    self._teacher_inference_steps: int = int(mcfg.get("teacher_inference_steps", 48))

    # --- Optional teacher ---
    self.teacher: ModelBase | None = role_models.get("teacher")
    if self.teacher is not None and getattr(self.teacher, "_trainable", False):
        raise ValueError("KDMethod requires teacher to be non-trainable "
                         "(set trainable: false in models.teacher)")

    # --- Build parquet dataloader via student.init_preprocessors ---
    self.student.init_preprocessors(self.training_config)

    # Wrap the parquet dataloader so we can swap it after generation.
    self._source_loader = self.student.dataloader
    self._kd_wrapper = _KDDataLoaderWrapper(self._source_loader)
    self.student.dataloader = self._kd_wrapper  # builder captures this

    # --- Build SelfForcingFlowMatchScheduler for sigma lookups ---
    # num_inference_steps=1000 gives a dense grid accurate for any t.
    tc = self.training_config
    self._flow_shift = float(getattr(tc.pipeline_config, "flow_shift", 0.0) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=self._flow_shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=False,
    )

    # --- Student optimizer / scheduler (same as FineTuneMethod) ---
    self._init_optimizers_and_schedulers()
Modules
fastvideo.train.methods.knowledge_distillation.kd

Knowledge Distillation method for ODE-init training.

Trains a student model with MSE loss to reproduce a teacher model's multi-step ODE denoising trajectories. The resulting checkpoint (exported via dcp_to_diffusers) serves as the ode_init weight initialization for downstream Self-Forcing training.

Teacher path generation is cached to disk so it only runs once. Interrupted generation resumes from the last completed sample.

Typical YAML::

models:
  student:
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    trainable: true
  teacher:           # omit once cache is complete
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-14B-Diffusers
    trainable: false
    disable_custom_init_weights: true

method:
  _target_: fastvideo.train.methods.knowledge_distillation.kd.KDMethod
  teacher_path_cache: /data/kd_cache/wan14b_4step
  t_list: [999, 937, 833, 624, 0]   # integer timesteps
  student_sample_steps: 4
  teacher_guidance_scale: 1.0
Classes
fastvideo.train.methods.knowledge_distillation.kd.KDCausalMethod
KDCausalMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: KDMethod

KD for causal Wan: per-frame block-quantized timestep sampling.

Identical to :class:KDMethod except single_train_step samples a per-frame denoising step index (block-quantized to groups of num_frames_per_block frames) instead of one index per batch. This matches the legacy ODEInitTrainingPipeline training scheme required by causal / streaming student models.

Additional YAML field under method::

num_frames_per_block: 3   # frames sharing the same noise level
Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    self._num_frames_per_block: int = int(self.method_config.get("num_frames_per_block", 3))
    if self._num_frames_per_block < 1:
        raise ValueError("num_frames_per_block must be >= 1")
fastvideo.train.methods.knowledge_distillation.kd.KDMethod
KDMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Knowledge Distillation training method.

Trains the student with MSE loss on teacher ODE trajectories cached to method_config.teacher_path_cache.

Roles
  • student (required, trainable): the model being distilled.
  • teacher (optional, non-trainable): used to generate the cache on first run; freed from GPU memory afterwards.

If the cache is incomplete and no teacher is configured, an error is raised at the start of training.

Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("KDMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("KDMethod requires student to be trainable")

    mcfg = self.method_config

    # --- Parse method config ---
    raw_t_list = mcfg.get("t_list")
    if not isinstance(raw_t_list, list) or not raw_t_list:
        raise ValueError("method_config.t_list must be a non-empty list "
                         "of integer timestep values, e.g. "
                         "[999, 937, 833, 624, 0]")
    self._t_list: list[int] = [int(t) for t in raw_t_list]

    raw_steps = mcfg.get("student_sample_steps")
    if raw_steps is None:
        raw_steps = len(self._t_list) - 1
    self._num_steps: int = int(raw_steps)
    if len(self._t_list) != self._num_steps + 1:
        raise ValueError(f"len(t_list)={len(self._t_list)} must equal "
                         f"student_sample_steps+1={self._num_steps + 1}")

    cache_dir = mcfg.get("teacher_path_cache")
    if not cache_dir:
        raise ValueError("method_config.teacher_path_cache must be set")
    self._cache_dir: str = str(cache_dir)

    self._teacher_guidance_scale: float = float(mcfg.get("teacher_guidance_scale", 1.0))
    self._teacher_inference_steps: int = int(mcfg.get("teacher_inference_steps", 48))

    # --- Optional teacher ---
    self.teacher: ModelBase | None = role_models.get("teacher")
    if self.teacher is not None and getattr(self.teacher, "_trainable", False):
        raise ValueError("KDMethod requires teacher to be non-trainable "
                         "(set trainable: false in models.teacher)")

    # --- Build parquet dataloader via student.init_preprocessors ---
    self.student.init_preprocessors(self.training_config)

    # Wrap the parquet dataloader so we can swap it after generation.
    self._source_loader = self.student.dataloader
    self._kd_wrapper = _KDDataLoaderWrapper(self._source_loader)
    self.student.dataloader = self._kd_wrapper  # builder captures this

    # --- Build SelfForcingFlowMatchScheduler for sigma lookups ---
    # num_inference_steps=1000 gives a dense grid accurate for any t.
    tc = self.training_config
    self._flow_shift = float(getattr(tc.pipeline_config, "flow_shift", 0.0) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=self._flow_shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=False,
    )

    # --- Student optimizer / scheduler (same as FineTuneMethod) ---
    self._init_optimizers_and_schedulers()
Functions

fastvideo.train.models

Model build plugins for Phase 2/2.9 distillation.

These are "model plugins" selected by recipe.family / roles.<role>.family.

Modules

fastvideo.train.models.base
Classes
fastvideo.train.models.base.CausalModelBase

Bases: ModelBase

Extension for causal / streaming model plugins.

Cache state is internal to the model instance and keyed by cache_tag (no role handle needed).

Functions
fastvideo.train.models.base.CausalModelBase.clear_caches abstractmethod
clear_caches(*, cache_tag: str = 'pos') -> None

Clear internal caches before starting a new rollout.

Source code in fastvideo/train/models/base.py
@abstractmethod
def clear_caches(self, *, cache_tag: str = "pos") -> None:
    """Clear internal caches before starting a new rollout."""
fastvideo.train.models.base.CausalModelBase.predict_noise_streaming abstractmethod
predict_noise_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Streaming predict-noise that may update internal caches.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Streaming predict-noise that may update internal caches."""
fastvideo.train.models.base.CausalModelBase.predict_x0_streaming
predict_x0_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Predict x0 streaming via predict_noise_streaming + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Predict x0 streaming via
    ``predict_noise_streaming`` + conversion."""
    pred_noise = self.predict_noise_streaming(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cache_tag=cache_tag,
        store_kv=store_kv,
        cur_start_frame=cur_start_frame,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    if pred_noise is None:
        return None
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase

Bases: ABC

Per-role model instance.

Every role (student, teacher, critic, …) gets its own ModelBase instance. Each instance owns its own transformer and noise_scheduler. Heavyweight resources (VAE, dataloader, RNG seeds) are loaded lazily via :meth:init_preprocessors, which the method calls only on the student.

Attributes
fastvideo.train.models.base.ModelBase.device property
device: device

The local CUDA device for this rank.

fastvideo.train.models.base.ModelBase.num_train_timesteps property
num_train_timesteps: int

Return the scheduler's training timestep horizon.

Functions
fastvideo.train.models.base.ModelBase.add_noise abstractmethod
add_noise(clean_latents: Tensor, noise: Tensor, timestep: Tensor) -> Tensor

Apply forward-process noise at timestep.

Source code in fastvideo/train/models/base.py
@abstractmethod
def add_noise(
    self,
    clean_latents: torch.Tensor,
    noise: torch.Tensor,
    timestep: torch.Tensor,
) -> torch.Tensor:
    """Apply forward-process noise at *timestep*."""
fastvideo.train.models.base.ModelBase.backward abstractmethod
backward(loss: Tensor, ctx: Any, *, grad_accum_rounds: int) -> None

Backward that may restore forward-context.

Source code in fastvideo/train/models/base.py
@abstractmethod
def backward(
    self,
    loss: torch.Tensor,
    ctx: Any,
    *,
    grad_accum_rounds: int,
) -> None:
    """Backward that may restore forward-context."""
fastvideo.train.models.base.ModelBase.init_preprocessors
init_preprocessors(training_config: TrainingConfig) -> None

Load VAE, build dataloader, seed RNGs.

Called only on the student by the method's __init__. Default is a no-op so teacher/critic instances skip this.

Source code in fastvideo/train/models/base.py
def init_preprocessors(  # noqa: B027
        self,
        training_config: TrainingConfig,
) -> None:
    """Load VAE, build dataloader, seed RNGs.

    Called only on the student by the method's ``__init__``.
    Default is a no-op so teacher/critic instances skip this.
    """
fastvideo.train.models.base.ModelBase.on_train_start
on_train_start() -> None

Called once before the training loop begins.

Source code in fastvideo/train/models/base.py
def on_train_start(self) -> None:  # noqa: B027
    """Called once before the training loop begins."""
fastvideo.train.models.base.ModelBase.predict_noise abstractmethod
predict_noise(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict noise/flow for the given noisy latents.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict noise/flow for the given noisy latents."""
fastvideo.train.models.base.ModelBase.predict_x0
predict_x0(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict x0 via predict_noise + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict x0 via ``predict_noise`` + conversion."""
    pred_noise = self.predict_noise(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase.prepare_batch abstractmethod
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Convert a dataloader batch into forward primitives.

Source code in fastvideo/train/models/base.py
@abstractmethod
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Convert a dataloader batch into forward primitives."""
fastvideo.train.models.base.ModelBase.shift_and_clamp_timestep
shift_and_clamp_timestep(timestep: Tensor) -> Tensor

Apply model/pipeline timestep shifting and clamp.

Source code in fastvideo/train/models/base.py
def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
    """Apply model/pipeline timestep shifting and clamp."""
    return timestep
Functions
fastvideo.train.models.wan

Wan model plugin package.

Classes
Modules
fastvideo.train.models.wan.wan

Wan model plugin (per-role instance).

Classes
fastvideo.train.models.wan.wan.WanModel
WanModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: ModelBase

Wan per-role model: owns transformer + noise_scheduler.

Source code in fastvideo/train/models/wan/wan.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    self._init_from = str(init_from)
    self._trainable = bool(trainable)

    self.transformer = self._load_transformer(
        init_from=self._init_from,
        trainable=self._trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        training_config=training_config,
        transformer_override_safetensor=(transformer_override_safetensor),
    )

    self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))

    # Filled by init_preprocessors (student only).
    self.vae: Any = None
    self.training_config: TrainingConfig = training_config
    self.dataloader: Any = None
    self.validator: Any = None
    self.start_step: int = 0

    self.world_group: Any = None
    self.sp_group: Any = None

    self.negative_prompt_embeds: (torch.Tensor | None) = None
    self.negative_prompt_attention_mask: (torch.Tensor | None) = None

    # Timestep mechanics.
    self.timestep_shift: float = float(flow_shift)
    self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
    self.min_timestep: int = 0
    self.max_timestep: int = self.num_train_timestep
Functions
fastvideo.train.models.wan.wan_causal

Wan causal model plugin (per-role instance, streaming/cache).

Classes
fastvideo.train.models.wan.wan_causal.WanCausalModel
WanCausalModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel, CausalModelBase

Wan per-role model with causal/streaming primitives.

Source code in fastvideo/train/models/wan/wan_causal.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
    )
    self._streaming_caches: (dict[tuple[int, str], _StreamingCaches]) = {}
Functions

fastvideo.train.utils

Distillation utilities shared across families/methods/entrypoints.

Modules

fastvideo.train.utils.builder

Assembly: build method + dataloader from a _target_-based config.

Classes
Functions
fastvideo.train.utils.builder.build_from_config
build_from_config(cfg: RunConfig) -> tuple[TrainingConfig, TrainingMethod, Any, int]

Build method + dataloader from a v3 run config.

  1. Instantiate each model in cfg.models via _target_.
  2. Resolve the method class from cfg.method["_target_"] and construct it with (cfg=cfg, role_models=...).
  3. Return (training_args, method, dataloader, start_step).
Source code in fastvideo/train/utils/builder.py
def build_from_config(cfg: RunConfig, ) -> tuple[TrainingConfig, TrainingMethod, Any, int]:
    """Build method + dataloader from a v3 run config.

    1. Instantiate each model in ``cfg.models`` via ``_target_``.
    2. Resolve the method class from ``cfg.method["_target_"]``
       and construct it with ``(cfg=cfg, role_models=...)``.
    3. Return ``(training_args, method, dataloader, start_step)``.
    """
    from fastvideo.train.models.base import ModelBase

    # --- 1. Build role model instances ---
    role_models: dict[str, ModelBase] = {}
    for role, model_cfg in cfg.models.items():
        model = instantiate(model_cfg, training_config=cfg.training)
        if not isinstance(model, ModelBase):
            raise TypeError(f"models.{role}._target_ must resolve to a "
                            f"ModelBase subclass, got {type(model).__name__}")
        role_models[role] = model

    # --- 2. Build method ---
    method_cfg = dict(cfg.method)
    method_target = str(method_cfg.pop("_target_"))
    method_cls = resolve_target(method_target)

    # The student model provides the dataloader.
    student = role_models.get("student")

    method = method_cls(
        cfg=cfg,
        role_models=role_models,
    )

    # --- 3. Gather dataloader and start_step ---
    dataloader = (getattr(student, "dataloader", None) if student is not None else None)
    start_step = int(getattr(student, "start_step", 0) if student is not None else 0)

    return cfg.training, method, dataloader, start_step
fastvideo.train.utils.checkpoint
Classes
fastvideo.train.utils.checkpoint.CheckpointManager
CheckpointManager(*, method: Any, dataloader: Any, output_dir: str, config: CheckpointConfig, callbacks: Any | None = None, raw_config: dict[str, Any] | None = None)

Role-based checkpoint manager for training runtime.

  • Checkpoint policy lives in YAML (via TrainingArgs fields).
  • Resume path is typically provided via CLI (--resume-from-checkpoint).
Source code in fastvideo/train/utils/checkpoint.py
def __init__(
    self,
    *,
    method: Any,
    dataloader: Any,
    output_dir: str,
    config: CheckpointConfig,
    callbacks: Any | None = None,
    raw_config: dict[str, Any] | None = None,
) -> None:
    self.method = method
    self.dataloader = dataloader
    self.output_dir = str(output_dir)
    self.config = config
    self._callbacks = callbacks
    self._raw_config = raw_config
    self._last_saved_step: int | None = None
Functions
fastvideo.train.utils.checkpoint.CheckpointManager.load_metadata staticmethod
load_metadata(checkpoint_dir: str | Path) -> dict[str, Any]

Read metadata.json from a checkpoint dir.

Source code in fastvideo/train/utils/checkpoint.py
@staticmethod
def load_metadata(checkpoint_dir: str | Path, ) -> dict[str, Any]:
    """Read ``metadata.json`` from a checkpoint dir."""
    meta_path = Path(checkpoint_dir) / "metadata.json"
    if not meta_path.is_file():
        raise FileNotFoundError(f"No metadata.json in {checkpoint_dir}")
    with open(meta_path, encoding="utf-8") as f:
        return json.load(f)  # type: ignore[no-any-return]
fastvideo.train.utils.checkpoint.CheckpointManager.load_rng_snapshot
load_rng_snapshot(checkpoint_path: str) -> None

Restore per-rank RNG state from the snapshot file.

Must be called AFTER dcp.load and after iter(dataloader) so no later operation can clobber the restored state.

Source code in fastvideo/train/utils/checkpoint.py
def load_rng_snapshot(
    self,
    checkpoint_path: str,
) -> None:
    """Restore per-rank RNG state from the snapshot file.

    Must be called AFTER ``dcp.load`` **and** after
    ``iter(dataloader)`` so no later operation can
    clobber the restored state.
    """
    resolved = _resolve_resume_checkpoint(
        checkpoint_path,
        output_dir=self.output_dir,
    )
    if resolved is None:
        return
    rank = _rank()
    rng_path = resolved / f"rng_state_rank{rank}.pt"
    if not rng_path.is_file():
        # Fall back to legacy single-file snapshot.
        rng_path = resolved / "rng_state.pt"
    if not rng_path.is_file():
        logger.warning(
            "No rng_state in %s; skipping "
            "RNG snapshot restore.",
            resolved,
        )
        return

    rng = torch.load(
        rng_path,
        map_location="cpu",
        weights_only=False,
    )
    if "torch_rng" in rng:
        torch.set_rng_state(rng["torch_rng"])
    if "python_rng" in rng:
        random.setstate(rng["python_rng"])
    if "numpy_rng" in rng:
        np.random.set_state(rng["numpy_rng"])

    torch.cuda.set_rng_state(rng["cuda_rng"])
    self.method.cuda_generator.set_state(rng["gen_cuda"])
    logger.info(
        "Restored RNG snapshot from %s",
        rng_path,
    )
Functions
fastvideo.train.utils.config

Training run config (_target_ based YAML).

Classes
fastvideo.train.utils.config.RunConfig dataclass
RunConfig(models: dict[str, dict[str, Any]], method: dict[str, Any], training: TrainingConfig, callbacks: dict[str, dict[str, Any]], raw: dict[str, Any])

Parsed run config loaded from YAML.

Functions
fastvideo.train.utils.config.RunConfig.resolved_config
resolved_config() -> dict[str, Any]

Return a fully-resolved config dict with defaults.

Suitable for logging to W&B so that every parameter (including defaults) is visible.

Source code in fastvideo/train/utils/config.py
def resolved_config(self) -> dict[str, Any]:
    """Return a fully-resolved config dict with defaults.

    Suitable for logging to W&B so that every parameter
    (including defaults) is visible.
    """
    import dataclasses

    def _safe_asdict(obj: Any) -> Any:
        if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
            return {
                f.name: _safe_asdict(getattr(obj, f.name))
                for f in dataclasses.fields(obj) if not callable(getattr(obj, f.name))
            }
        if isinstance(obj, dict):
            return {k: _safe_asdict(v) for k, v in obj.items()}
        if isinstance(obj, list | tuple):
            return type(obj)(_safe_asdict(v) for v in obj)
        return obj

    resolved: dict[str, Any] = {}
    resolved["models"] = dict(self.models)
    resolved["method"] = dict(self.method)
    resolved["training"] = _safe_asdict(self.training)
    resolved["callbacks"] = dict(self.callbacks)
    return resolved
Functions
fastvideo.train.utils.config.load_run_config
load_run_config(path: str, overrides: list[str] | None = None) -> RunConfig

Load a run config from YAML.

Expected top-level keys: models, method, training (nested), and optionally callbacks and pipeline.

Parameters:

Name Type Description Default
path str

Path to the YAML config file.

required
overrides list[str] | None

Optional list of CLI override tokens, e.g. ["--training.distributed.num_gpus", "4"]. Dotted keys map to nested YAML paths.

None
Source code in fastvideo/train/utils/config.py
def load_run_config(
    path: str,
    overrides: list[str] | None = None,
) -> RunConfig:
    """Load a run config from YAML.

    Expected top-level keys: ``models``, ``method``,
    ``training`` (nested), and optionally ``callbacks``
    and ``pipeline``.

    Args:
        path: Path to the YAML config file.
        overrides: Optional list of CLI override tokens,
            e.g. ``["--training.distributed.num_gpus", "4"]``.
            Dotted keys map to nested YAML paths.
    """
    path = _resolve_existing_file(path)
    with open(path, encoding="utf-8") as f:
        raw = yaml.safe_load(f)
    cfg = _require_mapping(raw, where=path)

    # Apply CLI overrides before building typed config.
    if overrides:
        parsed = _parse_cli_overrides(overrides)
        _apply_overrides(cfg, parsed)
        logger.info("Applied CLI overrides: %s", parsed)

    # --- models ---
    models_raw = _require_mapping(cfg.get("models"), where="models")
    models: dict[str, dict[str, Any]] = {}
    for role, model_cfg_raw in models_raw.items():
        role_str = _require_str(role, where="models.<role>")
        model_cfg = _require_mapping(model_cfg_raw, where=f"models.{role_str}")
        if "_target_" not in model_cfg:
            raise ValueError(f"models.{role_str} must have a "
                             "'_target_' key")
        models[role_str] = dict(model_cfg)

    # --- method ---
    method_raw = _require_mapping(cfg.get("method"), where="method")
    if "_target_" not in method_raw:
        raise ValueError("method must have a '_target_' key")
    method = dict(method_raw)

    # --- callbacks ---
    callbacks_raw = cfg.get("callbacks", None)
    if callbacks_raw is None:
        callbacks: dict[str, dict[str, Any]] = {}
    else:
        callbacks = _require_mapping(callbacks_raw, where="callbacks")

    # --- pipeline config ---
    pipeline_config = _parse_pipeline_config(cfg, models=models)

    # --- training config ---
    training_raw = _require_mapping(cfg.get("training"), where="training")
    t = dict(training_raw)
    training = _build_training_config(t, models=models, pipeline_config=pipeline_config)

    return RunConfig(
        models=models,
        method=method,
        training=training,
        callbacks=callbacks,
        raw=cfg,
    )
fastvideo.train.utils.config.require_bool
require_bool(mapping: dict[str, Any], key: str, *, default: bool | None = None, where: str | None = None) -> bool

Read a bool value.

Source code in fastvideo/train/utils/config.py
def require_bool(
    mapping: dict[str, Any],
    key: str,
    *,
    default: bool | None = None,
    where: str | None = None,
) -> bool:
    """Read a bool value."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            return default
        raise ValueError(f"Missing required key {loc!r}")
    if not isinstance(raw, bool):
        raise ValueError(f"{loc} must be a bool, "
                         f"got {type(raw).__name__}")
    return raw
fastvideo.train.utils.config.require_choice
require_choice(mapping: dict[str, Any], key: str, choices: set[str] | frozenset[str], *, default: str | None = None, where: str | None = None) -> str

Read a string that must be one of choices.

Source code in fastvideo/train/utils/config.py
def require_choice(
    mapping: dict[str, Any],
    key: str,
    choices: set[str] | frozenset[str],
    *,
    default: str | None = None,
    where: str | None = None,
) -> str:
    """Read a string that must be one of *choices*."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            if default not in choices:
                raise ValueError(f"Default {default!r} not in {choices}")
            return default
        raise ValueError(f"Missing required key {loc!r}")
    if not isinstance(raw, str) or not raw.strip():
        raise ValueError(f"{loc} must be a non-empty string, "
                         f"got {type(raw).__name__}")
    val = raw.strip().lower()
    if val not in choices:
        raise ValueError(f"{loc} must be one of {sorted(choices)}, "
                         f"got {raw!r}")
    return val
fastvideo.train.utils.config.require_non_negative_float
require_non_negative_float(mapping: dict[str, Any], key: str, *, default: float | None = None, where: str | None = None) -> float

Read a float that must be >= 0.

Source code in fastvideo/train/utils/config.py
def require_non_negative_float(
    mapping: dict[str, Any],
    key: str,
    *,
    default: float | None = None,
    where: str | None = None,
) -> float:
    """Read a float that must be >= 0."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            return default
        raise ValueError(f"Missing required key {loc!r}")
    val = get_optional_float(mapping, key, where=loc)
    if val is None or val < 0.0:
        raise ValueError(f"{loc} must be a non-negative float, "
                         f"got {raw!r}")
    return val
fastvideo.train.utils.config.require_non_negative_int
require_non_negative_int(mapping: dict[str, Any], key: str, *, default: int | None = None, where: str | None = None) -> int

Read an int that must be >= 0.

Source code in fastvideo/train/utils/config.py
def require_non_negative_int(
    mapping: dict[str, Any],
    key: str,
    *,
    default: int | None = None,
    where: str | None = None,
) -> int:
    """Read an int that must be >= 0."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            return default
        raise ValueError(f"Missing required key {loc!r}")
    val = get_optional_int(mapping, key, where=loc)
    if val is None or val < 0:
        raise ValueError(f"{loc} must be a non-negative integer, "
                         f"got {raw!r}")
    return val
fastvideo.train.utils.config.require_positive_int
require_positive_int(mapping: dict[str, Any], key: str, *, default: int | None = None, where: str | None = None) -> int

Read an int that must be > 0.

Source code in fastvideo/train/utils/config.py
def require_positive_int(
    mapping: dict[str, Any],
    key: str,
    *,
    default: int | None = None,
    where: str | None = None,
) -> int:
    """Read an int that must be > 0."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            return default
        raise ValueError(f"Missing required key {loc!r}")
    val = get_optional_int(mapping, key, where=loc)
    if val is None or val <= 0:
        raise ValueError(f"{loc} must be a positive integer, got {raw!r}")
    return val
fastvideo.train.utils.dataloader
Functions
fastvideo.train.utils.dataloader.build_parquet_t2v_train_dataloader
build_parquet_t2v_train_dataloader(data_config: DataConfig, *, text_len: int, parquet_schema: Any) -> Any

Build a parquet dataloader for T2V-style datasets.

Source code in fastvideo/train/utils/dataloader.py
def build_parquet_t2v_train_dataloader(
    data_config: DataConfig,
    *,
    text_len: int,
    parquet_schema: Any,
) -> Any:
    """Build a parquet dataloader for T2V-style datasets."""

    from fastvideo.dataset import (
        build_parquet_map_style_dataloader, )

    _dataset, dataloader = (build_parquet_map_style_dataloader(
        data_config.data_path,
        data_config.train_batch_size,
        num_data_workers=(data_config.dataloader_num_workers),
        parquet_schema=parquet_schema,
        cfg_rate=data_config.training_cfg_rate,
        drop_last=True,
        text_padding_length=int(text_len),
        seed=int(data_config.seed or 0),
    ))
    return dataloader
fastvideo.train.utils.instantiate

_target_-based instantiation utilities.

These helpers resolve a dotted Python path to a class and instantiate it, filtering constructor kwargs through inspect.signature so that only recognized parameters are forwarded. Unrecognized keys emit a warning rather than raising — this keeps YAML configs forward-compatible when a class drops a parameter in a later version.

Functions
fastvideo.train.utils.instantiate.instantiate
instantiate(cfg: dict[str, Any], **extra: Any) -> Any

Instantiate the class specified by cfg["_target_"].

All remaining keys in cfg (minus _target_) plus any extra keyword arguments are forwarded to the constructor. Keys that do not match an __init__ parameter are silently warned about and dropped, so callers can safely pass a superset.

Source code in fastvideo/train/utils/instantiate.py
def instantiate(cfg: dict[str, Any], **extra: Any) -> Any:
    """Instantiate the class specified by ``cfg["_target_"]``.

    All remaining keys in *cfg* (minus ``_target_``) plus any *extra*
    keyword arguments are forwarded to the constructor.  Keys that do
    not match an ``__init__`` parameter are silently warned about and
    dropped, so callers can safely pass a superset.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f"instantiate() expects a dict with '_target_', "
                        f"got {type(cfg).__name__}")
    target_str = cfg.get("_target_")
    if target_str is None:
        raise KeyError("Config dict is missing '_target_' key")

    cls = resolve_target(str(target_str))
    kwargs: dict[str, Any] = {k: v for k, v in cfg.items() if k != "_target_"}
    kwargs.update(extra)

    sig = inspect.signature(cls.__init__)  # type: ignore[misc]
    params = sig.parameters
    has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())

    if not has_var_keyword:
        valid_names = {
            name
            for name, p in params.items() if p.kind in (
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
                inspect.Parameter.KEYWORD_ONLY,
            )
        }
        valid_names.discard("self")
        unrecognized = set(kwargs) - valid_names
        if unrecognized:
            warnings.warn(
                f"instantiate({target_str}): dropping unrecognized "
                f"kwargs {sorted(unrecognized)}",
                stacklevel=2,
            )
            for key in unrecognized:
                del kwargs[key]

    return cls(**kwargs)
fastvideo.train.utils.instantiate.resolve_target
resolve_target(target: str) -> type

Import and return the class (or callable) at target.

target must be a fully-qualified dotted path, e.g. "fastvideo.train.models.wan.wan.WanModel".

Source code in fastvideo/train/utils/instantiate.py
def resolve_target(target: str) -> type:
    """Import and return the class (or callable) at *target*.

    *target* must be a fully-qualified dotted path, e.g.
    ``"fastvideo.train.models.wan.wan.WanModel"``.
    """
    if not isinstance(target, str) or not target.strip():
        raise ValueError(f"_target_ must be a non-empty dotted path string, "
                         f"got {target!r}")
    target = target.strip()
    parts = target.rsplit(".", 1)
    if len(parts) != 2:
        raise ValueError(f"_target_ must contain at least one dot "
                         f"(module.ClassName), got {target!r}")
    module_path, attr_name = parts
    try:
        module = importlib.import_module(module_path)
    except ModuleNotFoundError as exc:
        raise ImportError(f"Cannot import module {module_path!r} "
                          f"(from _target_={target!r})") from exc
    try:
        cls = getattr(module, attr_name)
    except AttributeError as exc:
        raise ImportError(f"Module {module_path!r} has no attribute "
                          f"{attr_name!r} (from _target_={target!r})") from exc
    return cls
fastvideo.train.utils.module_state
Functions
fastvideo.train.utils.module_state.apply_trainable
apply_trainable(module: Module, *, trainable: bool) -> Module

Apply train/eval mode + requires_grad based on a role's trainable flag.

Source code in fastvideo/train/utils/module_state.py
def apply_trainable(module: torch.nn.Module, *, trainable: bool) -> torch.nn.Module:
    """Apply train/eval mode + requires_grad based on a role's trainable flag."""

    module.requires_grad_(bool(trainable))
    if trainable:
        module.train()
    else:
        module.eval()
    return module
fastvideo.train.utils.moduleloader
Classes
Functions
fastvideo.train.utils.moduleloader.load_module_from_path
load_module_from_path(*, model_path: str, module_type: str, training_config: TrainingConfig, disable_custom_init_weights: bool = False, override_transformer_cls_name: str | None = None, transformer_override_safetensor: str | None = None) -> Module

Load a single pipeline component module.

Accepts a TrainingConfig and internally builds the TrainingArgs needed by PipelineComponentLoader.

Source code in fastvideo/train/utils/moduleloader.py
def load_module_from_path(
    *,
    model_path: str,
    module_type: str,
    training_config: TrainingConfig,
    disable_custom_init_weights: bool = False,
    override_transformer_cls_name: str | None = None,
    transformer_override_safetensor: str | None = None,
) -> torch.nn.Module:
    """Load a single pipeline component module.

    Accepts a ``TrainingConfig`` and internally builds the
    ``TrainingArgs`` needed by ``PipelineComponentLoader``.
    """
    fastvideo_args: Any = _make_training_args(training_config, model_path=model_path)

    local_model_path = maybe_download_model(model_path)
    config = verify_model_config_and_directory(local_model_path)

    if module_type not in config:
        raise ValueError(f"Module {module_type!r} not found in "
                         f"config at {local_model_path}")

    module_info = config[module_type]
    if module_info is None:
        raise ValueError(f"Module {module_type!r} has null value in "
                         f"config at {local_model_path}")

    transformers_or_diffusers, _architecture = module_info
    component_path = os.path.join(local_model_path, module_type)

    old_override: str | None = None
    if override_transformer_cls_name is not None:
        old_override = getattr(
            fastvideo_args,
            "override_transformer_cls_name",
            None,
        )
        fastvideo_args.override_transformer_cls_name = str(override_transformer_cls_name)

    if transformer_override_safetensor:
        fastvideo_args.init_weights_from_safetensors = str(transformer_override_safetensor)

    if disable_custom_init_weights:
        fastvideo_args._loading_teacher_critic_model = True
    try:
        module = PipelineComponentLoader.load_module(
            module_name=module_type,
            component_model_path=component_path,
            transformers_or_diffusers=(transformers_or_diffusers),
            fastvideo_args=fastvideo_args,
        )
    finally:
        if disable_custom_init_weights and hasattr(fastvideo_args, "_loading_teacher_critic_model"):
            del fastvideo_args._loading_teacher_critic_model
        if override_transformer_cls_name is not None:
            if old_override is None:
                if hasattr(
                        fastvideo_args,
                        "override_transformer_cls_name",
                ):
                    fastvideo_args.override_transformer_cls_name = (None)
            else:
                fastvideo_args.override_transformer_cls_name = (old_override)

    if not isinstance(module, torch.nn.Module):
        raise TypeError(f"Loaded {module_type!r} is not a "
                        f"torch.nn.Module: {type(module)}")
    return module
fastvideo.train.utils.moduleloader.make_inference_args
make_inference_args(tc: TrainingConfig, *, model_path: str) -> TrainingArgs

Build a TrainingArgs for inference (validation / pipelines).

Source code in fastvideo/train/utils/moduleloader.py
def make_inference_args(
    tc: TrainingConfig,
    *,
    model_path: str,
) -> TrainingArgs:
    """Build a TrainingArgs for inference (validation / pipelines)."""
    args = _make_training_args(tc, model_path=model_path)
    args.inference_mode = True
    args.mode = ExecutionMode.INFERENCE
    args.dit_cpu_offload = True
    args.VSA_sparsity = tc.vsa_sparsity
    return args
fastvideo.train.utils.optimizer
Functions
fastvideo.train.utils.optimizer.build_optimizer_and_scheduler
build_optimizer_and_scheduler(*, params: list[Parameter], optimizer_config: OptimizerConfig, loop_config: TrainingLoopConfig, learning_rate: float, betas: tuple[float, float], scheduler_name: str) -> tuple[Optimizer, object]

Build an AdamW optimizer and LR scheduler.

Returns (optimizer, lr_scheduler) so the caller can store them as method-level attributes.

Source code in fastvideo/train/utils/optimizer.py
def build_optimizer_and_scheduler(
    *,
    params: list[torch.nn.Parameter],
    optimizer_config: OptimizerConfig,
    loop_config: TrainingLoopConfig,
    learning_rate: float,
    betas: tuple[float, float],
    scheduler_name: str,
) -> tuple[torch.optim.Optimizer, object]:
    """Build an AdamW optimizer and LR scheduler.

    Returns ``(optimizer, lr_scheduler)`` so the caller can store them
    as method-level attributes.
    """
    if not params:
        raise ValueError("No trainable parameters passed to "
                         "build_optimizer_and_scheduler")

    optimizer = torch.optim.AdamW(
        params,
        lr=float(learning_rate),
        betas=betas,
        weight_decay=float(optimizer_config.weight_decay),
        eps=1e-8,
    )

    scheduler = get_scheduler(
        str(scheduler_name),
        optimizer=optimizer,
        num_warmup_steps=int(optimizer_config.lr_warmup_steps),
        num_training_steps=int(loop_config.max_train_steps),
        num_cycles=int(optimizer_config.lr_num_cycles),
        power=float(optimizer_config.lr_power),
        min_lr_ratio=float(optimizer_config.min_lr_ratio),
        last_epoch=-1,
    )

    return optimizer, scheduler
fastvideo.train.utils.tracking
Functions
fastvideo.train.utils.tracking.build_tracker
build_tracker(tracker_config: TrackerConfig, checkpoint_config: CheckpointConfig, *, config: dict[str, Any] | None) -> Any

Build a tracker instance for a distillation run.

Source code in fastvideo/train/utils/tracking.py
def build_tracker(
    tracker_config: TrackerConfig,
    checkpoint_config: CheckpointConfig,
    *,
    config: dict[str, Any] | None,
) -> Any:
    """Build a tracker instance for a distillation run."""

    world_group = get_world_group()

    trackers = list(tracker_config.trackers)
    if not trackers and str(tracker_config.project_name):
        trackers.append(Trackers.WANDB.value)
    if world_group.rank != 0:
        trackers = []

    tracker_log_dir = (checkpoint_config.output_dir or os.getcwd())
    if trackers:
        tracker_log_dir = os.path.join(tracker_log_dir, "tracker")

    tracker_config_dict = config if trackers else None
    tracker_run_name = tracker_config.run_name or None
    project = (tracker_config.project_name or "fastvideo")

    return initialize_trackers(
        trackers,
        experiment_name=project,
        config=tracker_config_dict,
        log_dir=tracker_log_dir,
        run_name=tracker_run_name,
    )
fastvideo.train.utils.training_config

Typed training config — replaces TrainingArgs.

Classes