load_module_from_path(*, model_path: str, module_type: str, training_config: TrainingConfig, disable_custom_init_weights: bool = False, override_transformer_cls_name: str | None = None, transformer_override_safetensor: str | None = None) -> Module
Load a single pipeline component module.
Accepts a TrainingConfig and internally builds the
TrainingArgs needed by PipelineComponentLoader.
Source code in fastvideo/train/utils/moduleloader.py
| def load_module_from_path(
*,
model_path: str,
module_type: str,
training_config: TrainingConfig,
disable_custom_init_weights: bool = False,
override_transformer_cls_name: str | None = None,
transformer_override_safetensor: str | None = None,
) -> torch.nn.Module:
"""Load a single pipeline component module.
Accepts a ``TrainingConfig`` and internally builds the
``TrainingArgs`` needed by ``PipelineComponentLoader``.
"""
fastvideo_args: Any = _make_training_args(training_config, model_path=model_path)
local_model_path = maybe_download_model(model_path)
config = verify_model_config_and_directory(local_model_path)
if module_type not in config:
raise ValueError(f"Module {module_type!r} not found in "
f"config at {local_model_path}")
module_info = config[module_type]
if module_info is None:
raise ValueError(f"Module {module_type!r} has null value in "
f"config at {local_model_path}")
transformers_or_diffusers, _architecture = module_info
component_path = os.path.join(local_model_path, module_type)
old_override: str | None = None
if override_transformer_cls_name is not None:
old_override = getattr(
fastvideo_args,
"override_transformer_cls_name",
None,
)
fastvideo_args.override_transformer_cls_name = str(override_transformer_cls_name)
if transformer_override_safetensor:
fastvideo_args.init_weights_from_safetensors = str(transformer_override_safetensor)
if disable_custom_init_weights:
fastvideo_args._loading_teacher_critic_model = True
try:
module = PipelineComponentLoader.load_module(
module_name=module_type,
component_model_path=component_path,
transformers_or_diffusers=(transformers_or_diffusers),
fastvideo_args=fastvideo_args,
)
finally:
if disable_custom_init_weights and hasattr(fastvideo_args, "_loading_teacher_critic_model"):
del fastvideo_args._loading_teacher_critic_model
if override_transformer_cls_name is not None:
if old_override is None:
if hasattr(
fastvideo_args,
"override_transformer_cls_name",
):
fastvideo_args.override_transformer_cls_name = (None)
else:
fastvideo_args.override_transformer_cls_name = (old_override)
if not isinstance(module, torch.nn.Module):
raise TypeError(f"Loaded {module_type!r} is not a "
f"torch.nn.Module: {type(module)}")
return module
|