Cosmos model plugin (per-role instance).
Subclasses WanModel since Cosmos uses the same
FlowMatchEulerDiscreteScheduler. Differences:
- transformer class name: CosmosTransformer3DModel
- normalize_dit_input("cosmos", ...) instead of ("wan", ...)
- forward kwargs: no encoder_attention_mask, needs
condition_mask + padding_mask + fps
- hidden_states in (B,C,T,H,W) — no permute needed
- default flow_shift = 1.0
- single T5 text encoder (not dual like Hunyuan)
Classes
fastvideo.train.models.cosmos.cosmos.CosmosModel
CosmosModel(*, init_from: str, training_config: TrainingConfig, trainable: bool = True, disable_custom_init_weights: bool = False, flow_shift: float = 1.0, enable_gradient_checkpointing_type: str | None = None, transformer_override_safetensor: str | None = None)
Bases: WanModel
Cosmos 2.5 per-role model.
Inherits most behaviour from WanModel (noise scheduler,
timestep sampling, attention metadata, backward). Overrides
only the pieces that differ for Cosmos 2.5.
Cosmos 2.5 uses:
- Cosmos25Transformer3DModel (velocity prediction)
- EDM noise schedule: x_t = x_0 + sigma * eps
- No input/output preconditioning (raw latents)
- Timestep = raw sigma value
- Model output = velocity ≈ noise
Source code in fastvideo/train/models/cosmos/cosmos.py
| def __init__(
self,
*,
init_from: str,
training_config: TrainingConfig,
trainable: bool = True,
disable_custom_init_weights: bool = False,
flow_shift: float = 1.0,
enable_gradient_checkpointing_type: str
| None = None,
transformer_override_safetensor: str
| None = None,
) -> None:
super().__init__(
init_from=init_from,
training_config=training_config,
trainable=trainable,
disable_custom_init_weights=(disable_custom_init_weights),
flow_shift=flow_shift,
enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
transformer_override_safetensor=(transformer_override_safetensor),
)
|
Functions
fastvideo.train.models.cosmos.cosmos.CosmosModel.ensure_negative_conditioning
ensure_negative_conditioning() -> None
Create negative (unconditional) prompt embeddings.
Cosmos 2.5 uses Reason1 (Qwen2.5-VL) which is expensive
to load. This method only supports training_cfg_rate=0
(no classifier-free guidance dropout), in which case the
negative embedding is never used and a zero placeholder
sized to match the text embedding dimension is sufficient.
training_cfg_rate>0 would require real Reason1 negative
embeddings and is rejected here to avoid silently training
with zero-vector "unconditional" inputs.
Source code in fastvideo/train/models/cosmos/cosmos.py
| def ensure_negative_conditioning(self) -> None:
"""Create negative (unconditional) prompt embeddings.
Cosmos 2.5 uses Reason1 (Qwen2.5-VL) which is expensive
to load. This method only supports ``training_cfg_rate=0``
(no classifier-free guidance dropout), in which case the
negative embedding is never used and a zero placeholder
sized to match the text embedding dimension is sufficient.
``training_cfg_rate>0`` would require real Reason1 negative
embeddings and is rejected here to avoid silently training
with zero-vector "unconditional" inputs.
"""
if self.negative_prompt_embeds is not None: # type: ignore[has-type]
return
assert self.training_config is not None
tc = self.training_config
cfg_rate = float(tc.data.training_cfg_rate or 0.0)
if cfg_rate > 0.0:
raise NotImplementedError("Cosmos 2.5 currently only supports training_cfg_rate=0; "
f"got training_cfg_rate={cfg_rate}. Real negative-prompt "
"embeddings via Reason1 (Qwen2.5-VL) are not implemented "
"yet — using the zero placeholder with CFG dropout would "
"train against zero-vector \"unconditional\" inputs and "
"produce wrong gradients. Set "
"training.data.training_cfg_rate=0.")
device = self.device
dtype = self._get_training_dtype()
# Infer embedding dimension from the pipeline config's
# text encoder settings, or fall back to a reasonable
# default for Cosmos 2.5 (Reason1 full_concat: 100352).
text_enc_cfgs = tc.pipeline_config.text_encoder_configs
if text_enc_cfgs:
arch = text_enc_cfgs[0].arch_config
embed_dim = getattr(arch, "hidden_size", 100352)
else:
embed_dim = 100352
num_tokens = 512 # Reason1 default padding length
neg_embeds = torch.zeros(
1,
num_tokens,
embed_dim,
device=device,
dtype=dtype,
)
neg_mask = torch.ones(
1,
num_tokens,
device=device,
dtype=dtype,
)
self.negative_prompt_embeds = neg_embeds
self.negative_prompt_attention_mask = neg_mask
|
fastvideo.train.models.cosmos.cosmos.CosmosModel.prepare_batch
prepare_batch(raw_batch: dict[str, Any], *, generator: Generator, latents_source: Literal['data', 'zeros'] = 'data') -> TrainingBatch
Same flow as Wan, but uses Cosmos VAE
normalisation.
Source code in fastvideo/train/models/cosmos/cosmos.py
| def prepare_batch(
self,
raw_batch: dict[str, Any],
*,
generator: torch.Generator,
latents_source: Literal["data", "zeros"] = "data",
) -> TrainingBatch:
"""Same flow as Wan, but uses Cosmos VAE
normalisation."""
self.ensure_negative_conditioning()
assert self.training_config is not None
tc = self.training_config
dtype = self._get_training_dtype()
device = self.device
training_batch = TrainingBatch()
encoder_hidden_states = raw_batch["text_embedding"]
encoder_attention_mask = raw_batch["text_attention_mask"]
infos = raw_batch.get("info_list")
if latents_source == "zeros":
batch_size = encoder_hidden_states.shape[0]
vae_config = (
tc.pipeline_config.vae_config # type: ignore[union-attr]
.arch_config)
num_channels = getattr(
vae_config,
"z_dim",
getattr(vae_config, "latent_channels", 16),
)
spatial_compression_ratio = (vae_config.spatial_compression_ratio)
latent_height = (tc.data.num_height // spatial_compression_ratio)
latent_width = (tc.data.num_width // spatial_compression_ratio)
latents = torch.zeros(
batch_size,
num_channels,
tc.data.num_latent_t,
latent_height,
latent_width,
device=device,
dtype=dtype,
)
elif latents_source == "data":
if "vae_latent" not in raw_batch:
raise ValueError("vae_latent not found in batch "
"and latents_source='data'")
latents = raw_batch["vae_latent"]
latents = latents[:, :, :tc.data.num_latent_t]
latents = latents.to(device, dtype=dtype)
else:
raise ValueError(f"Unknown latents_source: "
f"{latents_source!r}")
training_batch.latents = latents
training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
training_batch.infos = infos
# KEY DIFFERENCE: "cosmos" normalisation
training_batch.latents = normalize_dit_input(
"cosmos",
training_batch.latents,
self.vae,
)
training_batch = self._prepare_dit_inputs(training_batch, generator)
training_batch = self._build_attention_metadata(training_batch)
# Shallow copy keeps the lru_cache'd LongTensor index fields shared
# with the original metadata; only the float ``VSA_sparsity`` differs
# between the two views. deepcopy here would materialize a fresh copy
# of all four cached index tensors on every training step.
training_batch.attn_metadata_vsa = copy.copy(training_batch.attn_metadata)
if training_batch.attn_metadata is not None:
training_batch.attn_metadata.VSA_sparsity = 0.0 # type: ignore[attr-defined]
return training_batch
|
Functions