Skip to content

ltx2_upsampler

LTX-2 latent upsampler (spatial/temporal) implementation.

Classes

fastvideo.models.upsamplers.ltx2_upsampler.BlurDownsample

BlurDownsample(dims: int, stride: int, kernel_size: int = 5)

Bases: Module

Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. Works for dims=2 or dims=3 (per-frame).

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
    super().__init__()
    if dims not in (2, 3):
        raise ValueError("dims must be 2 or 3")
    if stride < 1:
        raise ValueError("stride must be >= 1")
    if kernel_size < 3 or kernel_size % 2 != 1:
        raise ValueError("kernel_size must be an odd integer >= 3")

    self.dims = dims
    self.stride = stride
    self.kernel_size = kernel_size

    k = torch.tensor([math.comb(kernel_size - 1, idx) for idx in range(kernel_size)])
    k2d = k[:, None] @ k[None, :]
    k2d = (k2d / k2d.sum()).float()
    self.register_buffer("kernel", k2d[None, None, :, :])

fastvideo.models.upsamplers.ltx2_upsampler.LTX2LatentUpsampler

LTX2LatentUpsampler(config: dict[str, Any])

Bases: Module

Public wrapper for the LTX-2 latent upsampler.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, config: dict[str, Any]):
    super().__init__()
    self.model: LatentUpsampler = LatentUpsamplerConfigurator.from_config(config)

fastvideo.models.upsamplers.ltx2_upsampler.LatentUpsampler

LatentUpsampler(in_channels: int = 128, mid_channels: int = 512, num_blocks_per_stage: int = 4, dims: int = 3, spatial_upsample: bool = True, temporal_upsample: bool = False, spatial_scale: float = 2.0, rational_resampler: bool = False)

Bases: Module

Model to upsample VAE latents spatially and/or temporally.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(
    self,
    in_channels: int = 128,
    mid_channels: int = 512,
    num_blocks_per_stage: int = 4,
    dims: int = 3,
    spatial_upsample: bool = True,
    temporal_upsample: bool = False,
    spatial_scale: float = 2.0,
    rational_resampler: bool = False,
) -> None:
    super().__init__()

    self.in_channels = in_channels
    self.mid_channels = mid_channels
    self.num_blocks_per_stage = num_blocks_per_stage
    self.dims = dims
    self.spatial_upsample = spatial_upsample
    self.temporal_upsample = temporal_upsample
    self.spatial_scale = float(spatial_scale)
    self.rational_resampler = rational_resampler

    conv = nn.Conv2d if dims == 2 else nn.Conv3d

    self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
    self.initial_norm = nn.GroupNorm(32, mid_channels)
    self.initial_activation = nn.SiLU()

    self.res_blocks = nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])

    if spatial_upsample and temporal_upsample:
        self.upsampler = nn.Sequential(
            nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
            PixelShuffleND(3),
        )
    elif spatial_upsample:
        if rational_resampler:
            self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
        else:
            self.upsampler = nn.Sequential(
                nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(2),
            )
    elif temporal_upsample:
        self.upsampler = nn.Sequential(
            nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
            PixelShuffleND(1),
        )
    else:
        raise ValueError("Either spatial_upsample or temporal_upsample must be True")

    self.post_upsample_res_blocks = nn.ModuleList(
        [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
    )

    self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)

fastvideo.models.upsamplers.ltx2_upsampler.LatentUpsamplerConfigurator

Configurator for LatentUpsampler from a config dict.

fastvideo.models.upsamplers.ltx2_upsampler.PixelShuffleND

PixelShuffleND(dims: int, upscale_factors: Tuple[int, int, int] = (2, 2, 2))

Bases: Module

N-dimensional pixel shuffle for upsampling.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, dims: int, upscale_factors: Tuple[int, int, int] = (2, 2, 2)) -> None:
    super().__init__()
    if dims not in (1, 2, 3):
        raise ValueError("dims must be 1, 2, or 3")
    self.dims = dims
    self.upscale_factors = upscale_factors

fastvideo.models.upsamplers.ltx2_upsampler.ResBlock

ResBlock(channels: int, mid_channels: Optional[int] = None, dims: int = 3)

Bases: Module

Residual block with two convolutional layers, group norm, and SiLU.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3) -> None:
    super().__init__()
    if mid_channels is None:
        mid_channels = channels

    conv = nn.Conv2d if dims == 2 else nn.Conv3d

    self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
    self.norm1 = nn.GroupNorm(32, mid_channels)
    self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
    self.norm2 = nn.GroupNorm(32, channels)
    self.activation = nn.SiLU()

fastvideo.models.upsamplers.ltx2_upsampler.SpatialRationalResampler

SpatialRationalResampler(mid_channels: int, scale: float)

Bases: Module

Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased downsample by 'den' using fixed blur + stride. Operates on H,W only. For dims==3, work per-frame for spatial scaling (temporal axis untouched).

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, mid_channels: int, scale: float) -> None:
    super().__init__()
    self.scale = float(scale)
    self.num, self.den = _rational_for_scale(self.scale)
    self.conv = nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
    self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
    self.blur_down = BlurDownsample(dims=2, stride=self.den)

Functions

fastvideo.models.upsamplers.ltx2_upsampler.upsample_video

upsample_video(latent: Tensor, video_encoder: Any, upsampler: LatentUpsampler) -> Tensor

Upsample a latent tensor with normalization based on the video encoder's per-channel statistics.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def upsample_video(latent: torch.Tensor, video_encoder: Any, upsampler: LatentUpsampler) -> torch.Tensor:
    """
    Upsample a latent tensor with normalization based on the video encoder's per-channel statistics.
    """
    if not hasattr(video_encoder, "per_channel_statistics"):
        raise ValueError("video_encoder must expose per_channel_statistics for normalization")
    stats = video_encoder.per_channel_statistics
    latent = stats.un_normalize(latent)
    latent = upsampler(latent)
    latent = stats.normalize(latent)
    return latent