Skip to content

flash_attn

Classes

fastvideo.attention.backends.flash_attn.FlashAttentionImpl

FlashAttentionImpl(num_heads: int, head_size: int, causal: bool, softmax_scale: float, num_kv_heads: int | None = None, prefix: str = '', **extra_impl_args)

Bases: AttentionImpl

Source code in fastvideo/attention/backends/flash_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    causal: bool,
    softmax_scale: float,
    num_kv_heads: int | None = None,
    prefix: str = "",
    **extra_impl_args,
) -> None:
    self.causal = causal
    self.softmax_scale = softmax_scale
    self.nvfp4_fa4 = extra_impl_args.get("nvfp4_fa4", False) or os.environ.get("FASTVIDEO_NVFP4_FA4", "0") == "1"
    if self.nvfp4_fa4:
        cap = torch.cuda.get_device_capability()
        assert cap in [(10, 0), (10, 3)], (f"NVFP4 FA4 requires Blackwell (sm100a/sm103a), got sm{cap[0]}{cap[1]}")
        assert _FA4_FP4_AVAILABLE, ("NVFP4 FA4 requires flash-attention-fp4 (flash_attn.cute). "
                                    "Install via instructions in docs/inference/optimizations.md")
        logger.info("NVFP4 FA4 enabled for FlashAttentionImpl (quant_qk only)")

Functions: