Skip to content

training_pipeline

Classes

fastvideo.training.training_pipeline.TrainingPipeline

TrainingPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: LoRAPipeline, ABC

A pipeline for training a model. All training pipelines should inherit from this class. All reusable components and code should be implemented in this class.

Source code in fastvideo/training/training_pipeline.py
def __init__(self,
             model_path: str,
             fastvideo_args: TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules, loaded_modules)  # type: ignore
    self.tracker = DummyTracker()

Functions

fastvideo.training.training_pipeline.TrainingPipeline.visualize_intermediate_latents
visualize_intermediate_latents(training_batch: TrainingBatch, training_args: TrainingArgs, step: int)

Add visualization data to tracker logging and save frames to disk.

Source code in fastvideo/training/training_pipeline.py
def visualize_intermediate_latents(self, training_batch: TrainingBatch, training_args: TrainingArgs, step: int):
    """Add visualization data to tracker logging and save frames to disk."""
    raise NotImplementedError("Visualize intermediate latents is not implemented for training pipeline")

Functions