Skip to content

cosmos

Cosmos model plugin package.

Classes

Modules

fastvideo.train.models.cosmos.cosmos

Cosmos model plugin (per-role instance).

Subclasses WanModel since Cosmos uses the same FlowMatchEulerDiscreteScheduler. Differences: - transformer class name: CosmosTransformer3DModel - normalize_dit_input("cosmos", ...) instead of ("wan", ...) - forward kwargs: no encoder_attention_mask, needs condition_mask + padding_mask + fps - hidden_states in (B,C,T,H,W) — no permute needed - default flow_shift = 1.0 - single T5 text encoder (not dual like Hunyuan)

Classes

fastvideo.train.models.cosmos.cosmos.CosmosModel
CosmosModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 1.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)

Bases: WanModel

Cosmos 2.5 per-role model.

Inherits most behaviour from WanModel (noise scheduler, timestep sampling, attention metadata, backward). Overrides only the pieces that differ for Cosmos 2.5.

Cosmos 2.5 uses: - Cosmos25Transformer3DModel (velocity prediction) - EDM noise schedule: x_t = x_0 + sigma * eps - No input/output preconditioning (raw latents) - Timestep = raw sigma value - Model output = velocity ≈ noise

Source code in fastvideo/train/models/cosmos/cosmos.py
def __init__(
    self,
    *,
    init_from: str,
    training_config: TrainingConfig,
    trainable: bool = True,
    disable_custom_init_weights: bool = False,
    flow_shift: float = 1.0,
    enable_gradient_checkpointing_type: str
    | None = None,
    transformer_override_safetensor: str
    | None = None,
) -> None:
    super().__init__(
        init_from=init_from,
        training_config=training_config,
        trainable=trainable,
        disable_custom_init_weights=(disable_custom_init_weights),
        flow_shift=flow_shift,
        enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
        transformer_override_safetensor=(transformer_override_safetensor),
    )
Functions
fastvideo.train.models.cosmos.cosmos.CosmosModel.ensure_negative_conditioning
ensure_negative_conditioning() -> None

Create negative (unconditional) prompt embeddings.

Cosmos 2.5 uses Reason1 (Qwen2.5-VL) which is expensive to load. This method only supports training_cfg_rate=0 (no classifier-free guidance dropout), in which case the negative embedding is never used and a zero placeholder sized to match the text embedding dimension is sufficient. training_cfg_rate>0 would require real Reason1 negative embeddings and is rejected here to avoid silently training with zero-vector "unconditional" inputs.

Source code in fastvideo/train/models/cosmos/cosmos.py
def ensure_negative_conditioning(self) -> None:
    """Create negative (unconditional) prompt embeddings.

    Cosmos 2.5 uses Reason1 (Qwen2.5-VL) which is expensive
    to load.  This method only supports ``training_cfg_rate=0``
    (no classifier-free guidance dropout), in which case the
    negative embedding is never used and a zero placeholder
    sized to match the text embedding dimension is sufficient.
    ``training_cfg_rate>0`` would require real Reason1 negative
    embeddings and is rejected here to avoid silently training
    with zero-vector "unconditional" inputs.
    """
    if self.negative_prompt_embeds is not None:  # type: ignore[has-type]
        return

    assert self.training_config is not None
    tc = self.training_config

    cfg_rate = float(tc.data.training_cfg_rate or 0.0)
    if cfg_rate > 0.0:
        raise NotImplementedError("Cosmos 2.5 currently only supports training_cfg_rate=0; "
                                  f"got training_cfg_rate={cfg_rate}. Real negative-prompt "
                                  "embeddings via Reason1 (Qwen2.5-VL) are not implemented "
                                  "yet — using the zero placeholder with CFG dropout would "
                                  "train against zero-vector \"unconditional\" inputs and "
                                  "produce wrong gradients. Set "
                                  "training.data.training_cfg_rate=0.")

    device = self.device
    dtype = self._get_training_dtype()

    # Infer embedding dimension from the pipeline config's
    # text encoder settings, or fall back to a reasonable
    # default for Cosmos 2.5 (Reason1 full_concat: 100352).
    text_enc_cfgs = tc.pipeline_config.text_encoder_configs
    if text_enc_cfgs:
        arch = text_enc_cfgs[0].arch_config
        embed_dim = getattr(arch, "hidden_size", 100352)
    else:
        embed_dim = 100352

    num_tokens = 512  # Reason1 default padding length

    neg_embeds = torch.zeros(
        1,
        num_tokens,
        embed_dim,
        device=device,
        dtype=dtype,
    )
    neg_mask = torch.ones(
        1,
        num_tokens,
        device=device,
        dtype=dtype,
    )

    self.negative_prompt_embeds = neg_embeds
    self.negative_prompt_attention_mask = neg_mask
fastvideo.train.models.cosmos.cosmos.CosmosModel.prepare_batch
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch

Same flow as Wan, but uses Cosmos VAE normalisation.

Source code in fastvideo/train/models/cosmos/cosmos.py
def prepare_batch(
    self,
    raw_batch: dict[str, Any],
    *,
    generator: torch.Generator,
    latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
    """Same flow as Wan, but uses Cosmos VAE
    normalisation."""
    self.ensure_negative_conditioning()
    assert self.training_config is not None
    tc = self.training_config

    dtype = self._get_training_dtype()
    device = self.device

    training_batch = TrainingBatch()
    encoder_hidden_states = raw_batch["text_embedding"]
    encoder_attention_mask = raw_batch["text_attention_mask"]
    infos = raw_batch.get("info_list")

    if latents_source == "zeros":
        batch_size = encoder_hidden_states.shape[0]
        vae_config = (
            tc.pipeline_config.vae_config  # type: ignore[union-attr]
            .arch_config)
        num_channels = getattr(
            vae_config,
            "z_dim",
            getattr(vae_config, "latent_channels", 16),
        )
        spatial_compression_ratio = (vae_config.spatial_compression_ratio)
        latent_height = (tc.data.num_height // spatial_compression_ratio)
        latent_width = (tc.data.num_width // spatial_compression_ratio)
        latents = torch.zeros(
            batch_size,
            num_channels,
            tc.data.num_latent_t,
            latent_height,
            latent_width,
            device=device,
            dtype=dtype,
        )
    elif latents_source == "data":
        if "vae_latent" not in raw_batch:
            raise ValueError("vae_latent not found in batch "
                             "and latents_source='data'")
        latents = raw_batch["vae_latent"]
        latents = latents[:, :, :tc.data.num_latent_t]
        latents = latents.to(device, dtype=dtype)
    else:
        raise ValueError(f"Unknown latents_source: "
                         f"{latents_source!r}")

    training_batch.latents = latents
    training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
    training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
    training_batch.infos = infos

    # KEY DIFFERENCE: "cosmos" normalisation
    training_batch.latents = normalize_dit_input(
        "cosmos",
        training_batch.latents,
        self.vae,
    )
    training_batch = self._prepare_dit_inputs(training_batch, generator)
    training_batch = self._build_attention_metadata(training_batch)

    # Shallow copy keeps the lru_cache'd LongTensor index fields shared
    # with the original metadata; only the float ``VSA_sparsity`` differs
    # between the two views. deepcopy here would materialize a fresh copy
    # of all four cached index tensors on every training step.
    training_batch.attn_metadata_vsa = copy.copy(training_batch.attn_metadata)
    if training_batch.attn_metadata is not None:
        training_batch.attn_metadata.VSA_sparsity = 0.0  # type: ignore[attr-defined]

    return training_batch

Functions