Skip to content

bsa_attn

Bidirectional Sparse Attention (BSA) backend for FastVideo.

Pure-PyTorch reference implementation from: "Bidirectional Sparse Attention for Faster Video Diffusion Training" (arXiv:2509.01085)

BSA sparsifies both queries (pruning redundant tokens per block) and key-value pairs (keeping only relevant KV blocks per query block).

This is a training-free inference backend: it works with any model trained with full attention by applying BSA sparsity at inference time.

Classes

fastvideo.attention.backends.bsa_attn.BSAAttentionImpl

BSAAttentionImpl(num_heads: int, head_size: int, causal: bool, softmax_scale: float, num_kv_heads: int | None = None, prefix: str = '', **extra_impl_args)

Bases: AttentionImpl

Source code in fastvideo/attention/backends/bsa_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    causal: bool,
    softmax_scale: float,
    num_kv_heads: int | None = None,
    prefix: str = "",
    **extra_impl_args,
) -> None:
    self.prefix = prefix
    self.num_heads = num_heads
    self.head_size = head_size
    if num_kv_heads is not None and num_kv_heads != num_heads:
        raise ValueError("BSA backend does not support grouped-query attention")
    if causal:
        raise ValueError("BSA backend is bidirectional; causal=True is unsupported")
    if softmax_scale is not None:
        expected_scale = 1.0 / math.sqrt(self.head_size)
        if not math.isclose(softmax_scale, expected_scale, rel_tol=1e-4, abs_tol=1e-5):
            raise ValueError("softmax_scale must be default (1/sqrt(d)) for BSA")
    try:
        sp_group = get_sp_group()
        self.sp_size = sp_group.world_size
    except (AssertionError, RuntimeError):
        self.sp_size = 1

Functions

fastvideo.attention.backends.bsa_attn.BSAAttentionImpl.forward
forward(query: Tensor, key: Tensor, value: Tensor, attn_metadata: BSAAttentionMetadata) -> Tensor

BSA attention forward pass.

Input tensors are already in tile-contiguous order from preprocess_qkv.

Parameters:

Name Type Description Default
query Tensor

[B, L, num_heads, D] (tile-ordered)

required
key Tensor

[B, L, num_heads, D] (tile-ordered)

required
value Tensor

[B, L, num_heads, D] (tile-ordered)

required
attn_metadata BSAAttentionMetadata

BSA metadata

required

Returns:

Name Type Description
output Tensor

[B, L, num_heads, D] (tile-ordered)

Source code in fastvideo/attention/backends/bsa_attn.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_metadata: BSAAttentionMetadata,
) -> torch.Tensor:
    """
    BSA attention forward pass.

    Input tensors are already in tile-contiguous order from preprocess_qkv.

    Args:
        query: [B, L, num_heads, D] (tile-ordered)
        key:   [B, L, num_heads, D] (tile-ordered)
        value: [B, L, num_heads, D] (tile-ordered)
        attn_metadata: BSA metadata

    Returns:
        output: [B, L, num_heads, D] (tile-ordered)
    """
    B, L, H, D = query.shape
    block_size = attn_metadata.block_size
    num_blocks = attn_metadata.num_blocks
    assert num_blocks * block_size == L, "Sequence length must match tiling"

    # Reshape to [B, H, L, D] for attention computation
    q = query.transpose(1, 2).contiguous()  # [B, H, L, D]
    k = key.transpose(1, 2).contiguous()
    v = value.transpose(1, 2).contiguous()

    # Reshape into blocks: [B, H, num_blocks, block_size, D]
    q_blocks = q.view(B, H, num_blocks, block_size, D)
    k_blocks = k.view(B, H, num_blocks, block_size, D)
    v_blocks = v.view(B, H, num_blocks, block_size, D)

    # --- Query sparsification ---
    sparse_q, keep_indices, keep_size = _prune_queries(q_blocks, attn_metadata.query_keep_ratio)

    # --- KV block selection ---
    kv_mask = _select_kv_blocks(
        sparse_q,
        k_blocks,
        attn_metadata.kv_cumulative_threshold,
        attn_metadata.min_kv_blocks,
    )

    # --- Sparse attention ---
    sparse_output = _compute_sparse_attention(sparse_q, k_blocks, v_blocks, kv_mask)

    # --- Reconstruct pruned positions ---
    full_output = _reconstruct_pruned(sparse_output, keep_indices, block_size)

    # Reshape back: [B, H, num_blocks, block_size, D] -> [B, H, L, D] -> [B, L, H, D]
    hidden_states = full_output.view(B, H, L, D).transpose(1, 2)

    return hidden_states
fastvideo.attention.backends.bsa_attn.BSAAttentionImpl.postprocess_output
postprocess_output(output: Tensor, attn_metadata: BSAAttentionMetadata) -> Tensor

Reorder tokens from tile-contiguous order back to raster order.

Source code in fastvideo/attention/backends/bsa_attn.py
def postprocess_output(
    self,
    output: torch.Tensor,
    attn_metadata: BSAAttentionMetadata,
) -> torch.Tensor:
    """Reorder tokens from tile-contiguous order back to raster order."""
    return output[:, attn_metadata.reverse_tile_partition_indices]
fastvideo.attention.backends.bsa_attn.BSAAttentionImpl.preprocess_qkv
preprocess_qkv(qkv: Tensor, attn_metadata: BSAAttentionMetadata) -> Tensor

Reorder tokens from raster order to tile-contiguous order.

Source code in fastvideo/attention/backends/bsa_attn.py
def preprocess_qkv(
    self,
    qkv: torch.Tensor,
    attn_metadata: BSAAttentionMetadata,
) -> torch.Tensor:
    """Reorder tokens from raster order to tile-contiguous order."""
    # qkv: [B, L, num_heads, D]
    return qkv[:, attn_metadata.tile_partition_indices]

Functions

fastvideo.attention.backends.bsa_attn.get_reverse_tile_partition_indices cached

get_reverse_tile_partition_indices(dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: device) -> LongTensor

Inverse mapping: tile-contiguous order back to raster order.

Source code in fastvideo/attention/backends/bsa_attn.py
@functools.lru_cache(maxsize=10)
def get_reverse_tile_partition_indices(
    dit_seq_shape: tuple[int, int, int],
    tile_size: tuple[int, int, int],
    device: torch.device,
) -> torch.LongTensor:
    """Inverse mapping: tile-contiguous order back to raster order."""
    return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))

fastvideo.attention.backends.bsa_attn.get_tile_partition_indices cached

get_tile_partition_indices(dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: device) -> LongTensor

Map raster-order tokens to tile-contiguous order.

Source code in fastvideo/attention/backends/bsa_attn.py
@functools.lru_cache(maxsize=10)
def get_tile_partition_indices(
    dit_seq_shape: tuple[int, int, int],
    tile_size: tuple[int, int, int],
    device: torch.device,
) -> torch.LongTensor:
    """Map raster-order tokens to tile-contiguous order."""
    T, H, W = dit_seq_shape
    ts, hs, ws = tile_size
    indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
    ls = []
    for t in range(math.ceil(T / ts)):
        for h in range(math.ceil(H / hs)):
            for w in range(math.ceil(W / ws)):
                ls.append(indices[
                    t * ts:min(t * ts + ts, T),
                    h * hs:min(h * hs + hs, H),
                    w * ws:min(w * ws + ws, W),
                ].flatten())
    return torch.cat(ls, dim=0)