Skip to content

models

Model build plugins for Phase 2/2.9 distillation.

These are "model plugins" selected by recipe.family / roles.<role>.family.

Modules

fastvideo.train.models.base

Classes

fastvideo.train.models.base.CausalModelBase

Bases: ModelBase

Extension for causal / streaming model plugins.

Cache state is internal to the model instance and keyed by cache_tag (no role handle needed).

Functions
fastvideo.train.models.base.CausalModelBase.clear_caches abstractmethod
clear_caches(*, cache_tag: str = 'pos') -> None

Clear internal caches before starting a new rollout.

Source code in fastvideo/train/models/base.py
@abstractmethod
def clear_caches(self, *, cache_tag: str = "pos") -> None:
    """Clear internal caches before starting a new rollout."""
fastvideo.train.models.base.CausalModelBase.predict_noise_streaming abstractmethod
predict_noise_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Streaming predict-noise that may update internal caches.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Streaming predict-noise that may update internal caches."""
fastvideo.train.models.base.CausalModelBase.predict_x0_streaming
predict_x0_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Predict x0 streaming via predict_noise_streaming + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Predict x0 streaming via
    ``predict_noise_streaming`` + conversion."""
    pred_noise = self.predict_noise_streaming(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cache_tag=cache_tag,
        store_kv=store_kv,
        cur_start_frame=cur_start_frame,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    if pred_noise is None:
        return None
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase

Bases: ABC

Per-role model instance.

Every role (student, teacher, critic, …) gets its own ModelBase instance. Each instance owns its own transformer and noise_scheduler. Heavyweight resources (VAE, dataloader, RNG seeds) are loaded lazily via :meth:init_preprocessors, which the method calls only on the student.

Attributes
fastvideo.train.models.base.ModelBase.device property
device: device

The local CUDA device for this rank.

fastvideo.train.models.base.ModelBase.num_train_timesteps property
num_train_timesteps: int

Return the scheduler's training timestep horizon.

Functions
fastvideo.train.models.base.ModelBase.add_noise abstractmethod
add_noise(clean_latents: Tensor, noise: Tensor, timestep: Tensor) -> Tensor

Apply forward-process noise at timestep.

Source code in fastvideo/train/models/base.py
@abstractmethod
def add_noise(
    self,
    clean_latents: torch.Tensor,
    noise: torch.Tensor,
    timestep: torch.Tensor,
) -> torch.Tensor:
    """Apply forward-process noise at *timestep*."""
fastvideo.train.models.base.ModelBase.backward abstractmethod
backward(loss: Tensor, ctx: Any, *, grad_accum_rounds: int) -> None

Backward that may restore forward-context.

Source code in fastvideo/train/models/base.py
@abstractmethod
def backward(
    self,
    loss: torch.Tensor,
    ctx: Any,
    *,
    grad_accum_rounds: int,
) -> None:
    """Backward that may restore forward-context."""
fastvideo.train.models.base.ModelBase.init_preprocessors
init_preprocessors(training_config: TrainingConfig) -> None

Load VAE, build dataloader, seed RNGs.

Called only on the student by the method's __init__. Default is a no-op so teacher/critic instances skip this.

Source code in fastvideo/train/models/base.py
def init_preprocessors(  # noqa: B027
        self,
        training_config: TrainingConfig,
) -> None:
    """Load VAE, build dataloader, seed RNGs.

    Called only on the student by the method's ``__init__``.
    Default is a no-op so teacher/critic instances skip this.
    """
fastvideo.train.models.base.ModelBase.on_train_start
on_train_start() -> None

Called once before the training loop begins.

Source code in fastvideo/train/models/base.py
def on_train_start(self) -> None:  # noqa: B027
    """Called once before the training loop begins."""
fastvideo.train.models.base.ModelBase.predict_noise abstractmethod
predict_noise(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict noise/flow for the given noisy latents.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict noise/flow for the given noisy latents."""
fastvideo.train.models.base.ModelBase.predict_x0
predict_x0(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict x0 via predict_noise + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict x0 via ``predict_noise`` + conversion."""
    pred_noise = self.predict_noise(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase.prepare_batch abstractmethod
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Convert a dataloader batch into forward primitives.

Source code in fastvideo/train/models/base.py
@abstractmethod
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Convert a dataloader batch into forward primitives."""
fastvideo.train.models.base.ModelBase.shift_and_clamp_timestep
shift_and_clamp_timestep(timestep: Tensor) -> Tensor

Apply model/pipeline timestep shifting and clamp.

Source code in fastvideo/train/models/base.py
def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
    """Apply model/pipeline timestep shifting and clamp."""
    return timestep

Functions

fastvideo.train.models.wan

Wan model plugin package.

Classes

Modules

fastvideo.train.models.wan.wan

Wan model plugin (per-role instance).

Classes
fastvideo.train.models.wan.wan.WanModel
WanModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: ModelBase

Wan per-role model: owns transformer + noise_scheduler.

Source code in fastvideo/train/models/wan/wan.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    self._init_from = str(init_from)
    self._trainable = bool(trainable)

    self.transformer = self._load_transformer(
        init_from=self._init_from,
        trainable=self._trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        training_config=training_config,
        transformer_override_safetensor=(transformer_override_safetensor),
    )

    self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))

    # Filled by init_preprocessors (student only).
    self.vae: Any = None
    self.training_config: TrainingConfig = training_config
    self.dataloader: Any = None
    self.validator: Any = None
    self.start_step: int = 0

    self.world_group: Any = None
    self.sp_group: Any = None

    self.negative_prompt_embeds: (torch.Tensor | None) = None
    self.negative_prompt_attention_mask: (torch.Tensor | None) = None

    # Timestep mechanics.
    self.timestep_shift: float = float(flow_shift)
    self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
    self.min_timestep: int = 0
    self.max_timestep: int = self.num_train_timestep
Functions
fastvideo.train.models.wan.wan_causal

Wan causal model plugin (per-role instance, streaming/cache).

Classes
fastvideo.train.models.wan.wan_causal.WanCausalModel
WanCausalModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel, CausalModelBase

Wan per-role model with causal/streaming primitives.

Source code in fastvideo/train/models/wan/wan_causal.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
    )
    self._streaming_caches: (dict[tuple[int, str], _StreamingCaches]) = {}
Functions