Parler_TTS_API / api.py
aspirant312's picture
Fix WebSocket streaming - ParlerTTSStreamer yields numpy arrays not tensors
97e03da
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
import json
import torch
import numpy as np
import re
from io import BytesIO
import soundfile as sf
from pydantic import BaseModel
import os
from huggingface_hub import login
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
from transformers import AutoTokenizer
from threading import Thread
import queue
# Authenticate with HuggingFace if token is available
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Try to import spaces for HF Spaces deployment
try:
import spaces
HAS_SPACES = True
except ImportError:
HAS_SPACES = False
class _NoOpSpaces:
def GPU(self, *args, **kwargs):
def decorator(fn):
return fn
return decorator
spaces = _NoOpSpaces()
# --- Model Loading ---
MODEL_ID = "ai4bharat/indic-parler-tts"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
if DEVICE == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print("Loading Indic Parler-TTS model...")
model = ParlerTTSForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)
# Optimize model for inference
if DEVICE == "cuda":
model = model.half() # Use half precision (fp16) for faster inference
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
SAMPLE_RATE = model.config.sampling_rate
# Disable gradients for inference
torch.set_grad_enabled(False)
print("Model loaded and optimized!")
# Named speakers
SPEAKERS = {
"Divya": "Divya",
"Rani": "Rani",
"Rohit": "Rohit",
"Aman": "Aman",
"Generic Female": "",
"Generic Male": "",
}
app = FastAPI(title="Parler TTS API", version="1.0")
class TTSRequest(BaseModel):
text: str
speaker: str = "Divya"
pitch: str = "Moderate"
rate: str = "Moderate"
temperature: float = 0.8
do_sample: bool = True
def build_description(speaker_name, gender, pitch, rate):
"""Build voice description prompt."""
if speaker_name:
return (
f"{speaker_name}'s voice delivers a slightly expressive speech "
f"with a {pitch.lower()} pitch and a {rate.lower()} speaking rate. "
f"The recording is of very high quality, with the speaker's voice sounding clear "
f"and very close up. Very clear audio."
)
else:
return (
f"A {gender.lower()} speaker delivers a slightly expressive and clear speech "
f"with a {pitch.lower()} pitch and a {rate.lower()} speaking rate. "
f"The recording is of very high quality, with the speaker's voice sounding clear "
f"and very close up. Very clear audio."
)
def split_sentences(text):
"""Split Urdu text into sentences."""
sentences = re.split(r'[۔।\.\!\?]+', text)
return [s.strip() for s in sentences if s.strip()]
def clean_urdu_text(text):
"""Minimal text cleaning - preserve content."""
text = re.sub(r'\s+', ' ', text).strip()
if text and text[-1] not in '۔.!?،':
text += '۔'
return text
@spaces.GPU()
def generate_speech_internal(text, speaker, pitch, rate, temperature, do_sample):
"""Internal function for speech generation - sentence by sentence with fp16."""
if not text.strip():
return None
try:
text = clean_urdu_text(text)
speaker_name = SPEAKERS.get(speaker, "")
gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male"
description = build_description(speaker_name, gender, pitch, rate)
sentences = split_sentences(text)
if not sentences:
sentences = [text.strip()]
all_audio = []
seed = torch.randint(0, 2**32, (1,)).item()
for sentence in sentences:
desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE)
prompt_tokens = tokenizer(sentence, return_tensors="pt").to(DEVICE)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
with torch.no_grad():
generation = model.generate(
input_ids=desc_tokens.input_ids,
attention_mask=desc_tokens.attention_mask,
prompt_input_ids=prompt_tokens.input_ids,
prompt_attention_mask=prompt_tokens.attention_mask,
do_sample=do_sample,
temperature=temperature if do_sample else 1.0,
min_new_tokens=10,
)
audio_chunk = generation.cpu().numpy().squeeze()
audio_chunk = (audio_chunk * 32767).astype(np.int16)
all_audio.append(audio_chunk)
# Add 0.3s silence between sentences
silence = np.zeros(int(SAMPLE_RATE * 0.3), dtype=np.int16)
all_audio.append(silence)
if not all_audio:
return None
audio = np.concatenate(all_audio)
return audio
except Exception as e:
print(f"Error generating speech: {e}")
import traceback
traceback.print_exc()
return None
@app.get("/")
async def root():
"""Health check endpoint."""
return {
"status": "ok",
"model": "Indic Parler-TTS",
"speakers": list(SPEAKERS.keys()),
"sample_rate": SAMPLE_RATE,
"endpoints": {
"POST /tts": "Standard TTS (wait for full audio)",
"POST /tts/stream": "HTTP Streaming TTS (audio chunks in real-time)",
"WS /ws/tts": "WebSocket TTS (BEST FOR PIPECAT - true real-time bidirectional)",
"GET /speakers": "List available speakers"
},
"device": DEVICE,
"optimization": "fp16 (half precision)"
}
def generate_audio_chunks_streaming(text, speaker, pitch, rate, temperature, do_sample):
"""Generate audio using official ParlerTTSStreamer for true streaming."""
text = clean_urdu_text(text)
speaker_name = SPEAKERS.get(speaker, "")
gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male"
description = build_description(speaker_name, gender, pitch, rate)
# Create streamer for real-time audio chunks
play_steps = int(model.config.sampling_rate * 0.5) # 0.5 second chunks
streamer = ParlerTTSStreamer(model, device=DEVICE, play_steps=play_steps)
# Tokenize
desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE)
prompt_tokens = tokenizer(text, return_tensors="pt").to(DEVICE)
# Set seed
seed = torch.randint(0, 2**32, (1,)).item()
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Generate in background thread
generation_kwargs = dict(
input_ids=desc_tokens.input_ids,
attention_mask=desc_tokens.attention_mask,
prompt_input_ids=prompt_tokens.input_ids,
prompt_attention_mask=prompt_tokens.attention_mask,
streamer=streamer,
do_sample=do_sample,
temperature=temperature if do_sample else 1.0,
min_new_tokens=10,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.daemon = True
thread.start()
# Yield audio chunks as they're generated
for audio_chunk in streamer:
if audio_chunk.shape[0] > 0:
# ParlerTTSStreamer yields numpy arrays directly, not tensors
audio_int16 = (audio_chunk * 32767).astype(np.int16)
yield audio_int16
thread.join()
async def generate_audio_stream(text, speaker, pitch, rate, temperature, do_sample):
"""Stream audio generation sentence by sentence."""
text = clean_urdu_text(text)
speaker_name = SPEAKERS.get(speaker, "")
gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male"
description = build_description(speaker_name, gender, pitch, rate)
sentences = split_sentences(text)
if not sentences:
sentences = [text.strip()]
# Collect all audio chunks
all_audio = []
seed = torch.randint(0, 2**32, (1,)).item()
for sentence in sentences:
desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE)
prompt_tokens = tokenizer(sentence, return_tensors="pt").to(DEVICE)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
with torch.no_grad():
generation = model.generate(
input_ids=desc_tokens.input_ids,
attention_mask=desc_tokens.attention_mask,
prompt_input_ids=prompt_tokens.input_ids,
prompt_attention_mask=prompt_tokens.attention_mask,
do_sample=do_sample,
temperature=temperature if do_sample else 1.0,
min_new_tokens=10,
)
audio_chunk = generation.cpu().numpy().squeeze()
audio_chunk = (audio_chunk * 32767).astype(np.int16)
all_audio.append(audio_chunk)
# Add 0.3s silence between sentences
silence = np.zeros(int(SAMPLE_RATE * 0.3), dtype=np.int16)
all_audio.append(silence)
# Convert to WAV
audio = np.concatenate(all_audio)
audio_buffer = BytesIO()
sf.write(audio_buffer, audio, SAMPLE_RATE, format='WAV')
audio_buffer.seek(0)
return audio_buffer.getvalue()
@app.post("/tts/stream")
async def text_to_speech_streaming(request: TTSRequest):
"""Generate speech with real-time streaming (fastest latency)."""
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
if request.speaker not in SPEAKERS:
raise HTTPException(status_code=400, detail=f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}")
async def audio_generator():
"""Generator that yields audio chunks and WAV header."""
import struct
try:
# WAV header will be written first
wav_header_written = False
for audio_chunk in generate_audio_chunks_streaming(
request.text,
request.speaker,
request.pitch,
request.rate,
request.temperature,
request.do_sample
):
if not wav_header_written:
# Write WAV header on first chunk
channels = 1
sample_width = 2
framerate = SAMPLE_RATE
audio_buffer = BytesIO()
sf.write(audio_buffer, audio_chunk, SAMPLE_RATE, format='WAV')
audio_buffer.seek(0)
wav_data = audio_buffer.read()
yield wav_data
wav_header_written = True
else:
# For subsequent chunks, just append raw audio data
yield audio_chunk.tobytes()
except Exception as e:
print(f"Error in streaming: {e}")
import traceback
traceback.print_exc()
return StreamingResponse(
audio_generator(),
media_type="audio/wav",
headers={"Content-Disposition": "inline; filename=speech.wav"}
)
@app.post("/tts")
async def text_to_speech(request: TTSRequest):
"""Generate speech from Urdu text."""
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
if request.speaker not in SPEAKERS:
raise HTTPException(status_code=400, detail=f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}")
try:
audio_data = await generate_audio_stream(
request.text,
request.speaker,
request.pitch,
request.rate,
request.temperature,
request.do_sample
)
if audio_data is None:
raise HTTPException(status_code=500, detail="Failed to generate speech")
return StreamingResponse(
iter([audio_data]),
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=speech.wav"}
)
except Exception as e:
print(f"Error in TTS endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/speakers")
async def get_speakers():
"""Get list of available speakers."""
return {"speakers": list(SPEAKERS.keys())}
@app.websocket("/ws/tts")
async def websocket_tts(websocket: WebSocket):
"""WebSocket endpoint for real-time audio streaming.
Usage:
1. Connect to ws://localhost:7860/ws/tts
2. Send JSON: {"text": "سلام دنیا", "speaker": "Divya"}
3. Receive audio chunks in real-time
4. Connection closes when generation completes
"""
await websocket.accept()
try:
data = await websocket.receive_text()
request_data = json.loads(data)
text = request_data.get("text", "").strip()
speaker = request_data.get("speaker", "Divya")
pitch = request_data.get("pitch", "Moderate")
rate = request_data.get("rate", "Moderate")
temperature = request_data.get("temperature", 0.8)
do_sample = request_data.get("do_sample", True)
# Validate inputs
if not text:
await websocket.send_json({"error": "Text cannot be empty"})
await websocket.close()
return
if speaker not in SPEAKERS:
await websocket.send_json({
"error": f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}"
})
await websocket.close()
return
# Send status message
await websocket.send_json({
"status": "generating",
"message": f"Generating audio for speaker {speaker}..."
})
# Generate and stream audio chunks
chunk_count = 0
try:
for audio_chunk in generate_audio_chunks_streaming(
text, speaker, pitch, rate, temperature, do_sample
):
# Send audio chunk as binary data
await websocket.send_bytes(audio_chunk.tobytes())
chunk_count += 1
# Send completion message
await websocket.send_json({
"status": "complete",
"chunks_sent": chunk_count
})
except Exception as e:
await websocket.send_json({
"error": f"Generation failed: {str(e)}"
})
except WebSocketDisconnect:
print("WebSocket client disconnected")
except json.JSONDecodeError:
await websocket.send_json({"error": "Invalid JSON format"})
await websocket.close()
except Exception as e:
print(f"WebSocket error: {e}")
await websocket.close()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)