Skip to content

matrixgame2_self_forcing_distillation_pipeline

Classes

fastvideo.training.matrixgame2_self_forcing_distillation_pipeline.MatrixGame2SelfForcingDistillationPipeline

MatrixGame2SelfForcingDistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: SelfForcingDistillationPipeline

A self-forcing distillation pipeline for Matrix-Game 2.0 that uses the self-forcing methodology with DMD for video generation.

Source code in fastvideo/training/training_pipeline.py
def __init__(self,
             model_path: str,
             fastvideo_args: TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules, loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
    self.validation_ref_videos_logged = False

Functions

fastvideo.training.matrixgame2_self_forcing_distillation_pipeline.MatrixGame2SelfForcingDistillationPipeline.faker_score_forward
faker_score_forward(training_batch: TrainingBatch) -> tuple[TrainingBatch, Tensor]

Forward pass for critic training with Matrix-Game 2.0 action conditioning.

Source code in fastvideo/training/matrixgame2_self_forcing_distillation_pipeline.py
def faker_score_forward(self, training_batch: TrainingBatch) -> tuple[TrainingBatch, torch.Tensor]:
    """Forward pass for critic training with Matrix-Game 2.0 action conditioning."""
    with torch.no_grad(), set_forward_context(current_timestep=training_batch.timesteps,
                                              attn_metadata=training_batch.attn_metadata_vsa):
        if self.training_args.simulate_generator_forward:
            generator_pred_video = self._generator_multi_step_simulation_forward(training_batch)
        else:
            generator_pred_video = self._generator_forward(training_batch)

    fake_score_timestep = torch.randint(0, self.num_train_timestep, [1], device=self.device, dtype=torch.long)

    fake_score_timestep = shift_timestep(fake_score_timestep, self.timestep_shift, self.num_train_timestep)

    fake_score_timestep = fake_score_timestep.clamp(self.min_timestep, self.max_timestep)

    fake_score_noise = torch.randn(self.video_latent_shape, device=self.device, dtype=generator_pred_video.dtype)

    noisy_generator_pred_video = self.noise_scheduler.add_noise(
        generator_pred_video.flatten(0, 1), fake_score_noise.flatten(0, 1),
        fake_score_timestep).unflatten(0, (generator_pred_video.shape[0], generator_pred_video.shape[1]))

    # Non-causal critic expects 1D timestep (batch_size,), not 2D (batch_size, num_frames).
    expanded_fake_score_timestep = fake_score_timestep.expand(noisy_generator_pred_video.shape[0])

    self._build_distill_input_kwargs(noisy_generator_pred_video, expanded_fake_score_timestep, None, training_batch)

    with set_forward_context(current_timestep=training_batch.timesteps, attn_metadata=training_batch.attn_metadata):
        current_fake_score_transformer = self._get_fake_score_transformer(fake_score_timestep)
        fake_score_pred_noise = current_fake_score_transformer(**training_batch.input_kwargs).permute(0, 2, 1, 3, 4)

    target = fake_score_noise - generator_pred_video
    flow_matching_loss = torch.mean((fake_score_pred_noise - target)**2)

    training_batch.fake_score_latent_vis_dict = {
        "training_batch_fakerscore_fwd_clean_latent": training_batch.latents,
        "generator_pred_video": generator_pred_video,
        "fake_score_timestep": fake_score_timestep,
    }

    return training_batch, flow_matching_loss

Functions