def __init__(
self,
*,
cfg: Any,
role_models: dict[str, ModelBase],
) -> None:
super().__init__(
cfg=cfg,
role_models=role_models,
)
# Validate causal student.
if not isinstance(self.student, CausalModelBase):
raise ValueError("SelfForcingMethod requires a causal student "
"implementing CausalModelBase.")
if self._rollout_mode != "simulate":
raise ValueError("SelfForcingMethod only supports "
"method_config.rollout_mode='simulate'")
mcfg = self.method_config
chunk_size = get_optional_int(
mcfg,
"chunk_size",
where="method_config.chunk_size",
)
if chunk_size is None:
chunk_size = 3
if chunk_size <= 0:
raise ValueError("method_config.chunk_size must be a positive "
f"integer, got {chunk_size}")
self._chunk_size = int(chunk_size)
sample_type_raw = mcfg.get("student_sample_type", "sde")
sample_type = _require_str(
sample_type_raw,
where="method_config.student_sample_type",
)
sample_type = sample_type.strip().lower()
if sample_type not in {"sde", "ode"}:
raise ValueError("method_config.student_sample_type must be one "
f"of {{sde, ode}}, got {sample_type_raw!r}")
self._student_sample_type: Literal["sde", "ode"] = (
sample_type # type: ignore[assignment]
)
same_step_raw = mcfg.get("same_step_across_blocks", False)
if same_step_raw is None:
same_step_raw = False
self._same_step_across_blocks = _require_bool(
same_step_raw,
where="method_config.same_step_across_blocks",
)
last_step_raw = mcfg.get("last_step_only", False)
if last_step_raw is None:
last_step_raw = False
self._last_step_only = _require_bool(
last_step_raw,
where="method_config.last_step_only",
)
context_noise = get_optional_float(
mcfg,
"context_noise",
where="method_config.context_noise",
)
if context_noise is None:
context_noise = 0.0
if context_noise < 0.0:
raise ValueError("method_config.context_noise must be >= 0, "
f"got {context_noise}")
self._context_noise = float(context_noise)
enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True)
if enable_grad_raw is None:
enable_grad_raw = True
self._enable_gradient_in_rollout = _require_bool(
enable_grad_raw,
where="method_config.enable_gradient_in_rollout",
)
start_grad_frame = get_optional_int(
mcfg,
"start_gradient_frame",
where="method_config.start_gradient_frame",
)
if start_grad_frame is None:
start_grad_frame = 0
if start_grad_frame < 0:
raise ValueError("method_config.start_gradient_frame must be "
f">= 0, got {start_grad_frame}")
self._start_gradient_frame = int(start_grad_frame)
shift = float(getattr(
self.training_config.pipeline_config,
"flow_shift",
0.0,
) or 0.0)
self._sf_scheduler = SelfForcingFlowMatchScheduler(
num_inference_steps=1000,
num_train_timesteps=int(self.student.num_train_timesteps),
shift=shift,
sigma_min=0.0,
extra_one_step=True,
training=True,
)
self._sf_denoising_step_list: torch.Tensor | None = None