from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from fastapi.responses import StreamingResponse import json import torch import numpy as np import re from io import BytesIO import soundfile as sf from pydantic import BaseModel import os from huggingface_hub import login from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer from transformers import AutoTokenizer from threading import Thread import queue # Authenticate with HuggingFace if token is available hf_token = os.getenv("HF_TOKEN") if hf_token: login(token=hf_token) # Try to import spaces for HF Spaces deployment try: import spaces HAS_SPACES = True except ImportError: HAS_SPACES = False class _NoOpSpaces: def GPU(self, *args, **kwargs): def decorator(fn): return fn return decorator spaces = _NoOpSpaces() # --- Model Loading --- MODEL_ID = "ai4bharat/indic-parler-tts" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {DEVICE}") if DEVICE == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print("Loading Indic Parler-TTS model...") model = ParlerTTSForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE) # Optimize model for inference if DEVICE == "cuda": model = model.half() # Use half precision (fp16) for faster inference model.eval() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path) SAMPLE_RATE = model.config.sampling_rate # Disable gradients for inference torch.set_grad_enabled(False) print("Model loaded and optimized!") # Named speakers SPEAKERS = { "Divya": "Divya", "Rani": "Rani", "Rohit": "Rohit", "Aman": "Aman", "Generic Female": "", "Generic Male": "", } app = FastAPI(title="Parler TTS API", version="1.0") class TTSRequest(BaseModel): text: str speaker: str = "Divya" pitch: str = "Moderate" rate: str = "Moderate" temperature: float = 0.8 do_sample: bool = True def build_description(speaker_name, gender, pitch, rate): """Build voice description prompt.""" if speaker_name: return ( f"{speaker_name}'s voice delivers a slightly expressive speech " f"with a {pitch.lower()} pitch and a {rate.lower()} speaking rate. " f"The recording is of very high quality, with the speaker's voice sounding clear " f"and very close up. Very clear audio." ) else: return ( f"A {gender.lower()} speaker delivers a slightly expressive and clear speech " f"with a {pitch.lower()} pitch and a {rate.lower()} speaking rate. " f"The recording is of very high quality, with the speaker's voice sounding clear " f"and very close up. Very clear audio." ) def split_sentences(text): """Split Urdu text into sentences.""" sentences = re.split(r'[۔।\.\!\?]+', text) return [s.strip() for s in sentences if s.strip()] def clean_urdu_text(text): """Minimal text cleaning - preserve content.""" text = re.sub(r'\s+', ' ', text).strip() if text and text[-1] not in '۔.!?،': text += '۔' return text @spaces.GPU() def generate_speech_internal(text, speaker, pitch, rate, temperature, do_sample): """Internal function for speech generation - sentence by sentence with fp16.""" if not text.strip(): return None try: text = clean_urdu_text(text) speaker_name = SPEAKERS.get(speaker, "") gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male" description = build_description(speaker_name, gender, pitch, rate) sentences = split_sentences(text) if not sentences: sentences = [text.strip()] all_audio = [] seed = torch.randint(0, 2**32, (1,)).item() for sentence in sentences: desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE) prompt_tokens = tokenizer(sentence, return_tensors="pt").to(DEVICE) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) with torch.no_grad(): generation = model.generate( input_ids=desc_tokens.input_ids, attention_mask=desc_tokens.attention_mask, prompt_input_ids=prompt_tokens.input_ids, prompt_attention_mask=prompt_tokens.attention_mask, do_sample=do_sample, temperature=temperature if do_sample else 1.0, min_new_tokens=10, ) audio_chunk = generation.cpu().numpy().squeeze() audio_chunk = (audio_chunk * 32767).astype(np.int16) all_audio.append(audio_chunk) # Add 0.3s silence between sentences silence = np.zeros(int(SAMPLE_RATE * 0.3), dtype=np.int16) all_audio.append(silence) if not all_audio: return None audio = np.concatenate(all_audio) return audio except Exception as e: print(f"Error generating speech: {e}") import traceback traceback.print_exc() return None @app.get("/") async def root(): """Health check endpoint.""" return { "status": "ok", "model": "Indic Parler-TTS", "speakers": list(SPEAKERS.keys()), "sample_rate": SAMPLE_RATE, "endpoints": { "POST /tts": "Standard TTS (wait for full audio)", "POST /tts/stream": "HTTP Streaming TTS (audio chunks in real-time)", "WS /ws/tts": "WebSocket TTS (BEST FOR PIPECAT - true real-time bidirectional)", "GET /speakers": "List available speakers" }, "device": DEVICE, "optimization": "fp16 (half precision)" } def generate_audio_chunks_streaming(text, speaker, pitch, rate, temperature, do_sample): """Generate audio using official ParlerTTSStreamer for true streaming.""" text = clean_urdu_text(text) speaker_name = SPEAKERS.get(speaker, "") gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male" description = build_description(speaker_name, gender, pitch, rate) # Create streamer for real-time audio chunks play_steps = int(model.config.sampling_rate * 0.5) # 0.5 second chunks streamer = ParlerTTSStreamer(model, device=DEVICE, play_steps=play_steps) # Tokenize desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE) prompt_tokens = tokenizer(text, return_tensors="pt").to(DEVICE) # Set seed seed = torch.randint(0, 2**32, (1,)).item() torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # Generate in background thread generation_kwargs = dict( input_ids=desc_tokens.input_ids, attention_mask=desc_tokens.attention_mask, prompt_input_ids=prompt_tokens.input_ids, prompt_attention_mask=prompt_tokens.attention_mask, streamer=streamer, do_sample=do_sample, temperature=temperature if do_sample else 1.0, min_new_tokens=10, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.daemon = True thread.start() # Yield audio chunks as they're generated for audio_chunk in streamer: if audio_chunk.shape[0] > 0: # ParlerTTSStreamer yields numpy arrays directly, not tensors audio_int16 = (audio_chunk * 32767).astype(np.int16) yield audio_int16 thread.join() async def generate_audio_stream(text, speaker, pitch, rate, temperature, do_sample): """Stream audio generation sentence by sentence.""" text = clean_urdu_text(text) speaker_name = SPEAKERS.get(speaker, "") gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male" description = build_description(speaker_name, gender, pitch, rate) sentences = split_sentences(text) if not sentences: sentences = [text.strip()] # Collect all audio chunks all_audio = [] seed = torch.randint(0, 2**32, (1,)).item() for sentence in sentences: desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE) prompt_tokens = tokenizer(sentence, return_tensors="pt").to(DEVICE) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) with torch.no_grad(): generation = model.generate( input_ids=desc_tokens.input_ids, attention_mask=desc_tokens.attention_mask, prompt_input_ids=prompt_tokens.input_ids, prompt_attention_mask=prompt_tokens.attention_mask, do_sample=do_sample, temperature=temperature if do_sample else 1.0, min_new_tokens=10, ) audio_chunk = generation.cpu().numpy().squeeze() audio_chunk = (audio_chunk * 32767).astype(np.int16) all_audio.append(audio_chunk) # Add 0.3s silence between sentences silence = np.zeros(int(SAMPLE_RATE * 0.3), dtype=np.int16) all_audio.append(silence) # Convert to WAV audio = np.concatenate(all_audio) audio_buffer = BytesIO() sf.write(audio_buffer, audio, SAMPLE_RATE, format='WAV') audio_buffer.seek(0) return audio_buffer.getvalue() @app.post("/tts/stream") async def text_to_speech_streaming(request: TTSRequest): """Generate speech with real-time streaming (fastest latency).""" if not request.text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") if request.speaker not in SPEAKERS: raise HTTPException(status_code=400, detail=f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}") async def audio_generator(): """Generator that yields audio chunks and WAV header.""" import struct try: # WAV header will be written first wav_header_written = False for audio_chunk in generate_audio_chunks_streaming( request.text, request.speaker, request.pitch, request.rate, request.temperature, request.do_sample ): if not wav_header_written: # Write WAV header on first chunk channels = 1 sample_width = 2 framerate = SAMPLE_RATE audio_buffer = BytesIO() sf.write(audio_buffer, audio_chunk, SAMPLE_RATE, format='WAV') audio_buffer.seek(0) wav_data = audio_buffer.read() yield wav_data wav_header_written = True else: # For subsequent chunks, just append raw audio data yield audio_chunk.tobytes() except Exception as e: print(f"Error in streaming: {e}") import traceback traceback.print_exc() return StreamingResponse( audio_generator(), media_type="audio/wav", headers={"Content-Disposition": "inline; filename=speech.wav"} ) @app.post("/tts") async def text_to_speech(request: TTSRequest): """Generate speech from Urdu text.""" if not request.text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") if request.speaker not in SPEAKERS: raise HTTPException(status_code=400, detail=f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}") try: audio_data = await generate_audio_stream( request.text, request.speaker, request.pitch, request.rate, request.temperature, request.do_sample ) if audio_data is None: raise HTTPException(status_code=500, detail="Failed to generate speech") return StreamingResponse( iter([audio_data]), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=speech.wav"} ) except Exception as e: print(f"Error in TTS endpoint: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/speakers") async def get_speakers(): """Get list of available speakers.""" return {"speakers": list(SPEAKERS.keys())} @app.websocket("/ws/tts") async def websocket_tts(websocket: WebSocket): """WebSocket endpoint for real-time audio streaming. Usage: 1. Connect to ws://localhost:7860/ws/tts 2. Send JSON: {"text": "سلام دنیا", "speaker": "Divya"} 3. Receive audio chunks in real-time 4. Connection closes when generation completes """ await websocket.accept() try: data = await websocket.receive_text() request_data = json.loads(data) text = request_data.get("text", "").strip() speaker = request_data.get("speaker", "Divya") pitch = request_data.get("pitch", "Moderate") rate = request_data.get("rate", "Moderate") temperature = request_data.get("temperature", 0.8) do_sample = request_data.get("do_sample", True) # Validate inputs if not text: await websocket.send_json({"error": "Text cannot be empty"}) await websocket.close() return if speaker not in SPEAKERS: await websocket.send_json({ "error": f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}" }) await websocket.close() return # Send status message await websocket.send_json({ "status": "generating", "message": f"Generating audio for speaker {speaker}..." }) # Generate and stream audio chunks chunk_count = 0 try: for audio_chunk in generate_audio_chunks_streaming( text, speaker, pitch, rate, temperature, do_sample ): # Send audio chunk as binary data await websocket.send_bytes(audio_chunk.tobytes()) chunk_count += 1 # Send completion message await websocket.send_json({ "status": "complete", "chunks_sent": chunk_count }) except Exception as e: await websocket.send_json({ "error": f"Generation failed: {str(e)}" }) except WebSocketDisconnect: print("WebSocket client disconnected") except json.JSONDecodeError: await websocket.send_json({"error": "Invalid JSON format"}) await websocket.close() except Exception as e: print(f"WebSocket error: {e}") await websocket.close() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)