Skip to content

models

Modules

fastvideo.train.models.base

Classes

fastvideo.train.models.base.CausalModelBase
CausalModelBase(*, trainable: bool = True, lora: LoraConfig | dict[str, Any] | None = None)

Bases: ModelBase

Extension for causal / streaming model plugins.

Cache state is internal to the model instance and keyed by cache_tag (no role handle needed).

Source code in fastvideo/train/models/base.py
def __init__(
    self,
    *,
    trainable: bool = True,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    from fastvideo.train.utils.lora import LoraConfig

    self._trainable = bool(trainable)
    self._lora_config: LoraConfig | None = LoraConfig.coerce(lora)
    self._num_lora_layers = 0
Methods:
fastvideo.train.models.base.CausalModelBase.clear_caches abstractmethod
clear_caches(*, cache_tag: str = 'pos') -> None

Clear internal caches before starting a new rollout.

Source code in fastvideo/train/models/base.py
@abstractmethod
def clear_caches(self, *, cache_tag: str = "pos") -> None:
    """Clear internal caches before starting a new rollout."""
fastvideo.train.models.base.CausalModelBase.predict_noise_streaming abstractmethod
predict_noise_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Streaming predict-noise that may update internal caches.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Streaming predict-noise that may update internal caches."""
fastvideo.train.models.base.CausalModelBase.predict_x0_streaming
predict_x0_streaming(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cache_tag: str = 'pos', store_kv: bool = False, cur_start_frame: int = 0, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor | None

Predict x0 streaming via predict_noise_streaming + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0_streaming(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cache_tag: str = "pos",
    store_kv: bool = False,
    cur_start_frame: int = 0,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor | None:
    """Predict x0 streaming via
    ``predict_noise_streaming`` + conversion."""
    pred_noise = self.predict_noise_streaming(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cache_tag=cache_tag,
        store_kv=store_kv,
        cur_start_frame=cur_start_frame,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    if pred_noise is None:
        return None
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase
ModelBase(*, trainable: bool = True, lora: LoraConfig | dict[str, Any] | None = None)

Bases: ABC

Per-role model instance.

Every role (student, teacher, critic, …) gets its own ModelBase instance. Each instance owns its own transformer and noise_scheduler. Heavyweight resources (VAE, dataloader, RNG seeds) are loaded lazily via :meth:init_preprocessors, which the method calls only on the student.

Source code in fastvideo/train/models/base.py
def __init__(
    self,
    *,
    trainable: bool = True,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    from fastvideo.train.utils.lora import LoraConfig

    self._trainable = bool(trainable)
    self._lora_config: LoraConfig | None = LoraConfig.coerce(lora)
    self._num_lora_layers = 0
Attributes
fastvideo.train.models.base.ModelBase.device property
device: device

The local CUDA device for this rank.

fastvideo.train.models.base.ModelBase.num_train_timesteps property
num_train_timesteps: int

Return the scheduler's training timestep horizon.

Methods:
fastvideo.train.models.base.ModelBase.add_noise abstractmethod
add_noise(clean_latents: Tensor, noise: Tensor, timestep: Tensor) -> Tensor

Apply forward-process noise at timestep.

Source code in fastvideo/train/models/base.py
@abstractmethod
def add_noise(
    self,
    clean_latents: torch.Tensor,
    noise: torch.Tensor,
    timestep: torch.Tensor,
) -> torch.Tensor:
    """Apply forward-process noise at *timestep*."""
fastvideo.train.models.base.ModelBase.backward abstractmethod
backward(loss: Tensor, ctx: Any, *, grad_accum_rounds: int) -> None

Backward that may restore forward-context.

Source code in fastvideo/train/models/base.py
@abstractmethod
def backward(
    self,
    loss: torch.Tensor,
    ctx: Any,
    *,
    grad_accum_rounds: int,
) -> None:
    """Backward that may restore forward-context."""
fastvideo.train.models.base.ModelBase.init_preprocessors
init_preprocessors(training_config: TrainingConfig) -> None

Load VAE, build dataloader, seed RNGs.

Called only on the student by the method's __init__. Default is a no-op so teacher/critic instances skip this.

Source code in fastvideo/train/models/base.py
def init_preprocessors(  # noqa: B027
        self,
        training_config: TrainingConfig,
) -> None:
    """Load VAE, build dataloader, seed RNGs.

    Called only on the student by the method's ``__init__``.
    Default is a no-op so teacher/critic instances skip this.
    """
fastvideo.train.models.base.ModelBase.on_train_start
on_train_start() -> None

Called once before the training loop begins.

Source code in fastvideo/train/models/base.py
def on_train_start(self) -> None:  # noqa: B027
    """Called once before the training loop begins."""
fastvideo.train.models.base.ModelBase.predict_noise abstractmethod
predict_noise(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict noise/flow for the given noisy latents.

Source code in fastvideo/train/models/base.py
@abstractmethod
def predict_noise(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict noise/flow for the given noisy latents."""
fastvideo.train.models.base.ModelBase.predict_x0
predict_x0(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Predict x0 via predict_noise + conversion.

Source code in fastvideo/train/models/base.py
def predict_x0(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Predict x0 via ``predict_noise`` + conversion."""
    pred_noise = self.predict_noise(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    return pred_noise_to_pred_video(
        pred_noise=pred_noise.flatten(0, 1),
        noise_input_latent=noisy_latents.flatten(0, 1),
        timestep=timestep,
        scheduler=self.noise_scheduler,
    ).unflatten(0, pred_noise.shape[:2])
fastvideo.train.models.base.ModelBase.prepare_batch abstractmethod
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Convert a dataloader batch into forward primitives.

Source code in fastvideo/train/models/base.py
@abstractmethod
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Convert a dataloader batch into forward primitives."""
fastvideo.train.models.base.ModelBase.shift_and_clamp_timestep
shift_and_clamp_timestep(timestep: Tensor) -> Tensor

Apply model/pipeline timestep shifting and clamp.

Source code in fastvideo/train/models/base.py
def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
    """Apply model/pipeline timestep shifting and clamp."""
    return timestep

Functions:

fastvideo.train.models.cosmos

Cosmos model plugin package.

Classes

Modules

fastvideo.train.models.cosmos.cosmos

Cosmos model plugin (per-role instance).

Subclasses WanModel since Cosmos uses the same FlowMatchEulerDiscreteScheduler. Differences: - transformer class name: CosmosTransformer3DModel - normalize_dit_input("cosmos", ...) instead of ("wan", ...) - forward kwargs: no encoder_attention_mask, needs condition_mask + padding_mask + fps - hidden_states in (B,C,T,H,W) — no permute needed - default flow_shift = 1.0 - single T5 text encoder (not dual like Hunyuan)

Classes
fastvideo.train.models.cosmos.cosmos.CosmosModel
CosmosModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 1.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel

Cosmos 2.5 per-role model.

Inherits most behaviour from WanModel (noise scheduler, timestep sampling, attention metadata, backward). Overrides only the pieces that differ for Cosmos 2.5.

Cosmos 2.5 uses: - Cosmos25Transformer3DModel (velocity prediction) - EDM noise schedule: x_t = x_0 + sigma * eps - No input/output preconditioning (raw latents) - Timestep = raw sigma value - Model output = velocity ≈ noise

Source code in fastvideo/train/models/cosmos/cosmos.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 1.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
    )
Methods:
fastvideo.train.models.cosmos.cosmos.CosmosModel.ensure_negative_conditioning
ensure_negative_conditioning() -> None

Create negative (unconditional) prompt embeddings.

Cosmos 2.5 uses Reason1 (Qwen2.5-VL) which is expensive to load. This method only supports training_cfg_rate=0 (no classifier-free guidance dropout), in which case the negative embedding is never used and a zero placeholder sized to match the text embedding dimension is sufficient. training_cfg_rate>0 would require real Reason1 negative embeddings and is rejected here to avoid silently training with zero-vector "unconditional" inputs.

Source code in fastvideo/train/models/cosmos/cosmos.py
def ensure_negative_conditioning(self) -> None:
    """Create negative (unconditional) prompt embeddings.

    Cosmos 2.5 uses Reason1 (Qwen2.5-VL) which is expensive
    to load.  This method only supports ``training_cfg_rate=0``
    (no classifier-free guidance dropout), in which case the
    negative embedding is never used and a zero placeholder
    sized to match the text embedding dimension is sufficient.
    ``training_cfg_rate>0`` would require real Reason1 negative
    embeddings and is rejected here to avoid silently training
    with zero-vector "unconditional" inputs.
    """
    if self.negative_prompt_embeds is not None:  # type: ignore[has-type]
        return

    assert self.training_config is not None
    tc = self.training_config

    cfg_rate = float(tc.data.training_cfg_rate or 0.0)
    if cfg_rate > 0.0:
        raise NotImplementedError("Cosmos 2.5 currently only supports training_cfg_rate=0; "
                                  f"got training_cfg_rate={cfg_rate}. Real negative-prompt "
                                  "embeddings via Reason1 (Qwen2.5-VL) are not implemented "
                                  "yet — using the zero placeholder with CFG dropout would "
                                  "train against zero-vector \"unconditional\" inputs and "
                                  "produce wrong gradients. Set "
                                  "training.data.training_cfg_rate=0.")

    device = self.device
    dtype = self._get_training_dtype()

    # Infer embedding dimension from the pipeline config's
    # text encoder settings, or fall back to a reasonable
    # default for Cosmos 2.5 (Reason1 full_concat: 100352).
    text_enc_cfgs = tc.pipeline_config.text_encoder_configs
    if text_enc_cfgs:
        arch = text_enc_cfgs[0].arch_config
        embed_dim = getattr(arch, "hidden_size", 100352)
    else:
        embed_dim = 100352

    num_tokens = 512  # Reason1 default padding length

    neg_embeds = torch.zeros(
        1,
        num_tokens,
        embed_dim,
        device=device,
        dtype=dtype,
    )
    neg_mask = torch.ones(
        1,
        num_tokens,
        device=device,
        dtype=dtype,
    )

    self.negative_prompt_embeds = neg_embeds
    self.negative_prompt_attention_mask = neg_mask
fastvideo.train.models.cosmos.cosmos.CosmosModel.prepare_batch
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Same flow as Wan, but uses Cosmos VAE normalisation.

Source code in fastvideo/train/models/cosmos/cosmos.py
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Same flow as Wan, but uses Cosmos VAE
    normalisation."""
    self.ensure_negative_conditioning()
    assert self.training_config is not None
    tc = self.training_config

    dtype = self._get_training_dtype()
    device = self.device

    training_batch = TrainingBatch()
    encoder_hidden_states = raw_batch["text_embedding"]
    encoder_attention_mask = raw_batch["text_attention_mask"]
    infos = raw_batch.get("info_list")

    if latents_source == "zeros":
        batch_size = encoder_hidden_states.shape[0]
        vae_config = (
            tc.pipeline_config.vae_config  # type: ignore[union-attr]
            .arch_config)
        num_channels = getattr(
            vae_config,
            "z_dim",
            getattr(vae_config, "latent_channels", 16),
        )
        spatial_compression_ratio = (vae_config.spatial_compression_ratio)
        latent_height = (tc.data.num_height // spatial_compression_ratio)
        latent_width = (tc.data.num_width // spatial_compression_ratio)
        latents = torch.zeros(
            batch_size,
            num_channels,
            tc.data.num_latent_t,
            latent_height,
            latent_width,
            device=device,
            dtype=dtype,
        )
    elif latents_source == "data":
        if "vae_latent" not in raw_batch:
            raise ValueError("vae_latent not found in batch "
                             "and latents_source='data'")
        latents = raw_batch["vae_latent"]
        latents = latents[:, :, :tc.data.num_latent_t]
        latents = latents.to(device, dtype=dtype)
    else:
        raise ValueError(f"Unknown latents_source: "
                         f"{latents_source!r}")

    training_batch.latents = latents
    training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
    training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
    training_batch.infos = infos

    # KEY DIFFERENCE: "cosmos" normalisation
    training_batch.latents = normalize_dit_input(
        "cosmos",
        training_batch.latents,
        self.vae,
    )
    training_batch = self._prepare_dit_inputs(training_batch, generator)
    training_batch = self._build_attention_metadata(training_batch)

    # Shallow copy keeps the lru_cache'd LongTensor index fields shared
    # with the original metadata; only the float ``VSA_sparsity`` differs
    # between the two views. deepcopy here would materialize a fresh copy
    # of all four cached index tensors on every training step.
    training_batch.attn_metadata_vsa = copy.copy(training_batch.attn_metadata)
    if training_batch.attn_metadata is not None:
        training_batch.attn_metadata.VSA_sparsity = 0.0  # type: ignore[attr-defined]

    return training_batch
Functions:

fastvideo.train.models.hunyuan

Hunyuan model plugin package.

Classes

Modules

fastvideo.train.models.hunyuan.hunyuan

Hunyuan model plugin (per-role instance).

Subclasses WanModel since HunyuanVideo uses the same FlowMatchEulerDiscreteScheduler and linear-interpolation noise schedule. Differences: - transformer class name - normalize_dit_input("hunyuan", ...) instead of ("wan", ...) - forward kwargs: no encoder_attention_mask, no return_dict - default flow_shift = 7

Classes
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel
HunyuanModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 7.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None, lora: LoraConfig | dict[str, Any] | None = None)

Bases: WanModel

HunyuanVideo per-role model.

Inherits most behaviour from WanModel (noise scheduler, timestep sampling, attention metadata, backward). Overrides only the pieces that differ for Hunyuan.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 7.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
        lora=lora,
    )
Methods:
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel.ensure_negative_conditioning
ensure_negative_conditioning() -> None

Encode the negative prompt with dual text encoders (LLaMA + CLIP).

Every rank encodes independently to avoid NCCL deadlocks when only a subset of ranks would otherwise participate.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def ensure_negative_conditioning(self) -> None:
    """Encode the negative prompt with dual text encoders
    (LLaMA + CLIP).

    Every rank encodes independently to avoid NCCL deadlocks
    when only a subset of ranks would otherwise participate.
    """
    if self.negative_prompt_embeds is not None:  # type: ignore[has-type]
        return

    assert self.training_config is not None
    tc = self.training_config
    device = self.device
    dtype = self._get_training_dtype()

    from transformers import (AutoTokenizer, CLIPTextModel, LlamaModel)

    from fastvideo.configs.pipelines.hunyuan import (
        clip_preprocess_text,
        clip_postprocess_text,
        llama_preprocess_text,
        llama_postprocess_text,
    )
    from fastvideo.utils import (PRECISION_TO_TYPE, maybe_download_model)

    model_path = maybe_download_model(tc.model_path)

    # Use configured precisions for each encoder.
    precisions = tc.pipeline_config.text_encoder_precisions
    llama_dtype = PRECISION_TO_TYPE[precisions[0]]
    clip_dtype = PRECISION_TO_TYPE[precisions[1]]

    # --- LLaMA ---
    llama_tok = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
    llama_enc = LlamaModel.from_pretrained(
        os.path.join(model_path, "text_encoder"),
        torch_dtype=llama_dtype,
    ).to(device).eval()

    llama_cfg = tc.pipeline_config.text_encoder_configs[0]
    llama_tok_kwargs = dict(llama_cfg.tokenizer_kwargs)

    negative_prompt = ""
    llama_text = llama_preprocess_text(negative_prompt)

    with torch.no_grad():
        llama_inputs = llama_tok(llama_text, **llama_tok_kwargs).to(device)
        llama_out = llama_enc(**llama_inputs, output_hidden_states=True)
        llama_embeds = llama_postprocess_text(llama_out).squeeze(0)

    del llama_enc, llama_tok

    # --- CLIP ---
    clip_tok = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer_2"))
    clip_enc = CLIPTextModel.from_pretrained(
        os.path.join(model_path, "text_encoder_2"),
        torch_dtype=clip_dtype,
    ).to(device).eval()

    clip_cfg = tc.pipeline_config.text_encoder_configs[1]
    clip_tok_kwargs = dict(clip_cfg.tokenizer_kwargs)
    clip_text = clip_preprocess_text(negative_prompt)

    with torch.no_grad():
        clip_inputs = clip_tok(clip_text, **clip_tok_kwargs).to(device)
        clip_out = clip_enc(**clip_inputs)
        clip_pooled = clip_postprocess_text(clip_out).squeeze(0)

    del clip_enc, clip_tok

    # --- Combine: [pooled_clip_row, llama_embeds] ---
    llama_dim = llama_embeds.shape[-1]
    pooled_row = torch.zeros(llama_dim, device=device)
    pooled_row[:clip_pooled.shape[-1]] = clip_pooled
    neg_embeds = torch.cat(
        [pooled_row.unsqueeze(0), llama_embeds],
        dim=0,
    ).unsqueeze(0).to(device=device, dtype=dtype)

    # Attention mask: all ones for the combined sequence.
    neg_mask = torch.ones(neg_embeds.shape[:2], device=device, dtype=dtype)

    self.negative_prompt_embeds = neg_embeds
    self.negative_prompt_attention_mask = neg_mask
fastvideo.train.models.hunyuan.hunyuan.HunyuanModel.prepare_batch
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Same flow as Wan, but uses Hunyuan VAE normalisation.

Source code in fastvideo/train/models/hunyuan/hunyuan.py
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Same flow as Wan, but uses Hunyuan VAE normalisation."""
    self.ensure_negative_conditioning()
    assert self.training_config is not None
    tc = self.training_config

    dtype = self._get_training_dtype()
    device = self.device

    training_batch = TrainingBatch()
    encoder_hidden_states = raw_batch["text_embedding"]
    encoder_attention_mask = raw_batch["text_attention_mask"]
    infos = raw_batch.get("info_list")

    if latents_source == "zeros":
        batch_size = encoder_hidden_states.shape[0]
        vae_config = (
            tc.pipeline_config.vae_config  # type: ignore[union-attr]
            .arch_config)
        num_channels = getattr(
            vae_config,
            "latent_channels",
            getattr(vae_config, "z_dim", 16),
        )
        spatial_compression_ratio = (vae_config.spatial_compression_ratio)
        latent_height = (tc.data.num_height // spatial_compression_ratio)
        latent_width = (tc.data.num_width // spatial_compression_ratio)
        latents = torch.zeros(
            batch_size,
            num_channels,
            tc.data.num_latent_t,
            latent_height,
            latent_width,
            device=device,
            dtype=dtype,
        )
    elif latents_source == "data":
        if "vae_latent" not in raw_batch:
            raise ValueError("vae_latent not found in batch "
                             "and latents_source='data'")
        latents = raw_batch["vae_latent"]
        latents = latents[:, :, :tc.data.num_latent_t]
        latents = latents.to(device, dtype=dtype)
    else:
        raise ValueError(f"Unknown latents_source: "
                         f"{latents_source!r}")

    training_batch.latents = latents
    training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
    training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
    training_batch.infos = infos

    # KEY DIFFERENCE: "hunyuan" normalisation
    training_batch.latents = normalize_dit_input(
        "hunyuan",
        training_batch.latents,
        self.vae,
    )
    training_batch = self._prepare_dit_inputs(training_batch, generator)
    training_batch = self._build_attention_metadata(training_batch)

    # Shallow copy keeps the lru_cache'd LongTensor index fields shared
    # with the original metadata; only the float ``VSA_sparsity`` differs
    # between the two views. deepcopy here would materialize a fresh copy
    # of all four cached index tensors on every training step.
    training_batch.attn_metadata_vsa = copy.copy(training_batch.attn_metadata)
    if training_batch.attn_metadata is not None:
        training_batch.attn_metadata.VSA_sparsity = 0.0  # type: ignore[attr-defined]

    return training_batch

fastvideo.train.models.longcat

LongCat model plugin package.

Classes

Modules

fastvideo.train.models.longcat.longcat

LongCat model plugin (per-role instance).

Classes
fastvideo.train.models.longcat.longcat.LongCatModel
LongCatModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 12.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel

LongCat per-role model for training and distillation.

Source code in fastvideo/train/models/longcat/longcat.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 12.0,
    enable_gradient_checkpointing_type: str | None = None,
    transformer_override_safetensor: str | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=disable_custom_init_weights,
        flow_shift=self._validate_flow_shift(flow_shift),
        enable_gradient_checkpointing_type=enable_gradient_checkpointing_type,
        transformer_override_safetensor=transformer_override_safetensor,
    )
Methods:
fastvideo.train.models.longcat.longcat.LongCatModel.predict_noise
predict_noise(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Adapt LongCat's sign convention to FineTuneMethod's target.

LongCatTransformer3DModel is pretrained to output the clean - noise direction; LongCatDenoisingStage (the bidirectional inference pipeline) explicitly negates the transformer output before handing it to FlowMatchEulerDiscreteScheduler.step. Training methods on the other hand (FineTuneMethod, DiffusionForcingSFTMethod) target noise - clean directly (the standard rectified-flow velocity Wan uses).

Without the negation here, the loss MSE pushes the transformer toward noise - clean, flipping its native output sign over training. Inference then applies its own negation on top, so the scheduler receives the wrong direction and produces noise even while the training loss is dropping. Verified empirically on a 100-step LongCat overfit run: step 0 generated meaningful video, step 100 was pure noise despite low loss.

Negating in predict_noise keeps the transformer's pretrained sign convention intact while presenting the training methods with a Wan-compatible pred ≈ noise - clean for MSE.

Source code in fastvideo/train/models/longcat/longcat.py
def predict_noise(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Adapt LongCat's sign convention to FineTuneMethod's target.

    ``LongCatTransformer3DModel`` is pretrained to output the
    ``clean - noise`` direction; ``LongCatDenoisingStage`` (the
    bidirectional inference pipeline) explicitly negates the
    transformer output before handing it to
    ``FlowMatchEulerDiscreteScheduler.step``. Training methods on
    the other hand (``FineTuneMethod``,
    ``DiffusionForcingSFTMethod``) target ``noise - clean``
    directly (the standard rectified-flow velocity Wan uses).

    Without the negation here, the loss MSE pushes the transformer
    toward ``noise - clean``, flipping its native output sign over
    training. Inference then applies its own negation on top, so
    the scheduler receives the wrong direction and produces noise
    even while the training loss is dropping. Verified empirically
    on a 100-step LongCat overfit run: step 0 generated meaningful
    video, step 100 was pure noise despite low loss.

    Negating in ``predict_noise`` keeps the transformer's
    pretrained sign convention intact while presenting the
    training methods with a Wan-compatible
    ``pred ≈ noise - clean`` for MSE.
    """
    pred = super().predict_noise(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    return -pred

fastvideo.train.models.wan

Wan model plugin package.

Classes

Modules

fastvideo.train.models.wan.wan

Wan model plugin (per-role instance).

Classes
fastvideo.train.models.wan.wan.WanModel
WanModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None, lora: LoraConfig | dict[str, Any] | None = None)

Bases: ModelBase

Wan per-role model: owns transformer + noise_scheduler.

Source code in fastvideo/train/models/wan/wan.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    super().__init__(
        trainable=trainable,
        lora=lora,
    )
    self._init_from = str(init_from)

    self.transformer = self._load_transformer(
        init_from=self._init_from,
        trainable=self._trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        training_config=training_config,
        transformer_override_safetensor=(transformer_override_safetensor),
    )

    self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))

    # Filled by init_preprocessors (student only).
    self.vae: Any = None
    self.training_config: TrainingConfig = training_config
    self.dataloader: Any = None
    self.validator: Any = None
    self.start_step: int = 0

    self.world_group: Any = None
    self.sp_group: Any = None

    self.negative_prompt_embeds: (torch.Tensor | None) = None
    self.negative_prompt_attention_mask: (torch.Tensor | None) = None

    # Timestep mechanics.
    self.timestep_shift: float = float(flow_shift)
    self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
    self.min_timestep: int = 0
    self.max_timestep: int = self.num_train_timestep
Functions:
fastvideo.train.models.wan.wan_causal

Wan causal model plugin (per-role instance, streaming/cache).

Classes
fastvideo.train.models.wan.wan_causal.WanCausalModel
WanCausalModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None, lora: LoraConfig | dict[str, Any] | None = None)

Bases: WanModel, CausalModelBase

Wan per-role model with causal/streaming primitives.

Source code in fastvideo/train/models/wan/wan_causal.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
    lora: LoraConfig | dict[str, Any] | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
        lora=lora,
    )
    self._streaming_caches: (dict[tuple[int, str], _StreamingCaches]) = {}
Functions: