Skip to content

validation

Shared validation helpers for RL training methods.

Functions:

fastvideo.train.methods.rl.common.validation.media_to_video_array

media_to_video_array(media: Tensor) -> Any

Convert decoded media to a tracker video array.

Accepts [C, T, H, W] tensors. [C, H, W] tensors are treated as T=1 media. Output follows the existing tracker convention used elsewhere in FastVideo: [T, C, H, W] uint8.

Source code in fastvideo/train/methods/rl/common/validation.py
def media_to_video_array(media: torch.Tensor) -> Any:
    """Convert decoded media to a tracker video array.

    Accepts ``[C, T, H, W]`` tensors. ``[C, H, W]`` tensors are treated as
    ``T=1`` media. Output follows the existing tracker convention used
    elsewhere in FastVideo: ``[T, C, H, W]`` uint8.
    """
    if media.ndim == 3:
        media = media.unsqueeze(1)
    if media.ndim != 4:
        raise ValueError("media must have shape [C, T, H, W] or [C, H, W], "
                         f"got {tuple(media.shape)}")
    video = (media.detach().float().clamp(0, 1) * 255).round().to(torch.uint8)
    return video.permute(1, 0, 2, 3).contiguous().cpu().numpy()

fastvideo.train.methods.rl.common.validation.validation_shard_indices

validation_shard_indices(num_prompts: int, *, rank: int, world_size: int) -> list[tuple[int, bool]]

Return fixed validation prompt indices for one distributed rank.

Source code in fastvideo/train/methods/rl/common/validation.py
def validation_shard_indices(
    num_prompts: int,
    *,
    rank: int,
    world_size: int,
) -> list[tuple[int, bool]]:
    """Return fixed validation prompt indices for one distributed rank."""
    num_prompts = max(1, int(num_prompts))
    world_size = max(1, int(world_size))
    per_rank = int(math.ceil(num_prompts / world_size))
    padded_total = per_rank * world_size
    return [((idx % num_prompts), idx < num_prompts) for idx in range(rank, padded_total, world_size)]