Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import StreamingResponse | |
| import json | |
| import torch | |
| import numpy as np | |
| import re | |
| from io import BytesIO | |
| import soundfile as sf | |
| from pydantic import BaseModel | |
| import os | |
| from huggingface_hub import login | |
| from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer | |
| from transformers import AutoTokenizer | |
| from threading import Thread | |
| import queue | |
| # Authenticate with HuggingFace if token is available | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| # Try to import spaces for HF Spaces deployment | |
| try: | |
| import spaces | |
| HAS_SPACES = True | |
| except ImportError: | |
| HAS_SPACES = False | |
| class _NoOpSpaces: | |
| def GPU(self, *args, **kwargs): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _NoOpSpaces() | |
| # --- Model Loading --- | |
| MODEL_ID = "ai4bharat/indic-parler-tts" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| if DEVICE == "cuda": | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print("Loading Indic Parler-TTS model...") | |
| model = ParlerTTSForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE) | |
| # Optimize model for inference | |
| if DEVICE == "cuda": | |
| model = model.half() # Use half precision (fp16) for faster inference | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path) | |
| SAMPLE_RATE = model.config.sampling_rate | |
| # Disable gradients for inference | |
| torch.set_grad_enabled(False) | |
| print("Model loaded and optimized!") | |
| # Named speakers | |
| SPEAKERS = { | |
| "Divya": "Divya", | |
| "Rani": "Rani", | |
| "Rohit": "Rohit", | |
| "Aman": "Aman", | |
| "Generic Female": "", | |
| "Generic Male": "", | |
| } | |
| app = FastAPI(title="Parler TTS API", version="1.0") | |
| class TTSRequest(BaseModel): | |
| text: str | |
| speaker: str = "Divya" | |
| pitch: str = "Moderate" | |
| rate: str = "Moderate" | |
| temperature: float = 0.8 | |
| do_sample: bool = True | |
| def build_description(speaker_name, gender, pitch, rate): | |
| """Build voice description prompt.""" | |
| if speaker_name: | |
| return ( | |
| f"{speaker_name}'s voice delivers a slightly expressive speech " | |
| f"with a {pitch.lower()} pitch and a {rate.lower()} speaking rate. " | |
| f"The recording is of very high quality, with the speaker's voice sounding clear " | |
| f"and very close up. Very clear audio." | |
| ) | |
| else: | |
| return ( | |
| f"A {gender.lower()} speaker delivers a slightly expressive and clear speech " | |
| f"with a {pitch.lower()} pitch and a {rate.lower()} speaking rate. " | |
| f"The recording is of very high quality, with the speaker's voice sounding clear " | |
| f"and very close up. Very clear audio." | |
| ) | |
| def split_sentences(text): | |
| """Split Urdu text into sentences.""" | |
| sentences = re.split(r'[۔।\.\!\?]+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def clean_urdu_text(text): | |
| """Minimal text cleaning - preserve content.""" | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| if text and text[-1] not in '۔.!?،': | |
| text += '۔' | |
| return text | |
| def generate_speech_internal(text, speaker, pitch, rate, temperature, do_sample): | |
| """Internal function for speech generation - sentence by sentence with fp16.""" | |
| if not text.strip(): | |
| return None | |
| try: | |
| text = clean_urdu_text(text) | |
| speaker_name = SPEAKERS.get(speaker, "") | |
| gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male" | |
| description = build_description(speaker_name, gender, pitch, rate) | |
| sentences = split_sentences(text) | |
| if not sentences: | |
| sentences = [text.strip()] | |
| all_audio = [] | |
| seed = torch.randint(0, 2**32, (1,)).item() | |
| for sentence in sentences: | |
| desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE) | |
| prompt_tokens = tokenizer(sentence, return_tensors="pt").to(DEVICE) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| with torch.no_grad(): | |
| generation = model.generate( | |
| input_ids=desc_tokens.input_ids, | |
| attention_mask=desc_tokens.attention_mask, | |
| prompt_input_ids=prompt_tokens.input_ids, | |
| prompt_attention_mask=prompt_tokens.attention_mask, | |
| do_sample=do_sample, | |
| temperature=temperature if do_sample else 1.0, | |
| min_new_tokens=10, | |
| ) | |
| audio_chunk = generation.cpu().numpy().squeeze() | |
| audio_chunk = (audio_chunk * 32767).astype(np.int16) | |
| all_audio.append(audio_chunk) | |
| # Add 0.3s silence between sentences | |
| silence = np.zeros(int(SAMPLE_RATE * 0.3), dtype=np.int16) | |
| all_audio.append(silence) | |
| if not all_audio: | |
| return None | |
| audio = np.concatenate(all_audio) | |
| return audio | |
| except Exception as e: | |
| print(f"Error generating speech: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| async def root(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "ok", | |
| "model": "Indic Parler-TTS", | |
| "speakers": list(SPEAKERS.keys()), | |
| "sample_rate": SAMPLE_RATE, | |
| "endpoints": { | |
| "POST /tts": "Standard TTS (wait for full audio)", | |
| "POST /tts/stream": "HTTP Streaming TTS (audio chunks in real-time)", | |
| "WS /ws/tts": "WebSocket TTS (BEST FOR PIPECAT - true real-time bidirectional)", | |
| "GET /speakers": "List available speakers" | |
| }, | |
| "device": DEVICE, | |
| "optimization": "fp16 (half precision)" | |
| } | |
| def generate_audio_chunks_streaming(text, speaker, pitch, rate, temperature, do_sample): | |
| """Generate audio using official ParlerTTSStreamer for true streaming.""" | |
| text = clean_urdu_text(text) | |
| speaker_name = SPEAKERS.get(speaker, "") | |
| gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male" | |
| description = build_description(speaker_name, gender, pitch, rate) | |
| # Create streamer for real-time audio chunks | |
| play_steps = int(model.config.sampling_rate * 0.5) # 0.5 second chunks | |
| streamer = ParlerTTSStreamer(model, device=DEVICE, play_steps=play_steps) | |
| # Tokenize | |
| desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE) | |
| prompt_tokens = tokenizer(text, return_tensors="pt").to(DEVICE) | |
| # Set seed | |
| seed = torch.randint(0, 2**32, (1,)).item() | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| # Generate in background thread | |
| generation_kwargs = dict( | |
| input_ids=desc_tokens.input_ids, | |
| attention_mask=desc_tokens.attention_mask, | |
| prompt_input_ids=prompt_tokens.input_ids, | |
| prompt_attention_mask=prompt_tokens.attention_mask, | |
| streamer=streamer, | |
| do_sample=do_sample, | |
| temperature=temperature if do_sample else 1.0, | |
| min_new_tokens=10, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.daemon = True | |
| thread.start() | |
| # Yield audio chunks as they're generated | |
| for audio_chunk in streamer: | |
| if audio_chunk.shape[0] > 0: | |
| # ParlerTTSStreamer yields numpy arrays directly, not tensors | |
| audio_int16 = (audio_chunk * 32767).astype(np.int16) | |
| yield audio_int16 | |
| thread.join() | |
| async def generate_audio_stream(text, speaker, pitch, rate, temperature, do_sample): | |
| """Stream audio generation sentence by sentence.""" | |
| text = clean_urdu_text(text) | |
| speaker_name = SPEAKERS.get(speaker, "") | |
| gender = "female" if "Female" in speaker or speaker in ["Divya", "Rani"] else "male" | |
| description = build_description(speaker_name, gender, pitch, rate) | |
| sentences = split_sentences(text) | |
| if not sentences: | |
| sentences = [text.strip()] | |
| # Collect all audio chunks | |
| all_audio = [] | |
| seed = torch.randint(0, 2**32, (1,)).item() | |
| for sentence in sentences: | |
| desc_tokens = description_tokenizer(description, return_tensors="pt").to(DEVICE) | |
| prompt_tokens = tokenizer(sentence, return_tensors="pt").to(DEVICE) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| with torch.no_grad(): | |
| generation = model.generate( | |
| input_ids=desc_tokens.input_ids, | |
| attention_mask=desc_tokens.attention_mask, | |
| prompt_input_ids=prompt_tokens.input_ids, | |
| prompt_attention_mask=prompt_tokens.attention_mask, | |
| do_sample=do_sample, | |
| temperature=temperature if do_sample else 1.0, | |
| min_new_tokens=10, | |
| ) | |
| audio_chunk = generation.cpu().numpy().squeeze() | |
| audio_chunk = (audio_chunk * 32767).astype(np.int16) | |
| all_audio.append(audio_chunk) | |
| # Add 0.3s silence between sentences | |
| silence = np.zeros(int(SAMPLE_RATE * 0.3), dtype=np.int16) | |
| all_audio.append(silence) | |
| # Convert to WAV | |
| audio = np.concatenate(all_audio) | |
| audio_buffer = BytesIO() | |
| sf.write(audio_buffer, audio, SAMPLE_RATE, format='WAV') | |
| audio_buffer.seek(0) | |
| return audio_buffer.getvalue() | |
| async def text_to_speech_streaming(request: TTSRequest): | |
| """Generate speech with real-time streaming (fastest latency).""" | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| if request.speaker not in SPEAKERS: | |
| raise HTTPException(status_code=400, detail=f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}") | |
| async def audio_generator(): | |
| """Generator that yields audio chunks and WAV header.""" | |
| import struct | |
| try: | |
| # WAV header will be written first | |
| wav_header_written = False | |
| for audio_chunk in generate_audio_chunks_streaming( | |
| request.text, | |
| request.speaker, | |
| request.pitch, | |
| request.rate, | |
| request.temperature, | |
| request.do_sample | |
| ): | |
| if not wav_header_written: | |
| # Write WAV header on first chunk | |
| channels = 1 | |
| sample_width = 2 | |
| framerate = SAMPLE_RATE | |
| audio_buffer = BytesIO() | |
| sf.write(audio_buffer, audio_chunk, SAMPLE_RATE, format='WAV') | |
| audio_buffer.seek(0) | |
| wav_data = audio_buffer.read() | |
| yield wav_data | |
| wav_header_written = True | |
| else: | |
| # For subsequent chunks, just append raw audio data | |
| yield audio_chunk.tobytes() | |
| except Exception as e: | |
| print(f"Error in streaming: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return StreamingResponse( | |
| audio_generator(), | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": "inline; filename=speech.wav"} | |
| ) | |
| async def text_to_speech(request: TTSRequest): | |
| """Generate speech from Urdu text.""" | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| if request.speaker not in SPEAKERS: | |
| raise HTTPException(status_code=400, detail=f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}") | |
| try: | |
| audio_data = await generate_audio_stream( | |
| request.text, | |
| request.speaker, | |
| request.pitch, | |
| request.rate, | |
| request.temperature, | |
| request.do_sample | |
| ) | |
| if audio_data is None: | |
| raise HTTPException(status_code=500, detail="Failed to generate speech") | |
| return StreamingResponse( | |
| iter([audio_data]), | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": "attachment; filename=speech.wav"} | |
| ) | |
| except Exception as e: | |
| print(f"Error in TTS endpoint: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_speakers(): | |
| """Get list of available speakers.""" | |
| return {"speakers": list(SPEAKERS.keys())} | |
| async def websocket_tts(websocket: WebSocket): | |
| """WebSocket endpoint for real-time audio streaming. | |
| Usage: | |
| 1. Connect to ws://localhost:7860/ws/tts | |
| 2. Send JSON: {"text": "سلام دنیا", "speaker": "Divya"} | |
| 3. Receive audio chunks in real-time | |
| 4. Connection closes when generation completes | |
| """ | |
| await websocket.accept() | |
| try: | |
| data = await websocket.receive_text() | |
| request_data = json.loads(data) | |
| text = request_data.get("text", "").strip() | |
| speaker = request_data.get("speaker", "Divya") | |
| pitch = request_data.get("pitch", "Moderate") | |
| rate = request_data.get("rate", "Moderate") | |
| temperature = request_data.get("temperature", 0.8) | |
| do_sample = request_data.get("do_sample", True) | |
| # Validate inputs | |
| if not text: | |
| await websocket.send_json({"error": "Text cannot be empty"}) | |
| await websocket.close() | |
| return | |
| if speaker not in SPEAKERS: | |
| await websocket.send_json({ | |
| "error": f"Invalid speaker. Choose from: {list(SPEAKERS.keys())}" | |
| }) | |
| await websocket.close() | |
| return | |
| # Send status message | |
| await websocket.send_json({ | |
| "status": "generating", | |
| "message": f"Generating audio for speaker {speaker}..." | |
| }) | |
| # Generate and stream audio chunks | |
| chunk_count = 0 | |
| try: | |
| for audio_chunk in generate_audio_chunks_streaming( | |
| text, speaker, pitch, rate, temperature, do_sample | |
| ): | |
| # Send audio chunk as binary data | |
| await websocket.send_bytes(audio_chunk.tobytes()) | |
| chunk_count += 1 | |
| # Send completion message | |
| await websocket.send_json({ | |
| "status": "complete", | |
| "chunks_sent": chunk_count | |
| }) | |
| except Exception as e: | |
| await websocket.send_json({ | |
| "error": f"Generation failed: {str(e)}" | |
| }) | |
| except WebSocketDisconnect: | |
| print("WebSocket client disconnected") | |
| except json.JSONDecodeError: | |
| await websocket.send_json({"error": "Invalid JSON format"}) | |
| await websocket.close() | |
| except Exception as e: | |
| print(f"WebSocket error: {e}") | |
| await websocket.close() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |