Skip to content

nvfp4_qat_config

Classes

fastvideo.layers.quantization.nvfp4_qat_config.NVFP4QATQuantizeMethod

NVFP4QATQuantizeMethod()

Bases: QuantizeMethodBase

Source code in fastvideo/layers/quantization/nvfp4_qat_config.py
def __init__(self) -> None:
    super().__init__()
    self.weight_fp4 = None
    self.weight_scale = None

Functions

fastvideo.layers.quantization.nvfp4_qat_config.NVFP4QATQuantizeMethod.apply
apply(layer: Module, x: Tensor, bias: Tensor | None = None) -> Tensor

Apply NVFP4 QAT quantized computation.

Source code in fastvideo/layers/quantization/nvfp4_qat_config.py
@torch.compile
def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
    """Apply NVFP4 QAT quantized computation."""
    flashinfer_mod = _require_flashinfer()
    out_dim = layer.weight.shape[0]
    original_shape = x.shape
    assert x.dtype == torch.bfloat16 or x.dtype == torch.float16, f"only allow bf16/fp16 inputs to fp4 linear, got {x.dtype}"

    x = x.view(-1, x.shape[-1])

    x_global_sf = (448 * 6) / x.float().abs().nan_to_num().max()
    x_fp4, x_scale = flashinfer_mod.nvfp4_quantize(
        x,
        x_global_sf,
        sfLayout=flashinfer_mod.SfLayout.layout_128x4,
        do_shuffle=False,
    )
    weight_fp4 = layer._fp4_weight
    weight_scale = layer._fp4_weight_scale
    weight_global_sf = layer._weight_global_sf

    out = flashinfer_mod.mm_fp4(
        x_fp4,
        weight_fp4.T,
        x_scale,
        weight_scale.T,
        1.0 / (x_global_sf * weight_global_sf),
        torch.bfloat16,
        None,
        backend="cutlass",
    )

    if bias is not None:
        if bias.device != out.device or bias.dtype != out.dtype:
            bias = bias.to(device=out.device, dtype=out.dtype)
        out = out + bias

    if len(original_shape) == 3:
        out = out.view(original_shape[0], original_shape[1], out_dim)

    return out
fastvideo.layers.quantization.nvfp4_qat_config.NVFP4QATQuantizeMethod.create_weights
create_weights(layer: Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: dtype, **extra_weight_attrs)

Create weights for a linear layer. Note the corrected signature to match LinearMethodBase.

Source code in fastvideo/layers/quantization/nvfp4_qat_config.py
def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int],
                   input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs):
    """Create weights for a linear layer. Note the corrected signature to match LinearMethodBase."""
    weight = Parameter(torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition,
        dtype=params_dtype,
    ),
                       requires_grad=False)
    set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
    layer.register_parameter("weight", weight)
    set_weight_attrs(weight, extra_weight_attrs)

Functions