Skip to content

input_validation

Input validation stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.input_validation.InputValidationStage

Bases: PipelineStage

Stage for validating and preparing inputs for diffusion pipelines.

This stage validates that all required inputs are present and properly formatted before proceeding with the diffusion process.

Functions

fastvideo.pipelines.stages.input_validation.InputValidationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Validate and prepare inputs.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The validated batch information.

Source code in fastvideo/pipelines/stages/input_validation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Validate and prepare inputs.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The validated batch information.
    """

    self._generate_seeds(batch, fastvideo_args)

    # Ensure prompt is properly formatted
    if batch.prompt is None and batch.prompt_embeds is None:
        raise ValueError("Either `prompt` or `prompt_embeds` must be provided")

    # Ensure negative prompt is properly formatted if using classifier-free guidance
    if (batch.do_classifier_free_guidance and batch.negative_prompt is None
            and batch.negative_prompt_embeds is None):
        raise ValueError("For classifier-free guidance, either `negative_prompt` or "
                         "`negative_prompt_embeds` must be provided")

    # Validate height and width
    if batch.height is None or batch.width is None:
        raise ValueError("Height and width must be provided. Please set `height` and `width`.")
    if batch.height % 8 != 0 or batch.width % 8 != 0:
        raise ValueError(f"Height and width must be divisible by 8 but are {batch.height} and {batch.width}.")

    # Validate number of inference steps
    if batch.num_inference_steps <= 0:
        raise ValueError(f"Number of inference steps must be positive, but got {batch.num_inference_steps}")

    # Validate guidance scale if using classifier-free guidance
    if batch.do_classifier_free_guidance and batch.guidance_scale <= 0:
        raise ValueError(f"Guidance scale must be positive, but got {batch.guidance_scale}")

    # for i2v, get image from image_path
    # @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage
    if batch.image_path is not None:
        if batch.image_path.endswith(".mp4"):
            image = load_video(batch.image_path)[0]
        else:
            image = load_image(batch.image_path)
        batch.pil_image = image

    # further processing for ti2v task
    if (fastvideo_args.pipeline_config.ti2v_task
            or fastvideo_args.pipeline_config.is_causal) and batch.pil_image is not None:
        img = batch.pil_image
        ih, iw = img.height, img.width

        pipeline_class_name = type(fastvideo_args.pipeline_config).__name__
        if 'MatrixGame' in pipeline_class_name or 'MatrixCausal' in pipeline_class_name:
            oh, ow = batch.height, batch.width
            img = img.resize((ow, oh), Image.LANCZOS)
        else:
            # Standard Wan logic
            patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
            vae_stride = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
            dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride
            max_area = 480 * 832
            ow, oh = best_output_size(iw, ih, dw, dh, max_area)

            scale = max(ow / iw, oh / ih)
            img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)

            # center-crop
            x1 = (img.width - ow) // 2
            y1 = (img.height - oh) // 2
            img = img.crop((x1, y1, x1 + ow, y1 + oh))

        assert img.width == ow and img.height == oh
        logger.info("final processed img height: %s, img width: %s", img.height, img.width)

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)
        img = img.unsqueeze(0)
        batch.height = oh
        batch.width = ow
        batch.pil_image = img

    # for v2v, get control video from video path
    if batch.video_path is not None:
        pil_images, original_fps = load_video(batch.video_path, return_fps=True)
        logger.info("Loaded video with %s frames, original FPS: %s", len(pil_images), original_fps)

        # Get target parameters from batch
        target_fps = batch.fps
        target_num_frames = batch.num_frames
        target_height = batch.height
        target_width = batch.width

        if target_fps is not None and original_fps is not None:
            frame_skip = max(1, int(original_fps // target_fps))
            if frame_skip > 1:
                pil_images = pil_images[::frame_skip]
                effective_fps = original_fps / frame_skip
                logger.info("Resampled video from %.1f fps to %.1f fps (skip=%s)", original_fps, effective_fps,
                            frame_skip)

        # Limit to target number of frames
        if target_num_frames is not None and len(pil_images) > target_num_frames:
            pil_images = pil_images[:target_num_frames]
            logger.info("Limited video to %s frames (from %s total)", target_num_frames, len(pil_images))

        # Resize each PIL image to target dimensions
        resized_images = []
        for pil_img in pil_images:
            resized_img = resize(pil_img, target_height, target_width, resize_mode="default", resample="lanczos")
            resized_images.append(resized_img)

        # Convert PIL images to numpy array
        video_numpy = pil_to_numpy(resized_images)
        video_numpy = normalize(video_numpy)
        video_tensor = numpy_to_pt(video_numpy)

        # Rearrange to [C, T, H, W] and add batch dimension -> [B, C, T, H, W]
        input_video = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)

        batch.video_latent = input_video

    # Validate action control inputs (Matrix-Game)
    if batch.mouse_cond is not None:
        if batch.mouse_cond.dim() != 3 or batch.mouse_cond.shape[-1] != 2:
            raise ValueError(f"mouse_cond must have shape (B, T, 2), but got {batch.mouse_cond.shape}")
        logger.info("Action control: mouse_cond validated - shape %s", batch.mouse_cond.shape)

    if batch.keyboard_cond is not None:
        if batch.keyboard_cond.dim() != 3:
            raise ValueError(f"keyboard_cond must have 3 dimensions (B, T, K), but got {batch.keyboard_cond.dim()}")
        keyboard_dim = batch.keyboard_cond.shape[-1]
        if keyboard_dim not in {2, 3, 4, 6, 7}:
            raise ValueError(f"keyboard_cond last dimension must be 2, 3, 4, 6, or 7, but got {keyboard_dim}")
        logger.info("Action control: keyboard_cond validated - shape %s (dim=%d)", batch.keyboard_cond.shape,
                    keyboard_dim)

    if batch.grid_sizes is not None:
        if not isinstance(batch.grid_sizes, list | tuple | torch.Tensor):
            raise ValueError("grid_sizes must be a list, tuple, or tensor")
        if isinstance(batch.grid_sizes, torch.Tensor):
            if batch.grid_sizes.numel() != 3:
                raise ValueError("grid_sizes must have 3 elements [F, H, W]")
        else:
            if len(batch.grid_sizes) != 3:
                raise ValueError("grid_sizes must have 3 elements [F, H, W]")
        logger.info("Action control: grid_sizes validated - %s", batch.grid_sizes)

    return batch
fastvideo.pipelines.stages.input_validation.InputValidationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage inputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage inputs."""
    result = VerificationResult()
    # Cosmos-Predict2.5 default seed is 0; allow non-negative seeds here.
    result.add_check("seed", batch.seed, [V.not_none, V.non_negative_int])
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt, V.positive_int)
    result.add_check("prompt_or_embeds", None,
                     lambda _: V.string_or_list_strings(batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     lambda x: not batch.do_classifier_free_guidance or V.positive_float(x))
    return result
fastvideo.pipelines.stages.input_validation.InputValidationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage outputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage outputs."""
    result = VerificationResult()
    result.add_check("seeds", batch.seeds, V.list_not_empty)
    result.add_check("generator", batch.generator, V.generator_or_list_generators)
    return result

Functions