1/mvector/predict.py
2025-04-18 19:56:58 +08:00

189 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pickle
import shutil
from io import BufferedReader
import numpy as np
import torch
import yaml
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from mvector import SUPPORT_MODEL
from mvector.data_utils.audio import AudioSegment
from mvector.data_utils.featurizer import AudioFeaturizer
from mvector.models.ecapa_tdnn import EcapaTdnn
from mvector.models.fc import SpeakerIdetification
from mvector.utils.logger import setup_logger
from mvector.utils.utils import dict_to_object, print_arguments
logger = setup_logger(__name__)
class MVectorPredictor:
def __init__(self,
configs,
threshold=0.6,
model_path='models/ecapa_tdnn_FBank/best_model/',
use_gpu=True):
"""
声纹识别预测工具
:param configs: 配置参数
:param threshold: 判断是否为同一个人的阈值
:param model_path: 导出的预测模型文件夹路径
:param use_gpu: 是否使用GPU预测
"""
if use_gpu:
assert (torch.cuda.is_available()), 'GPU不可用'
self.device = torch.device("cuda")
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
self.device = torch.device("cpu")
# 索引候选数量
self.cdd_num = 5
self.threshold = threshold
# 读取配置文件
if isinstance(configs, str):
with open(configs, 'r', encoding='utf-8') as f:
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
# print_arguments(configs=configs)
self.configs = dict_to_object(configs)
assert 'max_duration' in self.configs.dataset_conf, \
'【警告】,您貌似使用了旧的配置文件,如果你同时使用了旧的模型,这是错误的,请重新下载或者重新训练,否则只能回滚代码。'
assert self.configs.use_model in SUPPORT_MODEL, f'没有该模型:{self.configs.use_model}'
self._audio_featurizer = AudioFeaturizer(feature_conf=self.configs.feature_conf, **self.configs.preprocess_conf)
self._audio_featurizer.to(self.device)
# 获取模型
if self.configs.use_model == 'EcapaTdnn' or self.configs.use_model == 'ecapa_tdnn':
backbone = EcapaTdnn(input_size=self._audio_featurizer.feature_dim, **self.configs.model_conf)
else:
raise Exception(f'{self.configs.use_model} 模型不存在!')
model = SpeakerIdetification(backbone=backbone, num_class=self.configs.dataset_conf.num_speakers)
model.to(self.device)
# 加载模型
if os.path.isdir(model_path):
model_path = os.path.join(model_path, 'model.pt')
assert os.path.exists(model_path), f"{model_path} 模型不存在!"
if torch.cuda.is_available() and use_gpu:
model_state_dict = torch.load(model_path)
else:
model_state_dict = torch.load(model_path, map_location='cpu')
# 加载模型参数
model.load_state_dict(model_state_dict)
print(f"成功加载模型参数:{model_path}")
# 设置为评估模式
model.eval()
self.predictor = model.backbone
# 声纹库的声纹特征
self.audio_feature = None
def _load_audio(self, audio_data, sample_rate=16000):
"""加载音频
:param audio_data: 需要识别的数据支持文件路径文件对象字节numpy。如果是字节的话必须是完整的字节文件
:param sample_rate: 如果传入的事numpy数据需要指定采样率
:return: 识别的文本结果和解码的得分数
"""
# 加载音频文件,并进行预处理
if isinstance(audio_data, str):
audio_segment = AudioSegment.from_file(audio_data)
elif isinstance(audio_data, BufferedReader):
audio_segment = AudioSegment.from_file(audio_data)
elif isinstance(audio_data, np.ndarray):
audio_segment = AudioSegment.from_ndarray(audio_data, sample_rate)
elif isinstance(audio_data, bytes):
audio_segment = AudioSegment.from_bytes(audio_data)
else:
raise Exception(f'不支持该数据类型,当前数据类型为:{type(audio_data)}')
assert audio_segment.duration >= self.configs.dataset_conf.min_duration, \
f'音频太短,最小应该为{self.configs.dataset_conf.min_duration}s当前音频为{audio_segment.duration}s'
# 重采样
if audio_segment.sample_rate != self.configs.dataset_conf.sample_rate:
audio_segment.resample(self.configs.dataset_conf.sample_rate)
# decibel normalization
if self.configs.dataset_conf.use_dB_normalization:
audio_segment.normalize(target_db=self.configs.dataset_conf.target_dB)
return audio_segment
def predict(self,
audio_data,
sample_rate=16000):
"""预测一个音频的特征
:param audio_data: 需要识别的数据支持文件路径文件对象字节numpy。如果是字节的话必须是完整并带格式的字节文件
:param sample_rate: 如果传入的事numpy数据需要指定采样率
:return: 声纹特征向量
"""
# 加载音频文件,并进行预处理
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
input_data = torch.tensor(input_data.samples, dtype=torch.float32, device=self.device).unsqueeze(0)
input_len_ratio = torch.tensor([1], dtype=torch.float32, device=self.device)
audio_feature, _ = self._audio_featurizer(input_data, input_len_ratio)
# 执行预测
feature = self.predictor(audio_feature).data.cpu().numpy()[0]
return feature
def predict_batch(self, audios_data, sample_rate=16000):
"""预测一批音频的特征
:param audios_data: 需要识别的数据支持文件路径文件对象字节numpy。如果是字节的话必须是完整并带格式的字节文件
:param sample_rate: 如果传入的事numpy数据需要指定采样率
:return: 声纹特征向量
"""
audios_data1 = []
for audio_data in audios_data:
# 加载音频文件,并进行预处理
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
audios_data1.append(input_data.samples)
# 找出音频长度最长的
batch = sorted(audios_data1, key=lambda a: a.shape[0], reverse=True)
max_audio_length = batch[0].shape[0]
batch_size = len(batch)
# 以最大的长度创建0张量
inputs = np.zeros((batch_size, max_audio_length), dtype='float32')
input_lens_ratio = []
for x in range(batch_size):
tensor = audios_data1[x]
seq_length = tensor.shape[0]
# 将数据插入都0张量中实现了padding
inputs[x, :seq_length] = tensor[:]
input_lens_ratio.append(seq_length/max_audio_length)
audios_data = torch.tensor(inputs, dtype=torch.float32, device=self.device)
input_lens_ratio = torch.tensor(input_lens_ratio, dtype=torch.float32, device=self.device)
audio_feature, _ = self._audio_featurizer(audios_data, input_lens_ratio)
# 执行预测
features = self.predictor(audio_feature).data.cpu().numpy()
return features
# 声纹对比
def contrast(self, audio_data1, audio_data2):
feature1 = self.predict(audio_data1)
feature2 = self.predict(audio_data2)
# 对角余弦值
dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
return dist
def recognition(self, audio_data, threshold=None, sample_rate=16000):
"""声纹识别
:param audio_data: 需要识别的数据支持文件路径文件对象字节numpy。如果是字节的话必须是完整的字节文件
:param threshold: 判断的阈值如果为None则用创建对象时使用的阈值
:param sample_rate: 如果传入的事numpy数据需要指定采样率
:return: 识别的用户名称如果为None即没有识别到用户
"""
if threshold:
self.threshold = threshold
feature = self.predict(audio_data, sample_rate=sample_rate)
name = self.__retrieval(np_feature=[feature])[0]
return name
def compare(self, feature1, feature2):
"""声纹对比
:param feature1: 特征1
:param feature2: 特征2
:return:
"""
# 对角余弦值
dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
return dist