def __init__(
self,
*,
cfg: Any,
role_models: dict[str, ModelBase],
) -> None:
super().__init__(cfg=cfg, role_models=role_models)
if "old" not in role_models:
raise ValueError("DiffusionNFTMethod requires role 'old'")
if "reference" not in role_models:
raise ValueError("DiffusionNFTMethod requires role 'reference'")
if not self.student._trainable:
raise ValueError("DiffusionNFTMethod requires a trainable student")
self.old = role_models["old"]
self.reference = role_models["reference"]
self.student.init_preprocessors(self.training_config)
self._sampling_config = self._parse_sampling_config()
self._sampler = DiffusionSampler(self._sampling_config)
self._validation_config = RLValidationConfig.from_mapping(self.method_config.get("validation"))
self._validation_sampling_config = self._parse_validation_sampling_config()
self._validation_sampler = DiffusionSampler(self._validation_sampling_config)
self._validation_items: list[tuple[int, bool, dict[str, Any]]] | None = None
self._sample_steps = int(self._sampling_config.num_steps)
self._sample_train_batch_size = self._read_int(
"sample_train_batch_size",
max(1, int(self.training_config.data.train_batch_size or 1)),
)
self._train_batch_size = self._read_int("train_batch_size", self._sample_train_batch_size)
self._num_batches_per_epoch = self._read_int("num_batches_per_epoch", 48)
self._num_inner_epochs = self._read_int("num_inner_epochs", 1)
self._num_video_per_prompt = self._read_int("num_video_per_prompt", 24)
self._adv_clip_max = self._read_float("adv_clip_max", 5.0)
self._timestep_fraction = self._read_float("timestep_fraction", 0.99)
self._kl_beta = self._read_float("kl_beta", 0.0001)
self._nft_beta = self._read_float("beta", 0.1)
self._max_grad_norm = self._read_float("max_grad_norm", 1.0)
self._decay_type = self._read_int("decay_type", 1)
self._adv_mode = str(self.method_config.get("adv_mode", "all") or "all").strip().lower()
self._terminal_progress = bool(self.method_config.get("terminal_progress", True))
ema_config = self._parse_ema_config()
self._ema_enabled = bool(ema_config["enabled"])
self._ema_decay = float(ema_config["decay"])
self._ema_update_after_step = int(ema_config["update_after_step"])
self._validation_use_ema = bool(ema_config["validation"])
self._student_ema: EMA_FSDP | None = None
self._ema_update_count = 0
self._trained_prompt_hashes: set[int] = set()
if self._adv_mode not in {"all", "positive_only", "negative_only", "one_only", "binary"}:
raise ValueError("method.adv_mode must be one of "
"{all, positive_only, negative_only, one_only, binary}")
reward_fn = self.method_config.get("reward_fn", None)
if not isinstance(reward_fn, dict) or not reward_fn:
raise ValueError("method.reward_fn must be a non-empty mapping, "
"for example {pickscore: 1.0, clipscore: 1.0}")
self._reward_fn_config = {str(k): float(v) for k, v in reward_fn.items()}
unsupported = sorted(set(self._reward_fn_config) - {"pickscore", "clipscore"})
if unsupported:
raise ValueError(f"Unsupported DiffusionNFT reward(s): {unsupported}. "
"Only pickscore and clipscore are currently ported.")
self._reward_scorer: Any | None = None
self._init_optimizer_and_scheduler()