Skip to content

test_matrixgame3_similarity

Classes

Functions

fastvideo.tests.ssim.test_matrixgame3_similarity.test_matrixgame3_similarity

test_matrixgame3_similarity(prompt, ATTENTION_BACKEND, model_id)

Test that runs MG3 inference (action conditions auto-generated from seed) and compares the output to reference videos using SSIM.

Source code in fastvideo/tests/ssim/test_matrixgame3_similarity.py
@pytest.mark.parametrize("prompt", TEST_PROMPTS)
@pytest.mark.parametrize("ATTENTION_BACKEND", ["FLASH_ATTN"])
@pytest.mark.parametrize("model_id", list(MODEL_TO_PARAMS.keys()))
def test_matrixgame3_similarity(prompt, ATTENTION_BACKEND, model_id):
    """
    Test that runs MG3 inference (action conditions auto-generated from seed)
    and compares the output to reference videos using SSIM.
    """
    os.environ["FASTVIDEO_ATTENTION_BACKEND"] = ATTENTION_BACKEND

    script_dir = os.path.dirname(os.path.abspath(__file__))

    output_dir = build_generated_output_dir(
        script_dir,
        device_reference_folder,
        model_id,
        ATTENTION_BACKEND,
    )

    os.makedirs(output_dir, exist_ok=True)

    params_map = select_ssim_params(
        MODEL_TO_PARAMS,
        FULL_QUALITY_MODEL_TO_PARAMS,
    )
    BASE_PARAMS = params_map[model_id]
    num_inference_steps = BASE_PARAMS["num_inference_steps"]

    init_kwargs = {
        "num_gpus": BASE_PARAMS["num_gpus"],
        "use_fsdp_inference": True,
        "dit_layerwise_offload": False,
        "dit_cpu_offload": False,
        "vae_cpu_offload": False,
        "text_encoder_cpu_offload": True,
        "pin_cpu_memory": True,
    }

    generation_kwargs = {
        "num_inference_steps": num_inference_steps,
        "output_path": output_dir,
        "image_path": TEST_IMAGE_PATHS[0],
        "height": BASE_PARAMS["height"],
        "width": BASE_PARAMS["width"],
        "num_frames": BASE_PARAMS["num_frames"],
        "guidance_scale": BASE_PARAMS["guidance_scale"],
        "seed": BASE_PARAMS["seed"],
        "save_video": True,
    }

    generator = VideoGenerator.from_pretrained(model_path=BASE_PARAMS["model_path"], **init_kwargs)
    generator.generate_video(prompt, **generation_kwargs)

    if isinstance(generator.executor, MultiprocExecutor):
        generator.executor.shutdown()

    assert os.path.exists(output_dir), f"Output video was not generated at {output_dir}"

    reference_folder = build_reference_folder_path(
        script_dir,
        device_reference_folder,
        model_id,
        ATTENTION_BACKEND,
    )

    if not os.path.exists(reference_folder):
        logger.error("Reference folder missing")
        raise FileNotFoundError(f"Reference video folder does not exist: {reference_folder}")

    prompt_prefix = prompt[:100].strip().rstrip(".")

    def _find_mp4(folder: str) -> str | None:
        for filename in sorted(os.listdir(folder)):
            if filename.endswith(".mp4") and prompt_prefix in filename:
                return filename
        return None

    reference_video_name = _find_mp4(reference_folder)
    if not reference_video_name:
        logger.error(f"Reference video not found for model: {model_id} with backend: {ATTENTION_BACKEND}")
        raise FileNotFoundError("Reference video missing")

    generated_video_name = _find_mp4(output_dir)
    if not generated_video_name:
        logger.error(f"Generated video not found for model: {model_id} with backend: {ATTENTION_BACKEND}")
        raise FileNotFoundError(f"Generated video missing in {output_dir}")

    reference_video_path = os.path.join(reference_folder, reference_video_name)
    generated_video_path = os.path.join(output_dir, generated_video_name)

    logger.info(f"Computing SSIM between {reference_video_path} and {generated_video_path}")
    ssim_values = compute_video_ssim_torchvision(reference_video_path, generated_video_path, use_ms_ssim=True)

    mean_ssim = ssim_values[0]
    logger.info(f"SSIM mean value: {mean_ssim}")
    logger.info(f"Writing SSIM results to directory: {output_dir}")

    success = write_ssim_results(
        output_dir,
        ssim_values,
        reference_video_path,
        generated_video_path,
        num_inference_steps,
        prompt,
    )

    if not success:
        logger.error("Failed to write SSIM results to file")

    min_acceptable_ssim = 0.98
    assert mean_ssim >= min_acceptable_ssim, (
        f"SSIM value {mean_ssim} is below threshold {min_acceptable_ssim} for {model_id} with backend {ATTENTION_BACKEND}"
    )