Skip to content

longcat

LongCat model plugin package.

Classes

Modules

fastvideo.train.models.longcat.longcat

LongCat model plugin (per-role instance).

Classes

fastvideo.train.models.longcat.longcat.LongCatModel
LongCatModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 12.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel

LongCat per-role model for training and distillation.

Source code in fastvideo/train/models/longcat/longcat.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 12.0,
    enable_gradient_checkpointing_type: str | None = None,
    transformer_override_safetensor: str | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=disable_custom_init_weights,
        flow_shift=self._validate_flow_shift(flow_shift),
        enable_gradient_checkpointing_type=enable_gradient_checkpointing_type,
        transformer_override_safetensor=transformer_override_safetensor,
    )
Functions
fastvideo.train.models.longcat.longcat.LongCatModel.predict_noise
predict_noise(noisy_latents: Tensor, timestep: Tensor, batch: TrainingBatch, *, conditional: bool, cfg_uncond: dict[str, Any] | None = None, attn_kind: Literal['dense', 'vsa'] = 'dense') -> Tensor

Adapt LongCat's sign convention to FineTuneMethod's target.

LongCatTransformer3DModel is pretrained to output the clean - noise direction; LongCatDenoisingStage (the bidirectional inference pipeline) explicitly negates the transformer output before handing it to FlowMatchEulerDiscreteScheduler.step. Training methods on the other hand (FineTuneMethod, DiffusionForcingSFTMethod) target noise - clean directly (the standard rectified-flow velocity Wan uses).

Without the negation here, the loss MSE pushes the transformer toward noise - clean, flipping its native output sign over training. Inference then applies its own negation on top, so the scheduler receives the wrong direction and produces noise even while the training loss is dropping. Verified empirically on a 100-step LongCat overfit run: step 0 generated meaningful video, step 100 was pure noise despite low loss.

Negating in predict_noise keeps the transformer's pretrained sign convention intact while presenting the training methods with a Wan-compatible pred ≈ noise - clean for MSE.

Source code in fastvideo/train/models/longcat/longcat.py
def predict_noise(
    self,
    noisy_latents: torch.Tensor,
    timestep: torch.Tensor,
    batch: TrainingBatch,
    *,
    conditional: bool,
    cfg_uncond: dict[str, Any] | None = None,
    attn_kind: Literal["dense", "vsa"] = "dense",
) -> torch.Tensor:
    """Adapt LongCat's sign convention to FineTuneMethod's target.

    ``LongCatTransformer3DModel`` is pretrained to output the
    ``clean - noise`` direction; ``LongCatDenoisingStage`` (the
    bidirectional inference pipeline) explicitly negates the
    transformer output before handing it to
    ``FlowMatchEulerDiscreteScheduler.step``. Training methods on
    the other hand (``FineTuneMethod``,
    ``DiffusionForcingSFTMethod``) target ``noise - clean``
    directly (the standard rectified-flow velocity Wan uses).

    Without the negation here, the loss MSE pushes the transformer
    toward ``noise - clean``, flipping its native output sign over
    training. Inference then applies its own negation on top, so
    the scheduler receives the wrong direction and produces noise
    even while the training loss is dropping. Verified empirically
    on a 100-step LongCat overfit run: step 0 generated meaningful
    video, step 100 was pure noise despite low loss.

    Negating in ``predict_noise`` keeps the transformer's
    pretrained sign convention intact while presenting the
    training methods with a Wan-compatible
    ``pred ≈ noise - clean`` for MSE.
    """
    pred = super().predict_noise(
        noisy_latents,
        timestep,
        batch,
        conditional=conditional,
        cfg_uncond=cfg_uncond,
        attn_kind=attn_kind,
    )
    return -pred