Skip to content

train

Modules

fastvideo.train.callbacks

Classes

fastvideo.train.callbacks.Callback

Base callback with no-op hooks.

Subclasses override whichever hooks they need. The training_config and method attributes are set by CallbackDict after instantiation.

fastvideo.train.callbacks.CallbackDict
CallbackDict(callback_configs: dict[str, dict[str, Any]], training_config: TrainingConfig)

Manages a collection of named callbacks.

Instantiates each callback from its _target_ config and dispatches hook calls to all registered callbacks.

Source code in fastvideo/train/callbacks/callback.py
def __init__(
    self,
    callback_configs: dict[str, dict[str, Any]],
    training_config: TrainingConfig,
) -> None:
    self._callbacks: dict[str, Callback] = {}
    if not callback_configs:
        return
    for name, cb_cfg in callback_configs.items():
        cb_cfg = dict(cb_cfg)
        if "_target_" not in cb_cfg:
            if name in _BUILTIN_CALLBACKS:
                cb_cfg["_target_"] = (_BUILTIN_CALLBACKS[name])
            else:
                logger.warning(
                    "Callback %r is missing "
                    "'_target_', skipping: %s",
                    name,
                    cb_cfg,
                )
                continue
        logger.info(
            "Instantiating callback %r: %s",
            name,
            cb_cfg,
        )
        cb = instantiate(cb_cfg)
        if not isinstance(cb, Callback):
            raise TypeError(f"Callback {name!r} resolved to "
                            f"{type(cb).__name__}, expected a "
                            f"Callback subclass.")
        cb.training_config = training_config
        cb._callback_dict = self
        cb.name = name
        self._callbacks[name] = cb
fastvideo.train.callbacks.EMACallback
EMACallback(*, decay: float = 0.9999, start_iter: int = 0)

Bases: Callback

Manage EMA shadow weights for the student transformer.

All configuration lives in the YAML callbacks.ema section:

.. code-block:: yaml

callbacks:
  ema:
    decay: 0.9999
    start_iter: 0

The callback creates an EMA_FSDP instance at train start, updates it after each optimizer step, and exposes an ema_context() context manager for temporarily swapping EMA weights into the live model (used by validation).

Source code in fastvideo/train/callbacks/ema.py
def __init__(
    self,
    *,
    decay: float = 0.9999,
    start_iter: int = 0,
) -> None:
    self._decay = float(decay)
    self._start_iter = int(start_iter)
    self._ema_started = False
    self.student_ema: EMA_FSDP | None = None
Methods:
fastvideo.train.callbacks.EMACallback.ema_context
ema_context(transformer: Module) -> Generator[Module, None, None]

Temporarily swap EMA weights into transformer.

If EMA is not active, yields the transformer unchanged.

Source code in fastvideo/train/callbacks/ema.py
@contextlib.contextmanager
def ema_context(
    self,
    transformer: torch.nn.Module,
) -> Generator[torch.nn.Module, None, None]:
    """Temporarily swap EMA weights into *transformer*.

    If EMA is not active, yields the transformer unchanged.
    """
    if (self.student_ema is not None and self._ema_started):
        with self.student_ema.apply_to_model(transformer, ):
            yield transformer
    else:
        yield transformer
fastvideo.train.callbacks.GradNormClipCallback
GradNormClipCallback(*, max_grad_norm: float = 1.0, log_grad_norms: bool = True)

Bases: Callback

Clip gradient norms before the optimizer step.

max_grad_norm must be set explicitly in the callback config (callbacks.grad_clip.max_grad_norm).

Source code in fastvideo/train/callbacks/grad_clip.py
def __init__(
    self,
    *,
    max_grad_norm: float = 1.0,
    log_grad_norms: bool = True,
) -> None:
    self._max_grad_norm = float(max_grad_norm)
    self._log_grad_norms = bool(log_grad_norms)
fastvideo.train.callbacks.ValidationCallback
ValidationCallback(*, pipeline_target: str, dataset_file: str, every_steps: int = 100, sampling_steps: list[int] | None = None, guidance_scale: float | None = None, num_frames: int | None = None, output_dir: str | None = None, sampling_timesteps: list[int] | None = None, **pipeline_kwargs: Any)

Bases: Callback

Generic validation callback driven entirely by YAML config.

Works with any pipeline that follows the PipelineCls.from_pretrained(...) + pipeline.forward() contract.

Source code in fastvideo/train/callbacks/validation.py
def __init__(
    self,
    *,
    pipeline_target: str,
    dataset_file: str,
    every_steps: int = 100,
    sampling_steps: list[int] | None = None,
    guidance_scale: float | None = None,
    num_frames: int | None = None,
    output_dir: str | None = None,
    sampling_timesteps: list[int] | None = None,
    **pipeline_kwargs: Any,
) -> None:
    self.pipeline_target = str(pipeline_target)
    self.dataset_file = str(dataset_file)
    self.every_steps = int(every_steps)
    self.sampling_steps = ([int(s) for s in sampling_steps] if sampling_steps else [40])
    self.guidance_scale = (float(guidance_scale) if guidance_scale is not None else None)
    self.num_frames = (int(num_frames) if num_frames is not None else None)
    self.output_dir = (str(output_dir) if output_dir is not None else None)
    self.sampling_timesteps = ([int(s) for s in sampling_timesteps] if sampling_timesteps is not None else None)
    self.pipeline_kwargs = dict(pipeline_kwargs)

    # Set after on_train_start.
    self._pipeline: Any | None = None
    self._pipeline_key: tuple[Any, ...] | None = None
    self._sampling_param: SamplingParam | None = None
    self.tracker: Any = DummyTracker()
    self.validation_random_generator: (torch.Generator | None) = None
    self.seed: int = 0

Modules

fastvideo.train.callbacks.callback

Callback base class and CallbackDict manager.

Adapted from FastGen's callback pattern to FastVideo's types.

Classes
fastvideo.train.callbacks.callback.Callback

Base callback with no-op hooks.

Subclasses override whichever hooks they need. The training_config and method attributes are set by CallbackDict after instantiation.

fastvideo.train.callbacks.callback.CallbackDict
CallbackDict(callback_configs: dict[str, dict[str, Any]], training_config: TrainingConfig)

Manages a collection of named callbacks.

Instantiates each callback from its _target_ config and dispatches hook calls to all registered callbacks.

Source code in fastvideo/train/callbacks/callback.py
def __init__(
    self,
    callback_configs: dict[str, dict[str, Any]],
    training_config: TrainingConfig,
) -> None:
    self._callbacks: dict[str, Callback] = {}
    if not callback_configs:
        return
    for name, cb_cfg in callback_configs.items():
        cb_cfg = dict(cb_cfg)
        if "_target_" not in cb_cfg:
            if name in _BUILTIN_CALLBACKS:
                cb_cfg["_target_"] = (_BUILTIN_CALLBACKS[name])
            else:
                logger.warning(
                    "Callback %r is missing "
                    "'_target_', skipping: %s",
                    name,
                    cb_cfg,
                )
                continue
        logger.info(
            "Instantiating callback %r: %s",
            name,
            cb_cfg,
        )
        cb = instantiate(cb_cfg)
        if not isinstance(cb, Callback):
            raise TypeError(f"Callback {name!r} resolved to "
                            f"{type(cb).__name__}, expected a "
                            f"Callback subclass.")
        cb.training_config = training_config
        cb._callback_dict = self
        cb.name = name
        self._callbacks[name] = cb
Functions:
fastvideo.train.callbacks.ema

EMA (Exponential Moving Average) callback.

Owns the full EMA lifecycle: creation, per-step updates, weight swapping for validation, and checkpoint state. All EMA config lives under callbacks.ema in the YAML file.

Classes
fastvideo.train.callbacks.ema.EMACallback
EMACallback(*, decay: float = 0.9999, start_iter: int = 0)

Bases: Callback

Manage EMA shadow weights for the student transformer.

All configuration lives in the YAML callbacks.ema section:

.. code-block:: yaml

callbacks:
  ema:
    decay: 0.9999
    start_iter: 0

The callback creates an EMA_FSDP instance at train start, updates it after each optimizer step, and exposes an ema_context() context manager for temporarily swapping EMA weights into the live model (used by validation).

Source code in fastvideo/train/callbacks/ema.py
def __init__(
    self,
    *,
    decay: float = 0.9999,
    start_iter: int = 0,
) -> None:
    self._decay = float(decay)
    self._start_iter = int(start_iter)
    self._ema_started = False
    self.student_ema: EMA_FSDP | None = None
Methods:
fastvideo.train.callbacks.ema.EMACallback.ema_context
ema_context(transformer: Module) -> Generator[Module, None, None]

Temporarily swap EMA weights into transformer.

If EMA is not active, yields the transformer unchanged.

Source code in fastvideo/train/callbacks/ema.py
@contextlib.contextmanager
def ema_context(
    self,
    transformer: torch.nn.Module,
) -> Generator[torch.nn.Module, None, None]:
    """Temporarily swap EMA weights into *transformer*.

    If EMA is not active, yields the transformer unchanged.
    """
    if (self.student_ema is not None and self._ema_started):
        with self.student_ema.apply_to_model(transformer, ):
            yield transformer
    else:
        yield transformer
Functions:
fastvideo.train.callbacks.grad_clip

Gradient norm clipping callback.

Clips gradients on modules returned by method.get_grad_clip_targets() before the optimizer step. Optionally logs per-module grad norms to the tracker.

Classes
fastvideo.train.callbacks.grad_clip.GradNormClipCallback
GradNormClipCallback(*, max_grad_norm: float = 1.0, log_grad_norms: bool = True)

Bases: Callback

Clip gradient norms before the optimizer step.

max_grad_norm must be set explicitly in the callback config (callbacks.grad_clip.max_grad_norm).

Source code in fastvideo/train/callbacks/grad_clip.py
def __init__(
    self,
    *,
    max_grad_norm: float = 1.0,
    log_grad_norms: bool = True,
) -> None:
    self._max_grad_norm = float(max_grad_norm)
    self._log_grad_norms = bool(log_grad_norms)
Functions:
fastvideo.train.callbacks.validation

Validation callback.

All configuration is read from the YAML callbacks.validation section. The pipeline class is resolved from pipeline_target.

Classes
fastvideo.train.callbacks.validation.ValidationCallback
ValidationCallback(*, pipeline_target: str, dataset_file: str, every_steps: int = 100, sampling_steps: list[int] | None = None, guidance_scale: float | None = None, num_frames: int | None = None, output_dir: str | None = None, sampling_timesteps: list[int] | None = None, **pipeline_kwargs: Any)

Bases: Callback

Generic validation callback driven entirely by YAML config.

Works with any pipeline that follows the PipelineCls.from_pretrained(...) + pipeline.forward() contract.

Source code in fastvideo/train/callbacks/validation.py
def __init__(
    self,
    *,
    pipeline_target: str,
    dataset_file: str,
    every_steps: int = 100,
    sampling_steps: list[int] | None = None,
    guidance_scale: float | None = None,
    num_frames: int | None = None,
    output_dir: str | None = None,
    sampling_timesteps: list[int] | None = None,
    **pipeline_kwargs: Any,
) -> None:
    self.pipeline_target = str(pipeline_target)
    self.dataset_file = str(dataset_file)
    self.every_steps = int(every_steps)
    self.sampling_steps = ([int(s) for s in sampling_steps] if sampling_steps else [40])
    self.guidance_scale = (float(guidance_scale) if guidance_scale is not None else None)
    self.num_frames = (int(num_frames) if num_frames is not None else None)
    self.output_dir = (str(output_dir) if output_dir is not None else None)
    self.sampling_timesteps = ([int(s) for s in sampling_timesteps] if sampling_timesteps is not None else None)
    self.pipeline_kwargs = dict(pipeline_kwargs)

    # Set after on_train_start.
    self._pipeline: Any | None = None
    self._pipeline_key: tuple[Any, ...] | None = None
    self._sampling_param: SamplingParam | None = None
    self.tracker: Any = DummyTracker()
    self.validation_random_generator: (torch.Generator | None) = None
    self.seed: int = 0
Functions:

fastvideo.train.entrypoint

Modules

fastvideo.train.entrypoint.dcp_to_diffusers

Convert a DCP training checkpoint to a diffusers-style model directory.

Works on a single GPU regardless of how many GPUs were used for training (DCP handles resharding automatically).

Usage (no torchrun needed)::

python -m fastvideo.train.entrypoint.dcp_to_diffusers         --checkpoint /path/to/checkpoint-1000         --output-dir /path/to/diffusers_output

Or with torchrun (also fine)::

torchrun --nproc_per_node=1         -m fastvideo.train.entrypoint.dcp_to_diffusers         --checkpoint ... --output-dir ...

The checkpoint must contain metadata.json (written by CheckpointManager). If the checkpoint predates metadata support, pass --config explicitly to provide the training YAML.

Functions:
fastvideo.train.entrypoint.dcp_to_diffusers.convert
convert(*, checkpoint_dir: str, output_dir: str, config_path: str | None = None, role: str = 'student', overwrite: bool = False) -> str

Load a DCP checkpoint and export as a diffusers model.

Returns the path to the exported model directory.

Source code in fastvideo/train/entrypoint/dcp_to_diffusers.py
def convert(
    *,
    checkpoint_dir: str,
    output_dir: str,
    config_path: str | None = None,
    role: str = "student",
    overwrite: bool = False,
) -> str:
    """Load a DCP checkpoint and export as a diffusers model.

    Returns the path to the exported model directory.
    """
    _ensure_distributed()

    from fastvideo.distributed import (
        maybe_init_distributed_environment_and_model_parallel, )
    from fastvideo.train.utils.builder import build_from_config
    from fastvideo.train.utils.checkpoint import (
        CheckpointManager,
        _resolve_resume_checkpoint,
    )
    from fastvideo.train.utils.config import (
        RunConfig,
        load_run_config,
    )

    import torch.distributed.checkpoint as dcp

    # -- Resolve checkpoint directory --
    resolved = _resolve_resume_checkpoint(
        checkpoint_dir,
        output_dir=checkpoint_dir,
    )
    dcp_dir = resolved / "dcp"
    if not dcp_dir.is_dir():
        raise FileNotFoundError(f"Missing dcp/ under {resolved}")

    # -- Obtain config --
    cfg: RunConfig
    if config_path is not None:
        cfg = load_run_config(config_path)
    else:
        metadata = CheckpointManager.load_metadata(resolved)
        raw_config = metadata.get("config")
        if raw_config is None:
            raise ValueError("Checkpoint metadata.json does not "
                             "contain 'config'. Pass --config "
                             "explicitly.")
        cfg = _run_config_from_raw(raw_config)

    tc = cfg.training

    # -- Init distributed (1 GPU is enough; DCP reshards) --
    maybe_init_distributed_environment_and_model_parallel(
        tp_size=1,
        sp_size=1,
    )

    # Override distributed config so model loading uses 1 GPU.
    tc.distributed.tp_size = 1
    tc.distributed.sp_size = 1
    tc.distributed.num_gpus = 1
    tc.distributed.hsdp_replicate_dim = 1
    tc.distributed.hsdp_shard_dim = 1

    # -- Build model (loads pretrained weights + FSDP) --
    _, method, _, _ = build_from_config(cfg)

    # -- Load DCP weights into the model --
    states = method.checkpoint_state()
    logger.info(
        "Loading DCP checkpoint from %s",
        resolved,
    )
    dcp.load(states, checkpoint_id=str(dcp_dir))

    # -- Export to diffusers format --
    model = method._role_models[role]
    base_model_path = str(tc.model_path)
    if not base_model_path:
        raise ValueError("Cannot determine base_model_path from "
                         "config. Ensure models.student.init_from "
                         "is set.")

    logger.info(
        "Exporting role=%s to %s (base=%s)",
        role,
        output_dir,
        base_model_path,
    )
    result = _save_role_pretrained(
        role=role,
        base_model_path=base_model_path,
        output_dir=output_dir,
        overwrite=overwrite,
        model=model,
    )
    logger.info("Export complete: %s", result)
    return result
fastvideo.train.entrypoint.misc
Modules
fastvideo.train.entrypoint.misc.wan_ode_init_conversion

Convert Self-Forcing ode_init.pt to HuggingFace diffusers format.

The official ode_init.pt from https://huggingface.co/gdhe17/Self-Forcing/resolve/main/checkpoints/ode_init.pt stores weights under {"generator": {<original_wan_keys>}}.

This script converts those keys to diffusers WanTransformer3DModel format, verifies them against a reference model, and saves a complete diffusers-compatible model directory (transformer + scheduler + vae + text_encoder + tokenizer).

Usage

python -m fastvideo.train.entrypoint.misc.wan_ode_init_conversion --input /path/to/ode_init.pt --output /path/to/WanOdeInit --base-model Wan-AI/Wan2.1-T2V-1.3B-Diffusers

Functions:
fastvideo.train.entrypoint.misc.wan_ode_init_conversion.convert_state_dict
convert_state_dict(orig_sd: dict[str, Tensor]) -> dict[str, Tensor]

Convert an entire original-Wan state dict.

Source code in fastvideo/train/entrypoint/misc/wan_ode_init_conversion.py
def convert_state_dict(orig_sd: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]:
    """Convert an entire original-Wan state dict."""
    return {_convert_key(k): v for k, v in orig_sd.items()}
fastvideo.train.entrypoint.train

YAML-only training entrypoint.

Usage::

torchrun --nproc_per_node=<N> -m fastvideo.train.entrypoint.train         --config path/to/run.yaml

Any unknown --dotted.key value arguments are applied as overrides to the YAML config before parsing. For example::

torchrun --nproc_per_node=8 -m fastvideo.train.entrypoint.train         --config path/to/run.yaml         --training.distributed.num_gpus 8         --training.optimizer.learning_rate 1e-5
Functions:
fastvideo.train.entrypoint.train.run_training_from_config
run_training_from_config(config_path: str, *, dry_run: bool = False, overrides: list[str] | None = None) -> None

YAML-only training entrypoint (schema v2).

Source code in fastvideo/train/entrypoint/train.py
def run_training_from_config(
    config_path: str,
    *,
    dry_run: bool = False,
    overrides: list[str] | None = None,
) -> None:
    """YAML-only training entrypoint (schema v2)."""

    from fastvideo.distributed import (
        maybe_init_distributed_environment_and_model_parallel, )
    from fastvideo.train import Trainer
    from fastvideo.train.utils.checkpoint import (
        CheckpointConfig,
        CheckpointManager,
    )
    from fastvideo.train.utils.builder import build_from_config
    from fastvideo.train.utils.config import load_run_config

    # Enable deterministic mode for reproducibility.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    cfg = load_run_config(config_path, overrides=overrides)
    tc = cfg.training

    model_path_lower = str(tc.model_path).lower()

    # Auto-set attention backend for model families that require a specific
    # backend at load time, unless the user already overrode it explicitly.
    if tc.vsa_sparsity > 0.0:
        os.environ.setdefault(
            "FASTVIDEO_ATTENTION_BACKEND",
            "VIDEO_SPARSE_ATTN",
        )
    elif ("turbodiffusion" in model_path_lower or "turbowan" in model_path_lower):
        os.environ.setdefault(
            "FASTVIDEO_ATTENTION_BACKEND",
            "SLA_ATTN",
        )

    maybe_init_distributed_environment_and_model_parallel(
        tc.distributed.tp_size,
        tc.distributed.sp_size,
    )

    _, method, dataloader, start_step = build_from_config(cfg)

    if dry_run:
        logger.info("Dry-run: config parsed and "
                    "build_from_config succeeded.")
        return

    trainer = Trainer(
        tc,
        config=cfg.resolved_config(),
        callback_configs=cfg.callbacks,
    )

    # Attach the exact YAML used for this run to the
    # tracker (e.g., W&B Files).
    trainer.tracker.log_file(
        os.path.abspath(os.path.expanduser(config_path)),
        name="run.yaml",
    )

    ckpt_config = CheckpointConfig(
        save_steps=int(tc.checkpoint.training_state_checkpointing_steps or 0),
        keep_last=int(tc.checkpoint.checkpoints_total_limit or 0),
    )

    checkpoint_manager = CheckpointManager(
        method=method,
        dataloader=dataloader,
        output_dir=tc.checkpoint.output_dir,
        config=ckpt_config,
        callbacks=trainer.callbacks,
        raw_config=cfg.raw,
    )

    trainer.run(
        method,
        dataloader=dataloader,
        max_steps=tc.loop.max_train_steps,
        start_step=start_step,
        checkpoint_manager=checkpoint_manager,
    )

fastvideo.train.methods

Classes

fastvideo.train.methods.TrainingMethod
TrainingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: Module, ABC

Base training method (algorithm layer).

Subclasses own their role models (student, teacher, critic, …) as plain attributes and manage optimizers directly — no RoleManager or RoleHandle.

The constructor receives role_models (a dict[str, ModelBase]) and a cfg object. It calls init_preprocessors on the student and builds self.role_modules for FSDP wrapping.

A single shared CUDA RNG generator (cuda_generator) is created in :meth:on_train_start. All torch.randn / torch.randint calls in methods and models must use this generator instead of relying on global RNG state.

Source code in fastvideo/train/methods/base.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__()
    self.tracker: Any | None = None
    self._role_models: dict[str, ModelBase] = dict(role_models)

    self.student = role_models["student"]
    self.training_config = cfg.training
    self.method_config: dict[str, Any] = dict(cfg.method)
    self.validation_config: dict[str, Any] = dict(getattr(cfg, "validation", {}) or {})

    # Build nn.ModuleDict for FSDP / checkpoint visibility.
    self.role_modules = torch.nn.ModuleDict()
    for role, model in role_models.items():
        mods: dict[str, torch.nn.Module] = {}
        transformer = getattr(model, "transformer", None)
        if isinstance(transformer, torch.nn.Module):
            mods["transformer"] = transformer
        if mods:
            self.role_modules[role] = torch.nn.ModuleDict(mods)
Methods:
fastvideo.train.methods.TrainingMethod.checkpoint_state
checkpoint_state() -> dict[str, Any]

Return DCP-ready checkpoint state for all trainable roles.

Keys follow the convention: roles.<role>.<module>, optimizers.<role>, schedulers.<role>, random_state.*.

EMA state is managed by the EMACallback and is checkpointed through the callback state mechanism.

Source code in fastvideo/train/methods/base.py
def checkpoint_state(self) -> dict[str, Any]:
    """Return DCP-ready checkpoint state for all trainable roles.

    Keys follow the convention:
    ``roles.<role>.<module>``, ``optimizers.<role>``,
    ``schedulers.<role>``, ``random_state.*``.

    EMA state is managed by the ``EMACallback`` and is
    checkpointed through the callback state mechanism.
    """
    states: dict[str, Any] = {}

    for role, model in self._role_models.items():
        if not getattr(model, "_trainable", False):
            continue

        modules: dict[str, torch.nn.Module] = {}
        if model.transformer is not None:
            modules["transformer"] = model.transformer

        container = _RoleModuleContainer(modules)

        for module_name, module in modules.items():
            states[f"roles.{role}.{module_name}"] = ModelWrapper(module)

        opt = self._optimizer_dict.get(role)
        if opt is not None:
            states[f"optimizers.{role}"] = OptimizerWrapper(container, opt)

        sched = self._lr_scheduler_dict.get(role)
        if sched is not None:
            states[f"schedulers.{role}"] = SchedulerWrapper(sched)

    return states
fastvideo.train.methods.TrainingMethod.get_grad_clip_targets
get_grad_clip_targets(iteration: int) -> dict[str, Module]

Return modules whose gradients should be clipped.

Override in subclasses to add/conditionally include modules (e.g. critic, conditionally student). Default: student transformer.

Source code in fastvideo/train/methods/base.py
def get_grad_clip_targets(
    self,
    iteration: int,
) -> dict[str, torch.nn.Module]:
    """Return modules whose gradients should be clipped.

    Override in subclasses to add/conditionally include
    modules (e.g. critic, conditionally student).
    Default: student transformer.
    """
    return {"student": self.student.transformer}
fastvideo.train.methods.TrainingMethod.seed_optimizer_state_for_resume
seed_optimizer_state_for_resume() -> None

Seed optimizer state so DCP can load saved state.

A fresh optimizer has empty state (exp_avg, exp_avg_sq, step are only created on the first optimizer.step()). DCP needs matching entries to load into; without them the saved optimizer state is silently dropped.

Source code in fastvideo/train/methods/base.py
def seed_optimizer_state_for_resume(self) -> None:
    """Seed optimizer state so DCP can load saved state.

    A fresh optimizer has empty state (exp_avg, exp_avg_sq,
    step are only created on the first optimizer.step()).
    DCP needs matching entries to load into; without them
    the saved optimizer state is silently dropped.
    """
    for opt in self.get_optimizers(0):
        for group in opt.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                if len(opt.state.get(p, {})) > 0:
                    continue
                opt.state[p] = {
                    "step": torch.tensor(0.0),
                    "exp_avg": torch.zeros_like(p),
                    "exp_avg_sq": torch.zeros_like(p),
                }

Modules

fastvideo.train.methods.base
Classes
fastvideo.train.methods.base.TrainingMethod
TrainingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: Module, ABC

Base training method (algorithm layer).

Subclasses own their role models (student, teacher, critic, …) as plain attributes and manage optimizers directly — no RoleManager or RoleHandle.

The constructor receives role_models (a dict[str, ModelBase]) and a cfg object. It calls init_preprocessors on the student and builds self.role_modules for FSDP wrapping.

A single shared CUDA RNG generator (cuda_generator) is created in :meth:on_train_start. All torch.randn / torch.randint calls in methods and models must use this generator instead of relying on global RNG state.

Source code in fastvideo/train/methods/base.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__()
    self.tracker: Any | None = None
    self._role_models: dict[str, ModelBase] = dict(role_models)

    self.student = role_models["student"]
    self.training_config = cfg.training
    self.method_config: dict[str, Any] = dict(cfg.method)
    self.validation_config: dict[str, Any] = dict(getattr(cfg, "validation", {}) or {})

    # Build nn.ModuleDict for FSDP / checkpoint visibility.
    self.role_modules = torch.nn.ModuleDict()
    for role, model in role_models.items():
        mods: dict[str, torch.nn.Module] = {}
        transformer = getattr(model, "transformer", None)
        if isinstance(transformer, torch.nn.Module):
            mods["transformer"] = transformer
        if mods:
            self.role_modules[role] = torch.nn.ModuleDict(mods)
Methods:
fastvideo.train.methods.base.TrainingMethod.checkpoint_state
checkpoint_state() -> dict[str, Any]

Return DCP-ready checkpoint state for all trainable roles.

Keys follow the convention: roles.<role>.<module>, optimizers.<role>, schedulers.<role>, random_state.*.

EMA state is managed by the EMACallback and is checkpointed through the callback state mechanism.

Source code in fastvideo/train/methods/base.py
def checkpoint_state(self) -> dict[str, Any]:
    """Return DCP-ready checkpoint state for all trainable roles.

    Keys follow the convention:
    ``roles.<role>.<module>``, ``optimizers.<role>``,
    ``schedulers.<role>``, ``random_state.*``.

    EMA state is managed by the ``EMACallback`` and is
    checkpointed through the callback state mechanism.
    """
    states: dict[str, Any] = {}

    for role, model in self._role_models.items():
        if not getattr(model, "_trainable", False):
            continue

        modules: dict[str, torch.nn.Module] = {}
        if model.transformer is not None:
            modules["transformer"] = model.transformer

        container = _RoleModuleContainer(modules)

        for module_name, module in modules.items():
            states[f"roles.{role}.{module_name}"] = ModelWrapper(module)

        opt = self._optimizer_dict.get(role)
        if opt is not None:
            states[f"optimizers.{role}"] = OptimizerWrapper(container, opt)

        sched = self._lr_scheduler_dict.get(role)
        if sched is not None:
            states[f"schedulers.{role}"] = SchedulerWrapper(sched)

    return states
fastvideo.train.methods.base.TrainingMethod.get_grad_clip_targets
get_grad_clip_targets(iteration: int) -> dict[str, Module]

Return modules whose gradients should be clipped.

Override in subclasses to add/conditionally include modules (e.g. critic, conditionally student). Default: student transformer.

Source code in fastvideo/train/methods/base.py
def get_grad_clip_targets(
    self,
    iteration: int,
) -> dict[str, torch.nn.Module]:
    """Return modules whose gradients should be clipped.

    Override in subclasses to add/conditionally include
    modules (e.g. critic, conditionally student).
    Default: student transformer.
    """
    return {"student": self.student.transformer}
fastvideo.train.methods.base.TrainingMethod.seed_optimizer_state_for_resume
seed_optimizer_state_for_resume() -> None

Seed optimizer state so DCP can load saved state.

A fresh optimizer has empty state (exp_avg, exp_avg_sq, step are only created on the first optimizer.step()). DCP needs matching entries to load into; without them the saved optimizer state is silently dropped.

Source code in fastvideo/train/methods/base.py
def seed_optimizer_state_for_resume(self) -> None:
    """Seed optimizer state so DCP can load saved state.

    A fresh optimizer has empty state (exp_avg, exp_avg_sq,
    step are only created on the first optimizer.step()).
    DCP needs matching entries to load into; without them
    the saved optimizer state is silently dropped.
    """
    for opt in self.get_optimizers(0):
        for group in opt.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                if len(opt.state.get(p, {})) > 0:
                    continue
                opt.state[p] = {
                    "step": torch.tensor(0.0),
                    "exp_avg": torch.zeros_like(p),
                    "exp_avg_sq": torch.zeros_like(p),
                }
Functions:
fastvideo.train.methods.distribution_matching
Classes
fastvideo.train.methods.distribution_matching.DMD2Method
DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
fastvideo.train.methods.distribution_matching.SelfForcingMethod
SelfForcingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: DMD2Method

Self-Forcing DMD2 (distribution matching) method.

Requires a causal student implementing CausalModelBase.

Source code in fastvideo/train/methods/distribution_matching/self_forcing.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(
        cfg=cfg,
        role_models=role_models,
    )

    # Validate causal student.
    if not isinstance(self.student, CausalModelBase):
        raise ValueError("SelfForcingMethod requires a causal student "
                         "implementing CausalModelBase.")

    if self._rollout_mode != "simulate":
        raise ValueError("SelfForcingMethod only supports "
                         "method_config.rollout_mode='simulate'")

    mcfg = self.method_config

    chunk_size = get_optional_int(
        mcfg,
        "chunk_size",
        where="method_config.chunk_size",
    )
    if chunk_size is None:
        chunk_size = 3
    if chunk_size <= 0:
        raise ValueError("method_config.chunk_size must be a positive "
                         f"integer, got {chunk_size}")
    self._chunk_size = int(chunk_size)

    sample_type_raw = mcfg.get("student_sample_type", "sde")
    sample_type = _require_str(
        sample_type_raw,
        where="method_config.student_sample_type",
    )
    sample_type = sample_type.strip().lower()
    if sample_type not in {"sde", "ode"}:
        raise ValueError("method_config.student_sample_type must be one "
                         f"of {{sde, ode}}, got {sample_type_raw!r}")
    self._student_sample_type: Literal["sde", "ode"] = (
        sample_type  # type: ignore[assignment]
    )

    same_step_raw = mcfg.get("same_step_across_blocks", False)
    if same_step_raw is None:
        same_step_raw = False
    self._same_step_across_blocks = _require_bool(
        same_step_raw,
        where="method_config.same_step_across_blocks",
    )

    last_step_raw = mcfg.get("last_step_only", False)
    if last_step_raw is None:
        last_step_raw = False
    self._last_step_only = _require_bool(
        last_step_raw,
        where="method_config.last_step_only",
    )

    context_noise = get_optional_float(
        mcfg,
        "context_noise",
        where="method_config.context_noise",
    )
    if context_noise is None:
        context_noise = 0.0
    if context_noise < 0.0:
        raise ValueError("method_config.context_noise must be >= 0, "
                         f"got {context_noise}")
    self._context_noise = float(context_noise)

    enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
    if enable_grad_raw is None:
        enable_grad_raw = True
    self._enable_gradient_in_rollout = _require_bool(
        enable_grad_raw,
        where="method_config.enable_gradient_in_rollout",
    )

    start_grad_frame = get_optional_int(
        mcfg,
        "start_gradient_frame",
        where="method_config.start_gradient_frame",
    )
    if start_grad_frame is None:
        start_grad_frame = 0
    if start_grad_frame < 0:
        raise ValueError("method_config.start_gradient_frame must be "
                         f">= 0, got {start_grad_frame}")
    self._start_gradient_frame = int(start_grad_frame)

    shift = float(getattr(
        self.training_config.pipeline_config,
        "flow_shift",
        0.0,
    ) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=True,
    )

    self._sf_denoising_step_list: torch.Tensor | None = None
Modules
fastvideo.train.methods.distribution_matching.dmd2

DMD2 distillation method (algorithm layer).

Classes
fastvideo.train.methods.distribution_matching.dmd2.DMD2Method
DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions:
fastvideo.train.methods.distribution_matching.self_forcing

Self-Forcing distillation method (algorithm layer).

Classes
fastvideo.train.methods.distribution_matching.self_forcing.SelfForcingMethod
SelfForcingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: DMD2Method

Self-Forcing DMD2 (distribution matching) method.

Requires a causal student implementing CausalModelBase.

Source code in fastvideo/train/methods/distribution_matching/self_forcing.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(
        cfg=cfg,
        role_models=role_models,
    )

    # Validate causal student.
    if not isinstance(self.student, CausalModelBase):
        raise ValueError("SelfForcingMethod requires a causal student "
                         "implementing CausalModelBase.")

    if self._rollout_mode != "simulate":
        raise ValueError("SelfForcingMethod only supports "
                         "method_config.rollout_mode='simulate'")

    mcfg = self.method_config

    chunk_size = get_optional_int(
        mcfg,
        "chunk_size",
        where="method_config.chunk_size",
    )
    if chunk_size is None:
        chunk_size = 3
    if chunk_size <= 0:
        raise ValueError("method_config.chunk_size must be a positive "
                         f"integer, got {chunk_size}")
    self._chunk_size = int(chunk_size)

    sample_type_raw = mcfg.get("student_sample_type", "sde")
    sample_type = _require_str(
        sample_type_raw,
        where="method_config.student_sample_type",
    )
    sample_type = sample_type.strip().lower()
    if sample_type not in {"sde", "ode"}:
        raise ValueError("method_config.student_sample_type must be one "
                         f"of {{sde, ode}}, got {sample_type_raw!r}")
    self._student_sample_type: Literal["sde", "ode"] = (
        sample_type  # type: ignore[assignment]
    )

    same_step_raw = mcfg.get("same_step_across_blocks", False)
    if same_step_raw is None:
        same_step_raw = False
    self._same_step_across_blocks = _require_bool(
        same_step_raw,
        where="method_config.same_step_across_blocks",
    )

    last_step_raw = mcfg.get("last_step_only", False)
    if last_step_raw is None:
        last_step_raw = False
    self._last_step_only = _require_bool(
        last_step_raw,
        where="method_config.last_step_only",
    )

    context_noise = get_optional_float(
        mcfg,
        "context_noise",
        where="method_config.context_noise",
    )
    if context_noise is None:
        context_noise = 0.0
    if context_noise < 0.0:
        raise ValueError("method_config.context_noise must be >= 0, "
                         f"got {context_noise}")
    self._context_noise = float(context_noise)

    enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
    if enable_grad_raw is None:
        enable_grad_raw = True
    self._enable_gradient_in_rollout = _require_bool(
        enable_grad_raw,
        where="method_config.enable_gradient_in_rollout",
    )

    start_grad_frame = get_optional_int(
        mcfg,
        "start_gradient_frame",
        where="method_config.start_gradient_frame",
    )
    if start_grad_frame is None:
        start_grad_frame = 0
    if start_grad_frame < 0:
        raise ValueError("method_config.start_gradient_frame must be "
                         f">= 0, got {start_grad_frame}")
    self._start_gradient_frame = int(start_grad_frame)

    shift = float(getattr(
        self.training_config.pipeline_config,
        "flow_shift",
        0.0,
    ) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=True,
    )

    self._sf_denoising_step_list: torch.Tensor | None = None
Functions:
fastvideo.train.methods.fine_tuning
Classes
fastvideo.train.methods.fine_tuning.DiffusionForcingSFTMethod
DiffusionForcingSFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Diffusion-forcing SFT (DFSFT): train only student with inhomogeneous timesteps.

Source code in fastvideo/train/methods/fine_tuning/dfsft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DFSFT requires role 'student'")
    if not self.student._trainable:
        raise ValueError("DFSFT requires student to be trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    self._chunk_size = self._parse_chunk_size(self.method_config.get("chunk_size", None))
    self._timestep_index_range = (self._parse_timestep_index_range())
    self._training_weights = self._build_training_weights()

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
fastvideo.train.methods.fine_tuning.FineTuneMethod
FineTuneMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Supervised finetuning: only student participates.

Source code in fastvideo/train/methods/fine_tuning/finetune.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("FineTuneMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("FineTuneMethod requires student to be "
                         "trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Modules
fastvideo.train.methods.fine_tuning.dfsft

Diffusion-forcing SFT method (DFSFT; algorithm layer).

Classes
fastvideo.train.methods.fine_tuning.dfsft.DiffusionForcingSFTMethod
DiffusionForcingSFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Diffusion-forcing SFT (DFSFT): train only student with inhomogeneous timesteps.

Source code in fastvideo/train/methods/fine_tuning/dfsft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DFSFT requires role 'student'")
    if not self.student._trainable:
        raise ValueError("DFSFT requires student to be trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    self._chunk_size = self._parse_chunk_size(self.method_config.get("chunk_size", None))
    self._timestep_index_range = (self._parse_timestep_index_range())
    self._training_weights = self._build_training_weights()

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions:
fastvideo.train.methods.fine_tuning.finetune

Supervised finetuning method (algorithm layer).

Classes
fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod
FineTuneMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Supervised finetuning: only student participates.

Source code in fastvideo/train/methods/fine_tuning/finetune.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("FineTuneMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("FineTuneMethod requires student to be "
                         "trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions:
fastvideo.train.methods.knowledge_distillation
Classes
fastvideo.train.methods.knowledge_distillation.KDCausalMethod
KDCausalMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: KDMethod

KD for causal Wan: per-frame block-quantized timestep sampling.

Identical to :class:KDMethod except single_train_step samples a per-frame denoising step index (block-quantized to groups of num_frames_per_block frames) instead of one index per batch. This matches the legacy ODEInitTrainingPipeline training scheme required by causal / streaming student models.

Additional YAML field under method::

num_frames_per_block: 3   # frames sharing the same noise level
Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    self._num_frames_per_block: int = int(self.method_config.get("num_frames_per_block", 3))
    if self._num_frames_per_block < 1:
        raise ValueError("num_frames_per_block must be >= 1")
fastvideo.train.methods.knowledge_distillation.KDMethod
KDMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Knowledge Distillation training method.

Trains the student with MSE loss on teacher ODE trajectories cached to method_config.teacher_path_cache.

Roles
  • student (required, trainable): the model being distilled.
  • teacher (optional, non-trainable): used to generate the cache on first run; freed from GPU memory afterwards.

If the cache is incomplete and no teacher is configured, an error is raised at the start of training.

Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("KDMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("KDMethod requires student to be trainable")

    mcfg = self.method_config

    # --- Parse method config ---
    raw_t_list = mcfg.get("t_list")
    if not isinstance(raw_t_list, list) or not raw_t_list:
        raise ValueError("method_config.t_list must be a non-empty list "
                         "of integer timestep values, e.g. "
                         "[999, 937, 833, 624, 0]")
    self._t_list: list[int] = [int(t) for t in raw_t_list]

    raw_steps = mcfg.get("student_sample_steps")
    if raw_steps is None:
        raw_steps = len(self._t_list) - 1
    self._num_steps: int = int(raw_steps)
    if len(self._t_list) != self._num_steps + 1:
        raise ValueError(f"len(t_list)={len(self._t_list)} must equal "
                         f"student_sample_steps+1={self._num_steps + 1}")

    cache_dir = mcfg.get("teacher_path_cache")
    if not cache_dir:
        raise ValueError("method_config.teacher_path_cache must be set")
    self._cache_dir: str = str(cache_dir)

    self._teacher_guidance_scale: float = float(mcfg.get("teacher_guidance_scale", 1.0))
    self._teacher_inference_steps: int = int(mcfg.get("teacher_inference_steps", 48))

    # --- Optional teacher ---
    self.teacher: ModelBase | None = role_models.get("teacher")
    if self.teacher is not None and getattr(self.teacher, "_trainable", False):
        raise ValueError("KDMethod requires teacher to be non-trainable "
                         "(set trainable: false in models.teacher)")

    # --- Build parquet dataloader via student.init_preprocessors ---
    self.student.init_preprocessors(self.training_config)

    # Wrap the parquet dataloader so we can swap it after generation.
    self._source_loader = self.student.dataloader
    self._kd_wrapper = _KDDataLoaderWrapper(self._source_loader)
    self.student.dataloader = self._kd_wrapper  # builder captures this

    # --- Build SelfForcingFlowMatchScheduler for sigma lookups ---
    # num_inference_steps=1000 gives a dense grid accurate for any t.
    tc = self.training_config
    self._flow_shift = float(getattr(tc.pipeline_config, "flow_shift", 0.0) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=self._flow_shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=False,
    )

    # --- Student optimizer / scheduler (same as FineTuneMethod) ---
    self._init_optimizers_and_schedulers()
Modules
fastvideo.train.methods.knowledge_distillation.kd

Knowledge Distillation method for ODE-init training.

Trains a student model with MSE loss to reproduce a teacher model's multi-step ODE denoising trajectories. The resulting checkpoint (exported via dcp_to_diffusers) serves as the ode_init weight initialization for downstream Self-Forcing training.

Teacher path generation is cached to disk so it only runs once. Interrupted generation resumes from the last completed sample.

Typical YAML::

models:
  student:
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    trainable: true
  teacher:           # omit once cache is complete
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-14B-Diffusers
    trainable: false
    disable_custom_init_weights: true

method:
  _target_: fastvideo.train.methods.knowledge_distillation.kd.KDMethod
  teacher_path_cache: /data/kd_cache/wan14b_4step
  t_list: [999, 937, 833, 624, 0]   # integer timesteps
  student_sample_steps: 4
  teacher_guidance_scale: 1.0
Classes
fastvideo.train.methods.knowledge_distillation.kd.KDCausalMethod
KDCausalMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: KDMethod

KD for causal Wan: per-frame block-quantized timestep sampling.

Identical to :class:KDMethod except single_train_step samples a per-frame denoising step index (block-quantized to groups of num_frames_per_block frames) instead of one index per batch. This matches the legacy ODEInitTrainingPipeline training scheme required by causal / streaming student models.

Additional YAML field under method::

num_frames_per_block: 3   # frames sharing the same noise level
Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    self._num_frames_per_block: int = int(self.method_config.get("num_frames_per_block", 3))
    if self._num_frames_per_block < 1:
        raise ValueError("num_frames_per_block must be >= 1")
fastvideo.train.methods.knowledge_distillation.kd.KDMethod
KDMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Knowledge Distillation training method.

Trains the student with MSE loss on teacher ODE trajectories cached to method_config.teacher_path_cache.

Roles
  • student (required, trainable): the model being distilled.
  • teacher (optional, non-trainable): used to generate the cache on first run; freed from GPU memory afterwards.

If the cache is incomplete and no teacher is configured, an error is raised at the start of training.

Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("KDMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("KDMethod requires student to be trainable")

    mcfg = self.method_config

    # --- Parse method config ---
    raw_t_list = mcfg.get("t_list")
    if not isinstance(raw_t_list, list) or not raw_t_list:
        raise ValueError("method_config.t_list must be a non-empty list "
                         "of integer timestep values, e.g. "
                         "[999, 937, 833, 624, 0]")
    self._t_list: list[int] = [int(t) for t in raw_t_list]

    raw_steps = mcfg.get("student_sample_steps")
    if raw_steps is None:
        raw_steps = len(self._t_list) - 1
    self._num_steps: int = int(raw_steps)
    if len(self._t_list) != self._num_steps + 1:
        raise ValueError(f"len(t_list)={len(self._t_list)} must equal "
                         f"student_sample_steps+1={self._num_steps + 1}")

    cache_dir = mcfg.get("teacher_path_cache")
    if not cache_dir:
        raise ValueError("method_config.teacher_path_cache must be set")
    self._cache_dir: str = str(cache_dir)

    self._teacher_guidance_scale: float = float(mcfg.get("teacher_guidance_scale", 1.0))
    self._teacher_inference_steps: int = int(mcfg.get("teacher_inference_steps", 48))

    # --- Optional teacher ---
    self.teacher: ModelBase | None = role_models.get("teacher")
    if self.teacher is not None and getattr(self.teacher, "_trainable", False):
        raise ValueError("KDMethod requires teacher to be non-trainable "
                         "(set trainable: false in models.teacher)")

    # --- Build parquet dataloader via student.init_preprocessors ---
    self.student.init_preprocessors(self.training_config)

    # Wrap the parquet dataloader so we can swap it after generation.
    self._source_loader = self.student.dataloader
    self._kd_wrapper = _KDDataLoaderWrapper(self._source_loader)
    self.student.dataloader = self._kd_wrapper  # builder captures this

    # --- Build SelfForcingFlowMatchScheduler for sigma lookups ---
    # num_inference_steps=1000 gives a dense grid accurate for any t.
    tc = self.training_config
    self._flow_shift = float(getattr(tc.pipeline_config, "flow_shift", 0.0) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=self._flow_shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=False,
    )

    # --- Student optimizer / scheduler (same as FineTuneMethod) ---
    self._init_optimizers_and_schedulers()
Functions:

fastvideo.train.models

Modules

fastvideo.train.models.base
Classes
fastvideo.train.models.base.CausalModelBase
CausalModelBase(*, trainable: bool = True, lora: LoraConfig | dict[str, Any] | None = None)

Bases: ModelBase

Extension for causal / streaming model plugins.

Cache state is internal to the model instance and keyed by cache_tag (no role handle needed).

Source code in fastvideo/train/models/base.py
def __init__(
    self,
    *,
    trainable: bool = True,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    from fastvideo.train.utils.lora import LoraConfig

    self._trainable = bool(trainable)
    self._lora_config: LoraConfig | None = LoraConfig.coerce(lora)
    self._num_lora_layers = 0
Methods:
fastvideo.train.models.base.CausalModelBase.clear_caches abstractmethod
clear_caches(*, cache_tag: str = 'pos') -> None

Clear internal caches before starting a new rollout.

Source code in fastvideo/train/models/base.py
@abstractmethod
def clear_caches(self, *, cache_tag: str = "pos") -> None:
    """Clear internal caches before starting a new rollout."""
fastvideo.train.models.base.CausalModelBase.predict_noise_streaming abstractmethod
predict_noise_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Streaming predict-noise that may update internal caches.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Streaming predict-noise that may update internal caches."""
fastvideo.train.models.base.CausalModelBase.predict_x0_streaming
predict_x0_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Predict x0 streaming via predict_noise_streaming + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Predict x0 streaming via
    ``predict_noise_streaming`` + conversion."""
    pred_noise = self.predict_noise_streaming(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cache_tag=cache_tag,
        store_kv=store_kv,
        cur_start_frame=cur_start_frame,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    if pred_noise is None:
        return None
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase
ModelBase(*, trainable: bool = True, lora: LoraConfig | dict[str, Any] | None = None)

Bases: ABC

Per-role model instance.

Every role (student, teacher, critic, …) gets its own ModelBase instance. Each instance owns its own transformer and noise_scheduler. Heavyweight resources (VAE, dataloader, RNG seeds) are loaded lazily via :meth:init_preprocessors, which the method calls only on the student.

Source code in fastvideo/train/models/base.py
def __init__(
    self,
    *,
    trainable: bool = True,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    from fastvideo.train.utils.lora import LoraConfig

    self._trainable = bool(trainable)
    self._lora_config: LoraConfig | None = LoraConfig.coerce(lora)
    self._num_lora_layers = 0
Attributes
fastvideo.train.models.base.ModelBase.device property
device: device

The local CUDA device for this rank.

fastvideo.train.models.base.ModelBase.num_train_timesteps property
num_train_timesteps: int

Return the scheduler's training timestep horizon.

Methods:
fastvideo.train.models.base.ModelBase.add_noise abstractmethod
add_noise(clean_latents: Tensor, noise: Tensor, timestep: Tensor) -> Tensor

Apply forward-process noise at timestep.

Source code in fastvideo/train/models/base.py
@abstractmethod
def add_noise(
    self,
    clean_latents: torch.Tensor,
    noise: torch.Tensor,
    timestep: torch.Tensor,
) -> torch.Tensor:
    """Apply forward-process noise at *timestep*."""
fastvideo.train.models.base.ModelBase.backward abstractmethod
backward(loss: Tensor, ctx: Any, *, grad_accum_rounds: int) -> None

Backward that may restore forward-context.

Source code in fastvideo/train/models/base.py
@abstractmethod
def backward(
    self,
    loss: torch.Tensor,
    ctx: Any,
    *,
    grad_accum_rounds: int,
) -> None:
    """Backward that may restore forward-context."""
fastvideo.train.models.base.ModelBase.init_preprocessors
init_preprocessors(training_config: TrainingConfig) -> None

Load VAE, build dataloader, seed RNGs.

Called only on the student by the method's __init__. Default is a no-op so teacher/critic instances skip this.

Source code in fastvideo/train/models/base.py
def init_preprocessors(  # noqa: B027
        self,
        training_config: TrainingConfig,
) -> None:
    """Load VAE, build dataloader, seed RNGs.

    Called only on the student by the method's ``__init__``.
    Default is a no-op so teacher/critic instances skip this.
    """
fastvideo.train.models.base.ModelBase.on_train_start
on_train_start() -> None

Called once before the training loop begins.

Source code in fastvideo/train/models/base.py
def on_train_start(self) -> None:  # noqa: B027
    """Called once before the training loop begins."""
fastvideo.train.models.base.ModelBase.predict_noise abstractmethod
predict_noise(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict noise/flow for the given noisy latents.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict noise/flow for the given noisy latents."""
fastvideo.train.models.base.ModelBase.predict_x0
predict_x0(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict x0 via predict_noise + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict x0 via ``predict_noise`` + conversion."""
    pred_noise = self.predict_noise(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase.prepare_batch abstractmethod
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Convert a dataloader batch into forward primitives.

Source code in fastvideo/train/models/base.py
@abstractmethod
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Convert a dataloader batch into forward primitives."""
fastvideo.train.models.base.ModelBase.shift_and_clamp_timestep
shift_and_clamp_timestep(timestep: Tensor) -> Tensor

Apply model/pipeline timestep shifting and clamp.

Source code in fastvideo/train/models/base.py
def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
    """Apply model/pipeline timestep shifting and clamp."""
    return timestep
Functions:
fastvideo.train.models.cosmos

Cosmos model plugin package.

Classes
Modules
fastvideo.train.models.cosmos.cosmos

Cosmos model plugin (per-role instance).

Subclasses WanModel since Cosmos uses the same FlowMatchEulerDiscreteScheduler. Differences: - transformer class name: CosmosTransformer3DModel - normalize_dit_input("cosmos", ...) instead of ("wan", ...) - forward kwargs: no encoder_attention_mask, needs condition_mask + padding_mask + fps - hidden_states in (B,C,T,H,W) — no permute needed - default flow_shift = 1.0 - single T5 text encoder (not dual like Hunyuan)

Classes
fastvideo.train.models.cosmos.cosmos.CosmosModel
CosmosModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 1.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel

Cosmos 2.5 per-role model.

Inherits most behaviour from WanModel (noise scheduler, timestep sampling, attention metadata, backward). Overrides only the pieces that differ for Cosmos 2.5.

Cosmos 2.5 uses: - Cosmos25Transformer3DModel (velocity prediction) - EDM noise schedule: x_t = x_0 + sigma * eps - No input/output preconditioning (raw latents) - Timestep = raw sigma value - Model output = velocity ≈ noise

Source code in fastvideo/train/models/cosmos/cosmos.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 1.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
    )
Methods:
fastvideo.train.models.cosmos.cosmos.CosmosModel.ensure_negative_conditioning
ensure_negative_conditioning() -> None

Create negative (unconditional) prompt embeddings.

Cosmos 2.5 uses Reason1 (Qwen2.5-VL) which is expensive to load. This method only supports training_cfg_rate=0 (no classifier-free guidance dropout), in which case the negative embedding is never used and a zero placeholder sized to match the text embedding dimension is sufficient. training_cfg_rate>0 would require real Reason1 negative embeddings and is rejected here to avoid silently training with zero-vector "unconditional" inputs.

Source code in fastvideo/train/models/cosmos/cosmos.py
def ensure_negative_conditioning(self) -> None:
    """Create negative (unconditional) prompt embeddings.

    Cosmos 2.5 uses Reason1 (Qwen2.5-VL) which is expensive
    to load.  This method only supports ``training_cfg_rate=0``
    (no classifier-free guidance dropout), in which case the
    negative embedding is never used and a zero placeholder
    sized to match the text embedding dimension is sufficient.
    ``training_cfg_rate>0`` would require real Reason1 negative
    embeddings and is rejected here to avoid silently training
    with zero-vector "unconditional" inputs.
    """
    if self.negative_prompt_embeds is not None:  # type: ignore[has-type]
        return

    assert self.training_config is not None
    tc = self.training_config

    cfg_rate = float(tc.data.training_cfg_rate or 0.0)
    if cfg_rate > 0.0:
        raise NotImplementedError("Cosmos 2.5 currently only supports training_cfg_rate=0; "
                                  f"got training_cfg_rate={cfg_rate}. Real negative-prompt "
                                  "embeddings via Reason1 (Qwen2.5-VL) are not implemented "
                                  "yet — using the zero placeholder with CFG dropout would "
                                  "train against zero-vector \"unconditional\" inputs and "
                                  "produce wrong gradients. Set "
                                  "training.data.training_cfg_rate=0.")

    device = self.device
    dtype = self._get_training_dtype()

    # Infer embedding dimension from the pipeline config's
    # text encoder settings, or fall back to a reasonable
    # default for Cosmos 2.5 (Reason1 full_concat: 100352).
    text_enc_cfgs = tc.pipeline_config.text_encoder_configs
    if text_enc_cfgs:
        arch = text_enc_cfgs[0].arch_config
        embed_dim = getattr(arch, "hidden_size", 100352)
    else:
        embed_dim = 100352

    num_tokens = 512  # Reason1 default padding length

    neg_embeds = torch.zeros(
        1,
        num_tokens,
        embed_dim,
        device=device,
        dtype=dtype,
    )
    neg_mask = torch.ones(
        1,
        num_tokens,
        device=device,
        dtype=dtype,
    )

    self.negative_prompt_embeds = neg_embeds
    self.negative_prompt_attention_mask = neg_mask
fastvideo.train.models.cosmos.cosmos.CosmosModel.prepare_batch
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Same flow as Wan, but uses Cosmos VAE normalisation.

Source code in fastvideo/train/models/cosmos/cosmos.py
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Same flow as Wan, but uses Cosmos VAE
    normalisation."""
    self.ensure_negative_conditioning()
    assert self.training_config is not None
    tc = self.training_config

    dtype = self._get_training_dtype()
    device = self.device

    training_batch = TrainingBatch()
    encoder_hidden_states = raw_batch["text_embedding"]
    encoder_attention_mask = raw_batch["text_attention_mask"]
    infos = raw_batch.get("info_list")

    if latents_source == "zeros":
        batch_size = encoder_hidden_states.shape[0]
        vae_config = (
            tc.pipeline_config.vae_config  # type: ignore[union-attr]
            .arch_config)
        num_channels = getattr(
            vae_config,
            "z_dim",
            getattr(vae_config, "latent_channels", 16),
        )
        spatial_compression_ratio = (vae_config.spatial_compression_ratio)
        latent_height = (tc.data.num_height // spatial_compression_ratio)
        latent_width = (tc.data.num_width // spatial_compression_ratio)
        latents = torch.zeros(
            batch_size,
            num_channels,
            tc.data.num_latent_t,
            latent_height,
            latent_width,
            device=device,
            dtype=dtype,
        )
    elif latents_source == "data":
        if "vae_latent" not in raw_batch:
            raise ValueError("vae_latent not found in batch "
                             "and latents_source='data'")
        latents = raw_batch["vae_latent"]
        latents = latents[:, :, :tc.data.num_latent_t]
        latents = latents.to(device, dtype=dtype)
    else:
        raise ValueError(f"Unknown latents_source: "
                         f"{latents_source!r}")

    training_batch.latents = latents
    training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
    training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
    training_batch.infos = infos

    # KEY DIFFERENCE: "cosmos" normalisation
    training_batch.latents = normalize_dit_input(
        "cosmos",
        training_batch.latents,
        self.vae,
    )
    training_batch = self._prepare_dit_inputs(training_batch, generator)
    training_batch = self._build_attention_metadata(training_batch)

    # Shallow copy keeps the lru_cache'd LongTensor index fields shared
    # with the original metadata; only the float ``VSA_sparsity`` differs
    # between the two views. deepcopy here would materialize a fresh copy
    # of all four cached index tensors on every training step.
    training_batch.attn_metadata_vsa = copy.copy(training_batch.attn_metadata)
    if training_batch.attn_metadata is not None:
        training_batch.attn_metadata.VSA_sparsity = 0.0  # type: ignore[attr-defined]

    return training_batch
Functions:
fastvideo.train.models.hunyuan

Hunyuan model plugin package.

Classes
Modules
fastvideo.train.models.hunyuan.hunyuan

Hunyuan model plugin (per-role instance).

Subclasses WanModel since HunyuanVideo uses the same FlowMatchEulerDiscreteScheduler and linear-interpolation noise schedule. Differences: - transformer class name - normalize_dit_input("hunyuan", ...) instead of ("wan", ...) - forward kwargs: no encoder_attention_mask, no return_dict - default flow_shift = 7

Classes
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel
HunyuanModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 7.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None, lora: LoraConfig | dict[str, Any] | None = None)

Bases: WanModel

HunyuanVideo per-role model.

Inherits most behaviour from WanModel (noise scheduler, timestep sampling, attention metadata, backward). Overrides only the pieces that differ for Hunyuan.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 7.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
        lora=lora,
    )
Methods:
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel.ensure_negative_conditioning
ensure_negative_conditioning() -> None

Encode the negative prompt with dual text encoders (LLaMA + CLIP).

Every rank encodes independently to avoid NCCL deadlocks when only a subset of ranks would otherwise participate.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def ensure_negative_conditioning(self) -> None:
    """Encode the negative prompt with dual text encoders
    (LLaMA + CLIP).

    Every rank encodes independently to avoid NCCL deadlocks
    when only a subset of ranks would otherwise participate.
    """
    if self.negative_prompt_embeds is not None:  # type: ignore[has-type]
        return

    assert self.training_config is not None
    tc = self.training_config
    device = self.device
    dtype = self._get_training_dtype()

    from transformers import (AutoTokenizer, CLIPTextModel, LlamaModel)

    from fastvideo.configs.pipelines.hunyuan import (
        clip_preprocess_text,
        clip_postprocess_text,
        llama_preprocess_text,
        llama_postprocess_text,
    )
    from fastvideo.utils import (PRECISION_TO_TYPE, maybe_download_model)

    model_path = maybe_download_model(tc.model_path)

    # Use configured precisions for each encoder.
    precisions = tc.pipeline_config.text_encoder_precisions
    llama_dtype = PRECISION_TO_TYPE[precisions[0]]
    clip_dtype = PRECISION_TO_TYPE[precisions[1]]

    # --- LLaMA ---
    llama_tok = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
    llama_enc = LlamaModel.from_pretrained(
        os.path.join(model_path, "text_encoder"),
        torch_dtype=llama_dtype,
    ).to(device).eval()

    llama_cfg = tc.pipeline_config.text_encoder_configs[0]
    llama_tok_kwargs = dict(llama_cfg.tokenizer_kwargs)

    negative_prompt = ""
    llama_text = llama_preprocess_text(negative_prompt)

    with torch.no_grad():
        llama_inputs = llama_tok(llama_text, **llama_tok_kwargs).to(device)
        llama_out = llama_enc(**llama_inputs, output_hidden_states=True)
        llama_embeds = llama_postprocess_text(llama_out).squeeze(0)

    del llama_enc, llama_tok

    # --- CLIP ---
    clip_tok = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer_2"))
    clip_enc = CLIPTextModel.from_pretrained(
        os.path.join(model_path, "text_encoder_2"),
        torch_dtype=clip_dtype,
    ).to(device).eval()

    clip_cfg = tc.pipeline_config.text_encoder_configs[1]
    clip_tok_kwargs = dict(clip_cfg.tokenizer_kwargs)
    clip_text = clip_preprocess_text(negative_prompt)

    with torch.no_grad():
        clip_inputs = clip_tok(clip_text, **clip_tok_kwargs).to(device)
        clip_out = clip_enc(**clip_inputs)
        clip_pooled = clip_postprocess_text(clip_out).squeeze(0)

    del clip_enc, clip_tok

    # --- Combine: [pooled_clip_row, llama_embeds] ---
    llama_dim = llama_embeds.shape[-1]
    pooled_row = torch.zeros(llama_dim, device=device)
    pooled_row[:clip_pooled.shape[-1]] = clip_pooled
    neg_embeds = torch.cat(
        [pooled_row.unsqueeze(0), llama_embeds],
        dim=0,
    ).unsqueeze(0).to(device=device, dtype=dtype)

    # Attention mask: all ones for the combined sequence.
    neg_mask = torch.ones(neg_embeds.shape[:2], device=device, dtype=dtype)

    self.negative_prompt_embeds = neg_embeds
    self.negative_prompt_attention_mask = neg_mask
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel.prepare_batch
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Same flow as Wan, but uses Hunyuan VAE normalisation.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Same flow as Wan, but uses Hunyuan VAE normalisation."""
    self.ensure_negative_conditioning()
    assert self.training_config is not None
    tc = self.training_config

    dtype = self._get_training_dtype()
    device = self.device

    training_batch = TrainingBatch()
    encoder_hidden_states = raw_batch["text_embedding"]
    encoder_attention_mask = raw_batch["text_attention_mask"]
    infos = raw_batch.get("info_list")

    if latents_source == "zeros":
        batch_size = encoder_hidden_states.shape[0]
        vae_config = (
            tc.pipeline_config.vae_config  # type: ignore[union-attr]
            .arch_config)
        num_channels = getattr(
            vae_config,
            "latent_channels",
            getattr(vae_config, "z_dim", 16),
        )
        spatial_compression_ratio = (vae_config.spatial_compression_ratio)
        latent_height = (tc.data.num_height // spatial_compression_ratio)
        latent_width = (tc.data.num_width // spatial_compression_ratio)
        latents = torch.zeros(
            batch_size,
            num_channels,
            tc.data.num_latent_t,
            latent_height,
            latent_width,
            device=device,
            dtype=dtype,
        )
    elif latents_source == "data":
        if "vae_latent" not in raw_batch:
            raise ValueError("vae_latent not found in batch "
                             "and latents_source='data'")
        latents = raw_batch["vae_latent"]
        latents = latents[:, :, :tc.data.num_latent_t]
        latents = latents.to(device, dtype=dtype)
    else:
        raise ValueError(f"Unknown latents_source: "
                         f"{latents_source!r}")

    training_batch.latents = latents
    training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
    training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
    training_batch.infos = infos

    # KEY DIFFERENCE: "hunyuan" normalisation
    training_batch.latents = normalize_dit_input(
        "hunyuan",
        training_batch.latents,
        self.vae,
    )
    training_batch = self._prepare_dit_inputs(training_batch, generator)
    training_batch = self._build_attention_metadata(training_batch)

    # Shallow copy keeps the lru_cache'd LongTensor index fields shared
    # with the original metadata; only the float ``VSA_sparsity`` differs
    # between the two views. deepcopy here would materialize a fresh copy
    # of all four cached index tensors on every training step.
    training_batch.attn_metadata_vsa = copy.copy(training_batch.attn_metadata)
    if training_batch.attn_metadata is not None:
        training_batch.attn_metadata.VSA_sparsity = 0.0  # type: ignore[attr-defined]

    return training_batch
fastvideo.train.models.longcat

LongCat model plugin package.

Classes
Modules
fastvideo.train.models.longcat.longcat

LongCat model plugin (per-role instance).

Classes
fastvideo.train.models.longcat.longcat.LongCatModel
LongCatModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 12.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel

LongCat per-role model for training and distillation.

Source code in fastvideo/train/models/longcat/longcat.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 12.0,
    enable_gradient_checkpointing_type: str | None = None,
    transformer_override_safetensor: str | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=disable_custom_init_weights,
        flow_shift=self._validate_flow_shift(flow_shift),
        enable_gradient_checkpointing_type=enable_gradient_checkpointing_type,
        transformer_override_safetensor=transformer_override_safetensor,
    )
Methods:
fastvideo.train.models.longcat.longcat.LongCatModel.predict_noise
predict_noise(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Adapt LongCat's sign convention to FineTuneMethod's target.

LongCatTransformer3DModel is pretrained to output the clean - noise direction; LongCatDenoisingStage (the bidirectional inference pipeline) explicitly negates the transformer output before handing it to FlowMatchEulerDiscreteScheduler.step. Training methods on the other hand (FineTuneMethod, DiffusionForcingSFTMethod) target noise - clean directly (the standard rectified-flow velocity Wan uses).

Without the negation here, the loss MSE pushes the transformer toward noise - clean, flipping its native output sign over training. Inference then applies its own negation on top, so the scheduler receives the wrong direction and produces noise even while the training loss is dropping. Verified empirically on a 100-step LongCat overfit run: step 0 generated meaningful video, step 100 was pure noise despite low loss.

Negating in predict_noise keeps the transformer's pretrained sign convention intact while presenting the training methods with a Wan-compatible pred ≈ noise - clean for MSE.

Source code in fastvideo/train/models/longcat/longcat.py
def predict_noise(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Adapt LongCat's sign convention to FineTuneMethod's target.

    ``LongCatTransformer3DModel`` is pretrained to output the
    ``clean - noise`` direction; ``LongCatDenoisingStage`` (the
    bidirectional inference pipeline) explicitly negates the
    transformer output before handing it to
    ``FlowMatchEulerDiscreteScheduler.step``. Training methods on
    the other hand (``FineTuneMethod``,
    ``DiffusionForcingSFTMethod``) target ``noise - clean``
    directly (the standard rectified-flow velocity Wan uses).

    Without the negation here, the loss MSE pushes the transformer
    toward ``noise - clean``, flipping its native output sign over
    training. Inference then applies its own negation on top, so
    the scheduler receives the wrong direction and produces noise
    even while the training loss is dropping. Verified empirically
    on a 100-step LongCat overfit run: step 0 generated meaningful
    video, step 100 was pure noise despite low loss.

    Negating in ``predict_noise`` keeps the transformer's
    pretrained sign convention intact while presenting the
    training methods with a Wan-compatible
    ``pred ≈ noise - clean`` for MSE.
    """
    pred = super().predict_noise(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    return -pred
fastvideo.train.models.wan

Wan model plugin package.

Classes
Modules
fastvideo.train.models.wan.wan

Wan model plugin (per-role instance).

Classes
fastvideo.train.models.wan.wan.WanModel
WanModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None, lora: LoraConfig | dict[str, Any] | None = None)

Bases: ModelBase

Wan per-role model: owns transformer + noise_scheduler.

Source code in fastvideo/train/models/wan/wan.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    super().__init__(
        trainable=trainable,
        lora=lora,
    )
    self._init_from = str(init_from)

    self.transformer = self._load_transformer(
        init_from=self._init_from,
        trainable=self._trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        training_config=training_config,
        transformer_override_safetensor=(transformer_override_safetensor),
    )

    self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))

    # Filled by init_preprocessors (student only).
    self.vae: Any = None
    self.training_config: TrainingConfig = training_config
    self.dataloader: Any = None
    self.validator: Any = None
    self.start_step: int = 0

    self.world_group: Any = None
    self.sp_group: Any = None

    self.negative_prompt_embeds: (torch.Tensor | None) = None
    self.negative_prompt_attention_mask: (torch.Tensor | None) = None

    # Timestep mechanics.
    self.timestep_shift: float = float(flow_shift)
    self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
    self.min_timestep: int = 0
    self.max_timestep: int = self.num_train_timestep
Functions:
fastvideo.train.models.wan.wan_causal

Wan causal model plugin (per-role instance, streaming/cache).

Classes
fastvideo.train.models.wan.wan_causal.WanCausalModel
WanCausalModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None, lora: LoraConfig | dict[str, Any] | None = None)

Bases: WanModel, CausalModelBase

Wan per-role model with causal/streaming primitives.

Source code in fastvideo/train/models/wan/wan_causal.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
        lora=lora,
    )
    self._streaming_caches: (dict[tuple[int, str], _StreamingCaches]) = {}
Functions:

fastvideo.train.trainer

Classes

Functions:

fastvideo.train.utils

Modules

fastvideo.train.utils.builder

Assembly: build method + dataloader from a _target_-based config.

Classes
Functions:
fastvideo.train.utils.builder.build_from_config
build_from_config(cfg: RunConfig) -> tuple[TrainingConfig, TrainingMethod, Any, int]

Build method + dataloader from a v3 run config.

  1. Instantiate each model in cfg.models via _target_.
  2. Resolve the method class from cfg.method["_target_"] and construct it with (cfg=cfg, role_models=...).
  3. Return (training_args, method, dataloader, start_step).
Source code in fastvideo/train/utils/builder.py
def build_from_config(cfg: RunConfig, ) -> tuple[TrainingConfig, TrainingMethod, Any, int]:
    """Build method + dataloader from a v3 run config.

    1. Instantiate each model in ``cfg.models`` via ``_target_``.
    2. Resolve the method class from ``cfg.method["_target_"]``
       and construct it with ``(cfg=cfg, role_models=...)``.
    3. Return ``(training_args, method, dataloader, start_step)``.
    """
    from fastvideo.train.models.base import ModelBase

    # --- 1. Build role model instances ---
    role_models: dict[str, ModelBase] = {}
    for role, model_cfg in cfg.models.items():
        model = instantiate(model_cfg, training_config=cfg.training)
        if not isinstance(model, ModelBase):
            raise TypeError(f"models.{role}._target_ must resolve to a "
                            f"ModelBase subclass, got {type(model).__name__}")
        role_models[role] = model

    # --- 2. Build method ---
    method_cfg = dict(cfg.method)
    method_target = str(method_cfg.pop("_target_"))
    method_cls = resolve_target(method_target)

    # The student model provides the dataloader.
    student = role_models.get("student")

    method = method_cls(
        cfg=cfg,
        role_models=role_models,
    )

    # --- 3. Gather dataloader and start_step ---
    dataloader = (getattr(student, "dataloader", None) if student is not None else None)
    start_step = int(getattr(student, "start_step", 0) if student is not None else 0)

    return cfg.training, method, dataloader, start_step
fastvideo.train.utils.checkpoint
Classes
fastvideo.train.utils.checkpoint.CheckpointManager
CheckpointManager(*, method: Any, dataloader: Any, output_dir: str, config: CheckpointConfig, callbacks: Any | None = None, raw_config: dict[str, Any] | None = None)

Role-based checkpoint manager for training runtime.

  • Checkpoint policy lives in YAML (via TrainingArgs fields).
  • Resume path is typically provided via CLI (--resume-from-checkpoint).
Source code in fastvideo/train/utils/checkpoint.py
def __init__(
    self,
    *,
    method: Any,
    dataloader: Any,
    output_dir: str,
    config: CheckpointConfig,
    callbacks: Any | None = None,
    raw_config: dict[str, Any] | None = None,
) -> None:
    self.method = method
    self.dataloader = dataloader
    self.output_dir = str(output_dir)
    self.config = config
    self._callbacks = callbacks
    self._raw_config = raw_config
    self._last_saved_step: int | None = None
Methods:
fastvideo.train.utils.checkpoint.CheckpointManager.load_metadata staticmethod
load_metadata(checkpoint_dir: str | Path) -> dict[str, Any]

Read metadata.json from a checkpoint dir.

Source code in fastvideo/train/utils/checkpoint.py
@staticmethod
def load_metadata(checkpoint_dir: str | Path, ) -> dict[str, Any]:
    """Read ``metadata.json`` from a checkpoint dir."""
    meta_path = Path(checkpoint_dir) / "metadata.json"
    if not meta_path.is_file():
        raise FileNotFoundError(f"No metadata.json in {checkpoint_dir}")
    with open(meta_path, encoding="utf-8") as f:
        return json.load(f)  # type: ignore[no-any-return]
fastvideo.train.utils.checkpoint.CheckpointManager.load_rng_snapshot
load_rng_snapshot(checkpoint_path: str) -> None

Restore per-rank RNG state from the snapshot file.

Must be called AFTER dcp.load and after iter(dataloader) so no later operation can clobber the restored state.

Source code in fastvideo/train/utils/checkpoint.py
def load_rng_snapshot(
    self,
    checkpoint_path: str,
) -> None:
    """Restore per-rank RNG state from the snapshot file.

    Must be called AFTER ``dcp.load`` **and** after
    ``iter(dataloader)`` so no later operation can
    clobber the restored state.
    """
    resolved = _resolve_resume_checkpoint(
        checkpoint_path,
        output_dir=self.output_dir,
    )
    if resolved is None:
        return
    rank = _rank()
    rng_path = resolved / f"rng_state_rank{rank}.pt"
    if not rng_path.is_file():
        # Fall back to legacy single-file snapshot.
        rng_path = resolved / "rng_state.pt"
    if not rng_path.is_file():
        logger.warning(
            "No rng_state in %s; skipping "
            "RNG snapshot restore.",
            resolved,
        )
        return

    rng = torch.load(
        rng_path,
        map_location="cpu",
        weights_only=False,
    )
    if "torch_rng" in rng:
        torch.set_rng_state(rng["torch_rng"])
    if "python_rng" in rng:
        random.setstate(rng["python_rng"])
    if "numpy_rng" in rng:
        np.random.set_state(rng["numpy_rng"])

    torch.cuda.set_rng_state(rng["cuda_rng"])
    self.method.cuda_generator.set_state(rng["gen_cuda"])
    logger.info(
        "Restored RNG snapshot from %s",
        rng_path,
    )
Functions:
fastvideo.train.utils.config

Training run config (_target_ based YAML).

Classes
fastvideo.train.utils.config.RunConfig dataclass
RunConfig(models: dict[str, dict[str, Any]], method: dict[str, Any], training: TrainingConfig, callbacks: dict[str, dict[str, Any]], raw: dict[str, Any])

Parsed run config loaded from YAML.

Methods:
fastvideo.train.utils.config.RunConfig.resolved_config
resolved_config() -> dict[str, Any]

Return a fully-resolved config dict with defaults.

Suitable for logging to W&B so that every parameter (including defaults) is visible.

Source code in fastvideo/train/utils/config.py
def resolved_config(self) -> dict[str, Any]:
    """Return a fully-resolved config dict with defaults.

    Suitable for logging to W&B so that every parameter
    (including defaults) is visible.
    """
    import dataclasses

    def _safe_asdict(obj: Any) -> Any:
        if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
            return {
                f.name: _safe_asdict(getattr(obj, f.name))
                for f in dataclasses.fields(obj) if not callable(getattr(obj, f.name))
            }
        if isinstance(obj, dict):
            return {k: _safe_asdict(v) for k, v in obj.items()}
        if isinstance(obj, list | tuple):
            return type(obj)(_safe_asdict(v) for v in obj)
        return obj

    resolved: dict[str, Any] = {}
    resolved["models"] = dict(self.models)
    resolved["method"] = dict(self.method)
    resolved["training"] = _safe_asdict(self.training)
    resolved["callbacks"] = dict(self.callbacks)
    return resolved
Functions:
fastvideo.train.utils.config.load_run_config
load_run_config(path: str, overrides: list[str] | None = None) -> RunConfig

Load a run config from YAML.

Expected top-level keys: models, method, training (nested), and optionally callbacks and pipeline.

Parameters:

Name Type Description Default
path str

Path to the YAML config file.

required
overrides list[str] | None

Optional list of CLI override tokens, e.g. ["--training.distributed.num_gpus", "4"]. Dotted keys map to nested YAML paths.

None
Source code in fastvideo/train/utils/config.py
def load_run_config(
    path: str,
    overrides: list[str] | None = None,
) -> RunConfig:
    """Load a run config from YAML.

    Expected top-level keys: ``models``, ``method``,
    ``training`` (nested), and optionally ``callbacks``
    and ``pipeline``.

    Args:
        path: Path to the YAML config file.
        overrides: Optional list of CLI override tokens,
            e.g. ``["--training.distributed.num_gpus", "4"]``.
            Dotted keys map to nested YAML paths.
    """
    path = _resolve_existing_file(path)
    with open(path, encoding="utf-8") as f:
        raw = yaml.safe_load(f)
    cfg = _require_mapping(raw, where=path)

    # Apply CLI overrides before building typed config.
    if overrides:
        parsed = _parse_cli_overrides(overrides)
        _apply_overrides(cfg, parsed)
        logger.info("Applied CLI overrides: %s", parsed)

    # --- models ---
    models_raw = _require_mapping(cfg.get("models"), where="models")
    models: dict[str, dict[str, Any]] = {}
    for role, model_cfg_raw in models_raw.items():
        role_str = _require_str(role, where="models.<role>")
        model_cfg = _require_mapping(model_cfg_raw, where=f"models.{role_str}")
        if "_target_" not in model_cfg:
            raise ValueError(f"models.{role_str} must have a "
                             "'_target_' key")
        models[role_str] = dict(model_cfg)

    # --- method ---
    method_raw = _require_mapping(cfg.get("method"), where="method")
    if "_target_" not in method_raw:
        raise ValueError("method must have a '_target_' key")
    method = dict(method_raw)

    # --- callbacks ---
    callbacks_raw = cfg.get("callbacks", None)
    if callbacks_raw is None:
        callbacks: dict[str, dict[str, Any]] = {}
    else:
        callbacks = _require_mapping(callbacks_raw, where="callbacks")

    # --- pipeline config ---
    pipeline_config = _parse_pipeline_config(cfg, models=models)

    # --- training config ---
    training_raw = _require_mapping(cfg.get("training"), where="training")
    t = dict(training_raw)
    training = _build_training_config(t, models=models, pipeline_config=pipeline_config)

    return RunConfig(
        models=models,
        method=method,
        training=training,
        callbacks=callbacks,
        raw=cfg,
    )
fastvideo.train.utils.config.require_bool
require_bool(mapping: dict[str, Any], key: str, *, default: bool | None = None, where: str | None = None) -> bool

Read a bool value.

Source code in fastvideo/train/utils/config.py
def require_bool(
    mapping: dict[str, Any],
    key: str,
    *,
    default: bool | None = None,
    where: str | None = None,
) -> bool:
    """Read a bool value."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            return default
        raise ValueError(f"Missing required key {loc!r}")
    if not isinstance(raw, bool):
        raise ValueError(f"{loc} must be a bool, "
                         f"got {type(raw).__name__}")
    return raw
fastvideo.train.utils.config.require_choice
require_choice(mapping: dict[str, Any], key: str, choices: set[str] | frozenset[str], *, default: str | None = None, where: str | None = None) -> str

Read a string that must be one of choices.

Source code in fastvideo/train/utils/config.py
def require_choice(
    mapping: dict[str, Any],
    key: str,
    choices: set[str] | frozenset[str],
    *,
    default: str | None = None,
    where: str | None = None,
) -> str:
    """Read a string that must be one of *choices*."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            if default not in choices:
                raise ValueError(f"Default {default!r} not in {choices}")
            return default
        raise ValueError(f"Missing required key {loc!r}")
    if not isinstance(raw, str) or not raw.strip():
        raise ValueError(f"{loc} must be a non-empty string, "
                         f"got {type(raw).__name__}")
    val = raw.strip().lower()
    if val not in choices:
        raise ValueError(f"{loc} must be one of {sorted(choices)}, "
                         f"got {raw!r}")
    return val
fastvideo.train.utils.config.require_non_negative_float
require_non_negative_float(mapping: dict[str, Any], key: str, *, default: float | None = None, where: str | None = None) -> float

Read a float that must be >= 0.

Source code in fastvideo/train/utils/config.py
def require_non_negative_float(
    mapping: dict[str, Any],
    key: str,
    *,
    default: float | None = None,
    where: str | None = None,
) -> float:
    """Read a float that must be >= 0."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            return default
        raise ValueError(f"Missing required key {loc!r}")
    val = get_optional_float(mapping, key, where=loc)
    if val is None or val < 0.0:
        raise ValueError(f"{loc} must be a non-negative float, "
                         f"got {raw!r}")
    return val
fastvideo.train.utils.config.require_non_negative_int
require_non_negative_int(mapping: dict[str, Any], key: str, *, default: int | None = None, where: str | None = None) -> int

Read an int that must be >= 0.

Source code in fastvideo/train/utils/config.py
def require_non_negative_int(
    mapping: dict[str, Any],
    key: str,
    *,
    default: int | None = None,
    where: str | None = None,
) -> int:
    """Read an int that must be >= 0."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            return default
        raise ValueError(f"Missing required key {loc!r}")
    val = get_optional_int(mapping, key, where=loc)
    if val is None or val < 0:
        raise ValueError(f"{loc} must be a non-negative integer, "
                         f"got {raw!r}")
    return val
fastvideo.train.utils.config.require_positive_int
require_positive_int(mapping: dict[str, Any], key: str, *, default: int | None = None, where: str | None = None) -> int

Read an int that must be > 0.

Source code in fastvideo/train/utils/config.py
def require_positive_int(
    mapping: dict[str, Any],
    key: str,
    *,
    default: int | None = None,
    where: str | None = None,
) -> int:
    """Read an int that must be > 0."""
    loc = where or key
    raw = mapping.get(key)
    if raw is None:
        if default is not None:
            return default
        raise ValueError(f"Missing required key {loc!r}")
    val = get_optional_int(mapping, key, where=loc)
    if val is None or val <= 0:
        raise ValueError(f"{loc} must be a positive integer, got {raw!r}")
    return val
fastvideo.train.utils.dataloader
Functions:
fastvideo.train.utils.dataloader.build_parquet_matrixgame2_train_dataloader
build_parquet_matrixgame2_train_dataloader(data_config: DataConfig, *, parquet_schema: Any) -> Any

Build a parquet dataloader for Matrix-Game 2.0 datasets.

Source code in fastvideo/train/utils/dataloader.py
def build_parquet_matrixgame2_train_dataloader(
    data_config: DataConfig,
    *,
    parquet_schema: Any,
) -> Any:
    """Build a parquet dataloader for Matrix-Game 2.0 datasets."""

    from fastvideo.dataset import (
        build_parquet_map_style_dataloader, )

    _dataset, dataloader = (build_parquet_map_style_dataloader(
        data_config.data_path,
        data_config.train_batch_size,
        num_data_workers=(data_config.dataloader_num_workers),
        parquet_schema=parquet_schema,
        cfg_rate=float(data_config.training_cfg_rate or 0.0),
        drop_last=True,
        text_padding_length=512,
        seed=int(data_config.seed or 0),
    ))
    return dataloader
fastvideo.train.utils.dataloader.build_parquet_t2v_train_dataloader
build_parquet_t2v_train_dataloader(data_config: DataConfig, *, text_len: int, parquet_schema: Any) -> Any

Build a parquet dataloader for T2V-style datasets.

Source code in fastvideo/train/utils/dataloader.py
def build_parquet_t2v_train_dataloader(
    data_config: DataConfig,
    *,
    text_len: int,
    parquet_schema: Any,
) -> Any:
    """Build a parquet dataloader for T2V-style datasets."""

    from fastvideo.dataset import (
        build_parquet_map_style_dataloader, )

    _dataset, dataloader = (build_parquet_map_style_dataloader(
        data_config.data_path,
        data_config.train_batch_size,
        num_data_workers=(data_config.dataloader_num_workers),
        parquet_schema=parquet_schema,
        cfg_rate=data_config.training_cfg_rate,
        drop_last=True,
        text_padding_length=int(text_len),
        seed=int(data_config.seed or 0),
    ))
    return dataloader
fastvideo.train.utils.instantiate

_target_-based instantiation utilities.

These helpers resolve a dotted Python path to a class and instantiate it, filtering constructor kwargs through inspect.signature so that only recognized parameters are forwarded. Unrecognized keys emit a warning rather than raising — this keeps YAML configs forward-compatible when a class drops a parameter in a later version.

Functions:
fastvideo.train.utils.instantiate.instantiate
instantiate(cfg: dict[str, Any], **extra: Any) -> Any

Instantiate the class specified by cfg["_target_"].

All remaining keys in cfg (minus _target_) plus any extra keyword arguments are forwarded to the constructor. Keys that do not match an __init__ parameter are silently warned about and dropped, so callers can safely pass a superset.

Source code in fastvideo/train/utils/instantiate.py
def instantiate(cfg: dict[str, Any], **extra: Any) -> Any:
    """Instantiate the class specified by ``cfg["_target_"]``.

    All remaining keys in *cfg* (minus ``_target_``) plus any *extra*
    keyword arguments are forwarded to the constructor.  Keys that do
    not match an ``__init__`` parameter are silently warned about and
    dropped, so callers can safely pass a superset.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f"instantiate() expects a dict with '_target_', "
                        f"got {type(cfg).__name__}")
    target_str = cfg.get("_target_")
    if target_str is None:
        raise KeyError("Config dict is missing '_target_' key")

    cls = resolve_target(str(target_str))
    kwargs: dict[str, Any] = {k: v for k, v in cfg.items() if k != "_target_"}
    kwargs.update(extra)

    sig = inspect.signature(cls.__init__)  # type: ignore[misc]
    params = sig.parameters
    has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())

    if not has_var_keyword:
        valid_names = {
            name
            for name, p in params.items() if p.kind in (
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
                inspect.Parameter.KEYWORD_ONLY,
            )
        }
        valid_names.discard("self")
        unrecognized = set(kwargs) - valid_names
        if unrecognized:
            warnings.warn(
                f"instantiate({target_str}): dropping unrecognized "
                f"kwargs {sorted(unrecognized)}",
                stacklevel=2,
            )
            for key in unrecognized:
                del kwargs[key]

    return cls(**kwargs)
fastvideo.train.utils.instantiate.resolve_target
resolve_target(target: str) -> type

Import and return the class (or callable) at target.

target must be a fully-qualified dotted path, e.g. "fastvideo.train.models.wan.wan.WanModel".

Source code in fastvideo/train/utils/instantiate.py
def resolve_target(target: str) -> type:
    """Import and return the class (or callable) at *target*.

    *target* must be a fully-qualified dotted path, e.g.
    ``"fastvideo.train.models.wan.wan.WanModel"``.
    """
    if not isinstance(target, str) or not target.strip():
        raise ValueError(f"_target_ must be a non-empty dotted path string, "
                         f"got {target!r}")
    target = target.strip()
    parts = target.rsplit(".", 1)
    if len(parts) != 2:
        raise ValueError(f"_target_ must contain at least one dot "
                         f"(module.ClassName), got {target!r}")
    module_path, attr_name = parts
    try:
        module = importlib.import_module(module_path)
    except ModuleNotFoundError as exc:
        raise ImportError(f"Cannot import module {module_path!r} "
                          f"(from _target_={target!r})") from exc
    try:
        cls = getattr(module, attr_name)
    except AttributeError as exc:
        raise ImportError(f"Module {module_path!r} has no attribute "
                          f"{attr_name!r} (from _target_={target!r})") from exc
    return cls
fastvideo.train.utils.lora

Training-side LoRA utilities for fastvideo.train model plugins.

Classes
fastvideo.train.utils.lora.LoraConfig dataclass
LoraConfig(enable: bool = False, rank: int | None = None, alpha: int | None = None, target_modules: list[str] | None = None)

Structured LoRA settings for one fastvideo.train model role.

Parsed from the nested models.<role>.lora YAML block::

lora:
  enable: true                       # default false
  rank: 16
  alpha: 32                          # defaults to rank when omitted
  target_modules: [to_q, to_k, to_v, to_out]

enable is an explicit on/off switch so a config states its intent plainly: the presence of rank alone never silently flips a run into LoRA-only training. When enable is false a still-present rank is ignored (with an INFO log), so a configured-but-off block is valid.

Methods:
fastvideo.train.utils.lora.LoraConfig.coerce classmethod
coerce(obj: LoraConfig | dict[str, Any] | None) -> LoraConfig | None

Normalize a raw YAML mapping (or existing config) into a LoraConfig.

Returns None when no lora block was given, which callers treat as "LoRA not configured" — identical in effect to enable: false.

Source code in fastvideo/train/utils/lora.py
@classmethod
def coerce(
    cls,
    obj: LoraConfig | dict[str, Any] | None,
) -> LoraConfig | None:
    """Normalize a raw YAML mapping (or existing config) into a LoraConfig.

    Returns ``None`` when no ``lora`` block was given, which callers treat
    as "LoRA not configured" — identical in effect to ``enable: false``.
    """
    if obj is None:
        return None
    if isinstance(obj, LoraConfig):
        return obj
    if not isinstance(obj, dict):
        raise TypeError("models.<role>.lora must be a mapping or LoraConfig, got "
                        f"{type(obj).__name__}")
    unknown = set(obj) - set(_LORA_CONFIG_KEYS)
    if unknown:
        logger.warning("LoraConfig: ignoring unrecognized lora keys %s "
                       "(valid keys: %s)", sorted(unknown), list(_LORA_CONFIG_KEYS))
    return cls(
        enable=bool(obj.get("enable", False)),
        rank=obj.get("rank"),
        alpha=obj.get("alpha"),
        target_modules=obj.get("target_modules"),
    )
Functions:
fastvideo.train.utils.lora.enable_lora_training
enable_lora_training(transformer: Module, *, lora_rank: int, lora_alpha: int | None = None, lora_target_modules: Sequence[str] | None = None) -> int

Replace supported linear layers with trainable LoRA wrappers.

Returns the number of layers converted to LoRA.

Source code in fastvideo/train/utils/lora.py
def enable_lora_training(
    transformer: torch.nn.Module,
    *,
    lora_rank: int,
    lora_alpha: int | None = None,
    lora_target_modules: Sequence[str] | None = None,
) -> int:
    """Replace supported linear layers with trainable LoRA wrappers.

    Returns the number of layers converted to LoRA.
    """

    rank = int(lora_rank)
    if rank <= 0:
        raise ValueError(f"lora_rank must be > 0, got {lora_rank!r}")

    alpha = int(lora_alpha) if lora_alpha is not None else rank
    target_modules = list(lora_target_modules or DEFAULT_LORA_TARGET_MODULES)
    arch_config = getattr(
        getattr(transformer, "config", None),
        "arch_config",
        None,
    )
    excluded_modules = list(getattr(arch_config, "exclude_lora_layers", []), )

    transformer.requires_grad_(False)

    replacements: list[tuple[str, BaseLayerWithLoRA]] = []
    for module_name, module in transformer.named_modules():
        if not module_name:
            continue
        if not _is_target_layer(module_name, target_modules):
            continue
        if _is_excluded_layer(module_name, excluded_modules):
            continue

        lora_layer = get_lora_layer(
            module,
            lora_rank=rank,
            lora_alpha=alpha,
            training_mode=True,
        )
        if lora_layer is None:
            continue
        replacements.append((module_name, lora_layer))

    if not replacements:
        raise ValueError("No LoRA-compatible layers were found for the requested "
                         f"target modules: {target_modules}")

    for module_name, lora_layer in replacements:
        replace_submodule(transformer, module_name, lora_layer)

    _replicate_lora_parameters(transformer)
    transformer.train()

    logger.info(
        "Enabled LoRA training with rank=%d alpha=%d on %d layers",
        rank,
        alpha,
        len(replacements),
    )
    return len(replacements)
fastvideo.train.utils.module_state
Functions:
fastvideo.train.utils.module_state.apply_trainable
apply_trainable(module: Module, *, trainable: bool) -> Module

Apply train/eval mode + requires_grad based on a role's trainable flag.

Source code in fastvideo/train/utils/module_state.py
def apply_trainable(module: torch.nn.Module, *, trainable: bool) -> torch.nn.Module:
    """Apply train/eval mode + requires_grad based on a role's trainable flag."""

    module.requires_grad_(bool(trainable))
    if trainable:
        module.train()
    else:
        module.eval()
    return module
fastvideo.train.utils.moduleloader
Classes
Functions:
fastvideo.train.utils.moduleloader.load_module_from_path
load_module_from_path(*, model_path: str, module_type: str, training_config: TrainingConfig, disable_custom_init_weights: bool = False, override_transformer_cls_name: str | None = None, transformer_override_safetensor: str | None = None) -> Module

Load a single pipeline component module.

Accepts a TrainingConfig and internally builds the TrainingArgs needed by PipelineComponentLoader.

Source code in fastvideo/train/utils/moduleloader.py
def load_module_from_path(
    *,
    model_path: str,
    module_type: str,
    training_config: TrainingConfig,
    disable_custom_init_weights: bool = False,
    override_transformer_cls_name: str | None = None,
    transformer_override_safetensor: str | None = None,
) -> torch.nn.Module:
    """Load a single pipeline component module.

    Accepts a ``TrainingConfig`` and internally builds the
    ``TrainingArgs`` needed by ``PipelineComponentLoader``.
    """
    fastvideo_args: Any = _make_training_args(training_config, model_path=model_path)

    local_model_path = maybe_download_model(model_path)
    config = verify_model_config_and_directory(local_model_path)

    if module_type not in config:
        raise ValueError(f"Module {module_type!r} not found in "
                         f"config at {local_model_path}")

    module_info = config[module_type]
    if module_info is None:
        raise ValueError(f"Module {module_type!r} has null value in "
                         f"config at {local_model_path}")

    transformers_or_diffusers, _architecture = module_info
    component_path = os.path.join(local_model_path, module_type)

    old_override: str | None = None
    if override_transformer_cls_name is not None:
        old_override = getattr(
            fastvideo_args,
            "override_transformer_cls_name",
            None,
        )
        fastvideo_args.override_transformer_cls_name = str(override_transformer_cls_name)

    if transformer_override_safetensor:
        fastvideo_args.init_weights_from_safetensors = str(transformer_override_safetensor)

    if disable_custom_init_weights:
        fastvideo_args._loading_teacher_critic_model = True
    try:
        module = PipelineComponentLoader.load_module(
            module_name=module_type,
            component_model_path=component_path,
            transformers_or_diffusers=(transformers_or_diffusers),
            fastvideo_args=fastvideo_args,
        )
    finally:
        if disable_custom_init_weights and hasattr(fastvideo_args, "_loading_teacher_critic_model"):
            del fastvideo_args._loading_teacher_critic_model
        if override_transformer_cls_name is not None:
            if old_override is None:
                if hasattr(
                        fastvideo_args,
                        "override_transformer_cls_name",
                ):
                    fastvideo_args.override_transformer_cls_name = (None)
            else:
                fastvideo_args.override_transformer_cls_name = (old_override)

    if not isinstance(module, torch.nn.Module):
        raise TypeError(f"Loaded {module_type!r} is not a "
                        f"torch.nn.Module: {type(module)}")
    return module
fastvideo.train.utils.moduleloader.make_inference_args
make_inference_args(tc: TrainingConfig, *, model_path: str) -> TrainingArgs

Build a TrainingArgs for inference (validation / pipelines).

Source code in fastvideo/train/utils/moduleloader.py
def make_inference_args(
    tc: TrainingConfig,
    *,
    model_path: str,
) -> TrainingArgs:
    """Build a TrainingArgs for inference (validation / pipelines)."""
    args = _make_training_args(tc, model_path=model_path)
    args.inference_mode = True
    args.mode = ExecutionMode.INFERENCE
    args.dit_cpu_offload = True
    args.VSA_sparsity = tc.vsa_sparsity
    return args
fastvideo.train.utils.negative_prompt

Per-rank negative-prompt encoding shared by training model plugins.

Encoding the negative prompt only on rank 0 and broadcasting (the previous Wan path) ran Pipeline.from_pretrained asymmetrically across ranks, which deadlocked on any collective fired during text-encoder load (FSDP device-mesh init, weight broadcast, etc.). The text encoder is small and only loaded once at startup, so loading it on every rank sidesteps the deadlock entirely.

Classes
Functions:
fastvideo.train.utils.negative_prompt.encode_negative_prompt
encode_negative_prompt(training_config: TrainingConfig, *, prompt: str, device: device, dtype: dtype, encoder_index: int = 0) -> tuple[Tensor, Tensor]

Per-rank encode of prompt using encoder encoder_index.

Reads pipeline_config.text_encoder_configs[encoder_index] so the encoder class (e.g. UMT5 for Wan) and tokenizer kwargs match the inference path, and applies the matching postprocess_text_funcs entry. Returns (embeds, mask) on device cast to dtype.

Source code in fastvideo/train/utils/negative_prompt.py
def encode_negative_prompt(
    training_config: TrainingConfig,
    *,
    prompt: str,
    device: torch.device,
    dtype: torch.dtype,
    encoder_index: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Per-rank encode of ``prompt`` using encoder ``encoder_index``.

    Reads ``pipeline_config.text_encoder_configs[encoder_index]`` so the
    encoder class (e.g. UMT5 for Wan) and tokenizer kwargs match the
    inference path, and applies the matching ``postprocess_text_funcs``
    entry. Returns ``(embeds, mask)`` on ``device`` cast to ``dtype``.
    """
    tc = training_config
    pipeline_config = tc.pipeline_config
    if pipeline_config is None:
        raise ValueError("training_config.pipeline_config is required for negative "
                         "prompt encoding")

    encoder_configs = pipeline_config.text_encoder_configs
    postprocess_funcs = pipeline_config.postprocess_text_funcs
    preprocess_funcs = getattr(pipeline_config, "preprocess_text_funcs", None)

    if encoder_index < 0 or encoder_index >= len(encoder_configs):
        raise IndexError(f"encoder_index {encoder_index} out of range for "
                         f"text_encoder_configs (len={len(encoder_configs)})")
    encoder_config = encoder_configs[encoder_index]
    postprocess_text = postprocess_funcs[encoder_index]
    preprocess_text = (preprocess_funcs[encoder_index] if preprocess_funcs is not None else None)

    # HF convention: text_encoder / tokenizer for index 0,
    # text_encoder_2 / tokenizer_2 for index 1, etc.
    suffix = "" if encoder_index == 0 else f"_{encoder_index + 1}"
    encoder_subdir = f"text_encoder{suffix}"
    tokenizer_subdir = f"tokenizer{suffix}"

    model_path = maybe_download_model(tc.model_path)
    inference_args = make_inference_args(tc, model_path=model_path)
    # Keep the encoder on-device; CPU offload would init an FSDP device
    # mesh and reintroduce the collective at load time.
    inference_args.text_encoder_cpu_offload = False

    loader = TextEncoderLoader()
    text_encoder = loader.load(
        os.path.join(model_path, encoder_subdir),
        inference_args,
    ).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, tokenizer_subdir))

    tok_kwargs = dict(encoder_config.tokenizer_kwargs)
    text = preprocess_text(prompt) if preprocess_text is not None else prompt

    with torch.no_grad(), set_forward_context(
            current_timestep=0,
            attn_metadata=None,
    ):
        text_inputs = tokenizer(text, **tok_kwargs).to(device)
        outputs = text_encoder(
            input_ids=text_inputs.input_ids,
            attention_mask=text_inputs.attention_mask,
        )
        # Mirror TextEncodingStage: postprocess reads outputs.attention_mask.
        outputs.attention_mask = text_inputs["attention_mask"]
        embeds = postprocess_text(outputs).to(device=device, dtype=dtype)
        mask = text_inputs["attention_mask"].to(device=device, dtype=dtype)

    del text_encoder, tokenizer

    return embeds, mask
fastvideo.train.utils.optimizer
Functions:
fastvideo.train.utils.optimizer.build_optimizer_and_scheduler
build_optimizer_and_scheduler(*, params: list[Parameter], optimizer_config: OptimizerConfig, loop_config: TrainingLoopConfig, learning_rate: float, betas: tuple[float, float], scheduler_name: str) -> tuple[Optimizer, object]

Build an AdamW optimizer and LR scheduler.

Returns (optimizer, lr_scheduler) so the caller can store them as method-level attributes.

Source code in fastvideo/train/utils/optimizer.py
def build_optimizer_and_scheduler(
    *,
    params: list[torch.nn.Parameter],
    optimizer_config: OptimizerConfig,
    loop_config: TrainingLoopConfig,
    learning_rate: float,
    betas: tuple[float, float],
    scheduler_name: str,
) -> tuple[torch.optim.Optimizer, object]:
    """Build an AdamW optimizer and LR scheduler.

    Returns ``(optimizer, lr_scheduler)`` so the caller can store them
    as method-level attributes.
    """
    if not params:
        raise ValueError("No trainable parameters passed to "
                         "build_optimizer_and_scheduler")

    optimizer = torch.optim.AdamW(
        params,
        lr=float(learning_rate),
        betas=betas,
        weight_decay=float(optimizer_config.weight_decay),
        eps=1e-8,
    )

    scheduler = get_scheduler(
        str(scheduler_name),
        optimizer=optimizer,
        num_warmup_steps=int(optimizer_config.lr_warmup_steps),
        num_training_steps=int(loop_config.max_train_steps),
        num_cycles=int(optimizer_config.lr_num_cycles),
        power=float(optimizer_config.lr_power),
        min_lr_ratio=float(optimizer_config.min_lr_ratio),
        last_epoch=-1,
    )

    return optimizer, scheduler
fastvideo.train.utils.tracking
Functions:
fastvideo.train.utils.tracking.build_tracker
build_tracker(tracker_config: TrackerConfig, checkpoint_config: CheckpointConfig, *, config: dict[str, Any] | None) -> Any

Build a tracker instance for a distillation run.

Source code in fastvideo/train/utils/tracking.py
def build_tracker(
    tracker_config: TrackerConfig,
    checkpoint_config: CheckpointConfig,
    *,
    config: dict[str, Any] | None,
) -> Any:
    """Build a tracker instance for a distillation run."""

    world_group = get_world_group()

    trackers = list(tracker_config.trackers)
    if not trackers and str(tracker_config.project_name):
        trackers.append(Trackers.WANDB.value)
    if world_group.rank != 0:
        trackers = []

    tracker_log_dir = (checkpoint_config.output_dir or os.getcwd())
    if trackers:
        tracker_log_dir = os.path.join(tracker_log_dir, "tracker")

    tracker_config_dict = config if trackers else None
    tracker_run_name = tracker_config.run_name or None
    project = (tracker_config.project_name or "fastvideo")

    return initialize_trackers(
        trackers,
        experiment_name=project,
        config=tracker_config_dict,
        log_dir=tracker_log_dir,
        run_name=tracker_run_name,
    )
fastvideo.train.utils.training_config

Typed training config — replaces TrainingArgs.

Classes