Skip to content

common

Reusable RL training primitives.

Classes

fastvideo.train.methods.rl.common.DiffusionSampler

DiffusionSampler(config: SamplingConfig)

Thin model/scheduler sampler used by RL methods.

This intentionally does not call FastVideo's full inference pipelines. RL training needs a reusable sampling primitive that works with ModelBase wrappers and scheduler math without binding a method to model-family pipeline classes such as WanDMDPipeline.

Source code in fastvideo/train/methods/rl/common/sampling.py
def __init__(self, config: SamplingConfig) -> None:
    self.config = config

fastvideo.train.methods.rl.common.KRepeatSample dataclass

KRepeatSample(local_indices: list[int], unique_prompt_count: int)

Local prompt indices for one distributed K-repeat sampling batch.

fastvideo.train.methods.rl.common.SamplingConfig dataclass

SamplingConfig(num_steps: int = 25, scheduler: SchedulerName = 'model_default', trajectory: TrajectoryName = 'ode', flow_shift: float | None = None, timesteps: list[float] | None = None, sigmas: list[float] | None = None)

YAML-backed sampling knobs shared by RL methods.

Functions:

fastvideo.train.methods.rl.common.distributed_k_repeat_indices

distributed_k_repeat_indices(*, dataset_length: int, batch_size: int, repeats_per_prompt: int, world_size: int, rank: int, seed: int) -> KRepeatSample

Mirror DiffusionNFT's distributed K-repeat prompt sampler.

Adapted from DiffusionNFT's scripts/train_nft_sd3.py::DistributedKRepeatSampler.

Source code in fastvideo/train/methods/rl/common/prompt_sampling.py
def distributed_k_repeat_indices(
    *,
    dataset_length: int,
    batch_size: int,
    repeats_per_prompt: int,
    world_size: int,
    rank: int,
    seed: int,
) -> KRepeatSample:
    """Mirror DiffusionNFT's distributed K-repeat prompt sampler.

    Adapted from DiffusionNFT's
    ``scripts/train_nft_sd3.py::DistributedKRepeatSampler``.
    """
    dataset_length = int(dataset_length)
    batch_size = int(batch_size)
    repeats_per_prompt = int(repeats_per_prompt)
    world_size = int(world_size)
    rank = int(rank)
    if dataset_length <= 0:
        raise ValueError("dataset_length must be positive")
    if batch_size <= 0:
        raise ValueError("batch_size must be positive")
    if repeats_per_prompt <= 0:
        raise ValueError("repeats_per_prompt must be positive")
    if world_size <= 0:
        raise ValueError("world_size must be positive")
    if rank < 0 or rank >= world_size:
        raise ValueError(f"rank must be in [0, {world_size}), got {rank}")

    total_samples = world_size * batch_size
    if total_samples % repeats_per_prompt != 0:
        raise ValueError("world_size * batch_size must be divisible by repeats_per_prompt "
                         f"({world_size} * {batch_size} vs {repeats_per_prompt})")
    unique_prompt_count = total_samples // repeats_per_prompt
    if unique_prompt_count > dataset_length:
        raise ValueError("K-repeat sampling needs at least as many rows as unique prompts "
                         f"per sampling batch ({dataset_length} < {unique_prompt_count})")

    generator = torch.Generator()
    generator.manual_seed(int(seed))
    indices = torch.randperm(dataset_length, generator=generator)[:unique_prompt_count].tolist()
    repeated_indices = [idx for idx in indices for _ in range(repeats_per_prompt)]
    shuffled_order = torch.randperm(len(repeated_indices), generator=generator).tolist()
    shuffled_samples = [int(repeated_indices[idx]) for idx in shuffled_order]

    start = rank * batch_size
    end = start + batch_size
    return KRepeatSample(
        local_indices=shuffled_samples[start:end],
        unique_prompt_count=unique_prompt_count,
    )

fastvideo.train.methods.rl.common.media_to_video_array

media_to_video_array(media: Tensor) -> Any

Convert decoded media to a tracker video array.

Accepts [C, T, H, W] tensors. [C, H, W] tensors are treated as T=1 media. Output follows the existing tracker convention used elsewhere in FastVideo: [T, C, H, W] uint8.

Source code in fastvideo/train/methods/rl/common/validation.py
def media_to_video_array(media: torch.Tensor) -> Any:
    """Convert decoded media to a tracker video array.

    Accepts ``[C, T, H, W]`` tensors. ``[C, H, W]`` tensors are treated as
    ``T=1`` media. Output follows the existing tracker convention used
    elsewhere in FastVideo: ``[T, C, H, W]`` uint8.
    """
    if media.ndim == 3:
        media = media.unsqueeze(1)
    if media.ndim != 4:
        raise ValueError("media must have shape [C, T, H, W] or [C, H, W], "
                         f"got {tuple(media.shape)}")
    video = (media.detach().float().clamp(0, 1) * 255).round().to(torch.uint8)
    return video.permute(1, 0, 2, 3).contiguous().cpu().numpy()

fastvideo.train.methods.rl.common.validation_shard_indices

validation_shard_indices(num_prompts: int, *, rank: int, world_size: int) -> list[tuple[int, bool]]

Return fixed validation prompt indices for one distributed rank.

Source code in fastvideo/train/methods/rl/common/validation.py
def validation_shard_indices(
    num_prompts: int,
    *,
    rank: int,
    world_size: int,
) -> list[tuple[int, bool]]:
    """Return fixed validation prompt indices for one distributed rank."""
    num_prompts = max(1, int(num_prompts))
    world_size = max(1, int(world_size))
    per_rank = int(math.ceil(num_prompts / world_size))
    padded_total = per_rank * world_size
    return [((idx % num_prompts), idx < num_prompts) for idx in range(rank, padded_total, world_size)]

Modules

fastvideo.train.methods.rl.common.prompt_sampling

Prompt-row sampling helpers for online RL methods.

This module chooses and repeats dataset prompt rows across ranks for RL training batches. Here, "sampling" means selection, not generator sampling.

Classes

fastvideo.train.methods.rl.common.prompt_sampling.KRepeatSample dataclass
KRepeatSample(local_indices: list[int], unique_prompt_count: int)

Local prompt indices for one distributed K-repeat sampling batch.

Functions:

fastvideo.train.methods.rl.common.prompt_sampling.distributed_k_repeat_indices
distributed_k_repeat_indices(*, dataset_length: int, batch_size: int, repeats_per_prompt: int, world_size: int, rank: int, seed: int) -> KRepeatSample

Mirror DiffusionNFT's distributed K-repeat prompt sampler.

Adapted from DiffusionNFT's scripts/train_nft_sd3.py::DistributedKRepeatSampler.

Source code in fastvideo/train/methods/rl/common/prompt_sampling.py
def distributed_k_repeat_indices(
    *,
    dataset_length: int,
    batch_size: int,
    repeats_per_prompt: int,
    world_size: int,
    rank: int,
    seed: int,
) -> KRepeatSample:
    """Mirror DiffusionNFT's distributed K-repeat prompt sampler.

    Adapted from DiffusionNFT's
    ``scripts/train_nft_sd3.py::DistributedKRepeatSampler``.
    """
    dataset_length = int(dataset_length)
    batch_size = int(batch_size)
    repeats_per_prompt = int(repeats_per_prompt)
    world_size = int(world_size)
    rank = int(rank)
    if dataset_length <= 0:
        raise ValueError("dataset_length must be positive")
    if batch_size <= 0:
        raise ValueError("batch_size must be positive")
    if repeats_per_prompt <= 0:
        raise ValueError("repeats_per_prompt must be positive")
    if world_size <= 0:
        raise ValueError("world_size must be positive")
    if rank < 0 or rank >= world_size:
        raise ValueError(f"rank must be in [0, {world_size}), got {rank}")

    total_samples = world_size * batch_size
    if total_samples % repeats_per_prompt != 0:
        raise ValueError("world_size * batch_size must be divisible by repeats_per_prompt "
                         f"({world_size} * {batch_size} vs {repeats_per_prompt})")
    unique_prompt_count = total_samples // repeats_per_prompt
    if unique_prompt_count > dataset_length:
        raise ValueError("K-repeat sampling needs at least as many rows as unique prompts "
                         f"per sampling batch ({dataset_length} < {unique_prompt_count})")

    generator = torch.Generator()
    generator.manual_seed(int(seed))
    indices = torch.randperm(dataset_length, generator=generator)[:unique_prompt_count].tolist()
    repeated_indices = [idx for idx in indices for _ in range(repeats_per_prompt)]
    shuffled_order = torch.randperm(len(repeated_indices), generator=generator).tolist()
    shuffled_samples = [int(repeated_indices[idx]) for idx in shuffled_order]

    start = rank * batch_size
    end = start + batch_size
    return KRepeatSample(
        local_indices=shuffled_samples[start:end],
        unique_prompt_count=unique_prompt_count,
    )

fastvideo.train.methods.rl.common.sampling

Configurable diffusion samplers for RL training methods.

Classes

fastvideo.train.methods.rl.common.sampling.DiffusionSampler
DiffusionSampler(config: SamplingConfig)

Thin model/scheduler sampler used by RL methods.

This intentionally does not call FastVideo's full inference pipelines. RL training needs a reusable sampling primitive that works with ModelBase wrappers and scheduler math without binding a method to model-family pipeline classes such as WanDMDPipeline.

Source code in fastvideo/train/methods/rl/common/sampling.py
def __init__(self, config: SamplingConfig) -> None:
    self.config = config
fastvideo.train.methods.rl.common.sampling.SamplingConfig dataclass
SamplingConfig(num_steps: int = 25, scheduler: SchedulerName = 'model_default', trajectory: TrajectoryName = 'ode', flow_shift: float | None = None, timesteps: list[float] | None = None, sigmas: list[float] | None = None)

YAML-backed sampling knobs shared by RL methods.

fastvideo.train.methods.rl.common.validation

Shared validation helpers for RL training methods.

Functions:

fastvideo.train.methods.rl.common.validation.media_to_video_array
media_to_video_array(media: Tensor) -> Any

Convert decoded media to a tracker video array.

Accepts [C, T, H, W] tensors. [C, H, W] tensors are treated as T=1 media. Output follows the existing tracker convention used elsewhere in FastVideo: [T, C, H, W] uint8.

Source code in fastvideo/train/methods/rl/common/validation.py
def media_to_video_array(media: torch.Tensor) -> Any:
    """Convert decoded media to a tracker video array.

    Accepts ``[C, T, H, W]`` tensors. ``[C, H, W]`` tensors are treated as
    ``T=1`` media. Output follows the existing tracker convention used
    elsewhere in FastVideo: ``[T, C, H, W]`` uint8.
    """
    if media.ndim == 3:
        media = media.unsqueeze(1)
    if media.ndim != 4:
        raise ValueError("media must have shape [C, T, H, W] or [C, H, W], "
                         f"got {tuple(media.shape)}")
    video = (media.detach().float().clamp(0, 1) * 255).round().to(torch.uint8)
    return video.permute(1, 0, 2, 3).contiguous().cpu().numpy()
fastvideo.train.methods.rl.common.validation.validation_shard_indices
validation_shard_indices(num_prompts: int, *, rank: int, world_size: int) -> list[tuple[int, bool]]

Return fixed validation prompt indices for one distributed rank.

Source code in fastvideo/train/methods/rl/common/validation.py
def validation_shard_indices(
    num_prompts: int,
    *,
    rank: int,
    world_size: int,
) -> list[tuple[int, bool]]:
    """Return fixed validation prompt indices for one distributed rank."""
    num_prompts = max(1, int(num_prompts))
    world_size = max(1, int(world_size))
    per_rank = int(math.ceil(num_prompts / world_size))
    padded_total = per_rank * world_size
    return [((idx % num_prompts), idx < num_prompts) for idx in range(rank, padded_total, world_size)]