Skip to content

matrixgame2

Modules

fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline

Classes

fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline.PreprocessPipeline_MatrixGame2
PreprocessPipeline_MatrixGame2(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: BasePreprocessPipeline

I2V preprocessing pipeline implementation.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}
    self._trace_mgr = None

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline.PreprocessPipeline_MatrixGame2.create_record
create_record(video_name: str, vae_latent: ndarray, text_embedding: ndarray, valid_data: dict[str, Any], idx: int, extra_features: dict[str, Any] | None = None) -> dict[str, Any]

Create a record for the Parquet dataset with CLIP features.

Source code in fastvideo/pipelines/preprocess/matrixgame2/matrixgame2_preprocess_pipeline.py
def create_record(self,
                  video_name: str,
                  vae_latent: np.ndarray,
                  text_embedding: np.ndarray,
                  valid_data: dict[str, Any],
                  idx: int,
                  extra_features: dict[str, Any] | None = None) -> dict[str, Any]:
    """Create a record for the Parquet dataset with CLIP features."""
    record = super().create_record(video_name=video_name,
                                   vae_latent=vae_latent,
                                   text_embedding=text_embedding,
                                   valid_data=valid_data,
                                   idx=idx,
                                   extra_features=extra_features)

    if extra_features and "clip_feature" in extra_features:
        clip_feature = extra_features["clip_feature"]
        record.update({
            "clip_feature_bytes": clip_feature.tobytes(),
            "clip_feature_shape": list(clip_feature.shape),
            "clip_feature_dtype": str(clip_feature.dtype),
        })
    else:
        record.update({
            "clip_feature_bytes": b"",
            "clip_feature_shape": [],
            "clip_feature_dtype": "",
        })

    if extra_features and "first_frame_latent" in extra_features:
        first_frame_latent = extra_features["first_frame_latent"]
        record.update({
            "first_frame_latent_bytes": first_frame_latent.tobytes(),
            "first_frame_latent_shape": list(first_frame_latent.shape),
            "first_frame_latent_dtype": str(first_frame_latent.dtype),
        })
    else:
        record.update({
            "first_frame_latent_bytes": b"",
            "first_frame_latent_shape": [],
            "first_frame_latent_dtype": "",
        })

    if extra_features and "pil_image" in extra_features:
        pil_image = extra_features["pil_image"]
        record.update({
            "pil_image_bytes": pil_image.tobytes(),
            "pil_image_shape": list(pil_image.shape),
            "pil_image_dtype": str(pil_image.dtype),
        })
    else:
        record.update({
            "pil_image_bytes": b"",
            "pil_image_shape": [],
            "pil_image_dtype": "",
        })

    if extra_features and "keyboard_cond" in extra_features:
        keyboard_cond = extra_features["keyboard_cond"]
        record.update({
            "keyboard_cond_bytes": keyboard_cond.tobytes(),
            "keyboard_cond_shape": list(keyboard_cond.shape),
            "keyboard_cond_dtype": str(keyboard_cond.dtype),
        })
    else:
        record.update({
            "keyboard_cond_bytes": b"",
            "keyboard_cond_shape": [],
            "keyboard_cond_dtype": "",
        })

    if extra_features and "mouse_cond" in extra_features:
        mouse_cond = extra_features["mouse_cond"]
        record.update({
            "mouse_cond_bytes": mouse_cond.tobytes(),
            "mouse_cond_shape": list(mouse_cond.shape),
            "mouse_cond_dtype": str(mouse_cond.dtype),
        })
    else:
        record.update({
            "mouse_cond_bytes": b"",
            "mouse_cond_shape": [],
            "mouse_cond_dtype": "",
        })

    return record
fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline.PreprocessPipeline_MatrixGame2.get_pyarrow_schema
get_pyarrow_schema()

Return the PyArrow schema for I2V pipeline.

Source code in fastvideo/pipelines/preprocess/matrixgame2/matrixgame2_preprocess_pipeline.py
def get_pyarrow_schema(self):
    """Return the PyArrow schema for I2V pipeline."""
    return pyarrow_schema_matrixgame2

Functions

fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline_ode_trajectory

ODE Trajectory Data Preprocessing pipeline implementation.

This module contains an implementation of the ODE Trajectory Data Preprocessing pipeline using the modular pipeline architecture.

Sec 4.3 of CausVid paper: https://arxiv.org/pdf/2412.07772

Classes

fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline_ode_trajectory.PreprocessPipeline_MatrixGame2_ODE_Trajectory
PreprocessPipeline_MatrixGame2_ODE_Trajectory(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: BasePreprocessPipeline

ODE Trajectory preprocessing pipeline implementation.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}
    self._trace_mgr = None

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError("Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline_ode_trajectory.PreprocessPipeline_MatrixGame2_ODE_Trajectory.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/preprocess/matrixgame2/matrixgame2_preprocess_pipeline_ode_trajectory.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""
    assert fastvideo_args.pipeline_config.flow_shift == 5
    self.modules["scheduler"] = SelfForcingFlowMatchScheduler(shift=fastvideo_args.pipeline_config.flow_shift,
                                                              sigma_min=0.0,
                                                              extra_one_step=True)
    self.modules["scheduler"].set_timesteps(num_inference_steps=48, denoising_strength=1.0)

    self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())
    self.add_stage(stage_name="image_encoding_stage",
                   stage=MatrixGame2ImageEncodingStage(
                       image_encoder=self.get_module("image_encoder"),
                       image_processor=self.get_module("image_processor"),
                   ))
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))
    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
                                                transformer=self.get_module("transformer", None)))
    self.add_stage(stage_name="denoising_stage",
                   stage=MatrixGame2CausalDenoisingStage(
                       transformer=self.get_module("transformer"),
                       scheduler=self.get_module("scheduler"),
                       pipeline=self,
                       vae=self.get_module("vae"),
                   ))
    self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")))
fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline_ode_trajectory.PreprocessPipeline_MatrixGame2_ODE_Trajectory.get_pyarrow_schema
get_pyarrow_schema() -> Schema

Return the PyArrow schema for ODE Trajectory pipeline.

Source code in fastvideo/pipelines/preprocess/matrixgame2/matrixgame2_preprocess_pipeline_ode_trajectory.py
def get_pyarrow_schema(self) -> pa.Schema:
    """Return the PyArrow schema for ODE Trajectory pipeline."""
    return pyarrow_schema_matrixgame2_ode_trajectory
fastvideo.pipelines.preprocess.matrixgame2.matrixgame2_preprocess_pipeline_ode_trajectory.PreprocessPipeline_MatrixGame2_ODE_Trajectory.preprocess_action_and_trajectory
preprocess_action_and_trajectory(fastvideo_args: FastVideoArgs, args)

Preprocess data and generate trajectory information.

Source code in fastvideo/pipelines/preprocess/matrixgame2/matrixgame2_preprocess_pipeline_ode_trajectory.py
def preprocess_action_and_trajectory(self, fastvideo_args: FastVideoArgs, args):
    """Preprocess data and generate trajectory information."""

    for batch_idx, data in enumerate(self.pbar):
        if data is None:
            continue

        with torch.inference_mode():
            # Filter out invalid samples (those with all zeros)
            valid_indices = []
            for i, pixel_values in enumerate(data["pixel_values"]):
                if not torch.all(pixel_values == 0):  # Check if all values are zero
                    valid_indices.append(i)
            self.num_processed_samples += len(valid_indices)

            if not valid_indices:
                continue

            # Create new batch with only valid samples
            valid_data = {
                "pixel_values": torch.stack([data["pixel_values"][i] for i in valid_indices]),
                "path": [data["path"][i] for i in valid_indices],
            }

            if "fps" in data:
                valid_data["fps"] = [data["fps"][i] for i in valid_indices]
            if "duration" in data:
                valid_data["duration"] = [data["duration"][i] for i in valid_indices]
            if "action_path" in data:
                valid_data["action_path"] = [data["action_path"][i] for i in valid_indices]

            pixel_values = valid_data["pixel_values"]
            if pixel_values.shape[2] == 1 and args.num_frames is not None:
                pixel_values = pixel_values.repeat(1, 1, args.num_frames, 1, 1)
                valid_data["pixel_values"] = pixel_values

            # Get extra features if needed
            extra_features = self.get_extra_features(valid_data, fastvideo_args)

            clip_features = extra_features['clip_feature']
            image_latents = extra_features['first_frame_latent']
            image_latents = image_latents[:, :, :args.num_latent_t]
            pil_image = extra_features['pil_image']
            keyboard_cond = extra_features.get('keyboard_cond')
            mouse_cond = extra_features.get('mouse_cond')

            sampling_params = SamplingParam.from_pretrained(args.model_path)

            trajectory_latents = []
            trajectory_timesteps = []
            trajectory_decoded = []

            device = get_local_torch_device()
            for i in range(len(valid_indices)):
                # Collect the trajectory data
                batch = ForwardBatch(**shallow_asdict(sampling_params), )
                batch.image_embeds = [clip_features[i].unsqueeze(0)]
                batch.image_latent = image_latents[i].unsqueeze(0)
                batch.keyboard_cond = (torch.from_numpy(keyboard_cond[i]).unsqueeze(0).to(device)
                                       if keyboard_cond is not None else None)
                batch.mouse_cond = (torch.from_numpy(mouse_cond[i]).unsqueeze(0).to(device)
                                    if mouse_cond is not None else None)
                batch.num_inference_steps = 48
                batch.return_trajectory_latents = True
                # Enabling this will save the decoded trajectory videos.
                # Used for debugging.
                batch.return_trajectory_decoded = False
                batch.height = args.max_height
                batch.width = args.max_width
                batch.fps = args.train_fps
                batch.num_frames = valid_data["pixel_values"].shape[2]
                batch.guidance_scale = 6.0
                batch.do_classifier_free_guidance = False
                batch.prompt = ""

                result_batch = self.input_validation_stage(batch, fastvideo_args)
                result_batch = self.timestep_preparation_stage(batch, fastvideo_args)
                result_batch.timesteps = result_batch.timesteps.to(device)
                result_batch = self.latent_preparation_stage(result_batch, fastvideo_args)
                result_batch = self.denoising_stage(result_batch, fastvideo_args)
                result_batch = self.decoding_stage(result_batch, fastvideo_args)

                trajectory_latents.append(result_batch.trajectory_latents.cpu())
                trajectory_timesteps.append(result_batch.trajectory_timesteps.cpu())
                trajectory_decoded.append(result_batch.trajectory_decoded)

            # Prepare extra features
            extra_features = {
                "trajectory_latents": trajectory_latents,
                "trajectory_timesteps": trajectory_timesteps
            }

            if batch.return_trajectory_decoded:
                for i, decoded_frames in enumerate(trajectory_decoded):
                    for j, decoded_frame in enumerate(decoded_frames):
                        save_decoded_latents_as_video(decoded_frame,
                                                      f"decoded_videos/trajectory_decoded_{i}_{j}.mp4",
                                                      args.train_fps)

            # Prepare batch data for Parquet dataset
            batch_data: list[dict[str, Any]] = []

            # Add progress bar for saving outputs
            save_pbar = tqdm(enumerate(valid_data["path"]), desc="Saving outputs", unit="item", leave=False)

            for idx, video_path in save_pbar:
                video_name = os.path.basename(video_path).split(".")[0]

                clip_feature_np = clip_features[idx].cpu().numpy()
                first_frame_latent_np = image_latents[idx].cpu().numpy()
                pil_image_np = pil_image[idx].cpu().numpy()
                keyboard_cond_np = keyboard_cond[idx] if keyboard_cond is not None else None
                mouse_cond_np = mouse_cond[idx] if mouse_cond is not None else None

                # Get trajectory features for this sample
                traj_latents = extra_features["trajectory_latents"][idx]
                traj_timesteps = extra_features["trajectory_timesteps"][idx]
                if isinstance(traj_latents, torch.Tensor):
                    traj_latents = traj_latents.cpu().float().numpy()
                if isinstance(traj_timesteps, torch.Tensor):
                    traj_timesteps = traj_timesteps.cpu().float().numpy()

                # Create record for Parquet dataset
                record: dict[str, Any] = matrixgame2_ode_record_creator(video_name=video_name,
                                                                        clip_feature=clip_feature_np,
                                                                        first_frame_latent=first_frame_latent_np,
                                                                        trajectory_latents=traj_latents,
                                                                        trajectory_timesteps=traj_timesteps,
                                                                        pil_image=pil_image_np,
                                                                        keyboard_cond=keyboard_cond_np,
                                                                        mouse_cond=mouse_cond_np,
                                                                        caption="")
                batch_data.append(record)

            if batch_data:
                write_pbar = tqdm(total=1, desc="Writing to Parquet dataset", unit="batch")
                table = records_to_table(batch_data, self.get_pyarrow_schema())
                write_pbar.update(1)
                write_pbar.close()

                if not hasattr(self, 'dataset_writer'):
                    self.dataset_writer = ParquetDatasetWriter(
                        out_dir=self.combined_parquet_dir,
                        samples_per_file=args.samples_per_file,
                    )
                self.dataset_writer.append_table(table)

                logger.info("Collected batch with %s samples", len(table))

            if self.num_processed_samples >= args.flush_frequency:
                written = self.dataset_writer.flush()
                logger.info("Flushed %s samples to parquet", written)
                self.num_processed_samples = 0

    # Final flush for any remaining samples
    if hasattr(self, 'dataset_writer'):
        written = self.dataset_writer.flush(write_remainder=True)
        if written:
            logger.info("Final flush wrote %s samples", written)

Functions