Skip to content

dmd2

DMD2 distillation method (algorithm layer).

Classes

fastvideo.train.methods.distribution_matching.dmd2.DMD2Method

DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()

Functions