Skip to content

methods

Classes

fastvideo.train.methods.TrainingMethod

TrainingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: Module, ABC

Base training method (algorithm layer).

Subclasses own their role models (student, teacher, critic, …) as plain attributes and manage optimizers directly — no RoleManager or RoleHandle.

The constructor receives role_models (a dict[str, ModelBase]) and a cfg object. It calls init_preprocessors on the student and builds self.role_modules for FSDP wrapping.

A single shared CUDA RNG generator (cuda_generator) is created in :meth:on_train_start. All torch.randn / torch.randint calls in methods and models must use this generator instead of relying on global RNG state.

Source code in fastvideo/train/methods/base.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__()
    self.tracker: Any | None = None
    self._role_models: dict[str, ModelBase] = dict(role_models)

    self.student = role_models["student"]
    self.training_config = cfg.training
    self.method_config: dict[str, Any] = dict(cfg.method)
    self.validation_config: dict[str, Any] = dict(getattr(cfg, "validation", {}) or {})

    # Build nn.ModuleDict for FSDP / checkpoint visibility.
    self.role_modules = torch.nn.ModuleDict()
    for role, model in role_models.items():
        mods: dict[str, torch.nn.Module] = {}
        transformer = getattr(model, "transformer", None)
        if isinstance(transformer, torch.nn.Module):
            mods["transformer"] = transformer
        if mods:
            self.role_modules[role] = torch.nn.ModuleDict(mods)

Functions

fastvideo.train.methods.TrainingMethod.checkpoint_state
checkpoint_state() -> dict[str, Any]

Return DCP-ready checkpoint state for all trainable roles.

Keys follow the convention: roles.<role>.<module>, optimizers.<role>, schedulers.<role>, random_state.*.

EMA state is managed by the EMACallback and is checkpointed through the callback state mechanism.

Source code in fastvideo/train/methods/base.py
def checkpoint_state(self) -> dict[str, Any]:
    """Return DCP-ready checkpoint state for all trainable roles.

    Keys follow the convention:
    ``roles.<role>.<module>``, ``optimizers.<role>``,
    ``schedulers.<role>``, ``random_state.*``.

    EMA state is managed by the ``EMACallback`` and is
    checkpointed through the callback state mechanism.
    """
    states: dict[str, Any] = {}

    for role, model in self._role_models.items():
        if not getattr(model, "_trainable", False):
            continue

        modules: dict[str, torch.nn.Module] = {}
        if model.transformer is not None:
            modules["transformer"] = model.transformer

        container = _RoleModuleContainer(modules)

        for module_name, module in modules.items():
            states[f"roles.{role}.{module_name}"] = ModelWrapper(module)

        opt = self._optimizer_dict.get(role)
        if opt is not None:
            states[f"optimizers.{role}"] = OptimizerWrapper(container, opt)

        sched = self._lr_scheduler_dict.get(role)
        if sched is not None:
            states[f"schedulers.{role}"] = SchedulerWrapper(sched)

    return states
fastvideo.train.methods.TrainingMethod.get_grad_clip_targets
get_grad_clip_targets(iteration: int) -> dict[str, Module]

Return modules whose gradients should be clipped.

Override in subclasses to add/conditionally include modules (e.g. critic, conditionally student). Default: student transformer.

Source code in fastvideo/train/methods/base.py
def get_grad_clip_targets(
    self,
    iteration: int,
) -> dict[str, torch.nn.Module]:
    """Return modules whose gradients should be clipped.

    Override in subclasses to add/conditionally include
    modules (e.g. critic, conditionally student).
    Default: student transformer.
    """
    return {"student": self.student.transformer}
fastvideo.train.methods.TrainingMethod.seed_optimizer_state_for_resume
seed_optimizer_state_for_resume() -> None

Seed optimizer state so DCP can load saved state.

A fresh optimizer has empty state (exp_avg, exp_avg_sq, step are only created on the first optimizer.step()). DCP needs matching entries to load into; without them the saved optimizer state is silently dropped.

Source code in fastvideo/train/methods/base.py
def seed_optimizer_state_for_resume(self) -> None:
    """Seed optimizer state so DCP can load saved state.

    A fresh optimizer has empty state (exp_avg, exp_avg_sq,
    step are only created on the first optimizer.step()).
    DCP needs matching entries to load into; without them
    the saved optimizer state is silently dropped.
    """
    for opt in self.get_optimizers(0):
        for group in opt.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                if len(opt.state.get(p, {})) > 0:
                    continue
                opt.state[p] = {
                    "step": torch.tensor(0.0),
                    "exp_avg": torch.zeros_like(p),
                    "exp_avg_sq": torch.zeros_like(p),
                }

Modules

fastvideo.train.methods.base

Classes

fastvideo.train.methods.base.TrainingMethod
TrainingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: Module, ABC

Base training method (algorithm layer).

Subclasses own their role models (student, teacher, critic, …) as plain attributes and manage optimizers directly — no RoleManager or RoleHandle.

The constructor receives role_models (a dict[str, ModelBase]) and a cfg object. It calls init_preprocessors on the student and builds self.role_modules for FSDP wrapping.

A single shared CUDA RNG generator (cuda_generator) is created in :meth:on_train_start. All torch.randn / torch.randint calls in methods and models must use this generator instead of relying on global RNG state.

Source code in fastvideo/train/methods/base.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__()
    self.tracker: Any | None = None
    self._role_models: dict[str, ModelBase] = dict(role_models)

    self.student = role_models["student"]
    self.training_config = cfg.training
    self.method_config: dict[str, Any] = dict(cfg.method)
    self.validation_config: dict[str, Any] = dict(getattr(cfg, "validation", {}) or {})

    # Build nn.ModuleDict for FSDP / checkpoint visibility.
    self.role_modules = torch.nn.ModuleDict()
    for role, model in role_models.items():
        mods: dict[str, torch.nn.Module] = {}
        transformer = getattr(model, "transformer", None)
        if isinstance(transformer, torch.nn.Module):
            mods["transformer"] = transformer
        if mods:
            self.role_modules[role] = torch.nn.ModuleDict(mods)
Functions
fastvideo.train.methods.base.TrainingMethod.checkpoint_state
checkpoint_state() -> dict[str, Any]

Return DCP-ready checkpoint state for all trainable roles.

Keys follow the convention: roles.<role>.<module>, optimizers.<role>, schedulers.<role>, random_state.*.

EMA state is managed by the EMACallback and is checkpointed through the callback state mechanism.

Source code in fastvideo/train/methods/base.py
def checkpoint_state(self) -> dict[str, Any]:
    """Return DCP-ready checkpoint state for all trainable roles.

    Keys follow the convention:
    ``roles.<role>.<module>``, ``optimizers.<role>``,
    ``schedulers.<role>``, ``random_state.*``.

    EMA state is managed by the ``EMACallback`` and is
    checkpointed through the callback state mechanism.
    """
    states: dict[str, Any] = {}

    for role, model in self._role_models.items():
        if not getattr(model, "_trainable", False):
            continue

        modules: dict[str, torch.nn.Module] = {}
        if model.transformer is not None:
            modules["transformer"] = model.transformer

        container = _RoleModuleContainer(modules)

        for module_name, module in modules.items():
            states[f"roles.{role}.{module_name}"] = ModelWrapper(module)

        opt = self._optimizer_dict.get(role)
        if opt is not None:
            states[f"optimizers.{role}"] = OptimizerWrapper(container, opt)

        sched = self._lr_scheduler_dict.get(role)
        if sched is not None:
            states[f"schedulers.{role}"] = SchedulerWrapper(sched)

    return states
fastvideo.train.methods.base.TrainingMethod.get_grad_clip_targets
get_grad_clip_targets(iteration: int) -> dict[str, Module]

Return modules whose gradients should be clipped.

Override in subclasses to add/conditionally include modules (e.g. critic, conditionally student). Default: student transformer.

Source code in fastvideo/train/methods/base.py
def get_grad_clip_targets(
    self,
    iteration: int,
) -> dict[str, torch.nn.Module]:
    """Return modules whose gradients should be clipped.

    Override in subclasses to add/conditionally include
    modules (e.g. critic, conditionally student).
    Default: student transformer.
    """
    return {"student": self.student.transformer}
fastvideo.train.methods.base.TrainingMethod.seed_optimizer_state_for_resume
seed_optimizer_state_for_resume() -> None

Seed optimizer state so DCP can load saved state.

A fresh optimizer has empty state (exp_avg, exp_avg_sq, step are only created on the first optimizer.step()). DCP needs matching entries to load into; without them the saved optimizer state is silently dropped.

Source code in fastvideo/train/methods/base.py
def seed_optimizer_state_for_resume(self) -> None:
    """Seed optimizer state so DCP can load saved state.

    A fresh optimizer has empty state (exp_avg, exp_avg_sq,
    step are only created on the first optimizer.step()).
    DCP needs matching entries to load into; without them
    the saved optimizer state is silently dropped.
    """
    for opt in self.get_optimizers(0):
        for group in opt.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                if len(opt.state.get(p, {})) > 0:
                    continue
                opt.state[p] = {
                    "step": torch.tensor(0.0),
                    "exp_avg": torch.zeros_like(p),
                    "exp_avg_sq": torch.zeros_like(p),
                }

Functions

fastvideo.train.methods.distribution_matching

Classes

fastvideo.train.methods.distribution_matching.DMD2Method
DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
fastvideo.train.methods.distribution_matching.SelfForcingMethod
SelfForcingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: DMD2Method

Self-Forcing DMD2 (distribution matching) method.

Requires a causal student implementing CausalModelBase.

Source code in fastvideo/train/methods/distribution_matching/self_forcing.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(
        cfg=cfg,
        role_models=role_models,
    )

    # Validate causal student.
    if not isinstance(self.student, CausalModelBase):
        raise ValueError("SelfForcingMethod requires a causal student "
                         "implementing CausalModelBase.")

    if self._rollout_mode != "simulate":
        raise ValueError("SelfForcingMethod only supports "
                         "method_config.rollout_mode='simulate'")

    mcfg = self.method_config

    chunk_size = get_optional_int(
        mcfg,
        "chunk_size",
        where="method_config.chunk_size",
    )
    if chunk_size is None:
        chunk_size = 3
    if chunk_size <= 0:
        raise ValueError("method_config.chunk_size must be a positive "
                         f"integer, got {chunk_size}")
    self._chunk_size = int(chunk_size)

    sample_type_raw = mcfg.get("student_sample_type", "sde")
    sample_type = _require_str(
        sample_type_raw,
        where="method_config.student_sample_type",
    )
    sample_type = sample_type.strip().lower()
    if sample_type not in {"sde", "ode"}:
        raise ValueError("method_config.student_sample_type must be one "
                         f"of {{sde, ode}}, got {sample_type_raw!r}")
    self._student_sample_type: Literal["sde", "ode"] = (
        sample_type  # type: ignore[assignment]
    )

    same_step_raw = mcfg.get("same_step_across_blocks", False)
    if same_step_raw is None:
        same_step_raw = False
    self._same_step_across_blocks = _require_bool(
        same_step_raw,
        where="method_config.same_step_across_blocks",
    )

    last_step_raw = mcfg.get("last_step_only", False)
    if last_step_raw is None:
        last_step_raw = False
    self._last_step_only = _require_bool(
        last_step_raw,
        where="method_config.last_step_only",
    )

    context_noise = get_optional_float(
        mcfg,
        "context_noise",
        where="method_config.context_noise",
    )
    if context_noise is None:
        context_noise = 0.0
    if context_noise < 0.0:
        raise ValueError("method_config.context_noise must be >= 0, "
                         f"got {context_noise}")
    self._context_noise = float(context_noise)

    enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
    if enable_grad_raw is None:
        enable_grad_raw = True
    self._enable_gradient_in_rollout = _require_bool(
        enable_grad_raw,
        where="method_config.enable_gradient_in_rollout",
    )

    start_grad_frame = get_optional_int(
        mcfg,
        "start_gradient_frame",
        where="method_config.start_gradient_frame",
    )
    if start_grad_frame is None:
        start_grad_frame = 0
    if start_grad_frame < 0:
        raise ValueError("method_config.start_gradient_frame must be "
                         f">= 0, got {start_grad_frame}")
    self._start_gradient_frame = int(start_grad_frame)

    shift = float(getattr(
        self.training_config.pipeline_config,
        "flow_shift",
        0.0,
    ) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=True,
    )

    self._sf_denoising_step_list: torch.Tensor | None = None

Modules

fastvideo.train.methods.distribution_matching.dmd2

DMD2 distillation method (algorithm layer).

Classes
fastvideo.train.methods.distribution_matching.dmd2.DMD2Method
DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions
fastvideo.train.methods.distribution_matching.self_forcing

Self-Forcing distillation method (algorithm layer).

Classes
fastvideo.train.methods.distribution_matching.self_forcing.SelfForcingMethod
SelfForcingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: DMD2Method

Self-Forcing DMD2 (distribution matching) method.

Requires a causal student implementing CausalModelBase.

Source code in fastvideo/train/methods/distribution_matching/self_forcing.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(
        cfg=cfg,
        role_models=role_models,
    )

    # Validate causal student.
    if not isinstance(self.student, CausalModelBase):
        raise ValueError("SelfForcingMethod requires a causal student "
                         "implementing CausalModelBase.")

    if self._rollout_mode != "simulate":
        raise ValueError("SelfForcingMethod only supports "
                         "method_config.rollout_mode='simulate'")

    mcfg = self.method_config

    chunk_size = get_optional_int(
        mcfg,
        "chunk_size",
        where="method_config.chunk_size",
    )
    if chunk_size is None:
        chunk_size = 3
    if chunk_size <= 0:
        raise ValueError("method_config.chunk_size must be a positive "
                         f"integer, got {chunk_size}")
    self._chunk_size = int(chunk_size)

    sample_type_raw = mcfg.get("student_sample_type", "sde")
    sample_type = _require_str(
        sample_type_raw,
        where="method_config.student_sample_type",
    )
    sample_type = sample_type.strip().lower()
    if sample_type not in {"sde", "ode"}:
        raise ValueError("method_config.student_sample_type must be one "
                         f"of {{sde, ode}}, got {sample_type_raw!r}")
    self._student_sample_type: Literal["sde", "ode"] = (
        sample_type  # type: ignore[assignment]
    )

    same_step_raw = mcfg.get("same_step_across_blocks", False)
    if same_step_raw is None:
        same_step_raw = False
    self._same_step_across_blocks = _require_bool(
        same_step_raw,
        where="method_config.same_step_across_blocks",
    )

    last_step_raw = mcfg.get("last_step_only", False)
    if last_step_raw is None:
        last_step_raw = False
    self._last_step_only = _require_bool(
        last_step_raw,
        where="method_config.last_step_only",
    )

    context_noise = get_optional_float(
        mcfg,
        "context_noise",
        where="method_config.context_noise",
    )
    if context_noise is None:
        context_noise = 0.0
    if context_noise < 0.0:
        raise ValueError("method_config.context_noise must be >= 0, "
                         f"got {context_noise}")
    self._context_noise = float(context_noise)

    enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
    if enable_grad_raw is None:
        enable_grad_raw = True
    self._enable_gradient_in_rollout = _require_bool(
        enable_grad_raw,
        where="method_config.enable_gradient_in_rollout",
    )

    start_grad_frame = get_optional_int(
        mcfg,
        "start_gradient_frame",
        where="method_config.start_gradient_frame",
    )
    if start_grad_frame is None:
        start_grad_frame = 0
    if start_grad_frame < 0:
        raise ValueError("method_config.start_gradient_frame must be "
                         f">= 0, got {start_grad_frame}")
    self._start_gradient_frame = int(start_grad_frame)

    shift = float(getattr(
        self.training_config.pipeline_config,
        "flow_shift",
        0.0,
    ) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=True,
    )

    self._sf_denoising_step_list: torch.Tensor | None = None
Functions

fastvideo.train.methods.fine_tuning

Classes

fastvideo.train.methods.fine_tuning.DiffusionForcingSFTMethod
DiffusionForcingSFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Diffusion-forcing SFT (DFSFT): train only student with inhomogeneous timesteps.

Source code in fastvideo/train/methods/fine_tuning/dfsft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DFSFT requires role 'student'")
    if not self.student._trainable:
        raise ValueError("DFSFT requires student to be trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    self._chunk_size = self._parse_chunk_size(self.method_config.get("chunk_size", None))
    self._timestep_index_range = (self._parse_timestep_index_range())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
fastvideo.train.methods.fine_tuning.FineTuneMethod
FineTuneMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Supervised finetuning: only student participates.

Source code in fastvideo/train/methods/fine_tuning/finetune.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("FineTuneMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("FineTuneMethod requires student to be "
                         "trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()

Modules

fastvideo.train.methods.fine_tuning.dfsft

Diffusion-forcing SFT method (DFSFT; algorithm layer).

Classes
fastvideo.train.methods.fine_tuning.dfsft.DiffusionForcingSFTMethod
DiffusionForcingSFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Diffusion-forcing SFT (DFSFT): train only student with inhomogeneous timesteps.

Source code in fastvideo/train/methods/fine_tuning/dfsft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DFSFT requires role 'student'")
    if not self.student._trainable:
        raise ValueError("DFSFT requires student to be trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    self._chunk_size = self._parse_chunk_size(self.method_config.get("chunk_size", None))
    self._timestep_index_range = (self._parse_timestep_index_range())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions
fastvideo.train.methods.fine_tuning.finetune

Supervised finetuning method (algorithm layer).

Classes
fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod
FineTuneMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Supervised finetuning: only student participates.

Source code in fastvideo/train/methods/fine_tuning/finetune.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("FineTuneMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("FineTuneMethod requires student to be "
                         "trainable")
    self._attn_kind: Literal["dense", "vsa"] = (self._infer_attn_kind())

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()
Functions

fastvideo.train.methods.knowledge_distillation

Classes

fastvideo.train.methods.knowledge_distillation.KDCausalMethod
KDCausalMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: KDMethod

KD for causal Wan: per-frame block-quantized timestep sampling.

Identical to :class:KDMethod except single_train_step samples a per-frame denoising step index (block-quantized to groups of num_frames_per_block frames) instead of one index per batch. This matches the legacy ODEInitTrainingPipeline training scheme required by causal / streaming student models.

Additional YAML field under method::

num_frames_per_block: 3   # frames sharing the same noise level
Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    self._num_frames_per_block: int = int(self.method_config.get("num_frames_per_block", 3))
    if self._num_frames_per_block < 1:
        raise ValueError("num_frames_per_block must be >= 1")
fastvideo.train.methods.knowledge_distillation.KDMethod
KDMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Knowledge Distillation training method.

Trains the student with MSE loss on teacher ODE trajectories cached to method_config.teacher_path_cache.

Roles
  • student (required, trainable): the model being distilled.
  • teacher (optional, non-trainable): used to generate the cache on first run; freed from GPU memory afterwards.

If the cache is incomplete and no teacher is configured, an error is raised at the start of training.

Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("KDMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("KDMethod requires student to be trainable")

    mcfg = self.method_config

    # --- Parse method config ---
    raw_t_list = mcfg.get("t_list")
    if not isinstance(raw_t_list, list) or not raw_t_list:
        raise ValueError("method_config.t_list must be a non-empty list "
                         "of integer timestep values, e.g. "
                         "[999, 937, 833, 624, 0]")
    self._t_list: list[int] = [int(t) for t in raw_t_list]

    raw_steps = mcfg.get("student_sample_steps")
    if raw_steps is None:
        raw_steps = len(self._t_list) - 1
    self._num_steps: int = int(raw_steps)
    if len(self._t_list) != self._num_steps + 1:
        raise ValueError(f"len(t_list)={len(self._t_list)} must equal "
                         f"student_sample_steps+1={self._num_steps + 1}")

    cache_dir = mcfg.get("teacher_path_cache")
    if not cache_dir:
        raise ValueError("method_config.teacher_path_cache must be set")
    self._cache_dir: str = str(cache_dir)

    self._teacher_guidance_scale: float = float(mcfg.get("teacher_guidance_scale", 1.0))
    self._teacher_inference_steps: int = int(mcfg.get("teacher_inference_steps", 48))

    # --- Optional teacher ---
    self.teacher: ModelBase | None = role_models.get("teacher")
    if self.teacher is not None and getattr(self.teacher, "_trainable", False):
        raise ValueError("KDMethod requires teacher to be non-trainable "
                         "(set trainable: false in models.teacher)")

    # --- Build parquet dataloader via student.init_preprocessors ---
    self.student.init_preprocessors(self.training_config)

    # Wrap the parquet dataloader so we can swap it after generation.
    self._source_loader = self.student.dataloader
    self._kd_wrapper = _KDDataLoaderWrapper(self._source_loader)
    self.student.dataloader = self._kd_wrapper  # builder captures this

    # --- Build SelfForcingFlowMatchScheduler for sigma lookups ---
    # num_inference_steps=1000 gives a dense grid accurate for any t.
    tc = self.training_config
    self._flow_shift = float(getattr(tc.pipeline_config, "flow_shift", 0.0) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=self._flow_shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=False,
    )

    # --- Student optimizer / scheduler (same as FineTuneMethod) ---
    self._init_optimizers_and_schedulers()

Modules

fastvideo.train.methods.knowledge_distillation.kd

Knowledge Distillation method for ODE-init training.

Trains a student model with MSE loss to reproduce a teacher model's multi-step ODE denoising trajectories. The resulting checkpoint (exported via dcp_to_diffusers) serves as the ode_init weight initialization for downstream Self-Forcing training.

Teacher path generation is cached to disk so it only runs once. Interrupted generation resumes from the last completed sample.

Typical YAML::

models:
  student:
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    trainable: true
  teacher:           # omit once cache is complete
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-14B-Diffusers
    trainable: false
    disable_custom_init_weights: true

method:
  _target_: fastvideo.train.methods.knowledge_distillation.kd.KDMethod
  teacher_path_cache: /data/kd_cache/wan14b_4step
  t_list: [999, 937, 833, 624, 0]   # integer timesteps
  student_sample_steps: 4
  teacher_guidance_scale: 1.0
Classes
fastvideo.train.methods.knowledge_distillation.kd.KDCausalMethod
KDCausalMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: KDMethod

KD for causal Wan: per-frame block-quantized timestep sampling.

Identical to :class:KDMethod except single_train_step samples a per-frame denoising step index (block-quantized to groups of num_frames_per_block frames) instead of one index per batch. This matches the legacy ODEInitTrainingPipeline training scheme required by causal / streaming student models.

Additional YAML field under method::

num_frames_per_block: 3   # frames sharing the same noise level
Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    self._num_frames_per_block: int = int(self.method_config.get("num_frames_per_block", 3))
    if self._num_frames_per_block < 1:
        raise ValueError("num_frames_per_block must be >= 1")
fastvideo.train.methods.knowledge_distillation.kd.KDMethod
KDMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Knowledge Distillation training method.

Trains the student with MSE loss on teacher ODE trajectories cached to method_config.teacher_path_cache.

Roles
  • student (required, trainable): the model being distilled.
  • teacher (optional, non-trainable): used to generate the cache on first run; freed from GPU memory afterwards.

If the cache is incomplete and no teacher is configured, an error is raised at the start of training.

Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("KDMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("KDMethod requires student to be trainable")

    mcfg = self.method_config

    # --- Parse method config ---
    raw_t_list = mcfg.get("t_list")
    if not isinstance(raw_t_list, list) or not raw_t_list:
        raise ValueError("method_config.t_list must be a non-empty list "
                         "of integer timestep values, e.g. "
                         "[999, 937, 833, 624, 0]")
    self._t_list: list[int] = [int(t) for t in raw_t_list]

    raw_steps = mcfg.get("student_sample_steps")
    if raw_steps is None:
        raw_steps = len(self._t_list) - 1
    self._num_steps: int = int(raw_steps)
    if len(self._t_list) != self._num_steps + 1:
        raise ValueError(f"len(t_list)={len(self._t_list)} must equal "
                         f"student_sample_steps+1={self._num_steps + 1}")

    cache_dir = mcfg.get("teacher_path_cache")
    if not cache_dir:
        raise ValueError("method_config.teacher_path_cache must be set")
    self._cache_dir: str = str(cache_dir)

    self._teacher_guidance_scale: float = float(mcfg.get("teacher_guidance_scale", 1.0))
    self._teacher_inference_steps: int = int(mcfg.get("teacher_inference_steps", 48))

    # --- Optional teacher ---
    self.teacher: ModelBase | None = role_models.get("teacher")
    if self.teacher is not None and getattr(self.teacher, "_trainable", False):
        raise ValueError("KDMethod requires teacher to be non-trainable "
                         "(set trainable: false in models.teacher)")

    # --- Build parquet dataloader via student.init_preprocessors ---
    self.student.init_preprocessors(self.training_config)

    # Wrap the parquet dataloader so we can swap it after generation.
    self._source_loader = self.student.dataloader
    self._kd_wrapper = _KDDataLoaderWrapper(self._source_loader)
    self.student.dataloader = self._kd_wrapper  # builder captures this

    # --- Build SelfForcingFlowMatchScheduler for sigma lookups ---
    # num_inference_steps=1000 gives a dense grid accurate for any t.
    tc = self.training_config
    self._flow_shift = float(getattr(tc.pipeline_config, "flow_shift", 0.0) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=self._flow_shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=False,
    )

    # --- Student optimizer / scheduler (same as FineTuneMethod) ---
    self._init_optimizers_and_schedulers()
Functions