Per-rank negative-prompt encoding shared by training model plugins.
Encoding the negative prompt only on rank 0 and broadcasting (the
previous Wan path) ran Pipeline.from_pretrained asymmetrically across
ranks, which deadlocked on any collective fired during text-encoder load
(FSDP device-mesh init, weight broadcast, etc.). The text encoder is
small and only loaded once at startup, so loading it on every rank
sidesteps the deadlock entirely.
Classes
Functions
fastvideo.train.utils.negative_prompt.encode_negative_prompt
encode_negative_prompt(training_config: TrainingConfig, *, prompt: str, device: device, dtype: dtype, encoder_index: int = 0) -> tuple[Tensor, Tensor]
Per-rank encode of prompt using encoder encoder_index.
Reads pipeline_config.text_encoder_configs[encoder_index] so the
encoder class (e.g. UMT5 for Wan) and tokenizer kwargs match the
inference path, and applies the matching postprocess_text_funcs
entry. Returns (embeds, mask) on device cast to dtype.
Source code in fastvideo/train/utils/negative_prompt.py
| def encode_negative_prompt(
training_config: TrainingConfig,
*,
prompt: str,
device: torch.device,
dtype: torch.dtype,
encoder_index: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Per-rank encode of ``prompt`` using encoder ``encoder_index``.
Reads ``pipeline_config.text_encoder_configs[encoder_index]`` so the
encoder class (e.g. UMT5 for Wan) and tokenizer kwargs match the
inference path, and applies the matching ``postprocess_text_funcs``
entry. Returns ``(embeds, mask)`` on ``device`` cast to ``dtype``.
"""
tc = training_config
pipeline_config = tc.pipeline_config
if pipeline_config is None:
raise ValueError("training_config.pipeline_config is required for negative "
"prompt encoding")
encoder_configs = pipeline_config.text_encoder_configs
postprocess_funcs = pipeline_config.postprocess_text_funcs
preprocess_funcs = getattr(pipeline_config, "preprocess_text_funcs", None)
if encoder_index < 0 or encoder_index >= len(encoder_configs):
raise IndexError(f"encoder_index {encoder_index} out of range for "
f"text_encoder_configs (len={len(encoder_configs)})")
encoder_config = encoder_configs[encoder_index]
postprocess_text = postprocess_funcs[encoder_index]
preprocess_text = (preprocess_funcs[encoder_index] if preprocess_funcs is not None else None)
# HF convention: text_encoder / tokenizer for index 0,
# text_encoder_2 / tokenizer_2 for index 1, etc.
suffix = "" if encoder_index == 0 else f"_{encoder_index + 1}"
encoder_subdir = f"text_encoder{suffix}"
tokenizer_subdir = f"tokenizer{suffix}"
model_path = maybe_download_model(tc.model_path)
inference_args = make_inference_args(tc, model_path=model_path)
# Keep the encoder on-device; CPU offload would init an FSDP device
# mesh and reintroduce the collective at load time.
inference_args.text_encoder_cpu_offload = False
loader = TextEncoderLoader()
text_encoder = loader.load(
os.path.join(model_path, encoder_subdir),
inference_args,
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, tokenizer_subdir))
tok_kwargs = dict(encoder_config.tokenizer_kwargs)
text = preprocess_text(prompt) if preprocess_text is not None else prompt
with torch.no_grad(), set_forward_context(
current_timestep=0,
attn_metadata=None,
):
text_inputs = tokenizer(text, **tok_kwargs).to(device)
outputs = text_encoder(
input_ids=text_inputs.input_ids,
attention_mask=text_inputs.attention_mask,
)
# Mirror TextEncodingStage: postprocess reads outputs.attention_mask.
outputs.attention_mask = text_inputs["attention_mask"]
embeds = postprocess_text(outputs).to(device=device, dtype=dtype)
mask = text_inputs["attention_mask"].to(device=device, dtype=dtype)
del text_encoder, tokenizer
return embeds, mask
|