Skip to content

module_state

Functions

fastvideo.train.utils.module_state.apply_trainable

apply_trainable(module: Module, *, trainable: bool) -> Module

Apply train/eval mode + requires_grad based on a role's trainable flag.

Source code in fastvideo/train/utils/module_state.py
def apply_trainable(module: torch.nn.Module, *, trainable: bool) -> torch.nn.Module:
    """Apply train/eval mode + requires_grad based on a role's trainable flag."""

    module.requires_grad_(bool(trainable))
    if trainable:
        module.train()
    else:
        module.eval()
    return module