Skip to content

benchmark_weight_loading

Benchmark for model weight loading speed.

Measures the time to load model weights from safetensors files using different strategies (CPU vs GPU, broadcast vs independent).

Usage (single GPU): python fastvideo/models/loader/benchmarks/benchmark_weight_loading.py --model-path /path/to/model

Usage (multi-GPU, e.g. 4 GPUs): torchrun --nproc_per_node=4 fastvideo/models/loader/benchmarks/benchmark_weight_loading.py --model-path /path/to/model

Functions

fastvideo.models.loader.benchmarks.benchmark_weight_loading.benchmark_loading

benchmark_loading(files: list[str], to_cpu: bool, broadcast: bool, warmup: int, repeats: int, label: str) -> None

Run the weight loading benchmark and print results.

Source code in fastvideo/models/loader/benchmarks/benchmark_weight_loading.py
def benchmark_loading(
    files: list[str],
    to_cpu: bool,
    broadcast: bool,
    warmup: int,
    repeats: int,
    label: str,
) -> None:
    """Run the weight loading benchmark and print results."""
    rank = dist.get_rank() if dist.is_initialized() else 0
    node_group = get_node_group()

    # Count total params and bytes on first pass
    total_params = 0
    total_bytes = 0
    for name, tensor in safetensors_weights_iterator(
            files, to_cpu=to_cpu, broadcast=broadcast):
        total_params += 1
        total_bytes += tensor.nelement() * tensor.element_size()

    if rank == 0:
        logger.info("[%s] %d tensors, %.2f GB total",
                    label, total_params, total_bytes / 1e9)

    # Warmup
    for _ in range(warmup):
        for _ in safetensors_weights_iterator(
                files, to_cpu=to_cpu, broadcast=broadcast):
            pass
        if dist.is_initialized():
            dist.barrier()

    # Timed runs
    times = []
    for i in range(repeats):
        if dist.is_initialized():
            dist.barrier()
        torch.cuda.synchronize() if torch.cuda.is_available() else None

        t0 = time.perf_counter()
        for _ in safetensors_weights_iterator(
                files, to_cpu=to_cpu, broadcast=broadcast):
            pass
        torch.cuda.synchronize() if torch.cuda.is_available() else None

        if dist.is_initialized():
            dist.barrier()
        elapsed = time.perf_counter() - t0
        times.append(elapsed)

    if rank == 0:
        avg = sum(times) / len(times)
        best = min(times)
        throughput = total_bytes / avg / 1e9
        logger.info(
            "[%s] avg %.3fs | best %.3fs | throughput %.2f GB/s "
            "(over %d runs, %d warmup, %d GPU(s), node_size=%d)",
            label, avg, best, throughput,
            repeats, warmup, node_group.world_size, node_group.world_size,
        )