|
|
| """
|
| RMVPE 模型 - 用于高质量 F0 提取
|
| """
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import numpy as np
|
| from typing import Optional
|
|
|
|
|
| class BiGRU(nn.Module):
|
| """双向 GRU 层"""
|
|
|
| def __init__(self, input_features: int, hidden_features: int, num_layers: int):
|
| super().__init__()
|
| self.gru = nn.GRU(
|
| input_features,
|
| hidden_features,
|
| num_layers=num_layers,
|
| batch_first=True,
|
| bidirectional=True
|
| )
|
|
|
| def forward(self, x):
|
| return self.gru(x)[0]
|
|
|
|
|
| class ConvBlockRes(nn.Module):
|
| """残差卷积块"""
|
|
|
| def __init__(self, in_channels: int, out_channels: int, momentum: float = 0.01,
|
| force_shortcut: bool = False):
|
| super().__init__()
|
| self.conv = nn.Sequential(
|
| nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
| nn.BatchNorm2d(out_channels, momentum=momentum),
|
| nn.ReLU(),
|
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
|
| nn.BatchNorm2d(out_channels, momentum=momentum),
|
| nn.ReLU()
|
| )
|
|
|
|
|
| if in_channels != out_channels or force_shortcut:
|
| self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
|
| self.has_shortcut = True
|
| else:
|
| self.has_shortcut = False
|
|
|
| def forward(self, x):
|
| if self.has_shortcut:
|
| return self.conv(x) + self.shortcut(x)
|
| else:
|
| return self.conv(x) + x
|
|
|
|
|
| class EncoderBlock(nn.Module):
|
| """编码器块 - 包含多个 ConvBlockRes 和一个池化层"""
|
|
|
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
|
| n_blocks: int, momentum: float = 0.01):
|
| super().__init__()
|
| self.conv = nn.ModuleList()
|
| self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
| for _ in range(n_blocks - 1):
|
| self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| self.pool = nn.AvgPool2d(kernel_size)
|
|
|
| def forward(self, x):
|
| for block in self.conv:
|
| x = block(x)
|
|
|
| return self.pool(x), x
|
|
|
|
|
| class Encoder(nn.Module):
|
| """RMVPE 编码器"""
|
|
|
| def __init__(self, in_channels: int, in_size: int, n_encoders: int,
|
| kernel_size: int, n_blocks: int, out_channels: int = 16,
|
| momentum: float = 0.01):
|
| super().__init__()
|
|
|
| self.n_encoders = n_encoders
|
| self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| self.layers = nn.ModuleList()
|
| self.latent_channels = []
|
|
|
| for i in range(n_encoders):
|
| self.layers.append(
|
| EncoderBlock(
|
| in_channels if i == 0 else out_channels * (2 ** (i - 1)),
|
| out_channels * (2 ** i),
|
| kernel_size,
|
| n_blocks,
|
| momentum
|
| )
|
| )
|
| self.latent_channels.append(out_channels * (2 ** i))
|
|
|
| def forward(self, x):
|
| x = self.bn(x)
|
| concat_tensors = []
|
| for layer in self.layers:
|
| x, skip = layer(x)
|
| concat_tensors.append(skip)
|
| return x, concat_tensors
|
|
|
|
|
| class Intermediate(nn.Module):
|
| """中间层"""
|
|
|
| def __init__(self, in_channels: int, out_channels: int, n_inters: int,
|
| n_blocks: int, momentum: float = 0.01):
|
| super().__init__()
|
|
|
| self.layers = nn.ModuleList()
|
| for i in range(n_inters):
|
| if i == 0:
|
|
|
| self.layers.append(
|
| IntermediateBlock(in_channels, out_channels, n_blocks, momentum, first_block_shortcut=True)
|
| )
|
| else:
|
|
|
| self.layers.append(
|
| IntermediateBlock(out_channels, out_channels, n_blocks, momentum, first_block_shortcut=False)
|
| )
|
|
|
| def forward(self, x):
|
| for layer in self.layers:
|
| x = layer(x)
|
| return x
|
|
|
|
|
| class IntermediateBlock(nn.Module):
|
| """中间层块"""
|
|
|
| def __init__(self, in_channels: int, out_channels: int, n_blocks: int,
|
| momentum: float = 0.01, first_block_shortcut: bool = False):
|
| super().__init__()
|
| self.conv = nn.ModuleList()
|
|
|
| self.conv.append(ConvBlockRes(in_channels, out_channels, momentum, force_shortcut=first_block_shortcut))
|
| for _ in range(n_blocks - 1):
|
| self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
|
|
| def forward(self, x):
|
| for block in self.conv:
|
| x = block(x)
|
| return x
|
|
|
|
|
| class DecoderBlock(nn.Module):
|
| """解码器块"""
|
|
|
| def __init__(self, in_channels: int, out_channels: int, stride: int,
|
| n_blocks: int, momentum: float = 0.01):
|
| super().__init__()
|
|
|
| self.conv1 = nn.Sequential(
|
| nn.ConvTranspose2d(in_channels, out_channels, 3, stride, padding=1, output_padding=1, bias=False),
|
| nn.BatchNorm2d(out_channels, momentum=momentum)
|
| )
|
|
|
|
|
|
|
| self.conv2 = nn.ModuleList()
|
| self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
| for _ in range(n_blocks - 1):
|
| self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
|
|
| def forward(self, x, concat_tensor):
|
| x = self.conv1(x)
|
|
|
| diff_h = concat_tensor.size(2) - x.size(2)
|
| diff_w = concat_tensor.size(3) - x.size(3)
|
| if diff_h != 0 or diff_w != 0:
|
|
|
| x = F.pad(x, [0, diff_w, 0, diff_h])
|
| x = torch.cat([x, concat_tensor], dim=1)
|
| for block in self.conv2:
|
| x = block(x)
|
| return x
|
|
|
|
|
| class Decoder(nn.Module):
|
| """RMVPE 解码器"""
|
|
|
| def __init__(self, in_channels: int, n_decoders: int, stride: int,
|
| n_blocks: int, out_channels: int = 16, momentum: float = 0.01):
|
| super().__init__()
|
|
|
| self.layers = nn.ModuleList()
|
| for i in range(n_decoders):
|
| out_ch = out_channels * (2 ** (n_decoders - 1 - i))
|
| in_ch = in_channels if i == 0 else out_channels * (2 ** (n_decoders - i))
|
| self.layers.append(
|
| DecoderBlock(in_ch, out_ch, stride, n_blocks, momentum)
|
| )
|
|
|
| def forward(self, x, concat_tensors):
|
| for i, layer in enumerate(self.layers):
|
| x = layer(x, concat_tensors[-1 - i])
|
| return x
|
|
|
|
|
| class DeepUnet(nn.Module):
|
| """Deep U-Net 架构"""
|
|
|
| def __init__(self, kernel_size: int, n_blocks: int, en_de_layers: int = 5,
|
| inter_layers: int = 4, in_channels: int = 1, en_out_channels: int = 16):
|
| super().__init__()
|
|
|
|
|
| encoder_out_channels = en_out_channels * (2 ** (en_de_layers - 1))
|
|
|
| intermediate_out_channels = encoder_out_channels * 2
|
|
|
| self.encoder = Encoder(
|
| in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
|
| )
|
| self.intermediate = Intermediate(
|
| encoder_out_channels,
|
| intermediate_out_channels,
|
| inter_layers, n_blocks
|
| )
|
| self.decoder = Decoder(
|
| intermediate_out_channels,
|
| en_de_layers, kernel_size, n_blocks, en_out_channels
|
| )
|
|
|
| def forward(self, x):
|
| x, concat_tensors = self.encoder(x)
|
| x = self.intermediate(x)
|
| x = self.decoder(x, concat_tensors)
|
| return x
|
|
|
|
|
| class E2E(nn.Module):
|
| """端到端 RMVPE 模型"""
|
|
|
| def __init__(self, n_blocks: int, n_gru: int, kernel_size: int,
|
| en_de_layers: int = 5, inter_layers: int = 4,
|
| in_channels: int = 1, en_out_channels: int = 16):
|
| super().__init__()
|
|
|
| self.unet = DeepUnet(
|
| kernel_size, n_blocks, en_de_layers, inter_layers,
|
| in_channels, en_out_channels
|
| )
|
| self.cnn = nn.Conv2d(en_out_channels, 3, 3, 1, 1)
|
|
|
| if n_gru:
|
| self.fc = nn.Sequential(
|
| BiGRU(3 * 128, 256, n_gru),
|
| nn.Linear(512, 360),
|
| nn.Dropout(0.25),
|
| nn.Sigmoid()
|
| )
|
| else:
|
| self.fc = nn.Sequential(
|
| nn.Linear(3 * 128, 360),
|
| nn.Dropout(0.25),
|
| nn.Sigmoid()
|
| )
|
|
|
| def forward(self, mel):
|
|
|
|
|
| if mel.dim() == 3:
|
|
|
| mel = mel.transpose(-1, -2).unsqueeze(1)
|
| elif mel.dim() == 4 and mel.shape[1] == 1:
|
|
|
| mel = mel.transpose(-1, -2)
|
|
|
| x = self.unet(mel)
|
| x = self.cnn(x)
|
|
|
|
|
| x = x.transpose(1, 2).flatten(-2)
|
| x = self.fc(x)
|
| return x
|
|
|
|
|
| class MelSpectrogram(nn.Module):
|
| """Mel 频谱提取"""
|
|
|
| def __init__(self, n_mel: int = 128, n_fft: int = 1024, win_size: int = 1024,
|
| hop_length: int = 160, sample_rate: int = 16000,
|
| fmin: int = 30, fmax: int = 8000):
|
| super().__init__()
|
|
|
| self.n_fft = n_fft
|
| self.hop_length = hop_length
|
| self.win_size = win_size
|
| self.sample_rate = sample_rate
|
| self.n_mel = n_mel
|
|
|
|
|
| mel_basis = self._mel_filterbank(sample_rate, n_fft, n_mel, fmin, fmax)
|
| self.register_buffer("mel_basis", mel_basis)
|
| self.register_buffer("window", torch.hann_window(win_size))
|
|
|
| def _mel_filterbank(self, sr, n_fft, n_mels, fmin, fmax):
|
| """创建 Mel 滤波器组"""
|
| import librosa
|
|
|
| mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=True)
|
| return torch.from_numpy(mel).float()
|
|
|
| def forward(self, audio):
|
|
|
| spec = torch.stft(
|
| audio,
|
| self.n_fft,
|
| hop_length=self.hop_length,
|
| win_length=self.win_size,
|
| window=self.window,
|
| center=True,
|
| pad_mode="reflect",
|
| normalized=False,
|
| onesided=True,
|
| return_complex=True
|
| )
|
|
|
| spec = torch.abs(spec) ** 2
|
|
|
|
|
| mel = torch.matmul(self.mel_basis, spec)
|
| mel = torch.log(torch.clamp(mel, min=1e-5))
|
|
|
| return mel
|
|
|
|
|
| class RMVPE:
|
| """RMVPE F0 提取器封装类"""
|
|
|
| def __init__(self, model_path: str, device: str = "cuda"):
|
| self.device = device
|
|
|
|
|
| self.model = E2E(n_blocks=4, n_gru=1, kernel_size=2)
|
| ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
|
| self.model.load_state_dict(ckpt)
|
| self.model = self.model.to(device).eval()
|
|
|
|
|
| self.mel_extractor = MelSpectrogram().to(device)
|
|
|
|
|
| cents_mapping = 20 * np.arange(360) + 1997.3794084376191
|
| self.cents_mapping = np.pad(cents_mapping, (4, 4))
|
|
|
| @torch.no_grad()
|
| def infer_from_audio(self, audio: np.ndarray, thred: float = 0.03) -> np.ndarray:
|
| """
|
| 从音频提取 F0
|
|
|
| Args:
|
| audio: 16kHz 音频数据
|
| thred: 置信度阈值
|
|
|
| Returns:
|
| np.ndarray: F0 序列
|
| """
|
|
|
| audio = torch.from_numpy(audio).float().to(self.device)
|
| if audio.dim() == 1:
|
| audio = audio.unsqueeze(0)
|
|
|
|
|
| mel = self.mel_extractor(audio)
|
|
|
|
|
| n_frames = mel.shape[-1]
|
|
|
|
|
| n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
|
| if n_pad > 0:
|
| mel = F.pad(mel, (0, n_pad), mode='constant', value=0)
|
|
|
|
|
| hidden = self.model(mel)
|
|
|
|
|
| hidden = hidden[:, :n_frames, :]
|
| hidden = hidden.squeeze(0).cpu().numpy()
|
|
|
|
|
| f0 = self._decode(hidden, thred)
|
|
|
| return f0
|
|
|
| def _decode(self, hidden: np.ndarray, thred: float) -> np.ndarray:
|
| """解码隐藏状态为 F0 - 使用官方 RVC 算法"""
|
|
|
| cents = self._to_local_average_cents(hidden, thred)
|
|
|
|
|
| f0 = 10 * (2 ** (cents / 1200))
|
| f0[f0 == 10] = 0
|
|
|
| return f0
|
|
|
| def _to_local_average_cents(self, salience: np.ndarray, thred: float) -> np.ndarray:
|
| """官方 RVC 的 to_local_average_cents 算法"""
|
|
|
| center = np.argmax(salience, axis=1)
|
|
|
|
|
| salience = np.pad(salience, ((0, 0), (4, 4)))
|
| center += 4
|
|
|
|
|
| todo_salience = []
|
| todo_cents_mapping = []
|
| starts = center - 4
|
| ends = center + 5
|
|
|
| for idx in range(salience.shape[0]):
|
| todo_salience.append(salience[idx, starts[idx]:ends[idx]])
|
| todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
|
|
|
| todo_salience = np.array(todo_salience)
|
| todo_cents_mapping = np.array(todo_cents_mapping)
|
|
|
|
|
| product_sum = np.sum(todo_salience * todo_cents_mapping, axis=1)
|
| weight_sum = np.sum(todo_salience, axis=1) + 1e-9
|
| cents = product_sum / weight_sum
|
|
|
|
|
| maxx = np.max(salience, axis=1)
|
| cents[maxx <= thred] = 0
|
|
|
| return cents
|
|
|