multimodalart HF Staff commited on
Commit
81d4f64
·
verified ·
1 Parent(s): d2d267c

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1261 -229
pipeline.py CHANGED
@@ -1,240 +1,1272 @@
1
- import spaces
2
- import gradio as gr
3
- import torch
4
- from diffusers import DiffusionPipeline
5
- from diffusers.utils import load_image, export_to_video
6
- import os
7
- import random
 
 
 
 
 
 
 
 
 
 
 
8
  import numpy as np
9
- from moviepy import ImageSequenceClip, AudioFileClip, VideoFileClip
10
- from PIL import Image, ImageOps
11
-
12
- # --- 1. Model Setup & Configuration ---
13
-
14
- # Define the specific distilled sigmas (from LTX-2 documentation)
15
- DISTILLED_SIGMA_VALUES = [
16
- 1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875
17
- ]
18
-
19
- print("Loading LTX-2 Distilled Pipeline...")
20
- pipe = DiffusionPipeline.from_pretrained(
21
- "rootonchair/LTX-2-19b-distilled",
22
- custom_pipeline="multimodalart/ltx2-audio-to-video",
23
- torch_dtype=torch.bfloat16
24
- )
25
- pipe.to("cuda")
26
-
27
- # --- Restored LoRA Loading ---
28
- # We load the Camera Control LoRA and fuse it.
29
- # This stabilizes the video (static camera), which is usually preferred for talking heads.
30
- print("Loading and Fusing Camera Control LoRA...")
31
- pipe.load_lora_weights("Lightricks/LTX-2-19b-LoRA-Camera-Control-Static", adapter_name="camera_control")
32
- # We fuse with a scale of 1.0. This merges the weights into the base model permanently for this session.
33
- pipe.fuse_lora(lora_scale=1.0)
34
-
35
- # --- 2. Helper Functions ---
36
-
37
- def save_video_with_audio(video_frames, audio_path, fps=24):
38
- """
39
- Combines the generated video frames with the original input audio.
40
- """
41
- output_filename = f"output_{random.randint(0, 100000)}.mp4"
42
-
43
- # 1. Handle Diffusers Output Formats
44
- if isinstance(video_frames, list):
45
- if video_frames and isinstance(video_frames[0], list):
46
- frames_to_process = video_frames[0]
47
- else:
48
- frames_to_process = video_frames
49
- np_frames = [np.array(img) for img in frames_to_process]
50
- clip = ImageSequenceClip(np_frames, fps=fps)
51
-
52
- elif isinstance(video_frames, str):
53
- clip = VideoFileClip(video_frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  else:
55
- temp_path = "temp_video_no_audio.mp4"
56
- export_to_video(video_frames, temp_path, fps=fps)
57
- clip = VideoFileClip(temp_path)
58
-
59
- # 2. Load and Process Audio
60
- audio_clip = AudioFileClip(audio_path)
61
-
62
- if audio_clip.duration > clip.duration:
63
- audio_clip = audio_clip.subclipped(0, clip.duration)
64
-
65
- # 3. Combine and Save
66
- final_clip = clip.with_audio(audio_clip)
67
-
68
- final_clip.write_videofile(
69
- output_filename,
70
- fps=fps,
71
- codec="libx264",
72
- audio_codec="aac",
73
- logger="bar"
74
- )
75
-
76
- final_clip.close()
77
- audio_clip.close()
78
- if 'clip' in locals(): clip.close()
79
-
80
- return output_filename
81
-
82
- def process_image_for_aspect_ratio(image, aspect_ratio_str):
83
- """
84
- Crops and resizes the image to match the target resolution based on aspect ratio.
85
- """
86
- # Define resolutions (W, H)
87
- resolutions = {
88
- "1:1 (Square)": (512, 512),
89
- "16:9 (Cinematic)": (768, 512),
90
- "9:16 (Vertical)": (512, 768)
91
- }
92
-
93
- target_width, target_height = resolutions.get(aspect_ratio_str, (768, 512))
94
-
95
- # Use ImageOps.fit to center crop and resize automatically
96
- # This preserves aspect ratio of the content while filling the target dimensions
97
- processed_img = ImageOps.fit(
98
- image,
99
- (target_width, target_height),
100
- method=Image.LANCZOS,
101
- centering=(0.5, 0.5)
102
- )
103
-
104
- return processed_img, target_width, target_height
105
-
106
- # --- 3. Inference Function ---
107
- @spaces.GPU(duration=120, size='xlarge')
108
- def generate(
109
- image_path,
110
- audio_path,
111
- prompt,
112
- negative_prompt,
113
- aspect_ratio,
114
- video_duration,
115
- seed
116
  ):
117
- if not image_path or not audio_path:
118
- raise gr.Error("Please provide both an image and an audio file.")
119
-
120
- # Set reproducibility
121
- if seed == -1:
122
- seed = random.randint(0, 1000000)
123
- generator = torch.Generator(device="cuda").manual_seed(seed)
124
-
125
- # 1. Load and Preprocess Image
126
- original_image = load_image(image_path)
127
- # Crop/Resize logic
128
- image, width, height = process_image_for_aspect_ratio(original_image, aspect_ratio)
129
-
130
- print(f"Generating with seed: {seed}, Resolution: {width}x{height}, Duration: {video_duration}s")
131
-
132
- # 2. Calculate Frames
133
- fps = 24.0
134
- # LTX-2 constraint: (num_frames - 1) % 8 == 0
135
- total_frames = int(video_duration * fps)
136
-
137
- # Round to nearest valid block of 8, plus 1
138
- # Example: 4 seconds * 24 = 96 frames.
139
- # 96 is divisible by 8. So we take 96 + 1 = 97 frames.
140
- base_block = round(total_frames / 8) * 8
141
- num_frames = base_block + 1
142
-
143
- # Ensure sane minimum
144
- if num_frames < 9: num_frames = 9
145
-
146
- print(f"Calculated frames: {num_frames}")
147
-
148
- # 3. Run Inference
149
- video_output, _ = pipe(
150
- image=image,
151
- audio=audio_path,
152
- prompt=prompt,
153
- negative_prompt=negative_prompt,
154
- width=width,
155
- height=height,
156
- num_frames=num_frames,
157
- frame_rate=fps,
158
- num_inference_steps=8, # Distilled uses 8 steps
159
- sigmas=DISTILLED_SIGMA_VALUES,
160
- guidance_scale=1.0,
161
- generator=generator,
162
- return_dict=False,
163
- )
164
-
165
- # 4. Post-process: Add audio
166
- output_video_path = save_video_with_audio(video_output, audio_path, fps=fps)
167
-
168
- return output_video_path, seed
169
-
170
- # --- 4. Gradio Interface Definition ---
171
-
172
- css = """
173
- #col-container { max-width: 800px; margin: 0 auto; }
174
- """
175
 
176
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
177
- with gr.Column(elem_id="col-container"):
178
- gr.Markdown("# ⚡ LTX-2 Distilled Audio-to-Video")
179
- gr.Markdown("Generate lip-synced or audio-reactive video from a single image using the distilled 8-step LTX-2 model.")
180
-
181
- with gr.Row():
182
- with gr.Column():
183
- input_image = gr.Image(label="Input Image", type="filepath", height=300)
184
- input_audio = gr.Audio(label="Input Audio", type="filepath")
185
-
186
- with gr.Column():
187
- result_video = gr.Video(label="Generated Video")
188
-
189
- prompt = gr.Textbox(
190
- label="Prompt",
191
- value="A person speaking, lips moving in sync with the words, talking head",
192
- lines=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  )
194
-
195
- # New Controls
196
- with gr.Row():
197
- aspect_ratio = gr.Radio(
198
- choices=["1:1 (Square)", "16:9 (Cinematic)", "9:16 (Vertical)"],
199
- value="16:9 (Cinematic)",
200
- label="Output Aspect Ratio (Auto-Crops Image)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  )
202
- video_duration = gr.Slider(
203
- label="Video Duration (Seconds)",
204
- minimum=1.0,
205
- maximum=6.0,
206
- step=0.5,
207
- value=4.0,
208
- info="Approximate length. Longer videos require more GPU memory."
209
  )
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- with gr.Accordion("Advanced Settings", open=False):
212
- negative_prompt = gr.Textbox(
213
- label="Negative Prompt",
214
- value="low quality, worst quality, deformed, distorted",
215
- placeholder="Usually ignored by distilled models with guidance 1.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  )
217
- seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
 
 
 
218
 
219
- run_btn = gr.Button("Generate Video", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
- # Output info
222
- used_seed = gr.Number(label="Used Seed", visible=False)
223
-
224
- # Event Logic
225
- run_btn.click(
226
- fn=generate,
227
- inputs=[
228
- input_image,
229
- input_audio,
230
- prompt,
231
- negative_prompt,
232
- aspect_ratio,
233
- video_duration,
234
- seed
235
- ],
236
- outputs=[result_video, used_seed]
237
- )
238
-
239
- if __name__ == "__main__":
240
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
  import numpy as np
20
+ import torch
21
+ import torchaudio
22
+ import torchaudio.transforms as T
23
+ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
24
+
25
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from diffusers.image_processor import PipelineImageInput
27
+ from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
28
+ from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
29
+ from diffusers.models.transformers import LTX2VideoTransformer3DModel
30
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
31
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.video_processor import VideoProcessor
34
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
35
+ from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
36
+ from diffusers.pipelines.ltx2.pipeline_output import LTX2PipelineOutput
37
+ from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
38
+
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+ EXAMPLE_DOC_STRING = """
50
+ Examples:
51
+ ```py
52
+ >>> import torch
53
+ >>> from diffusers import DiffusionPipeline
54
+ >>> from diffusers.pipelines.ltx2.export_utils import encode_video
55
+ >>> from diffusers.utils import load_image
56
+
57
+ >>> pipe = DiffusionPipeline.from_pretrained(
58
+ ... "Lightricks/LTX-2",
59
+ ... custom_pipeline="pipeline_ltx2_audio2video",
60
+ ... torch_dtype=torch.bfloat16
61
+ ... )
62
+ >>> pipe.enable_model_cpu_offload()
63
+
64
+ >>> image = load_image(
65
+ ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
66
+ ... )
67
+ >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
68
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
69
+
70
+ >>> frame_rate = 24.0
71
+ >>> video, audio = pipe(
72
+ ... image=image,
73
+ ... audio="path/to/audio.wav",
74
+ ... prompt=prompt,
75
+ ... negative_prompt=negative_prompt,
76
+ ... width=768,
77
+ ... height=512,
78
+ ... num_frames=121,
79
+ ... frame_rate=frame_rate,
80
+ ... num_inference_steps=40,
81
+ ... guidance_scale=4.0,
82
+ ... output_type="np",
83
+ ... return_dict=False,
84
+ ... )
85
+ >>> video = (video * 255).round().astype("uint8")
86
+ >>> video = torch.from_numpy(video)
87
+
88
+ >>> encode_video(
89
+ ... video[0],
90
+ ... fps=frame_rate,
91
+ ... audio=audio[0].float().cpu(),
92
+ ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
93
+ ... output_path="video.mp4",
94
+ ... )
95
+ ```
96
+ """
97
+
98
+
99
+ def retrieve_latents(
100
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
101
+ ):
102
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
103
+ return encoder_output.latent_dist.sample(generator)
104
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
105
+ return encoder_output.latent_dist.mode()
106
+ elif hasattr(encoder_output, "latents"):
107
+ return encoder_output.latents
108
  else:
109
+ raise AttributeError("Could not access latents of provided encoder_output")
110
+
111
+
112
+ def calculate_shift(
113
+ image_seq_len,
114
+ base_seq_len: int = 256,
115
+ max_seq_len: int = 4096,
116
+ base_shift: float = 0.5,
117
+ max_shift: float = 1.15,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  ):
119
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
120
+ b = base_shift - m * base_seq_len
121
+ mu = image_seq_len * m + b
122
+ return mu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+
125
+ def retrieve_timesteps(
126
+ scheduler,
127
+ num_inference_steps: Optional[int] = None,
128
+ device: Optional[Union[str, torch.device]] = None,
129
+ timesteps: Optional[List[int]] = None,
130
+ sigmas: Optional[List[float]] = None,
131
+ **kwargs,
132
+ ):
133
+ if timesteps is not None and sigmas is not None:
134
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
135
+ if timesteps is not None:
136
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
137
+ if not accepts_timesteps:
138
+ raise ValueError(
139
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
140
+ f" timestep schedules. Please check whether you are using the correct scheduler."
141
+ )
142
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ num_inference_steps = len(timesteps)
145
+ elif sigmas is not None:
146
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
147
+ if not accept_sigmas:
148
+ raise ValueError(
149
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
150
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
151
+ )
152
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
153
+ timesteps = scheduler.timesteps
154
+ num_inference_steps = len(timesteps)
155
+ else:
156
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
157
+ timesteps = scheduler.timesteps
158
+ return timesteps, num_inference_steps
159
+
160
+
161
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
162
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
163
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
164
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
165
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
166
+ return noise_cfg
167
+
168
+
169
+ class LTX2AudioToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
170
+ r"""
171
+ Pipeline for audio-to-video generation with optional image conditioning.
172
+
173
+ This pipeline generates video conditioned on input audio, forcing the video generation
174
+ to attend to specific audio cues. It also supports image conditioning for image-to-video
175
+ generation with audio.
176
+
177
+ Reference: https://github.com/Lightricks/LTX-Video
178
+ """
179
+
180
+ model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
181
+ _optional_components = []
182
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
183
+
184
+ def __init__(
185
+ self,
186
+ scheduler: FlowMatchEulerDiscreteScheduler,
187
+ vae: AutoencoderKLLTX2Video,
188
+ audio_vae: AutoencoderKLLTX2Audio,
189
+ text_encoder: Gemma3ForConditionalGeneration,
190
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
191
+ connectors: LTX2TextConnectors,
192
+ transformer: LTX2VideoTransformer3DModel,
193
+ vocoder: LTX2Vocoder,
194
+ ):
195
+ super().__init__()
196
+
197
+ self.register_modules(
198
+ vae=vae,
199
+ audio_vae=audio_vae,
200
+ text_encoder=text_encoder,
201
+ tokenizer=tokenizer,
202
+ connectors=connectors,
203
+ transformer=transformer,
204
+ vocoder=vocoder,
205
+ scheduler=scheduler,
206
  )
207
+
208
+ self.vae_spatial_compression_ratio = (
209
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
210
+ )
211
+ self.vae_temporal_compression_ratio = (
212
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
213
+ )
214
+ self.audio_vae_mel_compression_ratio = (
215
+ self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
216
+ )
217
+ self.audio_vae_temporal_compression_ratio = (
218
+ self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
219
+ )
220
+ self.transformer_spatial_patch_size = (
221
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
222
+ )
223
+ self.transformer_temporal_patch_size = (
224
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
225
+ )
226
+
227
+ self.audio_sampling_rate = (
228
+ self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000
229
+ )
230
+ self.audio_hop_length = (
231
+ self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160
232
+ )
233
+
234
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear")
235
+ self.tokenizer_max_length = (
236
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
237
+ )
238
+
239
+ @staticmethod
240
+ def _pack_text_embeds(
241
+ text_hidden_states: torch.Tensor,
242
+ sequence_lengths: torch.Tensor,
243
+ device: Union[str, torch.device],
244
+ padding_side: str = "left",
245
+ scale_factor: int = 8,
246
+ eps: float = 1e-6,
247
+ ) -> torch.Tensor:
248
+ batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
249
+ original_dtype = text_hidden_states.dtype
250
+
251
+ token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
252
+ if padding_side == "right":
253
+ mask = token_indices < sequence_lengths[:, None]
254
+ elif padding_side == "left":
255
+ start_indices = seq_len - sequence_lengths[:, None]
256
+ mask = token_indices >= start_indices
257
+ else:
258
+ raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
259
+ mask = mask[:, :, None, None]
260
+
261
+ masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
262
+ num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
263
+ masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
264
+
265
+ x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
266
+ x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
267
+
268
+ normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
269
+ normalized_hidden_states = normalized_hidden_states * scale_factor
270
+
271
+ normalized_hidden_states = normalized_hidden_states.flatten(2)
272
+ mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
273
+ normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
274
+ normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
275
+ return normalized_hidden_states
276
+
277
+ def _get_gemma_prompt_embeds(
278
+ self,
279
+ prompt: Union[str, List[str]],
280
+ num_videos_per_prompt: int = 1,
281
+ max_sequence_length: int = 1024,
282
+ scale_factor: int = 8,
283
+ device: Optional[torch.device] = None,
284
+ dtype: Optional[torch.dtype] = None,
285
+ ):
286
+ device = device or self._execution_device
287
+ dtype = dtype or self.text_encoder.dtype
288
+
289
+ prompt = [prompt] if isinstance(prompt, str) else prompt
290
+ batch_size = len(prompt)
291
+
292
+ if getattr(self, "tokenizer", None) is not None:
293
+ self.tokenizer.padding_side = "left"
294
+ if self.tokenizer.pad_token is None:
295
+ self.tokenizer.pad_token = self.tokenizer.eos_token
296
+
297
+ prompt = [p.strip() for p in prompt]
298
+ text_inputs = self.tokenizer(
299
+ prompt,
300
+ padding="max_length",
301
+ max_length=max_sequence_length,
302
+ truncation=True,
303
+ add_special_tokens=True,
304
+ return_tensors="pt",
305
+ )
306
+ text_input_ids = text_inputs.input_ids
307
+ prompt_attention_mask = text_inputs.attention_mask
308
+ text_input_ids = text_input_ids.to(device)
309
+ prompt_attention_mask = prompt_attention_mask.to(device)
310
+
311
+ text_encoder_outputs = self.text_encoder(
312
+ input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
313
+ )
314
+ text_encoder_hidden_states = text_encoder_outputs.hidden_states
315
+ text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
316
+ sequence_lengths = prompt_attention_mask.sum(dim=-1)
317
+
318
+ prompt_embeds = self._pack_text_embeds(
319
+ text_encoder_hidden_states,
320
+ sequence_lengths,
321
+ device=device,
322
+ padding_side=self.tokenizer.padding_side,
323
+ scale_factor=scale_factor,
324
+ )
325
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
326
+
327
+ _, seq_len, _ = prompt_embeds.shape
328
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
329
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
330
+
331
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
332
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
333
+
334
+ return prompt_embeds, prompt_attention_mask
335
+
336
+ def encode_prompt(
337
+ self,
338
+ prompt: Union[str, List[str]],
339
+ negative_prompt: Optional[Union[str, List[str]]] = None,
340
+ do_classifier_free_guidance: bool = True,
341
+ num_videos_per_prompt: int = 1,
342
+ prompt_embeds: Optional[torch.Tensor] = None,
343
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
344
+ prompt_attention_mask: Optional[torch.Tensor] = None,
345
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
346
+ max_sequence_length: int = 1024,
347
+ scale_factor: int = 8,
348
+ device: Optional[torch.device] = None,
349
+ dtype: Optional[torch.dtype] = None,
350
+ ):
351
+ device = device or self._execution_device
352
+
353
+ prompt = [prompt] if isinstance(prompt, str) else prompt
354
+ if prompt is not None:
355
+ batch_size = len(prompt)
356
+ else:
357
+ batch_size = prompt_embeds.shape[0]
358
+
359
+ if prompt_embeds is None:
360
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
361
+ prompt=prompt,
362
+ num_videos_per_prompt=num_videos_per_prompt,
363
+ max_sequence_length=max_sequence_length,
364
+ scale_factor=scale_factor,
365
+ device=device,
366
+ dtype=dtype,
367
+ )
368
+
369
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
370
+ negative_prompt = negative_prompt or ""
371
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
372
+
373
+ if prompt is not None and type(prompt) is not type(negative_prompt):
374
+ raise TypeError(
375
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
376
+ f" {type(prompt)}."
377
+ )
378
+ elif batch_size != len(negative_prompt):
379
+ raise ValueError(
380
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
381
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
382
+ " the batch size of `prompt`."
383
+ )
384
+
385
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
386
+ prompt=negative_prompt,
387
+ num_videos_per_prompt=num_videos_per_prompt,
388
+ max_sequence_length=max_sequence_length,
389
+ scale_factor=scale_factor,
390
+ device=device,
391
+ dtype=dtype,
392
+ )
393
+
394
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
395
+
396
+ def check_inputs(
397
+ self,
398
+ prompt,
399
+ height,
400
+ width,
401
+ callback_on_step_end_tensor_inputs=None,
402
+ prompt_embeds=None,
403
+ negative_prompt_embeds=None,
404
+ prompt_attention_mask=None,
405
+ negative_prompt_attention_mask=None,
406
+ ):
407
+ if height % 32 != 0 or width % 32 != 0:
408
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
409
+
410
+ if callback_on_step_end_tensor_inputs is not None and not all(
411
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
412
+ ):
413
+ raise ValueError(
414
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
415
  )
416
+
417
+ if prompt is not None and prompt_embeds is not None:
418
+ raise ValueError(
419
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
420
+ " only forward one of the two."
 
 
421
  )
422
+ elif prompt is None and prompt_embeds is None:
423
+ raise ValueError(
424
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
425
+ )
426
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
427
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
428
+
429
+ if prompt_embeds is not None and prompt_attention_mask is None:
430
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
431
+
432
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
433
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
434
 
435
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
436
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
437
+ raise ValueError(
438
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
439
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
440
+ f" {negative_prompt_embeds.shape}."
441
+ )
442
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
443
+ raise ValueError(
444
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
445
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
446
+ f" {negative_prompt_attention_mask.shape}."
447
+ )
448
+
449
+ @staticmethod
450
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
451
+ batch_size, num_channels, num_frames, height, width = latents.shape
452
+ post_patch_num_frames = num_frames // patch_size_t
453
+ post_patch_height = height // patch_size
454
+ post_patch_width = width // patch_size
455
+ latents = latents.reshape(
456
+ batch_size,
457
+ -1,
458
+ post_patch_num_frames,
459
+ patch_size_t,
460
+ post_patch_height,
461
+ patch_size,
462
+ post_patch_width,
463
+ patch_size,
464
+ )
465
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
466
+ return latents
467
+
468
+ @staticmethod
469
+ def _unpack_latents(
470
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
471
+ ) -> torch.Tensor:
472
+ batch_size = latents.size(0)
473
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
474
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
475
+ return latents
476
+
477
+ @staticmethod
478
+ def _normalize_latents(
479
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
480
+ ) -> torch.Tensor:
481
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
482
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
483
+ latents = (latents - latents_mean) * scaling_factor / latents_std
484
+ return latents
485
+
486
+ @staticmethod
487
+ def _denormalize_latents(
488
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
489
+ ) -> torch.Tensor:
490
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
491
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
492
+ latents = latents * latents_std / scaling_factor + latents_mean
493
+ return latents
494
+
495
+ @staticmethod
496
+ def _pack_audio_latents(
497
+ latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
498
+ ) -> torch.Tensor:
499
+ if patch_size is not None and patch_size_t is not None:
500
+ batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
501
+ post_patch_latent_length = latent_length / patch_size_t
502
+ post_patch_mel_bins = latent_mel_bins / patch_size
503
+ latents = latents.reshape(
504
+ batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
505
  )
506
+ latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
507
+ else:
508
+ latents = latents.transpose(1, 2).flatten(2, 3)
509
+ return latents
510
 
511
+ @staticmethod
512
+ def _unpack_audio_latents(
513
+ latents: torch.Tensor,
514
+ latent_length: int,
515
+ num_mel_bins: int,
516
+ patch_size: Optional[int] = None,
517
+ patch_size_t: Optional[int] = None,
518
+ ) -> torch.Tensor:
519
+ if patch_size is not None and patch_size_t is not None:
520
+ batch_size = latents.size(0)
521
+ latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size)
522
+ latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
523
+ else:
524
+ latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
525
+ return latents
526
+
527
+ @staticmethod
528
+ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
529
+ """
530
+ Denormalize audio latents. The latents should be in patchified form [B, T, C*F]
531
+ where the last dimension matches the size of latents_mean/latents_std.
532
+ """
533
+ latents_mean = latents_mean.to(latents.device, latents.dtype)
534
+ latents_std = latents_std.to(latents.device, latents.dtype)
535
+ return (latents * latents_std) + latents_mean
536
+
537
+ @staticmethod
538
+ def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
539
+ """
540
+ Normalize audio latents. The latents should be in patchified form [B, T, C*F]
541
+ where the last dimension matches the size of latents_mean/latents_std.
542
+ """
543
+ latents_mean = latents_mean.to(latents.device, latents.dtype)
544
+ latents_std = latents_std.to(latents.device, latents.dtype)
545
+ return (latents - latents_mean) / latents_std
546
+
547
+ @staticmethod
548
+ def _patchify_audio_latents(latents: torch.Tensor) -> torch.Tensor:
549
+ """
550
+ Patchify audio latents from [B, C, T, F] to [B, T, C*F].
551
+ This is needed for normalization which operates on the flattened channel*freq dimension.
552
+ """
553
+ # latents shape: [B, C, T, F] -> [B, T, C*F]
554
+ batch, channels, time, freq = latents.shape
555
+ return latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
556
+
557
+ @staticmethod
558
+ def _unpatchify_audio_latents(latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
559
+ """
560
+ Unpatchify audio latents from [B, T, C*F] to [B, C, T, F].
561
+ """
562
+ # latents shape: [B, T, C*F] -> [B, C, T, F]
563
+ batch, time, _ = latents.shape
564
+ return latents.reshape(batch, time, channels, freq).permute(0, 2, 1, 3)
565
+
566
+ def _preprocess_audio(self, audio: Union[str, torch.Tensor], target_sample_rate: int) -> torch.Tensor:
567
+ """
568
+ Reads audio and converts to Mel Spectrogram matching Audio VAE expectations.
569
+
570
+ The Audio VAE encoder expects input shape: (batch_size, in_channels, time, mel_bins)
571
+ where in_channels=2 (stereo) and mel_bins=64 by default.
572
+
573
+ Uses the same mel spectrogram parameters as Wan2GP's AudioProcessor for compatibility.
574
+ """
575
+ if isinstance(audio, str):
576
+ waveform, sr = torchaudio.load(audio)
577
+ else:
578
+ waveform = audio
579
+ sr = target_sample_rate
580
+
581
+ if sr != target_sample_rate:
582
+ waveform = torchaudio.functional.resample(waveform, sr, target_sample_rate)
583
+
584
+ # Handle mono/stereo: VAE expects 2 channels
585
+ if waveform.shape[0] == 1:
586
+ # Duplicate mono to stereo
587
+ waveform = waveform.repeat(2, 1)
588
+ elif waveform.shape[0] > 2:
589
+ # Take first 2 channels if more than stereo
590
+ waveform = waveform[:2, :]
591
+
592
+ # Add batch dimension: [channels, samples] -> [batch, channels, samples]
593
+ waveform = waveform.unsqueeze(0)
594
+
595
+ n_fft = 1024
596
+ # Mel spectrogram parameters matching Wan2GP's AudioProcessor exactly
597
+ mel_transform = T.MelSpectrogram(
598
+ sample_rate=target_sample_rate,
599
+ n_fft=n_fft,
600
+ win_length=n_fft,
601
+ hop_length=self.audio_hop_length,
602
+ f_min=0.0,
603
+ f_max=target_sample_rate / 2.0,
604
+ n_mels=self.audio_vae.config.mel_bins,
605
+ window_fn=torch.hann_window,
606
+ center=True,
607
+ pad_mode="reflect",
608
+ power=1.0, # Important: power=1.0, not 2.0
609
+ mel_scale="slaney",
610
+ norm="slaney",
611
+ )
612
+
613
+ # waveform shape: [batch, channels, samples]
614
+ # mel_spec shape after transform: [batch, channels, mel_bins, time]
615
+ mel_spec = mel_transform(waveform)
616
+
617
+ # Log scaling
618
+ mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
619
+
620
+ # Permute to [batch, channels, time, mel_bins] as expected by VAE
621
+ mel_spec = mel_spec.permute(0, 1, 3, 2).contiguous()
622
+
623
+ return mel_spec
624
+
625
+ def prepare_latents(
626
+ self,
627
+ image: Optional[torch.Tensor] = None,
628
+ batch_size: int = 1,
629
+ num_channels_latents: int = 128,
630
+ height: int = 512,
631
+ width: int = 704,
632
+ num_frames: int = 161,
633
+ dtype: Optional[torch.dtype] = None,
634
+ device: Optional[torch.device] = None,
635
+ generator: Optional[torch.Generator] = None,
636
+ latents: Optional[torch.Tensor] = None,
637
+ ) -> torch.Tensor:
638
+ height = height // self.vae_spatial_compression_ratio
639
+ width = width // self.vae_spatial_compression_ratio
640
+ num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
641
+
642
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
643
+ mask_shape = (batch_size, 1, num_frames, height, width)
644
+
645
+ if latents is not None:
646
+ conditioning_mask = latents.new_zeros(mask_shape)
647
+ conditioning_mask[:, :, 0] = 1.0
648
+ conditioning_mask = self._pack_latents(
649
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
650
+ ).squeeze(-1)
651
+ if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
652
+ raise ValueError(
653
+ f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
654
+ )
655
+ return latents.to(device=device, dtype=dtype), conditioning_mask
656
+
657
+ if isinstance(generator, list):
658
+ if len(generator) != batch_size:
659
+ raise ValueError(
660
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
661
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
662
+ )
663
+
664
+ init_latents = [
665
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax")
666
+ for i in range(batch_size)
667
+ ]
668
+ else:
669
+ init_latents = [
670
+ retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image
671
+ ]
672
+
673
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
674
+ init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
675
+ init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
676
+
677
+ conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
678
+ conditioning_mask[:, :, 0] = 1.0
679
+
680
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
681
+ latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
682
+
683
+ conditioning_mask = self._pack_latents(
684
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
685
+ ).squeeze(-1)
686
+ latents = self._pack_latents(
687
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
688
+ )
689
+
690
+ return latents, conditioning_mask
691
+
692
+ def prepare_audio_latents(
693
+ self,
694
+ batch_size: int = 1,
695
+ num_channels_latents: int = 8,
696
+ num_mel_bins: int = 64,
697
+ num_frames: int = 121,
698
+ frame_rate: float = 25.0,
699
+ sampling_rate: int = 16000,
700
+ hop_length: int = 160,
701
+ dtype: Optional[torch.dtype] = None,
702
+ device: Optional[torch.device] = None,
703
+ generator: Optional[torch.Generator] = None,
704
+ audio_input: Optional[Union[str, torch.Tensor]] = None,
705
+ latents: Optional[torch.Tensor] = None,
706
+ ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
707
+ duration_s = num_frames / frame_rate
708
+ latents_per_second = (
709
+ float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
710
+ )
711
+ target_length = round(duration_s * latents_per_second)
712
+
713
+ if latents is not None:
714
+ return latents.to(device=device, dtype=dtype), target_length, None
715
+
716
+ latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
717
+
718
+ if audio_input is not None:
719
+ mel_spec = self._preprocess_audio(audio_input, sampling_rate).to(device=device)
720
+
721
+ # Encode the mel spectrogram to latents (use VAE's dtype for encoding)
722
+ mel_spec = mel_spec.to(dtype=self.audio_vae.dtype)
723
+ init_latents = self.audio_vae.encode(mel_spec).latent_dist.sample(generator)
724
+ init_latents = init_latents.to(dtype=dtype)
725
+
726
+ # Normalize: patchify -> normalize -> unpatchify
727
+ # init_latents shape: [B, C, T, F] where C=latent_channels, F=latent_mel_bins
728
+ latent_channels = init_latents.shape[1]
729
+ latent_freq = init_latents.shape[3]
730
+ init_latents_patched = self._patchify_audio_latents(init_latents) # [B, T, C*F]
731
+ init_latents_patched = self._normalize_audio_latents(
732
+ init_latents_patched, self.audio_vae.latents_mean, self.audio_vae.latents_std
733
+ )
734
+ init_latents = self._unpatchify_audio_latents(init_latents_patched, latent_channels, latent_freq) # [B, C, T, F]
735
+
736
+ current_len = init_latents.shape[2]
737
+ if current_len < target_length:
738
+ padding = target_length - current_len
739
+ init_latents = torch.nn.functional.pad(init_latents, (0, 0, 0, padding))
740
+ elif current_len > target_length:
741
+ init_latents = init_latents[:, :, :target_length, :]
742
+
743
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
744
+
745
+ if init_latents.shape[0] != batch_size:
746
+ init_latents = init_latents.repeat(batch_size, 1, 1, 1)
747
+ noise = noise.repeat(batch_size, 1, 1, 1)
748
+
749
+ packed_noise = self._pack_audio_latents(noise)
750
+
751
+ return packed_noise, target_length, init_latents
752
+
753
+ shape = (batch_size, num_channels_latents, target_length, latent_mel_bins)
754
+ if isinstance(generator, list) and len(generator) != batch_size:
755
+ raise ValueError("Generator size mismatch")
756
+
757
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
758
+ latents = self._pack_audio_latents(latents)
759
+
760
+ return latents, target_length, None
761
+
762
+ @property
763
+ def guidance_scale(self):
764
+ return self._guidance_scale
765
+
766
+ @property
767
+ def guidance_rescale(self):
768
+ return self._guidance_rescale
769
+
770
+ @property
771
+ def do_classifier_free_guidance(self):
772
+ return self._guidance_scale > 1.0
773
+
774
+ @property
775
+ def num_timesteps(self):
776
+ return self._num_timesteps
777
+
778
+ @property
779
+ def current_timestep(self):
780
+ return self._current_timestep
781
+
782
+ @property
783
+ def attention_kwargs(self):
784
+ return self._attention_kwargs
785
+
786
+ @property
787
+ def interrupt(self):
788
+ return self._interrupt
789
+
790
+ def _get_audio_duration(self, audio: Union[str, torch.Tensor], sample_rate: int) -> float:
791
+ """Get duration of audio in seconds."""
792
+ if isinstance(audio, str):
793
+ info = torchaudio.info(audio)
794
+ return info.num_frames / info.sample_rate
795
+ else:
796
+ # audio is a tensor with shape [channels, samples] or [samples]
797
+ num_samples = audio.shape[-1]
798
+ return num_samples / sample_rate
799
+
800
+ @torch.no_grad()
801
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
802
+ def __call__(
803
+ self,
804
+ image: PipelineImageInput = None,
805
+ audio: Optional[Union[str, torch.Tensor]] = None,
806
+ prompt: Union[str, List[str]] = None,
807
+ negative_prompt: Optional[Union[str, List[str]]] = None,
808
+ height: int = 512,
809
+ width: int = 768,
810
+ num_frames: Optional[int] = None,
811
+ max_frames: int = 257,
812
+ frame_rate: float = 24.0,
813
+ num_inference_steps: int = 40,
814
+ timesteps: List[int] = None,
815
+ sigmas: Optional[List[float]] = None,
816
+ guidance_scale: float = 4.0,
817
+ guidance_rescale: float = 0.0,
818
+ num_videos_per_prompt: Optional[int] = 1,
819
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
820
+ latents: Optional[torch.Tensor] = None,
821
+ audio_latents: Optional[torch.Tensor] = None,
822
+ prompt_embeds: Optional[torch.Tensor] = None,
823
+ prompt_attention_mask: Optional[torch.Tensor] = None,
824
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
825
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
826
+ decode_timestep: Union[float, List[float]] = 0.0,
827
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
828
+ output_type: Optional[str] = "pil",
829
+ return_dict: bool = True,
830
+ attention_kwargs: Optional[Dict[str, Any]] = None,
831
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
832
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
833
+ max_sequence_length: int = 1024,
834
+ ):
835
+ r"""
836
+ Function invoked when calling the pipeline for generation.
837
+
838
+ Args:
839
+ image (`PipelineImageInput`):
840
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
841
+ audio (`str` or `torch.Tensor`, *optional*):
842
+ The input audio to condition the generation on. Can be a path to an audio file or a waveform tensor.
843
+ When provided, the generated video will be synchronized to this audio input.
844
+ prompt (`str` or `List[str]`, *optional*):
845
+ The prompt or prompts to guide the image generation.
846
+ negative_prompt (`str` or `List[str]`, *optional*):
847
+ The prompt or prompts not to guide the image generation.
848
+ height (`int`, *optional*, defaults to `512`):
849
+ The height in pixels of the generated image.
850
+ width (`int`, *optional*, defaults to `768`):
851
+ The width in pixels of the generated image.
852
+ num_frames (`int`, *optional*):
853
+ The number of video frames to generate. If not provided and audio is given,
854
+ it will be calculated from the audio duration. Otherwise defaults to 121.
855
+ max_frames (`int`, *optional*, defaults to `257`):
856
+ Maximum number of frames to generate. Used to cap the calculated frames
857
+ when deriving from audio duration. 257 frames at 25fps is ~10 seconds.
858
+ frame_rate (`float`, *optional*, defaults to `24.0`):
859
+ The frames per second (FPS) of the generated video.
860
+ num_inference_steps (`int`, *optional*, defaults to 40):
861
+ The number of denoising steps.
862
+ timesteps (`List[int]`, *optional*):
863
+ Custom timesteps to use for the denoising process.
864
+ sigmas (`List[float]`, *optional*):
865
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
866
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
867
+ will be used.
868
+ guidance_scale (`float`, *optional*, defaults to `4.0`):
869
+ Guidance scale for classifier-free guidance.
870
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
871
+ Guidance rescale factor.
872
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
873
+ The number of videos to generate per prompt.
874
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
875
+ Random generator(s) for reproducibility.
876
+ latents (`torch.Tensor`, *optional*):
877
+ Pre-generated noisy latents for video generation.
878
+ audio_latents (`torch.Tensor`, *optional*):
879
+ Pre-generated noisy latents for audio generation.
880
+ prompt_embeds (`torch.Tensor`, *optional*):
881
+ Pre-generated text embeddings.
882
+ prompt_attention_mask (`torch.Tensor`, *optional*):
883
+ Pre-generated attention mask for text embeddings.
884
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
885
+ Pre-generated negative text embeddings.
886
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
887
+ Pre-generated attention mask for negative text embeddings.
888
+ decode_timestep (`float`, defaults to `0.0`):
889
+ The timestep at which generated video is decoded.
890
+ decode_noise_scale (`float`, defaults to `None`):
891
+ The interpolation factor between random noise and denoised latents.
892
+ output_type (`str`, *optional*, defaults to `"pil"`):
893
+ The output format of the generated image.
894
+ return_dict (`bool`, *optional*, defaults to `True`):
895
+ Whether to return a `LTX2PipelineOutput` instead of a plain tuple.
896
+ attention_kwargs (`dict`, *optional*):
897
+ Kwargs dictionary for the attention processor.
898
+ callback_on_step_end (`Callable`, *optional*):
899
+ A function called at the end of each denoising step.
900
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
901
+ The list of tensor inputs for the `callback_on_step_end` function.
902
+ max_sequence_length (`int`, *optional*, defaults to `1024`):
903
+ Maximum sequence length to use with the `prompt`.
904
+
905
+ Examples:
906
+
907
+ Returns:
908
+ [`LTX2PipelineOutput`] or `tuple`:
909
+ If `return_dict` is `True`, `LTX2PipelineOutput` is returned, otherwise a `tuple`.
910
+ """
911
+
912
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
913
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
914
+
915
+ # Calculate num_frames from audio duration if not provided
916
+ if num_frames is None:
917
+ if audio is not None:
918
+ # Get audio duration and calculate frames
919
+ audio_duration = self._get_audio_duration(audio, self.audio_sampling_rate)
920
+ # Calculate frames needed for this audio duration
921
+ # Add 1 because frames = duration * fps + 1 (first frame at t=0)
922
+ calculated_frames = int(audio_duration * frame_rate) + 1
923
+ # Cap at max_frames and ensure it's valid for the model
924
+ # LTX2 requires (num_frames - 1) to be divisible by temporal_compression_ratio (8)
925
+ num_frames = min(calculated_frames, max_frames)
926
+ # Adjust to valid frame count: (num_frames - 1) % 8 == 0
927
+ num_frames = ((num_frames - 1) // self.vae_temporal_compression_ratio) * self.vae_temporal_compression_ratio + 1
928
+ num_frames = max(num_frames, 9) # Minimum valid frame count
929
+ logger.info(f"Audio duration: {audio_duration:.2f}s -> num_frames: {num_frames}")
930
+ else:
931
+ num_frames = 121 # Default
932
+
933
+ self.check_inputs(
934
+ prompt=prompt,
935
+ height=height,
936
+ width=width,
937
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
938
+ prompt_embeds=prompt_embeds,
939
+ negative_prompt_embeds=negative_prompt_embeds,
940
+ prompt_attention_mask=prompt_attention_mask,
941
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
942
+ )
943
+
944
+ self._guidance_scale = guidance_scale
945
+ self._guidance_rescale = guidance_rescale
946
+ self._attention_kwargs = attention_kwargs
947
+ self._interrupt = False
948
+ self._current_timestep = None
949
+
950
+ if prompt is not None and isinstance(prompt, str):
951
+ batch_size = 1
952
+ elif prompt is not None and isinstance(prompt, list):
953
+ batch_size = len(prompt)
954
+ else:
955
+ batch_size = prompt_embeds.shape[0]
956
+
957
+ device = self._execution_device
958
+
959
+ (
960
+ prompt_embeds,
961
+ prompt_attention_mask,
962
+ negative_prompt_embeds,
963
+ negative_prompt_attention_mask,
964
+ ) = self.encode_prompt(
965
+ prompt=prompt,
966
+ negative_prompt=negative_prompt,
967
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
968
+ num_videos_per_prompt=num_videos_per_prompt,
969
+ prompt_embeds=prompt_embeds,
970
+ negative_prompt_embeds=negative_prompt_embeds,
971
+ prompt_attention_mask=prompt_attention_mask,
972
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
973
+ max_sequence_length=max_sequence_length,
974
+ device=device,
975
+ )
976
+ if self.do_classifier_free_guidance:
977
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
978
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
979
+
980
+ additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
981
+ connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
982
+ prompt_embeds, additive_attention_mask, additive_mask=True
983
+ )
984
+
985
+ if latents is None:
986
+ image = self.video_processor.preprocess(image, height=height, width=width)
987
+ image = image.to(device=device, dtype=prompt_embeds.dtype)
988
+
989
+ num_channels_latents = self.transformer.config.in_channels
990
+ latents, conditioning_mask = self.prepare_latents(
991
+ image,
992
+ batch_size * num_videos_per_prompt,
993
+ num_channels_latents,
994
+ height,
995
+ width,
996
+ num_frames,
997
+ torch.float32,
998
+ device,
999
+ generator,
1000
+ latents,
1001
+ )
1002
+ if self.do_classifier_free_guidance:
1003
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
1004
+
1005
+ num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
1006
+ latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
1007
+
1008
+ num_channels_latents_audio = (
1009
+ self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
1010
+ )
1011
+
1012
+ audio_latents, audio_num_frames, clean_audio_latents = self.prepare_audio_latents(
1013
+ batch_size * num_videos_per_prompt,
1014
+ num_channels_latents=num_channels_latents_audio,
1015
+ num_mel_bins=num_mel_bins,
1016
+ num_frames=num_frames,
1017
+ frame_rate=frame_rate,
1018
+ sampling_rate=self.audio_sampling_rate,
1019
+ hop_length=self.audio_hop_length,
1020
+ dtype=torch.float32,
1021
+ device=device,
1022
+ generator=generator,
1023
+ latents=audio_latents,
1024
+ audio_input=audio,
1025
+ )
1026
+
1027
+ # If clean audio latents are provided, pack them for use in the transformer
1028
+ # This is the key fix: we pass clean (not noisy) audio latents to the transformer
1029
+ packed_clean_audio_latents = None
1030
+ if clean_audio_latents is not None:
1031
+ packed_clean_audio_latents = self._pack_audio_latents(clean_audio_latents)
1032
+
1033
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1034
+ latent_height = height // self.vae_spatial_compression_ratio
1035
+ latent_width = width // self.vae_spatial_compression_ratio
1036
+ video_sequence_length = latent_num_frames * latent_height * latent_width
1037
+
1038
+ if sigmas is None:
1039
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1040
 
1041
+ mu = calculate_shift(
1042
+ video_sequence_length,
1043
+ self.scheduler.config.get("base_image_seq_len", 1024),
1044
+ self.scheduler.config.get("max_image_seq_len", 4096),
1045
+ self.scheduler.config.get("base_shift", 0.95),
1046
+ self.scheduler.config.get("max_shift", 2.05),
1047
+ )
1048
+
1049
+ audio_scheduler = copy.deepcopy(self.scheduler)
1050
+ _, _ = retrieve_timesteps(
1051
+ audio_scheduler,
1052
+ num_inference_steps,
1053
+ device,
1054
+ timesteps,
1055
+ sigmas=sigmas,
1056
+ mu=mu,
1057
+ )
1058
+ timesteps, num_inference_steps = retrieve_timesteps(
1059
+ self.scheduler,
1060
+ num_inference_steps,
1061
+ device,
1062
+ timesteps,
1063
+ sigmas=sigmas,
1064
+ mu=mu,
1065
+ )
1066
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1067
+ self._num_timesteps = len(timesteps)
1068
+
1069
+ rope_interpolation_scale = (
1070
+ self.vae_temporal_compression_ratio / frame_rate,
1071
+ self.vae_spatial_compression_ratio,
1072
+ self.vae_spatial_compression_ratio,
1073
+ )
1074
+ video_coords = self.transformer.rope.prepare_video_coords(
1075
+ latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
1076
+ )
1077
+ audio_coords = self.transformer.audio_rope.prepare_audio_coords(
1078
+ audio_latents.shape[0], audio_num_frames, audio_latents.device
1079
+ )
1080
+
1081
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1082
+ for i, t in enumerate(timesteps):
1083
+ if self.interrupt:
1084
+ continue
1085
+
1086
+ self._current_timestep = t
1087
+
1088
+ # When audio conditioning is provided, use clean audio latents directly
1089
+ # (not noisy). This matches Wan2GP's approach where audio with denoise_mask=0
1090
+ # stays constant throughout denoising.
1091
+ if packed_clean_audio_latents is not None:
1092
+ audio_latents_input = packed_clean_audio_latents.to(dtype=prompt_embeds.dtype)
1093
+ else:
1094
+ audio_latents_input = audio_latents.to(dtype=prompt_embeds.dtype)
1095
+
1096
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1097
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1098
+ audio_latent_model_input = (
1099
+ torch.cat([audio_latents_input] * 2) if self.do_classifier_free_guidance else audio_latents_input
1100
+ )
1101
+ audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
1102
+
1103
+ timestep = t.expand(latent_model_input.shape[0])
1104
+ video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
1105
+
1106
+ # When audio conditioning is provided, set audio_timestep to 0.
1107
+ # This tells the transformer the audio is already "denoised" (clean),
1108
+ # matching Wan2GP's approach where denoise_mask=0 results in timestep=0.
1109
+ if packed_clean_audio_latents is not None:
1110
+ audio_timestep = torch.zeros_like(timestep)
1111
+ else:
1112
+ audio_timestep = timestep
1113
+
1114
+ with self.transformer.cache_context("cond_uncond"):
1115
+ noise_pred_video, noise_pred_audio = self.transformer(
1116
+ hidden_states=latent_model_input,
1117
+ audio_hidden_states=audio_latent_model_input,
1118
+ encoder_hidden_states=connector_prompt_embeds,
1119
+ audio_encoder_hidden_states=connector_audio_prompt_embeds,
1120
+ timestep=video_timestep,
1121
+ audio_timestep=audio_timestep,
1122
+ encoder_attention_mask=connector_attention_mask,
1123
+ audio_encoder_attention_mask=connector_attention_mask,
1124
+ num_frames=latent_num_frames,
1125
+ height=latent_height,
1126
+ width=latent_width,
1127
+ fps=frame_rate,
1128
+ audio_num_frames=audio_num_frames,
1129
+ video_coords=video_coords,
1130
+ audio_coords=audio_coords,
1131
+ attention_kwargs=attention_kwargs,
1132
+ return_dict=False,
1133
+ )
1134
+ noise_pred_video = noise_pred_video.float()
1135
+ noise_pred_audio = noise_pred_audio.float()
1136
+
1137
+ if self.do_classifier_free_guidance:
1138
+ noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
1139
+ noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
1140
+ noise_pred_video_text - noise_pred_video_uncond
1141
+ )
1142
+
1143
+ noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
1144
+ noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
1145
+ noise_pred_audio_text - noise_pred_audio_uncond
1146
+ )
1147
+
1148
+ if self.guidance_rescale > 0:
1149
+ noise_pred_video = rescale_noise_cfg(
1150
+ noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
1151
+ )
1152
+ noise_pred_audio = rescale_noise_cfg(
1153
+ noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
1154
+ )
1155
+
1156
+ noise_pred_video = self._unpack_latents(
1157
+ noise_pred_video,
1158
+ latent_num_frames,
1159
+ latent_height,
1160
+ latent_width,
1161
+ self.transformer_spatial_patch_size,
1162
+ self.transformer_temporal_patch_size,
1163
+ )
1164
+ latents = self._unpack_latents(
1165
+ latents,
1166
+ latent_num_frames,
1167
+ latent_height,
1168
+ latent_width,
1169
+ self.transformer_spatial_patch_size,
1170
+ self.transformer_temporal_patch_size,
1171
+ )
1172
+
1173
+ noise_pred_video = noise_pred_video[:, :, 1:]
1174
+ noise_latents = latents[:, :, 1:]
1175
+ pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0]
1176
+
1177
+ latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
1178
+ latents = self._pack_latents(
1179
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
1180
+ )
1181
+
1182
+ # Only step audio latents when not using audio conditioning.
1183
+ # When audio conditioning is provided, we keep the clean latents constant.
1184
+ if packed_clean_audio_latents is None:
1185
+ audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
1186
+ # else: audio stays constant (packed_clean_audio_latents)
1187
+
1188
+ if callback_on_step_end is not None:
1189
+ callback_kwargs = {}
1190
+ for k in callback_on_step_end_tensor_inputs:
1191
+ callback_kwargs[k] = locals()[k]
1192
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1193
+
1194
+ latents = callback_outputs.pop("latents", latents)
1195
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1196
+
1197
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1198
+ progress_bar.update()
1199
+
1200
+ if XLA_AVAILABLE:
1201
+ xm.mark_step()
1202
+
1203
+ latents = self._unpack_latents(
1204
+ latents,
1205
+ latent_num_frames,
1206
+ latent_height,
1207
+ latent_width,
1208
+ self.transformer_spatial_patch_size,
1209
+ self.transformer_temporal_patch_size,
1210
+ )
1211
+ latents = self._denormalize_latents(
1212
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
1213
+ )
1214
+
1215
+ # Denormalize audio latents for decoding
1216
+ # Need to: patchify -> denormalize -> unpatchify (inverse of normalization)
1217
+ if clean_audio_latents is not None:
1218
+ # clean_audio_latents is in 4D format [B, C, T, F]
1219
+ latent_channels = clean_audio_latents.shape[1]
1220
+ latent_freq = clean_audio_latents.shape[3]
1221
+ audio_patched = self._patchify_audio_latents(clean_audio_latents) # [B, T, C*F]
1222
+ audio_patched = self._denormalize_audio_latents(
1223
+ audio_patched, self.audio_vae.latents_mean, self.audio_vae.latents_std
1224
+ )
1225
+ audio_latents_for_decode = self._unpatchify_audio_latents(audio_patched, latent_channels, latent_freq)
1226
+ else:
1227
+ # audio_latents is in packed format [B, T, C*F] from the denoising loop
1228
+ audio_latents_for_decode = self._denormalize_audio_latents(
1229
+ audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
1230
+ )
1231
+ # Unpack to 4D format [B, C, T, F]
1232
+ audio_latents_for_decode = self._unpack_audio_latents(
1233
+ audio_latents_for_decode, audio_num_frames, num_mel_bins=latent_mel_bins
1234
+ )
1235
+
1236
+ if output_type == "latent":
1237
+ video = latents
1238
+ audio = audio_latents_for_decode
1239
+ else:
1240
+ latents = latents.to(prompt_embeds.dtype)
1241
+
1242
+ if not self.vae.config.timestep_conditioning:
1243
+ timestep = None
1244
+ else:
1245
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
1246
+ if not isinstance(decode_timestep, list):
1247
+ decode_timestep = [decode_timestep] * batch_size
1248
+ if decode_noise_scale is None:
1249
+ decode_noise_scale = decode_timestep
1250
+ elif not isinstance(decode_noise_scale, list):
1251
+ decode_noise_scale = [decode_noise_scale] * batch_size
1252
+
1253
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
1254
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
1255
+ :, None, None, None, None
1256
+ ]
1257
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
1258
+
1259
+ latents = latents.to(self.vae.dtype)
1260
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
1261
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
1262
+
1263
+ audio_latents_for_decode = audio_latents_for_decode.to(self.audio_vae.dtype)
1264
+ generated_mel_spectrograms = self.audio_vae.decode(audio_latents_for_decode, return_dict=False)[0]
1265
+ audio = self.vocoder(generated_mel_spectrograms)
1266
+
1267
+ self.maybe_free_model_hooks()
1268
+
1269
+ if not return_dict:
1270
+ return (video, audio)
1271
+
1272
+ return LTX2PipelineOutput(frames=video, audio=audio)