Skip to content

flux_2

Classes

fastvideo.configs.models.dits.flux_2.Flux2ArchConfig dataclass

Flux2ArchConfig(stacked_params_mapping: list[tuple[str, str, str]] = list(), _fsdp_shard_conditions: list = list(), _compile_conditions: list = list(), param_names_mapping: dict = (lambda: {'transformer\\.(\\w*)\\.(.*)$': '\\1.\\2'})(), reverse_param_names_mapping: dict = dict(), lora_param_names_mapping: dict = dict(), cast_prompt_embeds_to_dit_dtype: bool = True, _supported_attention_backends: tuple[AttentionBackendEnum, ...] = (SAGE_ATTN, FLASH_ATTN, TORCH_SDPA, VIDEO_SPARSE_ATTN, VMOBA_ATTN, SAGE_ATTN_THREE, SLA_ATTN, SAGE_SLA_ATTN), hidden_size: int = 0, num_attention_heads: int = 24, num_channels_latents: int = 0, in_channels: int = 64, out_channels: int | None = None, exclude_lora_layers: list[str] = list(), boundary_ratio: float | None = None, patch_size: int = 1, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, joint_attention_dim: int = 4096, timestep_guidance_channels: int = 256, mlp_ratio: float = 3.0, axes_dims_rope: tuple[int, ...] = (32, 32, 32, 32), rope_theta: int = 2000, eps: float = 1e-06, guidance_embeds: bool = True, ff_context_swiglu_fp32: bool = False)

Bases: DiTArchConfig

Architecture configuration for Flux2 transformer model.

Methods:

fastvideo.configs.models.dits.flux_2.Flux2ArchConfig.update_from_weight_keys
update_from_weight_keys(all_keys: set[str]) -> None

Infer num_layers and num_single_layers from checkpoint weight keys so the model is built with the same number of blocks as the weights.

Source code in fastvideo/configs/models/dits/flux_2.py
def update_from_weight_keys(self, all_keys: set[str]) -> None:
    """Infer num_layers and num_single_layers from checkpoint weight keys so the model is built with the same number of blocks as the weights."""
    if not all_keys:
        return
    num_layers = 0
    num_single_layers = 0
    for k in all_keys:
        if "single_transformer_blocks." not in k and "transformer_blocks." in k:
            parts = k.split("transformer_blocks.")[-1].split(".")
            if parts[0].isdigit():
                num_layers = max(num_layers, int(parts[0]) + 1)
        if "single_transformer_blocks." in k:
            parts = k.split("single_transformer_blocks.")[-1].split(".")
            if parts[0].isdigit():
                num_single_layers = max(num_single_layers, int(parts[0]) + 1)
    if num_layers > 0:
        self.num_layers = num_layers
        logger.info("Inferred num_layers=%s from checkpoint keys", num_layers)
    if num_single_layers > 0:
        self.num_single_layers = num_single_layers
        logger.info("Inferred num_single_layers=%s from checkpoint keys", num_single_layers)
    if num_layers > 0 or num_single_layers > 0:
        self.__post_init__()

fastvideo.configs.models.dits.flux_2.Flux2Config dataclass

Flux2Config(arch_config: DiTArchConfig = Flux2ArchConfig(), prefix: str = 'Flux', quant_config: QuantizationConfig | None = None)

Bases: DiTConfig

Configuration for Flux2 transformer model.

Functions: