Skip to content

gen3c

GEN3C is a 3D-informed world-consistent video generation model with precise camera control.

Classes

fastvideo.pipelines.basic.gen3c.Cache3DBase

Cache3DBase(input_image: Tensor, input_depth: Tensor, input_w2c: Tensor, input_intrinsics: Tensor, input_mask: Tensor | None = None, input_format: list[str] | None = None, input_points: Tensor | None = None, weight_dtype: dtype = float32, is_depth: bool = True, device: str = 'cuda', filter_points_threshold: float = 1.0)

Base class for 3D cache management.

The cache maintains: - input_image: RGB images stored in the cache - input_points: 3D world coordinates for each pixel - input_mask: Validity mask for each pixel

Initialize the 3D cache.

Parameters:

Name Type Description Default
input_image Tensor

Input image tensor with varying dimensions

required
input_depth Tensor

Depth map tensor

required
input_w2c Tensor

World-to-camera transformation matrix

required
input_intrinsics Tensor

Camera intrinsic matrix

required
input_mask Tensor | None

Optional validity mask

None
input_format list[str] | None

Dimension labels for input_image (e.g., ['B', 'C', 'H', 'W'])

None
input_points Tensor | None

Pre-computed 3D world points (alternative to depth)

None
weight_dtype dtype

Data type for computations

float32
is_depth bool

If True, input_depth is z-depth; if False, it's distance

True
device str

Computation device

'cuda'
filter_points_threshold float

Threshold for filtering unreliable depth

1.0
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def __init__(
    self,
    input_image: torch.Tensor,
    input_depth: torch.Tensor,
    input_w2c: torch.Tensor,
    input_intrinsics: torch.Tensor,
    input_mask: torch.Tensor | None = None,
    input_format: list[str] | None = None,
    input_points: torch.Tensor | None = None,
    weight_dtype: torch.dtype = torch.float32,
    is_depth: bool = True,
    device: str = "cuda",
    filter_points_threshold: float = 1.0,
):
    """
    Initialize the 3D cache.

    Args:
        input_image: Input image tensor with varying dimensions
        input_depth: Depth map tensor
        input_w2c: World-to-camera transformation matrix
        input_intrinsics: Camera intrinsic matrix
        input_mask: Optional validity mask
        input_format: Dimension labels for input_image (e.g., ['B', 'C', 'H', 'W'])
        input_points: Pre-computed 3D world points (alternative to depth)
        weight_dtype: Data type for computations
        is_depth: If True, input_depth is z-depth; if False, it's distance
        device: Computation device
        filter_points_threshold: Threshold for filtering unreliable depth
    """
    self.weight_dtype = weight_dtype
    self.is_depth = is_depth
    self.device = device
    self.filter_points_threshold = filter_points_threshold

    if input_format is None:
        assert input_image.dim() == 4
        input_format = ["B", "C", "H", "W"]

    # Map dimension names to indices
    format_to_indices = {dim: idx for idx, dim in enumerate(input_format)}
    input_shape = input_image.shape

    if input_mask is not None:
        input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C"))

    # Extract dimensions
    B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1
    F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1
    N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1
    V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1
    H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None
    W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None

    # Reorder dimensions to B x F x N x V x C x H x W
    desired_dims = ["B", "F", "N", "V", "C", "H", "W"]
    permute_order: list[int | None] = []
    for dim in desired_dims:
        idx = format_to_indices.get(dim)
        permute_order.append(idx)

    permute_indices = [idx for idx in permute_order if idx is not None]
    input_image = input_image.permute(*permute_indices)

    for i, idx in enumerate(permute_order):
        if idx is None:
            input_image = input_image.unsqueeze(i)

    # Now input_image has shape B x F x N x V x C x H x W
    if input_mask is not None:
        self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:]
        self.input_mask = self.input_mask.to("cpu")
    else:
        self.input_mask = None
        self.input_image = input_image
    self.input_image = self.input_image.to(weight_dtype).to("cpu")

    # Compute 3D world points
    if input_points is not None:
        self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu")
        self.input_depth = None
    else:
        input_depth = torch.nan_to_num(input_depth, nan=100)
        input_depth = torch.clamp(input_depth, min=0, max=100)
        if weight_dtype == torch.float16:
            input_depth = torch.clamp(input_depth, max=70)

        self.input_points = (unproject_points(
            input_depth.reshape(-1, 1, H, W),
            input_w2c.reshape(-1, 4, 4),
            input_intrinsics.reshape(-1, 3, 3),
            is_depth=self.is_depth,
        ).to(weight_dtype).reshape(B, F, N, V, H, W, 3).to("cpu"))
        self.input_depth = input_depth

    # Filter unreliable depth
    if self.filter_points_threshold < 1.0 and input_depth is not None:
        input_depth = input_depth.reshape(-1, 1, H, W)
        depth_mask = reliable_depth_mask_range_batch(input_depth,
                                                     ratio_thresh=self.filter_points_threshold).reshape(
                                                         B, F, N, V, 1, H, W)
        if self.input_mask is None:
            self.input_mask = depth_mask.to("cpu")
        else:
            self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device)

Functions

fastvideo.pipelines.basic.gen3c.Cache3DBase.input_frame_count
input_frame_count() -> int

Return the number of frames in the cache.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def input_frame_count(self) -> int:
    """Return the number of frames in the cache."""
    return self.input_image.shape[1]
fastvideo.pipelines.basic.gen3c.Cache3DBase.render_cache
render_cache(target_w2cs: Tensor, target_intrinsics: Tensor, render_depth: bool = False, start_frame_idx: int = 0) -> tuple[Tensor, Tensor]

Render the cached 3D points from new camera viewpoints.

Parameters:

Name Type Description Default
target_w2cs Tensor

(b, F_target, 4, 4) target camera transformations

required
target_intrinsics Tensor

(b, F_target, 3, 3) target camera intrinsics

required
render_depth bool

If True, return depth instead of RGB

False
start_frame_idx int

Starting frame index in the cache

0

Returns:

Name Type Description
pixels Tensor

(b, F_target, N, c, h, w) rendered images or depth

masks Tensor

(b, F_target, N, 1, h, w) validity masks

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def render_cache(
    self,
    target_w2cs: torch.Tensor,
    target_intrinsics: torch.Tensor,
    render_depth: bool = False,
    start_frame_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the cached 3D points from new camera viewpoints.

    Args:
        target_w2cs: (b, F_target, 4, 4) target camera transformations
        target_intrinsics: (b, F_target, 3, 3) target camera intrinsics
        render_depth: If True, return depth instead of RGB
        start_frame_idx: Starting frame index in the cache

    Returns:
        pixels: (b, F_target, N, c, h, w) rendered images or depth
        masks: (b, F_target, N, 1, h, w) validity masks
    """
    bs, F_target, _, _ = target_w2cs.shape
    B, F, N, V, C, H, W = self.input_image.shape
    assert bs == B

    target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4)
    target_intrinsics = target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3,
                                                                               3).reshape(-1, 3, 3)

    # Prepare inputs
    first_images = rearrange(
        self.input_image[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, C, H, W),
        "B F N V C H W -> (B F N) V C H W")
    first_points = rearrange(
        self.input_points[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, H, W, 3),
        "B F N V H W C -> (B F N) V H W C")
    first_masks = rearrange(
        self.input_mask[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, 1, H, W),
        "B F N V C H W -> (B F N) V C H W") if self.input_mask is not None else None

    # Process in chunks for memory efficiency
    if first_images.shape[1] == 1:
        warp_chunk_size = 2
        rendered_warp_images = []
        rendered_warp_masks = []
        rendered_warp_depth = []

        first_images = first_images.squeeze(1)
        first_points = first_points.squeeze(1)
        first_masks = first_masks.squeeze(1) if first_masks is not None else None

        for i in range(0, first_images.shape[0], warp_chunk_size):
            with torch.no_grad():
                imgs_chunk = first_images[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                pts_chunk = first_points[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                masks_chunk = (first_masks[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                               if first_masks is not None else None)

                (
                    rendered_warp_images_chunk,
                    rendered_warp_masks_chunk,
                    rendered_warp_depth_chunk,
                    _,
                ) = forward_warp(
                    imgs_chunk,
                    mask1=masks_chunk,
                    depth1=None,
                    transformation1=None,
                    transformation2=target_w2cs[i:i + warp_chunk_size],
                    intrinsic1=target_intrinsics[i:i + warp_chunk_size],
                    intrinsic2=target_intrinsics[i:i + warp_chunk_size],
                    render_depth=render_depth,
                    world_points1=pts_chunk,
                )

                rendered_warp_images.append(rendered_warp_images_chunk.to("cpu"))
                rendered_warp_masks.append(rendered_warp_masks_chunk.to("cpu"))
                if render_depth:
                    rendered_warp_depth.append(rendered_warp_depth_chunk.to("cpu"))

                del imgs_chunk, pts_chunk, masks_chunk
                torch.cuda.empty_cache()

        rendered_warp_images = torch.cat(rendered_warp_images, dim=0)
        rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0)
        if render_depth:
            rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0)
    else:
        raise NotImplementedError("Multi-view rendering not yet supported")

    pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)
    masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)

    if render_depth:
        pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N)

    return pixels.to(self.device), masks.to(self.device)
fastvideo.pipelines.basic.gen3c.Cache3DBase.update_cache
update_cache(**kwargs)

Update the cache with new frames. To be implemented by subclasses.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def update_cache(self, **kwargs):
    """Update the cache with new frames. To be implemented by subclasses."""
    raise NotImplementedError

fastvideo.pipelines.basic.gen3c.Cache3DBuffer

Cache3DBuffer(frame_buffer_max: int = 2, noise_aug_strength: float = 0.0, generator: Generator | None = None, **kwargs)

Bases: Cache3DBase

3D cache with frame buffer support.

This class manages multiple frame buffers for temporal consistency and supports noise augmentation for training stability.

Initialize the buffered 3D cache.

Parameters:

Name Type Description Default
frame_buffer_max int

Maximum number of frames to buffer

2
noise_aug_strength float

Strength of noise augmentation per buffer

0.0
generator Generator | None

Random generator for reproducibility

None
**kwargs

Arguments passed to Cache3DBase

{}
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def __init__(
    self,
    frame_buffer_max: int = 2,
    noise_aug_strength: float = 0.0,
    generator: torch.Generator | None = None,
    **kwargs,
):
    """
    Initialize the buffered 3D cache.

    Args:
        frame_buffer_max: Maximum number of frames to buffer
        noise_aug_strength: Strength of noise augmentation per buffer
        generator: Random generator for reproducibility
        **kwargs: Arguments passed to Cache3DBase
    """
    super().__init__(**kwargs)
    self.frame_buffer_max = frame_buffer_max
    self.noise_aug_strength = noise_aug_strength
    self.generator = generator

Functions

fastvideo.pipelines.basic.gen3c.Cache3DBuffer.render_cache
render_cache(target_w2cs: Tensor, target_intrinsics: Tensor, render_depth: bool = False, start_frame_idx: int = 0) -> tuple[Tensor, Tensor]

Render the cache with optional noise augmentation.

Parameters:

Name Type Description Default
target_w2cs Tensor

(b, F_target, 4, 4) target camera transformations

required
target_intrinsics Tensor

(b, F_target, 3, 3) target camera intrinsics

required
render_depth bool

If True, return depth instead of RGB

False
start_frame_idx int

Starting frame index (must be 0 for this class)

0

Returns:

Name Type Description
pixels Tensor

(b, F_target, N, c, h, w) rendered images

masks Tensor

(b, F_target, N, 1, h, w) validity masks

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def render_cache(
    self,
    target_w2cs: torch.Tensor,
    target_intrinsics: torch.Tensor,
    render_depth: bool = False,
    start_frame_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the cache with optional noise augmentation.

    Args:
        target_w2cs: (b, F_target, 4, 4) target camera transformations
        target_intrinsics: (b, F_target, 3, 3) target camera intrinsics
        render_depth: If True, return depth instead of RGB
        start_frame_idx: Starting frame index (must be 0 for this class)

    Returns:
        pixels: (b, F_target, N, c, h, w) rendered images
        masks: (b, F_target, N, 1, h, w) validity masks
    """
    assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3DBuffer"

    output_device = target_w2cs.device
    target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device)
    target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device)

    pixels, masks = super().render_cache(target_w2cs, target_intrinsics, render_depth)

    pixels = pixels.to(output_device)
    masks = masks.to(output_device)

    # Apply noise augmentation (stronger for older buffers)
    if not render_depth and self.noise_aug_strength > 0:
        noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype)
        per_buffer_noise = (torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device) *
                            self.noise_aug_strength)
        pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1)

    return pixels, masks
fastvideo.pipelines.basic.gen3c.Cache3DBuffer.update_cache
update_cache(new_image: Tensor, new_depth: Tensor, new_w2c: Tensor, new_mask: Tensor | None = None, new_intrinsics: Tensor | None = None)

Update the cache with a new frame.

Parameters:

Name Type Description Default
new_image Tensor

(B, C, H, W) new RGB image

required
new_depth Tensor

(B, 1, H, W) new depth map

required
new_w2c Tensor

(B, 4, 4) new world-to-camera transformation

required
new_mask Tensor | None

Optional (B, 1, H, W) validity mask

None
new_intrinsics Tensor | None

(B, 3, 3) camera intrinsics (optional)

None
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def update_cache(
    self,
    new_image: torch.Tensor,
    new_depth: torch.Tensor,
    new_w2c: torch.Tensor,
    new_mask: torch.Tensor | None = None,
    new_intrinsics: torch.Tensor | None = None,
):
    """
    Update the cache with a new frame.

    Args:
        new_image: (B, C, H, W) new RGB image
        new_depth: (B, 1, H, W) new depth map
        new_w2c: (B, 4, 4) new world-to-camera transformation
        new_mask: Optional (B, 1, H, W) validity mask
        new_intrinsics: (B, 3, 3) camera intrinsics (optional)
    """
    new_image = new_image.to(self.weight_dtype).to(self.device)
    new_depth = new_depth.to(self.weight_dtype).to(self.device)
    new_w2c = new_w2c.to(self.weight_dtype).to(self.device)
    if new_intrinsics is not None:
        new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device)

    new_depth = torch.nan_to_num(new_depth, nan=1e4)
    new_depth = torch.clamp(new_depth, min=0, max=1e4)

    B, F, N, V, C, H, W = self.input_image.shape

    # Compute new 3D points
    new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu()
    new_image = new_image.cpu()

    if self.filter_points_threshold < 1.0:
        new_depth = new_depth.reshape(-1, 1, H, W)
        depth_mask = reliable_depth_mask_range_batch(new_depth,
                                                     ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W)
        new_mask = depth_mask.to("cpu") if new_mask is None else new_mask * depth_mask.to(new_mask.device)
    if new_mask is not None:
        new_mask = new_mask.cpu()

    # Update buffer (newest frame first)
    if self.frame_buffer_max > 1:
        if self.input_image.shape[2] < self.frame_buffer_max:
            self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2)
            self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2)
            if self.input_mask is not None:
                self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2)
        else:
            self.input_image[:, :, 0] = new_image[:, None, None]
            self.input_points[:, :, 0] = new_points[:, None, None]
            if self.input_mask is not None:
                self.input_mask[:, :, 0] = new_mask[:, None, None]
    else:
        self.input_image = new_image[:, None, None, None]
        self.input_points = new_points[:, None, None, None]

fastvideo.pipelines.basic.gen3c.Gen3CConditioningStage

Gen3CConditioningStage(vae=None)

Bases: PipelineStage

3D cache conditioning stage for GEN3C.

This stage performs the core GEN3C innovation: 1. Loads the input image 2. Predicts depth via MoGe 3. Initializes a 3D point cloud cache 4. Generates a camera trajectory 5. Renders warped frames from the cache at each target camera pose 6. Stores rendered warps on the batch for VAE encoding in the latent prep stage

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, vae=None) -> None:
    super().__init__()
    self._moge_model: Any | None = None
    self._vae = vae

Functions

fastvideo.pipelines.basic.gen3c.Gen3CConditioningStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run 3D cache conditioning pipeline.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Run 3D cache conditioning pipeline."""
    pipeline_config = fastvideo_args.pipeline_config
    device = get_local_torch_device()
    batch_extra = getattr(batch, "extra", {}) or {}

    image_path = getattr(batch, 'image_path', None) or batch_extra.get("image_path")
    if image_path is None:
        logger.info("No image_path provided - skipping 3D cache conditioning "
                    "(will use zero conditioning)")
        return batch

    logger.info("Running 3D cache conditioning with image: %s", image_path)

    height = getattr(batch, 'height', None) or getattr(pipeline_config, 'video_resolution', (720, 1280))[0]
    width = getattr(batch, 'width', None) or getattr(pipeline_config, 'video_resolution', (720, 1280))[1]
    num_frames = getattr(batch, 'num_frames', None) or getattr(pipeline_config, 'num_frames', 121)

    trajectory_type = (getattr(batch, 'trajectory_type', None) or batch_extra.get("trajectory_type")
                       or getattr(pipeline_config, 'default_trajectory_type', 'left'))
    movement_distance = (getattr(batch, 'movement_distance', None) or batch_extra.get("movement_distance")
                         or getattr(pipeline_config, 'default_movement_distance', 0.3))
    camera_rotation = (getattr(batch, 'camera_rotation', None) or batch_extra.get("camera_rotation")
                       or getattr(pipeline_config, 'default_camera_rotation', 'center_facing'))

    frame_buffer_max = getattr(pipeline_config, 'frame_buffer_max', 2)
    noise_aug_strength = getattr(pipeline_config, 'noise_aug_strength', 0.0)
    filter_points_threshold = getattr(pipeline_config, 'filter_points_threshold', 0.05)

    moge_model_name = getattr(pipeline_config, 'moge_model_name', 'Ruicheng/moge-vitl')

    from fastvideo.pipelines.basic.gen3c.depth_estimation import (predict_depth_from_path)

    moge_model = self._get_moge_model(device, moge_model_name)

    (
        image_b1chw,
        depth_b11hw,
        mask_b11hw,
        w2c_b144,
        intrinsics_b133,
    ) = predict_depth_from_path(image_path, height, width, device, moge_model)

    logger.info(
        "Depth prediction complete. Depth range: [%.3f, %.3f]",
        depth_b11hw.min().item(),
        depth_b11hw.max().item(),
    )

    from fastvideo.pipelines.basic.gen3c.cache_3d import Cache3DBuffer

    seed = getattr(batch, 'seed', None)
    if seed is None:
        seed = 42
    generator = torch.Generator(device=device).manual_seed(seed)

    cache = Cache3DBuffer(
        frame_buffer_max=frame_buffer_max,
        generator=generator,
        noise_aug_strength=noise_aug_strength,
        input_image=image_b1chw[:, 0].clone(),
        input_depth=depth_b11hw[:, 0],
        input_w2c=w2c_b144[:, 0],
        input_intrinsics=intrinsics_b133[:, 0],
        filter_points_threshold=filter_points_threshold,
    )

    logger.info("3D cache initialized with %d frame buffer(s)", frame_buffer_max)

    from fastvideo.pipelines.basic.gen3c.camera_utils import (generate_camera_trajectory)

    initial_w2c = w2c_b144[0, 0]
    initial_intrinsics = intrinsics_b133[0, 0]

    generated_w2cs, generated_intrinsics = generate_camera_trajectory(
        trajectory_type=trajectory_type,
        initial_w2c=initial_w2c,
        initial_intrinsics=initial_intrinsics,
        num_frames=num_frames,
        movement_distance=movement_distance,
        camera_rotation=camera_rotation,
        center_depth=1.0,
        device=device.type if isinstance(device, torch.device) else device,
    )

    logger.info(
        "Camera trajectory generated: type=%s, frames=%d, distance=%.3f",
        trajectory_type,
        num_frames,
        movement_distance,
    )

    rendered_warp_images, rendered_warp_masks = cache.render_cache(
        generated_w2cs[:, :num_frames],
        generated_intrinsics[:, :num_frames],
    )

    logger.info(
        "Cache rendered. Warped images shape: %s, non-zero mask ratio: %.3f",
        list(rendered_warp_images.shape),
        (rendered_warp_masks > 0).float().mean().item(),
    )

    batch.rendered_warp_images = rendered_warp_images.to(device)
    batch.rendered_warp_masks = rendered_warp_masks.to(device)
    batch.input_image_conditioning = image_b1chw[:, 0].unsqueeze(2).contiguous().to(device)
    batch.cache_3d = cache

    if getattr(pipeline_config, "offload_moge_after_depth", True):
        self._offload_moge()

    return batch

fastvideo.pipelines.basic.gen3c.Gen3CDenoisingStage

Gen3CDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for GEN3C models.

This stage extends the base denoising stage with support for: - condition_video_input_mask: Binary mask indicating conditioning frames - condition_video_pose: VAE-encoded 3D cache buffers - condition_video_augment_sigma: Noise augmentation sigma

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.basic.gen3c.Gen3CLatentPreparationStage

Gen3CLatentPreparationStage(scheduler, transformer, vae)

Bases: LatentPreparationStage

Latent preparation stage for GEN3C.

This stage prepares latents and encodes 3D cache buffers through the VAE. If rendered warped frames are available on the batch (from Gen3CConditioningStage), they are VAE-encoded to produce real conditioning. Otherwise falls back to zeros.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, scheduler, transformer, vae) -> None:
    super().__init__(scheduler, transformer)
    self.vae = vae

Functions

fastvideo.pipelines.basic.gen3c.Gen3CLatentPreparationStage.encode_warped_frames
encode_warped_frames(condition_state: Tensor, condition_state_mask: Tensor, vae: Any, frame_buffer_max: int, dtype: dtype) -> Tensor

Encode rendered 3D cache buffers through VAE.

Parameters:

Name Type Description Default
condition_state Tensor

(B, T, N, 3, H, W) rendered RGB images in [-1, 1].

required
condition_state_mask Tensor

(B, T, N, 1, H, W) rendered masks in [0, 1].

required
vae Any

VAE encoder.

required
frame_buffer_max int

Maximum number of buffers.

required
dtype dtype

Target dtype.

required

Returns:

Name Type Description
latent_condition Tensor

(B, buffer_channels, T_latent, H_latent, W_latent)

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def encode_warped_frames(
    self,
    condition_state: torch.Tensor,
    condition_state_mask: torch.Tensor,
    vae: Any,
    frame_buffer_max: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Encode rendered 3D cache buffers through VAE.

    Args:
        condition_state: (B, T, N, 3, H, W) rendered RGB images in [-1, 1].
        condition_state_mask: (B, T, N, 1, H, W) rendered masks in [0, 1].
        vae: VAE encoder.
        frame_buffer_max: Maximum number of buffers.
        dtype: Target dtype.

    Returns:
        latent_condition: (B, buffer_channels, T_latent, H_latent, W_latent)
    """
    assert condition_state.dim() == 6

    condition_state_mask = (condition_state_mask * 2 - 1).repeat(1, 1, 1, 3, 1, 1)

    latent_condition = []
    num_buffers = condition_state.shape[2]
    for i in range(num_buffers):
        img_input = condition_state[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
        mask_input = condition_state_mask[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
        batched_input = torch.cat([img_input, mask_input], dim=0)
        batched_latent = self._retrieve_latents(vae.encode(batched_input)).contiguous()
        current_video_latent, current_mask_latent = batched_latent.chunk(2, dim=0)

        latent_condition.append(current_video_latent)
        latent_condition.append(current_mask_latent)

    for _ in range(frame_buffer_max - num_buffers):
        latent_condition.append(torch.zeros_like(current_video_latent))
        latent_condition.append(torch.zeros_like(current_mask_latent))

    return torch.cat(latent_condition, dim=1)
fastvideo.pipelines.basic.gen3c.Gen3CLatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare latents and encode 3D cache buffers.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Prepare latents and encode 3D cache buffers."""
    pipeline_config = fastvideo_args.pipeline_config
    device = get_local_torch_device()

    if isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]
    batch_size *= batch.num_videos_per_prompt

    num_channels_latents = getattr(self.transformer, 'num_channels_latents', 16)

    fallback_num_frames = getattr(pipeline_config, 'num_frames', 121)
    if not isinstance(fallback_num_frames, int):
        fallback_num_frames = 121

    num_frames_raw = getattr(batch, 'num_frames', None)
    if num_frames_raw is None:
        num_frames_raw = fallback_num_frames
    if isinstance(num_frames_raw, list):
        num_frames = int(num_frames_raw[0]) if len(num_frames_raw) > 0 else fallback_num_frames
    elif isinstance(num_frames_raw, int):
        num_frames = num_frames_raw
    else:
        num_frames = fallback_num_frames
    if hasattr(self.vae, "get_latent_num_frames"):
        latent_frames = int(self.vae.get_latent_num_frames(num_frames))
    else:
        temporal_ratio = getattr(
            pipeline_config.vae_config.arch_config,
            "temporal_compression_ratio",
            4,
        )
        latent_frames = int((num_frames - 1) // temporal_ratio + 1)
    height = getattr(batch, 'height', 720)
    width = getattr(batch, 'width', 1280)

    spatial_ratio = getattr(
        pipeline_config.vae_config.arch_config,
        "spatial_compression_ratio",
        8,
    )
    latent_height = height // spatial_ratio
    latent_width = width // spatial_ratio

    generator = getattr(batch, "generator", None)
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(f"Expected {batch_size} generators, got {len(generator)}.")

    latents = randn_tensor(
        (
            batch_size,
            num_channels_latents,
            latent_frames,
            latent_height,
            latent_width,
        ),
        generator=generator,
        device=device,
        dtype=torch.float32,
    )

    if hasattr(self.scheduler, 'init_noise_sigma'):
        latents = latents * self.scheduler.init_noise_sigma

    batch.latents = latents
    batch.batch_size = batch_size
    batch.height = height
    batch.width = width
    batch.latent_height = latent_height
    batch.latent_width = latent_width
    batch.latent_frames = latent_frames
    batch.raw_latent_shape = latents.shape

    frame_buffer_max = getattr(pipeline_config, 'frame_buffer_max', 2)
    channels_per_buffer = 32
    buffer_channels = frame_buffer_max * channels_per_buffer

    rendered_warp_images = getattr(batch, 'rendered_warp_images', None)
    rendered_warp_masks = getattr(batch, 'rendered_warp_masks', None)

    if rendered_warp_images is not None and rendered_warp_masks is not None:
        logger.info(
            "Encoding rendered warped frames through VAE (%d buffers)...",
            rendered_warp_images.shape[2],
        )

        self.vae = self.vae.to(device)

        if hasattr(self.vae, 'module'):
            vae_dtype = next(self.vae.module.parameters()).dtype
        else:
            vae_dtype = next(self.vae.parameters()).dtype

        condition_video_pose = self.encode_warped_frames(
            rendered_warp_images,
            rendered_warp_masks,
            self.vae,
            frame_buffer_max,
            vae_dtype,
        )
        batch.condition_video_pose = condition_video_pose.to(device)

        logger.info(
            "condition_video_pose encoded. Shape: %s, non-zero: %.4f",
            list(batch.condition_video_pose.shape),
            (batch.condition_video_pose != 0).float().mean().item(),
        )

        source_image = getattr(batch, "input_image_conditioning", None)
        if source_image is None:
            source_image = rendered_warp_images[:, 0, 0].unsqueeze(2)
        first_frame = source_image.to(device=device, dtype=vae_dtype)
        first_latent = self._retrieve_latents(self.vae.encode(first_frame))
        conditioning_latents = torch.zeros(
            batch_size,
            num_channels_latents,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=first_latent.dtype,
        )
        conditioning_latents[:, :, :first_latent.shape[2], :, :] = first_latent
        batch.conditioning_latents = conditioning_latents

        if fastvideo_args.vae_cpu_offload:
            self.vae.to("cpu")

        batch.condition_video_input_mask = torch.zeros(
            batch_size,
            1,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.condition_video_input_mask[:, :, 0, :, :] = 1.0
    else:
        logger.info("No rendered warps available - using zero conditioning")
        batch.condition_video_pose = torch.zeros(
            batch_size,
            buffer_channels,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.condition_video_input_mask = torch.zeros(
            batch_size,
            1,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.conditioning_latents = None

    batch.condition_video_augment_sigma = torch.zeros(batch_size, device=device, dtype=torch.float32)
    batch.cond_indicator = torch.zeros(
        batch_size,
        1,
        latent_frames,
        latent_height,
        latent_width,
        device=device,
        dtype=torch.float32,
    )
    batch.cond_indicator[:, :, 0, :, :] = 1.0

    ones_padding = torch.ones_like(batch.cond_indicator)
    zeros_padding = torch.zeros_like(batch.cond_indicator)
    batch.cond_mask = batch.cond_indicator * ones_padding + (1 - batch.cond_indicator) * zeros_padding

    if batch.do_classifier_free_guidance:
        batch.uncond_indicator = batch.cond_indicator.clone()
        batch.uncond_mask = batch.cond_mask.clone()
    else:
        batch.uncond_indicator = None
        batch.uncond_mask = None

    return batch

fastvideo.pipelines.basic.gen3c.Gen3CPipeline

Gen3CPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

GEN3C Video Generation Pipeline.

This pipeline extends Cosmos with 3D cache support for camera-controlled video generation. When an input image is provided, it runs the full 3D cache conditioning pipeline (depth estimation -> point cloud -> camera trajectory -> forward warping -> VAE encoding).

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)

Functions

fastvideo.pipelines.basic.gen3c.Gen3CPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/gen3c/gen3c_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="cfg_policy_stage", stage=Gen3CCFGPolicyStage())

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=Gen3CConditioningStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=Gen3CLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                     transformer=self.get_module("transformer"),
                                                     vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=Gen3CDenoisingStage(transformer=self.get_module("transformer"),
                                             scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))

Functions

fastvideo.pipelines.basic.gen3c.forward_warp

forward_warp(frame1: Tensor, mask1: Tensor | None, depth1: Tensor | None, transformation1: Tensor | None, transformation2: Tensor, intrinsic1: Tensor | None, intrinsic2: Tensor | None, is_image: bool = True, is_depth: bool = True, render_depth: bool = False, world_points1: Tensor | None = None) -> tuple[Tensor, Tensor, Tensor | None, Tensor]

Forward warp frame1 to a new view defined by transformation2.

Parameters:

Name Type Description Default
frame1 Tensor

(b, c, h, w) source frame in range [-1, 1] for images

required
mask1 Tensor | None

(b, 1, h, w) valid pixel mask

required
depth1 Tensor | None

(b, 1, h, w) depth map (required if world_points1 is None)

required
transformation1 Tensor | None

(b, 4, 4) source camera w2c (required if depth1 is provided)

required
transformation2 Tensor

(b, 4, 4) target camera w2c

required
intrinsic1 Tensor | None

(b, 3, 3) source camera intrinsics

required
intrinsic2 Tensor | None

(b, 3, 3) target camera intrinsics

required
is_image bool

If True, output will be clipped to (-1, 1)

True
is_depth bool

If True, depth1 is z-depth; if False, it's distance

True
render_depth bool

If True, also return the warped depth map

False
world_points1 Tensor | None

(b, h, w, 3) pre-computed world points (alternative to depth1)

None

Returns:

Name Type Description
warped_frame2 Tensor

(b, c, h, w) warped frame

mask2 Tensor

(b, 1, h, w) validity mask

warped_depth2 Tensor | None

(b, h, w) warped depth (if render_depth=True)

flow12 Tensor

(b, 2, h, w) optical flow

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def forward_warp(
    frame1: torch.Tensor,
    mask1: torch.Tensor | None,
    depth1: torch.Tensor | None,
    transformation1: torch.Tensor | None,
    transformation2: torch.Tensor,
    intrinsic1: torch.Tensor | None,
    intrinsic2: torch.Tensor | None,
    is_image: bool = True,
    is_depth: bool = True,
    render_depth: bool = False,
    world_points1: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
    """
    Forward warp frame1 to a new view defined by transformation2.

    Args:
        frame1: (b, c, h, w) source frame in range [-1, 1] for images
        mask1: (b, 1, h, w) valid pixel mask
        depth1: (b, 1, h, w) depth map (required if world_points1 is None)
        transformation1: (b, 4, 4) source camera w2c (required if depth1 is provided)
        transformation2: (b, 4, 4) target camera w2c
        intrinsic1: (b, 3, 3) source camera intrinsics
        intrinsic2: (b, 3, 3) target camera intrinsics
        is_image: If True, output will be clipped to (-1, 1)
        is_depth: If True, depth1 is z-depth; if False, it's distance
        render_depth: If True, also return the warped depth map
        world_points1: (b, h, w, 3) pre-computed world points (alternative to depth1)

    Returns:
        warped_frame2: (b, c, h, w) warped frame
        mask2: (b, 1, h, w) validity mask
        warped_depth2: (b, h, w) warped depth (if render_depth=True)
        flow12: (b, 2, h, w) optical flow
    """
    device = frame1.device
    b, c, h, w = frame1.shape
    dtype = frame1.dtype

    if mask1 is None:
        mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype)
    if intrinsic2 is None:
        assert intrinsic1 is not None
        intrinsic2 = intrinsic1.clone()

    if world_points1 is not None:
        # Use pre-computed world points
        assert world_points1.shape == (b, h, w, 3)
        trans_points1 = project_points(world_points1, transformation2, intrinsic2)
    else:
        # Compute from depth
        assert depth1 is not None and transformation1 is not None
        assert depth1.shape == (b, 1, h, w)

        depth1 = torch.nan_to_num(depth1, nan=1e4)
        depth1 = torch.clamp(depth1, min=0, max=1e4)

        # Unproject to world, then project to target view
        world_points1 = unproject_points(depth1, transformation1, intrinsic1, is_depth=is_depth)
        trans_points1 = project_points(world_points1, transformation2, intrinsic2)

    # Filter points behind camera
    mask1 = mask1 * (trans_points1[:, :, :, 2, 0].unsqueeze(1) > 0)
    trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0] + 1e-7)
    trans_coordinates = trans_coordinates.permute(0, 3, 1, 2)  # b, 2, h, w
    trans_depth1 = trans_points1[:, :, :, 2, 0].unsqueeze(1)

    grid = create_grid(b, h, w, device=device, dtype=dtype)
    flow12 = trans_coordinates - grid

    warped_frame2, mask2 = bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image)

    warped_depth2 = None
    if render_depth:
        warped_depth2 = bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False)[0][:, 0]

    return warped_frame2, mask2, warped_depth2, flow12

fastvideo.pipelines.basic.gen3c.generate_camera_trajectory

generate_camera_trajectory(trajectory_type: str, initial_w2c: Tensor, initial_intrinsics: Tensor, num_frames: int, movement_distance: float, camera_rotation: str = 'center_facing', center_depth: float = 1.0, device: str = 'cuda') -> tuple[Tensor, Tensor]

Generate camera trajectory for GEN3C video generation.

Parameters:

Name Type Description Default
trajectory_type str

One of "left", "right", "up", "down", "zoom_in", "zoom_out", "clockwise", "counterclockwise".

required
initial_w2c Tensor

Initial world-to-camera matrix (4, 4).

required
initial_intrinsics Tensor

Camera intrinsics matrix (3, 3).

required
num_frames int

Number of frames in the trajectory.

required
movement_distance float

Distance factor for camera movement.

required
camera_rotation str

"center_facing", "no_rotation", or "trajectory_aligned".

'center_facing'
center_depth float

Depth of the scene center point.

1.0
device str

Computation device.

'cuda'

Returns:

Name Type Description
generated_w2cs Tensor

(1, num_frames, 4, 4) world-to-camera matrices.

generated_intrinsics Tensor

(1, num_frames, 3, 3) camera intrinsics.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def generate_camera_trajectory(
    trajectory_type: str,
    initial_w2c: torch.Tensor,
    initial_intrinsics: torch.Tensor,
    num_frames: int,
    movement_distance: float,
    camera_rotation: str = "center_facing",
    center_depth: float = 1.0,
    device: str = "cuda",
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Generate camera trajectory for GEN3C video generation.

    Args:
        trajectory_type: One of "left", "right", "up", "down", "zoom_in",
            "zoom_out", "clockwise", "counterclockwise".
        initial_w2c: Initial world-to-camera matrix (4, 4).
        initial_intrinsics: Camera intrinsics matrix (3, 3).
        num_frames: Number of frames in the trajectory.
        movement_distance: Distance factor for camera movement.
        camera_rotation: "center_facing", "no_rotation", or "trajectory_aligned".
        center_depth: Depth of the scene center point.
        device: Computation device.

    Returns:
        generated_w2cs: (1, num_frames, 4, 4) world-to-camera matrices.
        generated_intrinsics: (1, num_frames, 3, 3) camera intrinsics.
    """
    if trajectory_type in ["clockwise", "counterclockwise"]:
        new_w2cs_seq = create_spiral_trajectory(
            world_to_camera_matrix=initial_w2c,
            center_depth=center_depth,
            n_steps=num_frames,
            positive=trajectory_type == "clockwise",
            device=device,
            camera_rotation=camera_rotation,
            radius_x=movement_distance,
            radius_y=movement_distance,
        )
    elif trajectory_type == "none":
        # Static camera - repeat identity
        new_w2cs_seq = initial_w2c.unsqueeze(0).expand(num_frames, -1, -1)
    else:
        axis_map = {
            "left": (False, "x"),
            "right": (True, "x"),
            "up": (False, "y"),
            "down": (True, "y"),
            "zoom_in": (True, "z"),
            "zoom_out": (False, "z"),
        }
        if trajectory_type not in axis_map:
            raise ValueError(f"Unsupported trajectory type: {trajectory_type}")
        positive, axis = axis_map[trajectory_type]

        new_w2cs_seq = create_horizontal_trajectory(
            world_to_camera_matrix=initial_w2c,
            center_depth=center_depth,
            n_steps=num_frames,
            positive=positive,
            axis=axis,
            distance=movement_distance,
            device=device,
            camera_rotation=camera_rotation,
        )

    generated_w2cs = new_w2cs_seq.unsqueeze(0)  # (1, num_frames, 4, 4)
    if initial_intrinsics.dim() == 2:
        generated_intrinsics = initial_intrinsics.unsqueeze(0).unsqueeze(0).repeat(1, num_frames, 1, 1)
    else:
        generated_intrinsics = initial_intrinsics.unsqueeze(0)

    return generated_w2cs, generated_intrinsics

fastvideo.pipelines.basic.gen3c.project_points

project_points(world_points: Tensor, w2c: Tensor, intrinsic: Tensor) -> Tensor

Project 3D world points to 2D pixel coordinates.

Parameters:

Name Type Description Default
world_points Tensor

(b, h, w, 3) 3D world coordinates

required
w2c Tensor

(b, 4, 4) world-to-camera transformation matrix

required
intrinsic Tensor

(b, 3, 3) camera intrinsic matrix

required

Returns:

Name Type Description
projected_points Tensor

(b, h, w, 3, 1) projected 2D coordinates (x, y, z)

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def project_points(
    world_points: torch.Tensor,
    w2c: torch.Tensor,
    intrinsic: torch.Tensor,
) -> torch.Tensor:
    """
    Project 3D world points to 2D pixel coordinates.

    Args:
        world_points: (b, h, w, 3) 3D world coordinates
        w2c: (b, 4, 4) world-to-camera transformation matrix
        intrinsic: (b, 3, 3) camera intrinsic matrix

    Returns:
        projected_points: (b, h, w, 3, 1) projected 2D coordinates (x, y, z)
    """
    world_points = world_points.unsqueeze(-1)  # (b, h, w, 3, 1)
    b, h, w, _, _ = world_points.shape

    ones_4d = torch.ones((b, h, w, 1, 1), device=world_points.device, dtype=world_points.dtype)
    world_points_homo = torch.cat([world_points, ones_4d], dim=3)  # (b, h, w, 4, 1)

    trans_4d = w2c[:, None, None]  # (b, 1, 1, 4, 4)
    camera_points_homo = torch.matmul(trans_4d, world_points_homo)  # (b, h, w, 4, 1)

    camera_points = camera_points_homo[:, :, :, :3]  # (b, h, w, 3, 1)
    intrinsic_4d = intrinsic[:, None, None]  # (b, 1, 1, 3, 3)
    projected_points = torch.matmul(intrinsic_4d, camera_points)  # (b, h, w, 3, 1)

    return projected_points

fastvideo.pipelines.basic.gen3c.unproject_points

unproject_points(depth: Tensor, w2c: Tensor, intrinsic: Tensor, is_depth: bool = True, mask: Tensor | None = None) -> Tensor

Unproject depth map to 3D world points.

Parameters:

Name Type Description Default
depth Tensor

(b, 1, h, w) depth map

required
w2c Tensor

(b, 4, 4) world-to-camera transformation matrix

required
intrinsic Tensor

(b, 3, 3) camera intrinsic matrix

required
is_depth bool

If True, depth is z-depth; if False, depth is distance to camera

True
mask Tensor | None

Optional (b, h, w) or (b, 1, h, w) mask for valid pixels

None

Returns:

Name Type Description
world_points Tensor

(b, h, w, 3) 3D world coordinates

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def unproject_points(
    depth: torch.Tensor,
    w2c: torch.Tensor,
    intrinsic: torch.Tensor,
    is_depth: bool = True,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Unproject depth map to 3D world points.

    Args:
        depth: (b, 1, h, w) depth map
        w2c: (b, 4, 4) world-to-camera transformation matrix
        intrinsic: (b, 3, 3) camera intrinsic matrix
        is_depth: If True, depth is z-depth; if False, depth is distance to camera
        mask: Optional (b, h, w) or (b, 1, h, w) mask for valid pixels

    Returns:
        world_points: (b, h, w, 3) 3D world coordinates
    """
    b, _, h, w = depth.shape
    device = depth.device
    dtype = depth.dtype

    if mask is None:
        mask = depth > 0
    if mask.dim() == depth.dim() and mask.shape[1] == 1:
        mask = mask[:, 0]

    idx = torch.nonzero(mask)
    if idx.numel() == 0:
        return torch.zeros((b, h, w, 3), device=device, dtype=dtype)

    b_idx, y_idx, x_idx = idx[:, 0], idx[:, 1], idx[:, 2]

    intrinsic_inv = inverse_with_conversion(intrinsic)  # (b, 3, 3)

    x_valid = x_idx.to(dtype)
    y_valid = y_idx.to(dtype)
    ones = torch.ones_like(x_valid)
    pos = torch.stack([x_valid, y_valid, ones], dim=1).unsqueeze(-1)  # (N, 3, 1)

    intrinsic_inv_valid = intrinsic_inv[b_idx]  # (N, 3, 3)
    unnormalized_pos = torch.matmul(intrinsic_inv_valid, pos)  # (N, 3, 1)

    depth_valid = depth[b_idx, 0, y_idx, x_idx].view(-1, 1, 1)
    if is_depth:
        world_points_cam = depth_valid * unnormalized_pos
    else:
        norm_val = torch.norm(unnormalized_pos, dim=1, keepdim=True)
        direction = unnormalized_pos / (norm_val + 1e-8)
        world_points_cam = depth_valid * direction

    ones_h = torch.ones((world_points_cam.shape[0], 1, 1), device=device, dtype=dtype)
    world_points_homo = torch.cat([world_points_cam, ones_h], dim=1)  # (N, 4, 1)

    trans = inverse_with_conversion(w2c)  # (b, 4, 4)
    trans_valid = trans[b_idx]  # (N, 4, 4)
    world_points_transformed = torch.matmul(trans_valid, world_points_homo)  # (N, 4, 1)
    sparse_points = world_points_transformed[:, :3, 0]  # (N, 3)

    out_points = torch.zeros((b, h, w, 3), device=device, dtype=dtype)
    out_points[b_idx, y_idx, x_idx, :] = sparse_points
    return out_points

Modules

fastvideo.pipelines.basic.gen3c.cache_3d

This module implements the 3D cache system for GEN3C video generation with camera control. The cache maintains a point cloud representation of the scene, enabling: - Unprojecting depth maps to 3D world points - Forward warping rendered views to new camera poses - Managing multiple frame buffers for temporal consistency

Classes

fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBase
Cache3DBase(input_image: Tensor, input_depth: Tensor, input_w2c: Tensor, input_intrinsics: Tensor, input_mask: Tensor | None = None, input_format: list[str] | None = None, input_points: Tensor | None = None, weight_dtype: dtype = float32, is_depth: bool = True, device: str = 'cuda', filter_points_threshold: float = 1.0)

Base class for 3D cache management.

The cache maintains: - input_image: RGB images stored in the cache - input_points: 3D world coordinates for each pixel - input_mask: Validity mask for each pixel

Initialize the 3D cache.

Parameters:

Name Type Description Default
input_image Tensor

Input image tensor with varying dimensions

required
input_depth Tensor

Depth map tensor

required
input_w2c Tensor

World-to-camera transformation matrix

required
input_intrinsics Tensor

Camera intrinsic matrix

required
input_mask Tensor | None

Optional validity mask

None
input_format list[str] | None

Dimension labels for input_image (e.g., ['B', 'C', 'H', 'W'])

None
input_points Tensor | None

Pre-computed 3D world points (alternative to depth)

None
weight_dtype dtype

Data type for computations

float32
is_depth bool

If True, input_depth is z-depth; if False, it's distance

True
device str

Computation device

'cuda'
filter_points_threshold float

Threshold for filtering unreliable depth

1.0
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def __init__(
    self,
    input_image: torch.Tensor,
    input_depth: torch.Tensor,
    input_w2c: torch.Tensor,
    input_intrinsics: torch.Tensor,
    input_mask: torch.Tensor | None = None,
    input_format: list[str] | None = None,
    input_points: torch.Tensor | None = None,
    weight_dtype: torch.dtype = torch.float32,
    is_depth: bool = True,
    device: str = "cuda",
    filter_points_threshold: float = 1.0,
):
    """
    Initialize the 3D cache.

    Args:
        input_image: Input image tensor with varying dimensions
        input_depth: Depth map tensor
        input_w2c: World-to-camera transformation matrix
        input_intrinsics: Camera intrinsic matrix
        input_mask: Optional validity mask
        input_format: Dimension labels for input_image (e.g., ['B', 'C', 'H', 'W'])
        input_points: Pre-computed 3D world points (alternative to depth)
        weight_dtype: Data type for computations
        is_depth: If True, input_depth is z-depth; if False, it's distance
        device: Computation device
        filter_points_threshold: Threshold for filtering unreliable depth
    """
    self.weight_dtype = weight_dtype
    self.is_depth = is_depth
    self.device = device
    self.filter_points_threshold = filter_points_threshold

    if input_format is None:
        assert input_image.dim() == 4
        input_format = ["B", "C", "H", "W"]

    # Map dimension names to indices
    format_to_indices = {dim: idx for idx, dim in enumerate(input_format)}
    input_shape = input_image.shape

    if input_mask is not None:
        input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C"))

    # Extract dimensions
    B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1
    F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1
    N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1
    V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1
    H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None
    W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None

    # Reorder dimensions to B x F x N x V x C x H x W
    desired_dims = ["B", "F", "N", "V", "C", "H", "W"]
    permute_order: list[int | None] = []
    for dim in desired_dims:
        idx = format_to_indices.get(dim)
        permute_order.append(idx)

    permute_indices = [idx for idx in permute_order if idx is not None]
    input_image = input_image.permute(*permute_indices)

    for i, idx in enumerate(permute_order):
        if idx is None:
            input_image = input_image.unsqueeze(i)

    # Now input_image has shape B x F x N x V x C x H x W
    if input_mask is not None:
        self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:]
        self.input_mask = self.input_mask.to("cpu")
    else:
        self.input_mask = None
        self.input_image = input_image
    self.input_image = self.input_image.to(weight_dtype).to("cpu")

    # Compute 3D world points
    if input_points is not None:
        self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu")
        self.input_depth = None
    else:
        input_depth = torch.nan_to_num(input_depth, nan=100)
        input_depth = torch.clamp(input_depth, min=0, max=100)
        if weight_dtype == torch.float16:
            input_depth = torch.clamp(input_depth, max=70)

        self.input_points = (unproject_points(
            input_depth.reshape(-1, 1, H, W),
            input_w2c.reshape(-1, 4, 4),
            input_intrinsics.reshape(-1, 3, 3),
            is_depth=self.is_depth,
        ).to(weight_dtype).reshape(B, F, N, V, H, W, 3).to("cpu"))
        self.input_depth = input_depth

    # Filter unreliable depth
    if self.filter_points_threshold < 1.0 and input_depth is not None:
        input_depth = input_depth.reshape(-1, 1, H, W)
        depth_mask = reliable_depth_mask_range_batch(input_depth,
                                                     ratio_thresh=self.filter_points_threshold).reshape(
                                                         B, F, N, V, 1, H, W)
        if self.input_mask is None:
            self.input_mask = depth_mask.to("cpu")
        else:
            self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device)
Functions
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBase.input_frame_count
input_frame_count() -> int

Return the number of frames in the cache.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def input_frame_count(self) -> int:
    """Return the number of frames in the cache."""
    return self.input_image.shape[1]
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBase.render_cache
render_cache(target_w2cs: Tensor, target_intrinsics: Tensor, render_depth: bool = False, start_frame_idx: int = 0) -> tuple[Tensor, Tensor]

Render the cached 3D points from new camera viewpoints.

Parameters:

Name Type Description Default
target_w2cs Tensor

(b, F_target, 4, 4) target camera transformations

required
target_intrinsics Tensor

(b, F_target, 3, 3) target camera intrinsics

required
render_depth bool

If True, return depth instead of RGB

False
start_frame_idx int

Starting frame index in the cache

0

Returns:

Name Type Description
pixels Tensor

(b, F_target, N, c, h, w) rendered images or depth

masks Tensor

(b, F_target, N, 1, h, w) validity masks

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def render_cache(
    self,
    target_w2cs: torch.Tensor,
    target_intrinsics: torch.Tensor,
    render_depth: bool = False,
    start_frame_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the cached 3D points from new camera viewpoints.

    Args:
        target_w2cs: (b, F_target, 4, 4) target camera transformations
        target_intrinsics: (b, F_target, 3, 3) target camera intrinsics
        render_depth: If True, return depth instead of RGB
        start_frame_idx: Starting frame index in the cache

    Returns:
        pixels: (b, F_target, N, c, h, w) rendered images or depth
        masks: (b, F_target, N, 1, h, w) validity masks
    """
    bs, F_target, _, _ = target_w2cs.shape
    B, F, N, V, C, H, W = self.input_image.shape
    assert bs == B

    target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4)
    target_intrinsics = target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3,
                                                                               3).reshape(-1, 3, 3)

    # Prepare inputs
    first_images = rearrange(
        self.input_image[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, C, H, W),
        "B F N V C H W -> (B F N) V C H W")
    first_points = rearrange(
        self.input_points[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, H, W, 3),
        "B F N V H W C -> (B F N) V H W C")
    first_masks = rearrange(
        self.input_mask[:, start_frame_idx:start_frame_idx + F_target].expand(B, F_target, N, V, 1, H, W),
        "B F N V C H W -> (B F N) V C H W") if self.input_mask is not None else None

    # Process in chunks for memory efficiency
    if first_images.shape[1] == 1:
        warp_chunk_size = 2
        rendered_warp_images = []
        rendered_warp_masks = []
        rendered_warp_depth = []

        first_images = first_images.squeeze(1)
        first_points = first_points.squeeze(1)
        first_masks = first_masks.squeeze(1) if first_masks is not None else None

        for i in range(0, first_images.shape[0], warp_chunk_size):
            with torch.no_grad():
                imgs_chunk = first_images[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                pts_chunk = first_points[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                masks_chunk = (first_masks[i:i + warp_chunk_size].to(self.device, non_blocking=True)
                               if first_masks is not None else None)

                (
                    rendered_warp_images_chunk,
                    rendered_warp_masks_chunk,
                    rendered_warp_depth_chunk,
                    _,
                ) = forward_warp(
                    imgs_chunk,
                    mask1=masks_chunk,
                    depth1=None,
                    transformation1=None,
                    transformation2=target_w2cs[i:i + warp_chunk_size],
                    intrinsic1=target_intrinsics[i:i + warp_chunk_size],
                    intrinsic2=target_intrinsics[i:i + warp_chunk_size],
                    render_depth=render_depth,
                    world_points1=pts_chunk,
                )

                rendered_warp_images.append(rendered_warp_images_chunk.to("cpu"))
                rendered_warp_masks.append(rendered_warp_masks_chunk.to("cpu"))
                if render_depth:
                    rendered_warp_depth.append(rendered_warp_depth_chunk.to("cpu"))

                del imgs_chunk, pts_chunk, masks_chunk
                torch.cuda.empty_cache()

        rendered_warp_images = torch.cat(rendered_warp_images, dim=0)
        rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0)
        if render_depth:
            rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0)
    else:
        raise NotImplementedError("Multi-view rendering not yet supported")

    pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)
    masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)

    if render_depth:
        pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N)

    return pixels.to(self.device), masks.to(self.device)
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBase.update_cache
update_cache(**kwargs)

Update the cache with new frames. To be implemented by subclasses.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def update_cache(self, **kwargs):
    """Update the cache with new frames. To be implemented by subclasses."""
    raise NotImplementedError
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBuffer
Cache3DBuffer(frame_buffer_max: int = 2, noise_aug_strength: float = 0.0, generator: Generator | None = None, **kwargs)

Bases: Cache3DBase

3D cache with frame buffer support.

This class manages multiple frame buffers for temporal consistency and supports noise augmentation for training stability.

Initialize the buffered 3D cache.

Parameters:

Name Type Description Default
frame_buffer_max int

Maximum number of frames to buffer

2
noise_aug_strength float

Strength of noise augmentation per buffer

0.0
generator Generator | None

Random generator for reproducibility

None
**kwargs

Arguments passed to Cache3DBase

{}
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def __init__(
    self,
    frame_buffer_max: int = 2,
    noise_aug_strength: float = 0.0,
    generator: torch.Generator | None = None,
    **kwargs,
):
    """
    Initialize the buffered 3D cache.

    Args:
        frame_buffer_max: Maximum number of frames to buffer
        noise_aug_strength: Strength of noise augmentation per buffer
        generator: Random generator for reproducibility
        **kwargs: Arguments passed to Cache3DBase
    """
    super().__init__(**kwargs)
    self.frame_buffer_max = frame_buffer_max
    self.noise_aug_strength = noise_aug_strength
    self.generator = generator
Functions
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBuffer.render_cache
render_cache(target_w2cs: Tensor, target_intrinsics: Tensor, render_depth: bool = False, start_frame_idx: int = 0) -> tuple[Tensor, Tensor]

Render the cache with optional noise augmentation.

Parameters:

Name Type Description Default
target_w2cs Tensor

(b, F_target, 4, 4) target camera transformations

required
target_intrinsics Tensor

(b, F_target, 3, 3) target camera intrinsics

required
render_depth bool

If True, return depth instead of RGB

False
start_frame_idx int

Starting frame index (must be 0 for this class)

0

Returns:

Name Type Description
pixels Tensor

(b, F_target, N, c, h, w) rendered images

masks Tensor

(b, F_target, N, 1, h, w) validity masks

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def render_cache(
    self,
    target_w2cs: torch.Tensor,
    target_intrinsics: torch.Tensor,
    render_depth: bool = False,
    start_frame_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Render the cache with optional noise augmentation.

    Args:
        target_w2cs: (b, F_target, 4, 4) target camera transformations
        target_intrinsics: (b, F_target, 3, 3) target camera intrinsics
        render_depth: If True, return depth instead of RGB
        start_frame_idx: Starting frame index (must be 0 for this class)

    Returns:
        pixels: (b, F_target, N, c, h, w) rendered images
        masks: (b, F_target, N, 1, h, w) validity masks
    """
    assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3DBuffer"

    output_device = target_w2cs.device
    target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device)
    target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device)

    pixels, masks = super().render_cache(target_w2cs, target_intrinsics, render_depth)

    pixels = pixels.to(output_device)
    masks = masks.to(output_device)

    # Apply noise augmentation (stronger for older buffers)
    if not render_depth and self.noise_aug_strength > 0:
        noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype)
        per_buffer_noise = (torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device) *
                            self.noise_aug_strength)
        pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1)

    return pixels, masks
fastvideo.pipelines.basic.gen3c.cache_3d.Cache3DBuffer.update_cache
update_cache(new_image: Tensor, new_depth: Tensor, new_w2c: Tensor, new_mask: Tensor | None = None, new_intrinsics: Tensor | None = None)

Update the cache with a new frame.

Parameters:

Name Type Description Default
new_image Tensor

(B, C, H, W) new RGB image

required
new_depth Tensor

(B, 1, H, W) new depth map

required
new_w2c Tensor

(B, 4, 4) new world-to-camera transformation

required
new_mask Tensor | None

Optional (B, 1, H, W) validity mask

None
new_intrinsics Tensor | None

(B, 3, 3) camera intrinsics (optional)

None
Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def update_cache(
    self,
    new_image: torch.Tensor,
    new_depth: torch.Tensor,
    new_w2c: torch.Tensor,
    new_mask: torch.Tensor | None = None,
    new_intrinsics: torch.Tensor | None = None,
):
    """
    Update the cache with a new frame.

    Args:
        new_image: (B, C, H, W) new RGB image
        new_depth: (B, 1, H, W) new depth map
        new_w2c: (B, 4, 4) new world-to-camera transformation
        new_mask: Optional (B, 1, H, W) validity mask
        new_intrinsics: (B, 3, 3) camera intrinsics (optional)
    """
    new_image = new_image.to(self.weight_dtype).to(self.device)
    new_depth = new_depth.to(self.weight_dtype).to(self.device)
    new_w2c = new_w2c.to(self.weight_dtype).to(self.device)
    if new_intrinsics is not None:
        new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device)

    new_depth = torch.nan_to_num(new_depth, nan=1e4)
    new_depth = torch.clamp(new_depth, min=0, max=1e4)

    B, F, N, V, C, H, W = self.input_image.shape

    # Compute new 3D points
    new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu()
    new_image = new_image.cpu()

    if self.filter_points_threshold < 1.0:
        new_depth = new_depth.reshape(-1, 1, H, W)
        depth_mask = reliable_depth_mask_range_batch(new_depth,
                                                     ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W)
        new_mask = depth_mask.to("cpu") if new_mask is None else new_mask * depth_mask.to(new_mask.device)
    if new_mask is not None:
        new_mask = new_mask.cpu()

    # Update buffer (newest frame first)
    if self.frame_buffer_max > 1:
        if self.input_image.shape[2] < self.frame_buffer_max:
            self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2)
            self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2)
            if self.input_mask is not None:
                self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2)
        else:
            self.input_image[:, :, 0] = new_image[:, None, None]
            self.input_points[:, :, 0] = new_points[:, None, None]
            if self.input_mask is not None:
                self.input_mask[:, :, 0] = new_mask[:, None, None]
    else:
        self.input_image = new_image[:, None, None, None]
        self.input_points = new_points[:, None, None, None]

Functions

fastvideo.pipelines.basic.gen3c.cache_3d.bilinear_splatting
bilinear_splatting(frame1: Tensor, mask1: Tensor | None, depth1: Tensor, flow12: Tensor, flow12_mask: Tensor | None = None, is_image: bool = False, depth_weight_scale: float = 50.0) -> tuple[Tensor, Tensor]

Bilinear splatting for forward warping.

Parameters:

Name Type Description Default
frame1 Tensor

(b, c, h, w) source frame

required
mask1 Tensor | None

(b, 1, h, w) valid pixel mask (1 for known, 0 for unknown)

required
depth1 Tensor

(b, 1, h, w) depth map

required
flow12 Tensor

(b, 2, h, w) optical flow from frame1 to frame2

required
flow12_mask Tensor | None

(b, 1, h, w) flow validity mask

None
is_image bool

If True, output will be clipped to (-1, 1) range

False
depth_weight_scale float

Scale factor for depth weighting

50.0

Returns:

Name Type Description
warped_frame2 Tensor

(b, c, h, w) warped frame

mask2 Tensor

(b, 1, h, w) validity mask for warped frame

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def bilinear_splatting(
    frame1: torch.Tensor,
    mask1: torch.Tensor | None,
    depth1: torch.Tensor,
    flow12: torch.Tensor,
    flow12_mask: torch.Tensor | None = None,
    is_image: bool = False,
    depth_weight_scale: float = 50.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Bilinear splatting for forward warping.

    Args:
        frame1: (b, c, h, w) source frame
        mask1: (b, 1, h, w) valid pixel mask (1 for known, 0 for unknown)
        depth1: (b, 1, h, w) depth map
        flow12: (b, 2, h, w) optical flow from frame1 to frame2
        flow12_mask: (b, 1, h, w) flow validity mask
        is_image: If True, output will be clipped to (-1, 1) range
        depth_weight_scale: Scale factor for depth weighting

    Returns:
        warped_frame2: (b, c, h, w) warped frame
        mask2: (b, 1, h, w) validity mask for warped frame
    """
    b, c, h, w = frame1.shape
    device = frame1.device
    dtype = frame1.dtype

    if mask1 is None:
        mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype)
    if flow12_mask is None:
        flow12_mask = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype)

    grid = create_grid(b, h, w, device=device, dtype=dtype)
    trans_pos = flow12 + grid

    trans_pos_offset = trans_pos + 1
    trans_pos_floor = torch.floor(trans_pos_offset).long()
    trans_pos_ceil = torch.ceil(trans_pos_offset).long()

    trans_pos_offset = torch.stack(
        [torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
         torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)],
        dim=1)
    trans_pos_floor = torch.stack(
        [torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
         torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)],
        dim=1)
    trans_pos_ceil = torch.stack(
        [torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
         torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)],
        dim=1)

    # Bilinear weights
    prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
                     (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
    prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
                     (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
    prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
                     (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
    prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
                     (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))

    # Depth weighting for occlusion handling
    clamped_depth1 = torch.clamp(depth1, min=0)
    log_depth1 = torch.log1p(clamped_depth1)
    exponent = log_depth1 / (log_depth1.max() + 1e-7) * depth_weight_scale
    max_exponent = 80.0 if dtype in [torch.float32, torch.bfloat16] else 10.0
    clamped_exponent = torch.clamp(exponent, max=max_exponent)
    depth_weights = torch.exp(clamped_exponent) + 1e-7

    weight_nw = torch.moveaxis(prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
    weight_sw = torch.moveaxis(prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
    weight_ne = torch.moveaxis(prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
    weight_se = torch.moveaxis(prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])

    warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=dtype, device=device)
    warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=dtype, device=device)

    frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2])
    batch_indices = torch.arange(b, device=device, dtype=torch.long)[:, None, None]

    warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),
                            frame1_cl * weight_nw,
                            accumulate=True)
    warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),
                            frame1_cl * weight_sw,
                            accumulate=True)
    warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),
                            frame1_cl * weight_ne,
                            accumulate=True)
    warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),
                            frame1_cl * weight_se,
                            accumulate=True)

    warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), weight_nw, accumulate=True)
    warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), weight_sw, accumulate=True)
    warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), weight_ne, accumulate=True)
    warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), weight_se, accumulate=True)

    warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1])
    warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1])
    cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1]
    cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1]
    cropped_weights = torch.nan_to_num(cropped_weights, nan=1000.0)

    mask = cropped_weights > 0
    zero_value = -1 if is_image else 0
    zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device)
    warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor)
    mask2 = mask.to(frame1)

    if is_image:
        warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1)

    return warped_frame2, mask2
fastvideo.pipelines.basic.gen3c.cache_3d.create_grid
create_grid(b: int, h: int, w: int, device: str = 'cpu', dtype: dtype = float32) -> Tensor

Create a dense grid of (x, y) coordinates of shape (b, 2, h, w).

Parameters:

Name Type Description Default
b int

Batch size

required
h int

Height

required
w int

Width

required
device str

Device for tensor creation

'cpu'
dtype dtype

Data type for tensor

float32

Returns:

Type Description
Tensor

Grid tensor of shape (b, 2, h, w)

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def create_grid(b: int, h: int, w: int, device: str = "cpu", dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """
    Create a dense grid of (x, y) coordinates of shape (b, 2, h, w).

    Args:
        b: Batch size
        h: Height
        w: Width
        device: Device for tensor creation
        dtype: Data type for tensor

    Returns:
        Grid tensor of shape (b, 2, h, w)
    """
    x = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w)
    y = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w)
    return torch.cat([x, y], dim=1)
fastvideo.pipelines.basic.gen3c.cache_3d.forward_warp
forward_warp(frame1: Tensor, mask1: Tensor | None, depth1: Tensor | None, transformation1: Tensor | None, transformation2: Tensor, intrinsic1: Tensor | None, intrinsic2: Tensor | None, is_image: bool = True, is_depth: bool = True, render_depth: bool = False, world_points1: Tensor | None = None) -> tuple[Tensor, Tensor, Tensor | None, Tensor]

Forward warp frame1 to a new view defined by transformation2.

Parameters:

Name Type Description Default
frame1 Tensor

(b, c, h, w) source frame in range [-1, 1] for images

required
mask1 Tensor | None

(b, 1, h, w) valid pixel mask

required
depth1 Tensor | None

(b, 1, h, w) depth map (required if world_points1 is None)

required
transformation1 Tensor | None

(b, 4, 4) source camera w2c (required if depth1 is provided)

required
transformation2 Tensor

(b, 4, 4) target camera w2c

required
intrinsic1 Tensor | None

(b, 3, 3) source camera intrinsics

required
intrinsic2 Tensor | None

(b, 3, 3) target camera intrinsics

required
is_image bool

If True, output will be clipped to (-1, 1)

True
is_depth bool

If True, depth1 is z-depth; if False, it's distance

True
render_depth bool

If True, also return the warped depth map

False
world_points1 Tensor | None

(b, h, w, 3) pre-computed world points (alternative to depth1)

None

Returns:

Name Type Description
warped_frame2 Tensor

(b, c, h, w) warped frame

mask2 Tensor

(b, 1, h, w) validity mask

warped_depth2 Tensor | None

(b, h, w) warped depth (if render_depth=True)

flow12 Tensor

(b, 2, h, w) optical flow

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def forward_warp(
    frame1: torch.Tensor,
    mask1: torch.Tensor | None,
    depth1: torch.Tensor | None,
    transformation1: torch.Tensor | None,
    transformation2: torch.Tensor,
    intrinsic1: torch.Tensor | None,
    intrinsic2: torch.Tensor | None,
    is_image: bool = True,
    is_depth: bool = True,
    render_depth: bool = False,
    world_points1: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
    """
    Forward warp frame1 to a new view defined by transformation2.

    Args:
        frame1: (b, c, h, w) source frame in range [-1, 1] for images
        mask1: (b, 1, h, w) valid pixel mask
        depth1: (b, 1, h, w) depth map (required if world_points1 is None)
        transformation1: (b, 4, 4) source camera w2c (required if depth1 is provided)
        transformation2: (b, 4, 4) target camera w2c
        intrinsic1: (b, 3, 3) source camera intrinsics
        intrinsic2: (b, 3, 3) target camera intrinsics
        is_image: If True, output will be clipped to (-1, 1)
        is_depth: If True, depth1 is z-depth; if False, it's distance
        render_depth: If True, also return the warped depth map
        world_points1: (b, h, w, 3) pre-computed world points (alternative to depth1)

    Returns:
        warped_frame2: (b, c, h, w) warped frame
        mask2: (b, 1, h, w) validity mask
        warped_depth2: (b, h, w) warped depth (if render_depth=True)
        flow12: (b, 2, h, w) optical flow
    """
    device = frame1.device
    b, c, h, w = frame1.shape
    dtype = frame1.dtype

    if mask1 is None:
        mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype)
    if intrinsic2 is None:
        assert intrinsic1 is not None
        intrinsic2 = intrinsic1.clone()

    if world_points1 is not None:
        # Use pre-computed world points
        assert world_points1.shape == (b, h, w, 3)
        trans_points1 = project_points(world_points1, transformation2, intrinsic2)
    else:
        # Compute from depth
        assert depth1 is not None and transformation1 is not None
        assert depth1.shape == (b, 1, h, w)

        depth1 = torch.nan_to_num(depth1, nan=1e4)
        depth1 = torch.clamp(depth1, min=0, max=1e4)

        # Unproject to world, then project to target view
        world_points1 = unproject_points(depth1, transformation1, intrinsic1, is_depth=is_depth)
        trans_points1 = project_points(world_points1, transformation2, intrinsic2)

    # Filter points behind camera
    mask1 = mask1 * (trans_points1[:, :, :, 2, 0].unsqueeze(1) > 0)
    trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0] + 1e-7)
    trans_coordinates = trans_coordinates.permute(0, 3, 1, 2)  # b, 2, h, w
    trans_depth1 = trans_points1[:, :, :, 2, 0].unsqueeze(1)

    grid = create_grid(b, h, w, device=device, dtype=dtype)
    flow12 = trans_coordinates - grid

    warped_frame2, mask2 = bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image)

    warped_depth2 = None
    if render_depth:
        warped_depth2 = bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False)[0][:, 0]

    return warped_frame2, mask2, warped_depth2, flow12
fastvideo.pipelines.basic.gen3c.cache_3d.inverse_with_conversion
inverse_with_conversion(mtx: Tensor) -> Tensor

Compute matrix inverse with float32 conversion for numerical stability.

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def inverse_with_conversion(mtx: torch.Tensor) -> torch.Tensor:
    """Compute matrix inverse with float32 conversion for numerical stability."""
    return torch.linalg.inv(mtx.to(torch.float32)).to(mtx.dtype)
fastvideo.pipelines.basic.gen3c.cache_3d.project_points
project_points(world_points: Tensor, w2c: Tensor, intrinsic: Tensor) -> Tensor

Project 3D world points to 2D pixel coordinates.

Parameters:

Name Type Description Default
world_points Tensor

(b, h, w, 3) 3D world coordinates

required
w2c Tensor

(b, 4, 4) world-to-camera transformation matrix

required
intrinsic Tensor

(b, 3, 3) camera intrinsic matrix

required

Returns:

Name Type Description
projected_points Tensor

(b, h, w, 3, 1) projected 2D coordinates (x, y, z)

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def project_points(
    world_points: torch.Tensor,
    w2c: torch.Tensor,
    intrinsic: torch.Tensor,
) -> torch.Tensor:
    """
    Project 3D world points to 2D pixel coordinates.

    Args:
        world_points: (b, h, w, 3) 3D world coordinates
        w2c: (b, 4, 4) world-to-camera transformation matrix
        intrinsic: (b, 3, 3) camera intrinsic matrix

    Returns:
        projected_points: (b, h, w, 3, 1) projected 2D coordinates (x, y, z)
    """
    world_points = world_points.unsqueeze(-1)  # (b, h, w, 3, 1)
    b, h, w, _, _ = world_points.shape

    ones_4d = torch.ones((b, h, w, 1, 1), device=world_points.device, dtype=world_points.dtype)
    world_points_homo = torch.cat([world_points, ones_4d], dim=3)  # (b, h, w, 4, 1)

    trans_4d = w2c[:, None, None]  # (b, 1, 1, 4, 4)
    camera_points_homo = torch.matmul(trans_4d, world_points_homo)  # (b, h, w, 4, 1)

    camera_points = camera_points_homo[:, :, :, :3]  # (b, h, w, 3, 1)
    intrinsic_4d = intrinsic[:, None, None]  # (b, 1, 1, 3, 3)
    projected_points = torch.matmul(intrinsic_4d, camera_points)  # (b, h, w, 3, 1)

    return projected_points
fastvideo.pipelines.basic.gen3c.cache_3d.reliable_depth_mask_range_batch
reliable_depth_mask_range_batch(depth: Tensor, window_size: int = 5, ratio_thresh: float = 0.05, eps: float = 1e-06) -> Tensor

Compute a mask for reliable depth values based on local variation.

Parameters:

Name Type Description Default
depth Tensor

(b, h, w) or (b, 1, h, w) depth map

required
window_size int

Size of the local window (must be odd)

5
ratio_thresh float

Threshold for depth variation ratio

0.05
eps float

Small epsilon for numerical stability

1e-06

Returns:

Name Type Description
reliable_mask Tensor

Boolean mask where True indicates reliable depth

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def reliable_depth_mask_range_batch(
    depth: torch.Tensor,
    window_size: int = 5,
    ratio_thresh: float = 0.05,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    Compute a mask for reliable depth values based on local variation.

    Args:
        depth: (b, h, w) or (b, 1, h, w) depth map
        window_size: Size of the local window (must be odd)
        ratio_thresh: Threshold for depth variation ratio
        eps: Small epsilon for numerical stability

    Returns:
        reliable_mask: Boolean mask where True indicates reliable depth
    """
    assert window_size % 2 == 1, "Window size must be odd."

    if depth.dim() == 3:
        depth_unsq = depth.unsqueeze(1)
    elif depth.dim() == 4:
        depth_unsq = depth
    else:
        raise ValueError("depth tensor must be of shape (b, h, w) or (b, 1, h, w)")

    local_max = F.max_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2)
    local_min = -F.max_pool2d(-depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2)
    local_mean = F.avg_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2)

    ratio = (local_max - local_min) / (local_mean + eps)
    reliable_mask = (ratio < ratio_thresh) & (depth_unsq > 0)

    return reliable_mask
fastvideo.pipelines.basic.gen3c.cache_3d.unproject_points
unproject_points(depth: Tensor, w2c: Tensor, intrinsic: Tensor, is_depth: bool = True, mask: Tensor | None = None) -> Tensor

Unproject depth map to 3D world points.

Parameters:

Name Type Description Default
depth Tensor

(b, 1, h, w) depth map

required
w2c Tensor

(b, 4, 4) world-to-camera transformation matrix

required
intrinsic Tensor

(b, 3, 3) camera intrinsic matrix

required
is_depth bool

If True, depth is z-depth; if False, depth is distance to camera

True
mask Tensor | None

Optional (b, h, w) or (b, 1, h, w) mask for valid pixels

None

Returns:

Name Type Description
world_points Tensor

(b, h, w, 3) 3D world coordinates

Source code in fastvideo/pipelines/basic/gen3c/cache_3d.py
def unproject_points(
    depth: torch.Tensor,
    w2c: torch.Tensor,
    intrinsic: torch.Tensor,
    is_depth: bool = True,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Unproject depth map to 3D world points.

    Args:
        depth: (b, 1, h, w) depth map
        w2c: (b, 4, 4) world-to-camera transformation matrix
        intrinsic: (b, 3, 3) camera intrinsic matrix
        is_depth: If True, depth is z-depth; if False, depth is distance to camera
        mask: Optional (b, h, w) or (b, 1, h, w) mask for valid pixels

    Returns:
        world_points: (b, h, w, 3) 3D world coordinates
    """
    b, _, h, w = depth.shape
    device = depth.device
    dtype = depth.dtype

    if mask is None:
        mask = depth > 0
    if mask.dim() == depth.dim() and mask.shape[1] == 1:
        mask = mask[:, 0]

    idx = torch.nonzero(mask)
    if idx.numel() == 0:
        return torch.zeros((b, h, w, 3), device=device, dtype=dtype)

    b_idx, y_idx, x_idx = idx[:, 0], idx[:, 1], idx[:, 2]

    intrinsic_inv = inverse_with_conversion(intrinsic)  # (b, 3, 3)

    x_valid = x_idx.to(dtype)
    y_valid = y_idx.to(dtype)
    ones = torch.ones_like(x_valid)
    pos = torch.stack([x_valid, y_valid, ones], dim=1).unsqueeze(-1)  # (N, 3, 1)

    intrinsic_inv_valid = intrinsic_inv[b_idx]  # (N, 3, 3)
    unnormalized_pos = torch.matmul(intrinsic_inv_valid, pos)  # (N, 3, 1)

    depth_valid = depth[b_idx, 0, y_idx, x_idx].view(-1, 1, 1)
    if is_depth:
        world_points_cam = depth_valid * unnormalized_pos
    else:
        norm_val = torch.norm(unnormalized_pos, dim=1, keepdim=True)
        direction = unnormalized_pos / (norm_val + 1e-8)
        world_points_cam = depth_valid * direction

    ones_h = torch.ones((world_points_cam.shape[0], 1, 1), device=device, dtype=dtype)
    world_points_homo = torch.cat([world_points_cam, ones_h], dim=1)  # (N, 4, 1)

    trans = inverse_with_conversion(w2c)  # (b, 4, 4)
    trans_valid = trans[b_idx]  # (N, 4, 4)
    world_points_transformed = torch.matmul(trans_valid, world_points_homo)  # (N, 4, 1)
    sparse_points = world_points_transformed[:, :3, 0]  # (N, 3)

    out_points = torch.zeros((b, h, w, 3), device=device, dtype=dtype)
    out_points[b_idx, y_idx, x_idx, :] = sparse_points
    return out_points

fastvideo.pipelines.basic.gen3c.camera_utils

Camera trajectory generation utilities for GEN3C 3D cache conditioning.

Functions

fastvideo.pipelines.basic.gen3c.camera_utils.apply_transformation
apply_transformation(Bx4x4: Tensor, another_matrix: Tensor) -> Tensor

Apply batch transformation to a matrix.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def apply_transformation(Bx4x4: torch.Tensor, another_matrix: torch.Tensor) -> torch.Tensor:
    """Apply batch transformation to a matrix."""
    B = Bx4x4.shape[0]
    if another_matrix.dim() == 2:
        another_matrix = another_matrix.unsqueeze(0).expand(B, -1, -1)
    return torch.bmm(Bx4x4, another_matrix)
fastvideo.pipelines.basic.gen3c.camera_utils.create_horizontal_trajectory
create_horizontal_trajectory(world_to_camera_matrix: Tensor, center_depth: float, positive: bool = True, n_steps: int = 13, distance: float = 0.1, device: str = 'cuda', axis: str = 'x', camera_rotation: str = 'center_facing') -> Tensor

Create a linear camera trajectory along a specified axis.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def create_horizontal_trajectory(
    world_to_camera_matrix: torch.Tensor,
    center_depth: float,
    positive: bool = True,
    n_steps: int = 13,
    distance: float = 0.1,
    device: str = "cuda",
    axis: str = "x",
    camera_rotation: str = "center_facing",
) -> torch.Tensor:
    """Create a linear camera trajectory along a specified axis."""
    look_at_target = torch.tensor([0.0, 0.0, center_depth]).to(device)
    trajectory = []
    initial_camera_pos = torch.tensor([0, 0, 0], device=device, dtype=torch.float32)

    translation_positions = []
    for i in range(n_steps):
        offset = i * distance * center_depth / n_steps * (1 if positive else -1)
        if axis == "x":
            pos = torch.tensor([offset, 0, 0], device=device)
        elif axis == "y":
            pos = torch.tensor([0, offset, 0], device=device)
        elif axis == "z":
            pos = torch.tensor([0, 0, offset], device=device)
        else:
            raise ValueError(f"Axis should be x, y or z, got {axis}")
        translation_positions.append(pos)

    for pos in translation_positions:
        camera_pos = initial_camera_pos + pos
        if camera_rotation == "trajectory_aligned":
            _look_at = look_at_target + pos * 2
        elif camera_rotation == "center_facing":
            _look_at = look_at_target
        elif camera_rotation == "no_rotation":
            _look_at = look_at_target + pos
        else:
            raise ValueError(f"camera_rotation should be center_facing, trajectory_aligned, "
                             f"or no_rotation, got {camera_rotation}")
        view_matrix = look_at_matrix(camera_pos, _look_at)
        trajectory.append(view_matrix)

    trajectory = torch.stack(trajectory)
    return apply_transformation(trajectory, world_to_camera_matrix)
fastvideo.pipelines.basic.gen3c.camera_utils.create_spiral_trajectory
create_spiral_trajectory(world_to_camera_matrix: Tensor, center_depth: float, radius_x: float = 0.03, radius_y: float = 0.02, radius_z: float = 0.0, positive: bool = True, camera_rotation: str = 'center_facing', n_steps: int = 13, device: str = 'cuda', start_from_zero: bool = True, num_circles: int = 1) -> Tensor

Create a spiral/circular camera trajectory.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def create_spiral_trajectory(
    world_to_camera_matrix: torch.Tensor,
    center_depth: float,
    radius_x: float = 0.03,
    radius_y: float = 0.02,
    radius_z: float = 0.0,
    positive: bool = True,
    camera_rotation: str = "center_facing",
    n_steps: int = 13,
    device: str = "cuda",
    start_from_zero: bool = True,
    num_circles: int = 1,
) -> torch.Tensor:
    """Create a spiral/circular camera trajectory."""
    look_at_target = torch.tensor([0.0, 0.0, center_depth]).to(device)
    trajectory = []
    initial_camera_pos = torch.tensor([0, 0, 0], device=device, dtype=torch.float32)

    theta_max = 2 * math.pi * num_circles
    spiral_positions = []

    for i in range(n_steps):
        theta = theta_max * i / (n_steps - 1)
        if start_from_zero:
            x = radius_x * (math.cos(theta) - 1) * (1 if positive else -1) * center_depth
        else:
            x = radius_x * math.cos(theta) * center_depth
        y = radius_y * math.sin(theta) * center_depth
        z = radius_z * math.sin(theta) * center_depth
        spiral_positions.append(torch.tensor([x, y, z], device=device))

    for pos in spiral_positions:
        camera_pos = initial_camera_pos + pos
        if camera_rotation == "center_facing":
            view_matrix = look_at_matrix(camera_pos, look_at_target)
        elif camera_rotation == "trajectory_aligned":
            view_matrix = look_at_matrix(camera_pos, look_at_target + pos * 2)
        elif camera_rotation == "no_rotation":
            view_matrix = look_at_matrix(camera_pos, look_at_target + pos)
        else:
            raise ValueError(f"camera_rotation should be center_facing, trajectory_aligned, "
                             f"or no_rotation, got {camera_rotation}")
        trajectory.append(view_matrix)

    trajectory = torch.stack(trajectory)
    return apply_transformation(trajectory, world_to_camera_matrix)
fastvideo.pipelines.basic.gen3c.camera_utils.generate_camera_trajectory
generate_camera_trajectory(trajectory_type: str, initial_w2c: Tensor, initial_intrinsics: Tensor, num_frames: int, movement_distance: float, camera_rotation: str = 'center_facing', center_depth: float = 1.0, device: str = 'cuda') -> tuple[Tensor, Tensor]

Generate camera trajectory for GEN3C video generation.

Parameters:

Name Type Description Default
trajectory_type str

One of "left", "right", "up", "down", "zoom_in", "zoom_out", "clockwise", "counterclockwise".

required
initial_w2c Tensor

Initial world-to-camera matrix (4, 4).

required
initial_intrinsics Tensor

Camera intrinsics matrix (3, 3).

required
num_frames int

Number of frames in the trajectory.

required
movement_distance float

Distance factor for camera movement.

required
camera_rotation str

"center_facing", "no_rotation", or "trajectory_aligned".

'center_facing'
center_depth float

Depth of the scene center point.

1.0
device str

Computation device.

'cuda'

Returns:

Name Type Description
generated_w2cs Tensor

(1, num_frames, 4, 4) world-to-camera matrices.

generated_intrinsics Tensor

(1, num_frames, 3, 3) camera intrinsics.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def generate_camera_trajectory(
    trajectory_type: str,
    initial_w2c: torch.Tensor,
    initial_intrinsics: torch.Tensor,
    num_frames: int,
    movement_distance: float,
    camera_rotation: str = "center_facing",
    center_depth: float = 1.0,
    device: str = "cuda",
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Generate camera trajectory for GEN3C video generation.

    Args:
        trajectory_type: One of "left", "right", "up", "down", "zoom_in",
            "zoom_out", "clockwise", "counterclockwise".
        initial_w2c: Initial world-to-camera matrix (4, 4).
        initial_intrinsics: Camera intrinsics matrix (3, 3).
        num_frames: Number of frames in the trajectory.
        movement_distance: Distance factor for camera movement.
        camera_rotation: "center_facing", "no_rotation", or "trajectory_aligned".
        center_depth: Depth of the scene center point.
        device: Computation device.

    Returns:
        generated_w2cs: (1, num_frames, 4, 4) world-to-camera matrices.
        generated_intrinsics: (1, num_frames, 3, 3) camera intrinsics.
    """
    if trajectory_type in ["clockwise", "counterclockwise"]:
        new_w2cs_seq = create_spiral_trajectory(
            world_to_camera_matrix=initial_w2c,
            center_depth=center_depth,
            n_steps=num_frames,
            positive=trajectory_type == "clockwise",
            device=device,
            camera_rotation=camera_rotation,
            radius_x=movement_distance,
            radius_y=movement_distance,
        )
    elif trajectory_type == "none":
        # Static camera - repeat identity
        new_w2cs_seq = initial_w2c.unsqueeze(0).expand(num_frames, -1, -1)
    else:
        axis_map = {
            "left": (False, "x"),
            "right": (True, "x"),
            "up": (False, "y"),
            "down": (True, "y"),
            "zoom_in": (True, "z"),
            "zoom_out": (False, "z"),
        }
        if trajectory_type not in axis_map:
            raise ValueError(f"Unsupported trajectory type: {trajectory_type}")
        positive, axis = axis_map[trajectory_type]

        new_w2cs_seq = create_horizontal_trajectory(
            world_to_camera_matrix=initial_w2c,
            center_depth=center_depth,
            n_steps=num_frames,
            positive=positive,
            axis=axis,
            distance=movement_distance,
            device=device,
            camera_rotation=camera_rotation,
        )

    generated_w2cs = new_w2cs_seq.unsqueeze(0)  # (1, num_frames, 4, 4)
    if initial_intrinsics.dim() == 2:
        generated_intrinsics = initial_intrinsics.unsqueeze(0).unsqueeze(0).repeat(1, num_frames, 1, 1)
    else:
        generated_intrinsics = initial_intrinsics.unsqueeze(0)

    return generated_w2cs, generated_intrinsics
fastvideo.pipelines.basic.gen3c.camera_utils.look_at_matrix
look_at_matrix(camera_pos: Tensor, target: Tensor, invert_pos: bool = True) -> Tensor

Create a 4x4 look-at view matrix pointing camera toward target.

Source code in fastvideo/pipelines/basic/gen3c/camera_utils.py
def look_at_matrix(camera_pos: torch.Tensor, target: torch.Tensor, invert_pos: bool = True) -> torch.Tensor:
    """Create a 4x4 look-at view matrix pointing camera toward target."""
    forward = (target - camera_pos).float()
    forward = forward / torch.norm(forward)

    up = torch.tensor([0.0, 1.0, 0.0], device=camera_pos.device)
    right = torch.cross(up, forward)
    right = right / torch.norm(right)
    up = torch.cross(forward, right)

    look_at = torch.eye(4, device=camera_pos.device)
    look_at[0, :3] = right
    look_at[1, :3] = up
    look_at[2, :3] = forward
    look_at[:3, 3] = (-camera_pos) if invert_pos else camera_pos

    return look_at

fastvideo.pipelines.basic.gen3c.depth_estimation

MoGe-based monocular depth estimation for GEN3C 3D cache conditioning.

Functions

fastvideo.pipelines.basic.gen3c.depth_estimation.load_moge_model
load_moge_model(model_name: str = 'Ruicheng/moge-vitl', device: str | device = 'cuda') -> MoGeModel

Load MoGe depth estimation model from HuggingFace.

Parameters:

Name Type Description Default
model_name str

HuggingFace model identifier.

'Ruicheng/moge-vitl'
device str | device

Device to load model on.

'cuda'

Returns:

Type Description
MoGeModel

Loaded MoGe model.

Source code in fastvideo/pipelines/basic/gen3c/depth_estimation.py
def load_moge_model(
    model_name: str = "Ruicheng/moge-vitl",
    device: str | torch.device = "cuda",
) -> MoGeModel:
    """Load MoGe depth estimation model from HuggingFace.

    Args:
        model_name: HuggingFace model identifier.
        device: Device to load model on.

    Returns:
        Loaded MoGe model.
    """
    try:
        from moge.model.v1 import MoGeModel
    except ImportError as exc:
        raise ImportError("MoGe is required for GEN3C 3D cache conditioning. "
                          "Install it with: pip install git+https://github.com/microsoft/MoGe.git. "
                          "If import fails with libGL.so.1, install system deps: "
                          "sudo apt-get install -y libgl1 libglib2.0-0 libsm6 libxext6 libxrender1") from exc

    logger.info("Loading MoGe depth model: %s", model_name)
    model = MoGeModel.from_pretrained(model_name).to(device)
    model.eval()
    logger.info("MoGe model loaded successfully")
    return model
fastvideo.pipelines.basic.gen3c.depth_estimation.predict_depth_from_path
predict_depth_from_path(image_path: str, target_h: int, target_w: int, device: device, moge_model: MoGeModel) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]

Predict depth, intrinsics, and mask from an image file path.

Parameters:

Name Type Description Default
image_path str

Path to input image (RGB or BGR, any format cv2 supports).

required
target_h int

Target height for output tensors.

required
target_w int

Target width for output tensors.

required
device device

Computation device.

required
moge_model MoGeModel

Loaded MoGe model.

required

Returns:

Name Type Description
image Tensor

(1, 1, 3, target_h, target_w) image tensor in [-1, 1].

depth Tensor

(1, 1, 1, target_h, target_w) depth map.

mask Tensor

(1, 1, 1, target_h, target_w) confidence mask.

w2c Tensor

(1, 1, 4, 4) world-to-camera matrix (identity).

intrinsics Tensor

(1, 1, 3, 3) camera intrinsics.

Source code in fastvideo/pipelines/basic/gen3c/depth_estimation.py
def predict_depth_from_path(
    image_path: str,
    target_h: int,
    target_w: int,
    device: torch.device,
    moge_model: MoGeModel,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Predict depth, intrinsics, and mask from an image file path.

    Args:
        image_path: Path to input image (RGB or BGR, any format cv2 supports).
        target_h: Target height for output tensors.
        target_w: Target width for output tensors.
        device: Computation device.
        moge_model: Loaded MoGe model.

    Returns:
        image: (1, 1, 3, target_h, target_w) image tensor in [-1, 1].
        depth: (1, 1, 1, target_h, target_w) depth map.
        mask: (1, 1, 1, target_h, target_w) confidence mask.
        w2c: (1, 1, 4, 4) world-to-camera matrix (identity).
        intrinsics: (1, 1, 3, 3) camera intrinsics.
    """
    import cv2

    input_image_bgr = cv2.imread(image_path)
    if input_image_bgr is None:
        raise FileNotFoundError(f"Input image not found: {image_path}")
    input_image_rgb = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB)

    return _predict_depth_core(input_image_rgb, target_h, target_w, device, moge_model)
fastvideo.pipelines.basic.gen3c.depth_estimation.predict_depth_from_tensor
predict_depth_from_tensor(image_tensor: Tensor, moge_model: MoGeModel) -> tuple[Tensor, Tensor]

Predict depth and mask from an image tensor (for autoregressive generation).

Parameters:

Name Type Description Default
image_tensor Tensor

(C, H, W) image tensor in [0, 1] range.

required
moge_model MoGeModel

Loaded MoGe model.

required

Returns:

Name Type Description
depth Tensor

(1, 1, H, W) depth map.

mask Tensor

(1, 1, H, W) confidence mask.

Source code in fastvideo/pipelines/basic/gen3c/depth_estimation.py
def predict_depth_from_tensor(
    image_tensor: torch.Tensor,
    moge_model: MoGeModel,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Predict depth and mask from an image tensor (for autoregressive generation).

    Args:
        image_tensor: (C, H, W) image tensor in [0, 1] range.
        moge_model: Loaded MoGe model.

    Returns:
        depth: (1, 1, H, W) depth map.
        mask: (1, 1, H, W) confidence mask.
    """
    moge_output = moge_model.infer(image_tensor)
    depth = moge_output["depth"]
    mask = moge_output["mask"]

    depth = depth.unsqueeze(0).unsqueeze(0)
    depth = torch.nan_to_num(depth, nan=1e4)
    depth = torch.clamp(depth, min=0, max=1e4)

    mask = mask.unsqueeze(0).unsqueeze(0)
    depth = torch.where(mask == 0, torch.tensor(1000.0, device=depth.device), depth)

    return depth, mask

fastvideo.pipelines.basic.gen3c.gen3c_pipeline

GEN3C video diffusion pipeline wiring.

Classes

fastvideo.pipelines.basic.gen3c.gen3c_pipeline.Gen3CPipeline
Gen3CPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

GEN3C Video Generation Pipeline.

This pipeline extends Cosmos with 3D cache support for camera-controlled video generation. When an input image is provided, it runs the full 3D cache conditioning pipeline (depth estimation -> point cloud -> camera trajectory -> forward warping -> VAE encoding).

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.basic.gen3c.gen3c_pipeline.Gen3CPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/gen3c/gen3c_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="cfg_policy_stage", stage=Gen3CCFGPolicyStage())

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage", stage=Gen3CConditioningStage(vae=self.get_module("vae")))

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=Gen3CLatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                     transformer=self.get_module("transformer"),
                                                     vae=self.get_module("vae")))

    self.add_stage(stage_name="denoising_stage",
                   stage=Gen3CDenoisingStage(transformer=self.get_module("transformer"),
                                             scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))

Functions

fastvideo.pipelines.basic.gen3c.presets

GEN3C model family pipeline presets.

Classes