Skip to content

video_sparse_attn

Classes

fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionImpl

VideoSparseAttentionImpl(num_heads: int, head_size: int, causal: bool, softmax_scale: float, num_kv_heads: int | None = None, prefix: str = '', **extra_impl_args)

Bases: AttentionImpl

Source code in fastvideo/attention/backends/video_sparse_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    causal: bool,
    softmax_scale: float,
    num_kv_heads: int | None = None,
    prefix: str = "",
    **extra_impl_args,
) -> None:
    self.prefix = prefix
    sp_group = get_sp_group()
    self.sp_size = sp_group.world_size

Methods:

fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionImpl.preprocess_qkv
preprocess_qkv(qkv: Tensor, attn_metadata: VideoSparseAttentionMetadata) -> Tensor

Tile QKV; aliasing contract: see tile().

Source code in fastvideo/attention/backends/video_sparse_attn.py
def preprocess_qkv(
    self,
    qkv: torch.Tensor,
    attn_metadata: VideoSparseAttentionMetadata,
) -> torch.Tensor:
    """Tile QKV; aliasing contract: see ``tile()``."""
    return self.tile(qkv, attn_metadata)
fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionImpl.tile
tile(x: Tensor, attn_metadata: VideoSparseAttentionMetadata) -> Tensor

Tile x into attn_metadata.tile_buf and return it.

The returned tensor aliases the per-metadata buffer and is only valid until the next tile() / preprocess_qkv call on the same attn_metadata. Callers must consume (or copy) the result before invoking another VSA layer with the same metadata. Today both call sites materialize copies via .transpose(...).contiguous() inside forward(), so the contract holds; future callers must preserve it.

Source code in fastvideo/attention/backends/video_sparse_attn.py
def tile(self, x: torch.Tensor, attn_metadata: VideoSparseAttentionMetadata) -> torch.Tensor:
    """Tile ``x`` into ``attn_metadata.tile_buf`` and return it.

    The returned tensor aliases the per-metadata buffer and is only
    valid until the next ``tile()`` / ``preprocess_qkv`` call on the
    same ``attn_metadata``.  Callers must consume (or copy) the
    result before invoking another VSA layer with the same metadata.
    Today both call sites materialize copies via
    ``.transpose(...).contiguous()`` inside ``forward()``, so the
    contract holds; future callers must preserve it.
    """
    num_tiles = attn_metadata.num_tiles
    t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
    h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
    w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
    target_shape = (x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1])

    if not attn_metadata.cache_tile_buf:
        buf = torch.zeros(target_shape, device=x.device, dtype=x.dtype)
        buf[:, attn_metadata.non_pad_index] = x[:, attn_metadata.tile_partition_indices]
        return buf

    # Reuse the per-step buffer stashed on metadata (lazily allocated
    # on the first VSA layer's call within a denoising step).  Pad
    # positions are zero from the initial torch.zeros and never
    # written to.  Scoping to metadata makes reuse safe across
    # concurrent requests and keeps the "pad positions are zero"
    # invariant trivially true: ``non_pad_index`` is fixed within
    # a single metadata instance.
    buf = attn_metadata.tile_buf
    if (buf is None or buf.shape != target_shape or buf.dtype != x.dtype or buf.device != x.device):
        buf = torch.zeros(target_shape, device=x.device, dtype=x.dtype)
        attn_metadata.tile_buf = buf

    buf[:, attn_metadata.non_pad_index] = x[:, attn_metadata.tile_partition_indices]
    return buf

Functions:

fastvideo.attention.backends.video_sparse_attn.construct_variable_block_sizes cached

construct_variable_block_sizes(dit_seq_shape: tuple[int, int, int], num_tiles: tuple[int, int, int], device: device) -> LongTensor

Compute the number of valid (non‑padded) tokens inside every (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order (t‑tile, h‑tile, w‑tile) that rearrange uses.

Returns

torch.LongTensor # shape: [∏ full_window_size]

Source code in fastvideo/attention/backends/video_sparse_attn.py
@functools.lru_cache(maxsize=10)
def construct_variable_block_sizes(
    dit_seq_shape: tuple[int, int, int],
    num_tiles: tuple[int, int, int],
    device: torch.device,
) -> torch.LongTensor:
    """
    Compute the number of valid (non‑padded) tokens inside every
    (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
    (t‑tile, h‑tile, w‑tile) that `rearrange` uses.

    Returns
    -------
    torch.LongTensor  # shape: [∏ full_window_size]
    """
    # unpack
    t, h, w = dit_seq_shape
    ts_t, ts_h, ts_w = VSA_TILE_SIZE
    n_t, n_h, n_w = num_tiles

    def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
        """Vector with the size of each tile along one dimension."""
        sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device)
        # size of last (possibly partial) tile
        remainder = dim_len - (n_tiles - 1) * tile
        sizes[-1] = remainder if remainder > 0 else tile
        return sizes

    t_sizes = _sizes(t, ts_t, n_t)  # [n_t]
    h_sizes = _sizes(h, ts_h, n_h)  # [n_h]
    w_sizes = _sizes(w, ts_w, n_w)  # [n_w]

    # broadcast‑multiply to get voxels per tile, then flatten
    block_sizes = (
        t_sizes[:, None, None]  # [n_t, 1,   1]
        * h_sizes[None, :, None]  # [1,   n_h, 1]
        * w_sizes[None, None, :]  # [1,   1,   n_w]
    ).reshape(-1)  # [n_t * n_h * n_w]

    return block_sizes