Skip to content

basic

Basic inference pipelines for fastvideo.

This package contains basic pipelines for video and image generation.

Modules

fastvideo.pipelines.basic.cosmos

Modules

fastvideo.pipelines.basic.cosmos.cosmos2_5_pipeline

Cosmos 2.5 pipeline entry (staged pipeline).

Classes
fastvideo.pipelines.basic.cosmos.cosmos2_5_pipeline.Cosmos2_5Pipeline
Cosmos2_5Pipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Cosmos 2.5 video generation pipeline.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.cosmos.cosmos_pipeline

Cosmos video diffusion pipeline implementation.

This module contains an implementation of the Cosmos video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.cosmos.cosmos_pipeline.Cosmos2VideoToWorldPipeline
Cosmos2VideoToWorldPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.cosmos.cosmos_pipeline.Cosmos2VideoToWorldPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=CosmosLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                      transformer=self.get_module("transformer"),
                                                      vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=CosmosDenoisingStage(transformer=self.get_module("transformer"),
                                              scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.cosmos.presets

Cosmos model family pipeline presets.

Covers both Cosmos Predict2 and Cosmos Predict2.5, which share the same pipeline directory but have distinct model families.

Classes

fastvideo.pipelines.basic.gamecraft

HunyuanGameCraft pipeline implementations.

Modules

fastvideo.pipelines.basic.gamecraft.gamecraft_pipeline

HunyuanGameCraft video diffusion pipeline implementation.

This module implements the HunyuanGameCraft pipeline for camera/action-conditioned video generation with the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.gamecraft.gamecraft_pipeline.HunyuanGameCraftPipeline
HunyuanGameCraftPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Pipeline for HunyuanGameCraft video generation.

This pipeline supports: - Text-to-video generation with camera/action conditioning - Autoregressive generation with history frames - 33-channel input (16 latent + 16 gt_latent + 1 mask) - CameraNet for encoding Plücker coordinates

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.gamecraft.gamecraft_pipeline.HunyuanGameCraftPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/gamecraft/gamecraft_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(
        stage_name="input_validation_stage",
        stage=InputValidationStage(),
    )

    self.add_stage(
        stage_name="prompt_encoding_stage_primary",
        stage=TextEncodingStage(
            text_encoders=[
                self.get_module("text_encoder"),
                self.get_module("text_encoder_2"),
            ],
            tokenizers=[
                self.get_module("tokenizer"),
                self.get_module("tokenizer_2"),
            ],
        ),
    )

    self.add_stage(
        stage_name="conditioning_stage",
        stage=ConditioningStage(),
    )

    self.add_stage(
        stage_name="timestep_preparation_stage",
        stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")),
    )

    self.add_stage(
        stage_name="latent_preparation_stage",
        stage=LatentPreparationStage(
            scheduler=self.get_module("scheduler"),
            transformer=self.get_module("transformer"),
        ),
    )

    self.add_stage(
        stage_name="denoising_stage",
        stage=GameCraftDenoisingStage(
            transformer=self.get_module("transformer"),
            scheduler=self.get_module("scheduler"),
        ),
    )

    self.add_stage(
        stage_name="decoding_stage",
        stage=DecodingStage(vae=self.get_module("vae")),
    )
Functions
fastvideo.pipelines.basic.gamecraft.presets

HunyuanGameCraft model family pipeline presets.

Classes

fastvideo.pipelines.basic.gen3c

GEN3C is a 3D-informed world-consistent video generation model with precise camera control.

Classes

fastvideo.pipelines.basic.gen3c.Cache3DBase
Cache3DBase(input_image: Tensor, input_depth: Tensor, input_w2c: Tensor, input_intrinsics: Tensor, input_mask: Tensor | None = None, input_format: list[str] | None = None, input_points: Tensor | None = None, weight_dtype: dtype = float32, is_depth: bool = True, device: str = 'cuda', filter_points_threshold: float = 1.0)

Base class for 3D cache management.

The cache maintains: - input_image: RGB images stored in the cache - input_points: 3D world coordinates for each pixel - input_mask: Validity mask for each pixel

Initialize the 3D cache.

Parameters:

Name Type Description Default
input_image Tensor

Input image tensor with varying dimensions

required
input_depth Tensor

Depth map tensor

required
input_w2c Tensor

World-to-camera transformation matrix

required
input_intrinsics Tensor

Camera intrinsic matrix

required
input_mask Tensor | None

Optional validity mask

None
input_format list[str] | None

Dimension labels for input_image (e.g., ['B', 'C', 'H', 'W'])

None
input_points Tensor | None

Pre-computed 3D world points (alternative to depth)

None
weight_dtype dtype

Data type for computations

float32
is_depth bool

If True, input_depth is z-depth; if False, it's distance

True
device str

Computation device

'cuda'
filter_points_threshold float

Threshold for filtering unreliable depth

1.0
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def __init__(
    self,
    input_image: torch.Tensor,
    input_depth: torch.Tensor,
    input_w2c: torch.Tensor,
    input_intrinsics: torch.Tensor,
    input_mask: torch.Tensor | None = None,
    input_format: list[str] | None = None,
    input_points: torch.Tensor | None = None,
    weight_dtype: torch.dtype = torch.float32,
    is_depth: bool = True,
    device: str = "cuda",
    filter_points_threshold: float = 1.0,
):
    """
    Initialize the 3D cache.

    Args:
        input_image: Input image tensor with varying dimensions
        input_depth: Depth map tensor
        input_w2c: World-to-camera transformation matrix
        input_intrinsics: Camera intrinsic matrix
        input_mask: Optional validity mask
        input_format: Dimension labels for input_image (e.g., ['B', 'C', 'H', 'W'])
        input_points: Pre-computed 3D world points (alternative to depth)
        weight_dtype: Data type for computations
        is_depth: If True, input_depth is z-depth; if False, it's distance
        device: Computation device
        filter_points_threshold: Threshold for filtering unreliable depth
    """
    self.weight_dtype = weight_dtype
    self.is_depth = is_depth
    self.device = device
    self.filter_points_threshold = filter_points_threshold

    if input_format is None:
        assert input_image.dim() == 4
        input_format = ["B", "C", "H", "W"]

    # Map dimension names to indices
    format_to_indices = {dim: idx for idx, dim in enumerate(input_format)}
    input_shape = input_image.shape

    if input_mask is not None:
        input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C"))

    # Extract dimensions
    B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1
    F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1
    N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1
    V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1
    H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None
    W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None

    # Reorder dimensions to B x F x N x V x C x H x W
    desired_dims = ["B", "F", "N", "V", "C", "H", "W"]
    permute_order: list[int | None] = []
    for dim in desired_dims:
        idx = format_to_indices.get(dim)
        permute_order.append(idx)

    permute_indices = [idx for idx in permute_order if idx is not None]
    input_image = input_image.permute(*permute_indices)

    for i, idx in enumerate(permute_order):
        if idx is None:
            input_image = input_image.unsqueeze(i)

    # Now input_image has shape B x F x N x V x C x H x W
    if input_mask is not None:
        self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:]
        self.input_mask = self.input_mask.to("cpu")
    else:
        self.input_mask = None
        self.input_image = input_image
    self.input_image = self.input_image.to(weight_dtype).to("cpu")

    # Compute 3D world points
    if input_points is not None:
        self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu")
        self.input_depth = None
    else:
        input_depth = torch.nan_to_num(input_depth, nan=100)
        input_depth = torch.clamp(input_depth, min=0, max=100)
        if weight_dtype == torch.float16:
            input_depth = torch.clamp(input_depth, max=70)

        self.input_points = (unproject_points(
            input_depth.reshape(-1, 1, H, W),
            input_w2c.reshape(-1, 4, 4),
            input_intrinsics.reshape(-1, 3, 3),
            is_depth=self.is_depth,
        ).to(weight_dtype).reshape(B, F, N, V, H, W, 3).to("cpu"))
        self.input_depth = input_depth

    # Filter unreliable depth
    if self.filter_points_threshold < 1.0 and input_depth is not None:
        input_depth = input_depth.reshape(-1, 1, H, W)
        depth_mask = reliable_depth_mask_range_batch(input_depth,
                                                     ratio_thresh=self.filter_points_threshold).reshape(
                                                         B, F, N, V, 1, H, W)
        if self.input_mask is None:
            self.input_mask = depth_mask.to("cpu")
        else:
            self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device)
Functions
fastvideo.pipelines.basic.gen3c.Cache3DBase.input_frame_count
input_frame_count() -> int

Return the number of frames in the cache.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def input_frame_count(self) -> int:
    """Return the number of frames in the cache."""
    return self.input_image.shape[1]
fastvideo.pipelines.basic.gen3c.Cache3DBase.render_cache
render_cache(target_w2cs: Tensor, target_intrinsics: Tensor, render_depth: bool = False, start_frame_idx: int = 0) -> tuple[Tensor, Tensor]

Render the cached 3D points from new camera viewpoints.

Parameters:

Name Type Description Default
target_w2cs Tensor

(b, F_target, 4, 4) target camera transformations

required
target_intrinsics Tensor

(b, F_target, 3, 3) target camera intrinsics

required
render_depth bool

If True, return depth instead of RGB

False
start_frame_idx int

Starting frame index in the cache

0

Returns:

Name Type Description
pixels Tensor

(b, F_target, N, c, h, w) rendered images or depth

masks Tensor

(b, F_target, N, 1, h, w) validity masks

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def render_cache(
    self,
    target_w2cs: torch.Tensor,
    target_intrinsics: torch.Tensor,
    render_depth: bool = False,
    start_frame_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the cached 3D points from new camera viewpoints.

    Args:
        target_w2cs: (b, F_target, 4, 4) target camera transformations
        target_intrinsics: (b, F_target, 3, 3) target camera intrinsics
        render_depth: If True, return depth instead of RGB
        start_frame_idx: Starting frame index in the cache

    Returns:
        pixels: (b, F_target, N, c, h, w) rendered images or depth
        masks: (b, F_target, N, 1, h, w) validity masks
    """
    bs, F_target, _, _ = target_w2cs.shape
    B, F, N, V, C, H, W = self.input_image.shape
    assert bs == B

    target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4)
    target_intrinsics = target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3,
                                                                               3).reshape(-1, 3, 3)

    # Prepare inputs
    first_images = rearrange(
        self.input_image[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, C, H, W),
        "B F N V C H W -> (B F N) V C H W")
    first_points = rearrange(
        self.input_points[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, H, W, 3),
        "B F N V H W C -> (B F N) V H W C")
    first_masks = rearrange(
        self.input_mask[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, 1, H, W),
        "B F N V C H W -> (B F N) V C H W") if self.input_mask is not None else None

    # Process in chunks for memory efficiency
    if first_images.shape[1] == 1:
        warp_chunk_size = 2
        rendered_warp_images = []
        rendered_warp_masks = []
        rendered_warp_depth = []

        first_images = first_images.squeeze(1)
        first_points = first_points.squeeze(1)
        first_masks = first_masks.squeeze(1) if first_masks is not None else None

        for i in range(0, first_images.shape[0], warp_chunk_size):
            with torch.no_grad():
                imgs_chunk = first_images[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                pts_chunk = first_points[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                masks_chunk = (first_masks[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                               if first_masks is not None else None)

                (
                    rendered_warp_images_chunk,
                    rendered_warp_masks_chunk,
                    rendered_warp_depth_chunk,
                    _,
                ) = forward_warp(
                    imgs_chunk,
                    mask1=masks_chunk,
                    depth1=None,
                    transformation1=None,
                    transformation2=target_w2cs[i:i + warp_chunk_size],
                    intrinsic1=target_intrinsics[i:i + warp_chunk_size],
                    intrinsic2=target_intrinsics[i:i + warp_chunk_size],
                    render_depth=render_depth,
                    world_points1=pts_chunk,
                )

                rendered_warp_images.append(rendered_warp_images_chunk.to("cpu"))
                rendered_warp_masks.append(rendered_warp_masks_chunk.to("cpu"))
                if render_depth:
                    rendered_warp_depth.append(rendered_warp_depth_chunk.to("cpu"))

                del imgs_chunk, pts_chunk, masks_chunk
                torch.cuda.empty_cache()

        rendered_warp_images = torch.cat(rendered_warp_images, dim=0)
        rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0)
        if render_depth:
            rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0)
    else:
        raise NotImplementedError("Multi-view rendering not yet supported")

    pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)
    masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)

    if render_depth:
        pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N)

    return pixels.to(self.device), masks.to(self.device)
fastvideo.pipelines.basic.gen3c.Cache3DBase.update_cache
update_cache(**kwargs)

Update the cache with new frames. To be implemented by subclasses.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def update_cache(self, **kwargs):
    """Update the cache with new frames. To be implemented by subclasses."""
    raise NotImplementedError
fastvideo.pipelines.basic.gen3c.Cache3DBuffer
Cache3DBuffer(frame_buffer_max: int = 2, noise_aug_strength: float = 0.0, generator: Generator | None = None, **kwargs)

Bases: Cache3DBase

3D cache with frame buffer support.

This class manages multiple frame buffers for temporal consistency and supports noise augmentation for training stability.

Initialize the buffered 3D cache.

Parameters:

Name Type Description Default
frame_buffer_max int

Maximum number of frames to buffer

2
noise_aug_strength float

Strength of noise augmentation per buffer

0.0
generator Generator | None

Random generator for reproducibility

None
**kwargs

Arguments passed to Cache3DBase

{}
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def __init__(
    self,
    frame_buffer_max: int = 2,
    noise_aug_strength: float = 0.0,
    generator: torch.Generator | None = None,
    **kwargs,
):
    """
    Initialize the buffered 3D cache.

    Args:
        frame_buffer_max: Maximum number of frames to buffer
        noise_aug_strength: Strength of noise augmentation per buffer
        generator: Random generator for reproducibility
        **kwargs: Arguments passed to Cache3DBase
    """
    super().__init__(**kwargs)
    self.frame_buffer_max = frame_buffer_max
    self.noise_aug_strength = noise_aug_strength
    self.generator = generator
Functions
fastvideo.pipelines.basic.gen3c.Cache3DBuffer.render_cache
render_cache(target_w2cs: Tensor, target_intrinsics: Tensor, render_depth: bool = False, start_frame_idx: int = 0) -> tuple[Tensor, Tensor]

Render the cache with optional noise augmentation.

Parameters:

Name Type Description Default
target_w2cs Tensor

(b, F_target, 4, 4) target camera transformations

required
target_intrinsics Tensor

(b, F_target, 3, 3) target camera intrinsics

required
render_depth bool

If True, return depth instead of RGB

False
start_frame_idx int

Starting frame index (must be 0 for this class)

0

Returns:

Name Type Description
pixels Tensor

(b, F_target, N, c, h, w) rendered images

masks Tensor

(b, F_target, N, 1, h, w) validity masks

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def render_cache(
    self,
    target_w2cs: torch.Tensor,
    target_intrinsics: torch.Tensor,
    render_depth: bool = False,
    start_frame_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the cache with optional noise augmentation.

    Args:
        target_w2cs: (b, F_target, 4, 4) target camera transformations
        target_intrinsics: (b, F_target, 3, 3) target camera intrinsics
        render_depth: If True, return depth instead of RGB
        start_frame_idx: Starting frame index (must be 0 for this class)

    Returns:
        pixels: (b, F_target, N, c, h, w) rendered images
        masks: (b, F_target, N, 1, h, w) validity masks
    """
    assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3DBuffer"

    output_device = target_w2cs.device
    target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device)
    target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device)

    pixels, masks = super().render_cache(target_w2cs, target_intrinsics, render_depth)

    pixels = pixels.to(output_device)
    masks = masks.to(output_device)

    # Apply noise augmentation (stronger for older buffers)
    if not render_depth and self.noise_aug_strength > 0:
        noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype)
        per_buffer_noise = (torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device) *
                            self.noise_aug_strength)
        pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1)

    return pixels, masks
fastvideo.pipelines.basic.gen3c.Cache3DBuffer.update_cache
update_cache(new_image: Tensor, new_depth: Tensor, new_w2c: Tensor, new_mask: Tensor | None = None, new_intrinsics: Tensor | None = None)

Update the cache with a new frame.

Parameters:

Name Type Description Default
new_image Tensor

(B, C, H, W) new RGB image

required
new_depth Tensor

(B, 1, H, W) new depth map

required
new_w2c Tensor

(B, 4, 4) new world-to-camera transformation

required
new_mask Tensor | None

Optional (B, 1, H, W) validity mask

None
new_intrinsics Tensor | None

(B, 3, 3) camera intrinsics (optional)

None
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def update_cache(
    self,
    new_image: torch.Tensor,
    new_depth: torch.Tensor,
    new_w2c: torch.Tensor,
    new_mask: torch.Tensor | None = None,
    new_intrinsics: torch.Tensor | None = None,
):
    """
    Update the cache with a new frame.

    Args:
        new_image: (B, C, H, W) new RGB image
        new_depth: (B, 1, H, W) new depth map
        new_w2c: (B, 4, 4) new world-to-camera transformation
        new_mask: Optional (B, 1, H, W) validity mask
        new_intrinsics: (B, 3, 3) camera intrinsics (optional)
    """
    new_image = new_image.to(self.weight_dtype).to(self.device)
    new_depth = new_depth.to(self.weight_dtype).to(self.device)
    new_w2c = new_w2c.to(self.weight_dtype).to(self.device)
    if new_intrinsics is not None:
        new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device)

    new_depth = torch.nan_to_num(new_depth, nan=1e4)
    new_depth = torch.clamp(new_depth, min=0, max=1e4)

    B, F, N, V, C, H, W = self.input_image.shape

    # Compute new 3D points
    new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu()
    new_image = new_image.cpu()

    if self.filter_points_threshold < 1.0:
        new_depth = new_depth.reshape(-1, 1, H, W)
        depth_mask = reliable_depth_mask_range_batch(new_depth,
                                                     ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W)
        new_mask = depth_mask.to("cpu") if new_mask is None else new_mask * depth_mask.to(new_mask.device)
    if new_mask is not None:
        new_mask = new_mask.cpu()

    # Update buffer (newest frame first)
    if self.frame_buffer_max > 1:
        if self.input_image.shape[2] < self.frame_buffer_max:
            self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2)
            self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2)
            if self.input_mask is not None:
                self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2)
        else:
            self.input_image[:, :, 0] = new_image[:, None, None]
            self.input_points[:, :, 0] = new_points[:, None, None]
            if self.input_mask is not None:
                self.input_mask[:, :, 0] = new_mask[:, None, None]
    else:
        self.input_image = new_image[:, None, None, None]
        self.input_points = new_points[:, None, None, None]
fastvideo.pipelines.basic.gen3c.Gen3CConditioningStage
Gen3CConditioningStage(vae=None)

Bases: PipelineStage

3D cache conditioning stage for GEN3C.

This stage performs the core GEN3C innovation: 1. Loads the input image 2. Predicts depth via MoGe 3. Initializes a 3D point cloud cache 4. Generates a camera trajectory 5. Renders warped frames from the cache at each target camera pose 6. Stores rendered warps on the batch for VAE encoding in the latent prep stage

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, vae=None) -> None:
    super().__init__()
    self._moge_model: Any | None = None
    self._vae = vae
Functions
fastvideo.pipelines.basic.gen3c.Gen3CConditioningStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run 3D cache conditioning pipeline.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Run 3D cache conditioning pipeline."""
    pipeline_config = fastvideo_args.pipeline_config
    device = get_local_torch_device()
    batch_extra = getattr(batch, "extra", {}) or {}

    image_path = getattr(batch, 'image_path', None) or batch_extra.get("image_path")
    if image_path is None:
        logger.info("No image_path provided - skipping 3D cache conditioning "
                    "(will use zero conditioning)")
        return batch

    logger.info("Running 3D cache conditioning with image: %s", image_path)

    height = getattr(batch, 'height', None) or getattr(pipeline_config, 'video_resolution', (720, 1280))[0]
    width = getattr(batch, 'width', None) or getattr(pipeline_config, 'video_resolution', (720, 1280))[1]
    num_frames = getattr(batch, 'num_frames', None) or getattr(pipeline_config, 'num_frames', 121)

    trajectory_type = (getattr(batch, 'trajectory_type', None) or batch_extra.get("trajectory_type")
                       or getattr(pipeline_config, 'default_trajectory_type', 'left'))
    movement_distance = (getattr(batch, 'movement_distance', None) or batch_extra.get("movement_distance")
                         or getattr(pipeline_config, 'default_movement_distance', 0.3))
    camera_rotation = (getattr(batch, 'camera_rotation', None) or batch_extra.get("camera_rotation")
                       or getattr(pipeline_config, 'default_camera_rotation', 'center_facing'))

    frame_buffer_max = getattr(pipeline_config, 'frame_buffer_max', 2)
    noise_aug_strength = getattr(pipeline_config, 'noise_aug_strength', 0.0)
    filter_points_threshold = getattr(pipeline_config, 'filter_points_threshold', 0.05)

    moge_model_name = getattr(pipeline_config, 'moge_model_name', 'Ruicheng/moge-vitl')

    from fastvideo.pipelines.basic.gen3c.depth_estimation import (predict_depth_from_path)

    moge_model = self._get_moge_model(device, moge_model_name)

    (
        image_b1chw,
        depth_b11hw,
        mask_b11hw,
        w2c_b144,
        intrinsics_b133,
    ) = predict_depth_from_path(image_path, height, width, device, moge_model)

    logger.info(
        "Depth prediction complete. Depth range: [%.3f, %.3f]",
        depth_b11hw.min().item(),
        depth_b11hw.max().item(),
    )

    from fastvideo.pipelines.basic.gen3c.cache_3d import Cache3DBuffer

    seed = getattr(batch, 'seed', None)
    if seed is None:
        seed = 42
    generator = torch.Generator(device=device).manual_seed(seed)

    cache = Cache3DBuffer(
        frame_buffer_max=frame_buffer_max,
        generator=generator,
        noise_aug_strength=noise_aug_strength,
        input_image=image_b1chw[:, 0].clone(),
        input_depth=depth_b11hw[:, 0],
        input_w2c=w2c_b144[:, 0],
        input_intrinsics=intrinsics_b133[:, 0],
        filter_points_threshold=filter_points_threshold,
    )

    logger.info("3D cache initialized with %d frame buffer(s)", frame_buffer_max)

    from fastvideo.pipelines.basic.gen3c.camera_utils import (generate_camera_trajectory)

    initial_w2c = w2c_b144[0, 0]
    initial_intrinsics = intrinsics_b133[0, 0]

    generated_w2cs, generated_intrinsics = generate_camera_trajectory(
        trajectory_type=trajectory_type,
        initial_w2c=initial_w2c,
        initial_intrinsics=initial_intrinsics,
        num_frames=num_frames,
        movement_distance=movement_distance,
        camera_rotation=camera_rotation,
        center_depth=1.0,
        device=device.type if isinstance(device, torch.device) else device,
    )

    logger.info(
        "Camera trajectory generated: type=%s, frames=%d, distance=%.3f",
        trajectory_type,
        num_frames,
        movement_distance,
    )

    rendered_warp_images, rendered_warp_masks = cache.render_cache(
        generated_w2cs[:, :num_frames],
        generated_intrinsics[:, :num_frames],
    )

    logger.info(
        "Cache rendered. Warped images shape: %s, non-zero mask ratio: %.3f",
        list(rendered_warp_images.shape),
        (rendered_warp_masks > 0).float().mean().item(),
    )

    batch.rendered_warp_images = rendered_warp_images.to(device)
    batch.rendered_warp_masks = rendered_warp_masks.to(device)
    batch.input_image_conditioning = image_b1chw[:, 0].unsqueeze(2).contiguous().to(device)
    batch.cache_3d = cache

    if getattr(pipeline_config, "offload_moge_after_depth", True):
        self._offload_moge()

    return batch
fastvideo.pipelines.basic.gen3c.Gen3CDenoisingStage
Gen3CDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for GEN3C models.

This stage extends the base denoising stage with support for: - condition_video_input_mask: Binary mask indicating conditioning frames - condition_video_pose: VAE-encoded 3D cache buffers - condition_video_augment_sigma: Noise augmentation sigma

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)
fastvideo.pipelines.basic.gen3c.Gen3CLatentPreparationStage
Gen3CLatentPreparationStage(scheduler, transformer, vae)

Bases: LatentPreparationStage

Latent preparation stage for GEN3C.

This stage prepares latents and encodes 3D cache buffers through the VAE. If rendered warped frames are available on the batch (from Gen3CConditioningStage), they are VAE-encoded to produce real conditioning. Otherwise falls back to zeros.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, scheduler, transformer, vae) -> None:
    super().__init__(scheduler, transformer)
    self.vae = vae
Functions
fastvideo.pipelines.basic.gen3c.Gen3CLatentPreparationStage.encode_warped_frames
encode_warped_frames(condition_state: Tensor, condition_state_mask: Tensor, vae: Any, frame_buffer_max: int, dtype: dtype) -> Tensor

Encode rendered 3D cache buffers through VAE.

Parameters:

Name Type Description Default
condition_state Tensor

(B, T, N, 3, H, W) rendered RGB images in [-1, 1].

required
condition_state_mask Tensor

(B, T, N, 1, H, W) rendered masks in [0, 1].

required
vae Any

VAE encoder.

required
frame_buffer_max int

Maximum number of buffers.

required
dtype dtype

Target dtype.

required

Returns:

Name Type Description
latent_condition Tensor

(B, buffer_channels, T_latent, H_latent, W_latent)

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def encode_warped_frames(
    self,
    condition_state: torch.Tensor,
    condition_state_mask: torch.Tensor,
    vae: Any,
    frame_buffer_max: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Encode rendered 3D cache buffers through VAE.

    Args:
        condition_state: (B, T, N, 3, H, W) rendered RGB images in [-1, 1].
        condition_state_mask: (B, T, N, 1, H, W) rendered masks in [0, 1].
        vae: VAE encoder.
        frame_buffer_max: Maximum number of buffers.
        dtype: Target dtype.

    Returns:
        latent_condition: (B, buffer_channels, T_latent, H_latent, W_latent)
    """
    assert condition_state.dim() == 6

    condition_state_mask = (condition_state_mask * 2 - 1).repeat(1, 1, 1, 3, 1, 1)

    latent_condition = []
    num_buffers = condition_state.shape[2]
    for i in range(num_buffers):
        img_input = condition_state[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
        mask_input = condition_state_mask[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
        batched_input = torch.cat([img_input, mask_input], dim=0)
        batched_latent = self._retrieve_latents(vae.encode(batched_input)).contiguous()
        current_video_latent, current_mask_latent = batched_latent.chunk(2, dim=0)

        latent_condition.append(current_video_latent)
        latent_condition.append(current_mask_latent)

    for _ in range(frame_buffer_max - num_buffers):
        latent_condition.append(torch.zeros_like(current_video_latent))
        latent_condition.append(torch.zeros_like(current_mask_latent))

    return torch.cat(latent_condition, dim=1)
fastvideo.pipelines.basic.gen3c.Gen3CLatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare latents and encode 3D cache buffers.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Prepare latents and encode 3D cache buffers."""
    pipeline_config = fastvideo_args.pipeline_config
    device = get_local_torch_device()

    if isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]
    batch_size *= batch.num_videos_per_prompt

    num_channels_latents = getattr(self.transformer, 'num_channels_latents', 16)

    fallback_num_frames = getattr(pipeline_config, 'num_frames', 121)
    if not isinstance(fallback_num_frames, int):
        fallback_num_frames = 121

    num_frames_raw = getattr(batch, 'num_frames', None)
    if num_frames_raw is None:
        num_frames_raw = fallback_num_frames
    if isinstance(num_frames_raw, list):
        num_frames = int(num_frames_raw[0]) if len(num_frames_raw) > 0 else fallback_num_frames
    elif isinstance(num_frames_raw, int):
        num_frames = num_frames_raw
    else:
        num_frames = fallback_num_frames
    if hasattr(self.vae, "get_latent_num_frames"):
        latent_frames = int(self.vae.get_latent_num_frames(num_frames))
    else:
        temporal_ratio = getattr(
            pipeline_config.vae_config.arch_config,
            "temporal_compression_ratio",
            4,
        )
        latent_frames = int((num_frames - 1) // temporal_ratio + 1)
    height = getattr(batch, 'height', 720)
    width = getattr(batch, 'width', 1280)

    spatial_ratio = getattr(
        pipeline_config.vae_config.arch_config,
        "spatial_compression_ratio",
        8,
    )
    latent_height = height // spatial_ratio
    latent_width = width // spatial_ratio

    generator = getattr(batch, "generator", None)
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(f"Expected {batch_size} generators, got {len(generator)}.")

    latents = randn_tensor(
        (
            batch_size,
            num_channels_latents,
            latent_frames,
            latent_height,
            latent_width,
        ),
        generator=generator,
        device=device,
        dtype=torch.float32,
    )

    if hasattr(self.scheduler, 'init_noise_sigma'):
        latents = latents * self.scheduler.init_noise_sigma

    batch.latents = latents
    batch.batch_size = batch_size
    batch.height = height
    batch.width = width
    batch.latent_height = latent_height
    batch.latent_width = latent_width
    batch.latent_frames = latent_frames
    batch.raw_latent_shape = latents.shape

    frame_buffer_max = getattr(pipeline_config, 'frame_buffer_max', 2)
    channels_per_buffer = 32
    buffer_channels = frame_buffer_max * channels_per_buffer

    rendered_warp_images = getattr(batch, 'rendered_warp_images', None)
    rendered_warp_masks = getattr(batch, 'rendered_warp_masks', None)

    if rendered_warp_images is not None and rendered_warp_masks is not None:
        logger.info(
            "Encoding rendered warped frames through VAE (%d buffers)...",
            rendered_warp_images.shape[2],
        )

        self.vae = self.vae.to(device)

        if hasattr(self.vae, 'module'):
            vae_dtype = next(self.vae.module.parameters()).dtype
        else:
            vae_dtype = next(self.vae.parameters()).dtype

        condition_video_pose = self.encode_warped_frames(
            rendered_warp_images,
            rendered_warp_masks,
            self.vae,
            frame_buffer_max,
            vae_dtype,
        )
        batch.condition_video_pose = condition_video_pose.to(device)

        logger.info(
            "condition_video_pose encoded. Shape: %s, non-zero: %.4f",
            list(batch.condition_video_pose.shape),
            (batch.condition_video_pose != 0).float().mean().item(),
        )

        source_image = getattr(batch, "input_image_conditioning", None)
        if source_image is None:
            source_image = rendered_warp_images[:, 0, 0].unsqueeze(2)
        first_frame = source_image.to(device=device, dtype=vae_dtype)
        first_latent = self._retrieve_latents(self.vae.encode(first_frame))
        conditioning_latents = torch.zeros(
            batch_size,
            num_channels_latents,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=first_latent.dtype,
        )
        conditioning_latents[:, :, :first_latent.shape[2], :, :] = first_latent
        batch.conditioning_latents = conditioning_latents

        if fastvideo_args.vae_cpu_offload:
            self.vae.to("cpu")

        batch.condition_video_input_mask = torch.zeros(
            batch_size,
            1,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.condition_video_input_mask[:, :, 0, :, :] = 1.0
    else:
        logger.info("No rendered warps available - using zero conditioning")
        batch.condition_video_pose = torch.zeros(
            batch_size,
            buffer_channels,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.condition_video_input_mask = torch.zeros(
            batch_size,
            1,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.conditioning_latents = None

    batch.condition_video_augment_sigma = torch.zeros(batch_size, device=device, dtype=torch.float32)
    batch.cond_indicator = torch.zeros(
        batch_size,
        1,
        latent_frames,
        latent_height,
        latent_width,
        device=device,
        dtype=torch.float32,
    )
    batch.cond_indicator[:, :, 0, :, :] = 1.0

    ones_padding = torch.ones_like(batch.cond_indicator)
    zeros_padding = torch.zeros_like(batch.cond_indicator)
    batch.cond_mask = batch.cond_indicator * ones_padding + (1 - batch.cond_indicator) * zeros_padding

    if batch.do_classifier_free_guidance:
        batch.uncond_indicator = batch.cond_indicator.clone()
        batch.uncond_mask = batch.cond_mask.clone()
    else:
        batch.uncond_indicator = None
        batch.uncond_mask = None

    return batch
fastvideo.pipelines.basic.gen3c.Gen3CPipeline
Gen3CPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

GEN3C Video Generation Pipeline.

This pipeline extends Cosmos with 3D cache support for camera-controlled video generation. When an input image is provided, it runs the full 3D cache conditioning pipeline (depth estimation -> point cloud -> camera trajectory -> forward warping -> VAE encoding).

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.gen3c.Gen3CPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/gen3c/gen3c_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="cfg_policy_stage", stage=Gen3CCFGPolicyStage())

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=Gen3CConditioningStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=Gen3CLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                     transformer=self.get_module("transformer"),
                                                     vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=Gen3CDenoisingStage(transformer=self.get_module("transformer"),
                                             scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))

Functions

fastvideo.pipelines.basic.gen3c.forward_warp
forward_warp(frame1: Tensor, mask1: Tensor | None, depth1: Tensor | None, transformation1: Tensor | None, transformation2: Tensor, intrinsic1: Tensor | None, intrinsic2: Tensor | None, is_image: bool = True, is_depth: bool = True, render_depth: bool = False, world_points1: Tensor | None = None) -> tuple[Tensor, Tensor, Tensor | None, Tensor]

Forward warp frame1 to a new view defined by transformation2.

Parameters:

Name Type Description Default
frame1 Tensor

(b, c, h, w) source frame in range [-1, 1] for images

required
mask1 Tensor | None

(b, 1, h, w) valid pixel mask

required
depth1 Tensor | None

(b, 1, h, w) depth map (required if world_points1 is None)

required
transformation1 Tensor | None

(b, 4, 4) source camera w2c (required if depth1 is provided)

required
transformation2 Tensor

(b, 4, 4) target camera w2c

required
intrinsic1 Tensor | None

(b, 3, 3) source camera intrinsics

required
intrinsic2 Tensor | None

(b, 3, 3) target camera intrinsics

required
is_image bool

If True, output will be clipped to (-1, 1)

True
is_depth bool

If True, depth1 is z-depth; if False, it's distance

True
render_depth bool

If True, also return the warped depth map

False
world_points1 Tensor | None

(b, h, w, 3) pre-computed world points (alternative to depth1)

None

Returns:

Name Type Description
warped_frame2 Tensor

(b, c, h, w) warped frame

mask2 Tensor

(b, 1, h, w) validity mask

warped_depth2 Tensor | None

(b, h, w) warped depth (if render_depth=True)

flow12 Tensor

(b, 2, h, w) optical flow

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def forward_warp(
    frame1: torch.Tensor,
    mask1: torch.Tensor | None,
    depth1: torch.Tensor | None,
    transformation1: torch.Tensor | None,
    transformation2: torch.Tensor,
    intrinsic1: torch.Tensor | None,
    intrinsic2: torch.Tensor | None,
    is_image: bool = True,
    is_depth: bool = True,
    render_depth: bool = False,
    world_points1: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
    """
    Forward warp frame1 to a new view defined by transformation2.

    Args:
        frame1: (b, c, h, w) source frame in range [-1, 1] for images
        mask1: (b, 1, h, w) valid pixel mask
        depth1: (b, 1, h, w) depth map (required if world_points1 is None)
        transformation1: (b, 4, 4) source camera w2c (required if depth1 is provided)
        transformation2: (b, 4, 4) target camera w2c
        intrinsic1: (b, 3, 3) source camera intrinsics
        intrinsic2: (b, 3, 3) target camera intrinsics
        is_image: If True, output will be clipped to (-1, 1)
        is_depth: If True, depth1 is z-depth; if False, it's distance
        render_depth: If True, also return the warped depth map
        world_points1: (b, h, w, 3) pre-computed world points (alternative to depth1)

    Returns:
        warped_frame2: (b, c, h, w) warped frame
        mask2: (b, 1, h, w) validity mask
        warped_depth2: (b, h, w) warped depth (if render_depth=True)
        flow12: (b, 2, h, w) optical flow
    """
    device = frame1.device
    b, c, h, w = frame1.shape
    dtype = frame1.dtype

    if mask1 is None:
        mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype)
    if intrinsic2 is None:
        assert intrinsic1 is not None
        intrinsic2 = intrinsic1.clone()

    if world_points1 is not None:
        # Use pre-computed world points
        assert world_points1.shape == (b, h, w, 3)
        trans_points1 = project_points(world_points1, transformation2, intrinsic2)
    else:
        # Compute from depth
        assert depth1 is not None and transformation1 is not None
        assert depth1.shape == (b, 1, h, w)

        depth1 = torch.nan_to_num(depth1, nan=1e4)
        depth1 = torch.clamp(depth1, min=0, max=1e4)

        # Unproject to world, then project to target view
        world_points1 = unproject_points(depth1, transformation1, intrinsic1, is_depth=is_depth)
        trans_points1 = project_points(world_points1, transformation2, intrinsic2)

    # Filter points behind camera
    mask1 = mask1 * (trans_points1[:, :, :, 2, 0].unsqueeze(1) > 0)
    trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0] + 1e-7)
    trans_coordinates = trans_coordinates.permute(0, 3, 1, 2)  # b, 2, h, w
    trans_depth1 = trans_points1[:, :, :, 2, 0].unsqueeze(1)

    grid = create_grid(b, h, w, device=device, dtype=dtype)
    flow12 = trans_coordinates - grid

    warped_frame2, mask2 = bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image)

    warped_depth2 = None
    if render_depth:
        warped_depth2 = bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False)[0][:, 0]

    return warped_frame2, mask2, warped_depth2, flow12
fastvideo.pipelines.basic.gen3c.generate_camera_trajectory
generate_camera_trajectory(trajectory_type: str, initial_w2c: Tensor, initial_intrinsics: Tensor, num_frames: int, movement_distance: float, camera_rotation: str = 'center_facing', center_depth: float = 1.0, device: str = 'cuda') -> tuple[Tensor, Tensor]

Generate camera trajectory for GEN3C video generation.

Parameters:

Name Type Description Default
trajectory_type str

One of "left", "right", "up", "down", "zoom_in", "zoom_out", "clockwise", "counterclockwise".

required
initial_w2c Tensor

Initial world-to-camera matrix (4, 4).

required
initial_intrinsics Tensor

Camera intrinsics matrix (3, 3).

required
num_frames int

Number of frames in the trajectory.

required
movement_distance float

Distance factor for camera movement.

required
camera_rotation str

"center_facing", "no_rotation", or "trajectory_aligned".

'center_facing'
center_depth float

Depth of the scene center point.

1.0
device str

Computation device.

'cuda'

Returns:

Name Type Description
generated_w2cs Tensor

(1, num_frames, 4, 4) world-to-camera matrices.

generated_intrinsics Tensor

(1, num_frames, 3, 3) camera intrinsics.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def generate_camera_trajectory(
    trajectory_type: str,
    initial_w2c: torch.Tensor,
    initial_intrinsics: torch.Tensor,
    num_frames: int,
    movement_distance: float,
    camera_rotation: str = "center_facing",
    center_depth: float = 1.0,
    device: str = "cuda",
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Generate camera trajectory for GEN3C video generation.

    Args:
        trajectory_type: One of "left", "right", "up", "down", "zoom_in",
            "zoom_out", "clockwise", "counterclockwise".
        initial_w2c: Initial world-to-camera matrix (4, 4).
        initial_intrinsics: Camera intrinsics matrix (3, 3).
        num_frames: Number of frames in the trajectory.
        movement_distance: Distance factor for camera movement.
        camera_rotation: "center_facing", "no_rotation", or "trajectory_aligned".
        center_depth: Depth of the scene center point.
        device: Computation device.

    Returns:
        generated_w2cs: (1, num_frames, 4, 4) world-to-camera matrices.
        generated_intrinsics: (1, num_frames, 3, 3) camera intrinsics.
    """
    if trajectory_type in ["clockwise", "counterclockwise"]:
        new_w2cs_seq = create_spiral_trajectory(
            world_to_camera_matrix=initial_w2c,
            center_depth=center_depth,
            n_steps=num_frames,
            positive=trajectory_type == "clockwise",
            device=device,
            camera_rotation=camera_rotation,
            radius_x=movement_distance,
            radius_y=movement_distance,
        )
    elif trajectory_type == "none":
        # Static camera - repeat identity
        new_w2cs_seq = initial_w2c.unsqueeze(0).expand(num_frames, -1, -1)
    else:
        axis_map = {
            "left": (False, "x"),
            "right": (True, "x"),
            "up": (False, "y"),
            "down": (True, "y"),
            "zoom_in": (True, "z"),
            "zoom_out": (False, "z"),
        }
        if trajectory_type not in axis_map:
            raise ValueError(f"Unsupported trajectory type: {trajectory_type}")
        positive, axis = axis_map[trajectory_type]

        new_w2cs_seq = create_horizontal_trajectory(
            world_to_camera_matrix=initial_w2c,
            center_depth=center_depth,
            n_steps=num_frames,
            positive=positive,
            axis=axis,
            distance=movement_distance,
            device=device,
            camera_rotation=camera_rotation,
        )

    generated_w2cs = new_w2cs_seq.unsqueeze(0)  # (1, num_frames, 4, 4)
    if initial_intrinsics.dim() == 2:
        generated_intrinsics = initial_intrinsics.unsqueeze(0).unsqueeze(0).repeat(1, num_frames, 1, 1)
    else:
        generated_intrinsics = initial_intrinsics.unsqueeze(0)

    return generated_w2cs, generated_intrinsics
fastvideo.pipelines.basic.gen3c.project_points
project_points(world_points: Tensor, w2c: Tensor, intrinsic: Tensor) -> Tensor

Project 3D world points to 2D pixel coordinates.

Parameters:

Name Type Description Default
world_points Tensor

(b, h, w, 3) 3D world coordinates

required
w2c Tensor

(b, 4, 4) world-to-camera transformation matrix

required
intrinsic Tensor

(b, 3, 3) camera intrinsic matrix

required

Returns:

Name Type Description
projected_points Tensor

(b, h, w, 3, 1) projected 2D coordinates (x, y, z)

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def project_points(
    world_points: torch.Tensor,
    w2c: torch.Tensor,
    intrinsic: torch.Tensor,
) -> torch.Tensor:
    """
    Project 3D world points to 2D pixel coordinates.

    Args:
        world_points: (b, h, w, 3) 3D world coordinates
        w2c: (b, 4, 4) world-to-camera transformation matrix
        intrinsic: (b, 3, 3) camera intrinsic matrix

    Returns:
        projected_points: (b, h, w, 3, 1) projected 2D coordinates (x, y, z)
    """
    world_points = world_points.unsqueeze(-1)  # (b, h, w, 3, 1)
    b, h, w, _, _ = world_points.shape

    ones_4d = torch.ones((b, h, w, 1, 1), device=world_points.device, dtype=world_points.dtype)
    world_points_homo = torch.cat([world_points, ones_4d], dim=3)  # (b, h, w, 4, 1)

    trans_4d = w2c[:, None, None]  # (b, 1, 1, 4, 4)
    camera_points_homo = torch.matmul(trans_4d, world_points_homo)  # (b, h, w, 4, 1)

    camera_points = camera_points_homo[:, :, :, :3]  # (b, h, w, 3, 1)
    intrinsic_4d = intrinsic[:, None, None]  # (b, 1, 1, 3, 3)
    projected_points = torch.matmul(intrinsic_4d, camera_points)  # (b, h, w, 3, 1)

    return projected_points
fastvideo.pipelines.basic.gen3c.unproject_points
unproject_points(depth: Tensor, w2c: Tensor, intrinsic: Tensor, is_depth: bool = True, mask: Tensor | None = None) -> Tensor

Unproject depth map to 3D world points.

Parameters:

Name Type Description Default
depth Tensor

(b, 1, h, w) depth map

required
w2c Tensor

(b, 4, 4) world-to-camera transformation matrix

required
intrinsic Tensor

(b, 3, 3) camera intrinsic matrix

required
is_depth bool

If True, depth is z-depth; if False, depth is distance to camera

True
mask Tensor | None

Optional (b, h, w) or (b, 1, h, w) mask for valid pixels

None

Returns:

Name Type Description
world_points Tensor

(b, h, w, 3) 3D world coordinates

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def unproject_points(
    depth: torch.Tensor,
    w2c: torch.Tensor,
    intrinsic: torch.Tensor,
    is_depth: bool = True,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Unproject depth map to 3D world points.

    Args:
        depth: (b, 1, h, w) depth map
        w2c: (b, 4, 4) world-to-camera transformation matrix
        intrinsic: (b, 3, 3) camera intrinsic matrix
        is_depth: If True, depth is z-depth; if False, depth is distance to camera
        mask: Optional (b, h, w) or (b, 1, h, w) mask for valid pixels

    Returns:
        world_points: (b, h, w, 3) 3D world coordinates
    """
    b, _, h, w = depth.shape
    device = depth.device
    dtype = depth.dtype

    if mask is None:
        mask = depth > 0
    if mask.dim() == depth.dim() and mask.shape[1] == 1:
        mask = mask[:, 0]

    idx = torch.nonzero(mask)
    if idx.numel() == 0:
        return torch.zeros((b, h, w, 3), device=device, dtype=dtype)

    b_idx, y_idx, x_idx = idx[:, 0], idx[:, 1], idx[:, 2]

    intrinsic_inv = inverse_with_conversion(intrinsic)  # (b, 3, 3)

    x_valid = x_idx.to(dtype)
    y_valid = y_idx.to(dtype)
    ones = torch.ones_like(x_valid)
    pos = torch.stack([x_valid, y_valid, ones], dim=1).unsqueeze(-1)  # (N, 3, 1)

    intrinsic_inv_valid = intrinsic_inv[b_idx]  # (N, 3, 3)
    unnormalized_pos = torch.matmul(intrinsic_inv_valid, pos)  # (N, 3, 1)

    depth_valid = depth[b_idx, 0, y_idx, x_idx].view(-1, 1, 1)
    if is_depth:
        world_points_cam = depth_valid * unnormalized_pos
    else:
        norm_val = torch.norm(unnormalized_pos, dim=1, keepdim=True)
        direction = unnormalized_pos / (norm_val + 1e-8)
        world_points_cam = depth_valid * direction

    ones_h = torch.ones((world_points_cam.shape[0], 1, 1), device=device, dtype=dtype)
    world_points_homo = torch.cat([world_points_cam, ones_h], dim=1)  # (N, 4, 1)

    trans = inverse_with_conversion(w2c)  # (b, 4, 4)
    trans_valid = trans[b_idx]  # (N, 4, 4)
    world_points_transformed = torch.matmul(trans_valid, world_points_homo)  # (N, 4, 1)
    sparse_points = world_points_transformed[:, :3, 0]  # (N, 3)

    out_points = torch.zeros((b, h, w, 3), device=device, dtype=dtype)
    out_points[b_idx, y_idx, x_idx, :] = sparse_points
    return out_points

Modules

fastvideo.pipelines.basic.gen3c.cache_3d

This module implements the 3D cache system for GEN3C video generation with camera control. The cache maintains a point cloud representation of the scene, enabling: - Unprojecting depth maps to 3D world points - Forward warping rendered views to new camera poses - Managing multiple frame buffers for temporal consistency

Classes
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBase
Cache3DBase(input_image: Tensor, input_depth: Tensor, input_w2c: Tensor, input_intrinsics: Tensor, input_mask: Tensor | None = None, input_format: list[str] | None = None, input_points: Tensor | None = None, weight_dtype: dtype = float32, is_depth: bool = True, device: str = 'cuda', filter_points_threshold: float = 1.0)

Base class for 3D cache management.

The cache maintains: - input_image: RGB images stored in the cache - input_points: 3D world coordinates for each pixel - input_mask: Validity mask for each pixel

Initialize the 3D cache.

Parameters:

Name Type Description Default
input_image Tensor

Input image tensor with varying dimensions

required
input_depth Tensor

Depth map tensor

required
input_w2c Tensor

World-to-camera transformation matrix

required
input_intrinsics Tensor

Camera intrinsic matrix

required
input_mask Tensor | None

Optional validity mask

None
input_format list[str] | None

Dimension labels for input_image (e.g., ['B', 'C', 'H', 'W'])

None
input_points Tensor | None

Pre-computed 3D world points (alternative to depth)

None
weight_dtype dtype

Data type for computations

float32
is_depth bool

If True, input_depth is z-depth; if False, it's distance

True
device str

Computation device

'cuda'
filter_points_threshold float

Threshold for filtering unreliable depth

1.0
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def __init__(
    self,
    input_image: torch.Tensor,
    input_depth: torch.Tensor,
    input_w2c: torch.Tensor,
    input_intrinsics: torch.Tensor,
    input_mask: torch.Tensor | None = None,
    input_format: list[str] | None = None,
    input_points: torch.Tensor | None = None,
    weight_dtype: torch.dtype = torch.float32,
    is_depth: bool = True,
    device: str = "cuda",
    filter_points_threshold: float = 1.0,
):
    """
    Initialize the 3D cache.

    Args:
        input_image: Input image tensor with varying dimensions
        input_depth: Depth map tensor
        input_w2c: World-to-camera transformation matrix
        input_intrinsics: Camera intrinsic matrix
        input_mask: Optional validity mask
        input_format: Dimension labels for input_image (e.g., ['B', 'C', 'H', 'W'])
        input_points: Pre-computed 3D world points (alternative to depth)
        weight_dtype: Data type for computations
        is_depth: If True, input_depth is z-depth; if False, it's distance
        device: Computation device
        filter_points_threshold: Threshold for filtering unreliable depth
    """
    self.weight_dtype = weight_dtype
    self.is_depth = is_depth
    self.device = device
    self.filter_points_threshold = filter_points_threshold

    if input_format is None:
        assert input_image.dim() == 4
        input_format = ["B", "C", "H", "W"]

    # Map dimension names to indices
    format_to_indices = {dim: idx for idx, dim in enumerate(input_format)}
    input_shape = input_image.shape

    if input_mask is not None:
        input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C"))

    # Extract dimensions
    B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1
    F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1
    N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1
    V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1
    H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None
    W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None

    # Reorder dimensions to B x F x N x V x C x H x W
    desired_dims = ["B", "F", "N", "V", "C", "H", "W"]
    permute_order: list[int | None] = []
    for dim in desired_dims:
        idx = format_to_indices.get(dim)
        permute_order.append(idx)

    permute_indices = [idx for idx in permute_order if idx is not None]
    input_image = input_image.permute(*permute_indices)

    for i, idx in enumerate(permute_order):
        if idx is None:
            input_image = input_image.unsqueeze(i)

    # Now input_image has shape B x F x N x V x C x H x W
    if input_mask is not None:
        self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:]
        self.input_mask = self.input_mask.to("cpu")
    else:
        self.input_mask = None
        self.input_image = input_image
    self.input_image = self.input_image.to(weight_dtype).to("cpu")

    # Compute 3D world points
    if input_points is not None:
        self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu")
        self.input_depth = None
    else:
        input_depth = torch.nan_to_num(input_depth, nan=100)
        input_depth = torch.clamp(input_depth, min=0, max=100)
        if weight_dtype == torch.float16:
            input_depth = torch.clamp(input_depth, max=70)

        self.input_points = (unproject_points(
            input_depth.reshape(-1, 1, H, W),
            input_w2c.reshape(-1, 4, 4),
            input_intrinsics.reshape(-1, 3, 3),
            is_depth=self.is_depth,
        ).to(weight_dtype).reshape(B, F, N, V, H, W, 3).to("cpu"))
        self.input_depth = input_depth

    # Filter unreliable depth
    if self.filter_points_threshold < 1.0 and input_depth is not None:
        input_depth = input_depth.reshape(-1, 1, H, W)
        depth_mask = reliable_depth_mask_range_batch(input_depth,
                                                     ratio_thresh=self.filter_points_threshold).reshape(
                                                         B, F, N, V, 1, H, W)
        if self.input_mask is None:
            self.input_mask = depth_mask.to("cpu")
        else:
            self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device)
Functions
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBase.input_frame_count
input_frame_count() -> int

Return the number of frames in the cache.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def input_frame_count(self) -> int:
    """Return the number of frames in the cache."""
    return self.input_image.shape[1]
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBase.render_cache
render_cache(target_w2cs: Tensor, target_intrinsics: Tensor, render_depth: bool = False, start_frame_idx: int = 0) -> tuple[Tensor, Tensor]

Render the cached 3D points from new camera viewpoints.

Parameters:

Name Type Description Default
target_w2cs Tensor

(b, F_target, 4, 4) target camera transformations

required
target_intrinsics Tensor

(b, F_target, 3, 3) target camera intrinsics

required
render_depth bool

If True, return depth instead of RGB

False
start_frame_idx int

Starting frame index in the cache

0

Returns:

Name Type Description
pixels Tensor

(b, F_target, N, c, h, w) rendered images or depth

masks Tensor

(b, F_target, N, 1, h, w) validity masks

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def render_cache(
    self,
    target_w2cs: torch.Tensor,
    target_intrinsics: torch.Tensor,
    render_depth: bool = False,
    start_frame_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the cached 3D points from new camera viewpoints.

    Args:
        target_w2cs: (b, F_target, 4, 4) target camera transformations
        target_intrinsics: (b, F_target, 3, 3) target camera intrinsics
        render_depth: If True, return depth instead of RGB
        start_frame_idx: Starting frame index in the cache

    Returns:
        pixels: (b, F_target, N, c, h, w) rendered images or depth
        masks: (b, F_target, N, 1, h, w) validity masks
    """
    bs, F_target, _, _ = target_w2cs.shape
    B, F, N, V, C, H, W = self.input_image.shape
    assert bs == B

    target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4)
    target_intrinsics = target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3,
                                                                               3).reshape(-1, 3, 3)

    # Prepare inputs
    first_images = rearrange(
        self.input_image[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, C, H, W),
        "B F N V C H W -> (B F N) V C H W")
    first_points = rearrange(
        self.input_points[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, H, W, 3),
        "B F N V H W C -> (B F N) V H W C")
    first_masks = rearrange(
        self.input_mask[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, 1, H, W),
        "B F N V C H W -> (B F N) V C H W") if self.input_mask is not None else None

    # Process in chunks for memory efficiency
    if first_images.shape[1] == 1:
        warp_chunk_size = 2
        rendered_warp_images = []
        rendered_warp_masks = []
        rendered_warp_depth = []

        first_images = first_images.squeeze(1)
        first_points = first_points.squeeze(1)
        first_masks = first_masks.squeeze(1) if first_masks is not None else None

        for i in range(0, first_images.shape[0], warp_chunk_size):
            with torch.no_grad():
                imgs_chunk = first_images[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                pts_chunk = first_points[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                masks_chunk = (first_masks[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                               if first_masks is not None else None)

                (
                    rendered_warp_images_chunk,
                    rendered_warp_masks_chunk,
                    rendered_warp_depth_chunk,
                    _,
                ) = forward_warp(
                    imgs_chunk,
                    mask1=masks_chunk,
                    depth1=None,
                    transformation1=None,
                    transformation2=target_w2cs[i:i + warp_chunk_size],
                    intrinsic1=target_intrinsics[i:i + warp_chunk_size],
                    intrinsic2=target_intrinsics[i:i + warp_chunk_size],
                    render_depth=render_depth,
                    world_points1=pts_chunk,
                )

                rendered_warp_images.append(rendered_warp_images_chunk.to("cpu"))
                rendered_warp_masks.append(rendered_warp_masks_chunk.to("cpu"))
                if render_depth:
                    rendered_warp_depth.append(rendered_warp_depth_chunk.to("cpu"))

                del imgs_chunk, pts_chunk, masks_chunk
                torch.cuda.empty_cache()

        rendered_warp_images = torch.cat(rendered_warp_images, dim=0)
        rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0)
        if render_depth:
            rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0)
    else:
        raise NotImplementedError("Multi-view rendering not yet supported")

    pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)
    masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)

    if render_depth:
        pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N)

    return pixels.to(self.device), masks.to(self.device)
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBase.update_cache
update_cache(**kwargs)

Update the cache with new frames. To be implemented by subclasses.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def update_cache(self, **kwargs):
    """Update the cache with new frames. To be implemented by subclasses."""
    raise NotImplementedError
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBuffer
Cache3DBuffer(frame_buffer_max: int = 2, noise_aug_strength: float = 0.0, generator: Generator | None = None, **kwargs)

Bases: Cache3DBase

3D cache with frame buffer support.

This class manages multiple frame buffers for temporal consistency and supports noise augmentation for training stability.

Initialize the buffered 3D cache.

Parameters:

Name Type Description Default
frame_buffer_max int

Maximum number of frames to buffer

2
noise_aug_strength float

Strength of noise augmentation per buffer

0.0
generator Generator | None

Random generator for reproducibility

None
**kwargs

Arguments passed to Cache3DBase

{}
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def __init__(
    self,
    frame_buffer_max: int = 2,
    noise_aug_strength: float = 0.0,
    generator: torch.Generator | None = None,
    **kwargs,
):
    """
    Initialize the buffered 3D cache.

    Args:
        frame_buffer_max: Maximum number of frames to buffer
        noise_aug_strength: Strength of noise augmentation per buffer
        generator: Random generator for reproducibility
        **kwargs: Arguments passed to Cache3DBase
    """
    super().__init__(**kwargs)
    self.frame_buffer_max = frame_buffer_max
    self.noise_aug_strength = noise_aug_strength
    self.generator = generator
Functions
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBuffer.render_cache
render_cache(target_w2cs: Tensor, target_intrinsics: Tensor, render_depth: bool = False, start_frame_idx: int = 0) -> tuple[Tensor, Tensor]

Render the cache with optional noise augmentation.

Parameters:

Name Type Description Default
target_w2cs Tensor

(b, F_target, 4, 4) target camera transformations

required
target_intrinsics Tensor

(b, F_target, 3, 3) target camera intrinsics

required
render_depth bool

If True, return depth instead of RGB

False
start_frame_idx int

Starting frame index (must be 0 for this class)

0

Returns:

Name Type Description
pixels Tensor

(b, F_target, N, c, h, w) rendered images

masks Tensor

(b, F_target, N, 1, h, w) validity masks

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def render_cache(
    self,
    target_w2cs: torch.Tensor,
    target_intrinsics: torch.Tensor,
    render_depth: bool = False,
    start_frame_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the cache with optional noise augmentation.

    Args:
        target_w2cs: (b, F_target, 4, 4) target camera transformations
        target_intrinsics: (b, F_target, 3, 3) target camera intrinsics
        render_depth: If True, return depth instead of RGB
        start_frame_idx: Starting frame index (must be 0 for this class)

    Returns:
        pixels: (b, F_target, N, c, h, w) rendered images
        masks: (b, F_target, N, 1, h, w) validity masks
    """
    assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3DBuffer"

    output_device = target_w2cs.device
    target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device)
    target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device)

    pixels, masks = super().render_cache(target_w2cs, target_intrinsics, render_depth)

    pixels = pixels.to(output_device)
    masks = masks.to(output_device)

    # Apply noise augmentation (stronger for older buffers)
    if not render_depth and self.noise_aug_strength > 0:
        noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype)
        per_buffer_noise = (torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device) *
                            self.noise_aug_strength)
        pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1)

    return pixels, masks
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBuffer.update_cache
update_cache(new_image: Tensor, new_depth: Tensor, new_w2c: Tensor, new_mask: Tensor | None = None, new_intrinsics: Tensor | None = None)

Update the cache with a new frame.

Parameters:

Name Type Description Default
new_image Tensor

(B, C, H, W) new RGB image

required
new_depth Tensor

(B, 1, H, W) new depth map

required
new_w2c Tensor

(B, 4, 4) new world-to-camera transformation

required
new_mask Tensor | None

Optional (B, 1, H, W) validity mask

None
new_intrinsics Tensor | None

(B, 3, 3) camera intrinsics (optional)

None
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def update_cache(
    self,
    new_image: torch.Tensor,
    new_depth: torch.Tensor,
    new_w2c: torch.Tensor,
    new_mask: torch.Tensor | None = None,
    new_intrinsics: torch.Tensor | None = None,
):
    """
    Update the cache with a new frame.

    Args:
        new_image: (B, C, H, W) new RGB image
        new_depth: (B, 1, H, W) new depth map
        new_w2c: (B, 4, 4) new world-to-camera transformation
        new_mask: Optional (B, 1, H, W) validity mask
        new_intrinsics: (B, 3, 3) camera intrinsics (optional)
    """
    new_image = new_image.to(self.weight_dtype).to(self.device)
    new_depth = new_depth.to(self.weight_dtype).to(self.device)
    new_w2c = new_w2c.to(self.weight_dtype).to(self.device)
    if new_intrinsics is not None:
        new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device)

    new_depth = torch.nan_to_num(new_depth, nan=1e4)
    new_depth = torch.clamp(new_depth, min=0, max=1e4)

    B, F, N, V, C, H, W = self.input_image.shape

    # Compute new 3D points
    new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu()
    new_image = new_image.cpu()

    if self.filter_points_threshold < 1.0:
        new_depth = new_depth.reshape(-1, 1, H, W)
        depth_mask = reliable_depth_mask_range_batch(new_depth,
                                                     ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W)
        new_mask = depth_mask.to("cpu") if new_mask is None else new_mask * depth_mask.to(new_mask.device)
    if new_mask is not None:
        new_mask = new_mask.cpu()

    # Update buffer (newest frame first)
    if self.frame_buffer_max > 1:
        if self.input_image.shape[2] < self.frame_buffer_max:
            self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2)
            self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2)
            if self.input_mask is not None:
                self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2)
        else:
            self.input_image[:, :, 0] = new_image[:, None, None]
            self.input_points[:, :, 0] = new_points[:, None, None]
            if self.input_mask is not None:
                self.input_mask[:, :, 0] = new_mask[:, None, None]
    else:
        self.input_image = new_image[:, None, None, None]
        self.input_points = new_points[:, None, None, None]
Functions
fastvideo.pipelines.basic.gen3c.cache_3d.bilinear_splatting
bilinear_splatting(frame1: Tensor, mask1: Tensor | None, depth1: Tensor, flow12: Tensor, flow12_mask: Tensor | None = None, is_image: bool = False, depth_weight_scale: float = 50.0) -> tuple[Tensor, Tensor]

Bilinear splatting for forward warping.

Parameters:

Name Type Description Default
frame1 Tensor

(b, c, h, w) source frame

required
mask1 Tensor | None

(b, 1, h, w) valid pixel mask (1 for known, 0 for unknown)

required
depth1 Tensor

(b, 1, h, w) depth map

required
flow12 Tensor

(b, 2, h, w) optical flow from frame1 to frame2

required
flow12_mask Tensor | None

(b, 1, h, w) flow validity mask

None
is_image bool

If True, output will be clipped to (-1, 1) range

False
depth_weight_scale float

Scale factor for depth weighting

50.0

Returns:

Name Type Description
warped_frame2 Tensor

(b, c, h, w) warped frame

mask2 Tensor

(b, 1, h, w) validity mask for warped frame

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def bilinear_splatting(
    frame1: torch.Tensor,
    mask1: torch.Tensor | None,
    depth1: torch.Tensor,
    flow12: torch.Tensor,
    flow12_mask: torch.Tensor | None = None,
    is_image: bool = False,
    depth_weight_scale: float = 50.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Bilinear splatting for forward warping.

    Args:
        frame1: (b, c, h, w) source frame
        mask1: (b, 1, h, w) valid pixel mask (1 for known, 0 for unknown)
        depth1: (b, 1, h, w) depth map
        flow12: (b, 2, h, w) optical flow from frame1 to frame2
        flow12_mask: (b, 1, h, w) flow validity mask
        is_image: If True, output will be clipped to (-1, 1) range
        depth_weight_scale: Scale factor for depth weighting

    Returns:
        warped_frame2: (b, c, h, w) warped frame
        mask2: (b, 1, h, w) validity mask for warped frame
    """
    b, c, h, w = frame1.shape
    device = frame1.device
    dtype = frame1.dtype

    if mask1 is None:
        mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype)
    if flow12_mask is None:
        flow12_mask = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype)

    grid = create_grid(b, h, w, device=device, dtype=dtype)
    trans_pos = flow12 + grid

    trans_pos_offset = trans_pos + 1
    trans_pos_floor = torch.floor(trans_pos_offset).long()
    trans_pos_ceil = torch.ceil(trans_pos_offset).long()

    trans_pos_offset = torch.stack(
        [torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
         torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)],
        dim=1)
    trans_pos_floor = torch.stack(
        [torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
         torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)],
        dim=1)
    trans_pos_ceil = torch.stack(
        [torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
         torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)],
        dim=1)

    # Bilinear weights
    prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
                     (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
    prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
                     (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
    prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
                     (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
    prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
                     (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))

    # Depth weighting for occlusion handling
    clamped_depth1 = torch.clamp(depth1, min=0)
    log_depth1 = torch.log1p(clamped_depth1)
    exponent = log_depth1 / (log_depth1.max() + 1e-7) * depth_weight_scale
    max_exponent = 80.0 if dtype in [torch.float32, torch.bfloat16] else 10.0
    clamped_exponent = torch.clamp(exponent, max=max_exponent)
    depth_weights = torch.exp(clamped_exponent) + 1e-7

    weight_nw = torch.moveaxis(prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
    weight_sw = torch.moveaxis(prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
    weight_ne = torch.moveaxis(prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
    weight_se = torch.moveaxis(prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])

    warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=dtype, device=device)
    warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=dtype, device=device)

    frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2])
    batch_indices = torch.arange(b, device=device, dtype=torch.long)[:, None, None]

    warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),
                            frame1_cl * weight_nw,
                            accumulate=True)
    warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),
                            frame1_cl * weight_sw,
                            accumulate=True)
    warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),
                            frame1_cl * weight_ne,
                            accumulate=True)
    warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),
                            frame1_cl * weight_se,
                            accumulate=True)

    warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), weight_nw, accumulate=True)
    warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), weight_sw, accumulate=True)
    warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), weight_ne, accumulate=True)
    warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), weight_se, accumulate=True)

    warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1])
    warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1])
    cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1]
    cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1]
    cropped_weights = torch.nan_to_num(cropped_weights, nan=1000.0)

    mask = cropped_weights > 0
    zero_value = -1 if is_image else 0
    zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device)
    warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor)
    mask2 = mask.to(frame1)

    if is_image:
        warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1)

    return warped_frame2, mask2
fastvideo.pipelines.basic.gen3c.cache_3d.create_grid
create_grid(b: int, h: int, w: int, device: str = 'cpu', dtype: dtype = float32) -> Tensor

Create a dense grid of (x, y) coordinates of shape (b, 2, h, w).

Parameters:

Name Type Description Default
b int

Batch size

required
h int

Height

required
w int

Width

required
device str

Device for tensor creation

'cpu'
dtype dtype

Data type for tensor

float32

Returns:

Type Description
Tensor

Grid tensor of shape (b, 2, h, w)

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def create_grid(b: int, h: int, w: int, device: str = "cpu", dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """
    Create a dense grid of (x, y) coordinates of shape (b, 2, h, w).

    Args:
        b: Batch size
        h: Height
        w: Width
        device: Device for tensor creation
        dtype: Data type for tensor

    Returns:
        Grid tensor of shape (b, 2, h, w)
    """
    x = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w)
    y = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w)
    return torch.cat([x, y], dim=1)
fastvideo.pipelines.basic.gen3c.cache_3d.forward_warp
forward_warp(frame1: Tensor, mask1: Tensor | None, depth1: Tensor | None, transformation1: Tensor | None, transformation2: Tensor, intrinsic1: Tensor | None, intrinsic2: Tensor | None, is_image: bool = True, is_depth: bool = True, render_depth: bool = False, world_points1: Tensor | None = None) -> tuple[Tensor, Tensor, Tensor | None, Tensor]

Forward warp frame1 to a new view defined by transformation2.

Parameters:

Name Type Description Default
frame1 Tensor

(b, c, h, w) source frame in range [-1, 1] for images

required
mask1 Tensor | None

(b, 1, h, w) valid pixel mask

required
depth1 Tensor | None

(b, 1, h, w) depth map (required if world_points1 is None)

required
transformation1 Tensor | None

(b, 4, 4) source camera w2c (required if depth1 is provided)

required
transformation2 Tensor

(b, 4, 4) target camera w2c

required
intrinsic1 Tensor | None

(b, 3, 3) source camera intrinsics

required
intrinsic2 Tensor | None

(b, 3, 3) target camera intrinsics

required
is_image bool

If True, output will be clipped to (-1, 1)

True
is_depth bool

If True, depth1 is z-depth; if False, it's distance

True
render_depth bool

If True, also return the warped depth map

False
world_points1 Tensor | None

(b, h, w, 3) pre-computed world points (alternative to depth1)

None

Returns:

Name Type Description
warped_frame2 Tensor

(b, c, h, w) warped frame

mask2 Tensor

(b, 1, h, w) validity mask

warped_depth2 Tensor | None

(b, h, w) warped depth (if render_depth=True)

flow12 Tensor

(b, 2, h, w) optical flow

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def forward_warp(
    frame1: torch.Tensor,
    mask1: torch.Tensor | None,
    depth1: torch.Tensor | None,
    transformation1: torch.Tensor | None,
    transformation2: torch.Tensor,
    intrinsic1: torch.Tensor | None,
    intrinsic2: torch.Tensor | None,
    is_image: bool = True,
    is_depth: bool = True,
    render_depth: bool = False,
    world_points1: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
    """
    Forward warp frame1 to a new view defined by transformation2.

    Args:
        frame1: (b, c, h, w) source frame in range [-1, 1] for images
        mask1: (b, 1, h, w) valid pixel mask
        depth1: (b, 1, h, w) depth map (required if world_points1 is None)
        transformation1: (b, 4, 4) source camera w2c (required if depth1 is provided)
        transformation2: (b, 4, 4) target camera w2c
        intrinsic1: (b, 3, 3) source camera intrinsics
        intrinsic2: (b, 3, 3) target camera intrinsics
        is_image: If True, output will be clipped to (-1, 1)
        is_depth: If True, depth1 is z-depth; if False, it's distance
        render_depth: If True, also return the warped depth map
        world_points1: (b, h, w, 3) pre-computed world points (alternative to depth1)

    Returns:
        warped_frame2: (b, c, h, w) warped frame
        mask2: (b, 1, h, w) validity mask
        warped_depth2: (b, h, w) warped depth (if render_depth=True)
        flow12: (b, 2, h, w) optical flow
    """
    device = frame1.device
    b, c, h, w = frame1.shape
    dtype = frame1.dtype

    if mask1 is None:
        mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype)
    if intrinsic2 is None:
        assert intrinsic1 is not None
        intrinsic2 = intrinsic1.clone()

    if world_points1 is not None:
        # Use pre-computed world points
        assert world_points1.shape == (b, h, w, 3)
        trans_points1 = project_points(world_points1, transformation2, intrinsic2)
    else:
        # Compute from depth
        assert depth1 is not None and transformation1 is not None
        assert depth1.shape == (b, 1, h, w)

        depth1 = torch.nan_to_num(depth1, nan=1e4)
        depth1 = torch.clamp(depth1, min=0, max=1e4)

        # Unproject to world, then project to target view
        world_points1 = unproject_points(depth1, transformation1, intrinsic1, is_depth=is_depth)
        trans_points1 = project_points(world_points1, transformation2, intrinsic2)

    # Filter points behind camera
    mask1 = mask1 * (trans_points1[:, :, :, 2, 0].unsqueeze(1) > 0)
    trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0] + 1e-7)
    trans_coordinates = trans_coordinates.permute(0, 3, 1, 2)  # b, 2, h, w
    trans_depth1 = trans_points1[:, :, :, 2, 0].unsqueeze(1)

    grid = create_grid(b, h, w, device=device, dtype=dtype)
    flow12 = trans_coordinates - grid

    warped_frame2, mask2 = bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image)

    warped_depth2 = None
    if render_depth:
        warped_depth2 = bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False)[0][:, 0]

    return warped_frame2, mask2, warped_depth2, flow12
fastvideo.pipelines.basic.gen3c.cache_3d.inverse_with_conversion
inverse_with_conversion(mtx: Tensor) -> Tensor

Compute matrix inverse with float32 conversion for numerical stability.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def inverse_with_conversion(mtx: torch.Tensor) -> torch.Tensor:
    """Compute matrix inverse with float32 conversion for numerical stability."""
    return torch.linalg.inv(mtx.to(torch.float32)).to(mtx.dtype)
fastvideo.pipelines.basic.gen3c.cache_3d.project_points
project_points(world_points: Tensor, w2c: Tensor, intrinsic: Tensor) -> Tensor

Project 3D world points to 2D pixel coordinates.

Parameters:

Name Type Description Default
world_points Tensor

(b, h, w, 3) 3D world coordinates

required
w2c Tensor

(b, 4, 4) world-to-camera transformation matrix

required
intrinsic Tensor

(b, 3, 3) camera intrinsic matrix

required

Returns:

Name Type Description
projected_points Tensor

(b, h, w, 3, 1) projected 2D coordinates (x, y, z)

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def project_points(
    world_points: torch.Tensor,
    w2c: torch.Tensor,
    intrinsic: torch.Tensor,
) -> torch.Tensor:
    """
    Project 3D world points to 2D pixel coordinates.

    Args:
        world_points: (b, h, w, 3) 3D world coordinates
        w2c: (b, 4, 4) world-to-camera transformation matrix
        intrinsic: (b, 3, 3) camera intrinsic matrix

    Returns:
        projected_points: (b, h, w, 3, 1) projected 2D coordinates (x, y, z)
    """
    world_points = world_points.unsqueeze(-1)  # (b, h, w, 3, 1)
    b, h, w, _, _ = world_points.shape

    ones_4d = torch.ones((b, h, w, 1, 1), device=world_points.device, dtype=world_points.dtype)
    world_points_homo = torch.cat([world_points, ones_4d], dim=3)  # (b, h, w, 4, 1)

    trans_4d = w2c[:, None, None]  # (b, 1, 1, 4, 4)
    camera_points_homo = torch.matmul(trans_4d, world_points_homo)  # (b, h, w, 4, 1)

    camera_points = camera_points_homo[:, :, :, :3]  # (b, h, w, 3, 1)
    intrinsic_4d = intrinsic[:, None, None]  # (b, 1, 1, 3, 3)
    projected_points = torch.matmul(intrinsic_4d, camera_points)  # (b, h, w, 3, 1)

    return projected_points
fastvideo.pipelines.basic.gen3c.cache_3d.reliable_depth_mask_range_batch
reliable_depth_mask_range_batch(depth: Tensor, window_size: int = 5, ratio_thresh: float = 0.05, eps: float = 1e-06) -> Tensor

Compute a mask for reliable depth values based on local variation.

Parameters:

Name Type Description Default
depth Tensor

(b, h, w) or (b, 1, h, w) depth map

required
window_size int

Size of the local window (must be odd)

5
ratio_thresh float

Threshold for depth variation ratio

0.05
eps float

Small epsilon for numerical stability

1e-06

Returns:

Name Type Description
reliable_mask Tensor

Boolean mask where True indicates reliable depth

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def reliable_depth_mask_range_batch(
    depth: torch.Tensor,
    window_size: int = 5,
    ratio_thresh: float = 0.05,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    Compute a mask for reliable depth values based on local variation.

    Args:
        depth: (b, h, w) or (b, 1, h, w) depth map
        window_size: Size of the local window (must be odd)
        ratio_thresh: Threshold for depth variation ratio
        eps: Small epsilon for numerical stability

    Returns:
        reliable_mask: Boolean mask where True indicates reliable depth
    """
    assert window_size % 2 == 1, "Window size must be odd."

    if depth.dim() == 3:
        depth_unsq = depth.unsqueeze(1)
    elif depth.dim() == 4:
        depth_unsq = depth
    else:
        raise ValueError("depth tensor must be of shape (b, h, w) or (b, 1, h, w)")

    local_max = F.max_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2)
    local_min = -F.max_pool2d(-depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2)
    local_mean = F.avg_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2)

    ratio = (local_max - local_min) / (local_mean + eps)
    reliable_mask = (ratio < ratio_thresh) & (depth_unsq > 0)

    return reliable_mask
fastvideo.pipelines.basic.gen3c.cache_3d.unproject_points
unproject_points(depth: Tensor, w2c: Tensor, intrinsic: Tensor, is_depth: bool = True, mask: Tensor | None = None) -> Tensor

Unproject depth map to 3D world points.

Parameters:

Name Type Description Default
depth Tensor

(b, 1, h, w) depth map

required
w2c Tensor

(b, 4, 4) world-to-camera transformation matrix

required
intrinsic Tensor

(b, 3, 3) camera intrinsic matrix

required
is_depth bool

If True, depth is z-depth; if False, depth is distance to camera

True
mask Tensor | None

Optional (b, h, w) or (b, 1, h, w) mask for valid pixels

None

Returns:

Name Type Description
world_points Tensor

(b, h, w, 3) 3D world coordinates

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def unproject_points(
    depth: torch.Tensor,
    w2c: torch.Tensor,
    intrinsic: torch.Tensor,
    is_depth: bool = True,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Unproject depth map to 3D world points.

    Args:
        depth: (b, 1, h, w) depth map
        w2c: (b, 4, 4) world-to-camera transformation matrix
        intrinsic: (b, 3, 3) camera intrinsic matrix
        is_depth: If True, depth is z-depth; if False, depth is distance to camera
        mask: Optional (b, h, w) or (b, 1, h, w) mask for valid pixels

    Returns:
        world_points: (b, h, w, 3) 3D world coordinates
    """
    b, _, h, w = depth.shape
    device = depth.device
    dtype = depth.dtype

    if mask is None:
        mask = depth > 0
    if mask.dim() == depth.dim() and mask.shape[1] == 1:
        mask = mask[:, 0]

    idx = torch.nonzero(mask)
    if idx.numel() == 0:
        return torch.zeros((b, h, w, 3), device=device, dtype=dtype)

    b_idx, y_idx, x_idx = idx[:, 0], idx[:, 1], idx[:, 2]

    intrinsic_inv = inverse_with_conversion(intrinsic)  # (b, 3, 3)

    x_valid = x_idx.to(dtype)
    y_valid = y_idx.to(dtype)
    ones = torch.ones_like(x_valid)
    pos = torch.stack([x_valid, y_valid, ones], dim=1).unsqueeze(-1)  # (N, 3, 1)

    intrinsic_inv_valid = intrinsic_inv[b_idx]  # (N, 3, 3)
    unnormalized_pos = torch.matmul(intrinsic_inv_valid, pos)  # (N, 3, 1)

    depth_valid = depth[b_idx, 0, y_idx, x_idx].view(-1, 1, 1)
    if is_depth:
        world_points_cam = depth_valid * unnormalized_pos
    else:
        norm_val = torch.norm(unnormalized_pos, dim=1, keepdim=True)
        direction = unnormalized_pos / (norm_val + 1e-8)
        world_points_cam = depth_valid * direction

    ones_h = torch.ones((world_points_cam.shape[0], 1, 1), device=device, dtype=dtype)
    world_points_homo = torch.cat([world_points_cam, ones_h], dim=1)  # (N, 4, 1)

    trans = inverse_with_conversion(w2c)  # (b, 4, 4)
    trans_valid = trans[b_idx]  # (N, 4, 4)
    world_points_transformed = torch.matmul(trans_valid, world_points_homo)  # (N, 4, 1)
    sparse_points = world_points_transformed[:, :3, 0]  # (N, 3)

    out_points = torch.zeros((b, h, w, 3), device=device, dtype=dtype)
    out_points[b_idx, y_idx, x_idx, :] = sparse_points
    return out_points
fastvideo.pipelines.basic.gen3c.camera_utils

Camera trajectory generation utilities for GEN3C 3D cache conditioning.

Functions
fastvideo.pipelines.basic.gen3c.camera_utils.apply_transformation
apply_transformation(Bx4x4: Tensor, another_matrix: Tensor) -> Tensor

Apply batch transformation to a matrix.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def apply_transformation(Bx4x4: torch.Tensor, another_matrix: torch.Tensor) -> torch.Tensor:
    """Apply batch transformation to a matrix."""
    B = Bx4x4.shape[0]
    if another_matrix.dim() == 2:
        another_matrix = another_matrix.unsqueeze(0).expand(B, -1, -1)
    return torch.bmm(Bx4x4, another_matrix)
fastvideo.pipelines.basic.gen3c.camera_utils.create_horizontal_trajectory
create_horizontal_trajectory(world_to_camera_matrix: Tensor, center_depth: float, positive: bool = True, n_steps: int = 13, distance: float = 0.1, device: str = 'cuda', axis: str = 'x', camera_rotation: str = 'center_facing') -> Tensor

Create a linear camera trajectory along a specified axis.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def create_horizontal_trajectory(
    world_to_camera_matrix: torch.Tensor,
    center_depth: float,
    positive: bool = True,
    n_steps: int = 13,
    distance: float = 0.1,
    device: str = "cuda",
    axis: str = "x",
    camera_rotation: str = "center_facing",
) -> torch.Tensor:
    """Create a linear camera trajectory along a specified axis."""
    look_at_target = torch.tensor([0.0, 0.0, center_depth]).to(device)
    trajectory = []
    initial_camera_pos = torch.tensor([0, 0, 0], device=device, dtype=torch.float32)

    translation_positions = []
    for i in range(n_steps):
        offset = i * distance * center_depth / n_steps * (1 if positive else -1)
        if axis == "x":
            pos = torch.tensor([offset, 0, 0], device=device)
        elif axis == "y":
            pos = torch.tensor([0, offset, 0], device=device)
        elif axis == "z":
            pos = torch.tensor([0, 0, offset], device=device)
        else:
            raise ValueError(f"Axis should be x, y or z, got {axis}")
        translation_positions.append(pos)

    for pos in translation_positions:
        camera_pos = initial_camera_pos + pos
        if camera_rotation == "trajectory_aligned":
            _look_at = look_at_target + pos * 2
        elif camera_rotation == "center_facing":
            _look_at = look_at_target
        elif camera_rotation == "no_rotation":
            _look_at = look_at_target + pos
        else:
            raise ValueError(f"camera_rotation should be center_facing, trajectory_aligned, "
                             f"or no_rotation, got {camera_rotation}")
        view_matrix = look_at_matrix(camera_pos, _look_at)
        trajectory.append(view_matrix)

    trajectory = torch.stack(trajectory)
    return apply_transformation(trajectory, world_to_camera_matrix)
fastvideo.pipelines.basic.gen3c.camera_utils.create_spiral_trajectory
create_spiral_trajectory(world_to_camera_matrix: Tensor, center_depth: float, radius_x: float = 0.03, radius_y: float = 0.02, radius_z: float = 0.0, positive: bool = True, camera_rotation: str = 'center_facing', n_steps: int = 13, device: str = 'cuda', start_from_zero: bool = True, num_circles: int = 1) -> Tensor

Create a spiral/circular camera trajectory.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def create_spiral_trajectory(
    world_to_camera_matrix: torch.Tensor,
    center_depth: float,
    radius_x: float = 0.03,
    radius_y: float = 0.02,
    radius_z: float = 0.0,
    positive: bool = True,
    camera_rotation: str = "center_facing",
    n_steps: int = 13,
    device: str = "cuda",
    start_from_zero: bool = True,
    num_circles: int = 1,
) -> torch.Tensor:
    """Create a spiral/circular camera trajectory."""
    look_at_target = torch.tensor([0.0, 0.0, center_depth]).to(device)
    trajectory = []
    initial_camera_pos = torch.tensor([0, 0, 0], device=device, dtype=torch.float32)

    theta_max = 2 * math.pi * num_circles
    spiral_positions = []

    for i in range(n_steps):
        theta = theta_max * i / (n_steps - 1)
        if start_from_zero:
            x = radius_x * (math.cos(theta) - 1) * (1 if positive else -1) * center_depth
        else:
            x = radius_x * math.cos(theta) * center_depth
        y = radius_y * math.sin(theta) * center_depth
        z = radius_z * math.sin(theta) * center_depth
        spiral_positions.append(torch.tensor([x, y, z], device=device))

    for pos in spiral_positions:
        camera_pos = initial_camera_pos + pos
        if camera_rotation == "center_facing":
            view_matrix = look_at_matrix(camera_pos, look_at_target)
        elif camera_rotation == "trajectory_aligned":
            view_matrix = look_at_matrix(camera_pos, look_at_target + pos * 2)
        elif camera_rotation == "no_rotation":
            view_matrix = look_at_matrix(camera_pos, look_at_target + pos)
        else:
            raise ValueError(f"camera_rotation should be center_facing, trajectory_aligned, "
                             f"or no_rotation, got {camera_rotation}")
        trajectory.append(view_matrix)

    trajectory = torch.stack(trajectory)
    return apply_transformation(trajectory, world_to_camera_matrix)
fastvideo.pipelines.basic.gen3c.camera_utils.generate_camera_trajectory
generate_camera_trajectory(trajectory_type: str, initial_w2c: Tensor, initial_intrinsics: Tensor, num_frames: int, movement_distance: float, camera_rotation: str = 'center_facing', center_depth: float = 1.0, device: str = 'cuda') -> tuple[Tensor, Tensor]

Generate camera trajectory for GEN3C video generation.

Parameters:

Name Type Description Default
trajectory_type str

One of "left", "right", "up", "down", "zoom_in", "zoom_out", "clockwise", "counterclockwise".

required
initial_w2c Tensor

Initial world-to-camera matrix (4, 4).

required
initial_intrinsics Tensor

Camera intrinsics matrix (3, 3).

required
num_frames int

Number of frames in the trajectory.

required
movement_distance float

Distance factor for camera movement.

required
camera_rotation str

"center_facing", "no_rotation", or "trajectory_aligned".

'center_facing'
center_depth float

Depth of the scene center point.

1.0
device str

Computation device.

'cuda'

Returns:

Name Type Description
generated_w2cs Tensor

(1, num_frames, 4, 4) world-to-camera matrices.

generated_intrinsics Tensor

(1, num_frames, 3, 3) camera intrinsics.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def generate_camera_trajectory(
    trajectory_type: str,
    initial_w2c: torch.Tensor,
    initial_intrinsics: torch.Tensor,
    num_frames: int,
    movement_distance: float,
    camera_rotation: str = "center_facing",
    center_depth: float = 1.0,
    device: str = "cuda",
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Generate camera trajectory for GEN3C video generation.

    Args:
        trajectory_type: One of "left", "right", "up", "down", "zoom_in",
            "zoom_out", "clockwise", "counterclockwise".
        initial_w2c: Initial world-to-camera matrix (4, 4).
        initial_intrinsics: Camera intrinsics matrix (3, 3).
        num_frames: Number of frames in the trajectory.
        movement_distance: Distance factor for camera movement.
        camera_rotation: "center_facing", "no_rotation", or "trajectory_aligned".
        center_depth: Depth of the scene center point.
        device: Computation device.

    Returns:
        generated_w2cs: (1, num_frames, 4, 4) world-to-camera matrices.
        generated_intrinsics: (1, num_frames, 3, 3) camera intrinsics.
    """
    if trajectory_type in ["clockwise", "counterclockwise"]:
        new_w2cs_seq = create_spiral_trajectory(
            world_to_camera_matrix=initial_w2c,
            center_depth=center_depth,
            n_steps=num_frames,
            positive=trajectory_type == "clockwise",
            device=device,
            camera_rotation=camera_rotation,
            radius_x=movement_distance,
            radius_y=movement_distance,
        )
    elif trajectory_type == "none":
        # Static camera - repeat identity
        new_w2cs_seq = initial_w2c.unsqueeze(0).expand(num_frames, -1, -1)
    else:
        axis_map = {
            "left": (False, "x"),
            "right": (True, "x"),
            "up": (False, "y"),
            "down": (True, "y"),
            "zoom_in": (True, "z"),
            "zoom_out": (False, "z"),
        }
        if trajectory_type not in axis_map:
            raise ValueError(f"Unsupported trajectory type: {trajectory_type}")
        positive, axis = axis_map[trajectory_type]

        new_w2cs_seq = create_horizontal_trajectory(
            world_to_camera_matrix=initial_w2c,
            center_depth=center_depth,
            n_steps=num_frames,
            positive=positive,
            axis=axis,
            distance=movement_distance,
            device=device,
            camera_rotation=camera_rotation,
        )

    generated_w2cs = new_w2cs_seq.unsqueeze(0)  # (1, num_frames, 4, 4)
    if initial_intrinsics.dim() == 2:
        generated_intrinsics = initial_intrinsics.unsqueeze(0).unsqueeze(0).repeat(1, num_frames, 1, 1)
    else:
        generated_intrinsics = initial_intrinsics.unsqueeze(0)

    return generated_w2cs, generated_intrinsics
fastvideo.pipelines.basic.gen3c.camera_utils.look_at_matrix
look_at_matrix(camera_pos: Tensor, target: Tensor, invert_pos: bool = True) -> Tensor

Create a 4x4 look-at view matrix pointing camera toward target.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def look_at_matrix(camera_pos: torch.Tensor, target: torch.Tensor, invert_pos: bool = True) -> torch.Tensor:
    """Create a 4x4 look-at view matrix pointing camera toward target."""
    forward = (target - camera_pos).float()
    forward = forward / torch.norm(forward)

    up = torch.tensor([0.0, 1.0, 0.0], device=camera_pos.device)
    right = torch.cross(up, forward)
    right = right / torch.norm(right)
    up = torch.cross(forward, right)

    look_at = torch.eye(4, device=camera_pos.device)
    look_at[0, :3] = right
    look_at[1, :3] = up
    look_at[2, :3] = forward
    look_at[:3, 3] = (-camera_pos) if invert_pos else camera_pos

    return look_at
fastvideo.pipelines.basic.gen3c.depth_estimation

MoGe-based monocular depth estimation for GEN3C 3D cache conditioning.

Functions
fastvideo.pipelines.basic.gen3c.depth_estimation.load_moge_model
load_moge_model(model_name: str = 'Ruicheng/moge-vitl', device: str | device = 'cuda') -> MoGeModel

Load MoGe depth estimation model from HuggingFace.

Parameters:

Name Type Description Default
model_name str

HuggingFace model identifier.

'Ruicheng/moge-vitl'
device str | device

Device to load model on.

'cuda'

Returns:

Type Description
MoGeModel

Loaded MoGe model.

Source code in fastvideo/pipelines/basic/gen3c/depth_estimation.py
def load_moge_model(
    model_name: str = "Ruicheng/moge-vitl",
    device: str | torch.device = "cuda",
) -> MoGeModel:
    """Load MoGe depth estimation model from HuggingFace.

    Args:
        model_name: HuggingFace model identifier.
        device: Device to load model on.

    Returns:
        Loaded MoGe model.
    """
    try:
        from moge.model.v1 import MoGeModel
    except ImportError as exc:
        raise ImportError("MoGe is required for GEN3C 3D cache conditioning. "
                          "Install it with: pip install git+https://github.com/microsoft/MoGe.git. "
                          "If import fails with libGL.so.1, install system deps: "
                          "sudo apt-get install -y libgl1 libglib2.0-0 libsm6 libxext6 libxrender1") from exc

    logger.info("Loading MoGe depth model: %s", model_name)
    model = MoGeModel.from_pretrained(model_name).to(device)
    model.eval()
    logger.info("MoGe model loaded successfully")
    return model
fastvideo.pipelines.basic.gen3c.depth_estimation.predict_depth_from_path
predict_depth_from_path(image_path: str, target_h: int, target_w: int, device: device, moge_model: MoGeModel) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]

Predict depth, intrinsics, and mask from an image file path.

Parameters:

Name Type Description Default
image_path str

Path to input image (RGB or BGR, any format cv2 supports).

required
target_h int

Target height for output tensors.

required
target_w int

Target width for output tensors.

required
device device

Computation device.

required
moge_model MoGeModel

Loaded MoGe model.

required

Returns:

Name Type Description
image Tensor

(1, 1, 3, target_h, target_w) image tensor in [-1, 1].

depth Tensor

(1, 1, 1, target_h, target_w) depth map.

mask Tensor

(1, 1, 1, target_h, target_w) confidence mask.

w2c Tensor

(1, 1, 4, 4) world-to-camera matrix (identity).

intrinsics Tensor

(1, 1, 3, 3) camera intrinsics.

Source code in fastvideo/pipelines/basic/gen3c/depth_estimation.py
def predict_depth_from_path(
    image_path: str,
    target_h: int,
    target_w: int,
    device: torch.device,
    moge_model: MoGeModel,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Predict depth, intrinsics, and mask from an image file path.

    Args:
        image_path: Path to input image (RGB or BGR, any format cv2 supports).
        target_h: Target height for output tensors.
        target_w: Target width for output tensors.
        device: Computation device.
        moge_model: Loaded MoGe model.

    Returns:
        image: (1, 1, 3, target_h, target_w) image tensor in [-1, 1].
        depth: (1, 1, 1, target_h, target_w) depth map.
        mask: (1, 1, 1, target_h, target_w) confidence mask.
        w2c: (1, 1, 4, 4) world-to-camera matrix (identity).
        intrinsics: (1, 1, 3, 3) camera intrinsics.
    """
    import cv2

    input_image_bgr = cv2.imread(image_path)
    if input_image_bgr is None:
        raise FileNotFoundError(f"Input image not found: {image_path}")
    input_image_rgb = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB)

    return _predict_depth_core(input_image_rgb, target_h, target_w, device, moge_model)
fastvideo.pipelines.basic.gen3c.depth_estimation.predict_depth_from_tensor
predict_depth_from_tensor(image_tensor: Tensor, moge_model: MoGeModel) -> tuple[Tensor, Tensor]

Predict depth and mask from an image tensor (for autoregressive generation).

Parameters:

Name Type Description Default
image_tensor Tensor

(C, H, W) image tensor in [0, 1] range.

required
moge_model MoGeModel

Loaded MoGe model.

required

Returns:

Name Type Description
depth Tensor

(1, 1, H, W) depth map.

mask Tensor

(1, 1, H, W) confidence mask.

Source code in fastvideo/pipelines/basic/gen3c/depth_estimation.py
def predict_depth_from_tensor(
    image_tensor: torch.Tensor,
    moge_model: MoGeModel,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Predict depth and mask from an image tensor (for autoregressive generation).

    Args:
        image_tensor: (C, H, W) image tensor in [0, 1] range.
        moge_model: Loaded MoGe model.

    Returns:
        depth: (1, 1, H, W) depth map.
        mask: (1, 1, H, W) confidence mask.
    """
    moge_output = moge_model.infer(image_tensor)
    depth = moge_output["depth"]
    mask = moge_output["mask"]

    depth = depth.unsqueeze(0).unsqueeze(0)
    depth = torch.nan_to_num(depth, nan=1e4)
    depth = torch.clamp(depth, min=0, max=1e4)

    mask = mask.unsqueeze(0).unsqueeze(0)
    depth = torch.where(mask == 0, torch.tensor(1000.0, device=depth.device), depth)

    return depth, mask
fastvideo.pipelines.basic.gen3c.gen3c_pipeline

GEN3C video diffusion pipeline wiring.

Classes
fastvideo.pipelines.basic.gen3c.gen3c_pipeline.Gen3CPipeline
Gen3CPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

GEN3C Video Generation Pipeline.

This pipeline extends Cosmos with 3D cache support for camera-controlled video generation. When an input image is provided, it runs the full 3D cache conditioning pipeline (depth estimation -> point cloud -> camera trajectory -> forward warping -> VAE encoding).

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.gen3c.gen3c_pipeline.Gen3CPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/gen3c/gen3c_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="cfg_policy_stage", stage=Gen3CCFGPolicyStage())

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=Gen3CConditioningStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=Gen3CLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                     transformer=self.get_module("transformer"),
                                                     vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=Gen3CDenoisingStage(transformer=self.get_module("transformer"),
                                             scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.gen3c.presets

GEN3C model family pipeline presets.

Classes

fastvideo.pipelines.basic.hunyuan

Modules

fastvideo.pipelines.basic.hunyuan.hunyuan_pipeline

Hunyuan video diffusion pipeline implementation.

This module contains an implementation of the Hunyuan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.hunyuan.hunyuan_pipeline.HunyuanVideoPipeline
HunyuanVideoPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.hunyuan.hunyuan_pipeline.HunyuanVideoPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/hunyuan/hunyuan_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage_primary",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder"),
                                      self.get_module("text_encoder_2")],
                       tokenizers=[self.get_module("tokenizer"),
                                   self.get_module("tokenizer_2")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer")))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.hunyuan.presets

Hunyuan model family pipeline presets.

Classes

fastvideo.pipelines.basic.hunyuan15

Modules

fastvideo.pipelines.basic.hunyuan15.hunyuan15_2sr_pipeline

Hunyuan video diffusion pipeline implementation.

This module contains an implementation of the Hunyuan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.hunyuan15.hunyuan15_2sr_pipeline.HunyuanVideo152SRPipeline
HunyuanVideo152SRPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.hunyuan15.hunyuan15_2sr_pipeline.HunyuanVideo152SRPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/hunyuan15/hunyuan15_2sr_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage_primary",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder"),
                                      self.get_module("text_encoder_2")],
                       tokenizers=[self.get_module("tokenizer"),
                                   self.get_module("tokenizer_2")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer")))

    self.add_stage(stage_name="image_encoding_stage",
                   stage=Hy15ImageEncodingStage(image_encoder=None, image_processor=None))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="sr_720p_latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer_2")))

    self.add_stage(stage_name="sr_720p_denoising_stage",
                   stage=SRDenoisingStage(transformer=self.get_module("transformer_2"),
                                          scheduler=self.get_module("scheduler"),
                                          upsampler=self.get_module("upsampler")))

    self.add_stage(stage_name="sr_1080p_latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer_3")))

    self.add_stage(stage_name="sr_1080p_denoising_stage",
                   stage=SRDenoisingStage(transformer=self.get_module("transformer_3"),
                                          scheduler=self.get_module("scheduler"),
                                          upsampler=self.get_module("upsampler_2")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
fastvideo.pipelines.basic.hunyuan15.hunyuan15_2sr_pipeline.HunyuanVideo152SRPipeline.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Generate a video or image using the pipeline.

Parameters:

Name Type Description Default
batch ForwardBatch

The batch to generate from.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns: ForwardBatch: The batch with the generated video or image.

Source code in fastvideo/pipelines/basic/hunyuan15/hunyuan15_2sr_pipeline.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Generate a video or image using the pipeline.

    Args:
        batch: The batch to generate from.
        fastvideo_args: The inference arguments.
    Returns:
        ForwardBatch: The batch with the generated video or image.
    """
    if not self.post_init_called:
        self.post_init()

    self.get_module("transformer").to(get_local_torch_device())
    # Execute each stage
    logger.info("Running pipeline stages: %s", self._stage_name_mapping.keys())
    # logger.info("Batch: %s", batch)
    batch = self.input_validation_stage(batch, fastvideo_args)
    batch = self.prompt_encoding_stage_primary(batch, fastvideo_args)
    batch = self.conditioning_stage(batch, fastvideo_args)
    batch = self.timestep_preparation_stage(batch, fastvideo_args)
    batch = self.latent_preparation_stage(batch, fastvideo_args)
    batch = self.image_encoding_stage(batch, fastvideo_args)
    batch = self.denoising_stage(batch, fastvideo_args)
    self.get_module("transformer").to("cpu")

    # 720p SR
    self.get_module("transformer_2").to(get_local_torch_device())
    batch.lq_latents = batch.latents
    batch.latents = None
    batch.height = 720
    batch.width = 1280
    batch.num_inference_steps_sr = 6
    batch = self.sr_720p_latent_preparation_stage(batch, fastvideo_args)
    batch = self.image_encoding_stage(batch, fastvideo_args)
    batch = self.sr_720p_denoising_stage(batch, fastvideo_args)
    self.get_module("transformer_2").to("cpu")

    # 1080p SR
    self.get_module("transformer_3").to(get_local_torch_device())
    batch.lq_latents = batch.latents
    batch.latents = None
    batch.height = 1072
    batch.width = 1920
    batch.num_inference_steps_sr = 8
    batch = self.sr_1080p_latent_preparation_stage(batch, fastvideo_args)
    batch = self.image_encoding_stage(batch, fastvideo_args)
    batch = self.sr_1080p_denoising_stage(batch, fastvideo_args)
    self.get_module("transformer_3").to("cpu")

    start_time = time.time()
    batch = self.decoding_stage(batch, fastvideo_args)
    end_time = time.time()
    logger.info("Decoding time: %s seconds", end_time - start_time)

    # Return the output
    return batch
Functions
fastvideo.pipelines.basic.hunyuan15.hunyuan15_i2v_pipeline

Hunyuan video diffusion pipeline implementation.

This module contains an implementation of the Hunyuan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.hunyuan15.hunyuan15_i2v_pipeline.HunyuanVideo15ImageToVideoPipeline
HunyuanVideo15ImageToVideoPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.hunyuan15.hunyuan15_i2v_pipeline.HunyuanVideo15ImageToVideoPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/hunyuan15/hunyuan15_i2v_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage_primary",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder"),
                                      self.get_module("text_encoder_2")],
                       tokenizers=[self.get_module("tokenizer"),
                                   self.get_module("tokenizer_2")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer")))

    self.add_stage(stage_name="image_encoding_stage",
                   stage=Hy15ImageEncodingStage(image_encoder=None, image_processor=None))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.hunyuan15.hunyuan15_pipeline

Hunyuan video diffusion pipeline implementation.

This module contains an implementation of the Hunyuan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.hunyuan15.hunyuan15_pipeline.HunyuanVideo15Pipeline
HunyuanVideo15Pipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.hunyuan15.hunyuan15_pipeline.HunyuanVideo15Pipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/hunyuan15/hunyuan15_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage_primary",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder"),
                                      self.get_module("text_encoder_2")],
                       tokenizers=[self.get_module("tokenizer"),
                                   self.get_module("tokenizer_2")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer")))

    self.add_stage(stage_name="image_encoding_stage",
                   stage=Hy15ImageEncodingStage(image_encoder=None, image_processor=None))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.hunyuan15.hunyuan15_sr_pipeline

Hunyuan video diffusion pipeline implementation.

This module contains an implementation of the Hunyuan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.hunyuan15.hunyuan15_sr_pipeline.HunyuanVideo15SRPipeline
HunyuanVideo15SRPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.hunyuan15.hunyuan15_sr_pipeline.HunyuanVideo15SRPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/hunyuan15/hunyuan15_sr_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage_primary",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder"),
                                      self.get_module("text_encoder_2")],
                       tokenizers=[self.get_module("tokenizer"),
                                   self.get_module("tokenizer_2")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer")))

    self.add_stage(stage_name="image_encoding_stage",
                   stage=Hy15ImageEncodingStage(image_encoder=None, image_processor=None))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="sr_latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer_2")))

    self.add_stage(stage_name="sr_denoising_stage",
                   stage=SRDenoisingStage(transformer=self.get_module("transformer_2"),
                                          scheduler=self.get_module("scheduler"),
                                          upsampler=self.get_module("upsampler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
fastvideo.pipelines.basic.hunyuan15.hunyuan15_sr_pipeline.HunyuanVideo15SRPipeline.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Generate a video or image using the pipeline.

Parameters:

Name Type Description Default
batch ForwardBatch

The batch to generate from.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns: ForwardBatch: The batch with the generated video or image.

Source code in fastvideo/pipelines/basic/hunyuan15/hunyuan15_sr_pipeline.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Generate a video or image using the pipeline.

    Args:
        batch: The batch to generate from.
        fastvideo_args: The inference arguments.
    Returns:
        ForwardBatch: The batch with the generated video or image.
    """
    if not self.post_init_called:
        self.post_init()

    # Execute each stage
    logger.info("Running pipeline stages: %s", self._stage_name_mapping.keys())
    # logger.info("Batch: %s", batch)
    self.get_module("transformer").to(get_local_torch_device())
    batch = self.input_validation_stage(batch, fastvideo_args)
    batch = self.prompt_encoding_stage_primary(batch, fastvideo_args)
    batch = self.conditioning_stage(batch, fastvideo_args)
    batch = self.timestep_preparation_stage(batch, fastvideo_args)
    batch = self.latent_preparation_stage(batch, fastvideo_args)
    batch = self.image_encoding_stage(batch, fastvideo_args)
    batch = self.denoising_stage(batch, fastvideo_args)
    self.get_module("transformer").to("cpu")

    self.get_module("transformer_2").to(get_local_torch_device())
    batch.lq_latents = batch.latents
    batch.latents = None
    batch.height = batch.height_sr
    batch.width = batch.width_sr
    batch = self.sr_latent_preparation_stage(batch, fastvideo_args)
    batch = self.image_encoding_stage(batch, fastvideo_args)
    batch = self.sr_denoising_stage(batch, fastvideo_args)
    self.get_module("transformer_2").to("cpu")

    start_time = time.time()
    batch = self.decoding_stage(batch, fastvideo_args)
    end_time = time.time()
    logger.info("Decoding time: %s seconds", end_time - start_time)

    # Return the output
    return batch
Functions
fastvideo.pipelines.basic.hunyuan15.presets

Hunyuan 1.5 model family pipeline presets.

Classes

fastvideo.pipelines.basic.hyworld

Modules

fastvideo.pipelines.basic.hyworld.hyworld_pipeline

HYWorld video diffusion pipeline implementation.

This module contains an implementation of the HYWorld video diffusion pipeline using the modular pipeline architecture with HYWorld-specific denoising stage for chunk-based video generation with context frame selection.

Classes
fastvideo.pipelines.basic.hyworld.hyworld_pipeline.HYWorldPipeline
HYWorldPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

HYWorld video diffusion pipeline.

This pipeline implements chunk-based video generation with context frame selection for 3D-aware generation using HYWorldDenoisingStage.

Note: HYWorld only uses a single LLM-based text encoder, unlike SDXL-style dual encoder setups. The text_encoder_2/tokenizer_2 are not used.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.hyworld.hyworld_pipeline.HYWorldPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with HYWorld-specific denoising stage.

Source code in fastvideo/pipelines/basic/hyworld/hyworld_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with HYWorld-specific denoising stage."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage_primary",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder"),
                                      self.get_module("text_encoder_2")],
                       tokenizers=[self.get_module("tokenizer"),
                                   self.get_module("tokenizer_2")]))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer")))

    self.add_stage(stage_name="image_encoding_stage",
                   stage=HYWorldImageEncodingStage(image_encoder=self.get_module("image_encoder"),
                                                   image_processor=self.get_module("feature_extractor"),
                                                   vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=HYWorldDenoisingStage(transformer=self.get_module("transformer"),
                                               scheduler=self.get_module("scheduler"),
                                               pipeline=self))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.hyworld.presets

HYWorld model family pipeline presets.

Classes

fastvideo.pipelines.basic.lingbotworld

Modules

fastvideo.pipelines.basic.lingbotworld.lingbotworld_pipeline

Wan video diffusion pipeline implementation.

This module contains an implementation of the Wan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.lingbotworld.presets

LingBotWorld model family pipeline presets.

Classes

fastvideo.pipelines.basic.longcat

LongCat pipeline module.

Classes

fastvideo.pipelines.basic.longcat.LongCatImageToVideoPipeline
LongCatImageToVideoPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

LongCat Image-to-Video pipeline.

Generates video from a single input image using Tier 3 I2V conditioning: - Per-frame timestep masking (timestep[:, 0] = 0) - num_cond_latents parameter to transformer - RoPE skipping for conditioning frames - Selective denoising (skip first frame in scheduler)

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.longcat.LongCatImageToVideoPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up I2V-specific pipeline stages.

Source code in fastvideo/pipelines/basic/longcat/longcat_i2v_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up I2V-specific pipeline stages."""

    # 1. Input validation
    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    # 2. Text encoding (same as T2V)
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    # 3. Image VAE encoding (for I2V - skipped in refinement mode)
    self.add_stage(stage_name="image_vae_encoding_stage",
                   stage=LongCatImageVAEEncodingStage(vae=self.get_module("vae")))

    # 4. Refinement initialization (skipped if not refining)
    self.add_stage(stage_name="longcat_refine_init_stage", stage=LongCatRefineInitStage(vae=self.get_module("vae")))

    # 5. Timestep preparation (generic)
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    # 6. Refinement timestep override (skipped if not refining)
    self.add_stage(stage_name="longcat_refine_timestep_stage",
                   stage=LongCatRefineTimestepStage(scheduler=self.get_module("scheduler")))

    # 7. Latent preparation with I2V conditioning
    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LongCatI2VLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                          transformer=self.get_module("transformer")))

    # 8. Denoising with I2V support
    self.add_stage(stage_name="denoising_stage",
                   stage=LongCatI2VDenoisingStage(transformer=self.get_module("transformer"),
                                                  transformer_2=self.get_module("transformer_2", None),
                                                  scheduler=self.get_module("scheduler"),
                                                  vae=self.get_module("vae"),
                                                  pipeline=self))

    # 9. Decoding
    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
fastvideo.pipelines.basic.longcat.LongCatImageToVideoPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize LongCat-specific components.

Source code in fastvideo/pipelines/basic/longcat/longcat_i2v_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize LongCat-specific components."""
    # Same BSA initialization as base LongCat pipeline
    pipeline_config = fastvideo_args.pipeline_config
    transformer = self.get_module("transformer", None)
    if transformer is None:
        return

    # Enable BSA if configured
    if pipeline_config.enable_bsa:
        bsa_params_cfg = getattr(pipeline_config, 'bsa_params', None) or {}
        sparsity = getattr(pipeline_config, 'bsa_sparsity', None)
        cdf_threshold = getattr(pipeline_config, 'bsa_cdf_threshold', None)
        chunk_q = getattr(pipeline_config, 'bsa_chunk_q', None)
        chunk_k = getattr(pipeline_config, 'bsa_chunk_k', None)

        effective_bsa_params = dict(bsa_params_cfg) if isinstance(bsa_params_cfg, dict) else {}
        if sparsity is not None:
            effective_bsa_params['sparsity'] = sparsity
        if cdf_threshold is not None:
            effective_bsa_params['cdf_threshold'] = cdf_threshold
        if chunk_q is not None:
            effective_bsa_params['chunk_3d_shape_q'] = chunk_q
        if chunk_k is not None:
            effective_bsa_params['chunk_3d_shape_k'] = chunk_k

        # Provide defaults
        effective_bsa_params.setdefault('sparsity', 0.9375)
        effective_bsa_params.setdefault('chunk_3d_shape_q', [4, 4, 4])
        effective_bsa_params.setdefault('chunk_3d_shape_k', [4, 4, 4])

        if hasattr(transformer, 'enable_bsa'):
            logger.info("Enabling BSA for LongCat I2V transformer")
            transformer.enable_bsa()
            if hasattr(transformer, 'blocks'):
                try:
                    for blk in transformer.blocks:
                        if hasattr(blk, 'self_attn'):
                            blk.self_attn.bsa_params = effective_bsa_params
                except Exception as e:
                    logger.warning("Failed to set BSA params: %s", e)
            logger.info("BSA parameters: %s", effective_bsa_params)
    else:
        if hasattr(transformer, 'disable_bsa'):
            transformer.disable_bsa()
fastvideo.pipelines.basic.longcat.LongCatPipeline
LongCatPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

LongCat video diffusion pipeline with LoRA support.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.longcat.LongCatPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/longcat/longcat_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    # Add refine initialization stage (will be skipped if not refining)
    self.add_stage(stage_name="longcat_refine_init_stage", stage=LongCatRefineInitStage(vae=self.get_module("vae")))

    # First prepare generic timesteps (for non-refine paths)
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    # Then override timesteps for refinement (will be a no-op if not refining),
    # matching LongCat's generate_refine schedule.
    self.add_stage(stage_name="longcat_refine_timestep_stage",
                   stage=LongCatRefineTimestepStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))

    self.add_stage(stage_name="denoising_stage",
                   stage=LongCatDenoisingStage(transformer=self.get_module("transformer"),
                                               transformer_2=self.get_module("transformer_2", None),
                                               scheduler=self.get_module("scheduler"),
                                               vae=self.get_module("vae"),
                                               pipeline=self))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
fastvideo.pipelines.basic.longcat.LongCatPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize LongCat-specific components.

Source code in fastvideo/pipelines/basic/longcat/longcat_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize LongCat-specific components."""

    # Enable BSA (Block Sparse Attention) if configured
    pipeline_config = fastvideo_args.pipeline_config
    transformer = self.get_module("transformer", None)
    if transformer is None:
        raise RuntimeError("Transformer module not found during initializing LongCat pipeline.")
    # If user toggles BSA via CLI/config
    if pipeline_config.enable_bsa:
        # Build effective BSA params:
        # 1) from explicit CLI overrides if provided
        # 2) else from pipeline_config.bsa_params
        # 3) else fall back to reasonable defaults
        bsa_params_cfg = pipeline_config.bsa_params
        sparsity = pipeline_config.bsa_sparsity
        cdf_threshold = pipeline_config.bsa_cdf_threshold
        chunk_q = pipeline_config.bsa_chunk_q
        chunk_k = pipeline_config.bsa_chunk_k

        effective_bsa_params = dict(bsa_params_cfg) if isinstance(bsa_params_cfg, dict) else {}
        if sparsity is not None:
            effective_bsa_params['sparsity'] = sparsity
        if cdf_threshold is not None:
            effective_bsa_params['cdf_threshold'] = cdf_threshold
        if chunk_q is not None:
            effective_bsa_params['chunk_3d_shape_q'] = chunk_q
        if chunk_k is not None:
            effective_bsa_params['chunk_3d_shape_k'] = chunk_k
        # Provide defaults if still missing
        effective_bsa_params.setdefault('sparsity', 0.9375)
        effective_bsa_params.setdefault('chunk_3d_shape_q', [4, 4, 4])
        effective_bsa_params.setdefault('chunk_3d_shape_k', [4, 4, 4])

        if hasattr(transformer, 'enable_bsa'):
            logger.info("Enabling Block Sparse Attention (BSA) for LongCat transformer")
            transformer.enable_bsa()
            # Propagate params to all attention modules
            if hasattr(transformer, 'blocks'):
                try:
                    for blk in transformer.blocks:
                        if hasattr(blk, 'self_attn'):
                            blk.self_attn.bsa_params = effective_bsa_params
                except Exception as e:
                    logger.warning("Failed to set BSA params on all blocks: %s", e)
            logger.info("BSA parameters in effect: %s", effective_bsa_params)
        else:
            logger.warning("BSA is enabled in config but transformer does not support it")
    else:
        # Explicitly disable if present
        if hasattr(transformer, 'disable_bsa'):
            transformer.disable_bsa()
fastvideo.pipelines.basic.longcat.LongCatVideoContinuationPipeline
LongCatVideoContinuationPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

LongCat Video Continuation pipeline.

Generates video continuation from multiple conditioning frames using optional KV cache for 2-3x speedup.

Key features: - Takes video input (13+ frames typically) - Encodes conditioning frames via VAE - Optionally pre-computes KV cache for conditioning - Uses cached K/V during denoising for speedup - Concatenates conditioning back after denoising

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.longcat.LongCatVideoContinuationPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up VC-specific pipeline stages.

Source code in fastvideo/pipelines/basic/longcat/longcat_vc_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up VC-specific pipeline stages."""

    # 1. Input validation
    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    # 2. Text encoding
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    # 3. Video VAE encoding (encodes conditioning frames)
    self.add_stage(stage_name="video_vae_encoding_stage",
                   stage=LongCatVideoVAEEncodingStage(vae=self.get_module("vae")))

    # 4. Timestep preparation
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    # 5. Latent preparation (reuse I2V stage - it handles video_latent too)
    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LongCatVCLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                         transformer=self.get_module("transformer")))

    # 6. KV cache initialization (optional, based on config)
    # This is always added but will skip if use_kv_cache=False
    self.add_stage(stage_name="kv_cache_init_stage",
                   stage=LongCatKVCacheInitStage(transformer=self.get_module("transformer")))

    # 7. Denoising with VC and KV cache support
    self.add_stage(stage_name="denoising_stage",
                   stage=LongCatVCDenoisingStage(transformer=self.get_module("transformer"),
                                                 transformer_2=self.get_module("transformer_2", None),
                                                 scheduler=self.get_module("scheduler"),
                                                 vae=self.get_module("vae"),
                                                 pipeline=self))

    # 8. Decoding
    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
fastvideo.pipelines.basic.longcat.LongCatVideoContinuationPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize LongCat-specific components.

Source code in fastvideo/pipelines/basic/longcat/longcat_vc_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize LongCat-specific components."""
    pipeline_config = fastvideo_args.pipeline_config
    transformer = self.get_module("transformer", None)
    if transformer is None:
        return

    # Enable BSA if configured (for VC, BSA may not be needed)
    if getattr(pipeline_config, 'enable_bsa', False):
        bsa_params_cfg = getattr(pipeline_config, 'bsa_params', None) or {}
        sparsity = getattr(pipeline_config, 'bsa_sparsity', None)
        cdf_threshold = getattr(pipeline_config, 'bsa_cdf_threshold', None)
        chunk_q = getattr(pipeline_config, 'bsa_chunk_q', None)
        chunk_k = getattr(pipeline_config, 'bsa_chunk_k', None)

        effective_bsa_params = dict(bsa_params_cfg) if isinstance(bsa_params_cfg, dict) else {}
        if sparsity is not None:
            effective_bsa_params['sparsity'] = sparsity
        if cdf_threshold is not None:
            effective_bsa_params['cdf_threshold'] = cdf_threshold
        if chunk_q is not None:
            effective_bsa_params['chunk_3d_shape_q'] = chunk_q
        if chunk_k is not None:
            effective_bsa_params['chunk_3d_shape_k'] = chunk_k

        # Provide defaults
        effective_bsa_params.setdefault('sparsity', 0.9375)
        effective_bsa_params.setdefault('chunk_3d_shape_q', [4, 4, 4])
        effective_bsa_params.setdefault('chunk_3d_shape_k', [4, 4, 4])

        if hasattr(transformer, 'enable_bsa'):
            logger.info("Enabling BSA for LongCat VC transformer")
            transformer.enable_bsa()
            if hasattr(transformer, 'blocks'):
                try:
                    for blk in transformer.blocks:
                        if hasattr(blk, 'self_attn'):
                            blk.self_attn.bsa_params = effective_bsa_params
                except Exception as e:
                    logger.warning("Failed to set BSA params: %s", e)
            logger.info("BSA parameters: %s", effective_bsa_params)
    else:
        if hasattr(transformer, 'disable_bsa'):
            transformer.disable_bsa()

Modules

fastvideo.pipelines.basic.longcat.longcat_i2v_pipeline

LongCat Image-to-Video pipeline implementation.

This module implements I2V (Image-to-Video) generation for LongCat using Tier 3 conditioning with timestep masking, num_cond_latents support, and RoPE skipping.

Supports: - Basic I2V (50 steps, guidance_scale=4.0) - Distilled I2V with LoRA (16 steps, guidance_scale=1.0) - Refinement I2V for 720p upscaling (with refinement LoRA + BSA)

Classes
fastvideo.pipelines.basic.longcat.longcat_i2v_pipeline.LongCatImageToVideoPipeline
LongCatImageToVideoPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

LongCat Image-to-Video pipeline.

Generates video from a single input image using Tier 3 I2V conditioning: - Per-frame timestep masking (timestep[:, 0] = 0) - num_cond_latents parameter to transformer - RoPE skipping for conditioning frames - Selective denoising (skip first frame in scheduler)

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.longcat.longcat_i2v_pipeline.LongCatImageToVideoPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up I2V-specific pipeline stages.

Source code in fastvideo/pipelines/basic/longcat/longcat_i2v_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up I2V-specific pipeline stages."""

    # 1. Input validation
    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    # 2. Text encoding (same as T2V)
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    # 3. Image VAE encoding (for I2V - skipped in refinement mode)
    self.add_stage(stage_name="image_vae_encoding_stage",
                   stage=LongCatImageVAEEncodingStage(vae=self.get_module("vae")))

    # 4. Refinement initialization (skipped if not refining)
    self.add_stage(stage_name="longcat_refine_init_stage", stage=LongCatRefineInitStage(vae=self.get_module("vae")))

    # 5. Timestep preparation (generic)
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    # 6. Refinement timestep override (skipped if not refining)
    self.add_stage(stage_name="longcat_refine_timestep_stage",
                   stage=LongCatRefineTimestepStage(scheduler=self.get_module("scheduler")))

    # 7. Latent preparation with I2V conditioning
    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LongCatI2VLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                          transformer=self.get_module("transformer")))

    # 8. Denoising with I2V support
    self.add_stage(stage_name="denoising_stage",
                   stage=LongCatI2VDenoisingStage(transformer=self.get_module("transformer"),
                                                  transformer_2=self.get_module("transformer_2", None),
                                                  scheduler=self.get_module("scheduler"),
                                                  vae=self.get_module("vae"),
                                                  pipeline=self))

    # 9. Decoding
    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
fastvideo.pipelines.basic.longcat.longcat_i2v_pipeline.LongCatImageToVideoPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize LongCat-specific components.

Source code in fastvideo/pipelines/basic/longcat/longcat_i2v_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize LongCat-specific components."""
    # Same BSA initialization as base LongCat pipeline
    pipeline_config = fastvideo_args.pipeline_config
    transformer = self.get_module("transformer", None)
    if transformer is None:
        return

    # Enable BSA if configured
    if pipeline_config.enable_bsa:
        bsa_params_cfg = getattr(pipeline_config, 'bsa_params', None) or {}
        sparsity = getattr(pipeline_config, 'bsa_sparsity', None)
        cdf_threshold = getattr(pipeline_config, 'bsa_cdf_threshold', None)
        chunk_q = getattr(pipeline_config, 'bsa_chunk_q', None)
        chunk_k = getattr(pipeline_config, 'bsa_chunk_k', None)

        effective_bsa_params = dict(bsa_params_cfg) if isinstance(bsa_params_cfg, dict) else {}
        if sparsity is not None:
            effective_bsa_params['sparsity'] = sparsity
        if cdf_threshold is not None:
            effective_bsa_params['cdf_threshold'] = cdf_threshold
        if chunk_q is not None:
            effective_bsa_params['chunk_3d_shape_q'] = chunk_q
        if chunk_k is not None:
            effective_bsa_params['chunk_3d_shape_k'] = chunk_k

        # Provide defaults
        effective_bsa_params.setdefault('sparsity', 0.9375)
        effective_bsa_params.setdefault('chunk_3d_shape_q', [4, 4, 4])
        effective_bsa_params.setdefault('chunk_3d_shape_k', [4, 4, 4])

        if hasattr(transformer, 'enable_bsa'):
            logger.info("Enabling BSA for LongCat I2V transformer")
            transformer.enable_bsa()
            if hasattr(transformer, 'blocks'):
                try:
                    for blk in transformer.blocks:
                        if hasattr(blk, 'self_attn'):
                            blk.self_attn.bsa_params = effective_bsa_params
                except Exception as e:
                    logger.warning("Failed to set BSA params: %s", e)
            logger.info("BSA parameters: %s", effective_bsa_params)
    else:
        if hasattr(transformer, 'disable_bsa'):
            transformer.disable_bsa()
Functions
fastvideo.pipelines.basic.longcat.longcat_pipeline

LongCat video diffusion pipeline implementation.

This module implements the LongCat video diffusion pipeline using FastVideo's modular pipeline architecture.

Classes
fastvideo.pipelines.basic.longcat.longcat_pipeline.LongCatPipeline
LongCatPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

LongCat video diffusion pipeline with LoRA support.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.longcat.longcat_pipeline.LongCatPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/longcat/longcat_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    # Add refine initialization stage (will be skipped if not refining)
    self.add_stage(stage_name="longcat_refine_init_stage", stage=LongCatRefineInitStage(vae=self.get_module("vae")))

    # First prepare generic timesteps (for non-refine paths)
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    # Then override timesteps for refinement (will be a no-op if not refining),
    # matching LongCat's generate_refine schedule.
    self.add_stage(stage_name="longcat_refine_timestep_stage",
                   stage=LongCatRefineTimestepStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))

    self.add_stage(stage_name="denoising_stage",
                   stage=LongCatDenoisingStage(transformer=self.get_module("transformer"),
                                               transformer_2=self.get_module("transformer_2", None),
                                               scheduler=self.get_module("scheduler"),
                                               vae=self.get_module("vae"),
                                               pipeline=self))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
fastvideo.pipelines.basic.longcat.longcat_pipeline.LongCatPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize LongCat-specific components.

Source code in fastvideo/pipelines/basic/longcat/longcat_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize LongCat-specific components."""

    # Enable BSA (Block Sparse Attention) if configured
    pipeline_config = fastvideo_args.pipeline_config
    transformer = self.get_module("transformer", None)
    if transformer is None:
        raise RuntimeError("Transformer module not found during initializing LongCat pipeline.")
    # If user toggles BSA via CLI/config
    if pipeline_config.enable_bsa:
        # Build effective BSA params:
        # 1) from explicit CLI overrides if provided
        # 2) else from pipeline_config.bsa_params
        # 3) else fall back to reasonable defaults
        bsa_params_cfg = pipeline_config.bsa_params
        sparsity = pipeline_config.bsa_sparsity
        cdf_threshold = pipeline_config.bsa_cdf_threshold
        chunk_q = pipeline_config.bsa_chunk_q
        chunk_k = pipeline_config.bsa_chunk_k

        effective_bsa_params = dict(bsa_params_cfg) if isinstance(bsa_params_cfg, dict) else {}
        if sparsity is not None:
            effective_bsa_params['sparsity'] = sparsity
        if cdf_threshold is not None:
            effective_bsa_params['cdf_threshold'] = cdf_threshold
        if chunk_q is not None:
            effective_bsa_params['chunk_3d_shape_q'] = chunk_q
        if chunk_k is not None:
            effective_bsa_params['chunk_3d_shape_k'] = chunk_k
        # Provide defaults if still missing
        effective_bsa_params.setdefault('sparsity', 0.9375)
        effective_bsa_params.setdefault('chunk_3d_shape_q', [4, 4, 4])
        effective_bsa_params.setdefault('chunk_3d_shape_k', [4, 4, 4])

        if hasattr(transformer, 'enable_bsa'):
            logger.info("Enabling Block Sparse Attention (BSA) for LongCat transformer")
            transformer.enable_bsa()
            # Propagate params to all attention modules
            if hasattr(transformer, 'blocks'):
                try:
                    for blk in transformer.blocks:
                        if hasattr(blk, 'self_attn'):
                            blk.self_attn.bsa_params = effective_bsa_params
                except Exception as e:
                    logger.warning("Failed to set BSA params on all blocks: %s", e)
            logger.info("BSA parameters in effect: %s", effective_bsa_params)
        else:
            logger.warning("BSA is enabled in config but transformer does not support it")
    else:
        # Explicitly disable if present
        if hasattr(transformer, 'disable_bsa'):
            transformer.disable_bsa()
Functions
fastvideo.pipelines.basic.longcat.longcat_vc_pipeline

LongCat Video Continuation (VC) pipeline implementation.

This module implements VC (Video Continuation) generation for LongCat with KV cache optimization for 2-3x speedup.

Supports: - Basic VC (50 steps, guidance_scale=4.0) - Distilled VC with LoRA (16 steps, guidance_scale=1.0) - KV cache for conditioning frames

Classes
fastvideo.pipelines.basic.longcat.longcat_vc_pipeline.LongCatVCLatentPreparationStage
LongCatVCLatentPreparationStage(scheduler, transformer, use_btchw_layout: bool = False)

Bases: LongCatI2VLatentPreparationStage

Prepare latents with video conditioning for first N frames.

Extends I2V latent preparation to handle video_latent (multiple frames) instead of image_latent (single frame).

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, use_btchw_layout: bool = False) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.use_btchw_layout = use_btchw_layout
Functions
fastvideo.pipelines.basic.longcat.longcat_vc_pipeline.LongCatVCLatentPreparationStage.forward
forward(batch, fastvideo_args)

Prepare latents with VC conditioning.

Source code in fastvideo/pipelines/basic/longcat/longcat_vc_pipeline.py
def forward(self, batch, fastvideo_args):
    """Prepare latents with VC conditioning."""

    # Check if we have video_latent (from VC encoding stage)
    video_latent = getattr(batch, 'video_latent', None)
    if video_latent is not None:
        # Set image_latent to video_latent for parent class compatibility
        batch.image_latent = video_latent

    # Call parent class forward
    return super().forward(batch, fastvideo_args)
fastvideo.pipelines.basic.longcat.longcat_vc_pipeline.LongCatVideoContinuationPipeline
LongCatVideoContinuationPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

LongCat Video Continuation pipeline.

Generates video continuation from multiple conditioning frames using optional KV cache for 2-3x speedup.

Key features: - Takes video input (13+ frames typically) - Encodes conditioning frames via VAE - Optionally pre-computes KV cache for conditioning - Uses cached K/V during denoising for speedup - Concatenates conditioning back after denoising

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.longcat.longcat_vc_pipeline.LongCatVideoContinuationPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up VC-specific pipeline stages.

Source code in fastvideo/pipelines/basic/longcat/longcat_vc_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up VC-specific pipeline stages."""

    # 1. Input validation
    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    # 2. Text encoding
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    # 3. Video VAE encoding (encodes conditioning frames)
    self.add_stage(stage_name="video_vae_encoding_stage",
                   stage=LongCatVideoVAEEncodingStage(vae=self.get_module("vae")))

    # 4. Timestep preparation
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    # 5. Latent preparation (reuse I2V stage - it handles video_latent too)
    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LongCatVCLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                         transformer=self.get_module("transformer")))

    # 6. KV cache initialization (optional, based on config)
    # This is always added but will skip if use_kv_cache=False
    self.add_stage(stage_name="kv_cache_init_stage",
                   stage=LongCatKVCacheInitStage(transformer=self.get_module("transformer")))

    # 7. Denoising with VC and KV cache support
    self.add_stage(stage_name="denoising_stage",
                   stage=LongCatVCDenoisingStage(transformer=self.get_module("transformer"),
                                                 transformer_2=self.get_module("transformer_2", None),
                                                 scheduler=self.get_module("scheduler"),
                                                 vae=self.get_module("vae"),
                                                 pipeline=self))

    # 8. Decoding
    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
fastvideo.pipelines.basic.longcat.longcat_vc_pipeline.LongCatVideoContinuationPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize LongCat-specific components.

Source code in fastvideo/pipelines/basic/longcat/longcat_vc_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize LongCat-specific components."""
    pipeline_config = fastvideo_args.pipeline_config
    transformer = self.get_module("transformer", None)
    if transformer is None:
        return

    # Enable BSA if configured (for VC, BSA may not be needed)
    if getattr(pipeline_config, 'enable_bsa', False):
        bsa_params_cfg = getattr(pipeline_config, 'bsa_params', None) or {}
        sparsity = getattr(pipeline_config, 'bsa_sparsity', None)
        cdf_threshold = getattr(pipeline_config, 'bsa_cdf_threshold', None)
        chunk_q = getattr(pipeline_config, 'bsa_chunk_q', None)
        chunk_k = getattr(pipeline_config, 'bsa_chunk_k', None)

        effective_bsa_params = dict(bsa_params_cfg) if isinstance(bsa_params_cfg, dict) else {}
        if sparsity is not None:
            effective_bsa_params['sparsity'] = sparsity
        if cdf_threshold is not None:
            effective_bsa_params['cdf_threshold'] = cdf_threshold
        if chunk_q is not None:
            effective_bsa_params['chunk_3d_shape_q'] = chunk_q
        if chunk_k is not None:
            effective_bsa_params['chunk_3d_shape_k'] = chunk_k

        # Provide defaults
        effective_bsa_params.setdefault('sparsity', 0.9375)
        effective_bsa_params.setdefault('chunk_3d_shape_q', [4, 4, 4])
        effective_bsa_params.setdefault('chunk_3d_shape_k', [4, 4, 4])

        if hasattr(transformer, 'enable_bsa'):
            logger.info("Enabling BSA for LongCat VC transformer")
            transformer.enable_bsa()
            if hasattr(transformer, 'blocks'):
                try:
                    for blk in transformer.blocks:
                        if hasattr(blk, 'self_attn'):
                            blk.self_attn.bsa_params = effective_bsa_params
                except Exception as e:
                    logger.warning("Failed to set BSA params: %s", e)
            logger.info("BSA parameters: %s", effective_bsa_params)
    else:
        if hasattr(transformer, 'disable_bsa'):
            transformer.disable_bsa()
Functions
fastvideo.pipelines.basic.longcat.presets

LongCat model family pipeline presets.

Classes

fastvideo.pipelines.basic.ltx2

Modules

fastvideo.pipelines.basic.ltx2.continuation

Typed continuation state for the LTX-2 streaming pipeline.

Segment N+1 conditions on segment N's trailing decoded frames and denoised audio latents. The streaming runtime used to hold this state as per-worker globals; lifting it into a typed, JSON-serializable object lets clients snapshot, migrate, or round-trip it through an HTTP/RPC boundary. The envelope ContinuationState(kind, payload) is the shared public API; the typed class here owns the LTX-2 payload shape.

Serialization contract:

  • Video frames → PNG bytes + base64, or a :class:BlobStore id.
  • Audio latents → a self-describing safetensors blob + base64, or a :class:BlobStore id. safetensors preserves bfloat16, which a raw-numpy round-trip cannot.
  • The returned payload is always a plain JSON-serializable dict.
Attributes
fastvideo.pipelines.basic.ltx2.continuation.DEFAULT_INLINE_THRESHOLD_BYTES module-attribute
DEFAULT_INLINE_THRESHOLD_BYTES = 2 * 1024 * 1024

Tensors larger than this go to the blob store (if available). 2 MiB is below typical single-JSON-message limits (Dynamo: 4 MiB, Postgres TOAST: 1 GiB) and well above per-frame PNG payloads (~200 KiB at 512x512).

fastvideo.pipelines.basic.ltx2.continuation.LTX2_CONTINUATION_KIND module-attribute
LTX2_CONTINUATION_KIND = 'ltx2.v1'

Public ContinuationState.kind for LTX-2 payloads.

fastvideo.pipelines.basic.ltx2.continuation.LTX2_CONTINUATION_SCHEMA_VERSION module-attribute
LTX2_CONTINUATION_SCHEMA_VERSION = 1

Payload schema version carried inside payload.schema_version.

Classes
fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState dataclass
LTX2ContinuationState(segment_index: int = 0, video_frames: list[ndarray] | None = None, video_frames_blob_id: str | None = None, video_conditioning_frame_idx: int = 0, video_conditioning_strength: float = 1.0, audio_latents: Tensor | None = None, audio_latents_blob_id: str | None = None, audio_sample_rate: int | None = None, audio_conditioning_num_frames: int = 0, audio_conditioning_strength: float = 1.0, video_position_offset_sec: float = 0.0, metadata: dict[str, Any] = dict())

Typed LTX-2 continuation state carried between streaming segments.

video_frames hold trailing decoded RGB frames (uint8 HxWx3) from segment N for conditioning segment N+1 via the VAE encode path. audio_latents is the cached denoised audio latent tensor of shape [B, C, T, mel] that segment N+1 will copy into the overlap region of its clean-latent conditioning.

Most fields map 1:1 onto the internal gpu_pool's per-worker state; the only new concept is the *_blob_id fields, which allow large tensors to live outside the JSON payload. See module docstring.

Attributes
fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.audio_conditioning_num_frames class-attribute instance-attribute
audio_conditioning_num_frames: int = 0

Number of trailing audio frames that carry over as clean context into segment N+1.

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.audio_conditioning_strength class-attribute instance-attribute
audio_conditioning_strength: float = 1.0

Clean-latent mask value applied to the overlap region; 0.0 keeps the cached audio entirely, 1.0 renoises from scratch.

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.audio_latents class-attribute instance-attribute
audio_latents: Tensor | None = None

Denoised audio latent tensor of shape [B, C, T, mel]. None when the state is blob-backed or unset.

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.audio_latents_blob_id class-attribute instance-attribute
audio_latents_blob_id: str | None = None

Blob store id when audio latents live outside the payload.

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.audio_sample_rate class-attribute instance-attribute
audio_sample_rate: int | None = None

Sample rate for the audio side (e.g. 24000).

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.metadata class-attribute instance-attribute
metadata: dict[str, Any] = field(default_factory=dict)

Opaque metadata bag for forward-compat fields that don't need their own typed slot yet (e.g. custom knob experiments).

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.segment_index class-attribute instance-attribute
segment_index: int = 0

Index of the just-completed segment. Segment 0 has no history; state returned after segment 0 carries segment_index=0 and the caller uses segment_index + 1 as the next segment number.

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.video_conditioning_frame_idx class-attribute instance-attribute
video_conditioning_frame_idx: int = 0

Target frame index inside the next segment that the trailing frames align with (matches the LTX-2 ltx2_video_conditions tuple's frame_idx slot).

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.video_conditioning_strength class-attribute instance-attribute
video_conditioning_strength: float = 1.0

Conditioning strength in [0, 1]. Matches the ltx2_video_ conditions tuple's strength slot.

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.video_frames class-attribute instance-attribute
video_frames: list[ndarray] | None = None

Trailing decoded frames, each an RGB uint8 np.ndarray shaped (H, W, 3). None when the state is blob-backed or unset.

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.video_frames_blob_id class-attribute instance-attribute
video_frames_blob_id: str | None = None

Blob store id when the frames live outside the payload.

fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.video_position_offset_sec class-attribute instance-attribute
video_position_offset_sec: float = 0.0

Seconds by which video RoPE is shifted forward so the audio prefix can sit at t >= 0 when audio conditioning is longer than video conditioning.

Functions
fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.from_continuation_state classmethod
from_continuation_state(state: ContinuationState, *, blob_store: BlobStore | None = None) -> LTX2ContinuationState

Rebuild a typed state from a public :class:ContinuationState.

Raises :class:ValueError when the kind doesn't match or the schema version is unsupported.

Source code in fastvideo/pipelines/basic/ltx2/continuation.py
@classmethod
def from_continuation_state(
    cls,
    state: ContinuationState,
    *,
    blob_store: BlobStore | None = None,
) -> LTX2ContinuationState:
    """Rebuild a typed state from a public :class:`ContinuationState`.

    Raises :class:`ValueError` when the kind doesn't match or the
    schema version is unsupported.
    """
    if state.kind != LTX2_CONTINUATION_KIND:
        raise ValueError(f"Expected ContinuationState.kind={LTX2_CONTINUATION_KIND!r}, "
                         f"got {state.kind!r}")
    payload = state.payload or {}
    version = int(payload.get("schema_version", LTX2_CONTINUATION_SCHEMA_VERSION))
    if version != LTX2_CONTINUATION_SCHEMA_VERSION:
        raise ValueError(f"Unsupported LTX-2 continuation schema_version={version}; "
                         f"this build expects {LTX2_CONTINUATION_SCHEMA_VERSION}")

    out = cls(
        segment_index=int(payload.get("segment_index", 0)),
        video_conditioning_frame_idx=int(payload.get("video_conditioning_frame_idx", 0)),
        video_conditioning_strength=float(payload.get("video_conditioning_strength", 1.0)),
        audio_sample_rate=(int(payload["audio_sample_rate"]) if "audio_sample_rate" in payload else None),
        audio_conditioning_num_frames=int(payload.get("audio_conditioning_num_frames", 0)),
        audio_conditioning_strength=float(payload.get("audio_conditioning_strength", 1.0)),
        video_position_offset_sec=float(payload.get("video_position_offset_sec", 0.0)),
        metadata=dict(payload.get("metadata") or {}),
    )

    video = payload.get("video")
    if isinstance(video, Mapping):
        cls._decode_video_frames(out, video, blob_store=blob_store)

    audio = payload.get("audio")
    if isinstance(audio, Mapping):
        cls._decode_audio_latents(out, audio, blob_store=blob_store)

    return out
fastvideo.pipelines.basic.ltx2.continuation.LTX2ContinuationState.to_continuation_state
to_continuation_state(*, blob_store: BlobStore | None = None, inline_threshold_bytes: int = DEFAULT_INLINE_THRESHOLD_BYTES) -> ContinuationState

Serialize into a public :class:ContinuationState.

When blob_store is given, tensors larger than inline_threshold_bytes are stored via :meth:BlobStore.put and referenced by id; otherwise all data is base64-encoded inline. The payload is always a plain JSON-serializable dict.

Source code in fastvideo/pipelines/basic/ltx2/continuation.py
def to_continuation_state(
    self,
    *,
    blob_store: BlobStore | None = None,
    inline_threshold_bytes: int = DEFAULT_INLINE_THRESHOLD_BYTES,
) -> ContinuationState:
    """Serialize into a public :class:`ContinuationState`.

    When ``blob_store`` is given, tensors larger than
    ``inline_threshold_bytes`` are stored via
    :meth:`BlobStore.put` and referenced by id; otherwise all data
    is base64-encoded inline. The payload is always a plain
    JSON-serializable dict.
    """
    payload: dict[str, Any] = {
        "schema_version": LTX2_CONTINUATION_SCHEMA_VERSION,
        "segment_index": int(self.segment_index),
        "video_conditioning_frame_idx": int(self.video_conditioning_frame_idx),
        "video_conditioning_strength": float(self.video_conditioning_strength),
        "audio_conditioning_num_frames": int(self.audio_conditioning_num_frames),
        "audio_conditioning_strength": float(self.audio_conditioning_strength),
        "video_position_offset_sec": float(self.video_position_offset_sec),
        "metadata": dict(self.metadata),
    }
    if self.audio_sample_rate is not None:
        payload["audio_sample_rate"] = int(self.audio_sample_rate)

    video_payload = self._encode_video_frames(
        blob_store=blob_store,
        inline_threshold_bytes=inline_threshold_bytes,
    )
    if video_payload is not None:
        payload["video"] = video_payload

    audio_payload = self._encode_audio_latents(
        blob_store=blob_store,
        inline_threshold_bytes=inline_threshold_bytes,
    )
    if audio_payload is not None:
        payload["audio"] = audio_payload

    return ContinuationState(
        kind=LTX2_CONTINUATION_KIND,
        payload=payload,
    )
Functions
fastvideo.pipelines.basic.ltx2.ltx2_pipeline

LTX-2 text-to-video pipeline.

Classes
Functions
fastvideo.pipelines.basic.ltx2.pipeline_configs
Classes
fastvideo.pipelines.basic.ltx2.pipeline_configs.LTX2T2VConfig dataclass
LTX2T2VConfig(model_path: str = '', pipeline_config_path: str | None = None, embedded_cfg_scale: float = 6.0, flow_shift: float | None = None, flow_shift_sr: float | None = None, disable_autocast: bool = False, is_causal: bool = False, dit_config: DiTConfig = LTX2VideoConfig(), dit_precision: str = 'bf16', upsampler_config: UpsamplerConfig = UpsamplerConfig(), upsampler_precision: str = 'fp32', vae_config: VAEConfig = LTX2VAEConfig(), vae_precision: str = 'bf16', vae_tiling: bool = True, vae_sp: bool = False, image_encoder_config: EncoderConfig = EncoderConfig(), image_encoder_precision: str = 'fp32', text_encoder_configs: tuple[EncoderConfig, ...] = (lambda: (LTX2GemmaConfig(),))(), text_encoder_precisions: tuple[str, ...] = (lambda: ('bf16',))(), preprocess_text_funcs: tuple[Callable[[str], str], ...] = (lambda: (preprocess_text,))(), postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], Tensor], ...] = (lambda: (ltx2_postprocess_text,))(), dmd_denoising_steps: list[int] | None = None, ti2v_task: bool = False, boundary_ratio: float | None = None, audio_decoder_config: ModelConfig = LTX2AudioDecoderConfig(), vocoder_config: ModelConfig = LTX2VocoderConfig(), audio_decoder_precision: str = 'bf16', vocoder_precision: str = 'bf16')

Bases: PipelineConfig

Configuration for LTX-2 T2V pipeline.

fastvideo.pipelines.basic.ltx2.presets

LTX2 model family pipeline presets.

Classes
fastvideo.pipelines.basic.ltx2.stage_overrides

Typed override surfaces for the LTX-2 two-stage refine flow.

  • preset_overrides.refine — init-time knobs (see :class:LTX2RefinePresetOverride).
  • stage_overrides.refine — per-request knobs (see :class:LTX2RefineStageOverride).

Asset paths live on :class:~fastvideo.api.schema.ComponentConfig (upsampler_weights and lora_path).

Classes
fastvideo.pipelines.basic.ltx2.stage_overrides.LTX2RefinePresetOverride dataclass
LTX2RefinePresetOverride(enabled: bool | None = None, add_noise: bool | None = None)

Init-time refine wiring under preset_overrides.refine.

fastvideo.pipelines.basic.ltx2.stage_overrides.LTX2RefineStageOverride dataclass
LTX2RefineStageOverride(num_inference_steps: int | None = None, guidance_scale: float | None = None, image_crf: int | None = None, video_position_offset_sec: float | None = None)

Per-request refine tuning under stage_overrides.refine.

Functions
fastvideo.pipelines.basic.ltx2.stage_overrides.refine_override_to_dict
refine_override_to_dict(override: LTX2RefinePresetOverride | LTX2RefineStageOverride) -> dict[str, Any]

Serialise a refine override, dropping None entries so only user-set fields reach preset_overrides.refine or stage_overrides.refine.

Source code in fastvideo/pipelines/basic/ltx2/stage_overrides.py
def refine_override_to_dict(override: LTX2RefinePresetOverride | LTX2RefineStageOverride, ) -> dict[str, Any]:
    """Serialise a refine override, dropping ``None`` entries so only
    user-set fields reach ``preset_overrides.refine`` or
    ``stage_overrides.refine``."""
    return {k: v for k, v in asdict(override).items() if v is not None}
fastvideo.pipelines.basic.ltx2.stages

LTX-2 family pipeline stages.

Classes
fastvideo.pipelines.basic.ltx2.stages.LTX2AudioDecodingStage
LTX2AudioDecodingStage(audio_decoder, vocoder)

Bases: PipelineStage

Decode LTX-2 audio latents into a waveform.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_audio_decoding.py
def __init__(self, audio_decoder, vocoder) -> None:
    super().__init__()
    self.audio_decoder = audio_decoder
    self.vocoder = vocoder
fastvideo.pipelines.basic.ltx2.stages.LTX2DenoisingStage
LTX2DenoisingStage(transformer)

Bases: PipelineStage

Run the LTX-2 denoising loop over the sigma schedule.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_denoising.py
def __init__(self, transformer) -> None:
    super().__init__()
    self.transformer = transformer
fastvideo.pipelines.basic.ltx2.stages.LTX2LatentPreparationStage
LTX2LatentPreparationStage(transformer)

Bases: PipelineStage

Prepare initial LTX-2 latents without relying on a diffusers scheduler.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_latent_preparation.py
def __init__(self, transformer) -> None:
    super().__init__()
    self.transformer = transformer
fastvideo.pipelines.basic.ltx2.stages.LTX2TextEncodingStage
LTX2TextEncodingStage(text_encoders, tokenizers)

Bases: TextEncodingStage

LTX2 text encoding stage with sequence parallelism support.

When SP is enabled (sp_world_size > 1), only rank 0 runs the text encoder and broadcasts embeddings to other ranks. This avoids I/O contention from all ranks loading the Gemma model simultaneously, which can cause text encoding to take 100+ seconds instead of ~5 seconds.

Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders
    self._last_audio_embeds: list[torch.Tensor] | None = None
Modules
fastvideo.pipelines.basic.ltx2.stages.ltx2_audio_decoding

Audio decoding stage for LTX-2 pipelines.

Classes
fastvideo.pipelines.basic.ltx2.stages.ltx2_audio_decoding.LTX2AudioDecodingStage
LTX2AudioDecodingStage(audio_decoder, vocoder)

Bases: PipelineStage

Decode LTX-2 audio latents into a waveform.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_audio_decoding.py
def __init__(self, audio_decoder, vocoder) -> None:
    super().__init__()
    self.audio_decoder = audio_decoder
    self.vocoder = vocoder
Functions
fastvideo.pipelines.basic.ltx2.stages.ltx2_denoising

LTX-2 denoising stage using the native sigma schedule.

Classes
fastvideo.pipelines.basic.ltx2.stages.ltx2_denoising.LTX2DenoisingStage
LTX2DenoisingStage(transformer)

Bases: PipelineStage

Run the LTX-2 denoising loop over the sigma schedule.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_denoising.py
def __init__(self, transformer) -> None:
    super().__init__()
    self.transformer = transformer
Functions
fastvideo.pipelines.basic.ltx2.stages.ltx2_latent_preparation

Latent preparation stage for LTX-2 pipelines.

Classes
fastvideo.pipelines.basic.ltx2.stages.ltx2_latent_preparation.LTX2LatentPreparationStage
LTX2LatentPreparationStage(transformer)

Bases: PipelineStage

Prepare initial LTX-2 latents without relying on a diffusers scheduler.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_latent_preparation.py
def __init__(self, transformer) -> None:
    super().__init__()
    self.transformer = transformer
Functions
fastvideo.pipelines.basic.ltx2.stages.ltx2_text_encoding

LTX2-specific text encoding stage with sequence parallelism broadcast support.

When running with sequence parallelism (SP), the Gemma text encoder is only executed on rank 0, and the embeddings are broadcast to all other ranks. This avoids I/O contention from all ranks loading the Gemma model simultaneously.

Classes
fastvideo.pipelines.basic.ltx2.stages.ltx2_text_encoding.LTX2TextEncodingStage
LTX2TextEncodingStage(text_encoders, tokenizers)

Bases: TextEncodingStage

LTX2 text encoding stage with sequence parallelism support.

When SP is enabled (sp_world_size > 1), only rank 0 runs the text encoder and broadcasts embeddings to other ranks. This avoids I/O contention from all ranks loading the Gemma model simultaneously, which can cause text encoding to take 100+ seconds instead of ~5 seconds.

Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders
    self._last_audio_embeds: list[torch.Tensor] | None = None
Functions

fastvideo.pipelines.basic.matrixgame

Modules

fastvideo.pipelines.basic.matrixgame.matrixgame_causal_dmd_pipeline

Matrix-Game causal DMD pipeline implementation.

Classes
Functions
fastvideo.pipelines.basic.matrixgame.matrixgame_i2v_pipeline

Matrix-Game I2V pipeline implementation.

Classes
Functions
fastvideo.pipelines.basic.matrixgame.presets

MatrixGame model family pipeline presets.

Classes

fastvideo.pipelines.basic.sd35

Modules

fastvideo.pipelines.basic.sd35.presets

Stable Diffusion 3.5 model family pipeline presets.

Classes
fastvideo.pipelines.basic.sd35.sd35_pipeline
Classes
fastvideo.pipelines.basic.sd35.sd35_pipeline.SD35Pipeline
SD35Pipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Minimal SD3.5 Medium text-to-image pipeline (treat as num_frames=1).

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
fastvideo.pipelines.basic.sd35.sd35_pipeline.StableDiffusion3Pipeline
StableDiffusion3Pipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: SD35Pipeline

Alias name to match SD3.5 diffusers model_index.json _class_name.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions

fastvideo.pipelines.basic.turbodiffusion

Classes

fastvideo.pipelines.basic.turbodiffusion.TurboDiffusionI2VPipeline
TurboDiffusionI2VPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

TurboDiffusion I2V pipeline for 1-4 step image-to-video generation.

Uses RCM scheduler, SLA attention, and dual model switching for high-quality I2V generation.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.turbodiffusion.TurboDiffusionI2VPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/turbodiffusion/turbodiffusion_i2v_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))

    # I2V: Encode initial image to latent space
    self.add_stage(stage_name="image_latent_preparation_stage",
                   stage=ImageVAEEncodingStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        transformer_2=self.get_module("transformer_2", None),
                                        scheduler=self.get_module("scheduler"),
                                        vae=self.get_module("vae"),
                                        pipeline=self))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
fastvideo.pipelines.basic.turbodiffusion.TurboDiffusionPipeline
TurboDiffusionPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

TurboDiffusion video pipeline for 1-4 step generation.

Uses RCM scheduler and SLA attention for fast, high-quality video generation.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.turbodiffusion.TurboDiffusionPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/turbodiffusion/turbodiffusion_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        transformer_2=self.get_module("transformer_2", None),
                                        scheduler=self.get_module("scheduler"),
                                        vae=self.get_module("vae"),
                                        pipeline=self))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))

Modules

fastvideo.pipelines.basic.turbodiffusion.presets

TurboDiffusion model family pipeline presets.

Classes
fastvideo.pipelines.basic.turbodiffusion.turbodiffusion_i2v_pipeline

TurboDiffusion I2V (Image-to-Video) Pipeline Implementation.

This module contains an implementation of the TurboDiffusion I2V pipeline for 1-4 step image-to-video generation using rCM (recurrent Consistency Model) sampling with SLA (Sparse-Linear Attention).

Key differences from T2V: - Uses dual models (high/low noise) with boundary switching - sigma_max=200 (vs 80 for T2V) - Mask conditioning with encoded first frame

Classes
fastvideo.pipelines.basic.turbodiffusion.turbodiffusion_i2v_pipeline.TurboDiffusionI2VPipeline
TurboDiffusionI2VPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

TurboDiffusion I2V pipeline for 1-4 step image-to-video generation.

Uses RCM scheduler, SLA attention, and dual model switching for high-quality I2V generation.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.turbodiffusion.turbodiffusion_i2v_pipeline.TurboDiffusionI2VPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/turbodiffusion/turbodiffusion_i2v_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))

    # I2V: Encode initial image to latent space
    self.add_stage(stage_name="image_latent_preparation_stage",
                   stage=ImageVAEEncodingStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        transformer_2=self.get_module("transformer_2", None),
                                        scheduler=self.get_module("scheduler"),
                                        vae=self.get_module("vae"),
                                        pipeline=self))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
Functions
fastvideo.pipelines.basic.turbodiffusion.turbodiffusion_pipeline

TurboDiffusion Video Pipeline Implementation.

This module contains an implementation of the TurboDiffusion video diffusion pipeline for 1-4 step video generation using rCM (recurrent Consistency Model) sampling with SLA (Sparse-Linear Attention).

Classes
fastvideo.pipelines.basic.turbodiffusion.turbodiffusion_pipeline.TurboDiffusionPipeline
TurboDiffusionPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

TurboDiffusion video pipeline for 1-4 step generation.

Uses RCM scheduler and SLA attention for fast, high-quality video generation.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.turbodiffusion.turbodiffusion_pipeline.TurboDiffusionPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/turbodiffusion/turbodiffusion_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        transformer_2=self.get_module("transformer_2", None),
                                        scheduler=self.get_module("scheduler"),
                                        vae=self.get_module("vae"),
                                        pipeline=self))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
Functions

fastvideo.pipelines.basic.wan

Modules

fastvideo.pipelines.basic.wan.presets

Wan model family pipeline presets.

Each preset is a named inference preset that declares the user-facing stage topology, default sampling values, and which per-stage overrides are allowed. Presets are registered explicitly from :func:fastvideo.registry._register_presets.

Classes
fastvideo.pipelines.basic.wan.wan_causal_dmd_pipeline

Wan causal DMD pipeline implementation.

This module wires the causal DMD denoising stage into the modular pipeline.

Classes
fastvideo.pipelines.basic.wan.wan_causal_dmd_pipeline.WanCausalDMDPipeline
WanCausalDMDPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.wan.wan_causal_dmd_pipeline.WanCausalDMDPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/wan/wan_causal_dmd_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))

    self.add_stage(stage_name="denoising_stage",
                   stage=CausalDMDDenosingStage(transformer=self.get_module("transformer"),
                                                transformer_2=self.get_module("transformer_2", None),
                                                scheduler=self.get_module("scheduler"),
                                                vae=self.get_module("vae")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.wan.wan_causal_pipeline

Wan causal pipeline with standard multi-step denoising.

Block-by-block causal inference with KV caching, using the full scheduler timestep schedule (40-50 steps) rather than DMD few-step.

Classes
fastvideo.pipelines.basic.wan.wan_causal_pipeline.WanCausalPipeline
WanCausalPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Wan causal pipeline with standard multi-step denoising.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.wan.wan_dmd_pipeline

Wan video diffusion pipeline implementation.

This module contains an implementation of the Wan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.wan.wan_dmd_pipeline.WanDMDPipeline
WanDMDPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Wan video diffusion pipeline with LoRA support.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.wan.wan_dmd_pipeline.WanDMDPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None),
                                                use_btchw_layout=True))

    self.add_stage(stage_name="denoising_stage",
                   stage=DmdDenoisingStage(transformer=self.get_module("transformer"),
                                           scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.wan.wan_i2v_dmd_pipeline

Wan video diffusion pipeline implementation.

This module contains an implementation of the Wan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.wan.wan_i2v_dmd_pipeline.WanImageToVideoDmdPipeline
WanImageToVideoDmdPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.wan.wan_i2v_dmd_pipeline.WanImageToVideoDmdPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="image_encoding_stage",
                   stage=ImageEncodingStage(
                       image_encoder=self.get_module("image_encoder"),
                       image_processor=self.get_module("image_processor"),
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer"),
                                                use_btchw_layout=True))

    self.add_stage(stage_name="image_latent_preparation_stage",
                   stage=ImageVAEEncodingStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=DmdDenoisingStage(transformer=self.get_module("transformer"),
                                           scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.wan.wan_i2v_pipeline

Wan video diffusion pipeline implementation.

This module contains an implementation of the Wan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.wan.wan_i2v_pipeline.WanImageToVideoPipeline
WanImageToVideoPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.wan.wan_i2v_pipeline.WanImageToVideoPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/wan/wan_i2v_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    if (self.get_module("image_encoder") is not None and self.get_module("image_processor") is not None):
        self.add_stage(stage_name="image_encoding_stage",
                       stage=ImageEncodingStage(
                           image_encoder=self.get_module("image_encoder"),
                           image_processor=self.get_module("image_processor"),
                       ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer")))

    self.add_stage(stage_name="image_latent_preparation_stage",
                   stage=ImageVAEEncodingStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        transformer_2=self.get_module("transformer_2"),
                                        scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions
fastvideo.pipelines.basic.wan.wan_pipeline

Wan video diffusion pipeline implementation.

This module contains an implementation of the Wan video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline
WanPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Wan video diffusion pipeline with LoRA support.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/wan/wan_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        transformer_2=self.get_module("transformer_2", None),
                                        scheduler=self.get_module("scheduler"),
                                        vae=self.get_module("vae"),
                                        pipeline=self))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
Functions
fastvideo.pipelines.basic.wan.wan_v2v_pipeline

Wan video-to-video diffusion pipeline implementation.

This module contains an implementation of the Wan video-to-video diffusion pipeline using the modular pipeline architecture.

Classes
fastvideo.pipelines.basic.wan.wan_v2v_pipeline.WanVideoToVideoPipeline
WanVideoToVideoPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Functions
fastvideo.pipelines.basic.wan.wan_v2v_pipeline.WanVideoToVideoPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/wan/wan_v2v_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    if (self.get_module("image_encoder") is not None and self.get_module("image_processor") is not None):
        self.add_stage(stage_name="ref_image_encoding_stage",
                       stage=RefImageEncodingStage(
                           image_encoder=self.get_module("image_encoder"),
                           image_processor=self.get_module("image_processor"),
                       ))

    self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer")))

    self.add_stage(stage_name="video_latent_preparation_stage",
                   stage=VideoVAEEncodingStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(transformer=self.get_module("transformer"),
                                        transformer_2=self.get_module("transformer_2"),
                                        scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
Functions