Skip to content

causal_denoising

Classes

fastvideo.pipelines.stages.causal_denoising.CausalDMDDenosingStage

CausalDMDDenosingStage(transformer, scheduler, transformer_2=None, vae=None)

Bases: DenoisingStage

Denoising stage for causal diffusion.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def __init__(self, transformer, scheduler, transformer_2=None, vae=None) -> None:
    super().__init__(transformer, scheduler, transformer_2)
    # KV and cross-attention cache state (initialized on first forward)
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.vae = vae
    # Model-dependent constants (aligned with causal_inference.py assumptions)
    self.num_transformer_blocks = len(self.transformer.blocks)
    self.num_frames_per_block = self.transformer.config.arch_config.num_frames_per_block
    self.sliding_window_num_frames = self.transformer.config.arch_config.sliding_window_num_frames

    try:
        self.local_attn_size = getattr(self.transformer.model, "local_attn_size", -1)  # type: ignore
    except Exception:
        self.local_attn_size = -1

Functions

fastvideo.pipelines.stages.causal_denoising.CausalDMDDenosingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent, V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator, V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result

fastvideo.pipelines.stages.causal_denoising.CausalDenoisingStage

CausalDenoisingStage(transformer, scheduler, transformer_2=None, vae=None)

Bases: CausalDMDDenosingStage

Causal block-by-block denoising with standard multi-step flow matching (scheduler.step), not DMD few-step.

Each block is fully denoised through all scheduler timesteps before moving to the next block. After each block is denoised, the KV cache is updated with clean context so subsequent blocks can attend to prior clean frames.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def __init__(self, transformer, scheduler, transformer_2=None, vae=None) -> None:
    super().__init__(transformer, scheduler, transformer_2)
    # KV and cross-attention cache state (initialized on first forward)
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.vae = vae
    # Model-dependent constants (aligned with causal_inference.py assumptions)
    self.num_transformer_blocks = len(self.transformer.blocks)
    self.num_frames_per_block = self.transformer.config.arch_config.num_frames_per_block
    self.sliding_window_num_frames = self.transformer.config.arch_config.sliding_window_num_frames

    try:
        self.local_attn_size = getattr(self.transformer.model, "local_attn_size", -1)  # type: ignore
    except Exception:
        self.local_attn_size = -1

Functions