Skip to content

flux_2

Flux2 pipeline module.

Classes

fastvideo.pipelines.basic.flux_2.Flux2KleinPipeline

Flux2KleinPipeline(*args, **kwargs)

Bases: Flux2Pipeline

Flux2 Klein image diffusion pipeline (distilled, 4-step, no guidance).

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    self.lora_adapter_paths = {}
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore

fastvideo.pipelines.basic.flux_2.Flux2Pipeline

Flux2Pipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Flux2 image diffusion pipeline with LoRA support.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    self.lora_adapter_paths = {}
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore

Methods:

fastvideo.pipelines.basic.flux_2.Flux2Pipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/flux_2/flux_2_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(
        stage_name="input_validation_stage",
        stage=InputValidationStage(),
    )

    self.add_stage(
        stage_name="prompt_encoding_stage",
        stage=Flux2TextEncodingStage(
            text_encoders=[self.get_module("text_encoder")],
            tokenizers=[self.get_module("tokenizer")],
        ),
    )

    self.add_stage(
        stage_name="conditioning_stage",
        stage=ConditioningStage(),
    )

    self.add_stage(
        stage_name="latent_preparation_stage",
        stage=Flux2LatentPreparationStage(
            scheduler=self.get_module("scheduler"),
            transformer=self.get_module("transformer", None),
        ),
    )

    self.add_stage(
        stage_name="timestep_preparation_stage",
        stage=Flux2TimestepPreparationStage(scheduler=self.get_module("scheduler"), ),
    )

    self.add_stage(
        stage_name="denoising_stage",
        stage=DenoisingStage(
            transformer=self.get_module("transformer"),
            scheduler=self.get_module("scheduler"),
            vae=self.get_module("vae"),
            pipeline=self,
        ),
    )

    self.add_stage(
        stage_name="decoding_stage",
        stage=DecodingStage(
            vae=self.get_module("vae"),
            pipeline=self,
        ),
    )

Modules

fastvideo.pipelines.basic.flux_2.flux_2_klein_pipeline

Flux2 Klein image generation pipeline (distilled, 4-step, no guidance).

Classes

fastvideo.pipelines.basic.flux_2.flux_2_klein_pipeline.Flux2KleinPipeline
Flux2KleinPipeline(*args, **kwargs)

Bases: Flux2Pipeline

Flux2 Klein image diffusion pipeline (distilled, 4-step, no guidance).

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    self.lora_adapter_paths = {}
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore

fastvideo.pipelines.basic.flux_2.flux_2_latent_preparation

Flux2 latent preparation stage using packed 2x2 layout.

Flux2 uses packed latents: transformer sees 128 channels (32*4) with half spatial resolution; after denoising we unpatchify to 32 channels and full spatial for VAE decode. This stage prepares (B, 128, T, H//2, W//2).

Classes

fastvideo.pipelines.basic.flux_2.flux_2_latent_preparation.Flux2LatentPreparationStage
Flux2LatentPreparationStage(scheduler, transformer, use_btchw_layout: bool = False)

Bases: LatentPreparationStage

Latent preparation for Flux2: packed layout with half spatial dimensions.

Matches diffusers Flux2Pipeline.prepare_latents: shape is (B, num_channels_latents, T, H_latent//2, W_latent//2) so the transformer sees 128 channels and half spatial; after denoising we unpatchify to (B, 32, H_latent, W_latent) before VAE.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, use_btchw_layout: bool = False) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.use_btchw_layout = use_btchw_layout
Methods:
fastvideo.pipelines.basic.flux_2.flux_2_latent_preparation.Flux2LatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare latents with Flux2 packed half-spatial shape.

Source code in fastvideo/pipelines/basic/flux_2/flux_2_latent_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Prepare latents with Flux2 packed half-spatial shape."""
    from fastvideo.distributed import get_local_torch_device

    latent_num_frames = None
    if hasattr(self, "adjust_video_length"):
        latent_num_frames = self.adjust_video_length(batch, fastvideo_args)

    if not batch.prompt_embeds:
        if batch.keyboard_cond is not None:
            batch_size = batch.keyboard_cond.shape[0]
        elif batch.mouse_cond is not None:
            batch_size = batch.mouse_cond.shape[0]
        elif batch.image_embeds:
            batch_size = batch.image_embeds[0].shape[0]
        else:
            batch_size = 1
    elif isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]

    batch_size *= batch.num_videos_per_prompt

    if not batch.prompt_embeds:
        transformer_dtype = next(self.transformer.parameters()).dtype
        device = get_local_torch_device()
        dummy_prompt = torch.zeros(
            batch_size,
            0,
            self.transformer.hidden_size,
            device=device,
            dtype=transformer_dtype,
        )
        batch.prompt_embeds = [dummy_prompt]
        batch.negative_prompt_embeds = []
        batch.do_classifier_free_guidance = False

    dtype = batch.prompt_embeds[0].dtype
    device = get_local_torch_device()
    generator = batch.generator
    latents = batch.latents
    num_frames = (latent_num_frames if latent_num_frames is not None else batch.num_frames)
    height = batch.height
    width = batch.width

    if height is None or width is None:
        raise ValueError("Height and width must be provided")

    vae_arch = fastvideo_args.pipeline_config.vae_config.arch_config
    scale = vae_arch.spatial_compression_ratio
    # Flux2 packed: half spatial (2x2 patch packing)
    latent_h = (height // scale) // 2
    latent_w = (width // scale) // 2

    if self.use_btchw_layout:
        shape = (
            batch_size,
            num_frames,
            self.transformer.num_channels_latents,
            latent_h,
            latent_w,
        )
        bcthw_shape = tuple(shape[i] for i in [0, 2, 1, 3, 4])
    else:
        shape = (
            batch_size,
            self.transformer.num_channels_latents,
            num_frames,
            latent_h,
            latent_w,
        )
        bcthw_shape = shape

    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(f"You have passed a list of generators of length {len(generator)}, "
                         f"but requested an effective batch size of {batch_size}.")

    if latents is None:
        latents = randn_tensor(
            shape,
            generator=generator,
            device=device,
            dtype=dtype,
        )
        if hasattr(self.scheduler, "init_noise_sigma"):
            latents = latents * self.scheduler.init_noise_sigma
    else:
        latents = latents.to(device)
        is_longcat_refine = (batch.refine_from is not None or batch.stage1_video is not None)
        if (not is_longcat_refine) and hasattr(self.scheduler, "init_noise_sigma"):
            latents = latents * self.scheduler.init_noise_sigma

    batch.latents = latents
    batch.raw_latent_shape = bcthw_shape
    latent_ids = torch.cartesian_prod(
        torch.arange(num_frames, device=device),
        torch.arange(latent_h, device=device),
        torch.arange(latent_w, device=device),
        torch.arange(1, device=device),
    )
    batch.extra["flux2_img_ids"] = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
    # Flux2 mu depends on image_seq_len; use packed spatial size
    batch.n_tokens = latent_h * latent_w
    return batch

fastvideo.pipelines.basic.flux_2.flux_2_pipeline

Flux2 image generation pipeline implementation.

This module contains an implementation of the Flux2 image diffusion pipeline using the modular pipeline architecture.

Classes

fastvideo.pipelines.basic.flux_2.flux_2_pipeline.Flux2Pipeline
Flux2Pipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Flux2 image diffusion pipeline with LoRA support.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    self.lora_adapter_paths = {}
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if (transformer_name in self.modules and self.modules[transformer_name] is not None):
            self.trainable_transformer_modules[transformer_name] = (self.modules[transformer_name])
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if (secondary_transformer_name in self.modules and self.modules[secondary_transformer_name] is not None):
            self.trainable_transformer_modules[secondary_transformer_name] = self.modules[
                secondary_transformer_name]

    logger.info(
        "trainable_transformer_modules: %s",
        self.trainable_transformer_modules.keys(),
    )

    for (
            transformer_name,
            transformer_module,
    ) in self.trainable_transformer_modules.items():
        self.exclude_lora_layers[transformer_name] = (transformer_module.config.arch_config.exclude_lora_layers)
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training", False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info(
            "Using LoRA training with rank %d and alpha %d",
            self.lora_rank,
            self.lora_alpha,
        )
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "to_q",
                "to_k",
                "to_v",
                "to_out",
                "to_qkv",
                "to_gate_compress",
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules,
            )
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules,
            )

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path,
        )  # type: ignore
Methods:
fastvideo.pipelines.basic.flux_2.flux_2_pipeline.Flux2Pipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/flux_2/flux_2_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(
        stage_name="input_validation_stage",
        stage=InputValidationStage(),
    )

    self.add_stage(
        stage_name="prompt_encoding_stage",
        stage=Flux2TextEncodingStage(
            text_encoders=[self.get_module("text_encoder")],
            tokenizers=[self.get_module("tokenizer")],
        ),
    )

    self.add_stage(
        stage_name="conditioning_stage",
        stage=ConditioningStage(),
    )

    self.add_stage(
        stage_name="latent_preparation_stage",
        stage=Flux2LatentPreparationStage(
            scheduler=self.get_module("scheduler"),
            transformer=self.get_module("transformer", None),
        ),
    )

    self.add_stage(
        stage_name="timestep_preparation_stage",
        stage=Flux2TimestepPreparationStage(scheduler=self.get_module("scheduler"), ),
    )

    self.add_stage(
        stage_name="denoising_stage",
        stage=DenoisingStage(
            transformer=self.get_module("transformer"),
            scheduler=self.get_module("scheduler"),
            vae=self.get_module("vae"),
            pipeline=self,
        ),
    )

    self.add_stage(
        stage_name="decoding_stage",
        stage=DecodingStage(
            vae=self.get_module("vae"),
            pipeline=self,
        ),
    )

Functions:

fastvideo.pipelines.basic.flux_2.flux_2_text_encoding

Flux2 text encoding stages.

Classes

fastvideo.pipelines.basic.flux_2.flux_2_text_encoding.Flux2TextEncodingStage
Flux2TextEncodingStage(text_encoders, tokenizers)

Bases: TextEncodingStage

Text encoding for Flux2 full and Klein variants.

Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders
    self._last_audio_embeds: list[torch.Tensor] | None = None

Functions:

fastvideo.pipelines.basic.flux_2.flux_2_timestep_preparation

Flux2-specific timestep preparation.

Classes

fastvideo.pipelines.basic.flux_2.flux_2_timestep_preparation.Flux2TimestepPreparationStage
Flux2TimestepPreparationStage(scheduler)

Bases: TimestepPreparationStage

Flux2 timestep preparation matching the Diffusers Flux2 schedule.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def __init__(self, scheduler) -> None:
    self.scheduler = scheduler

Functions:

fastvideo.pipelines.basic.flux_2.flux_2_timestep_preparation.compute_empirical_mu
compute_empirical_mu(image_seq_len: int, num_steps: int) -> float

Resolution-dependent mu for Flux2 flow-match scheduler. From Black Forest Labs flux2 official repo: sampling.compute_empirical_mu.

Source code in fastvideo/pipelines/basic/flux_2/flux_2_timestep_preparation.py
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
    """
    Resolution-dependent mu for Flux2 flow-match scheduler.
    From Black Forest Labs flux2 official repo: sampling.compute_empirical_mu.
    """
    a1, b1 = 8.73809524e-05, 1.89833333
    a2, b2 = 0.00016927, 0.45666666

    if image_seq_len > 4300:
        return float(a2 * image_seq_len + b2)

    m_200 = a2 * image_seq_len + b2
    m_10 = a1 * image_seq_len + b1
    a = (m_200 - m_10) / 190.0
    b = m_200 - 200.0 * a
    return float(a * num_steps + b)

fastvideo.pipelines.basic.flux_2.presets

Flux2 model family pipeline presets.

Each preset is a named inference preset that declares the user-facing stage topology, default sampling values, and which per-stage overrides are allowed. Presets are registered explicitly from :func:fastvideo.registry._register_presets.

Classes