Skip to content

denoising

Denoising stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.denoising.Cosmos25AutoDenoisingStage

Cosmos25AutoDenoisingStage(transformer, scheduler)

Bases: PipelineStage

Route Cosmos 2.5 denoising to T2W vs V2W/I2W.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler) -> None:
    super().__init__()
    self._t2w = Cosmos25T2WDenoisingStage(transformer=transformer, scheduler=scheduler)
    self._v2w = Cosmos25V2WDenoisingStage(transformer=transformer, scheduler=scheduler)

fastvideo.pipelines.stages.denoising.Cosmos25DenoisingStage

Cosmos25DenoisingStage(transformer, scheduler, pipeline=None)

Bases: CosmosDenoisingStage

Denoising stage for Cosmos 2.5 DiT (expects 1D/2D timestep, not 5D).

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.stages.denoising.Cosmos25T2WDenoisingStage

Cosmos25T2WDenoisingStage(transformer, scheduler, pipeline=None)

Bases: Cosmos25DenoisingStage

Cosmos 2.5 Text2World denoising stage.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.stages.denoising.Cosmos25V2WDenoisingStage

Cosmos25V2WDenoisingStage(transformer, scheduler, pipeline=None)

Bases: Cosmos25DenoisingStage

Cosmos 2.5 Video2World denoising stage.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.stages.denoising.CosmosDenoisingStage

CosmosDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

Functions

fastvideo.pipelines.stages.denoising.CosmosDenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
    result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.denoising.CosmosDenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.denoising.DenoisingStage

DenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: PipelineStage

Stage for running the denoising loop in diffusion pipelines.

This stage handles the iterative denoising process that transforms the initial noise into the final output.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.VMOBA_ATTN,
                                      AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA,
                                      AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )

Functions

fastvideo.pipelines.stages.denoising.DenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
        if pipeline:
            pipeline.add_module("transformer", self.transformer)
        fastvideo_args.model_loaded["transformer"] = True

    # Prepare extra step kwargs for scheduler
    extra_step_kwargs = self.prepare_extra_func_kwargs(
        self.scheduler.step,
        {
            "generator": batch.generator,
            "eta": batch.eta
        },
    )

    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps
    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert not torch.isnan(image_embeds[0]).any(), "image_embeds contains nan"
        image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    neg_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_neg,
            "encoder_attention_mask": batch.negative_attention_mask,
        },
    )

    action_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "mouse_cond": batch.mouse_cond,
            "keyboard_cond": batch.keyboard_cond,
            "c2ws_plucker_emb": batch.c2ws_plucker_emb,
        },
    )

    camera_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "camera_states": batch.camera_states,
        },
    )

    # Get latents and embeddings
    latents = batch.latents
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
    if batch.do_classifier_free_guidance:
        neg_prompt_embeds = batch.negative_prompt_embeds
        assert neg_prompt_embeds is not None
        assert not torch.isnan(neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"

    # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
    boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
    if batch.boundary_ratio is not None:
        logger.info("Overriding boundary ratio from %s to %s", boundary_ratio, batch.boundary_ratio)
        boundary_ratio = batch.boundary_ratio

    boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps if boundary_ratio is not None else None
    latent_model_input = latents.to(target_dtype)
    assert latent_model_input.shape[0] == 1, "only support batch size 1"

    if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
        # TI2V directly replaces the first frame of the latent with
        # the image latent instead of appending along the channel dim
        assert batch.image_latent is None, "TI2V task should not have image latents"
        assert self.vae is not None, "VAE is not provided for TI2V task"
        z = self.vae.encode(batch.pil_image).mean.float()
        if (hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                z -= self.vae.shift_factor.to(z.device, z.dtype)
            else:
                z -= self.vae.shift_factor

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            z = z * self.vae.scaling_factor.to(z.device, z.dtype)
        else:
            z = z * self.vae.scaling_factor

        latent_model_input = latent_model_input.squeeze(0)
        _, mask2 = masks_like([latent_model_input], zero=True)

        latent_model_input = (1. - mask2[0]) * z + mask2[0] * latent_model_input
        # latent_model_input = latent_model_input.unsqueeze(0)
        latent_model_input = latent_model_input.to(get_local_torch_device())
        latents = latent_model_input
        F = batch.num_frames
        temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
        spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        seq_len = ((F - 1) // temporal_scale + 1) * (batch.height // spatial_scale) * (
            batch.width // spatial_scale) // (patch_size[1] * patch_size[2])

    # Initialize lists for ODE trajectory
    trajectory_timesteps: list[torch.Tensor] = []
    trajectory_latents: list[torch.Tensor] = []

    # Run denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue

            if boundary_timestep is None or t >= boundary_timestep:
                if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
                        and self.transformer_2 is not None
                        and next(self.transformer_2.parameters()).device.type == 'cuda'):
                    self.transformer_2.to('cpu')
                current_model = self.transformer
                if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
                        and not fastvideo_args.use_fsdp_inference and current_model is not None):
                    transformer_device = next(current_model.parameters()).device.type
                    if transformer_device == 'cpu':
                        current_model.to(get_local_torch_device())
                current_guidance_scale = batch.guidance_scale
            else:
                # low-noise stage in wan2.2
                if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
                        and next(self.transformer.parameters()).device.type == 'cuda'):
                    self.transformer.to('cpu')
                current_model = self.transformer_2
                if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
                        and not fastvideo_args.use_fsdp_inference and current_model is not None):
                    transformer_2_device = next(current_model.parameters()).device.type
                    if transformer_2_device == 'cpu':
                        current_model.to(get_local_torch_device())
                current_guidance_scale = batch.guidance_scale_2
            assert current_model is not None, "current_model is None"

            # Expand latents for V2V/I2V
            latent_model_input = latents.to(target_dtype)
            if batch.video_latent is not None:
                latent_model_input = torch.cat([latent_model_input, batch.video_latent,
                                                torch.zeros_like(latents)],
                                               dim=1).to(target_dtype)
            elif batch.image_latent is not None:
                assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
                latent_model_input = torch.cat([latent_model_input, batch.image_latent], dim=1).to(target_dtype)

            assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
            if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                timestep = torch.stack([t]).to(get_local_torch_device())
                temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
                temp_ts = torch.cat([temp_ts, temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep])
                timestep = temp_ts.unsqueeze(0)
                t_expand = timestep.repeat(latent_model_input.shape[0], 1)
            else:
                t_expand = t.repeat(latent_model_input.shape[0])
            t_expand = t_expand.to(get_local_torch_device())

            use_meanflow = getattr(self.transformer.config, "use_meanflow", False)
            if use_meanflow:
                if i == len(timesteps) - 1:
                    timesteps_r = torch.tensor([0.0], device=get_local_torch_device())
                else:
                    timesteps_r = timesteps[i + 1]
                timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
            else:
                timesteps_r = None

            timesteps_r_kwarg = self.prepare_extra_func_kwargs(
                self.transformer.forward,
                {
                    "timestep_r": timesteps_r,
                },
            )

            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # Prepare inputs for transformer
            guidance_expand = (torch.tensor(
                [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
                dtype=torch.float32,
                device=get_local_torch_device(),
            ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
                if (vsa_available and self.attn_backend == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls()
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            VSA_sparsity=fastvideo_args.VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),
                        )
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                elif (vmoba_attn_available and self.attn_backend == VMOBAAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls()
                        # Prepare V-MoBA parameters from config
                        moba_params = fastvideo_args.moba_config.copy()
                        moba_params.update({
                            "current_timestep": i,
                            "raw_latent_shape": batch.raw_latent_shape[2:5],
                            "patch_size": fastvideo_args.pipeline_config.dit_config.patch_size,
                            "device": get_local_torch_device(),
                        })
                        attn_metadata = self.attn_metadata_builder.build(**moba_params)
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None
                # TODO(will): finalize the interface. vLLM uses this to
                # support torch dynamo compilation. They pass in
                # attn_metadata, vllm_config, and num_tokens. We can pass in
                # fastvideo_args or training_args, and attn_metadata.
                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    noise_pred = current_model(
                        latent_model_input,
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                        **action_kwargs,
                        **camera_kwargs,
                        **timesteps_r_kwarg,
                    )

                if batch.do_classifier_free_guidance:
                    batch.is_cfg_negative = True
                    with set_forward_context(
                            current_timestep=i,
                            attn_metadata=attn_metadata,
                            forward_batch=batch,
                    ):
                        noise_pred_uncond = current_model(
                            latent_model_input,
                            neg_prompt_embeds,
                            t_expand,
                            guidance=guidance_expand,
                            **image_kwargs,
                            **neg_cond_kwargs,
                            **action_kwargs,
                            **camera_kwargs,
                            **timesteps_r_kwarg,
                        )

                    noise_pred_text = noise_pred
                    noise_pred = noise_pred_uncond + current_guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # Apply guidance rescale if needed
                    if batch.guidance_rescale > 0.0:
                        # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        noise_pred = self.rescale_noise_cfg(
                            noise_pred,
                            noise_pred_text,
                            guidance_rescale=batch.guidance_rescale,
                        )
                # Compute the previous noisy sample
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                    latents = latents.squeeze(0)
                    latents = (1. - mask2[0]) * z + mask2[0] * latents
                    # latents = latents.unsqueeze(0)

            # save trajectory latents if needed
            if batch.return_trajectory_latents:
                trajectory_timesteps.append(t)
                trajectory_latents.append(latents)

            # Update progress bar
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
                                           (i + 1) % self.scheduler.order == 0 and progress_bar is not None):
                progress_bar.update()

    trajectory_tensor: torch.Tensor | None = None
    if trajectory_latents:
        trajectory_tensor = torch.stack(trajectory_latents, dim=1)
        trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)
    else:
        trajectory_tensor = None
        trajectory_timesteps_tensor = None

    if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
        batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
        batch.trajectory_latents = trajectory_tensor.cpu()

    # Update batch with final latents
    batch.latents = latents

    if fastvideo_args.dit_layerwise_offload:
        mgr = getattr(self.transformer, "_layerwise_offload_manager", None)
        if mgr is not None and getattr(mgr, "enabled", False):
            mgr.release_all()
        if self.transformer_2 is not None:
            mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager", None)
            if mgr2 is not None and getattr(mgr2, "enabled", False):
                mgr2.release_all()

    # deallocate transformer if on mps
    if torch.backends.mps.is_available():
        logger.info("Memory before deallocating transformer: %s", torch.mps.current_allocated_memory())
        del self.transformer
        if pipeline is not None and "transformer" in pipeline.modules:
            del pipeline.modules["transformer"]
        fastvideo_args.model_loaded["transformer"] = False
        logger.info("Memory after deallocating transformer: %s", torch.mps.current_allocated_memory())

    return batch
fastvideo.pipelines.stages.denoising.DenoisingStage.prepare_extra_func_kwargs
prepare_extra_func_kwargs(func, kwargs) -> dict[str, Any]

Prepare extra kwargs for the scheduler step / denoise step.

Parameters:

Name Type Description Default
func

The function to prepare kwargs for.

required
kwargs

The kwargs to prepare.

required

Returns:

Type Description
dict[str, Any]

The prepared kwargs.

Source code in fastvideo/pipelines/stages/denoising.py
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
    """
    Prepare extra kwargs for the scheduler step / denoise step.

    Args:
        func: The function to prepare kwargs for.
        kwargs: The kwargs to prepare.

    Returns:
        The prepared kwargs.
    """
    extra_step_kwargs = {}
    for k, v in kwargs.items():
        accepts = k in set(inspect.signature(func).parameters.keys())
        if accepts:
            extra_step_kwargs[k] = v
    return extra_step_kwargs
fastvideo.pipelines.stages.denoising.DenoisingStage.progress_bar
progress_bar(iterable: Iterable | None = None, total: int | None = None) -> tqdm

Create a progress bar for the denoising process.

Parameters:

Name Type Description Default
iterable Iterable | None

The iterable to iterate over.

None
total int | None

The total number of items.

None

Returns:

Type Description
tqdm

A tqdm progress bar.

Source code in fastvideo/pipelines/stages/denoising.py
def progress_bar(self, iterable: Iterable | None = None, total: int | None = None) -> tqdm:
    """
    Create a progress bar for the denoising process.

    Args:
        iterable: The iterable to iterate over.
        total: The total number of items.

    Returns:
        A tqdm progress bar.
    """
    local_rank = get_world_group().local_rank
    if local_rank == 0:
        return tqdm(iterable=iterable, total=total)
    else:
        return tqdm(iterable=iterable, total=total, disable=True)
fastvideo.pipelines.stages.denoising.DenoisingStage.rescale_noise_cfg
rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0) -> Tensor

Rescale noise prediction according to guidance_rescale.

Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

Parameters:

Name Type Description Default
noise_cfg

The noise prediction with guidance.

required
noise_pred_text

The text-conditioned noise prediction.

required
guidance_rescale

The guidance rescale factor.

0.0

Returns:

Type Description
Tensor

The rescaled noise prediction.

Source code in fastvideo/pipelines/stages/denoising.py
def rescale_noise_cfg(self, noise_cfg, noise_pred_text, guidance_rescale=0.0) -> torch.Tensor:
    """
    Rescale noise prediction according to guidance_rescale.

    Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
    (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

    Args:
        noise_cfg: The noise prediction with guidance.
        noise_pred_text: The text-conditioned noise prediction.
        guidance_rescale: The guidance rescale factor.

    Returns:
        The rescaled noise prediction.
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # Rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # Mix with the original results from guidance by factor guidance_rescale
    noise_cfg = (guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg)
    return noise_cfg
fastvideo.pipelines.stages.denoising.DenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)])
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent, V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator, V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.denoising.DenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.denoising.DmdDenoisingStage

DmdDenoisingStage(transformer, scheduler)

Bases: DenoisingStage

Denoising stage for DMD.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler) -> None:
    super().__init__(transformer, scheduler)
    self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)

Functions

fastvideo.pipelines.stages.denoising.DmdDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps

    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert torch.isnan(image_embeds[0]).sum() == 0
        image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    # Get latents and embeddings
    assert batch.latents is not None, "latents must be provided"
    latents = batch.latents

    video_raw_latent_shape = latents.shape
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
    timesteps = torch.tensor(fastvideo_args.pipeline_config.dmd_denoising_steps,
                             dtype=torch.long,
                             device=get_local_torch_device())

    # Run denoising loop
    with self.progress_bar(total=len(timesteps)) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue
            # Expand latents for I2V
            noise_latents = latents.clone()
            latent_model_input = latents.to(target_dtype)

            if batch.image_latent is not None:
                latent_model_input = torch.cat(
                    [latent_model_input, batch.image_latent.permute(0, 2, 1, 3, 4)], dim=2).to(target_dtype)
            assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"

            # Prepare inputs for transformer
            t_expand = t.repeat(latent_model_input.shape[0])
            guidance_expand = (torch.tensor(
                [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
                dtype=torch.float32,
                device=get_local_torch_device(),
            ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
                if (vsa_available and self.attn_backend == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls()
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            VSA_sparsity=fastvideo_args.VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),  # type: ignore
                        )  # type: ignore
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None

                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    pred_noise = self.transformer(
                        latent_model_input.permute(0, 2, 1, 3, 4),
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                    ).permute(0, 2, 1, 3, 4)

                pred_video = pred_noise_to_pred_video(pred_noise=pred_noise.flatten(0, 1),
                                                      noise_input_latent=noise_latents.flatten(0, 1),
                                                      timestep=t_expand,
                                                      scheduler=self.scheduler).unflatten(0, pred_noise.shape[:2])

                if i < len(timesteps) - 1:
                    next_timestep = timesteps[i + 1] * torch.ones([1], dtype=torch.long, device=pred_video.device)
                    noise = torch.randn(video_raw_latent_shape,
                                        dtype=pred_video.dtype,
                                        generator=batch.generator[0]).to(self.device)
                    latents = self.scheduler.add_noise(pred_video.flatten(0, 1), noise.flatten(0, 1),
                                                       next_timestep).unflatten(0, pred_video.shape[:2])
                else:
                    latents = pred_video

                # Update progress bar
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
                                               (i + 1) % self.scheduler.order == 0 and progress_bar is not None):
                    progress_bar.update()

    # Gather results if using sequence parallelism
    latents = latents.permute(0, 2, 1, 3, 4)
    # Update batch with final latents
    batch.latents = latents

    return batch

Functions