Skip to content

benchmark_weight_loading_comparison

A/B benchmark: independent-read vs rank-0-broadcast weight loading.

Compares two strategies
  • "before" (independent): every rank reads safetensors from disk to GPU
  • "after" (broadcast): rank 0 reads from disk, broadcasts to other ranks
Usage

1 GPU

torchrun --nproc_per_node=1 fastvideo/models/loader/benchmarks/benchmark_weight_loading_comparison.py --model-path /path/to/model --subfolder transformer

2 GPUs

torchrun --nproc_per_node=2 fastvideo/models/loader/benchmarks/benchmark_weight_loading_comparison.py --model-path /path/to/model --subfolder transformer

4 GPUs

torchrun --nproc_per_node=4 fastvideo/models/loader/benchmarks/benchmark_weight_loading_comparison.py --model-path /path/to/model --subfolder transformer

Functions

fastvideo.models.loader.benchmarks.benchmark_weight_loading_comparison.load_broadcast

load_broadcast(files: list[str], device: str, node_group, async_op: bool = False)

After-PR behavior: rank 0 reads from disk, broadcasts to other ranks.

Source code in fastvideo/models/loader/benchmarks/benchmark_weight_loading_comparison.py
def load_broadcast(files: list[str], device: str, node_group,
                   async_op: bool = False):
    """After-PR behavior: rank 0 reads from disk, broadcasts to other ranks."""
    local_rank = node_group.local_rank
    handles = []
    for st_file in files:
        with safe_open(st_file, framework="pt", device=device) as f:
            for name in f:
                if local_rank == 0:
                    param = f.get_tensor(name)
                else:
                    sl = f.get_slice(name)
                    shape = sl.get_shape()
                    dtype = SAFETENSORS_TO_TORCH_DTYPE[sl.get_dtype()]
                    param = torch.empty(shape, device=device, dtype=dtype)
                if node_group.world_size > 1:
                    group = node_group.device_group
                    if async_op:
                        handle = dist.broadcast(
                            param,
                            src=dist.get_global_rank(group, 0),
                            async_op=True,
                            group=group,
                        )
                        handles.append(handle)
                    else:
                        dist.broadcast(
                            param,
                            src=dist.get_global_rank(group, 0),
                            group=group,
                        )
                yield name, param
        if async_op:
            for handle in handles:
                handle.wait()
            handles.clear()

fastvideo.models.loader.benchmarks.benchmark_weight_loading_comparison.load_independent

load_independent(files: list[str], device: str)

Before-PR behavior: every rank reads every tensor from disk to GPU.

Source code in fastvideo/models/loader/benchmarks/benchmark_weight_loading_comparison.py
def load_independent(files: list[str], device: str):
    """Before-PR behavior: every rank reads every tensor from disk to GPU."""
    for st_file in files:
        with safe_open(st_file, framework="pt", device=device) as f:
            for name in f:
                param = f.get_tensor(name)
                yield name, param