Skip to content

ema

EMA (Exponential Moving Average) callback.

Owns the full EMA lifecycle: creation, per-step updates, weight swapping for validation, and checkpoint state. All EMA config lives under callbacks.ema in the YAML file.

Classes

fastvideo.train.callbacks.ema.EMACallback

EMACallback(*, decay: float = 0.9999, start_iter: int = 0)

Bases: Callback

Manage EMA shadow weights for the student transformer.

All configuration lives in the YAML callbacks.ema section:

.. code-block:: yaml

callbacks:
  ema:
    decay: 0.9999
    start_iter: 0

The callback creates an EMA_FSDP instance at train start, updates it after each optimizer step, and exposes an ema_context() context manager for temporarily swapping EMA weights into the live model (used by validation).

Source code in fastvideo/train/callbacks/ema.py
def __init__(
    self,
    *,
    decay: float = 0.9999,
    start_iter: int = 0,
) -> None:
    self._decay = float(decay)
    self._start_iter = int(start_iter)
    self._ema_started = False
    self.student_ema: EMA_FSDP | None = None

Functions

fastvideo.train.callbacks.ema.EMACallback.ema_context
ema_context(transformer: Module) -> Generator[Module, None, None]

Temporarily swap EMA weights into transformer.

If EMA is not active, yields the transformer unchanged.

Source code in fastvideo/train/callbacks/ema.py
@contextlib.contextmanager
def ema_context(
    self,
    transformer: torch.nn.Module,
) -> Generator[torch.nn.Module, None, None]:
    """Temporarily swap EMA weights into *transformer*.

    If EMA is not active, yields the transformer unchanged.
    """
    if (self.student_ema is not None and self._ema_started):
        with self.student_ema.apply_to_model(transformer, ):
            yield transformer
    else:
        yield transformer

Functions