AI-RVC / tests /test_model_defaults.py
mason369's picture
Release v1.2.1
a9536c4 verified
import sys
import tempfile
import unittest
from pathlib import Path
import numpy as np
import soundfile as sf
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
class ModelDefaultTests(unittest.TestCase):
def test_roformer_default_uses_public_audio_separator_sota_model(self):
from infer import separator
self.assertEqual(
separator.ROFORMER_DEFAULT_MODEL,
"ensemble:vocal_rvc",
)
self.assertEqual(
separator.ROFORMER_SOTA_MODELS,
[
"melband_roformer_big_beta6x.ckpt",
"mel_band_roformer_vocals_fv4_gabox.ckpt",
],
)
self.assertIn(
"vocals_mel_band_roformer.ckpt",
separator.ROFORMER_LEGACY_SINGLE_MODEL,
)
def test_karaoke_default_uses_public_sota_ensemble(self):
from infer import separator
self.assertEqual(
separator.KARAOKE_DEFAULT_MODEL,
"ensemble:karaoke",
)
self.assertEqual(
separator.KARAOKE_SOTA_MODEL,
"ensemble:karaoke",
)
self.assertEqual(
separator.KARAOKE_SOTA_MODELS,
[
"mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt",
"mel_band_roformer_karaoke_gabox_v2.ckpt",
"mel_band_roformer_karaoke_becruily.ckpt",
],
)
self.assertEqual(
separator.KARAOKE_LEGACY_SINGLE_MODEL,
"mel_band_roformer_karaoke_gabox.ckpt",
)
def test_deecho_default_uses_public_roformer_dereverb_model(self):
from infer import separator
self.assertEqual(
separator.ROFORMER_DEREVERB_DEFAULT_MODEL,
"dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt",
)
def test_strict_sota_defaults_do_not_expose_model_fallback_lists(self):
from infer import separator
self.assertFalse(hasattr(separator, "ROFORMER_FALLBACK_MODELS"))
self.assertFalse(hasattr(separator, "KARAOKE_FALLBACK_MODELS"))
self.assertFalse(hasattr(separator, "ROFORMER_DEREVERB_FALLBACK_MODELS"))
class KaraokeCandidateScoringTests(unittest.TestCase):
def test_karaoke_candidate_score_rewards_reconstruction_and_low_correlation(self):
from tools.evaluate_karaoke_models import score_karaoke_stems
sr = 16000
t = np.arange(sr, dtype=np.float32) / sr
lead_good = 0.18 * np.sin(2 * np.pi * 220 * t)
backing_good = 0.05 * np.sin(2 * np.pi * 330 * t + 0.4)
input_vocals = lead_good + backing_good
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
input_path = tmp_path / "input.wav"
lead_good_path = tmp_path / "lead_good.wav"
backing_good_path = tmp_path / "backing_good.wav"
lead_bad_path = tmp_path / "lead_bad.wav"
backing_bad_path = tmp_path / "backing_bad.wav"
sf.write(input_path, input_vocals, sr)
sf.write(lead_good_path, lead_good, sr)
sf.write(backing_good_path, backing_good, sr)
sf.write(lead_bad_path, input_vocals, sr)
sf.write(backing_bad_path, 0.7 * input_vocals, sr)
good = score_karaoke_stems(input_path, lead_good_path, backing_good_path)
bad = score_karaoke_stems(input_path, lead_bad_path, backing_bad_path)
self.assertGreater(good["score"], bad["score"])
self.assertLess(good["reconstruction_error"], bad["reconstruction_error"])
self.assertLess(good["lead_backing_abs_corr"], bad["lead_backing_abs_corr"])
def test_karaoke_candidate_score_penalizes_truncated_stems(self):
from tools.evaluate_karaoke_models import score_karaoke_stems
sr = 16000
t = np.arange(sr, dtype=np.float32) / sr
lead_good = 0.18 * np.sin(2 * np.pi * 220 * t)
backing_good = 0.04 * np.sin(2 * np.pi * 330 * t + 0.4)
input_vocals = lead_good + backing_good
short_len = sr // 4
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
input_path = tmp_path / "input.wav"
lead_short_path = tmp_path / "lead_short.wav"
backing_short_path = tmp_path / "backing_short.wav"
lead_full_path = tmp_path / "lead_full.wav"
backing_full_path = tmp_path / "backing_full.wav"
sf.write(input_path, input_vocals, sr)
sf.write(lead_short_path, lead_good[:short_len], sr)
sf.write(backing_short_path, backing_good[:short_len], sr)
sf.write(lead_full_path, 0.97 * lead_good, sr)
sf.write(backing_full_path, 0.97 * backing_good, sr)
short = score_karaoke_stems(input_path, lead_short_path, backing_short_path)
full = score_karaoke_stems(input_path, lead_full_path, backing_full_path)
self.assertIn("length_coverage", short)
self.assertLess(short["length_coverage"], 0.999)
self.assertGreaterEqual(full["length_coverage"], 0.999)
self.assertGreater(full["score"], short["score"])
def test_reference_karaoke_score_uses_true_si_sdr_when_refs_exist(self):
from tools.evaluate_karaoke_models import score_reference_stems
sr = 16000
t = np.arange(sr, dtype=np.float32) / sr
lead = 0.18 * np.sin(2 * np.pi * 220 * t)
backing = 0.04 * np.sin(2 * np.pi * 330 * t + 0.4)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
reference_lead_path = tmp_path / "reference_lead.wav"
reference_backing_path = tmp_path / "reference_backing.wav"
lead_path = tmp_path / "lead.wav"
backing_path = tmp_path / "backing.wav"
sf.write(reference_lead_path, lead, sr)
sf.write(reference_backing_path, backing, sr)
sf.write(lead_path, lead, sr)
sf.write(backing_path, backing, sr)
metrics = score_reference_stems(
reference_lead_path,
reference_backing_path,
lead_path,
backing_path,
)
self.assertGreater(metrics["mean_si_sdr"], 100.0)
self.assertIn("lead", metrics["stems"])
self.assertIn("backing", metrics["stems"])
if __name__ == "__main__":
unittest.main()