Benchmark for model weight loading speed.
Measures the time to load model weights from safetensors files using
different strategies (CPU vs GPU, broadcast vs independent).
Usage (single GPU):
python fastvideo/models/loader/benchmarks/benchmark_weight_loading.py --model-path /path/to/model
Usage (multi-GPU, e.g. 4 GPUs):
torchrun --nproc_per_node=4 fastvideo/models/loader/benchmarks/benchmark_weight_loading.py --model-path /path/to/model
Functions
fastvideo.models.loader.benchmarks.benchmark_weight_loading.benchmark_loading
Run the weight loading benchmark and print results.
Source code in fastvideo/models/loader/benchmarks/benchmark_weight_loading.py
| def benchmark_loading(
files: list[str],
to_cpu: bool,
broadcast: bool,
warmup: int,
repeats: int,
label: str,
) -> None:
"""Run the weight loading benchmark and print results."""
rank = dist.get_rank() if dist.is_initialized() else 0
node_group = get_node_group()
# Count total params and bytes on first pass
total_params = 0
total_bytes = 0
for name, tensor in safetensors_weights_iterator(
files, to_cpu=to_cpu, broadcast=broadcast):
total_params += 1
total_bytes += tensor.nelement() * tensor.element_size()
if rank == 0:
logger.info("[%s] %d tensors, %.2f GB total",
label, total_params, total_bytes / 1e9)
# Warmup
for _ in range(warmup):
for _ in safetensors_weights_iterator(
files, to_cpu=to_cpu, broadcast=broadcast):
pass
if dist.is_initialized():
dist.barrier()
# Timed runs
times = []
for i in range(repeats):
if dist.is_initialized():
dist.barrier()
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in safetensors_weights_iterator(
files, to_cpu=to_cpu, broadcast=broadcast):
pass
torch.cuda.synchronize() if torch.cuda.is_available() else None
if dist.is_initialized():
dist.barrier()
elapsed = time.perf_counter() - t0
times.append(elapsed)
if rank == 0:
avg = sum(times) / len(times)
best = min(times)
throughput = total_bytes / avg / 1e9
logger.info(
"[%s] avg %.3fs | best %.3fs | throughput %.2f GB/s "
"(over %d runs, %d warmup, %d GPU(s), node_size=%d)",
label, avg, best, throughput,
repeats, warmup, node_group.world_size, node_group.world_size,
)
|