Skip to content

hunyuan15

Classes

fastvideo.models.upsamplers.hunyuan15.SRTo1080pUpsampler

SRTo1080pUpsampler(config: SRTo1080pUpsamplerConfig)

Bases: Module

Source code in fastvideo/models/upsamplers/hunyuan15.py
def __init__(
    self,
    config: SRTo1080pUpsamplerConfig,
):
    super().__init__()
    self.num_res_blocks = config.num_res_blocks
    self.block_out_channels = config.block_out_channels
    self.z_channels = config.z_channels

    block_in = config.block_out_channels[0]
    self.conv_in = HunyuanVideo15CausalConv3d(config.z_channels, block_in, kernel_size=3)

    self.up = nn.ModuleList()
    for i_level, ch in enumerate(config.block_out_channels):
        block = nn.ModuleList()
        block_out = ch
        for _ in range(self.num_res_blocks + 1):
            block.append(HunyuanVideo15ResnetBlock(in_channels=block_in, out_channels=block_out))
            block_in = block_out
        up = nn.Module()
        up.block = block

        self.up.append(up)

    self.norm_out = HunyuanVideo15RMS_norm(block_in, images=False)
    self.conv_out = HunyuanVideo15CausalConv3d(block_in, config.out_channels, kernel_size=3)

    self.gradient_checkpointing = False
    self.is_residual = config.is_residual

Functions

fastvideo.models.upsamplers.hunyuan15.SRTo1080pUpsampler.forward
forward(z: Tensor, target_shape: Sequence[int] = None) -> Tensor

Parameters:

Name Type Description Default
z Tensor

(B, C, T, H, W)

required
target_shape Sequence[int]

(H, W)

None
Source code in fastvideo/models/upsamplers/hunyuan15.py
def forward(self, z: Tensor, target_shape: Sequence[int] = None) -> Tensor:
    """
    Args:
        z: (B, C, T, H, W)
        target_shape: (H, W)
    """
    if target_shape is not None and z.shape[-2:] != target_shape:
        bsz = z.shape[0]
        z = rearrange(z, "b c f h w -> (b f) c h w")
        z = F.interpolate(z, size=target_shape, mode="bilinear", align_corners=False)
        z = rearrange(z, "(b f) c h w -> b c f h w", b=bsz)

    # z to block_in
    repeats = self.block_out_channels[0] // (self.z_channels)
    h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)

    # upsampling
    for i_level in range(len(self.block_out_channels)):
        for i_block in range(self.num_res_blocks + 1):
            h = self.up[i_level].block[i_block](h)
        if hasattr(self.up[i_level], "upsample"):
            h = self.up[i_level].upsample(h)

    # end
    h = self.norm_out(h)
    h = get_act_fn("swish")(h)
    h = self.conv_out(h)
    return h

Functions