@pytest.mark.parametrize("prompt", TEST_PROMPTS)
@pytest.mark.parametrize("ATTENTION_BACKEND", ["FLASH_ATTN"])
@pytest.mark.parametrize("model_id", list(MODEL_TO_PARAMS.keys()))
def test_matrixgame3_similarity(prompt, ATTENTION_BACKEND, model_id):
"""
Test that runs MG3 inference (action conditions auto-generated from seed)
and compares the output to reference videos using SSIM.
"""
os.environ["FASTVIDEO_ATTENTION_BACKEND"] = ATTENTION_BACKEND
script_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = build_generated_output_dir(
script_dir,
device_reference_folder,
model_id,
ATTENTION_BACKEND,
)
os.makedirs(output_dir, exist_ok=True)
params_map = select_ssim_params(
MODEL_TO_PARAMS,
FULL_QUALITY_MODEL_TO_PARAMS,
)
BASE_PARAMS = params_map[model_id]
num_inference_steps = BASE_PARAMS["num_inference_steps"]
init_kwargs = {
"num_gpus": BASE_PARAMS["num_gpus"],
"use_fsdp_inference": True,
"dit_layerwise_offload": False,
"dit_cpu_offload": False,
"vae_cpu_offload": False,
"text_encoder_cpu_offload": True,
"pin_cpu_memory": True,
}
generation_kwargs = {
"num_inference_steps": num_inference_steps,
"output_path": output_dir,
"image_path": TEST_IMAGE_PATHS[0],
"height": BASE_PARAMS["height"],
"width": BASE_PARAMS["width"],
"num_frames": BASE_PARAMS["num_frames"],
"guidance_scale": BASE_PARAMS["guidance_scale"],
"seed": BASE_PARAMS["seed"],
"save_video": True,
}
generator = VideoGenerator.from_pretrained(model_path=BASE_PARAMS["model_path"], **init_kwargs)
generator.generate_video(prompt, **generation_kwargs)
if isinstance(generator.executor, MultiprocExecutor):
generator.executor.shutdown()
assert os.path.exists(output_dir), f"Output video was not generated at {output_dir}"
reference_folder = build_reference_folder_path(
script_dir,
device_reference_folder,
model_id,
ATTENTION_BACKEND,
)
if not os.path.exists(reference_folder):
logger.error("Reference folder missing")
raise FileNotFoundError(f"Reference video folder does not exist: {reference_folder}")
prompt_prefix = prompt[:100].strip().rstrip(".")
def _find_mp4(folder: str) -> str | None:
for filename in sorted(os.listdir(folder)):
if filename.endswith(".mp4") and prompt_prefix in filename:
return filename
return None
reference_video_name = _find_mp4(reference_folder)
if not reference_video_name:
logger.error(f"Reference video not found for model: {model_id} with backend: {ATTENTION_BACKEND}")
raise FileNotFoundError("Reference video missing")
generated_video_name = _find_mp4(output_dir)
if not generated_video_name:
logger.error(f"Generated video not found for model: {model_id} with backend: {ATTENTION_BACKEND}")
raise FileNotFoundError(f"Generated video missing in {output_dir}")
reference_video_path = os.path.join(reference_folder, reference_video_name)
generated_video_path = os.path.join(output_dir, generated_video_name)
logger.info(f"Computing SSIM between {reference_video_path} and {generated_video_path}")
ssim_values = compute_video_ssim_torchvision(reference_video_path, generated_video_path, use_ms_ssim=True)
mean_ssim = ssim_values[0]
logger.info(f"SSIM mean value: {mean_ssim}")
logger.info(f"Writing SSIM results to directory: {output_dir}")
success = write_ssim_results(
output_dir,
ssim_values,
reference_video_path,
generated_video_path,
num_inference_steps,
prompt,
)
if not success:
logger.error("Failed to write SSIM results to file")
min_acceptable_ssim = 0.98
assert mean_ssim >= min_acceptable_ssim, (
f"SSIM value {mean_ssim} is below threshold {min_acceptable_ssim} for {model_id} with backend {ATTENTION_BACKEND}"
)