Skip to content

models

Model build plugins for Phase 2/2.9 distillation.

These are "model plugins" selected by recipe.family / roles.<role>.family.

Modules

fastvideo.train.models.base

Classes

fastvideo.train.models.base.CausalModelBase

Bases: ModelBase

Extension for causal / streaming model plugins.

Cache state is internal to the model instance and keyed by cache_tag (no role handle needed).

Functions
fastvideo.train.models.base.CausalModelBase.clear_caches abstractmethod
clear_caches(*, cache_tag: str = 'pos') -> None

Clear internal caches before starting a new rollout.

Source code in fastvideo/train/models/base.py
@abstractmethod
def clear_caches(self, *, cache_tag: str = "pos") -> None:
    """Clear internal caches before starting a new rollout."""
fastvideo.train.models.base.CausalModelBase.predict_noise_streaming abstractmethod
predict_noise_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Streaming predict-noise that may update internal caches.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Streaming predict-noise that may update internal caches."""
fastvideo.train.models.base.CausalModelBase.predict_x0_streaming
predict_x0_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Predict x0 streaming via predict_noise_streaming + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Predict x0 streaming via
    ``predict_noise_streaming`` + conversion."""
    pred_noise = self.predict_noise_streaming(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cache_tag=cache_tag,
        store_kv=store_kv,
        cur_start_frame=cur_start_frame,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    if pred_noise is None:
        return None
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase

Bases: ABC

Per-role model instance.

Every role (student, teacher, critic, …) gets its own ModelBase instance. Each instance owns its own transformer and noise_scheduler. Heavyweight resources (VAE, dataloader, RNG seeds) are loaded lazily via :meth:init_preprocessors, which the method calls only on the student.

Attributes
fastvideo.train.models.base.ModelBase.device property
device: device

The local CUDA device for this rank.

fastvideo.train.models.base.ModelBase.num_train_timesteps property
num_train_timesteps: int

Return the scheduler's training timestep horizon.

Functions
fastvideo.train.models.base.ModelBase.add_noise abstractmethod
add_noise(clean_latents: Tensor, noise: Tensor, timestep: Tensor) -> Tensor

Apply forward-process noise at timestep.

Source code in fastvideo/train/models/base.py
@abstractmethod
def add_noise(
    self,
    clean_latents: torch.Tensor,
    noise: torch.Tensor,
    timestep: torch.Tensor,
) -> torch.Tensor:
    """Apply forward-process noise at *timestep*."""
fastvideo.train.models.base.ModelBase.backward abstractmethod
backward(loss: Tensor, ctx: Any, *, grad_accum_rounds: int) -> None

Backward that may restore forward-context.

Source code in fastvideo/train/models/base.py
@abstractmethod
def backward(
    self,
    loss: torch.Tensor,
    ctx: Any,
    *,
    grad_accum_rounds: int,
) -> None:
    """Backward that may restore forward-context."""
fastvideo.train.models.base.ModelBase.init_preprocessors
init_preprocessors(training_config: TrainingConfig) -> None

Load VAE, build dataloader, seed RNGs.

Called only on the student by the method's __init__. Default is a no-op so teacher/critic instances skip this.

Source code in fastvideo/train/models/base.py
def init_preprocessors(  # noqa: B027
        self,
        training_config: TrainingConfig,
) -> None:
    """Load VAE, build dataloader, seed RNGs.

    Called only on the student by the method's ``__init__``.
    Default is a no-op so teacher/critic instances skip this.
    """
fastvideo.train.models.base.ModelBase.on_train_start
on_train_start() -> None

Called once before the training loop begins.

Source code in fastvideo/train/models/base.py
def on_train_start(self) -> None:  # noqa: B027
    """Called once before the training loop begins."""
fastvideo.train.models.base.ModelBase.predict_noise abstractmethod
predict_noise(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict noise/flow for the given noisy latents.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict noise/flow for the given noisy latents."""
fastvideo.train.models.base.ModelBase.predict_x0
predict_x0(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict x0 via predict_noise + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict x0 via ``predict_noise`` + conversion."""
    pred_noise = self.predict_noise(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase.prepare_batch abstractmethod
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Convert a dataloader batch into forward primitives.

Source code in fastvideo/train/models/base.py
@abstractmethod
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Convert a dataloader batch into forward primitives."""
fastvideo.train.models.base.ModelBase.shift_and_clamp_timestep
shift_and_clamp_timestep(timestep: Tensor) -> Tensor

Apply model/pipeline timestep shifting and clamp.

Source code in fastvideo/train/models/base.py
def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
    """Apply model/pipeline timestep shifting and clamp."""
    return timestep

Functions

fastvideo.train.models.hunyuan

Hunyuan model plugin package.

Classes

Modules

fastvideo.train.models.hunyuan.hunyuan

Hunyuan model plugin (per-role instance).

Subclasses WanModel since HunyuanVideo uses the same FlowMatchEulerDiscreteScheduler and linear-interpolation noise schedule. Differences: - transformer class name - normalize_dit_input("hunyuan", ...) instead of ("wan", ...) - forward kwargs: no encoder_attention_mask, no return_dict - default flow_shift = 7

Classes
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel
HunyuanModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 7.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel

HunyuanVideo per-role model.

Inherits most behaviour from WanModel (noise scheduler, timestep sampling, attention metadata, backward). Overrides only the pieces that differ for Hunyuan.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 7.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
    )
Functions
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel.ensure_negative_conditioning
ensure_negative_conditioning() -> None

Encode the negative prompt with dual text encoders (LLaMA + CLIP).

Every rank encodes independently to avoid NCCL deadlocks when only a subset of ranks would otherwise participate.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def ensure_negative_conditioning(self) -> None:
    """Encode the negative prompt with dual text encoders
    (LLaMA + CLIP).

    Every rank encodes independently to avoid NCCL deadlocks
    when only a subset of ranks would otherwise participate.
    """
    if self.negative_prompt_embeds is not None:  # type: ignore[has-type]
        return

    assert self.training_config is not None
    tc = self.training_config
    device = self.device
    dtype = self._get_training_dtype()

    from transformers import (AutoTokenizer, CLIPTextModel, LlamaModel)

    from fastvideo.configs.pipelines.hunyuan import (
        clip_preprocess_text,
        clip_postprocess_text,
        llama_preprocess_text,
        llama_postprocess_text,
    )
    from fastvideo.utils import (PRECISION_TO_TYPE, maybe_download_model)

    model_path = maybe_download_model(tc.model_path)

    # Use configured precisions for each encoder.
    precisions = tc.pipeline_config.text_encoder_precisions
    llama_dtype = PRECISION_TO_TYPE[precisions[0]]
    clip_dtype = PRECISION_TO_TYPE[precisions[1]]

    # --- LLaMA ---
    llama_tok = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
    llama_enc = LlamaModel.from_pretrained(
        os.path.join(model_path, "text_encoder"),
        torch_dtype=llama_dtype,
    ).to(device).eval()

    llama_cfg = tc.pipeline_config.text_encoder_configs[0]
    llama_tok_kwargs = dict(llama_cfg.tokenizer_kwargs)

    negative_prompt = ""
    llama_text = llama_preprocess_text(negative_prompt)

    with torch.no_grad():
        llama_inputs = llama_tok(llama_text, **llama_tok_kwargs).to(device)
        llama_out = llama_enc(**llama_inputs, output_hidden_states=True)
        llama_embeds = llama_postprocess_text(llama_out).squeeze(0)

    del llama_enc, llama_tok

    # --- CLIP ---
    clip_tok = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer_2"))
    clip_enc = CLIPTextModel.from_pretrained(
        os.path.join(model_path, "text_encoder_2"),
        torch_dtype=clip_dtype,
    ).to(device).eval()

    clip_cfg = tc.pipeline_config.text_encoder_configs[1]
    clip_tok_kwargs = dict(clip_cfg.tokenizer_kwargs)
    clip_text = clip_preprocess_text(negative_prompt)

    with torch.no_grad():
        clip_inputs = clip_tok(clip_text, **clip_tok_kwargs).to(device)
        clip_out = clip_enc(**clip_inputs)
        clip_pooled = clip_postprocess_text(clip_out).squeeze(0)

    del clip_enc, clip_tok

    # --- Combine: [pooled_clip_row, llama_embeds] ---
    llama_dim = llama_embeds.shape[-1]
    pooled_row = torch.zeros(llama_dim, device=device)
    pooled_row[:clip_pooled.shape[-1]] = clip_pooled
    neg_embeds = torch.cat(
        [pooled_row.unsqueeze(0), llama_embeds],
        dim=0,
    ).unsqueeze(0).to(device=device, dtype=dtype)

    # Attention mask: all ones for the combined sequence.
    neg_mask = torch.ones(neg_embeds.shape[:2], device=device, dtype=dtype)

    self.negative_prompt_embeds = neg_embeds
    self.negative_prompt_attention_mask = neg_mask
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel.prepare_batch
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Same flow as Wan, but uses Hunyuan VAE normalisation.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Same flow as Wan, but uses Hunyuan VAE normalisation."""
    self.ensure_negative_conditioning()
    assert self.training_config is not None
    tc = self.training_config

    dtype = self._get_training_dtype()
    device = self.device

    training_batch = TrainingBatch()
    encoder_hidden_states = raw_batch["text_embedding"]
    encoder_attention_mask = raw_batch["text_attention_mask"]
    infos = raw_batch.get("info_list")

    if latents_source == "zeros":
        batch_size = encoder_hidden_states.shape[0]
        vae_config = (
            tc.pipeline_config.vae_config  # type: ignore[union-attr]
            .arch_config)
        num_channels = getattr(
            vae_config,
            "latent_channels",
            getattr(vae_config, "z_dim", 16),
        )
        spatial_compression_ratio = (vae_config.spatial_compression_ratio)
        latent_height = (tc.data.num_height // spatial_compression_ratio)
        latent_width = (tc.data.num_width // spatial_compression_ratio)
        latents = torch.zeros(
            batch_size,
            num_channels,
            tc.data.num_latent_t,
            latent_height,
            latent_width,
            device=device,
            dtype=dtype,
        )
    elif latents_source == "data":
        if "vae_latent" not in raw_batch:
            raise ValueError("vae_latent not found in batch "
                             "and latents_source='data'")
        latents = raw_batch["vae_latent"]
        latents = latents[:, :, :tc.data.num_latent_t]
        latents = latents.to(device, dtype=dtype)
    else:
        raise ValueError(f"Unknown latents_source: "
                         f"{latents_source!r}")

    training_batch.latents = latents
    training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
    training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
    training_batch.infos = infos

    # KEY DIFFERENCE: "hunyuan" normalisation
    training_batch.latents = normalize_dit_input(
        "hunyuan",
        training_batch.latents,
        self.vae,
    )
    training_batch = self._prepare_dit_inputs(training_batch, generator)
    training_batch = self._build_attention_metadata(training_batch)

    training_batch.attn_metadata_vsa = copy.deepcopy(training_batch.attn_metadata)
    if training_batch.attn_metadata is not None:
        training_batch.attn_metadata.VSA_sparsity = 0.0  # type: ignore[attr-defined]

    return training_batch

fastvideo.train.models.wan

Wan model plugin package.

Classes

Modules

fastvideo.train.models.wan.wan

Wan model plugin (per-role instance).

Classes
fastvideo.train.models.wan.wan.WanModel
WanModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: ModelBase

Wan per-role model: owns transformer + noise_scheduler.

Source code in fastvideo/train/models/wan/wan.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    self._init_from = str(init_from)
    self._trainable = bool(trainable)

    self.transformer = self._load_transformer(
        init_from=self._init_from,
        trainable=self._trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        training_config=training_config,
        transformer_override_safetensor=(transformer_override_safetensor),
    )

    self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))

    # Filled by init_preprocessors (student only).
    self.vae: Any = None
    self.training_config: TrainingConfig = training_config
    self.dataloader: Any = None
    self.validator: Any = None
    self.start_step: int = 0

    self.world_group: Any = None
    self.sp_group: Any = None

    self.negative_prompt_embeds: (torch.Tensor | None) = None
    self.negative_prompt_attention_mask: (torch.Tensor | None) = None

    # Timestep mechanics.
    self.timestep_shift: float = float(flow_shift)
    self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
    self.min_timestep: int = 0
    self.max_timestep: int = self.num_train_timestep
Functions
fastvideo.train.models.wan.wan_causal

Wan causal model plugin (per-role instance, streaming/cache).

Classes
fastvideo.train.models.wan.wan_causal.WanCausalModel
WanCausalModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel, CausalModelBase

Wan per-role model with causal/streaming primitives.

Source code in fastvideo/train/models/wan/wan_causal.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
    )
    self._streaming_caches: (dict[tuple[int, str], _StreamingCaches]) = {}
Functions