Skip to content

stages

LTX-2 family pipeline stages.

Classes

fastvideo.pipelines.basic.ltx2.stages.LTX2AudioDecodingStage

LTX2AudioDecodingStage(audio_decoder, vocoder)

Bases: PipelineStage

Decode LTX-2 audio latents into a waveform.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_audio_decoding.py
def __init__(self, audio_decoder, vocoder) -> None:
    super().__init__()
    self.audio_decoder = audio_decoder
    self.vocoder = vocoder

fastvideo.pipelines.basic.ltx2.stages.LTX2DenoisingStage

LTX2DenoisingStage(transformer, *, sigmas_override: list[float] | None = None, num_inference_steps_override: int | None = None, force_guidance_scale: float | None = None, initial_audio_latents_key: str | None = 'ltx2_audio_latents')

Bases: PipelineStage

Run the LTX-2 denoising loop over the sigma schedule.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_denoising.py
def __init__(
    self,
    transformer,
    *,
    sigmas_override: list[float] | None = None,
    num_inference_steps_override: int | None = None,
    force_guidance_scale: float | None = None,
    initial_audio_latents_key: str | None = "ltx2_audio_latents",
) -> None:
    super().__init__()
    self.transformer = transformer
    self.sigmas_override = sigmas_override
    self.num_inference_steps_override = num_inference_steps_override
    self.force_guidance_scale = force_guidance_scale
    self.initial_audio_latents_key = initial_audio_latents_key

fastvideo.pipelines.basic.ltx2.stages.LTX2LatentPreparationStage

LTX2LatentPreparationStage(transformer, vae)

Bases: PipelineStage

Prepare initial LTX-2 latents without relying on a diffusers scheduler.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_latent_preparation.py
def __init__(self, transformer, vae) -> None:
    super().__init__()
    self.transformer = transformer
    self.vae = vae

fastvideo.pipelines.basic.ltx2.stages.LTX2RefineInitStage

Bases: PipelineStage

Switch the request to half resolution before the stage-1 denoise.

Stashes the original target resolution on batch.extra so :class:LTX2UpsampleStage can recover it after stage 1 runs. When the refine path is disabled the stage is a no-op.

fastvideo.pipelines.basic.ltx2.stages.LTX2RefineLoRAStage

LTX2RefineLoRAStage(*, pipeline: Any, lora_path: str | None, lora_nickname: str = 'ltx2_refine')

Bases: PipelineStage

Apply a refinement-specific LoRA before stage-2 denoising.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_refine.py
def __init__(
    self,
    *,
    pipeline: Any,
    lora_path: str | None,
    lora_nickname: str = "ltx2_refine",
) -> None:
    super().__init__()
    self._pipeline_ref = (weakref.ref(pipeline) if pipeline is not None else None)
    self._lora_path = lora_path
    self._lora_nickname = lora_nickname
    self._applied = False

fastvideo.pipelines.basic.ltx2.stages.LTX2TextEncodingStage

LTX2TextEncodingStage(text_encoders, tokenizers)

Bases: TextEncodingStage

LTX2 text encoding stage with sequence parallelism support.

When SP is enabled (sp_world_size > 1), only rank 0 runs the text encoder and broadcasts embeddings to other ranks. This avoids I/O contention from all ranks loading the Gemma model simultaneously, which can cause text encoding to take 100+ seconds instead of ~5 seconds.

Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders
    self._last_audio_embeds: list[torch.Tensor] | None = None

fastvideo.pipelines.basic.ltx2.stages.LTX2UpsampleStage

LTX2UpsampleStage(*, upsampler: Any, vae: Any, transformer: Any | None = None, sigmas: list[float] | None = None, add_noise: bool = True)

Bases: PipelineStage

Upsample stage-1 latents to stage-2 resolution and add refine noise.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_refine.py
def __init__(
    self,
    *,
    upsampler: Any,
    vae: Any,
    transformer: Any | None = None,
    sigmas: list[float] | None = None,
    add_noise: bool = True,
) -> None:
    super().__init__()
    self.upsampler = upsampler
    self.vae = vae
    self.transformer = transformer
    self.sigmas = sigmas or STAGE_2_DISTILLED_SIGMA_VALUES
    self.add_noise = add_noise

Modules

fastvideo.pipelines.basic.ltx2.stages.ltx2_audio_decoding

Audio decoding stage for LTX-2 pipelines.

Classes

fastvideo.pipelines.basic.ltx2.stages.ltx2_audio_decoding.LTX2AudioDecodingStage
LTX2AudioDecodingStage(audio_decoder, vocoder)

Bases: PipelineStage

Decode LTX-2 audio latents into a waveform.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_audio_decoding.py
def __init__(self, audio_decoder, vocoder) -> None:
    super().__init__()
    self.audio_decoder = audio_decoder
    self.vocoder = vocoder

Functions:

fastvideo.pipelines.basic.ltx2.stages.ltx2_denoising

LTX-2 denoising stage using the native sigma schedule.

Classes

fastvideo.pipelines.basic.ltx2.stages.ltx2_denoising.LTX2DenoisingStage
LTX2DenoisingStage(transformer, *, sigmas_override: list[float] | None = None, num_inference_steps_override: int | None = None, force_guidance_scale: float | None = None, initial_audio_latents_key: str | None = 'ltx2_audio_latents')

Bases: PipelineStage

Run the LTX-2 denoising loop over the sigma schedule.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_denoising.py
def __init__(
    self,
    transformer,
    *,
    sigmas_override: list[float] | None = None,
    num_inference_steps_override: int | None = None,
    force_guidance_scale: float | None = None,
    initial_audio_latents_key: str | None = "ltx2_audio_latents",
) -> None:
    super().__init__()
    self.transformer = transformer
    self.sigmas_override = sigmas_override
    self.num_inference_steps_override = num_inference_steps_override
    self.force_guidance_scale = force_guidance_scale
    self.initial_audio_latents_key = initial_audio_latents_key

Functions:

fastvideo.pipelines.basic.ltx2.stages.ltx2_image_conditioning

FastVideo-native LTX-2 image-to-video conditioning helpers.

Public-side port of FastVideo-internal/.../ltx2_i2v_conditioning.py. The module composes a clean_latent + denoise_mask pair that the LTX-2 latent-prep + denoising stages mix into the noise tensor, so a generated segment can be anchored to:

  • one or more conditioning images at specific latent frame indices (ltx2_images),
  • a multi-frame conditioning video clip jointly VAE-encoded (ltx2_video_conditions),
  • a continuation latent carried over from the previous segment (ltx2_conditioning_latent_stage1 / _stage2).

The streaming server's session controller populates the continuation latents between segments; the legacy from_pretrained path passes ltx2_images / ltx2_image_crf through compat translation.

Classes

fastvideo.pipelines.basic.ltx2.stages.ltx2_image_conditioning.LTX2ImageConditioningState dataclass
LTX2ImageConditioningState(clean_latent: Tensor, denoise_mask: Tensor, images: list[tuple[str, int, float]], latent_conditioned: bool = False)

Result of building image / continuation conditioning.

Functions:

fastvideo.pipelines.basic.ltx2.stages.ltx2_image_conditioning.apply_ltx2_gaussian_noiser
apply_ltx2_gaussian_noiser(*, noise: Tensor, clean_latent: Tensor, denoise_mask: Tensor, noise_scale: float = 1.0) -> Tensor

Mix noise into clean_latent along denoise_mask * scale.

Values close to 1 in the mask produce near-pure noise (used in a fresh stage-2 latent), values near 0 leave the clean latent untouched (used in conditioning regions).

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_image_conditioning.py
def apply_ltx2_gaussian_noiser(
    *,
    noise: torch.Tensor,
    clean_latent: torch.Tensor,
    denoise_mask: torch.Tensor,
    noise_scale: float = 1.0,
) -> torch.Tensor:
    """Mix ``noise`` into ``clean_latent`` along ``denoise_mask`` * scale.

    Values close to 1 in the mask produce near-pure noise (used in a
    fresh stage-2 latent), values near 0 leave the clean latent
    untouched (used in conditioning regions).
    """
    scaled_mask = denoise_mask * float(noise_scale)
    return (noise * scaled_mask + clean_latent * (1.0 - scaled_mask)).to(noise.dtype)
fastvideo.pipelines.basic.ltx2.stages.ltx2_image_conditioning.build_ltx2_image_conditioning
build_ltx2_image_conditioning(*, batch: ForwardBatch, latents: Tensor, vae: Module, height: int, width: int, image_crf: float | None = None, base_clean_latent: Tensor | None = None) -> LTX2ImageConditioningState | None

Build the (clean_latent, denoise_mask) state for the next segment.

Returns None for plain T2V (no images, no continuation, no video conditions). The denoise mask is 1 where the model should sample fresh, 0 where it should preserve the conditioning latent exactly. base_clean_latent is None corresponds to stage 1 (fresh half-res latent); base_clean_latent set means stage 2 (already-upsampled latent from the upsampler stage).

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_image_conditioning.py
def build_ltx2_image_conditioning(
    *,
    batch: ForwardBatch,
    latents: torch.Tensor,
    vae: torch.nn.Module,
    height: int,
    width: int,
    image_crf: float | None = None,
    base_clean_latent: torch.Tensor | None = None,
) -> LTX2ImageConditioningState | None:
    """Build the (clean_latent, denoise_mask) state for the next segment.

    Returns ``None`` for plain T2V (no images, no continuation, no
    video conditions). The denoise mask is 1 where the model should
    sample fresh, 0 where it should preserve the conditioning latent
    exactly. ``base_clean_latent is None`` corresponds to stage 1
    (fresh half-res latent); ``base_clean_latent`` set means stage 2
    (already-upsampled latent from the upsampler stage).
    """
    images = resolve_ltx2_images(batch)
    conditioning_latent_stage1 = getattr(batch, "ltx2_conditioning_latent_stage1", None)
    conditioning_latent_stage2 = getattr(batch, "ltx2_conditioning_latent_stage2", None)
    is_stage1_conditioning = base_clean_latent is None
    is_stage2_conditioning = not is_stage1_conditioning
    has_latent_conditioning = False
    continuation_latent_to_insert: torch.Tensor | None = None
    if (conditioning_latent_stage1 is not None and not torch.is_tensor(conditioning_latent_stage1)):
        raise TypeError("LTX-2 stage1 continuation latent conditioning "
                        "expects a torch.Tensor.")
    if (conditioning_latent_stage2 is not None and not torch.is_tensor(conditioning_latent_stage2)):
        raise TypeError("LTX-2 stage2 continuation latent conditioning "
                        "expects a torch.Tensor.")

    if (conditioning_latent_stage1 is None) != (conditioning_latent_stage2 is None):
        raise ValueError("LTX-2 continuation expects both stage1 and stage2 "
                         "conditioning latents (or neither for first round).")
    if is_stage1_conditioning and conditioning_latent_stage1 is not None:
        has_latent_conditioning = True
        continuation_latent_to_insert = conditioning_latent_stage1.to(
            device=latents.device,
            dtype=latents.dtype,
        )
    elif is_stage2_conditioning and conditioning_latent_stage2 is not None:
        has_latent_conditioning = True
        continuation_latent_to_insert = conditioning_latent_stage2.to(
            device=latents.device,
            dtype=latents.dtype,
        )

    video_conditions = getattr(batch, "ltx2_video_conditions", None) or []

    if not images and not has_latent_conditioning and not video_conditions:
        return None

    clean_latent = (torch.zeros_like(latents) if base_clean_latent is None else base_clean_latent.clone())

    denoise_mask = torch.ones(
        (
            latents.shape[0],
            1,
            latents.shape[2],
            latents.shape[3],
            latents.shape[4],
        ),
        dtype=torch.float32,
        device=latents.device,
    )

    if image_crf is None:
        image_crf = getattr(batch, "ltx2_image_crf", DEFAULT_LTX2_IMAGE_CRF)

    vae_param = next(vae.parameters(), None)
    encoder_dtype = (vae_param.dtype if vae_param is not None else latents.dtype)
    encoder_device = (vae_param.device if vae_param is not None else latents.device)
    cache: dict[tuple[str, int, int, float], torch.Tensor] = {}
    latent_conditioned = False

    if has_latent_conditioning:
        if continuation_latent_to_insert is None:
            raise RuntimeError("LTX-2 continuation latent conditioning state is invalid.")
        # NOTE: frame index and strength are intentionally hard-coded.
        # We always anchor the first frame of the next clip at full
        # strength to the previous clip's last latent.
        _insert_conditioning_latent(
            conditioning_latent=continuation_latent_to_insert,
            clean_latent=clean_latent,
            denoise_mask=denoise_mask,
            frame_idx=LTX2_CONTINUATION_TARGET_FRAME_IDX,
            strength=LTX2_CONTINUATION_STRENGTH,
            source_name="continuation",
        )
        latent_conditioned = True

    for image_path, frame_idx, strength in images:
        cache_key = (image_path, height, width, float(image_crf))
        image_latent = cache.get(cache_key)
        if image_latent is None:
            image_tensor = load_ltx2_conditioning_image(
                image_path=image_path,
                height=height,
                width=width,
                dtype=encoder_dtype,
                device=encoder_device,
                image_crf=float(image_crf),
            )
            image_latent = _extract_video_latent(vae, image_tensor).to(
                device=latents.device,
                dtype=latents.dtype,
            )
            cache[cache_key] = image_latent

        _insert_conditioning_latent(
            conditioning_latent=image_latent,
            clean_latent=clean_latent,
            denoise_mask=denoise_mask,
            frame_idx=frame_idx,
            strength=strength,
            source_name="image",
        )

    for frame_paths, frame_idx, strength in video_conditions:
        video_tensor = load_ltx2_conditioning_video_clip(
            frame_paths,
            height=height,
            width=width,
            dtype=encoder_dtype,
            device=encoder_device,
            image_crf=float(image_crf),
        )
        video_latent = _extract_video_latent(vae, video_tensor).to(
            device=latents.device,
            dtype=latents.dtype,
        )
        logger.info(
            "[LTX2] Video-clip condition: %d frames -> "
            "latent T=%d at frame_idx=%d strength=%.2f",
            len(frame_paths),
            video_latent.shape[2],
            frame_idx,
            strength,
        )
        _insert_conditioning_latent(
            conditioning_latent=video_latent,
            clean_latent=clean_latent,
            denoise_mask=denoise_mask,
            frame_idx=frame_idx,
            strength=strength,
            source_name="video_clip",
        )

    return LTX2ImageConditioningState(
        clean_latent=clean_latent,
        denoise_mask=denoise_mask,
        images=images,
        latent_conditioned=latent_conditioned,
    )
fastvideo.pipelines.basic.ltx2.stages.ltx2_image_conditioning.load_ltx2_conditioning_video_clip
load_ltx2_conditioning_video_clip(frame_paths: list[str], *, height: int, width: int, dtype: dtype, device: device, image_crf: float) -> Tensor

Load multiple frames and stack as [1, C, T, H, W] for joint VAE encoding so the resulting latent captures temporal/motion info.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_image_conditioning.py
def load_ltx2_conditioning_video_clip(
    frame_paths: list[str],
    *,
    height: int,
    width: int,
    dtype: torch.dtype,
    device: torch.device,
    image_crf: float,
) -> torch.Tensor:
    """Load multiple frames and stack as ``[1, C, T, H, W]`` for joint
    VAE encoding so the resulting latent captures temporal/motion info.
    """
    frame_tensors: list[torch.Tensor] = []
    for path in frame_paths:
        image = load_image(path)
        image_np = np.array(image)[..., :3]
        image_np = _preprocess_conditioning_image(image_np, image_crf=image_crf)
        t = torch.tensor(image_np, dtype=torch.float32, device=device)
        # _resize_and_center_crop returns [1, C, 1, H, W]
        t = _resize_and_center_crop(t, height, width)
        frame_tensors.append(t)
    # Concat along T dimension -> [1, C, T, H, W]
    video = torch.cat(frame_tensors, dim=2)
    return (video / 127.5 - 1.0).to(device=device, dtype=dtype)
fastvideo.pipelines.basic.ltx2.stages.ltx2_image_conditioning.post_process_ltx2_denoised
post_process_ltx2_denoised(*, denoised: Tensor, denoise_mask: Tensor, clean_latent: Tensor) -> Tensor

Restore the conditioning regions of clean_latent outside the denoise mask after the model has filled in the masked area.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_image_conditioning.py
def post_process_ltx2_denoised(
    *,
    denoised: torch.Tensor,
    denoise_mask: torch.Tensor,
    clean_latent: torch.Tensor,
) -> torch.Tensor:
    """Restore the conditioning regions of ``clean_latent`` outside the
    denoise mask after the model has filled in the masked area."""
    return (denoised * denoise_mask + clean_latent.float() * (1.0 - denoise_mask)).to(denoised.dtype)
fastvideo.pipelines.basic.ltx2.stages.ltx2_image_conditioning.resolve_ltx2_images
resolve_ltx2_images(batch: ForwardBatch) -> list[tuple[str, int, float]]

Collect any LTX-2 image conditioning inputs from the batch.

Falls back to batch.image_path for the simple single-image i2v case (anchors the first latent frame at full strength).

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_image_conditioning.py
def resolve_ltx2_images(batch: ForwardBatch) -> list[tuple[str, int, float]]:
    """Collect any LTX-2 image conditioning inputs from the batch.

    Falls back to ``batch.image_path`` for the simple single-image i2v
    case (anchors the first latent frame at full strength).
    """
    images = batch.ltx2_images
    if images is None and batch.image_path:
        images = [(batch.image_path, 0, 1.0)]
    if not images:
        return []

    resolved: list[tuple[str, int, float]] = []
    for item in images:
        if not isinstance(item, tuple | list) or len(item) != 3:
            raise ValueError("Each ltx2_images item must be a tuple/list of "
                             "(path, frame_idx, strength).")
        image_path, frame_idx, strength = item
        frame_idx_int = int(frame_idx)
        strength_float = float(strength)
        if frame_idx_int < 0:
            raise ValueError(f"LTX-2 frame_idx must be >= 0, got {frame_idx_int}")
        if strength_float < 0.0 or strength_float > 1.0:
            raise ValueError(f"LTX-2 image conditioning strength must be in [0, 1], "
                             f"got {strength_float}")
        resolved.append((str(image_path), frame_idx_int, strength_float))
    return resolved

fastvideo.pipelines.basic.ltx2.stages.ltx2_latent_preparation

Latent preparation stage for LTX-2 pipelines.

Classes

fastvideo.pipelines.basic.ltx2.stages.ltx2_latent_preparation.LTX2LatentPreparationStage
LTX2LatentPreparationStage(transformer, vae)

Bases: PipelineStage

Prepare initial LTX-2 latents without relying on a diffusers scheduler.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_latent_preparation.py
def __init__(self, transformer, vae) -> None:
    super().__init__()
    self.transformer = transformer
    self.vae = vae

Functions:

fastvideo.pipelines.basic.ltx2.stages.ltx2_refine

LTX-2 refinement stages for 2x spatial upscaling + distilled denoising.

Public-side port of FastVideo-internal/.../stages/ltx2_refine.py. The three stages run between the stage-1 denoising pass and the stage-2 denoising pass:

  • :class:LTX2RefineInitStage — halves the requested resolution so the first denoise runs at ½× and stashes the original target resolution on batch.extra so the upsample stage can recover it.
  • :class:LTX2UpsampleStage — upsamples the stage-1 latents through the LTX-2 latent upsampler, optionally re-applies image conditioning, and mixes in fresh noise scaled by the stage-2 sigma so the next denoise has something to refine.
  • :class:LTX2RefineLoRAStage — swaps in a refinement LoRA before the stage-2 denoise (no-op when the path is unset).

Behaviour matches the internal version 1:1 for the text-to-video path; the i2v / continuation branches inside build_ltx2_image_conditioning defer to a NotImplementedError until the rest of the i2v conditioning module is ported.

Classes

fastvideo.pipelines.basic.ltx2.stages.ltx2_refine.LTX2RefineInitStage

Bases: PipelineStage

Switch the request to half resolution before the stage-1 denoise.

Stashes the original target resolution on batch.extra so :class:LTX2UpsampleStage can recover it after stage 1 runs. When the refine path is disabled the stage is a no-op.

fastvideo.pipelines.basic.ltx2.stages.ltx2_refine.LTX2RefineLoRAStage
LTX2RefineLoRAStage(*, pipeline: Any, lora_path: str | None, lora_nickname: str = 'ltx2_refine')

Bases: PipelineStage

Apply a refinement-specific LoRA before stage-2 denoising.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_refine.py
def __init__(
    self,
    *,
    pipeline: Any,
    lora_path: str | None,
    lora_nickname: str = "ltx2_refine",
) -> None:
    super().__init__()
    self._pipeline_ref = (weakref.ref(pipeline) if pipeline is not None else None)
    self._lora_path = lora_path
    self._lora_nickname = lora_nickname
    self._applied = False
fastvideo.pipelines.basic.ltx2.stages.ltx2_refine.LTX2UpsampleStage
LTX2UpsampleStage(*, upsampler: Any, vae: Any, transformer: Any | None = None, sigmas: list[float] | None = None, add_noise: bool = True)

Bases: PipelineStage

Upsample stage-1 latents to stage-2 resolution and add refine noise.

Source code in fastvideo/pipelines/basic/ltx2/stages/ltx2_refine.py
def __init__(
    self,
    *,
    upsampler: Any,
    vae: Any,
    transformer: Any | None = None,
    sigmas: list[float] | None = None,
    add_noise: bool = True,
) -> None:
    super().__init__()
    self.upsampler = upsampler
    self.vae = vae
    self.transformer = transformer
    self.sigmas = sigmas or STAGE_2_DISTILLED_SIGMA_VALUES
    self.add_noise = add_noise

Functions:

fastvideo.pipelines.basic.ltx2.stages.ltx2_text_encoding

LTX2-specific text encoding stage with sequence parallelism broadcast support.

When running with sequence parallelism (SP), the Gemma text encoder is only executed on rank 0, and the embeddings are broadcast to all other ranks. This avoids I/O contention from all ranks loading the Gemma model simultaneously.

Classes

fastvideo.pipelines.basic.ltx2.stages.ltx2_text_encoding.LTX2TextEncodingStage
LTX2TextEncodingStage(text_encoders, tokenizers)

Bases: TextEncodingStage

LTX2 text encoding stage with sequence parallelism support.

When SP is enabled (sp_world_size > 1), only rank 0 runs the text encoder and broadcasts embeddings to other ranks. This avoids I/O contention from all ranks loading the Gemma model simultaneously, which can cause text encoding to take 100+ seconds instead of ~5 seconds.

Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders
    self._last_audio_embeds: list[torch.Tensor] | None = None

Functions: