Apply train/eval mode + requires_grad based on a role's trainable flag.
Source code in fastvideo/train/utils/module_state.py
| def apply_trainable(module: torch.nn.Module, *, trainable: bool) -> torch.nn.Module:
"""Apply train/eval mode + requires_grad based on a role's trainable flag."""
module.requires_grad_(bool(trainable))
if trainable:
module.train()
else:
module.eval()
return module
|