Skip to content

media

Generic media reward composition utilities.

Classes

fastvideo.train.methods.rl.rewards.media.MultiRewardScorer

MultiRewardScorer(reward_weights: Mapping[str, float], *, scorers: Mapping[str, RewardScorer])

Weighted sum of reusable media reward scorers.

Mirrors DiffusionNFT's flow_grpo/rewards.py::multi_score behavior, while leaving frame selection to each concrete reward.

Source code in fastvideo/train/methods/rl/rewards/media.py
def __init__(
    self,
    reward_weights: Mapping[str, float],
    *,
    scorers: Mapping[str, RewardScorer],
) -> None:
    self.reward_weights = {str(k): float(v) for k, v in reward_weights.items()}
    if not self.reward_weights:
        raise ValueError("reward_weights must contain at least one reward")

    self.scorers = dict(scorers)
    unsupported = sorted(set(self.reward_weights) - set(self.scorers))
    if unsupported:
        raise ValueError(f"Unsupported reward(s): {unsupported}. "
                         f"Available rewards: {sorted(self.scorers)}")

Functions:

fastvideo.train.methods.rl.rewards.media.select_first_frame

select_first_frame(media: Tensor) -> Tensor

Return first-frame media as [B, C, H, W].

This is a helper for reward models that are intrinsically frame-based (for example PickScore and CLIPScore). Video-aware rewards should inspect the full [B, C, T, H, W] tensor themselves.

Source code in fastvideo/train/methods/rl/rewards/media.py
def select_first_frame(media: torch.Tensor) -> torch.Tensor:
    """Return first-frame media as ``[B, C, H, W]``.

    This is a helper for reward models that are intrinsically frame-based
    (for example PickScore and CLIPScore). Video-aware rewards should inspect
    the full ``[B, C, T, H, W]`` tensor themselves.
    """
    if not torch.is_tensor(media):
        raise TypeError(f"media must be a torch.Tensor, got {type(media).__name__}")
    if media.ndim == 5:
        return media[:, :, 0]
    if media.ndim == 4:
        return media
    raise ValueError("media must have shape [B, C, H, W] or [B, C, T, H, W], "
                     f"got {tuple(media.shape)}")