Skip to content

callbacks

Classes

fastvideo.train.callbacks.Callback

Base callback with no-op hooks.

Subclasses override whichever hooks they need. The training_config and method attributes are set by CallbackDict after instantiation.

fastvideo.train.callbacks.CallbackDict

CallbackDict(callback_configs: dict[str, dict[str, Any]], training_config: TrainingConfig)

Manages a collection of named callbacks.

Instantiates each callback from its _target_ config and dispatches hook calls to all registered callbacks.

Source code in fastvideo/train/callbacks/callback.py
def __init__(
    self,
    callback_configs: dict[str, dict[str, Any]],
    training_config: TrainingConfig,
) -> None:
    self._callbacks: dict[str, Callback] = {}
    if not callback_configs:
        return
    for name, cb_cfg in callback_configs.items():
        cb_cfg = dict(cb_cfg)
        if "_target_" not in cb_cfg:
            if name in _BUILTIN_CALLBACKS:
                cb_cfg["_target_"] = (_BUILTIN_CALLBACKS[name])
            else:
                logger.warning(
                    "Callback %r is missing "
                    "'_target_', skipping: %s",
                    name,
                    cb_cfg,
                )
                continue
        logger.info(
            "Instantiating callback %r: %s",
            name,
            cb_cfg,
        )
        cb = instantiate(cb_cfg)
        if not isinstance(cb, Callback):
            raise TypeError(f"Callback {name!r} resolved to "
                            f"{type(cb).__name__}, expected a "
                            f"Callback subclass.")
        cb.training_config = training_config
        cb._callback_dict = self
        self._callbacks[name] = cb

fastvideo.train.callbacks.EMACallback

EMACallback(*, decay: float = 0.9999, start_iter: int = 0)

Bases: Callback

Manage EMA shadow weights for the student transformer.

All configuration lives in the YAML callbacks.ema section:

.. code-block:: yaml

callbacks:
  ema:
    decay: 0.9999
    start_iter: 0

The callback creates an EMA_FSDP instance at train start, updates it after each optimizer step, and exposes an ema_context() context manager for temporarily swapping EMA weights into the live model (used by validation).

Source code in fastvideo/train/callbacks/ema.py
def __init__(
    self,
    *,
    decay: float = 0.9999,
    start_iter: int = 0,
) -> None:
    self._decay = float(decay)
    self._start_iter = int(start_iter)
    self._ema_started = False
    self.student_ema: EMA_FSDP | None = None

Functions

fastvideo.train.callbacks.EMACallback.ema_context
ema_context(transformer: Module) -> Generator[Module, None, None]

Temporarily swap EMA weights into transformer.

If EMA is not active, yields the transformer unchanged.

Source code in fastvideo/train/callbacks/ema.py
@contextlib.contextmanager
def ema_context(
    self,
    transformer: torch.nn.Module,
) -> Generator[torch.nn.Module, None, None]:
    """Temporarily swap EMA weights into *transformer*.

    If EMA is not active, yields the transformer unchanged.
    """
    if (self.student_ema is not None and self._ema_started):
        with self.student_ema.apply_to_model(transformer, ):
            yield transformer
    else:
        yield transformer

fastvideo.train.callbacks.GradNormClipCallback

GradNormClipCallback(*, max_grad_norm: float = 1.0, log_grad_norms: bool = True)

Bases: Callback

Clip gradient norms before the optimizer step.

max_grad_norm must be set explicitly in the callback config (callbacks.grad_clip.max_grad_norm).

Source code in fastvideo/train/callbacks/grad_clip.py
def __init__(
    self,
    *,
    max_grad_norm: float = 1.0,
    log_grad_norms: bool = True,
) -> None:
    self._max_grad_norm = float(max_grad_norm)
    self._log_grad_norms = bool(log_grad_norms)

fastvideo.train.callbacks.ValidationCallback

ValidationCallback(*, pipeline_target: str, dataset_file: str, every_steps: int = 100, sampling_steps: list[int] | None = None, guidance_scale: float | None = None, num_frames: int | None = None, output_dir: str | None = None, sampling_timesteps: list[int] | None = None, **pipeline_kwargs: Any)

Bases: Callback

Generic validation callback driven entirely by YAML config.

Works with any pipeline that follows the PipelineCls.from_pretrained(...) + pipeline.forward() contract.

Source code in fastvideo/train/callbacks/validation.py
def __init__(
    self,
    *,
    pipeline_target: str,
    dataset_file: str,
    every_steps: int = 100,
    sampling_steps: list[int] | None = None,
    guidance_scale: float | None = None,
    num_frames: int | None = None,
    output_dir: str | None = None,
    sampling_timesteps: list[int] | None = None,
    **pipeline_kwargs: Any,
) -> None:
    self.pipeline_target = str(pipeline_target)
    self.dataset_file = str(dataset_file)
    self.every_steps = int(every_steps)
    self.sampling_steps = ([int(s) for s in sampling_steps] if sampling_steps else [40])
    self.guidance_scale = (float(guidance_scale) if guidance_scale is not None else None)
    self.num_frames = (int(num_frames) if num_frames is not None else None)
    self.output_dir = (str(output_dir) if output_dir is not None else None)
    self.sampling_timesteps = ([int(s) for s in sampling_timesteps] if sampling_timesteps is not None else None)
    self.pipeline_kwargs = dict(pipeline_kwargs)

    # Set after on_train_start.
    self._pipeline: Any | None = None
    self._pipeline_key: tuple[Any, ...] | None = None
    self._sampling_param: SamplingParam | None = None
    self.tracker: Any = DummyTracker()
    self.validation_random_generator: (torch.Generator | None) = None
    self.seed: int = 0

Modules

fastvideo.train.callbacks.callback

Callback base class and CallbackDict manager.

Adapted from FastGen's callback pattern to FastVideo's types.

Classes

fastvideo.train.callbacks.callback.Callback

Base callback with no-op hooks.

Subclasses override whichever hooks they need. The training_config and method attributes are set by CallbackDict after instantiation.

fastvideo.train.callbacks.callback.CallbackDict
CallbackDict(callback_configs: dict[str, dict[str, Any]], training_config: TrainingConfig)

Manages a collection of named callbacks.

Instantiates each callback from its _target_ config and dispatches hook calls to all registered callbacks.

Source code in fastvideo/train/callbacks/callback.py
def __init__(
    self,
    callback_configs: dict[str, dict[str, Any]],
    training_config: TrainingConfig,
) -> None:
    self._callbacks: dict[str, Callback] = {}
    if not callback_configs:
        return
    for name, cb_cfg in callback_configs.items():
        cb_cfg = dict(cb_cfg)
        if "_target_" not in cb_cfg:
            if name in _BUILTIN_CALLBACKS:
                cb_cfg["_target_"] = (_BUILTIN_CALLBACKS[name])
            else:
                logger.warning(
                    "Callback %r is missing "
                    "'_target_', skipping: %s",
                    name,
                    cb_cfg,
                )
                continue
        logger.info(
            "Instantiating callback %r: %s",
            name,
            cb_cfg,
        )
        cb = instantiate(cb_cfg)
        if not isinstance(cb, Callback):
            raise TypeError(f"Callback {name!r} resolved to "
                            f"{type(cb).__name__}, expected a "
                            f"Callback subclass.")
        cb.training_config = training_config
        cb._callback_dict = self
        self._callbacks[name] = cb

Functions

fastvideo.train.callbacks.ema

EMA (Exponential Moving Average) callback.

Owns the full EMA lifecycle: creation, per-step updates, weight swapping for validation, and checkpoint state. All EMA config lives under callbacks.ema in the YAML file.

Classes

fastvideo.train.callbacks.ema.EMACallback
EMACallback(*, decay: float = 0.9999, start_iter: int = 0)

Bases: Callback

Manage EMA shadow weights for the student transformer.

All configuration lives in the YAML callbacks.ema section:

.. code-block:: yaml

callbacks:
  ema:
    decay: 0.9999
    start_iter: 0

The callback creates an EMA_FSDP instance at train start, updates it after each optimizer step, and exposes an ema_context() context manager for temporarily swapping EMA weights into the live model (used by validation).

Source code in fastvideo/train/callbacks/ema.py
def __init__(
    self,
    *,
    decay: float = 0.9999,
    start_iter: int = 0,
) -> None:
    self._decay = float(decay)
    self._start_iter = int(start_iter)
    self._ema_started = False
    self.student_ema: EMA_FSDP | None = None
Functions
fastvideo.train.callbacks.ema.EMACallback.ema_context
ema_context(transformer: Module) -> Generator[Module, None, None]

Temporarily swap EMA weights into transformer.

If EMA is not active, yields the transformer unchanged.

Source code in fastvideo/train/callbacks/ema.py
@contextlib.contextmanager
def ema_context(
    self,
    transformer: torch.nn.Module,
) -> Generator[torch.nn.Module, None, None]:
    """Temporarily swap EMA weights into *transformer*.

    If EMA is not active, yields the transformer unchanged.
    """
    if (self.student_ema is not None and self._ema_started):
        with self.student_ema.apply_to_model(transformer, ):
            yield transformer
    else:
        yield transformer

Functions

fastvideo.train.callbacks.grad_clip

Gradient norm clipping callback.

Clips gradients on modules returned by method.get_grad_clip_targets() before the optimizer step. Optionally logs per-module grad norms to the tracker.

Classes

fastvideo.train.callbacks.grad_clip.GradNormClipCallback
GradNormClipCallback(*, max_grad_norm: float = 1.0, log_grad_norms: bool = True)

Bases: Callback

Clip gradient norms before the optimizer step.

max_grad_norm must be set explicitly in the callback config (callbacks.grad_clip.max_grad_norm).

Source code in fastvideo/train/callbacks/grad_clip.py
def __init__(
    self,
    *,
    max_grad_norm: float = 1.0,
    log_grad_norms: bool = True,
) -> None:
    self._max_grad_norm = float(max_grad_norm)
    self._log_grad_norms = bool(log_grad_norms)

Functions

fastvideo.train.callbacks.validation

Validation callback.

All configuration is read from the YAML callbacks.validation section. The pipeline class is resolved from pipeline_target.

Classes

fastvideo.train.callbacks.validation.ValidationCallback
ValidationCallback(*, pipeline_target: str, dataset_file: str, every_steps: int = 100, sampling_steps: list[int] | None = None, guidance_scale: float | None = None, num_frames: int | None = None, output_dir: str | None = None, sampling_timesteps: list[int] | None = None, **pipeline_kwargs: Any)

Bases: Callback

Generic validation callback driven entirely by YAML config.

Works with any pipeline that follows the PipelineCls.from_pretrained(...) + pipeline.forward() contract.

Source code in fastvideo/train/callbacks/validation.py
def __init__(
    self,
    *,
    pipeline_target: str,
    dataset_file: str,
    every_steps: int = 100,
    sampling_steps: list[int] | None = None,
    guidance_scale: float | None = None,
    num_frames: int | None = None,
    output_dir: str | None = None,
    sampling_timesteps: list[int] | None = None,
    **pipeline_kwargs: Any,
) -> None:
    self.pipeline_target = str(pipeline_target)
    self.dataset_file = str(dataset_file)
    self.every_steps = int(every_steps)
    self.sampling_steps = ([int(s) for s in sampling_steps] if sampling_steps else [40])
    self.guidance_scale = (float(guidance_scale) if guidance_scale is not None else None)
    self.num_frames = (int(num_frames) if num_frames is not None else None)
    self.output_dir = (str(output_dir) if output_dir is not None else None)
    self.sampling_timesteps = ([int(s) for s in sampling_timesteps] if sampling_timesteps is not None else None)
    self.pipeline_kwargs = dict(pipeline_kwargs)

    # Set after on_train_start.
    self._pipeline: Any | None = None
    self._pipeline_key: tuple[Any, ...] | None = None
    self._sampling_param: SamplingParam | None = None
    self.tracker: Any = DummyTracker()
    self.validation_random_generator: (torch.Generator | None) = None
    self.seed: int = 0

Functions