Skip to content

Source: examples/inference/optimizations

Optimization Examples

python examples/inference/optimizations/attention_example.py

Additional Files

attention_example.py
import os
import time

from fastvideo import VideoGenerator

def main():
    # set the attention backend 
    os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "FLASH_ATTN"

    start_time = time.perf_counter()
    gen = VideoGenerator.from_pretrained(
        model_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
        num_gpus=1,
        dit_cpu_offload=False,
        vae_cpu_offload=False,
        text_encoder_cpu_offload=True,
        pin_cpu_memory=True, # set to false if low CPU RAM or hit obscure "CUDA error: Invalid argument"
    )
    load_time = time.perf_counter() - start_time
    print(f"Model loading time: {load_time:.2f} seconds")

    gen_start_time = time.perf_counter()

    gen.generate_video(
        prompt=
        "Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting.",
        seed=1024,
        output_path="example_outputs/")

    generation_time = time.perf_counter() - gen_start_time
    print(f"Video generation time: {generation_time:.2f} seconds")

    total_time = time.perf_counter() - start_time
    print(f"Total execution time: {total_time:.2f} seconds")

if __name__ == "__main__":
    main()
fp4_attn_wan2_1_1_3b.py
"""FP4 Flash Attention 4 inference example on Blackwell GPUs.

Quantizes Q and K to NVFP4 E2M1 with per-block E4M3 scale factors,
achieving up to 1.39x attention kernel speedup over BF16 FA4.

Requirements:
    - Blackwell GPU (B200/B300, sm100a/sm103a)
    - flash-attention-fp4, cutlass-dsl, flashinfer
    - See docs/inference/optimizations.md for installation

Usage:
    python fp4_attn_wan2_1_1_3b.py --nvfp4_fa4
    python fp4_attn_wan2_1_1_3b.py  # BF16 baseline
"""

import argparse
import os
import time

from fastvideo import VideoGenerator

OUTPUT_PATH = "video_samples"


def main():
    parser = argparse.ArgumentParser(description="FP4 FA4 video generation benchmark")
    parser.add_argument("--nvfp4_fa4", action="store_true",
                        help="Enable NVFP4 FP4 quantized QK flash attention")
    parser.add_argument("--model", default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
                        help="Model path or HuggingFace ID")
    parser.add_argument("--compile", action="store_true",
                        help="Enable torch.compile for DIT")
    parser.add_argument("--num_gpus", type=int, default=1)
    parser.add_argument("--infer_steps", type=int, default=50)
    args = parser.parse_args()

    mode = "nvfp4" if args.nvfp4_fa4 else "bf16"
    if args.compile:
        mode += "_compile"
    print(f"Mode: {mode.upper()}")

    generator = VideoGenerator.from_pretrained(
        args.model,
        num_gpus=args.num_gpus,
        nvfp4_fa4=args.nvfp4_fa4,
        use_fsdp_inference=not args.nvfp4_fa4,
        dit_cpu_offload=False,
        dit_layerwise_offload=False,
        vae_cpu_offload=True,
        text_encoder_cpu_offload=True,
        enable_torch_compile=args.compile,
    )

    prompt = (
        "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes "
        "wide with interest. The playful yet serene atmosphere is complemented by soft "
        "natural light filtering through the petals. Mid-shot, warm and cheerful tones."
    )

    n_warmup = 2 if args.compile else 1
    for i in range(n_warmup):
        generator.generate(request={"prompt": prompt, "sampling": {"num_inference_steps": 2},
                                    "output": {"save_video": False}})

    os.makedirs(OUTPUT_PATH, exist_ok=True)
    start = time.time()
    generator.generate(request={
        "prompt": prompt,
        "sampling": {"num_inference_steps": args.infer_steps},
        "output": {"save_video": True, "output_path": os.path.join(OUTPUT_PATH, f"raccoon_{mode}.mp4")},
    })
    elapsed = time.time() - start
    print(f"[{mode.upper()}] {args.infer_steps} steps in {elapsed:.2f}s "
          f"({args.infer_steps / elapsed:.2f} it/s)")

    generator.shutdown()


if __name__ == "__main__":
    main()
text_encoder_quant_example.py
from fastvideo import VideoGenerator
import argparse

OUTPUT_PATH = "video_samples_wan2_2_5B_ti2v"


def main(text_encoder_path: str):
    # FastVideo will automatically use the optimal default arguments for the
    # model.
    # If a local path is provided, FastVideo will make a best effort
    # attempt to identify the optimal arguments.
    model_name = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
    generator = VideoGenerator.from_pretrained(
        model_name,
        # FastVideo will automatically handle distributed setup
        num_gpus=1,
        use_fsdp_inference=True,
        dit_cpu_offload=True,
        vae_cpu_offload=False,
        text_encoder_cpu_offload=False,
        # AbsMaxFP8 is the quantization method used by ComfyUI; 
        # check fastvideo/layers/quantization/* for more quantization methods
        override_text_encoder_quant="AbsMaxFP8",
        # for Wan 2.2, this is the path to "umt5_xxl_fp8_e4m3fn_scaled.safetensors"
        override_text_encoder_safetensors=text_encoder_path,
        pin_cpu_memory=True,  # set to false if low CPU RAM or hit obscure "CUDA error: Invalid argument"
    )

    # I2V is triggered just by passing in an image_path argument
    prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
    image_path = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
    video = generator.generate_video(
        prompt, output_path=OUTPUT_PATH, save_video=True, image_path=image_path
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--text_encoder_path",
        type=str,
        required=True,
        help="Path to the quantized text encoder safetensors file.",
    )
    args = parser.parse_args()
    main(args.text_encoder_path)
torch_compile_example.py
"""torch.compile A/B example for FastVideo.

`enable_torch_compile=True` compiles the DiT submodules that declare
`_compile_conditions` for a substantial end-to-end speedup (e.g.
Wan2.1-T2V-1.3B on A100: ~-24% e2e). It is off by default.

The first compiled generation pays a one-time graph-build cost; it
amortizes over later generations with the same input shapes. This
script does one un-measured warmup then a measured run so the reported
number is steady-state, not graph-build — measuring the warmup is the
most common way to wrongly conclude compile is slower.

Usage:
    # baseline (eager)
    python torch_compile_example.py
    # compiled
    python torch_compile_example.py --compile
"""

import argparse
import os
import time

from fastvideo import VideoGenerator

PROMPT = (
    "A high-definition video of a robotic arm welding a metal structure, "
    "bright sparks and smoke, industrial setting."
)


def main() -> None:
    parser = argparse.ArgumentParser(description="torch.compile A/B")
    parser.add_argument("--model", default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
    parser.add_argument("--compile", action="store_true",
                        help="Enable torch.compile for the DiT")
    parser.add_argument("--num_gpus", type=int, default=1)
    args = parser.parse_args()

    mode = "COMPILE" if args.compile else "BASELINE"
    print(f"Mode: {mode}  (enable_torch_compile={args.compile})")

    os.makedirs("video_samples", exist_ok=True)

    generator = VideoGenerator.from_pretrained(
        args.model,
        num_gpus=args.num_gpus,
        enable_torch_compile=args.compile,
    )

    def _run(tag: str) -> float:
        save = tag == "measured"
        # Modern typed-request API (generate_video is deprecated). Same
        # prompt/seed/shapes both runs so the compiled graph is reused.
        request: dict = {
            "prompt": PROMPT,
            "sampling": {"seed": 1024},
            "output": {"save_video": save},
        }
        if save:
            request["output"]["output_path"] = (
                f"video_samples/torch_compile_{tag}.mp4")
        t0 = time.perf_counter()
        generator.generate(request)
        return time.perf_counter() - t0

    try:
        # Warmup: pays the one-time graph build when --compile. Discarded.
        w = _run("warmup")
        print(f"warmup: {w:.2f}s "
              f"({'incl. graph build' if args.compile else 'cold start'})")

        # Measured: steady state, compiled graph reused (same shapes/seed).
        m = _run("measured")
        print(f"=== {mode} steady-state e2e: {m:.2f}s ===")
    finally:
        generator.shutdown()


if __name__ == "__main__":
    main()