Skip to content

denoising

Denoising stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.denoising.Cosmos25AutoDenoisingStage

Cosmos25AutoDenoisingStage(transformer, scheduler)

Bases: PipelineStage

Route Cosmos 2.5 denoising to T2W vs V2W/I2W.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler) -> None:
    super().__init__()
    self._t2w = Cosmos25T2WDenoisingStage(transformer=transformer, scheduler=scheduler)
    self._v2w = Cosmos25V2WDenoisingStage(transformer=transformer, scheduler=scheduler)

fastvideo.pipelines.stages.denoising.Cosmos25DenoisingStage

Cosmos25DenoisingStage(transformer, scheduler, pipeline=None)

Bases: CosmosDenoisingStage

Denoising stage for Cosmos 2.5 DiT (expects 1D/2D timestep, not 5D).

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.stages.denoising.Cosmos25T2WDenoisingStage

Cosmos25T2WDenoisingStage(transformer, scheduler, pipeline=None)

Bases: Cosmos25DenoisingStage

Cosmos 2.5 Text2World denoising stage.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.stages.denoising.Cosmos25V2WDenoisingStage

Cosmos25V2WDenoisingStage(transformer, scheduler, pipeline=None)

Bases: Cosmos25DenoisingStage

Cosmos 2.5 Video2World denoising stage.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.stages.denoising.CosmosDenoisingStage

CosmosDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for Cosmos models.

Uses FlowMatchEulerDiscreteScheduler with manual EDM preconditioning (c_in, c_skip, c_out) to match the pretrained Cosmos model's training convention.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

Methods:

fastvideo.pipelines.stages.denoising.CosmosDenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
    result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.denoising.CosmosDenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.denoising.DenoisingStage

DenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: PipelineStage

Stage for running the denoising loop in diffusion pipelines.

This stage handles the iterative denoising process that transforms the initial noise into the final output.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.BSA_ATTN,
                                      AttentionBackendEnum.VMOBA_ATTN, AttentionBackendEnum.FLASH_ATTN,
                                      AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )

Methods:

fastvideo.pipelines.stages.denoising.DenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
        if pipeline:
            pipeline.add_module("transformer", self.transformer)
        fastvideo_args.model_loaded["transformer"] = True

    # Prepare extra step kwargs for scheduler
    extra_step_kwargs = self.prepare_extra_func_kwargs(
        self.scheduler.step,
        {
            "generator": batch.generator,
            "eta": batch.eta
        },
    )

    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    # Flux2-only denoising compensations.
    #
    # `_is_flux` gates four behaviors that exist because Flux2's transformer
    # forward() does things internally that the generic pipeline must undo or
    # match. These are architectural facts about the Flux2 transformer, not
    # tunable precision policies (the precision policies #5/#6 — prompt-embed
    # casting and scheduler-step placement — were already moved to config:
    # DiTArchConfig.cast_prompt_embeds_to_dit_dtype and
    # PipelineConfig.scheduler_step_in_fp32).
    #
    # The four behaviors gated below:
    #   1. env-var bf16-reduced-precision matmul disable (4-step Klein drift)
    #   2. autocast disabled (Flux2 long-sequence attention breaks parity under autocast)
    #   3. guidance: skip the external x1000 (Flux2 multiplies guidance by 1000 internally)
    #   4. timestep: divide by 1000 with cast-before-divide (Flux2 multiplies timestep by 1000 internally)
    #
    # Contract: `prefix == "Flux"` is set ONLY by Flux2 (fastvideo/configs/
    # models/dits/flux_2.py). No other model uses that prefix, so this exact
    # match cannot false-positive. A future Flux variant that needs the same
    # compensations must either set prefix == "Flux" too, OR (preferred) these
    # gates should graduate to arch-config declarations like the precision
    # policies above.
    _is_flux = (getattr(fastvideo_args.pipeline_config.dit_config, "prefix", "") == "Flux")
    if _is_flux and os.getenv("FASTVIDEO_FLUX2_DISABLE_BF16_REDUCED_PRECISION_REDUCTION",
                              "").lower() in {"1", "true", "yes"}:
        # Gate 1: tighten bf16 matmul accumulation for the 4-step Klein model (opt-in via env var).
        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
    # Gate 2: Flux2 runs its bf16 transformer WITHOUT autocast — autocast perturbs long-sequence attention enough to break 4-step latent parity.
    autocast_enabled = ((target_dtype != torch.float32) and not fastvideo_args.disable_autocast and not _is_flux)
    scheduler_fp32 = getattr(fastvideo_args.pipeline_config, "scheduler_step_in_fp32", False)
    local_device = get_local_torch_device()

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps
    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert not torch.isnan(image_embeds[0]).any(), "image_embeds contains nan"
        image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    neg_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_neg,
            "encoder_attention_mask": batch.negative_attention_mask,
        },
    )

    action_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "mouse_cond": batch.mouse_cond,
            "keyboard_cond": batch.keyboard_cond,
            "c2ws_plucker_emb": batch.c2ws_plucker_emb,
        },
    )

    camera_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "camera_states": batch.camera_states,
        },
    )

    for key in ("flux2_txt_ids", "flux2_img_ids"):
        value = batch.extra.get(key)
        if torch.is_tensor(value):
            batch.extra[key] = value.to(device=local_device)

    flux2_id_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "txt_ids": batch.extra.get("flux2_txt_ids"),
            "img_ids": batch.extra.get("flux2_img_ids"),
        },
    )

    # Get latents and embeddings
    latents = batch.latents
    cast_embeds = getattr(fastvideo_args.pipeline_config.dit_config, "cast_prompt_embeds_to_dit_dtype", False)
    if cast_embeds:
        prompt_embeds = [
            embed.to(device=local_device, dtype=target_dtype) if torch.is_tensor(embed) else embed
            for embed in batch.prompt_embeds
        ]
    else:
        prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
    if batch.do_classifier_free_guidance:
        neg_prompt_embeds = batch.negative_prompt_embeds
        assert neg_prompt_embeds is not None
        if cast_embeds:
            neg_prompt_embeds = [
                embed.to(device=local_device, dtype=target_dtype) if torch.is_tensor(embed) else embed
                for embed in neg_prompt_embeds
            ]
        else:
            neg_prompt_embeds = batch.negative_prompt_embeds
        assert not torch.isnan(neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"

    # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
    boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
    if batch.boundary_ratio is not None:
        logger.info("Overriding boundary ratio from %s to %s", boundary_ratio, batch.boundary_ratio)
        boundary_ratio = batch.boundary_ratio

    boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps if boundary_ratio is not None else None
    latent_model_input = latents.to(target_dtype)
    assert latent_model_input.shape[0] == 1, "only support batch size 1"

    if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
        # TI2V directly replaces the first frame of the latent with
        # the image latent instead of appending along the channel dim
        assert batch.image_latent is None, "TI2V task should not have image latents"
        assert self.vae is not None, "VAE is not provided for TI2V task"
        z = self.vae.encode(batch.pil_image).mean.float()
        if (hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                z -= self.vae.shift_factor.to(z.device, z.dtype)
            else:
                z -= self.vae.shift_factor

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            z = z * self.vae.scaling_factor.to(z.device, z.dtype)
        else:
            z = z * self.vae.scaling_factor

        latent_model_input = latent_model_input.squeeze(0)
        _, mask2 = masks_like([latent_model_input], zero=True)

        latent_model_input = (1. - mask2[0]) * z + mask2[0] * latent_model_input
        # latent_model_input = latent_model_input.unsqueeze(0)
        latent_model_input = latent_model_input.to(get_local_torch_device())
        latents = latent_model_input
        F = batch.num_frames
        temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
        spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        seq_len = ((F - 1) // temporal_scale + 1) * (batch.height // spatial_scale) * (
            batch.width // spatial_scale) // (patch_size[1] * patch_size[2])

    # Initialize lists for ODE trajectory
    trajectory_timesteps: list[torch.Tensor] = []
    trajectory_latents: list[torch.Tensor] = []
    is_lucy_edit = fastvideo_args.pipeline_config.lucy_edit_task

    # Hoisted out of the per-step loop: depends only on inputs that
    # are constant across denoising steps.
    use_meanflow = getattr(self.transformer.config, "use_meanflow", False)
    # Gate 3: Flux2's transformer multiplies guidance by 1000 internally, so we
    # skip the external *1000 pre-scaling for Flux models.
    embedded_cfg_scale = fastvideo_args.pipeline_config.embedded_cfg_scale
    if _is_flux and embedded_cfg_scale is not None:
        embedded_cfg_scale = batch.guidance_scale
    if embedded_cfg_scale is not None:
        guidance_expand = (torch.tensor(
            [embedded_cfg_scale] * latents.shape[0],
            dtype=torch.float32,
            device=get_local_torch_device(),
        ).to(target_dtype) * (1.0 if _is_flux else 1000.0))
    else:
        guidance_expand = None
    # V2V padding: zero-filled tensor concatenated with each step's
    # latent_model_input.  Shape is fixed by latents and is never
    # written to, so we allocate once.
    v2v_zero_pad = torch.zeros_like(latents) if batch.video_latent is not None else None
    lucy_timestep_seq_len = None
    if is_lucy_edit:
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        assert patch_size[0] == 1, "Lucy Edit timestep expansion assumes temporal patch size 1"
        lucy_timestep_seq_len = (latents.shape[2] * (latents.shape[3] // patch_size[1]) *
                                 (latents.shape[4] // patch_size[2]))

    # CFG gating / stale-uncond reuse setup (Adaptive Guidance LinearAG
    # variant, Castillo et al. 2023).  When envs.FASTVIDEO_CFG_GATE_STEP
    # < 1.0, the uncond forward is skipped after the gating step and the
    # guidance delta (cond - uncond) is reused from the last fresh
    # compute.  See envs.py for semantics.  delta_cached_model_id tracks
    # which underlying transformer produced the cache so we invalidate on
    # Wan2.2 expert switch.
    _cfg_gate_fraction = envs.FASTVIDEO_CFG_GATE_STEP
    if not 0.0 <= _cfg_gate_fraction <= 1.0:
        raise ValueError(f"FASTVIDEO_CFG_GATE_STEP must be in [0.0, 1.0], got {_cfg_gate_fraction!r}. "
                         "Use 1.0 (default) to disable; lower values trade quality for speed.")
    _cfg_gate_active = _cfg_gate_fraction < 1.0 and batch.do_classifier_free_guidance
    _is_rank0 = get_world_group().local_rank == 0
    if _cfg_gate_active:
        # Use len(timesteps), not num_inference_steps: the loop iterates
        # over timesteps directly, and for schedulers with order > 1
        # (e.g. DPM-Solver++ 2M, Heun) len(timesteps) is a multiple of
        # num_inference_steps. Using num_inference_steps would cause the
        # gate to fire at fraction/order of the loop instead of fraction.
        _cfg_gate_step_idx = int(len(timesteps) * _cfg_gate_fraction)
        if _is_rank0:
            logger.info("CFG gating enabled: fraction=%.3f, gate_step=%d/%d", _cfg_gate_fraction,
                        _cfg_gate_step_idx, len(timesteps))
        if batch.guidance_rescale > 0.0 and _is_rank0:
            # guidance_rescale rescales CFG output stats to match cond
            # stats (Lin et al. §3.4).  When `delta_cached` goes stale,
            # the rescaling still computes but is no longer guaranteed
            # to preserve the original quality semantics.  Warn so the
            # caller knows this combo is unvalidated; tighten or
            # fallback once VBench data lands.
            logger.warning(
                "CFG gating (fraction=%.3f) combined with guidance_rescale=%.3f is unvalidated; "
                "quality may degrade beyond CFG-gating-alone expectations.", _cfg_gate_fraction,
                batch.guidance_rescale)
    else:
        _cfg_gate_step_idx = len(timesteps) + 1  # never gates
    delta_cached: torch.Tensor | None = None
    delta_cached_model_id: int | None = None
    # Telemetry — logged at end of denoising loop on rank 0.
    _cfg_gate_fresh_uncond = 0
    _cfg_gate_reused_delta = 0
    _cfg_gate_invalidations = 0

    # Run denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue

            if boundary_timestep is None or t >= boundary_timestep:
                if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
                        and self.transformer_2 is not None
                        and next(self.transformer_2.parameters()).device.type == 'cuda'):
                    self.transformer_2.to('cpu')
                current_model = self.transformer
                if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
                        and not fastvideo_args.use_fsdp_inference and current_model is not None):
                    transformer_device = next(current_model.parameters()).device.type
                    if transformer_device == 'cpu':
                        current_model.to(get_local_torch_device())
                current_guidance_scale = batch.guidance_scale
            else:
                # low-noise stage in wan2.2
                if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
                        and next(self.transformer.parameters()).device.type == 'cuda'):
                    self.transformer.to('cpu')
                current_model = self.transformer_2
                if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
                        and not fastvideo_args.use_fsdp_inference and current_model is not None):
                    transformer_2_device = next(current_model.parameters()).device.type
                    if transformer_2_device == 'cpu':
                        current_model.to(get_local_torch_device())
                current_guidance_scale = batch.guidance_scale_2
            assert current_model is not None, "current_model is None"

            # Expand latents for V2V/I2V
            latent_model_input = latents.to(target_dtype)
            if batch.video_latent is not None:
                if is_lucy_edit:
                    latent_model_input = torch.cat(
                        [latent_model_input, batch.video_latent],
                        dim=1,
                    ).to(target_dtype)
                else:
                    latent_model_input = torch.cat(
                        [latent_model_input, batch.video_latent, v2v_zero_pad],
                        dim=1,
                    ).to(target_dtype)
            elif batch.image_latent is not None:
                assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
                latent_model_input = torch.cat([latent_model_input, batch.image_latent], dim=1).to(target_dtype)

            assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
            if is_lucy_edit:
                assert lucy_timestep_seq_len is not None
                t_expand = t.repeat(latent_model_input.shape[0], lucy_timestep_seq_len)
            elif fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                timestep = torch.stack([t]).to(get_local_torch_device())
                temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
                temp_ts = torch.cat([temp_ts, temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep])
                timestep = temp_ts.unsqueeze(0)
                t_expand = timestep.repeat(latent_model_input.shape[0], 1)
            else:
                t_expand = t.repeat(latent_model_input.shape[0])
            # Gate 4: Flux2 transformer multiplies timestep by 1000 internally, so
            # the pipeline must pass timestep/1000 (matching Diffusers).
            # Diffusers casts to the latent dtype before the division; doing
            # the division in fp32 first changes BF16 rounding for the final
            # Klein timestep and breaks latent parity.
            if _is_flux:
                t_expand = t_expand.to(
                    device=get_local_torch_device(),
                    dtype=latent_model_input.dtype,
                )
                t_expand = t_expand / 1000.0
            else:
                t_expand = t_expand.to(get_local_torch_device())

            if use_meanflow:
                if i == len(timesteps) - 1:
                    timesteps_r = torch.tensor([0.0], device=get_local_torch_device())
                else:
                    timesteps_r = timesteps[i + 1]
                timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
            else:
                timesteps_r = None

            timesteps_r_kwarg = self.prepare_extra_func_kwargs(
                self.transformer.forward,
                {
                    "timestep_r": timesteps_r,
                },
            )

            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # Predict noise residual
            with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
                if (vsa_available and self.attn_backend == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls()
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            VSA_sparsity=fastvideo_args.VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),
                        )
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                elif (vmoba_attn_available and self.attn_backend == VMOBAAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls()
                        # Prepare V-MoBA parameters from config
                        moba_params = fastvideo_args.moba_config.copy()
                        moba_params.update({
                            "current_timestep": i,
                            "raw_latent_shape": batch.raw_latent_shape[2:5],
                            "patch_size": fastvideo_args.pipeline_config.dit_config.patch_size,
                            "device": get_local_torch_device(),
                        })
                        attn_metadata = self.attn_metadata_builder.build(**moba_params)
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None
                # TODO(will): finalize the interface. vLLM uses this to
                # support torch dynamo compilation. They pass in
                # attn_metadata, vllm_config, and num_tokens. We can pass in
                # fastvideo_args or training_args, and attn_metadata.
                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    noise_pred = current_model(
                        latent_model_input,
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                        **action_kwargs,
                        **camera_kwargs,
                        **timesteps_r_kwarg,
                        **flux2_id_kwargs,
                    )

                if batch.do_classifier_free_guidance:
                    # CFG gating: invalidate cached delta when the underlying
                    # transformer changes (Wan2.2 high/low-noise expert
                    # switch at `boundary_timestep`).  delta_cached is tied
                    # to the model that produced it; reusing it across the
                    # boundary is silently wrong.
                    if delta_cached_model_id is not None and delta_cached_model_id != id(current_model):
                        delta_cached = None
                        delta_cached_model_id = None
                        _cfg_gate_invalidations += 1

                    _use_cached_delta = (i >= _cfg_gate_step_idx and delta_cached is not None)

                    noise_pred_text = noise_pred
                    if _use_cached_delta:
                        # Reuse frozen delta = cond - uncond from the last
                        # fresh compute.  Algebra:
                        #   pred = uncond + s * (cond - uncond)
                        #        = cond + (s - 1) * (cond - uncond)
                        #        = cond + (s - 1) * delta_cached
                        noise_pred = noise_pred_text + (current_guidance_scale - 1.0) * delta_cached
                        _cfg_gate_reused_delta += 1
                    else:
                        batch.is_cfg_negative = True
                        with set_forward_context(
                                current_timestep=i,
                                attn_metadata=attn_metadata,
                                forward_batch=batch,
                        ):
                            noise_pred_uncond = current_model(
                                latent_model_input,
                                neg_prompt_embeds,
                                t_expand,
                                guidance=guidance_expand,
                                **image_kwargs,
                                **neg_cond_kwargs,
                                **action_kwargs,
                                **camera_kwargs,
                                **timesteps_r_kwarg,
                                **flux2_id_kwargs,
                            )
                        _cfg_gate_fresh_uncond += 1

                        # Refresh cache only when gating is active; under the
                        # default (FASTVIDEO_CFG_GATE_STEP=1.0, _cfg_gate_step_idx
                        # > len(timesteps)) we never reuse, so skip the
                        # tensor allocation.
                        if _cfg_gate_step_idx <= len(timesteps):
                            delta_cached = noise_pred_text - noise_pred_uncond
                            delta_cached_model_id = id(current_model)
                            noise_pred = noise_pred_uncond + current_guidance_scale * delta_cached
                        else:
                            noise_pred = noise_pred_uncond + current_guidance_scale * (noise_pred_text -
                                                                                       noise_pred_uncond)

                    # Apply guidance rescale if needed
                    if batch.guidance_rescale > 0.0:
                        # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        noise_pred = self.rescale_noise_cfg(
                            noise_pred,
                            noise_pred_text,
                            guidance_rescale=batch.guidance_rescale,
                        )
            if scheduler_fp32:
                # Diffusers-style: fp32 Euler update outside autocast avoids BF16 drift.
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
            else:
                with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
            if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                latents = latents.squeeze(0)
                latents = (1. - mask2[0]) * z + mask2[0] * latents
                # latents = latents.unsqueeze(0)

            # save trajectory latents if needed
            if batch.return_trajectory_latents:
                trajectory_timesteps.append(t)
                trajectory_latents.append(latents)

            # Update progress bar
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
                                           (i + 1) % self.scheduler.order == 0 and progress_bar is not None):
                progress_bar.update()

    # CFG gating telemetry — log once on rank 0 after the loop ends.  When
    # gating is disabled (or CFG itself is off) fresh_uncond equals the
    # number of CFG-on steps and reused/invalidations are zero; we still
    # emit the line so users can confirm the env var is wired through.
    if _is_rank0 and batch.do_classifier_free_guidance:
        logger.info(
            "CFG gating summary: fraction=%.3f gate_step=%d/%d "
            "fresh_uncond=%d reused=%d invalidations=%d",
            _cfg_gate_fraction,
            _cfg_gate_step_idx if _cfg_gate_active else -1,
            len(timesteps),
            _cfg_gate_fresh_uncond,
            _cfg_gate_reused_delta,
            _cfg_gate_invalidations,
        )

    trajectory_tensor: torch.Tensor | None = None
    if trajectory_latents:
        trajectory_tensor = torch.stack(trajectory_latents, dim=1)
        trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)
    else:
        trajectory_tensor = None
        trajectory_timesteps_tensor = None

    if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
        batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
        batch.trajectory_latents = trajectory_tensor.cpu()

    # Update batch with final latents
    batch.latents = latents

    if fastvideo_args.dit_layerwise_offload:
        mgr = getattr(self.transformer, "_layerwise_offload_manager", None)
        if mgr is not None and getattr(mgr, "enabled", False):
            mgr.release_all()
        if self.transformer_2 is not None:
            mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager", None)
            if mgr2 is not None and getattr(mgr2, "enabled", False):
                mgr2.release_all()

    # deallocate transformer if on mps
    if torch.backends.mps.is_available():
        logger.info("Memory before deallocating transformer: %s", torch.mps.current_allocated_memory())
        del self.transformer
        if pipeline is not None and "transformer" in pipeline.modules:
            del pipeline.modules["transformer"]
        fastvideo_args.model_loaded["transformer"] = False
        logger.info("Memory after deallocating transformer: %s", torch.mps.current_allocated_memory())

    return batch
fastvideo.pipelines.stages.denoising.DenoisingStage.prepare_extra_func_kwargs
prepare_extra_func_kwargs(func, kwargs) -> dict[str, Any]

Prepare extra kwargs for the scheduler step / denoise step.

Parameters:

Name Type Description Default
func

The function to prepare kwargs for.

required
kwargs

The kwargs to prepare.

required

Returns:

Type Description
dict[str, Any]

The prepared kwargs.

Source code in fastvideo/pipelines/stages/denoising.py
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
    """
    Prepare extra kwargs for the scheduler step / denoise step.

    Args:
        func: The function to prepare kwargs for.
        kwargs: The kwargs to prepare.

    Returns:
        The prepared kwargs.
    """
    extra_step_kwargs = {}
    for k, v in kwargs.items():
        accepts = k in set(inspect.signature(func).parameters.keys())
        if accepts:
            extra_step_kwargs[k] = v
    return extra_step_kwargs
fastvideo.pipelines.stages.denoising.DenoisingStage.progress_bar
progress_bar(iterable: Iterable | None = None, total: int | None = None) -> tqdm

Create a progress bar for the denoising process.

Parameters:

Name Type Description Default
iterable Iterable | None

The iterable to iterate over.

None
total int | None

The total number of items.

None

Returns:

Type Description
tqdm

A tqdm progress bar.

Source code in fastvideo/pipelines/stages/denoising.py
def progress_bar(self, iterable: Iterable | None = None, total: int | None = None) -> tqdm:
    """
    Create a progress bar for the denoising process.

    Args:
        iterable: The iterable to iterate over.
        total: The total number of items.

    Returns:
        A tqdm progress bar.
    """
    local_rank = get_world_group().local_rank
    if local_rank == 0:
        return tqdm(iterable=iterable, total=total)
    else:
        return tqdm(iterable=iterable, total=total, disable=True)
fastvideo.pipelines.stages.denoising.DenoisingStage.rescale_noise_cfg
rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0) -> Tensor

Rescale noise prediction according to guidance_rescale.

Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

Parameters:

Name Type Description Default
noise_cfg

The noise prediction with guidance.

required
noise_pred_text

The text-conditioned noise prediction.

required
guidance_rescale

The guidance rescale factor.

0.0

Returns:

Type Description
Tensor

The rescaled noise prediction.

Source code in fastvideo/pipelines/stages/denoising.py
def rescale_noise_cfg(self, noise_cfg, noise_pred_text, guidance_rescale=0.0) -> torch.Tensor:
    """
    Rescale noise prediction according to guidance_rescale.

    Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
    (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

    Args:
        noise_cfg: The noise prediction with guidance.
        noise_pred_text: The text-conditioned noise prediction.
        guidance_rescale: The guidance rescale factor.

    Returns:
        The rescaled noise prediction.
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # Rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # Mix with the original results from guidance by factor guidance_rescale
    noise_cfg = (guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg)
    return noise_cfg
fastvideo.pipelines.stages.denoising.DenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)])
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent, V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator, V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.denoising.DenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.denoising.DmdDenoisingStage

DmdDenoisingStage(transformer, scheduler)

Bases: DenoisingStage

Denoising stage for DMD.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler) -> None:
    super().__init__(transformer, scheduler)
    self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)

Methods:

fastvideo.pipelines.stages.denoising.DmdDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps

    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert torch.isnan(image_embeds[0]).sum() == 0
        image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    # Get latents and embeddings
    assert batch.latents is not None, "latents must be provided"
    latents = batch.latents

    video_raw_latent_shape = latents.shape
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
    timesteps = torch.tensor(fastvideo_args.pipeline_config.dmd_denoising_steps,
                             dtype=torch.long,
                             device=get_local_torch_device())

    # Run denoising loop
    with self.progress_bar(total=len(timesteps)) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue
            # Expand latents for I2V
            noise_latents = latents.clone()
            latent_model_input = latents.to(target_dtype)

            if batch.image_latent is not None:
                latent_model_input = torch.cat(
                    [latent_model_input, batch.image_latent.permute(0, 2, 1, 3, 4)], dim=2).to(target_dtype)
            assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"

            # Prepare inputs for transformer
            t_expand = t.repeat(latent_model_input.shape[0])
            guidance_expand = (torch.tensor(
                [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
                dtype=torch.float32,
                device=get_local_torch_device(),
            ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
                if (vsa_available and self.attn_backend == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls()
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            VSA_sparsity=fastvideo_args.VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),  # type: ignore
                        )  # type: ignore
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None

                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    pred_noise = self.transformer(
                        latent_model_input.permute(0, 2, 1, 3, 4),
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                    ).permute(0, 2, 1, 3, 4)

                pred_video = pred_noise_to_pred_video(pred_noise=pred_noise.flatten(0, 1),
                                                      noise_input_latent=noise_latents.flatten(0, 1),
                                                      timestep=t_expand,
                                                      scheduler=self.scheduler).unflatten(0, pred_noise.shape[:2])

                if i < len(timesteps) - 1:
                    next_timestep = timesteps[i + 1] * torch.ones([1], dtype=torch.long, device=pred_video.device)
                    noise = torch.randn(video_raw_latent_shape,
                                        dtype=pred_video.dtype,
                                        generator=batch.generator[0]).to(self.device)
                    latents = self.scheduler.add_noise(pred_video.flatten(0, 1), noise.flatten(0, 1),
                                                       next_timestep).unflatten(0, pred_video.shape[:2])
                else:
                    latents = pred_video

                # Update progress bar
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
                                               (i + 1) % self.scheduler.order == 0 and progress_bar is not None):
                    progress_bar.update()

    # Gather results if using sequence parallelism
    latents = latents.permute(0, 2, 1, 3, 4)
    # Update batch with final latents
    batch.latents = latents

    return batch

Functions: