Skip to content

frame_rewards

Frame-based reward scorers used by RL training methods.

Classes

fastvideo.train.methods.rl.rewards.frame_rewards.ClipScoreScorer

ClipScoreScorer(*, device: device | str = 'cuda')

Bases: Module

CLIPScore reward, matching DiffusionNFT normalization.

Ported from DiffusionNFT's flow_grpo/clip_scorer.py.

Source code in fastvideo/train/methods/rl/rewards/frame_rewards.py
def __init__(
    self,
    *,
    device: torch.device | str = "cuda",
) -> None:
    super().__init__()
    import torch.nn as nn
    import torchvision.transforms as T
    from transformers import CLIPModel, CLIPProcessor

    def get_size(size: Any) -> Any:
        if isinstance(size, int):
            return (size, size)
        if isinstance(size, Mapping) and "height" in size and "width" in size:
            return (size["height"], size["width"])
        if isinstance(size, Mapping) and "shortest_edge" in size:
            return size["shortest_edge"]
        raise ValueError(f"Invalid processor size: {size!r}")

    def get_frame_transform(processor: Any) -> torch.nn.Module:
        config = processor.to_dict()
        resize = T.Resize(get_size(config.get("size"))) if config.get("do_resize") else nn.Identity()
        crop = T.CenterCrop(get_size(config.get("crop_size"))) if config.get("do_center_crop") else nn.Identity()
        normalize = (T.Normalize(mean=processor.image_mean, std=processor.image_std)
                     if config.get("do_normalize") else nn.Identity())
        return T.Compose([resize, crop, normalize])

    self.device = torch.device(device)
    self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device).eval()
    self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    self.transform = get_frame_transform(self.processor.image_processor)

fastvideo.train.methods.rl.rewards.frame_rewards.PickScoreScorer

PickScoreScorer(*, device: device | str = 'cuda', dtype: dtype = float32)

Bases: Module

PickScore reward, matching DiffusionNFT normalization.

Ported from DiffusionNFT's flow_grpo/pickscore_scorer.py.

Source code in fastvideo/train/methods/rl/rewards/frame_rewards.py
def __init__(
    self,
    *,
    device: torch.device | str = "cuda",
    dtype: torch.dtype = torch.float32,
) -> None:
    super().__init__()
    from transformers import AutoModel, AutoProcessor

    processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    model_path = "yuvalkirstain/PickScore_v1"
    self.device = torch.device(device)
    self.dtype = dtype
    self.processor = AutoProcessor.from_pretrained(processor_path)
    self.model = AutoModel.from_pretrained(model_path).eval().to(self.device)
    self.model = self.model.to(dtype=dtype)

Functions: