Skip to content

rl

RL training methods.

Classes

fastvideo.train.methods.rl.DiffusionNFTMethod

DiffusionNFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DiffusionNFT-style RL for diffusion models.

This method owns the algorithm's sample-then-inner-train loop. One Trainer step corresponds to one DiffusionNFT outer epoch.

Source code in fastvideo/train/methods/rl/diffusion_nft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    if "old" not in role_models:
        raise ValueError("DiffusionNFTMethod requires role 'old'")
    if "reference" not in role_models:
        raise ValueError("DiffusionNFTMethod requires role 'reference'")
    if not self.student._trainable:
        raise ValueError("DiffusionNFTMethod requires a trainable student")

    self.old = role_models["old"]
    self.reference = role_models["reference"]
    self.student.init_preprocessors(self.training_config)

    self._sampling_config = self._parse_sampling_config()
    self._sampler = DiffusionSampler(self._sampling_config)
    self._validation_config = RLValidationConfig.from_mapping(self.method_config.get("validation"))
    self._validation_sampling_config = self._parse_validation_sampling_config()
    self._validation_sampler = DiffusionSampler(self._validation_sampling_config)
    self._validation_items: list[tuple[int, bool, dict[str, Any]]] | None = None
    self._sample_steps = int(self._sampling_config.num_steps)
    self._sample_train_batch_size = self._read_int(
        "sample_train_batch_size",
        max(1, int(self.training_config.data.train_batch_size or 1)),
    )
    self._train_batch_size = self._read_int("train_batch_size", self._sample_train_batch_size)
    self._num_batches_per_epoch = self._read_int("num_batches_per_epoch", 48)
    self._num_inner_epochs = self._read_int("num_inner_epochs", 1)
    self._num_video_per_prompt = self._read_int("num_video_per_prompt", 24)
    self._adv_clip_max = self._read_float("adv_clip_max", 5.0)
    self._timestep_fraction = self._read_float("timestep_fraction", 0.99)
    self._kl_beta = self._read_float("kl_beta", 0.0001)
    self._nft_beta = self._read_float("beta", 0.1)
    self._max_grad_norm = self._read_float("max_grad_norm", 1.0)
    self._decay_type = self._read_int("decay_type", 1)
    self._adv_mode = str(self.method_config.get("adv_mode", "all") or "all").strip().lower()
    self._terminal_progress = bool(self.method_config.get("terminal_progress", True))
    ema_config = self._parse_ema_config()
    self._ema_enabled = bool(ema_config["enabled"])
    self._ema_decay = float(ema_config["decay"])
    self._ema_update_after_step = int(ema_config["update_after_step"])
    self._validation_use_ema = bool(ema_config["validation"])
    self._student_ema: EMA_FSDP | None = None
    self._ema_update_count = 0
    self._trained_prompt_hashes: set[int] = set()
    if self._adv_mode not in {"all", "positive_only", "negative_only", "one_only", "binary"}:
        raise ValueError("method.adv_mode must be one of "
                         "{all, positive_only, negative_only, one_only, binary}")

    reward_fn = self.method_config.get("reward_fn", None)
    if not isinstance(reward_fn, dict) or not reward_fn:
        raise ValueError("method.reward_fn must be a non-empty mapping, "
                         "for example {pickscore: 1.0, clipscore: 1.0}")
    self._reward_fn_config = {str(k): float(v) for k, v in reward_fn.items()}
    unsupported = sorted(set(self._reward_fn_config) - {"pickscore", "clipscore"})
    if unsupported:
        raise ValueError(f"Unsupported DiffusionNFT reward(s): {unsupported}. "
                         "Only pickscore and clipscore are currently ported.")

    self._reward_scorer: Any | None = None
    self._init_optimizer_and_scheduler()

Modules

fastvideo.train.methods.rl.common

Reusable RL training primitives.

Classes

fastvideo.train.methods.rl.common.DiffusionSampler
DiffusionSampler(config: SamplingConfig)

Thin model/scheduler sampler used by RL methods.

This intentionally does not call FastVideo's full inference pipelines. RL training needs a reusable sampling primitive that works with ModelBase wrappers and scheduler math without binding a method to model-family pipeline classes such as WanDMDPipeline.

Source code in fastvideo/train/methods/rl/common/sampling.py
def __init__(self, config: SamplingConfig) -> None:
    self.config = config
fastvideo.train.methods.rl.common.KRepeatSample dataclass
KRepeatSample(local_indices: list[int], unique_prompt_count: int)

Local prompt indices for one distributed K-repeat sampling batch.

fastvideo.train.methods.rl.common.SamplingConfig dataclass
SamplingConfig(num_steps: int = 25, scheduler: SchedulerName = 'model_default', trajectory: TrajectoryName = 'ode', flow_shift: float | None = None, timesteps: list[float] | None = None, sigmas: list[float] | None = None)

YAML-backed sampling knobs shared by RL methods.

Functions:

fastvideo.train.methods.rl.common.distributed_k_repeat_indices
distributed_k_repeat_indices(*, dataset_length: int, batch_size: int, repeats_per_prompt: int, world_size: int, rank: int, seed: int) -> KRepeatSample

Mirror DiffusionNFT's distributed K-repeat prompt sampler.

Adapted from DiffusionNFT's scripts/train_nft_sd3.py::DistributedKRepeatSampler.

Source code in fastvideo/train/methods/rl/common/prompt_sampling.py
def distributed_k_repeat_indices(
    *,
    dataset_length: int,
    batch_size: int,
    repeats_per_prompt: int,
    world_size: int,
    rank: int,
    seed: int,
) -> KRepeatSample:
    """Mirror DiffusionNFT's distributed K-repeat prompt sampler.

    Adapted from DiffusionNFT's
    ``scripts/train_nft_sd3.py::DistributedKRepeatSampler``.
    """
    dataset_length = int(dataset_length)
    batch_size = int(batch_size)
    repeats_per_prompt = int(repeats_per_prompt)
    world_size = int(world_size)
    rank = int(rank)
    if dataset_length <= 0:
        raise ValueError("dataset_length must be positive")
    if batch_size <= 0:
        raise ValueError("batch_size must be positive")
    if repeats_per_prompt <= 0:
        raise ValueError("repeats_per_prompt must be positive")
    if world_size <= 0:
        raise ValueError("world_size must be positive")
    if rank < 0 or rank >= world_size:
        raise ValueError(f"rank must be in [0, {world_size}), got {rank}")

    total_samples = world_size * batch_size
    if total_samples % repeats_per_prompt != 0:
        raise ValueError("world_size * batch_size must be divisible by repeats_per_prompt "
                         f"({world_size} * {batch_size} vs {repeats_per_prompt})")
    unique_prompt_count = total_samples // repeats_per_prompt
    if unique_prompt_count > dataset_length:
        raise ValueError("K-repeat sampling needs at least as many rows as unique prompts "
                         f"per sampling batch ({dataset_length} < {unique_prompt_count})")

    generator = torch.Generator()
    generator.manual_seed(int(seed))
    indices = torch.randperm(dataset_length, generator=generator)[:unique_prompt_count].tolist()
    repeated_indices = [idx for idx in indices for _ in range(repeats_per_prompt)]
    shuffled_order = torch.randperm(len(repeated_indices), generator=generator).tolist()
    shuffled_samples = [int(repeated_indices[idx]) for idx in shuffled_order]

    start = rank * batch_size
    end = start + batch_size
    return KRepeatSample(
        local_indices=shuffled_samples[start:end],
        unique_prompt_count=unique_prompt_count,
    )
fastvideo.train.methods.rl.common.media_to_video_array
media_to_video_array(media: Tensor) -> Any

Convert decoded media to a tracker video array.

Accepts [C, T, H, W] tensors. [C, H, W] tensors are treated as T=1 media. Output follows the existing tracker convention used elsewhere in FastVideo: [T, C, H, W] uint8.

Source code in fastvideo/train/methods/rl/common/validation.py
def media_to_video_array(media: torch.Tensor) -> Any:
    """Convert decoded media to a tracker video array.

    Accepts ``[C, T, H, W]`` tensors. ``[C, H, W]`` tensors are treated as
    ``T=1`` media. Output follows the existing tracker convention used
    elsewhere in FastVideo: ``[T, C, H, W]`` uint8.
    """
    if media.ndim == 3:
        media = media.unsqueeze(1)
    if media.ndim != 4:
        raise ValueError("media must have shape [C, T, H, W] or [C, H, W], "
                         f"got {tuple(media.shape)}")
    video = (media.detach().float().clamp(0, 1) * 255).round().to(torch.uint8)
    return video.permute(1, 0, 2, 3).contiguous().cpu().numpy()
fastvideo.train.methods.rl.common.validation_shard_indices
validation_shard_indices(num_prompts: int, *, rank: int, world_size: int) -> list[tuple[int, bool]]

Return fixed validation prompt indices for one distributed rank.

Source code in fastvideo/train/methods/rl/common/validation.py
def validation_shard_indices(
    num_prompts: int,
    *,
    rank: int,
    world_size: int,
) -> list[tuple[int, bool]]:
    """Return fixed validation prompt indices for one distributed rank."""
    num_prompts = max(1, int(num_prompts))
    world_size = max(1, int(world_size))
    per_rank = int(math.ceil(num_prompts / world_size))
    padded_total = per_rank * world_size
    return [((idx % num_prompts), idx < num_prompts) for idx in range(rank, padded_total, world_size)]

Modules

fastvideo.train.methods.rl.common.prompt_sampling

Prompt-row sampling helpers for online RL methods.

This module chooses and repeats dataset prompt rows across ranks for RL training batches. Here, "sampling" means selection, not generator sampling.

Classes
fastvideo.train.methods.rl.common.prompt_sampling.KRepeatSample dataclass
KRepeatSample(local_indices: list[int], unique_prompt_count: int)

Local prompt indices for one distributed K-repeat sampling batch.

Functions:
fastvideo.train.methods.rl.common.prompt_sampling.distributed_k_repeat_indices
distributed_k_repeat_indices(*, dataset_length: int, batch_size: int, repeats_per_prompt: int, world_size: int, rank: int, seed: int) -> KRepeatSample

Mirror DiffusionNFT's distributed K-repeat prompt sampler.

Adapted from DiffusionNFT's scripts/train_nft_sd3.py::DistributedKRepeatSampler.

Source code in fastvideo/train/methods/rl/common/prompt_sampling.py
def distributed_k_repeat_indices(
    *,
    dataset_length: int,
    batch_size: int,
    repeats_per_prompt: int,
    world_size: int,
    rank: int,
    seed: int,
) -> KRepeatSample:
    """Mirror DiffusionNFT's distributed K-repeat prompt sampler.

    Adapted from DiffusionNFT's
    ``scripts/train_nft_sd3.py::DistributedKRepeatSampler``.
    """
    dataset_length = int(dataset_length)
    batch_size = int(batch_size)
    repeats_per_prompt = int(repeats_per_prompt)
    world_size = int(world_size)
    rank = int(rank)
    if dataset_length <= 0:
        raise ValueError("dataset_length must be positive")
    if batch_size <= 0:
        raise ValueError("batch_size must be positive")
    if repeats_per_prompt <= 0:
        raise ValueError("repeats_per_prompt must be positive")
    if world_size <= 0:
        raise ValueError("world_size must be positive")
    if rank < 0 or rank >= world_size:
        raise ValueError(f"rank must be in [0, {world_size}), got {rank}")

    total_samples = world_size * batch_size
    if total_samples % repeats_per_prompt != 0:
        raise ValueError("world_size * batch_size must be divisible by repeats_per_prompt "
                         f"({world_size} * {batch_size} vs {repeats_per_prompt})")
    unique_prompt_count = total_samples // repeats_per_prompt
    if unique_prompt_count > dataset_length:
        raise ValueError("K-repeat sampling needs at least as many rows as unique prompts "
                         f"per sampling batch ({dataset_length} < {unique_prompt_count})")

    generator = torch.Generator()
    generator.manual_seed(int(seed))
    indices = torch.randperm(dataset_length, generator=generator)[:unique_prompt_count].tolist()
    repeated_indices = [idx for idx in indices for _ in range(repeats_per_prompt)]
    shuffled_order = torch.randperm(len(repeated_indices), generator=generator).tolist()
    shuffled_samples = [int(repeated_indices[idx]) for idx in shuffled_order]

    start = rank * batch_size
    end = start + batch_size
    return KRepeatSample(
        local_indices=shuffled_samples[start:end],
        unique_prompt_count=unique_prompt_count,
    )
fastvideo.train.methods.rl.common.sampling

Configurable diffusion samplers for RL training methods.

Classes
fastvideo.train.methods.rl.common.sampling.DiffusionSampler
DiffusionSampler(config: SamplingConfig)

Thin model/scheduler sampler used by RL methods.

This intentionally does not call FastVideo's full inference pipelines. RL training needs a reusable sampling primitive that works with ModelBase wrappers and scheduler math without binding a method to model-family pipeline classes such as WanDMDPipeline.

Source code in fastvideo/train/methods/rl/common/sampling.py
def __init__(self, config: SamplingConfig) -> None:
    self.config = config
fastvideo.train.methods.rl.common.sampling.SamplingConfig dataclass
SamplingConfig(num_steps: int = 25, scheduler: SchedulerName = 'model_default', trajectory: TrajectoryName = 'ode', flow_shift: float | None = None, timesteps: list[float] | None = None, sigmas: list[float] | None = None)

YAML-backed sampling knobs shared by RL methods.

fastvideo.train.methods.rl.common.validation

Shared validation helpers for RL training methods.

Functions:
fastvideo.train.methods.rl.common.validation.media_to_video_array
media_to_video_array(media: Tensor) -> Any

Convert decoded media to a tracker video array.

Accepts [C, T, H, W] tensors. [C, H, W] tensors are treated as T=1 media. Output follows the existing tracker convention used elsewhere in FastVideo: [T, C, H, W] uint8.

Source code in fastvideo/train/methods/rl/common/validation.py
def media_to_video_array(media: torch.Tensor) -> Any:
    """Convert decoded media to a tracker video array.

    Accepts ``[C, T, H, W]`` tensors. ``[C, H, W]`` tensors are treated as
    ``T=1`` media. Output follows the existing tracker convention used
    elsewhere in FastVideo: ``[T, C, H, W]`` uint8.
    """
    if media.ndim == 3:
        media = media.unsqueeze(1)
    if media.ndim != 4:
        raise ValueError("media must have shape [C, T, H, W] or [C, H, W], "
                         f"got {tuple(media.shape)}")
    video = (media.detach().float().clamp(0, 1) * 255).round().to(torch.uint8)
    return video.permute(1, 0, 2, 3).contiguous().cpu().numpy()
fastvideo.train.methods.rl.common.validation.validation_shard_indices
validation_shard_indices(num_prompts: int, *, rank: int, world_size: int) -> list[tuple[int, bool]]

Return fixed validation prompt indices for one distributed rank.

Source code in fastvideo/train/methods/rl/common/validation.py
def validation_shard_indices(
    num_prompts: int,
    *,
    rank: int,
    world_size: int,
) -> list[tuple[int, bool]]:
    """Return fixed validation prompt indices for one distributed rank."""
    num_prompts = max(1, int(num_prompts))
    world_size = max(1, int(world_size))
    per_rank = int(math.ceil(num_prompts / world_size))
    padded_total = per_rank * world_size
    return [((idx % num_prompts), idx < num_prompts) for idx in range(rank, padded_total, world_size)]

fastvideo.train.methods.rl.diffusion_nft

DiffusionNFT multi-reward policy optimization method.

Classes

fastvideo.train.methods.rl.diffusion_nft.DiffusionNFTMethod
DiffusionNFTMethod(*, cfg: Any, role_models: dict[str, ModelBase])

Bases: TrainingMethod

DiffusionNFT-style RL for diffusion models.

This method owns the algorithm's sample-then-inner-train loop. One Trainer step corresponds to one DiffusionNFT outer epoch.

Source code in fastvideo/train/methods/rl/diffusion_nft.py
def __init__(
    self,
    *,
    cfg: Any,
    role_models: dict[str, ModelBase],
) -> None:
    super().__init__(cfg=cfg, role_models=role_models)
    if "old" not in role_models:
        raise ValueError("DiffusionNFTMethod requires role 'old'")
    if "reference" not in role_models:
        raise ValueError("DiffusionNFTMethod requires role 'reference'")
    if not self.student._trainable:
        raise ValueError("DiffusionNFTMethod requires a trainable student")

    self.old = role_models["old"]
    self.reference = role_models["reference"]
    self.student.init_preprocessors(self.training_config)

    self._sampling_config = self._parse_sampling_config()
    self._sampler = DiffusionSampler(self._sampling_config)
    self._validation_config = RLValidationConfig.from_mapping(self.method_config.get("validation"))
    self._validation_sampling_config = self._parse_validation_sampling_config()
    self._validation_sampler = DiffusionSampler(self._validation_sampling_config)
    self._validation_items: list[tuple[int, bool, dict[str, Any]]] | None = None
    self._sample_steps = int(self._sampling_config.num_steps)
    self._sample_train_batch_size = self._read_int(
        "sample_train_batch_size",
        max(1, int(self.training_config.data.train_batch_size or 1)),
    )
    self._train_batch_size = self._read_int("train_batch_size", self._sample_train_batch_size)
    self._num_batches_per_epoch = self._read_int("num_batches_per_epoch", 48)
    self._num_inner_epochs = self._read_int("num_inner_epochs", 1)
    self._num_video_per_prompt = self._read_int("num_video_per_prompt", 24)
    self._adv_clip_max = self._read_float("adv_clip_max", 5.0)
    self._timestep_fraction = self._read_float("timestep_fraction", 0.99)
    self._kl_beta = self._read_float("kl_beta", 0.0001)
    self._nft_beta = self._read_float("beta", 0.1)
    self._max_grad_norm = self._read_float("max_grad_norm", 1.0)
    self._decay_type = self._read_int("decay_type", 1)
    self._adv_mode = str(self.method_config.get("adv_mode", "all") or "all").strip().lower()
    self._terminal_progress = bool(self.method_config.get("terminal_progress", True))
    ema_config = self._parse_ema_config()
    self._ema_enabled = bool(ema_config["enabled"])
    self._ema_decay = float(ema_config["decay"])
    self._ema_update_after_step = int(ema_config["update_after_step"])
    self._validation_use_ema = bool(ema_config["validation"])
    self._student_ema: EMA_FSDP | None = None
    self._ema_update_count = 0
    self._trained_prompt_hashes: set[int] = set()
    if self._adv_mode not in {"all", "positive_only", "negative_only", "one_only", "binary"}:
        raise ValueError("method.adv_mode must be one of "
                         "{all, positive_only, negative_only, one_only, binary}")

    reward_fn = self.method_config.get("reward_fn", None)
    if not isinstance(reward_fn, dict) or not reward_fn:
        raise ValueError("method.reward_fn must be a non-empty mapping, "
                         "for example {pickscore: 1.0, clipscore: 1.0}")
    self._reward_fn_config = {str(k): float(v) for k, v in reward_fn.items()}
    unsupported = sorted(set(self._reward_fn_config) - {"pickscore", "clipscore"})
    if unsupported:
        raise ValueError(f"Unsupported DiffusionNFT reward(s): {unsupported}. "
                         "Only pickscore and clipscore are currently ported.")

    self._reward_scorer: Any | None = None
    self._init_optimizer_and_scheduler()

Functions:

fastvideo.train.methods.rl.rewards

Reusable reward models for training methods.

Classes

fastvideo.train.methods.rl.rewards.ClipScoreScorer
ClipScoreScorer(*, device: device | str = 'cuda')

Bases: Module

CLIPScore reward, matching DiffusionNFT normalization.

Ported from DiffusionNFT's flow_grpo/clip_scorer.py.

Source code in fastvideo/train/methods/rl/rewards/frame_rewards.py
def __init__(
    self,
    *,
    device: torch.device | str = "cuda",
) -> None:
    super().__init__()
    import torch.nn as nn
    import torchvision.transforms as T
    from transformers import CLIPModel, CLIPProcessor

    def get_size(size: Any) -> Any:
        if isinstance(size, int):
            return (size, size)
        if isinstance(size, Mapping) and "height" in size and "width" in size:
            return (size["height"], size["width"])
        if isinstance(size, Mapping) and "shortest_edge" in size:
            return size["shortest_edge"]
        raise ValueError(f"Invalid processor size: {size!r}")

    def get_frame_transform(processor: Any) -> torch.nn.Module:
        config = processor.to_dict()
        resize = T.Resize(get_size(config.get("size"))) if config.get("do_resize") else nn.Identity()
        crop = T.CenterCrop(get_size(config.get("crop_size"))) if config.get("do_center_crop") else nn.Identity()
        normalize = (T.Normalize(mean=processor.image_mean, std=processor.image_std)
                     if config.get("do_normalize") else nn.Identity())
        return T.Compose([resize, crop, normalize])

    self.device = torch.device(device)
    self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device).eval()
    self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    self.transform = get_frame_transform(self.processor.image_processor)
fastvideo.train.methods.rl.rewards.MultiRewardScorer
MultiRewardScorer(reward_weights: Mapping[str, float], *, scorers: Mapping[str, RewardScorer])

Weighted sum of reusable media reward scorers.

Mirrors DiffusionNFT's flow_grpo/rewards.py::multi_score behavior, while leaving frame selection to each concrete reward.

Source code in fastvideo/train/methods/rl/rewards/media.py
def __init__(
    self,
    reward_weights: Mapping[str, float],
    *,
    scorers: Mapping[str, RewardScorer],
) -> None:
    self.reward_weights = {str(k): float(v) for k, v in reward_weights.items()}
    if not self.reward_weights:
        raise ValueError("reward_weights must contain at least one reward")

    self.scorers = dict(scorers)
    unsupported = sorted(set(self.reward_weights) - set(self.scorers))
    if unsupported:
        raise ValueError(f"Unsupported reward(s): {unsupported}. "
                         f"Available rewards: {sorted(self.scorers)}")
fastvideo.train.methods.rl.rewards.PickScoreScorer
PickScoreScorer(*, device: device | str = 'cuda', dtype: dtype = float32)

Bases: Module

PickScore reward, matching DiffusionNFT normalization.

Ported from DiffusionNFT's flow_grpo/pickscore_scorer.py.

Source code in fastvideo/train/methods/rl/rewards/frame_rewards.py
def __init__(
    self,
    *,
    device: torch.device | str = "cuda",
    dtype: torch.dtype = torch.float32,
) -> None:
    super().__init__()
    from transformers import AutoModel, AutoProcessor

    processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    model_path = "yuvalkirstain/PickScore_v1"
    self.device = torch.device(device)
    self.dtype = dtype
    self.processor = AutoProcessor.from_pretrained(processor_path)
    self.model = AutoModel.from_pretrained(model_path).eval().to(self.device)
    self.model = self.model.to(dtype=dtype)

Functions:

fastvideo.train.methods.rl.rewards.select_first_frame
select_first_frame(media: Tensor) -> Tensor

Return first-frame media as [B, C, H, W].

This is a helper for reward models that are intrinsically frame-based (for example PickScore and CLIPScore). Video-aware rewards should inspect the full [B, C, T, H, W] tensor themselves.

Source code in fastvideo/train/methods/rl/rewards/media.py
def select_first_frame(media: torch.Tensor) -> torch.Tensor:
    """Return first-frame media as ``[B, C, H, W]``.

    This is a helper for reward models that are intrinsically frame-based
    (for example PickScore and CLIPScore). Video-aware rewards should inspect
    the full ``[B, C, T, H, W]`` tensor themselves.
    """
    if not torch.is_tensor(media):
        raise TypeError(f"media must be a torch.Tensor, got {type(media).__name__}")
    if media.ndim == 5:
        return media[:, :, 0]
    if media.ndim == 4:
        return media
    raise ValueError("media must have shape [B, C, H, W] or [B, C, T, H, W], "
                     f"got {tuple(media.shape)}")

Modules

fastvideo.train.methods.rl.rewards.frame_rewards

Frame-based reward scorers used by RL training methods.

Classes
fastvideo.train.methods.rl.rewards.frame_rewards.ClipScoreScorer
ClipScoreScorer(*, device: device | str = 'cuda')

Bases: Module

CLIPScore reward, matching DiffusionNFT normalization.

Ported from DiffusionNFT's flow_grpo/clip_scorer.py.

Source code in fastvideo/train/methods/rl/rewards/frame_rewards.py
def __init__(
    self,
    *,
    device: torch.device | str = "cuda",
) -> None:
    super().__init__()
    import torch.nn as nn
    import torchvision.transforms as T
    from transformers import CLIPModel, CLIPProcessor

    def get_size(size: Any) -> Any:
        if isinstance(size, int):
            return (size, size)
        if isinstance(size, Mapping) and "height" in size and "width" in size:
            return (size["height"], size["width"])
        if isinstance(size, Mapping) and "shortest_edge" in size:
            return size["shortest_edge"]
        raise ValueError(f"Invalid processor size: {size!r}")

    def get_frame_transform(processor: Any) -> torch.nn.Module:
        config = processor.to_dict()
        resize = T.Resize(get_size(config.get("size"))) if config.get("do_resize") else nn.Identity()
        crop = T.CenterCrop(get_size(config.get("crop_size"))) if config.get("do_center_crop") else nn.Identity()
        normalize = (T.Normalize(mean=processor.image_mean, std=processor.image_std)
                     if config.get("do_normalize") else nn.Identity())
        return T.Compose([resize, crop, normalize])

    self.device = torch.device(device)
    self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device).eval()
    self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    self.transform = get_frame_transform(self.processor.image_processor)
fastvideo.train.methods.rl.rewards.frame_rewards.PickScoreScorer
PickScoreScorer(*, device: device | str = 'cuda', dtype: dtype = float32)

Bases: Module

PickScore reward, matching DiffusionNFT normalization.

Ported from DiffusionNFT's flow_grpo/pickscore_scorer.py.

Source code in fastvideo/train/methods/rl/rewards/frame_rewards.py
def __init__(
    self,
    *,
    device: torch.device | str = "cuda",
    dtype: torch.dtype = torch.float32,
) -> None:
    super().__init__()
    from transformers import AutoModel, AutoProcessor

    processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    model_path = "yuvalkirstain/PickScore_v1"
    self.device = torch.device(device)
    self.dtype = dtype
    self.processor = AutoProcessor.from_pretrained(processor_path)
    self.model = AutoModel.from_pretrained(model_path).eval().to(self.device)
    self.model = self.model.to(dtype=dtype)
Functions:
fastvideo.train.methods.rl.rewards.media

Generic media reward composition utilities.

Classes
fastvideo.train.methods.rl.rewards.media.MultiRewardScorer
MultiRewardScorer(reward_weights: Mapping[str, float], *, scorers: Mapping[str, RewardScorer])

Weighted sum of reusable media reward scorers.

Mirrors DiffusionNFT's flow_grpo/rewards.py::multi_score behavior, while leaving frame selection to each concrete reward.

Source code in fastvideo/train/methods/rl/rewards/media.py
def __init__(
    self,
    reward_weights: Mapping[str, float],
    *,
    scorers: Mapping[str, RewardScorer],
) -> None:
    self.reward_weights = {str(k): float(v) for k, v in reward_weights.items()}
    if not self.reward_weights:
        raise ValueError("reward_weights must contain at least one reward")

    self.scorers = dict(scorers)
    unsupported = sorted(set(self.reward_weights) - set(self.scorers))
    if unsupported:
        raise ValueError(f"Unsupported reward(s): {unsupported}. "
                         f"Available rewards: {sorted(self.scorers)}")
Functions:
fastvideo.train.methods.rl.rewards.media.select_first_frame
select_first_frame(media: Tensor) -> Tensor

Return first-frame media as [B, C, H, W].

This is a helper for reward models that are intrinsically frame-based (for example PickScore and CLIPScore). Video-aware rewards should inspect the full [B, C, T, H, W] tensor themselves.

Source code in fastvideo/train/methods/rl/rewards/media.py
def select_first_frame(media: torch.Tensor) -> torch.Tensor:
    """Return first-frame media as ``[B, C, H, W]``.

    This is a helper for reward models that are intrinsically frame-based
    (for example PickScore and CLIPScore). Video-aware rewards should inspect
    the full ``[B, C, T, H, W]`` tensor themselves.
    """
    if not torch.is_tensor(media):
        raise TypeError(f"media must be a torch.Tensor, got {type(media).__name__}")
    if media.ndim == 5:
        return media[:, :, 0]
    if media.ndim == 4:
        return media
    raise ValueError("media must have shape [B, C, H, W] or [B, C, T, H, W], "
                     f"got {tuple(media.shape)}")