Spaces:
Sleeping
Sleeping
| import requests | |
| import io | |
| import soundfile as sf | |
| import numpy as np | |
| from pathlib import Path | |
| import json | |
| # Test configuration | |
| BASE_URL = "http://localhost:7860" | |
| TEST_OUTPUT_DIR = Path("test_outputs") | |
| # Create output directory for test audio files | |
| TEST_OUTPUT_DIR.mkdir(exist_ok=True) | |
| def test_health_check(): | |
| """Test the health check endpoint.""" | |
| response = requests.get(f"{BASE_URL}/") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] == "ok" | |
| assert "model" in data | |
| assert "speakers" in data | |
| print("✓ Health check passed") | |
| print(f" Available speakers: {data['speakers']}") | |
| print(f" Sample rate: {data['sample_rate']} Hz") | |
| return data | |
| def test_speakers_endpoint(): | |
| """Test the speakers endpoint.""" | |
| response = requests.get(f"{BASE_URL}/speakers") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "speakers" in data | |
| assert len(data["speakers"]) > 0 | |
| print(f"✓ Speakers endpoint passed") | |
| print(f" Speakers: {data['speakers']}") | |
| return data["speakers"] | |
| def test_tts_generation(text, speaker="Divya", pitch="Moderate", rate="Moderate"): | |
| """Test TTS generation and validate audio output.""" | |
| payload = { | |
| "text": text, | |
| "speaker": speaker, | |
| "pitch": pitch, | |
| "rate": rate, | |
| "temperature": 0.8, | |
| "do_sample": True | |
| } | |
| response = requests.post(f"{BASE_URL}/tts", json=payload) | |
| if response.status_code != 200: | |
| print(f"✗ TTS generation failed: {response.status_code}") | |
| print(f" Response: {response.text}") | |
| return False | |
| # Validate that we got audio data | |
| audio_data = response.content | |
| assert len(audio_data) > 0, "Audio data is empty" | |
| # Try to read the audio to validate it's valid WAV | |
| try: | |
| audio_buffer = io.BytesIO(audio_data) | |
| audio, sample_rate = sf.read(audio_buffer) | |
| # Validate audio properties | |
| assert isinstance(audio, np.ndarray), "Audio is not a numpy array" | |
| assert len(audio) > 0, "Audio array is empty" | |
| assert sample_rate > 0, "Invalid sample rate" | |
| # Calculate audio duration | |
| duration = len(audio) / sample_rate | |
| print(f"✓ TTS generation successful") | |
| print(f" Text: {text}") | |
| print(f" Speaker: {speaker}") | |
| print(f" Audio shape: {audio.shape}") | |
| print(f" Sample rate: {sample_rate} Hz") | |
| print(f" Duration: {duration:.2f} seconds") | |
| print(f" Audio size: {len(audio_data) / 1024:.2f} KB") | |
| # Save test audio file | |
| test_file = TEST_OUTPUT_DIR / f"test_{speaker}_{len(text)}_chars.wav" | |
| with open(test_file, "wb") as f: | |
| f.write(audio_data) | |
| print(f" Saved to: {test_file}") | |
| return True | |
| except Exception as e: | |
| print(f"✗ Failed to read audio: {e}") | |
| return False | |
| def test_empty_text(): | |
| """Test handling of empty text.""" | |
| payload = {"text": "", "speaker": "Divya"} | |
| response = requests.post(f"{BASE_URL}/tts", json=payload) | |
| assert response.status_code == 400, "Should return 400 for empty text" | |
| print("✓ Empty text validation passed") | |
| def test_invalid_speaker(): | |
| """Test handling of invalid speaker.""" | |
| payload = {"text": "Hello world", "speaker": "InvalidSpeaker"} | |
| response = requests.post(f"{BASE_URL}/tts", json=payload) | |
| assert response.status_code == 400, "Should return 400 for invalid speaker" | |
| print("✓ Invalid speaker validation passed") | |
| def run_all_tests(): | |
| """Run all tests.""" | |
| print("\n" + "="*60) | |
| print("TTS API Test Suite") | |
| print("="*60 + "\n") | |
| try: | |
| # Basic endpoint tests | |
| print("1. Testing health check...") | |
| health_data = test_health_check() | |
| print() | |
| print("2. Testing speakers endpoint...") | |
| speakers = test_speakers_endpoint() | |
| print() | |
| # Validation tests | |
| print("3. Testing input validation...") | |
| test_empty_text() | |
| test_invalid_speaker() | |
| print() | |
| # TTS generation tests with different speakers and text | |
| print("4. Testing TTS generation with different speakers...") | |
| test_cases = [ | |
| ("سلام دنیا", "Divya"), # Urdu text | |
| ("ہیلو ورلڈ", "Rani"), | |
| ("مرحبا العالم", "Generic Female"), | |
| ] | |
| for text, speaker in test_cases: | |
| success = test_tts_generation(text, speaker=speaker) | |
| if not success: | |
| print(f" WARNING: Test failed for {speaker}") | |
| print() | |
| print("="*60) | |
| print(f"All tests completed! Test outputs saved to: {TEST_OUTPUT_DIR}") | |
| print("="*60 + "\n") | |
| except requests.exceptions.ConnectionError: | |
| print(f"✗ Cannot connect to server at {BASE_URL}") | |
| print(" Make sure the API is running: python api.py") | |
| except Exception as e: | |
| print(f"✗ Test suite failed with error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| run_all_tests() | |