Skip to content

gen3c_stages

GEN3C-specific pipeline stages, allowing us to keep conditioning/latent/denoising logic separate from pipeline orchestration.

Classes

fastvideo.pipelines.stages.gen3c_stages.Gen3CCFGPolicyStage

Bases: PipelineStage

Explicitly control when GEN3C runs a conditional/unconditional pair (CFG).

Policies: - legacy: enable CFG only when guidance_scale > 1.0 (current FastVideo behavior) - official_uncond_at_unity: also run CFG at guidance_scale == 1.0

fastvideo.pipelines.stages.gen3c_stages.Gen3CConditioningStage

Gen3CConditioningStage(vae=None)

Bases: PipelineStage

3D cache conditioning stage for GEN3C.

This stage performs the core GEN3C innovation: 1. Loads the input image 2. Predicts depth via MoGe 3. Initializes a 3D point cloud cache 4. Generates a camera trajectory 5. Renders warped frames from the cache at each target camera pose 6. Stores rendered warps on the batch for VAE encoding in the latent prep stage

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, vae=None) -> None:
    super().__init__()
    self._moge_model: Any | None = None
    self._vae = vae

Functions

fastvideo.pipelines.stages.gen3c_stages.Gen3CConditioningStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run 3D cache conditioning pipeline.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Run 3D cache conditioning pipeline."""
    pipeline_config = fastvideo_args.pipeline_config
    device = get_local_torch_device()
    batch_extra = getattr(batch, "extra", {}) or {}

    image_path = getattr(batch, 'image_path', None) or batch_extra.get("image_path")
    if image_path is None:
        logger.info("No image_path provided - skipping 3D cache conditioning "
                    "(will use zero conditioning)")
        return batch

    logger.info("Running 3D cache conditioning with image: %s", image_path)

    height = getattr(batch, 'height', None) or getattr(pipeline_config, 'video_resolution', (720, 1280))[0]
    width = getattr(batch, 'width', None) or getattr(pipeline_config, 'video_resolution', (720, 1280))[1]
    num_frames = getattr(batch, 'num_frames', None) or getattr(pipeline_config, 'num_frames', 121)

    trajectory_type = (getattr(batch, 'trajectory_type', None) or batch_extra.get("trajectory_type")
                       or getattr(pipeline_config, 'default_trajectory_type', 'left'))
    movement_distance = (getattr(batch, 'movement_distance', None) or batch_extra.get("movement_distance")
                         or getattr(pipeline_config, 'default_movement_distance', 0.3))
    camera_rotation = (getattr(batch, 'camera_rotation', None) or batch_extra.get("camera_rotation")
                       or getattr(pipeline_config, 'default_camera_rotation', 'center_facing'))

    frame_buffer_max = getattr(pipeline_config, 'frame_buffer_max', 2)
    noise_aug_strength = getattr(pipeline_config, 'noise_aug_strength', 0.0)
    filter_points_threshold = getattr(pipeline_config, 'filter_points_threshold', 0.05)

    moge_model_name = getattr(pipeline_config, 'moge_model_name', 'Ruicheng/moge-vitl')

    from fastvideo.pipelines.basic.gen3c.depth_estimation import (predict_depth_from_path)

    moge_model = self._get_moge_model(device, moge_model_name)

    (
        image_b1chw,
        depth_b11hw,
        mask_b11hw,
        w2c_b144,
        intrinsics_b133,
    ) = predict_depth_from_path(image_path, height, width, device, moge_model)

    logger.info(
        "Depth prediction complete. Depth range: [%.3f, %.3f]",
        depth_b11hw.min().item(),
        depth_b11hw.max().item(),
    )

    from fastvideo.pipelines.basic.gen3c.cache_3d import Cache3DBuffer

    seed = getattr(batch, 'seed', None)
    if seed is None:
        seed = 42
    generator = torch.Generator(device=device).manual_seed(seed)

    cache = Cache3DBuffer(
        frame_buffer_max=frame_buffer_max,
        generator=generator,
        noise_aug_strength=noise_aug_strength,
        input_image=image_b1chw[:, 0].clone(),
        input_depth=depth_b11hw[:, 0],
        input_w2c=w2c_b144[:, 0],
        input_intrinsics=intrinsics_b133[:, 0],
        filter_points_threshold=filter_points_threshold,
    )

    logger.info("3D cache initialized with %d frame buffer(s)", frame_buffer_max)

    from fastvideo.pipelines.basic.gen3c.camera_utils import (generate_camera_trajectory)

    initial_w2c = w2c_b144[0, 0]
    initial_intrinsics = intrinsics_b133[0, 0]

    generated_w2cs, generated_intrinsics = generate_camera_trajectory(
        trajectory_type=trajectory_type,
        initial_w2c=initial_w2c,
        initial_intrinsics=initial_intrinsics,
        num_frames=num_frames,
        movement_distance=movement_distance,
        camera_rotation=camera_rotation,
        center_depth=1.0,
        device=device.type if isinstance(device, torch.device) else device,
    )

    logger.info(
        "Camera trajectory generated: type=%s, frames=%d, distance=%.3f",
        trajectory_type,
        num_frames,
        movement_distance,
    )

    rendered_warp_images, rendered_warp_masks = cache.render_cache(
        generated_w2cs[:, :num_frames],
        generated_intrinsics[:, :num_frames],
    )

    logger.info(
        "Cache rendered. Warped images shape: %s, non-zero mask ratio: %.3f",
        list(rendered_warp_images.shape),
        (rendered_warp_masks > 0).float().mean().item(),
    )

    batch.rendered_warp_images = rendered_warp_images.to(device)
    batch.rendered_warp_masks = rendered_warp_masks.to(device)
    batch.input_image_conditioning = image_b1chw[:, 0].unsqueeze(2).contiguous().to(device)
    batch.cache_3d = cache

    if getattr(pipeline_config, "offload_moge_after_depth", True):
        self._offload_moge()

    return batch

fastvideo.pipelines.stages.gen3c_stages.Gen3CDenoisingStage

Gen3CDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for GEN3C models.

This stage extends the base denoising stage with support for: - condition_video_input_mask: Binary mask indicating conditioning frames - condition_video_pose: VAE-encoded 3D cache buffers - condition_video_augment_sigma: Noise augmentation sigma

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.stages.gen3c_stages.Gen3CLatentPreparationStage

Gen3CLatentPreparationStage(scheduler, transformer, vae)

Bases: LatentPreparationStage

Latent preparation stage for GEN3C.

This stage prepares latents and encodes 3D cache buffers through the VAE. If rendered warped frames are available on the batch (from Gen3CConditioningStage), they are VAE-encoded to produce real conditioning. Otherwise falls back to zeros.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def __init__(self, scheduler, transformer, vae) -> None:
    super().__init__(scheduler, transformer)
    self.vae = vae

Functions

fastvideo.pipelines.stages.gen3c_stages.Gen3CLatentPreparationStage.encode_warped_frames
encode_warped_frames(condition_state: Tensor, condition_state_mask: Tensor, vae: Any, frame_buffer_max: int, dtype: dtype) -> Tensor

Encode rendered 3D cache buffers through VAE.

Parameters:

Name Type Description Default
condition_state Tensor

(B, T, N, 3, H, W) rendered RGB images in [-1, 1].

required
condition_state_mask Tensor

(B, T, N, 1, H, W) rendered masks in [0, 1].

required
vae Any

VAE encoder.

required
frame_buffer_max int

Maximum number of buffers.

required
dtype dtype

Target dtype.

required

Returns:

Name Type Description
latent_condition Tensor

(B, buffer_channels, T_latent, H_latent, W_latent)

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def encode_warped_frames(
    self,
    condition_state: torch.Tensor,
    condition_state_mask: torch.Tensor,
    vae: Any,
    frame_buffer_max: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Encode rendered 3D cache buffers through VAE.

    Args:
        condition_state: (B, T, N, 3, H, W) rendered RGB images in [-1, 1].
        condition_state_mask: (B, T, N, 1, H, W) rendered masks in [0, 1].
        vae: VAE encoder.
        frame_buffer_max: Maximum number of buffers.
        dtype: Target dtype.

    Returns:
        latent_condition: (B, buffer_channels, T_latent, H_latent, W_latent)
    """
    assert condition_state.dim() == 6

    condition_state_mask = (condition_state_mask * 2 - 1).repeat(1, 1, 1, 3, 1, 1)

    latent_condition = []
    num_buffers = condition_state.shape[2]
    for i in range(num_buffers):
        img_input = condition_state[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
        mask_input = condition_state_mask[:, :, i].permute(0, 2, 1, 3, 4).to(dtype)
        batched_input = torch.cat([img_input, mask_input], dim=0)
        batched_latent = self._retrieve_latents(vae.encode(batched_input)).contiguous()
        current_video_latent, current_mask_latent = batched_latent.chunk(2, dim=0)

        latent_condition.append(current_video_latent)
        latent_condition.append(current_mask_latent)

    for _ in range(frame_buffer_max - num_buffers):
        latent_condition.append(torch.zeros_like(current_video_latent))
        latent_condition.append(torch.zeros_like(current_mask_latent))

    return torch.cat(latent_condition, dim=1)
fastvideo.pipelines.stages.gen3c_stages.Gen3CLatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare latents and encode 3D cache buffers.

Source code in fastvideo/pipelines/stages/gen3c_stages.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Prepare latents and encode 3D cache buffers."""
    pipeline_config = fastvideo_args.pipeline_config
    device = get_local_torch_device()

    if isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]
    batch_size *= batch.num_videos_per_prompt

    num_channels_latents = getattr(self.transformer, 'num_channels_latents', 16)

    fallback_num_frames = getattr(pipeline_config, 'num_frames', 121)
    if not isinstance(fallback_num_frames, int):
        fallback_num_frames = 121

    num_frames_raw = getattr(batch, 'num_frames', None)
    if num_frames_raw is None:
        num_frames_raw = fallback_num_frames
    if isinstance(num_frames_raw, list):
        num_frames = int(num_frames_raw[0]) if len(num_frames_raw) > 0 else fallback_num_frames
    elif isinstance(num_frames_raw, int):
        num_frames = num_frames_raw
    else:
        num_frames = fallback_num_frames
    if hasattr(self.vae, "get_latent_num_frames"):
        latent_frames = int(self.vae.get_latent_num_frames(num_frames))
    else:
        temporal_ratio = getattr(
            pipeline_config.vae_config.arch_config,
            "temporal_compression_ratio",
            4,
        )
        latent_frames = int((num_frames - 1) // temporal_ratio + 1)
    height = getattr(batch, 'height', 720)
    width = getattr(batch, 'width', 1280)

    spatial_ratio = getattr(
        pipeline_config.vae_config.arch_config,
        "spatial_compression_ratio",
        8,
    )
    latent_height = height // spatial_ratio
    latent_width = width // spatial_ratio

    generator = getattr(batch, "generator", None)
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(f"Expected {batch_size} generators, got {len(generator)}.")

    latents = randn_tensor(
        (
            batch_size,
            num_channels_latents,
            latent_frames,
            latent_height,
            latent_width,
        ),
        generator=generator,
        device=device,
        dtype=torch.float32,
    )

    if hasattr(self.scheduler, 'init_noise_sigma'):
        latents = latents * self.scheduler.init_noise_sigma

    batch.latents = latents
    batch.batch_size = batch_size
    batch.height = height
    batch.width = width
    batch.latent_height = latent_height
    batch.latent_width = latent_width
    batch.latent_frames = latent_frames
    batch.raw_latent_shape = latents.shape

    frame_buffer_max = getattr(pipeline_config, 'frame_buffer_max', 2)
    channels_per_buffer = 32
    buffer_channels = frame_buffer_max * channels_per_buffer

    rendered_warp_images = getattr(batch, 'rendered_warp_images', None)
    rendered_warp_masks = getattr(batch, 'rendered_warp_masks', None)

    if rendered_warp_images is not None and rendered_warp_masks is not None:
        logger.info(
            "Encoding rendered warped frames through VAE (%d buffers)...",
            rendered_warp_images.shape[2],
        )

        self.vae = self.vae.to(device)

        if hasattr(self.vae, 'module'):
            vae_dtype = next(self.vae.module.parameters()).dtype
        else:
            vae_dtype = next(self.vae.parameters()).dtype

        condition_video_pose = self.encode_warped_frames(
            rendered_warp_images,
            rendered_warp_masks,
            self.vae,
            frame_buffer_max,
            vae_dtype,
        )
        batch.condition_video_pose = condition_video_pose.to(device)

        logger.info(
            "condition_video_pose encoded. Shape: %s, non-zero: %.4f",
            list(batch.condition_video_pose.shape),
            (batch.condition_video_pose != 0).float().mean().item(),
        )

        source_image = getattr(batch, "input_image_conditioning", None)
        if source_image is None:
            source_image = rendered_warp_images[:, 0, 0].unsqueeze(2)
        first_frame = source_image.to(device=device, dtype=vae_dtype)
        first_latent = self._retrieve_latents(self.vae.encode(first_frame))
        conditioning_latents = torch.zeros(
            batch_size,
            num_channels_latents,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=first_latent.dtype,
        )
        conditioning_latents[:, :, :first_latent.shape[2], :, :] = first_latent
        batch.conditioning_latents = conditioning_latents

        if fastvideo_args.vae_cpu_offload:
            self.vae.to("cpu")

        batch.condition_video_input_mask = torch.zeros(
            batch_size,
            1,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.condition_video_input_mask[:, :, 0, :, :] = 1.0
    else:
        logger.info("No rendered warps available - using zero conditioning")
        batch.condition_video_pose = torch.zeros(
            batch_size,
            buffer_channels,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.condition_video_input_mask = torch.zeros(
            batch_size,
            1,
            latent_frames,
            latent_height,
            latent_width,
            device=device,
            dtype=torch.float32,
        )
        batch.conditioning_latents = None

    batch.condition_video_augment_sigma = torch.zeros(batch_size, device=device, dtype=torch.float32)
    batch.cond_indicator = torch.zeros(
        batch_size,
        1,
        latent_frames,
        latent_height,
        latent_width,
        device=device,
        dtype=torch.float32,
    )
    batch.cond_indicator[:, :, 0, :, :] = 1.0

    ones_padding = torch.ones_like(batch.cond_indicator)
    zeros_padding = torch.zeros_like(batch.cond_indicator)
    batch.cond_mask = batch.cond_indicator * ones_padding + (1 - batch.cond_indicator) * zeros_padding

    if batch.do_classifier_free_guidance:
        batch.uncond_indicator = batch.cond_indicator.clone()
        batch.uncond_mask = batch.cond_mask.clone()
    else:
        batch.uncond_indicator = None
        batch.uncond_mask = None

    return batch

Functions