Skip to content

distribution_matching

Classes

fastvideo.train.methods.distribution_matching.DMD2Method

DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()

fastvideo.train.methods.distribution_matching.SelfForcingMethod

SelfForcingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: DMD2Method

Self-Forcing DMD2 (distribution matching) method.

Requires a causal student implementing CausalModelBase.

Source code in fastvideo/train/methods/distribution_matching/self_forcing.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(
        cfg=cfg,
        role_models=role_models,
    )

    # Validate causal student.
    if not isinstance(self.student, CausalModelBase):
        raise ValueError("SelfForcingMethod requires a causal student "
                         "implementing CausalModelBase.")

    if self._rollout_mode != "simulate":
        raise ValueError("SelfForcingMethod only supports "
                         "method_config.rollout_mode='simulate'")

    mcfg = self.method_config

    chunk_size = get_optional_int(
        mcfg,
        "chunk_size",
        where="method_config.chunk_size",
    )
    if chunk_size is None:
        chunk_size = 3
    if chunk_size <= 0:
        raise ValueError("method_config.chunk_size must be a positive "
                         f"integer, got {chunk_size}")
    self._chunk_size = int(chunk_size)

    sample_type_raw = mcfg.get("student_sample_type", "sde")
    sample_type = _require_str(
        sample_type_raw,
        where="method_config.student_sample_type",
    )
    sample_type = sample_type.strip().lower()
    if sample_type not in {"sde", "ode"}:
        raise ValueError("method_config.student_sample_type must be one "
                         f"of {{sde, ode}}, got {sample_type_raw!r}")
    self._student_sample_type: Literal["sde", "ode"] = (
        sample_type  # type: ignore[assignment]
    )

    same_step_raw = mcfg.get("same_step_across_blocks", False)
    if same_step_raw is None:
        same_step_raw = False
    self._same_step_across_blocks = _require_bool(
        same_step_raw,
        where="method_config.same_step_across_blocks",
    )

    last_step_raw = mcfg.get("last_step_only", False)
    if last_step_raw is None:
        last_step_raw = False
    self._last_step_only = _require_bool(
        last_step_raw,
        where="method_config.last_step_only",
    )

    context_noise = get_optional_float(
        mcfg,
        "context_noise",
        where="method_config.context_noise",
    )
    if context_noise is None:
        context_noise = 0.0
    if context_noise < 0.0:
        raise ValueError("method_config.context_noise must be >= 0, "
                         f"got {context_noise}")
    self._context_noise = float(context_noise)

    enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
    if enable_grad_raw is None:
        enable_grad_raw = True
    self._enable_gradient_in_rollout = _require_bool(
        enable_grad_raw,
        where="method_config.enable_gradient_in_rollout",
    )

    start_grad_frame = get_optional_int(
        mcfg,
        "start_gradient_frame",
        where="method_config.start_gradient_frame",
    )
    if start_grad_frame is None:
        start_grad_frame = 0
    if start_grad_frame < 0:
        raise ValueError("method_config.start_gradient_frame must be "
                         f">= 0, got {start_grad_frame}")
    self._start_gradient_frame = int(start_grad_frame)

    shift = float(getattr(
        self.training_config.pipeline_config,
        "flow_shift",
        0.0,
    ) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=True,
    )

    self._sf_denoising_step_list: torch.Tensor | None = None

Modules

fastvideo.train.methods.distribution_matching.dmd2

DMD2 distillation method (algorithm layer).

Classes

fastvideo.train.methods.distribution_matching.dmd2.DMD2Method
DMD2Method(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DMD2 distillation algorithm (method layer).

Owns role model instances directly: - self.student — trainable student :class:ModelBase - self.teacher — frozen teacher :class:ModelBase - self.critic — trainable critic :class:ModelBase

Source code in fastvideo/train/methods/distribution_matching/dmd2.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)

    if "student" not in role_models:
        raise ValueError("DMD2Method requires role 'student'")
    if "teacher" not in role_models:
        raise ValueError("DMD2Method requires role 'teacher'")
    if "critic" not in role_models:
        raise ValueError("DMD2Method requires role 'critic'")

    self.teacher = role_models["teacher"]
    self.critic = role_models["critic"]

    if not self.student._trainable:
        raise ValueError("DMD2Method requires student to be trainable")
    if self.teacher._trainable:
        raise ValueError("DMD2Method requires teacher to be "
                         "non-trainable")
    if not self.critic._trainable:
        raise ValueError("DMD2Method requires critic to be trainable")
    self._cfg_uncond = self._parse_cfg_uncond()
    self._rollout_mode = self._parse_rollout_mode()
    self._denoising_step_list: torch.Tensor | None = (None)

    # Initialize preprocessors on student.
    self.student.init_preprocessors(self.training_config)

    self._init_optimizers_and_schedulers()

Functions

fastvideo.train.methods.distribution_matching.self_forcing

Self-Forcing distillation method (algorithm layer).

Classes

fastvideo.train.methods.distribution_matching.self_forcing.SelfForcingMethod
SelfForcingMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: DMD2Method

Self-Forcing DMD2 (distribution matching) method.

Requires a causal student implementing CausalModelBase.

Source code in fastvideo/train/methods/distribution_matching/self_forcing.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(
        cfg=cfg,
        role_models=role_models,
    )

    # Validate causal student.
    if not isinstance(self.student, CausalModelBase):
        raise ValueError("SelfForcingMethod requires a causal student "
                         "implementing CausalModelBase.")

    if self._rollout_mode != "simulate":
        raise ValueError("SelfForcingMethod only supports "
                         "method_config.rollout_mode='simulate'")

    mcfg = self.method_config

    chunk_size = get_optional_int(
        mcfg,
        "chunk_size",
        where="method_config.chunk_size",
    )
    if chunk_size is None:
        chunk_size = 3
    if chunk_size <= 0:
        raise ValueError("method_config.chunk_size must be a positive "
                         f"integer, got {chunk_size}")
    self._chunk_size = int(chunk_size)

    sample_type_raw = mcfg.get("student_sample_type", "sde")
    sample_type = _require_str(
        sample_type_raw,
        where="method_config.student_sample_type",
    )
    sample_type = sample_type.strip().lower()
    if sample_type not in {"sde", "ode"}:
        raise ValueError("method_config.student_sample_type must be one "
                         f"of {{sde, ode}}, got {sample_type_raw!r}")
    self._student_sample_type: Literal["sde", "ode"] = (
        sample_type  # type: ignore[assignment]
    )

    same_step_raw = mcfg.get("same_step_across_blocks", False)
    if same_step_raw is None:
        same_step_raw = False
    self._same_step_across_blocks = _require_bool(
        same_step_raw,
        where="method_config.same_step_across_blocks",
    )

    last_step_raw = mcfg.get("last_step_only", False)
    if last_step_raw is None:
        last_step_raw = False
    self._last_step_only = _require_bool(
        last_step_raw,
        where="method_config.last_step_only",
    )

    context_noise = get_optional_float(
        mcfg,
        "context_noise",
        where="method_config.context_noise",
    )
    if context_noise is None:
        context_noise = 0.0
    if context_noise < 0.0:
        raise ValueError("method_config.context_noise must be >= 0, "
                         f"got {context_noise}")
    self._context_noise = float(context_noise)

    enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
    if enable_grad_raw is None:
        enable_grad_raw = True
    self._enable_gradient_in_rollout = _require_bool(
        enable_grad_raw,
        where="method_config.enable_gradient_in_rollout",
    )

    start_grad_frame = get_optional_int(
        mcfg,
        "start_gradient_frame",
        where="method_config.start_gradient_frame",
    )
    if start_grad_frame is None:
        start_grad_frame = 0
    if start_grad_frame < 0:
        raise ValueError("method_config.start_gradient_frame must be "
                         f">= 0, got {start_grad_frame}")
    self._start_gradient_frame = int(start_grad_frame)

    shift = float(getattr(
        self.training_config.pipeline_config,
        "flow_shift",
        0.0,
    ) or 0.0)
    self._sf_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        num_train_timesteps=int(self.student.num_train_timesteps),
        shift=shift,
        sigma_min=0.0,
        extra_one_step=True,
        training=True,
    )

    self._sf_denoising_step_list: torch.Tensor | None = None

Functions