bsa_attn
¶
Bidirectional Sparse Attention (BSA) backend for FastVideo.
Pure-PyTorch reference implementation from: "Bidirectional Sparse Attention for Faster Video Diffusion Training" (arXiv:2509.01085)
BSA sparsifies both queries (pruning redundant tokens per block) and key-value pairs (keeping only relevant KV blocks per query block).
This is a training-free inference backend: it works with any model trained with full attention by applying BSA sparsity at inference time.
Classes¶
fastvideo.attention.backends.bsa_attn.BSAAttentionImpl
¶
BSAAttentionImpl(num_heads: int, head_size: int, causal: bool, softmax_scale: float, num_kv_heads: int | None = None, prefix: str = '', **extra_impl_args)
Bases: AttentionImpl
Source code in fastvideo/attention/backends/bsa_attn.py
Functions¶
fastvideo.attention.backends.bsa_attn.BSAAttentionImpl.forward
¶
BSA attention forward pass.
Input tensors are already in tile-contiguous order from preprocess_qkv.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query
|
Tensor
|
[B, L, num_heads, D] (tile-ordered) |
required |
key
|
Tensor
|
[B, L, num_heads, D] (tile-ordered) |
required |
value
|
Tensor
|
[B, L, num_heads, D] (tile-ordered) |
required |
attn_metadata
|
BSAAttentionMetadata
|
BSA metadata |
required |
Returns:
| Name | Type | Description |
|---|---|---|
output |
Tensor
|
[B, L, num_heads, D] (tile-ordered) |
Source code in fastvideo/attention/backends/bsa_attn.py
fastvideo.attention.backends.bsa_attn.BSAAttentionImpl.postprocess_output
¶
Reorder tokens from tile-contiguous order back to raster order.
Source code in fastvideo/attention/backends/bsa_attn.py
fastvideo.attention.backends.bsa_attn.BSAAttentionImpl.preprocess_qkv
¶
Reorder tokens from raster order to tile-contiguous order.
Source code in fastvideo/attention/backends/bsa_attn.py
Functions¶
fastvideo.attention.backends.bsa_attn.get_reverse_tile_partition_indices
cached
¶
get_reverse_tile_partition_indices(dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: device) -> LongTensor
Inverse mapping: tile-contiguous order back to raster order.
Source code in fastvideo/attention/backends/bsa_attn.py
fastvideo.attention.backends.bsa_attn.get_tile_partition_indices
cached
¶
get_tile_partition_indices(dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: device) -> LongTensor
Map raster-order tokens to tile-contiguous order.