Skip to content

lora

Training-side LoRA utilities for fastvideo.train model plugins.

Classes

fastvideo.train.utils.lora.LoraConfig dataclass

LoraConfig(enable: bool = False, rank: int | None = None, alpha: int | None = None, target_modules: list[str] | None = None)

Structured LoRA settings for one fastvideo.train model role.

Parsed from the nested models.<role>.lora YAML block::

lora:
  enable: true                       # default false
  rank: 16
  alpha: 32                          # defaults to rank when omitted
  target_modules: [to_q, to_k, to_v, to_out]

enable is an explicit on/off switch so a config states its intent plainly: the presence of rank alone never silently flips a run into LoRA-only training. When enable is false a still-present rank is ignored (with an INFO log), so a configured-but-off block is valid.

Functions

fastvideo.train.utils.lora.LoraConfig.coerce classmethod
coerce(obj: LoraConfig | dict[str, Any] | None) -> LoraConfig | None

Normalize a raw YAML mapping (or existing config) into a LoraConfig.

Returns None when no lora block was given, which callers treat as "LoRA not configured" — identical in effect to enable: false.

Source code in fastvideo/train/utils/lora.py
@classmethod
def coerce(
    cls,
    obj: LoraConfig | dict[str, Any] | None,
) -> LoraConfig | None:
    """Normalize a raw YAML mapping (or existing config) into a LoraConfig.

    Returns ``None`` when no ``lora`` block was given, which callers treat
    as "LoRA not configured" — identical in effect to ``enable: false``.
    """
    if obj is None:
        return None
    if isinstance(obj, LoraConfig):
        return obj
    if not isinstance(obj, dict):
        raise TypeError("models.<role>.lora must be a mapping or LoraConfig, got "
                        f"{type(obj).__name__}")
    unknown = set(obj) - set(_LORA_CONFIG_KEYS)
    if unknown:
        logger.warning("LoraConfig: ignoring unrecognized lora keys %s "
                       "(valid keys: %s)", sorted(unknown), list(_LORA_CONFIG_KEYS))
    return cls(
        enable=bool(obj.get("enable", False)),
        rank=obj.get("rank"),
        alpha=obj.get("alpha"),
        target_modules=obj.get("target_modules"),
    )

Functions

fastvideo.train.utils.lora.enable_lora_training

enable_lora_training(transformer: Module, *, lora_rank: int, lora_alpha: int | None = None, lora_target_modules: Sequence[str] | None = None) -> int

Replace supported linear layers with trainable LoRA wrappers.

Returns the number of layers converted to LoRA.

Source code in fastvideo/train/utils/lora.py
def enable_lora_training(
    transformer: torch.nn.Module,
    *,
    lora_rank: int,
    lora_alpha: int | None = None,
    lora_target_modules: Sequence[str] | None = None,
) -> int:
    """Replace supported linear layers with trainable LoRA wrappers.

    Returns the number of layers converted to LoRA.
    """

    rank = int(lora_rank)
    if rank <= 0:
        raise ValueError(f"lora_rank must be > 0, got {lora_rank!r}")

    alpha = int(lora_alpha) if lora_alpha is not None else rank
    target_modules = list(lora_target_modules or DEFAULT_LORA_TARGET_MODULES)
    arch_config = getattr(
        getattr(transformer, "config", None),
        "arch_config",
        None,
    )
    excluded_modules = list(getattr(arch_config, "exclude_lora_layers", []), )

    transformer.requires_grad_(False)

    replacements: list[tuple[str, BaseLayerWithLoRA]] = []
    for module_name, module in transformer.named_modules():
        if not module_name:
            continue
        if not _is_target_layer(module_name, target_modules):
            continue
        if _is_excluded_layer(module_name, excluded_modules):
            continue

        lora_layer = get_lora_layer(
            module,
            lora_rank=rank,
            lora_alpha=alpha,
            training_mode=True,
        )
        if lora_layer is None:
            continue
        replacements.append((module_name, lora_layer))

    if not replacements:
        raise ValueError("No LoRA-compatible layers were found for the requested "
                         f"target modules: {target_modules}")

    for module_name, lora_layer in replacements:
        replace_submodule(transformer, module_name, lora_layer)

    _replicate_lora_parameters(transformer)
    transformer.train()

    logger.info(
        "Enabled LoRA training with rank=%d alpha=%d on %d layers",
        rank,
        alpha,
        len(replacements),
    )
    return len(replacements)