Skip to content

upsamplers

Classes

fastvideo.models.upsamplers.BlurDownsample

BlurDownsample(dims: int, stride: int, kernel_size: int = 5)

Bases: Module

Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. Works for dims=2 or dims=3 (per-frame).

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
    super().__init__()
    if dims not in (2, 3):
        raise ValueError("dims must be 2 or 3")
    if stride < 1:
        raise ValueError("stride must be >= 1")
    if kernel_size < 3 or kernel_size % 2 != 1:
        raise ValueError("kernel_size must be an odd integer >= 3")

    self.dims = dims
    self.stride = stride
    self.kernel_size = kernel_size

    k = torch.tensor([math.comb(kernel_size - 1, idx) for idx in range(kernel_size)])
    k2d = k[:, None] @ k[None, :]
    k2d = (k2d / k2d.sum()).float()
    self.register_buffer("kernel", k2d[None, None, :, :])

fastvideo.models.upsamplers.LTX2LatentUpsampler

LTX2LatentUpsampler(config: dict[str, Any])

Bases: Module

Public wrapper for the LTX-2 latent upsampler.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, config: dict[str, Any]):
    super().__init__()
    self.model: LatentUpsampler = LatentUpsamplerConfigurator.from_config(config)

fastvideo.models.upsamplers.LatentUpsampler

LatentUpsampler(in_channels: int = 128, mid_channels: int = 512, num_blocks_per_stage: int = 4, dims: int = 3, spatial_upsample: bool = True, temporal_upsample: bool = False, spatial_scale: float = 2.0, rational_resampler: bool = False)

Bases: Module

Model to upsample VAE latents spatially and/or temporally.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(
    self,
    in_channels: int = 128,
    mid_channels: int = 512,
    num_blocks_per_stage: int = 4,
    dims: int = 3,
    spatial_upsample: bool = True,
    temporal_upsample: bool = False,
    spatial_scale: float = 2.0,
    rational_resampler: bool = False,
) -> None:
    super().__init__()

    self.in_channels = in_channels
    self.mid_channels = mid_channels
    self.num_blocks_per_stage = num_blocks_per_stage
    self.dims = dims
    self.spatial_upsample = spatial_upsample
    self.temporal_upsample = temporal_upsample
    self.spatial_scale = float(spatial_scale)
    self.rational_resampler = rational_resampler

    conv = nn.Conv2d if dims == 2 else nn.Conv3d

    self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
    self.initial_norm = nn.GroupNorm(32, mid_channels)
    self.initial_activation = nn.SiLU()

    self.res_blocks = nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])

    if spatial_upsample and temporal_upsample:
        self.upsampler = nn.Sequential(
            nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
            PixelShuffleND(3),
        )
    elif spatial_upsample:
        if rational_resampler:
            self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
        else:
            self.upsampler = nn.Sequential(
                nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(2),
            )
    elif temporal_upsample:
        self.upsampler = nn.Sequential(
            nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
            PixelShuffleND(1),
        )
    else:
        raise ValueError("Either spatial_upsample or temporal_upsample must be True")

    self.post_upsample_res_blocks = nn.ModuleList(
        [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
    )

    self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)

fastvideo.models.upsamplers.LatentUpsamplerConfigurator

Configurator for LatentUpsampler from a config dict.

fastvideo.models.upsamplers.PixelShuffleND

PixelShuffleND(dims: int, upscale_factors: Tuple[int, int, int] = (2, 2, 2))

Bases: Module

N-dimensional pixel shuffle for upsampling.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, dims: int, upscale_factors: Tuple[int, int, int] = (2, 2, 2)) -> None:
    super().__init__()
    if dims not in (1, 2, 3):
        raise ValueError("dims must be 1, 2, or 3")
    self.dims = dims
    self.upscale_factors = upscale_factors

fastvideo.models.upsamplers.ResBlock

ResBlock(channels: int, mid_channels: Optional[int] = None, dims: int = 3)

Bases: Module

Residual block with two convolutional layers, group norm, and SiLU.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3) -> None:
    super().__init__()
    if mid_channels is None:
        mid_channels = channels

    conv = nn.Conv2d if dims == 2 else nn.Conv3d

    self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
    self.norm1 = nn.GroupNorm(32, mid_channels)
    self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
    self.norm2 = nn.GroupNorm(32, channels)
    self.activation = nn.SiLU()

fastvideo.models.upsamplers.SpatialRationalResampler

SpatialRationalResampler(mid_channels: int, scale: float)

Bases: Module

Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased downsample by 'den' using fixed blur + stride. Operates on H,W only. For dims==3, work per-frame for spatial scaling (temporal axis untouched).

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, mid_channels: int, scale: float) -> None:
    super().__init__()
    self.scale = float(scale)
    self.num, self.den = _rational_for_scale(self.scale)
    self.conv = nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
    self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
    self.blur_down = BlurDownsample(dims=2, stride=self.den)

Functions

fastvideo.models.upsamplers.upsample_video

upsample_video(latent: Tensor, video_encoder: Any, upsampler: LatentUpsampler) -> Tensor

Upsample a latent tensor with normalization based on the video encoder's per-channel statistics.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def upsample_video(latent: torch.Tensor, video_encoder: Any, upsampler: LatentUpsampler) -> torch.Tensor:
    """
    Upsample a latent tensor with normalization based on the video encoder's per-channel statistics.
    """
    if not hasattr(video_encoder, "per_channel_statistics"):
        raise ValueError("video_encoder must expose per_channel_statistics for normalization")
    stats = video_encoder.per_channel_statistics
    latent = stats.un_normalize(latent)
    latent = upsampler(latent)
    latent = stats.normalize(latent)
    return latent

Modules

fastvideo.models.upsamplers.hunyuan15

Classes

fastvideo.models.upsamplers.hunyuan15.SRTo1080pUpsampler
SRTo1080pUpsampler(config: SRTo1080pUpsamplerConfig)

Bases: Module

Source code in fastvideo/models/upsamplers/hunyuan15.py
def __init__(
    self,
    config: SRTo1080pUpsamplerConfig,
):
    super().__init__()
    self.num_res_blocks = config.num_res_blocks
    self.block_out_channels = config.block_out_channels
    self.z_channels = config.z_channels

    block_in = config.block_out_channels[0]
    self.conv_in = HunyuanVideo15CausalConv3d(config.z_channels, block_in, kernel_size=3)

    self.up = nn.ModuleList()
    for i_level, ch in enumerate(config.block_out_channels):
        block = nn.ModuleList()
        block_out = ch
        for _ in range(self.num_res_blocks + 1):
            block.append(HunyuanVideo15ResnetBlock(in_channels=block_in, out_channels=block_out))
            block_in = block_out
        up = nn.Module()
        up.block = block

        self.up.append(up)

    self.norm_out = HunyuanVideo15RMS_norm(block_in, images=False)
    self.conv_out = HunyuanVideo15CausalConv3d(block_in, config.out_channels, kernel_size=3)

    self.gradient_checkpointing = False
    self.is_residual = config.is_residual
Functions
fastvideo.models.upsamplers.hunyuan15.SRTo1080pUpsampler.forward
forward(z: Tensor, target_shape: Sequence[int] = None) -> Tensor

Parameters:

Name Type Description Default
z Tensor

(B, C, T, H, W)

required
target_shape Sequence[int]

(H, W)

None
Source code in fastvideo/models/upsamplers/hunyuan15.py
def forward(self, z: Tensor, target_shape: Sequence[int] = None) -> Tensor:
    """
    Args:
        z: (B, C, T, H, W)
        target_shape: (H, W)
    """
    if target_shape is not None and z.shape[-2:] != target_shape:
        bsz = z.shape[0]
        z = rearrange(z, "b c f h w -> (b f) c h w")
        z = F.interpolate(z, size=target_shape, mode="bilinear", align_corners=False)
        z = rearrange(z, "(b f) c h w -> b c f h w", b=bsz)

    # z to block_in
    repeats = self.block_out_channels[0] // (self.z_channels)
    h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)

    # upsampling
    for i_level in range(len(self.block_out_channels)):
        for i_block in range(self.num_res_blocks + 1):
            h = self.up[i_level].block[i_block](h)
        if hasattr(self.up[i_level], "upsample"):
            h = self.up[i_level].upsample(h)

    # end
    h = self.norm_out(h)
    h = get_act_fn("swish")(h)
    h = self.conv_out(h)
    return h

Functions

fastvideo.models.upsamplers.ltx2_upsampler

LTX-2 latent upsampler (spatial/temporal) implementation.

Classes

fastvideo.models.upsamplers.ltx2_upsampler.BlurDownsample
BlurDownsample(dims: int, stride: int, kernel_size: int = 5)

Bases: Module

Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. Works for dims=2 or dims=3 (per-frame).

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
    super().__init__()
    if dims not in (2, 3):
        raise ValueError("dims must be 2 or 3")
    if stride < 1:
        raise ValueError("stride must be >= 1")
    if kernel_size < 3 or kernel_size % 2 != 1:
        raise ValueError("kernel_size must be an odd integer >= 3")

    self.dims = dims
    self.stride = stride
    self.kernel_size = kernel_size

    k = torch.tensor([math.comb(kernel_size - 1, idx) for idx in range(kernel_size)])
    k2d = k[:, None] @ k[None, :]
    k2d = (k2d / k2d.sum()).float()
    self.register_buffer("kernel", k2d[None, None, :, :])
fastvideo.models.upsamplers.ltx2_upsampler.LTX2LatentUpsampler
LTX2LatentUpsampler(config: dict[str, Any])

Bases: Module

Public wrapper for the LTX-2 latent upsampler.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, config: dict[str, Any]):
    super().__init__()
    self.model: LatentUpsampler = LatentUpsamplerConfigurator.from_config(config)
fastvideo.models.upsamplers.ltx2_upsampler.LatentUpsampler
LatentUpsampler(in_channels: int = 128, mid_channels: int = 512, num_blocks_per_stage: int = 4, dims: int = 3, spatial_upsample: bool = True, temporal_upsample: bool = False, spatial_scale: float = 2.0, rational_resampler: bool = False)

Bases: Module

Model to upsample VAE latents spatially and/or temporally.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(
    self,
    in_channels: int = 128,
    mid_channels: int = 512,
    num_blocks_per_stage: int = 4,
    dims: int = 3,
    spatial_upsample: bool = True,
    temporal_upsample: bool = False,
    spatial_scale: float = 2.0,
    rational_resampler: bool = False,
) -> None:
    super().__init__()

    self.in_channels = in_channels
    self.mid_channels = mid_channels
    self.num_blocks_per_stage = num_blocks_per_stage
    self.dims = dims
    self.spatial_upsample = spatial_upsample
    self.temporal_upsample = temporal_upsample
    self.spatial_scale = float(spatial_scale)
    self.rational_resampler = rational_resampler

    conv = nn.Conv2d if dims == 2 else nn.Conv3d

    self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
    self.initial_norm = nn.GroupNorm(32, mid_channels)
    self.initial_activation = nn.SiLU()

    self.res_blocks = nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])

    if spatial_upsample and temporal_upsample:
        self.upsampler = nn.Sequential(
            nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
            PixelShuffleND(3),
        )
    elif spatial_upsample:
        if rational_resampler:
            self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
        else:
            self.upsampler = nn.Sequential(
                nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(2),
            )
    elif temporal_upsample:
        self.upsampler = nn.Sequential(
            nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
            PixelShuffleND(1),
        )
    else:
        raise ValueError("Either spatial_upsample or temporal_upsample must be True")

    self.post_upsample_res_blocks = nn.ModuleList(
        [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
    )

    self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
fastvideo.models.upsamplers.ltx2_upsampler.LatentUpsamplerConfigurator

Configurator for LatentUpsampler from a config dict.

fastvideo.models.upsamplers.ltx2_upsampler.PixelShuffleND
PixelShuffleND(dims: int, upscale_factors: Tuple[int, int, int] = (2, 2, 2))

Bases: Module

N-dimensional pixel shuffle for upsampling.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, dims: int, upscale_factors: Tuple[int, int, int] = (2, 2, 2)) -> None:
    super().__init__()
    if dims not in (1, 2, 3):
        raise ValueError("dims must be 1, 2, or 3")
    self.dims = dims
    self.upscale_factors = upscale_factors
fastvideo.models.upsamplers.ltx2_upsampler.ResBlock
ResBlock(channels: int, mid_channels: Optional[int] = None, dims: int = 3)

Bases: Module

Residual block with two convolutional layers, group norm, and SiLU.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3) -> None:
    super().__init__()
    if mid_channels is None:
        mid_channels = channels

    conv = nn.Conv2d if dims == 2 else nn.Conv3d

    self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
    self.norm1 = nn.GroupNorm(32, mid_channels)
    self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
    self.norm2 = nn.GroupNorm(32, channels)
    self.activation = nn.SiLU()
fastvideo.models.upsamplers.ltx2_upsampler.SpatialRationalResampler
SpatialRationalResampler(mid_channels: int, scale: float)

Bases: Module

Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased downsample by 'den' using fixed blur + stride. Operates on H,W only. For dims==3, work per-frame for spatial scaling (temporal axis untouched).

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def __init__(self, mid_channels: int, scale: float) -> None:
    super().__init__()
    self.scale = float(scale)
    self.num, self.den = _rational_for_scale(self.scale)
    self.conv = nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
    self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
    self.blur_down = BlurDownsample(dims=2, stride=self.den)

Functions

fastvideo.models.upsamplers.ltx2_upsampler.upsample_video
upsample_video(latent: Tensor, video_encoder: Any, upsampler: LatentUpsampler) -> Tensor

Upsample a latent tensor with normalization based on the video encoder's per-channel statistics.

Source code in fastvideo/models/upsamplers/ltx2_upsampler.py
def upsample_video(latent: torch.Tensor, video_encoder: Any, upsampler: LatentUpsampler) -> torch.Tensor:
    """
    Upsample a latent tensor with normalization based on the video encoder's per-channel statistics.
    """
    if not hasattr(video_encoder, "per_channel_statistics"):
        raise ValueError("video_encoder must expose per_channel_statistics for normalization")
    stats = video_encoder.per_channel_statistics
    latent = stats.un_normalize(latent)
    latent = upsampler(latent)
    latent = stats.normalize(latent)
    return latent