Parler_TTS_API / test_api.py
aspirant312's picture
Add comprehensive test suite for TTS API audio validation
b483241
import requests
import io
import soundfile as sf
import numpy as np
from pathlib import Path
import json
# Test configuration
BASE_URL = "http://localhost:7860"
TEST_OUTPUT_DIR = Path("test_outputs")
# Create output directory for test audio files
TEST_OUTPUT_DIR.mkdir(exist_ok=True)
def test_health_check():
"""Test the health check endpoint."""
response = requests.get(f"{BASE_URL}/")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert "model" in data
assert "speakers" in data
print("✓ Health check passed")
print(f" Available speakers: {data['speakers']}")
print(f" Sample rate: {data['sample_rate']} Hz")
return data
def test_speakers_endpoint():
"""Test the speakers endpoint."""
response = requests.get(f"{BASE_URL}/speakers")
assert response.status_code == 200
data = response.json()
assert "speakers" in data
assert len(data["speakers"]) > 0
print(f"✓ Speakers endpoint passed")
print(f" Speakers: {data['speakers']}")
return data["speakers"]
def test_tts_generation(text, speaker="Divya", pitch="Moderate", rate="Moderate"):
"""Test TTS generation and validate audio output."""
payload = {
"text": text,
"speaker": speaker,
"pitch": pitch,
"rate": rate,
"temperature": 0.8,
"do_sample": True
}
response = requests.post(f"{BASE_URL}/tts", json=payload)
if response.status_code != 200:
print(f"✗ TTS generation failed: {response.status_code}")
print(f" Response: {response.text}")
return False
# Validate that we got audio data
audio_data = response.content
assert len(audio_data) > 0, "Audio data is empty"
# Try to read the audio to validate it's valid WAV
try:
audio_buffer = io.BytesIO(audio_data)
audio, sample_rate = sf.read(audio_buffer)
# Validate audio properties
assert isinstance(audio, np.ndarray), "Audio is not a numpy array"
assert len(audio) > 0, "Audio array is empty"
assert sample_rate > 0, "Invalid sample rate"
# Calculate audio duration
duration = len(audio) / sample_rate
print(f"✓ TTS generation successful")
print(f" Text: {text}")
print(f" Speaker: {speaker}")
print(f" Audio shape: {audio.shape}")
print(f" Sample rate: {sample_rate} Hz")
print(f" Duration: {duration:.2f} seconds")
print(f" Audio size: {len(audio_data) / 1024:.2f} KB")
# Save test audio file
test_file = TEST_OUTPUT_DIR / f"test_{speaker}_{len(text)}_chars.wav"
with open(test_file, "wb") as f:
f.write(audio_data)
print(f" Saved to: {test_file}")
return True
except Exception as e:
print(f"✗ Failed to read audio: {e}")
return False
def test_empty_text():
"""Test handling of empty text."""
payload = {"text": "", "speaker": "Divya"}
response = requests.post(f"{BASE_URL}/tts", json=payload)
assert response.status_code == 400, "Should return 400 for empty text"
print("✓ Empty text validation passed")
def test_invalid_speaker():
"""Test handling of invalid speaker."""
payload = {"text": "Hello world", "speaker": "InvalidSpeaker"}
response = requests.post(f"{BASE_URL}/tts", json=payload)
assert response.status_code == 400, "Should return 400 for invalid speaker"
print("✓ Invalid speaker validation passed")
def run_all_tests():
"""Run all tests."""
print("\n" + "="*60)
print("TTS API Test Suite")
print("="*60 + "\n")
try:
# Basic endpoint tests
print("1. Testing health check...")
health_data = test_health_check()
print()
print("2. Testing speakers endpoint...")
speakers = test_speakers_endpoint()
print()
# Validation tests
print("3. Testing input validation...")
test_empty_text()
test_invalid_speaker()
print()
# TTS generation tests with different speakers and text
print("4. Testing TTS generation with different speakers...")
test_cases = [
("سلام دنیا", "Divya"), # Urdu text
("ہیلو ورلڈ", "Rani"),
("مرحبا العالم", "Generic Female"),
]
for text, speaker in test_cases:
success = test_tts_generation(text, speaker=speaker)
if not success:
print(f" WARNING: Test failed for {speaker}")
print()
print("="*60)
print(f"All tests completed! Test outputs saved to: {TEST_OUTPUT_DIR}")
print("="*60 + "\n")
except requests.exceptions.ConnectionError:
print(f"✗ Cannot connect to server at {BASE_URL}")
print(" Make sure the API is running: python api.py")
except Exception as e:
print(f"✗ Test suite failed with error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
run_all_tests()