Skip to content

wan_causal

Wan causal model plugin (per-role instance, streaming/cache).

Classes

fastvideo.train.models.wan.wan_causal.WanCausalModel

WanCausalModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 3.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel, CausalModelBase

Wan per-role model with causal/streaming primitives.

Source code in fastvideo/train/models/wan/wan_causal.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 3.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
    )
    self._streaming_caches: (dict[tuple[int, str], _StreamingCaches]) = {}

Functions