common
¶
Reusable RL training primitives.
Classes¶
fastvideo.train.methods.rl.common.DiffusionSampler
¶
DiffusionSampler(config: SamplingConfig)
Thin model/scheduler sampler used by RL methods.
This intentionally does not call FastVideo's full inference pipelines.
RL training needs a reusable sampling primitive that works with
ModelBase wrappers and scheduler math without binding a method to
model-family pipeline classes such as WanDMDPipeline.
Source code in fastvideo/train/methods/rl/common/sampling.py
fastvideo.train.methods.rl.common.KRepeatSample
dataclass
¶
Local prompt indices for one distributed K-repeat sampling batch.
fastvideo.train.methods.rl.common.SamplingConfig
dataclass
¶
SamplingConfig(num_steps: int = 25, scheduler: SchedulerName = 'model_default', trajectory: TrajectoryName = 'ode', flow_shift: float | None = None, timesteps: list[float] | None = None, sigmas: list[float] | None = None)
YAML-backed sampling knobs shared by RL methods.
Functions:¶
fastvideo.train.methods.rl.common.distributed_k_repeat_indices
¶
distributed_k_repeat_indices(*, dataset_length: int, batch_size: int, repeats_per_prompt: int, world_size: int, rank: int, seed: int) -> KRepeatSample
Mirror DiffusionNFT's distributed K-repeat prompt sampler.
Adapted from DiffusionNFT's
scripts/train_nft_sd3.py::DistributedKRepeatSampler.
Source code in fastvideo/train/methods/rl/common/prompt_sampling.py
fastvideo.train.methods.rl.common.media_to_video_array
¶
media_to_video_array(media: Tensor) -> Any
Convert decoded media to a tracker video array.
Accepts [C, T, H, W] tensors. [C, H, W] tensors are treated as
T=1 media. Output follows the existing tracker convention used
elsewhere in FastVideo: [T, C, H, W] uint8.
Source code in fastvideo/train/methods/rl/common/validation.py
fastvideo.train.methods.rl.common.validation_shard_indices
¶
Return fixed validation prompt indices for one distributed rank.
Source code in fastvideo/train/methods/rl/common/validation.py
Modules¶
fastvideo.train.methods.rl.common.prompt_sampling
¶
Prompt-row sampling helpers for online RL methods.
This module chooses and repeats dataset prompt rows across ranks for RL training batches. Here, "sampling" means selection, not generator sampling.
Classes¶
fastvideo.train.methods.rl.common.prompt_sampling.KRepeatSample
dataclass
¶
Local prompt indices for one distributed K-repeat sampling batch.
Functions:¶
fastvideo.train.methods.rl.common.prompt_sampling.distributed_k_repeat_indices
¶
distributed_k_repeat_indices(*, dataset_length: int, batch_size: int, repeats_per_prompt: int, world_size: int, rank: int, seed: int) -> KRepeatSample
Mirror DiffusionNFT's distributed K-repeat prompt sampler.
Adapted from DiffusionNFT's
scripts/train_nft_sd3.py::DistributedKRepeatSampler.
Source code in fastvideo/train/methods/rl/common/prompt_sampling.py
fastvideo.train.methods.rl.common.sampling
¶
Configurable diffusion samplers for RL training methods.
Classes¶
fastvideo.train.methods.rl.common.sampling.DiffusionSampler
¶
DiffusionSampler(config: SamplingConfig)
Thin model/scheduler sampler used by RL methods.
This intentionally does not call FastVideo's full inference pipelines.
RL training needs a reusable sampling primitive that works with
ModelBase wrappers and scheduler math without binding a method to
model-family pipeline classes such as WanDMDPipeline.
Source code in fastvideo/train/methods/rl/common/sampling.py
fastvideo.train.methods.rl.common.sampling.SamplingConfig
dataclass
¶
SamplingConfig(num_steps: int = 25, scheduler: SchedulerName = 'model_default', trajectory: TrajectoryName = 'ode', flow_shift: float | None = None, timesteps: list[float] | None = None, sigmas: list[float] | None = None)
YAML-backed sampling knobs shared by RL methods.
fastvideo.train.methods.rl.common.validation
¶
Shared validation helpers for RL training methods.
Functions:¶
fastvideo.train.methods.rl.common.validation.media_to_video_array
¶
media_to_video_array(media: Tensor) -> Any
Convert decoded media to a tracker video array.
Accepts [C, T, H, W] tensors. [C, H, W] tensors are treated as
T=1 media. Output follows the existing tracker convention used
elsewhere in FastVideo: [T, C, H, W] uint8.
Source code in fastvideo/train/methods/rl/common/validation.py
fastvideo.train.methods.rl.common.validation.validation_shard_indices
¶
Return fixed validation prompt indices for one distributed rank.