Skip to content

diffusion_nft

DiffusionNFT multi-reward policy optimization method.

Classes

fastvideo.train.methods.rl.diffusion_nft.DiffusionNFTMethod

DiffusionNFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DiffusionNFT-style RL for diffusion models.

This method owns the algorithm's sample-then-inner-train loop. One Trainer step corresponds to one DiffusionNFT outer epoch.

Source code in fastvideo/train/methods/rl/diffusion_nft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    if "old" not in role_models:
        raise ValueError("DiffusionNFTMethod requires role 'old'")
    if "reference" not in role_models:
        raise ValueError("DiffusionNFTMethod requires role 'reference'")
    if not self.student._trainable:
        raise ValueError("DiffusionNFTMethod requires a trainable student")

    self.old = role_models["old"]
    self.reference = role_models["reference"]
    self.student.init_preprocessors(self.training_config)

    self._sampling_config = self._parse_sampling_config()
    self._sampler = DiffusionSampler(self._sampling_config)
    self._validation_config = RLValidationConfig.from_mapping(self.method_config.get("validation"))
    self._validation_sampling_config = self._parse_validation_sampling_config()
    self._validation_sampler = DiffusionSampler(self._validation_sampling_config)
    self._validation_items: list[tuple[int, bool, dict[str, Any]]] | None = None
    self._sample_steps = int(self._sampling_config.num_steps)
    self._sample_train_batch_size = self._read_int(
        "sample_train_batch_size",
        max(1, int(self.training_config.data.train_batch_size or 1)),
    )
    self._train_batch_size = self._read_int("train_batch_size", self._sample_train_batch_size)
    self._num_batches_per_epoch = self._read_int("num_batches_per_epoch", 48)
    self._num_inner_epochs = self._read_int("num_inner_epochs", 1)
    self._num_video_per_prompt = self._read_int("num_video_per_prompt", 24)
    self._adv_clip_max = self._read_float("adv_clip_max", 5.0)
    self._timestep_fraction = self._read_float("timestep_fraction", 0.99)
    self._kl_beta = self._read_float("kl_beta", 0.0001)
    self._nft_beta = self._read_float("beta", 0.1)
    self._max_grad_norm = self._read_float("max_grad_norm", 1.0)
    self._decay_type = self._read_int("decay_type", 1)
    self._adv_mode = str(self.method_config.get("adv_mode", "all") or "all").strip().lower()
    self._terminal_progress = bool(self.method_config.get("terminal_progress", True))
    ema_config = self._parse_ema_config()
    self._ema_enabled = bool(ema_config["enabled"])
    self._ema_decay = float(ema_config["decay"])
    self._ema_update_after_step = int(ema_config["update_after_step"])
    self._validation_use_ema = bool(ema_config["validation"])
    self._student_ema: EMA_FSDP | None = None
    self._ema_update_count = 0
    self._trained_prompt_hashes: set[int] = set()
    if self._adv_mode not in {"all", "positive_only", "negative_only", "one_only", "binary"}:
        raise ValueError("method.adv_mode must be one of "
                         "{all, positive_only, negative_only, one_only, binary}")

    reward_fn = self.method_config.get("reward_fn", None)
    if not isinstance(reward_fn, dict) or not reward_fn:
        raise ValueError("method.reward_fn must be a non-empty mapping, "
                         "for example {pickscore: 1.0, clipscore: 1.0}")
    self._reward_fn_config = {str(k): float(v) for k, v in reward_fn.items()}
    unsupported = sorted(set(self._reward_fn_config) - {"pickscore", "clipscore"})
    if unsupported:
        raise ValueError(f"Unsupported DiffusionNFT reward(s): {unsupported}. "
                         "Only pickscore and clipscore are currently ported.")

    self._reward_scorer: Any | None = None
    self._init_optimizer_and_scheduler()

Functions: