Skip to content

moduleloader

Classes

Functions

fastvideo.train.utils.moduleloader.load_module_from_path

load_module_from_path(*, model_path: str, module_type: str, training_config: TrainingConfig, disable_custom_init_weights: bool = False, override_transformer_cls_name: str | None = None, transformer_override_safetensor: str | None = None) -> Module

Load a single pipeline component module.

Accepts a TrainingConfig and internally builds the TrainingArgs needed by PipelineComponentLoader.

Source code in fastvideo/train/utils/moduleloader.py
def load_module_from_path(
    *,
    model_path: str,
    module_type: str,
    training_config: TrainingConfig,
    disable_custom_init_weights: bool = False,
    override_transformer_cls_name: str | None = None,
    transformer_override_safetensor: str | None = None,
) -> torch.nn.Module:
    """Load a single pipeline component module.

    Accepts a ``TrainingConfig`` and internally builds the
    ``TrainingArgs`` needed by ``PipelineComponentLoader``.
    """
    fastvideo_args: Any = _make_training_args(training_config, model_path=model_path)

    local_model_path = maybe_download_model(model_path)
    config = verify_model_config_and_directory(local_model_path)

    if module_type not in config:
        raise ValueError(f"Module {module_type!r} not found in "
                         f"config at {local_model_path}")

    module_info = config[module_type]
    if module_info is None:
        raise ValueError(f"Module {module_type!r} has null value in "
                         f"config at {local_model_path}")

    transformers_or_diffusers, _architecture = module_info
    component_path = os.path.join(local_model_path, module_type)

    old_override: str | None = None
    if override_transformer_cls_name is not None:
        old_override = getattr(
            fastvideo_args,
            "override_transformer_cls_name",
            None,
        )
        fastvideo_args.override_transformer_cls_name = str(override_transformer_cls_name)

    if transformer_override_safetensor:
        fastvideo_args.init_weights_from_safetensors = str(transformer_override_safetensor)

    if disable_custom_init_weights:
        fastvideo_args._loading_teacher_critic_model = True
    try:
        module = PipelineComponentLoader.load_module(
            module_name=module_type,
            component_model_path=component_path,
            transformers_or_diffusers=(transformers_or_diffusers),
            fastvideo_args=fastvideo_args,
        )
    finally:
        if disable_custom_init_weights and hasattr(fastvideo_args, "_loading_teacher_critic_model"):
            del fastvideo_args._loading_teacher_critic_model
        if override_transformer_cls_name is not None:
            if old_override is None:
                if hasattr(
                        fastvideo_args,
                        "override_transformer_cls_name",
                ):
                    fastvideo_args.override_transformer_cls_name = (None)
            else:
                fastvideo_args.override_transformer_cls_name = (old_override)

    if not isinstance(module, torch.nn.Module):
        raise TypeError(f"Loaded {module_type!r} is not a "
                        f"torch.nn.Module: {type(module)}")
    return module

fastvideo.train.utils.moduleloader.make_inference_args

make_inference_args(tc: TrainingConfig, *, model_path: str) -> TrainingArgs

Build a TrainingArgs for inference (validation / pipelines).

Source code in fastvideo/train/utils/moduleloader.py
def make_inference_args(
    tc: TrainingConfig,
    *,
    model_path: str,
) -> TrainingArgs:
    """Build a TrainingArgs for inference (validation / pipelines)."""
    args = _make_training_args(tc, model_path=model_path)
    args.inference_mode = True
    args.mode = ExecutionMode.INFERENCE
    args.dit_cpu_offload = True
    args.VSA_sparsity = tc.vsa_sparsity
    return args