Skip to content

weight_utils

Utilities for downloading and initializing model weights.

Functions:

fastvideo.models.loader.weight_utils.default_weight_loader

default_weight_loader(param: Tensor, loaded_weight: Tensor) -> None

Default weight loader.

Source code in fastvideo/models/loader/weight_utils.py
def default_weight_loader(param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    try:
        if param.numel() == 1 and loaded_weight.numel() == 1:
            # Sometimes scalar values aren't considered tensors with shapes
            # so if both param and loaded_weight are a scalar,
            # "broadcast" instead of copy
            param.data.fill_(loaded_weight.item())
        else:
            assert param.size() == loaded_weight.size(), (
                f"Attempted to load weight ({loaded_weight.size()}) "
                f"into parameter ({param.size()})")

            param.data.copy_(loaded_weight)
    except Exception:
        # NOTE: This exception is added for the purpose of setting breakpoint to
        # debug weight loading issues.
        raise

fastvideo.models.loader.weight_utils.enable_hf_transfer

enable_hf_transfer() -> None

automatically activates hf_transfer

Source code in fastvideo/models/loader/weight_utils.py
def enable_hf_transfer() -> None:
    """automatically activates hf_transfer
    """
    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
        try:
            # enable hf hub transfer if available
            import hf_transfer  # type: ignore # noqa
            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
        except ImportError:
            pass

fastvideo.models.loader.weight_utils.filter_files_not_needed_for_inference

filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]

Exclude files that are not needed for inference.

See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233

Source code in fastvideo/models/loader/weight_utils.py
def filter_files_not_needed_for_inference(
        hf_weights_files: list[str]) -> list[str]:
    """
    Exclude files that are not needed for inference.

    See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
    """
    blacklist = [
        "training_args.bin",
        "optimizer.bin",
        "optimizer.pt",
        "scheduler.pt",
        "scaler.pt",
    ]
    hf_weights_files = [
        f for f in hf_weights_files
        if not any(f.endswith(x) for x in blacklist)
    ]
    return hf_weights_files

fastvideo.models.loader.weight_utils.maybe_remap_kv_scale_name

maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None

Remap the name of FP8 k/v_scale parameters.

This function handles the remapping of FP8 k/v_scale parameter names. It detects if the given name ends with a suffix and attempts to remap it to the expected name format in the model. If the remapped name is not found in the params_dict, a warning is printed and None is returned.

Parameters:

Name Type Description Default
name str

The original loaded checkpoint parameter name.

required
params_dict dict

Dictionary containing the model's named parameters.

required

Returns:

Name Type Description
str str | None

The remapped parameter name if successful, or the original name if no remapping is needed.

None str | None

If the remapped name is not found in params_dict.

Source code in fastvideo/models/loader/weight_utils.py
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
    """Remap the name of FP8 k/v_scale parameters.

    This function handles the remapping of FP8 k/v_scale parameter names.
    It detects if the given name ends with a suffix and attempts to remap
    it to the expected name format in the model. If the remapped name is not
    found in the params_dict, a warning is printed and None is returned.

    Args:
        name (str): The original loaded checkpoint parameter name.
        params_dict (dict): Dictionary containing the model's named parameters.

    Returns:
        str: The remapped parameter name if successful, or the original name
             if no remapping is needed.
        None: If the remapped name is not found in params_dict.
    """
    if name.endswith(".kv_scale"):
        logger.warning_once(
            "DEPRECATED. Found kv_scale in the checkpoint. "
            "This format is deprecated in favor of separate k_scale and "
            "v_scale tensors and will be removed in a future release. "
            "Functionally, we will remap kv_scale to k_scale and duplicate "
            "k_scale to v_scale")
        # NOTE: we remap the deprecated kv_scale to k_scale
        remapped_name = name.replace(".kv_scale", ".attn.k_scale")
        if remapped_name not in params_dict:
            logger.warning_once(
                f"Found kv_scale in the checkpoint (e.g. {name}), "
                "but not found the expected name in the model "
                f"(e.g. {remapped_name}). kv_scale is "
                "not loaded.")
            return None
        return remapped_name

    possible_scale_names = [".k_scale", ".v_scale"]
    modelopt_scale_names = [
        ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
    ]
    for scale_name in possible_scale_names:
        if name.endswith(scale_name):
            if any(mo_scale_name in name
                   for mo_scale_name in modelopt_scale_names):
                remapped_name = name.replace(
                    f".self_attn.{scale_name[1]}_proj{scale_name}",
                    f".self_attn.attn{scale_name}")
            else:
                remapped_name = name.replace(scale_name, f".attn{scale_name}")
            if remapped_name not in params_dict:
                logger.warning_once(
                    f"Found {scale_name} in the checkpoint (e.g. {name}), "
                    "but not found the expected name in the model "
                    f"(e.g. {remapped_name}). {scale_name} is "
                    "not loaded.")
                return None
            return remapped_name

    # If there were no matches, return the untouched param name
    return name

fastvideo.models.loader.weight_utils.pt_weights_iterator

pt_weights_iterator(hf_weights_files: list[str], to_cpu: bool = False, broadcast: bool = True) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model bin/pt files.

Parameters:

Name Type Description Default
hf_weights_files list[str]

List of bin/pt files to load.

required
to_cpu bool

Whether to load the weights to CPU.

False
broadcast bool

Accepted for API symmetry. PT weights are loaded through torch.load and do not use the safetensors broadcast path.

True
Source code in fastvideo/models/loader/weight_utils.py
def pt_weights_iterator(
    hf_weights_files: list[str],
    to_cpu: bool = False,
    broadcast: bool = True
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model bin/pt files.

    Args:
        hf_weights_files: List of bin/pt files to load.
        to_cpu: Whether to load the weights to CPU.
        broadcast: Accepted for API symmetry. PT weights are loaded through
            torch.load and do not use the safetensors broadcast path.
    """
    node_group = _get_initialized_node_group()
    local_rank = node_group.local_rank if node_group is not None else int(
        os.environ.get("LOCAL_RANK", 0))
    device = str(parallel_state.get_local_torch_device()) if not to_cpu else "cpu"
    enable_tqdm = not torch.distributed.is_initialized() or local_rank == 0
    for bin_file in tqdm(
            hf_weights_files,
            desc="Loading pt checkpoint shards",
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
    ):
        state = torch.load(bin_file, map_location=device, weights_only=True)
        yield from state.items()
        del state

fastvideo.models.loader.weight_utils.resolve_safetensors_files

resolve_safetensors_files(model_path: str) -> list[str]

Discover safetensors files in a model directory.

Source code in fastvideo/models/loader/weight_utils.py
def resolve_safetensors_files(model_path: str) -> list[str]:
    """Discover safetensors files in a model directory."""
    files = sorted(
        glob.glob(os.path.join(model_path, "*.safetensors")))
    if not files:
        raise FileNotFoundError(
            f"No .safetensors files found in {model_path}")
    index_file = os.path.join(
        model_path, SAFE_WEIGHTS_INDEX_NAME)
    if os.path.exists(index_file):
        files = filter_duplicate_safetensors_files(
            files, model_path, SAFE_WEIGHTS_INDEX_NAME)
    return files

fastvideo.models.loader.weight_utils.safetensors_weights_iterator

safetensors_weights_iterator(hf_weights_files: list[str], to_cpu: bool = False, broadcast: bool = True, async_broadcast: bool = False) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model safetensor files. Args: hf_weights_files: List of safetensor files to load. to_cpu: Whether to load the weights to CPU. If False, will load to the GPU device bound to the current process. broadcast: Whether local rank 0 should read GPU weights and broadcast them to the other local ranks. async_broadcast: Whether to overlap loading from disk and broadcasting to other ranks. If True, must iterate over all the weights before use. Only used when broadcast is True and to_cpu is False.

Source code in fastvideo/models/loader/weight_utils.py
def safetensors_weights_iterator(
    hf_weights_files: list[str],
    to_cpu: bool = False,
    broadcast: bool = True,
    async_broadcast: bool = False
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files.
    Args:
        hf_weights_files: List of safetensor files to load.
        to_cpu: Whether to load the weights to CPU. If False, will load to the GPU device bound to the current
            process.
        broadcast: Whether local rank 0 should read GPU weights and broadcast them to the other local ranks.
        async_broadcast: Whether to overlap loading from disk and broadcasting to other ranks. If True,
            must iterate over all the weights before use. Only used when broadcast is True and to_cpu is False.
    """
    node_group = _get_initialized_node_group()
    local_rank = node_group.local_rank if node_group is not None else int(
        os.environ.get("LOCAL_RANK", 0))
    device = str(parallel_state.get_local_torch_device()) if not to_cpu else "cpu"
    enable_tqdm = not torch.distributed.is_initialized() or local_rank == 0
    if to_cpu or not broadcast or node_group is None:
        async_broadcast = False

    handles = []
    for st_file in tqdm(
            hf_weights_files,
            desc="Loading safetensors checkpoint shards",
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
    ):
        with safe_open(st_file, framework="pt", device=device) as f:
            for name in f.keys():  # noqa: SIM118
                if to_cpu:
                    param = f.get_tensor(name)
                elif broadcast and node_group is not None:
                    if local_rank == 0:
                        param = f.get_tensor(name)
                    else:
                        sl = f.get_slice(name)
                        shape = sl.get_shape()
                        dtype = SAFETENSORS_TO_TORCH_DTYPE[sl.get_dtype()]
                        param = torch.empty(shape, device=device, dtype=dtype)
                    # broadcast to local ranks
                    # TODO(Wenxuan): scatter instead of broadcast
                    if node_group.world_size > 1:
                        group = node_group.device_group
                        if async_broadcast:
                            handle = dist.broadcast(param,
                                                    src=dist.get_global_rank(
                                                        group, 0),
                                                    async_op=True,
                                                    group=group)
                            handles.append(handle)
                        else:
                            dist.broadcast(param,
                                           src=dist.get_global_rank(group, 0),
                                           group=group)
                else:
                    param = f.get_tensor(name)
                yield name, param

        if async_broadcast:
            for handle in handles:
                handle.wait()
            handles.clear()

Modules