Bases: TrainingMethod
DMD2 distillation algorithm (method layer).
Owns role model instances directly:
- self.student — trainable student :class:ModelBase
- self.teacher — frozen teacher :class:ModelBase
- self.critic — trainable critic :class:ModelBase
Source code in fastvideo/train/methods/distribution_matching/dmd2.py
| def __init__(
self,
*,
cfg: Any,
role_models: dict[str, ModelBase],
) -> None:
super().__init__(cfg=cfg, role_models=role_models)
if "student" not in role_models:
raise ValueError("DMD2Method requires role 'student'")
if "teacher" not in role_models:
raise ValueError("DMD2Method requires role 'teacher'")
if "critic" not in role_models:
raise ValueError("DMD2Method requires role 'critic'")
self.teacher = role_models["teacher"]
self.critic = role_models["critic"]
if not self.student._trainable:
raise ValueError("DMD2Method requires student to be trainable")
if self.teacher._trainable:
raise ValueError("DMD2Method requires teacher to be "
"non-trainable")
if not self.critic._trainable:
raise ValueError("DMD2Method requires critic to be trainable")
self._cfg_uncond = self._parse_cfg_uncond()
self._rollout_mode = self._parse_rollout_mode()
self._denoising_step_list: torch.Tensor | None = (None)
# Initialize preprocessors on student.
self.student.init_preprocessors(self.training_config)
self._init_optimizers_and_schedulers()
|