Skip to content

kd

Knowledge Distillation method for ODE-init training.

Trains a student model with MSE loss to reproduce a teacher model's multi-step ODE denoising trajectories. The resulting checkpoint (exported via dcp_to_diffusers) serves as the ode_init weight initialization for downstream Self-Forcing training.

Teacher path generation is cached to disk so it only runs once. Interrupted generation resumes from the last completed sample.

Typical YAML::

models:
  student:
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    trainable: true
  teacher:           # omit once cache is complete
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-14B-Diffusers
    trainable: false
    disable_custom_init_weights: true

method:
  _target_: fastvideo.train.methods.knowledge_distillation.kd.KDMethod
  teacher_path_cache: /data/kd_cache/wan14b_4step
  t_list: [999, 937, 833, 624, 0]   # integer timesteps
  student_sample_steps: 4
  teacher_guidance_scale: 1.0

Classes

fastvideo.train.methods.knowledge_distillation.kd.KDCausalMethod

KDCausalMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: KDMethod

KD for causal Wan: per-frame block-quantized timestep sampling.

Identical to :class:KDMethod except single_train_step samples a per-frame denoising step index (block-quantized to groups of num_frames_per_block frames) instead of one index per batch. This matches the legacy ODEInitTrainingPipeline training scheme required by causal / streaming student models.

Additional YAML field under method::

num_frames_per_block: 3   # frames sharing the same noise level
Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    self._num_frames_per_block: int = int(self.method_config.get("num_frames_per_block", 3))
    if self._num_frames_per_block < 1:
        raise ValueError("num_frames_per_block must be >= 1")

fastvideo.train.methods.knowledge_distillation.kd.KDMethod

KDMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

Knowledge Distillation training method.

Trains the student with MSE loss on teacher ODE trajectories cached to method_config.teacher_path_cache.

Roles
  • student (required, trainable): the model being distilled.
  • teacher (optional, non-trainable): used to generate the cache on first run; freed from GPU memory afterwards.

If the cache is incomplete and no teacher is configured, an error is raised at the start of training.

Source code in fastvideo/train/methods/knowledge_distillation/kd.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("KDMethod requires role 'student'")
    if not self.student._trainable:
        raise ValueError("KDMethod requires student to be trainable")

    mcfg = self.method_config

    # --- Parse method config ---
    raw_t_list = mcfg.get("t_list")
    if not isinstance(raw_t_list, list) or not raw_t_list:
        raise ValueError("method_config.t_list must be a non-empty list "
                         "of integer timestep values, e.g. "
                         "[999, 937, 833, 624, 0]")
    self._t_list: list[int] = [int(t) for t in raw_t_list]

    raw_steps = mcfg.get("student_sample_steps")
    if raw_steps is None:
        raw_steps = len(self._t_list) - 1
    self._num_steps: int = int(raw_steps)
    if len(self._t_list) != self._num_steps + 1:
        raise ValueError(f"len(t_list)={len(self._t_list)} must equal "
                         f"student_sample_steps+1={self._num_steps + 1}")

    cache_dir = mcfg.get("teacher_path_cache")
    if not cache_dir:
        raise ValueError("method_config.teacher_path_cache must be set")
    self._cache_dir: str = str(cache_dir)

    self._teacher_guidance_scale: float = float(mcfg.get("teacher_guidance_scale", 1.0))
    self._teacher_inference_steps: int = int(mcfg.get("teacher_inference_steps", 48))

    # --- Optional teacher ---
    self.teacher: ModelBase | None = role_models.get("teacher")
    if self.teacher is not None and getattr(self.teacher, "_trainable", False):
        raise ValueError("KDMethod requires teacher to be non-trainable "
                         "(set trainable: false in models.teacher)")

    # --- Build parquet dataloader via student.init_preprocessors ---
    self.student.init_preprocessors(self.training_config)

    # Wrap the parquet dataloader so we can swap it after generation.
    self._source_loader = self.student.dataloader
    self._kd_wrapper = _KDDataLoaderWrapper(self._source_loader)
    self.student.dataloader = self._kd_wrapper  # builder captures this

    # --- Build SelfForcingFlowMatchScheduler for sigma lookups ---
    # num_inference_steps=1000 gives a dense grid accurate for any t.
    tc = self.training_config
    self._flow_shift = float(getattr(tc.pipeline_config, "flow_shift", 0.0) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=self._flow_shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=False,
    )

    # --- Student optimizer / scheduler (same as FineTuneMethod) ---
    self._init_optimizers_and_schedulers()

Functions