first commit
This commit is contained in:
commit
9dbcf5c730
97
README.md
Normal file
97
README.md
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# 前言
|
||||||
|
|
||||||
|
|
||||||
|
使用环境:
|
||||||
|
|
||||||
|
- Anaconda 3
|
||||||
|
- Python 3.8
|
||||||
|
- Pytorch 1.13.1
|
||||||
|
- Windows 10 or Ubuntu 18.04
|
||||||
|
|
||||||
|
# 项目特性
|
||||||
|
|
||||||
|
1. 支持模型:EcapaTdnn、TDNN、Res2Net、ResNetSE
|
||||||
|
2. 支持池化层:AttentiveStatsPool(ASP)、SelfAttentivePooling(SAP)、TemporalStatisticsPooling(TSP)、TemporalAveragePooling(TAP)
|
||||||
|
3. 支持损失函数:AAMLoss、AMLoss、ARMLoss、CELoss
|
||||||
|
4. 支持预处理方法:MelSpectrogram、Spectrogram、MFCC
|
||||||
|
|
||||||
|
|
||||||
|
## 安装环境
|
||||||
|
|
||||||
|
- 首先安装的是Pytorch的GPU版本,如果已经安装过了,请跳过。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
conda install pytorch==11.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
|
||||||
|
```
|
||||||
|
|
||||||
|
- 安装ppvector库。
|
||||||
|
|
||||||
|
使用pip安装,命令如下:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m pip install mvector -U -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
```
|
||||||
|
|
||||||
|
# 使用指南
|
||||||
|
|
||||||
|
## 1. 环境准备
|
||||||
|
### 1.1 安装依赖
|
||||||
|
```shell
|
||||||
|
# 使用conda创建环境(可选)
|
||||||
|
conda create -n voiceprint python=3.8
|
||||||
|
conda activate voiceprint
|
||||||
|
|
||||||
|
# 安装项目依赖
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.2 准备音频数据
|
||||||
|
- 在`audio_db/`目录存放注册语音(建议16kHz单通道wav格式)
|
||||||
|
- 测试音频建议存放至`test_audio/`目录
|
||||||
|
|
||||||
|
## 2. 核心功能使用
|
||||||
|
|
||||||
|
### 2.1 训练声纹模型
|
||||||
|
```shell
|
||||||
|
python train.py \
|
||||||
|
--config_path configs/ecapa_tdnn.yml \
|
||||||
|
--augmentation_config configs/augmentation.json \
|
||||||
|
--save_dir models/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 声纹注册入库
|
||||||
|
```python
|
||||||
|
from mvector import MVector
|
||||||
|
mvector = MVector()
|
||||||
|
mvector.register_user(name="user1", audio_path="audio_db/user1.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 实时声纹识别
|
||||||
|
```shell
|
||||||
|
python infer_recognition.py \
|
||||||
|
--model_path models/ecapa_tdnn.pth \
|
||||||
|
--audio_path test_audio/unknown.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.4 声纹对比验证
|
||||||
|
```shell
|
||||||
|
python infer_contrast.py \
|
||||||
|
--audio1 audio_db/user1.wav \
|
||||||
|
--audio2 test_audio/sample.wav \
|
||||||
|
--threshold 0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. 降噪预处理
|
||||||
|
```python
|
||||||
|
from Reduction_Noise import NoiseReducer
|
||||||
|
reducer = NoiseReducer("Reduction_Noise/pytorch_model.bin")
|
||||||
|
clean_audio = reducer.process("noisy_audio.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. 模型评估
|
||||||
|
```shell
|
||||||
|
python eval.py \
|
||||||
|
--model_path models/ecapa_tdnn.pth \
|
||||||
|
--test_csv eval_samples.csv \
|
||||||
|
--batch_size 32
|
||||||
|
```
|
||||||
1
Reduction_Noise
Submodule
1
Reduction_Noise
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit cfc4b6a2433a4a6f0d2d1cda5f944d677e072ef4
|
||||||
BIN
audio_db/output1.wav
Normal file
BIN
audio_db/output1.wav
Normal file
Binary file not shown.
BIN
audio_db/output2.wav
Normal file
BIN
audio_db/output2.wav
Normal file
Binary file not shown.
BIN
audio_db/test.wav
Normal file
BIN
audio_db/test.wav
Normal file
Binary file not shown.
BIN
audio_db/test_Re.wav
Normal file
BIN
audio_db/test_Re.wav
Normal file
Binary file not shown.
72
configs/augmentation.json
Normal file
72
configs/augmentation.json
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "noise",
|
||||||
|
"aug_type": "audio",
|
||||||
|
"params": {
|
||||||
|
"min_snr_dB": 10,
|
||||||
|
"max_snr_dB": 50,
|
||||||
|
"repetition": 2,
|
||||||
|
"noise_dir": "dataset/noise/"
|
||||||
|
},
|
||||||
|
"prob": 0.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "resample",
|
||||||
|
"aug_type": "audio",
|
||||||
|
"params": {
|
||||||
|
"new_sample_rate": [8000, 32000, 44100, 48000]
|
||||||
|
},
|
||||||
|
"prob": 0.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "speed",
|
||||||
|
"aug_type": "audio",
|
||||||
|
"params": {
|
||||||
|
"min_speed_rate": 0.9,
|
||||||
|
"max_speed_rate": 1.1,
|
||||||
|
"num_rates": 3
|
||||||
|
},
|
||||||
|
"prob": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "shift",
|
||||||
|
"aug_type": "audio",
|
||||||
|
"params": {
|
||||||
|
"min_shift_ms": -5,
|
||||||
|
"max_shift_ms": 5
|
||||||
|
},
|
||||||
|
"prob": 0.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "volume",
|
||||||
|
"aug_type": "audio",
|
||||||
|
"params": {
|
||||||
|
"min_gain_dBFS": -15,
|
||||||
|
"max_gain_dBFS": 15
|
||||||
|
},
|
||||||
|
"prob": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "specaug",
|
||||||
|
"aug_type": "feature",
|
||||||
|
"params": {
|
||||||
|
"inplace": true,
|
||||||
|
"max_time_warp": 5,
|
||||||
|
"max_t_ratio": 0.01,
|
||||||
|
"n_freq_masks": 2,
|
||||||
|
"max_f_ratio": 0.05,
|
||||||
|
"n_time_masks": 2,
|
||||||
|
"replace_with_zero": true
|
||||||
|
},
|
||||||
|
"prob": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "specsub",
|
||||||
|
"aug_type": "feature",
|
||||||
|
"params": {
|
||||||
|
"max_t": 10,
|
||||||
|
"num_t_sub": 2
|
||||||
|
},
|
||||||
|
"prob": 0.0
|
||||||
|
}
|
||||||
|
]
|
||||||
54
configs/ecapa_tdnn.yml
Normal file
54
configs/ecapa_tdnn.yml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# 数据集参数
|
||||||
|
dataset_conf:
|
||||||
|
# 训练的批量大小
|
||||||
|
batch_size: 256
|
||||||
|
# 说话人数量,即分类大小
|
||||||
|
num_speakers: 3242
|
||||||
|
# 读取数据的线程数量
|
||||||
|
num_workers: 12
|
||||||
|
# 过滤最短的音频长度
|
||||||
|
min_duration: 0.5
|
||||||
|
# 最长的音频长度,大于这个长度会裁剪掉
|
||||||
|
max_duration: 6
|
||||||
|
# 是否裁剪静音片段
|
||||||
|
do_vad: False
|
||||||
|
# 音频的采样率
|
||||||
|
sample_rate: 16000
|
||||||
|
# 是否对音频进行音量归一化
|
||||||
|
use_dB_normalization: False
|
||||||
|
# 对音频进行音量归一化的音量分贝值
|
||||||
|
target_dB: -20
|
||||||
|
# 训练数据的数据列表路径
|
||||||
|
train_list: 'dataset/train_list.txt'
|
||||||
|
# 测试数据的数据列表路径
|
||||||
|
test_list: 'dataset/test_list.txt'
|
||||||
|
# 标签列表
|
||||||
|
label_list_path: 'dataset/label_list.txt'
|
||||||
|
|
||||||
|
# 数据预处理参数
|
||||||
|
preprocess_conf:
|
||||||
|
# 音频预处理方法,支持:MelSpectrogram、Spectrogram、MFCC、Fbank
|
||||||
|
feature_method: 'Fbank'
|
||||||
|
|
||||||
|
feature_conf:
|
||||||
|
sample_frequency: 16000
|
||||||
|
num_mel_bins: 80
|
||||||
|
|
||||||
|
optimizer_conf:
|
||||||
|
# 优化方法,支持Adam、AdamW、SGD
|
||||||
|
optimizer: 'Adam'
|
||||||
|
# 初始学习率的大小
|
||||||
|
learning_rate: 0.001
|
||||||
|
weight_decay: 1e-6
|
||||||
|
|
||||||
|
model_conf:
|
||||||
|
embd_dim: 192
|
||||||
|
channels: 512
|
||||||
|
|
||||||
|
train_conf:
|
||||||
|
# 训练的轮数
|
||||||
|
max_epoch: 30
|
||||||
|
log_interval: 100
|
||||||
|
|
||||||
|
# 所使用的模型
|
||||||
|
use_model: 'ecapa_tdnn'
|
||||||
104
create_data.py
Normal file
104
create_data.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
|
||||||
|
# 生成数据列表
|
||||||
|
def get_data_list(infodata_path, zhvoice_path):
|
||||||
|
print('正在读取标注文件...')
|
||||||
|
with open(infodata_path, 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
data = []
|
||||||
|
speakers = []
|
||||||
|
speakers_dict = {}
|
||||||
|
for line in lines:
|
||||||
|
line = json.loads(line.replace('\n', ''))
|
||||||
|
duration_ms = line['duration_ms']
|
||||||
|
if duration_ms < 1300:
|
||||||
|
continue
|
||||||
|
speaker = line['speaker']
|
||||||
|
if speaker not in speakers:
|
||||||
|
speakers_dict[speaker] = len(speakers)
|
||||||
|
speakers.append(speaker)
|
||||||
|
label = speakers_dict[speaker]
|
||||||
|
sound_path = os.path.join(zhvoice_path, line['index'])
|
||||||
|
data.append([sound_path.replace('\\', '/'), label])
|
||||||
|
print(f'一共有{len(data)}条数据!')
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def mp32wav(num, data_list):
|
||||||
|
start = time.time()
|
||||||
|
for i, data in enumerate(data_list):
|
||||||
|
sound_path, label = data
|
||||||
|
if os.path.exists(sound_path):
|
||||||
|
save_path = sound_path.replace('.mp3', '.wav')
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
wav = AudioSegment.from_mp3(sound_path)
|
||||||
|
wav.export(save_path, format="wav")
|
||||||
|
os.remove(sound_path)
|
||||||
|
if i % 100 == 0:
|
||||||
|
eta_sec = ((time.time() - start) / 100 * (len(data_list) - i))
|
||||||
|
start = time.time()
|
||||||
|
eta_str = str(timedelta(seconds=int(eta_sec)))
|
||||||
|
print(f'进程{num}进度:[{i}/{len(data_list)}],剩余时间:{eta_str}')
|
||||||
|
|
||||||
|
|
||||||
|
def split_data(list_temp, n):
|
||||||
|
length = len(list_temp) // n
|
||||||
|
for i in range(0, len(list_temp), length):
|
||||||
|
yield list_temp[i:i + length]
|
||||||
|
|
||||||
|
|
||||||
|
def main(infodata_path, list_path, zhvoice_path, to_wav=True, num_workers=2):
|
||||||
|
if to_wav:
|
||||||
|
text = input(f'音频文件将会转换为wav格式,这个过程可能很长,而且最终文件大小接近100G,是否继续?(y/n)')
|
||||||
|
if text is None or text != 'y':
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
text = input(f'将会直接使用MP3格式文件,但读取速度会比wav格式慢,是否继续?(y/n)')
|
||||||
|
if text is None or text != 'y':
|
||||||
|
return
|
||||||
|
data_all = []
|
||||||
|
data = get_data_list(infodata_path=infodata_path, zhvoice_path=zhvoice_path)
|
||||||
|
if to_wav:
|
||||||
|
print('准备把MP3总成WAV格式...')
|
||||||
|
split_d = split_data(data, num_workers)
|
||||||
|
pool = Pool(num_workers)
|
||||||
|
for i, d in enumerate(split_d):
|
||||||
|
pool.apply_async(mp32wav, (i, d))
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
for d in data:
|
||||||
|
sound_path, label = d
|
||||||
|
sound_path = sound_path.replace('.mp3', '.wav')
|
||||||
|
if os.path.exists(sound_path):
|
||||||
|
data_all.append([sound_path, label])
|
||||||
|
else:
|
||||||
|
for d in data:
|
||||||
|
sound_path, label = d
|
||||||
|
if os.path.exists(sound_path):
|
||||||
|
data_all.append(d)
|
||||||
|
f_train = open(os.path.join(list_path, 'train_list.txt'), 'w')
|
||||||
|
f_test = open(os.path.join(list_path, 'test_list.txt'), 'w')
|
||||||
|
for i, d in enumerate(data_all):
|
||||||
|
sound_path, label = d
|
||||||
|
if i % 200 == 0:
|
||||||
|
f_test.write(f'{sound_path}\t{label}\n')
|
||||||
|
else:
|
||||||
|
f_train.write(f'{sound_path}\t{label}\n')
|
||||||
|
f_test.close()
|
||||||
|
f_train.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main(infodata_path='dataset/zhvoice/text/infodata.json',
|
||||||
|
list_path='dataset',
|
||||||
|
zhvoice_path='dataset/zhvoice',
|
||||||
|
to_wav=False,
|
||||||
|
num_workers=cpu_count())
|
||||||
25
eval.py
Normal file
25
eval.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import time
|
||||||
|
|
||||||
|
from mvector.trainer import MVectorTrainer
|
||||||
|
from mvector.utils.utils import add_arguments, print_arguments
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
add_arg('configs', str, 'configs/ecapa_tdnn.yml', "配置文件")
|
||||||
|
add_arg("use_gpu", bool, True, "是否使用GPU评估模型")
|
||||||
|
add_arg('save_image_path', str, 'output/images/', "保存结果图的路径")
|
||||||
|
add_arg('resume_model', str, 'models/ecapa_tdnn_MFCC/best_model/', "模型的路径")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args=args)
|
||||||
|
|
||||||
|
# 获取训练器
|
||||||
|
trainer = MVectorTrainer(configs=args.configs, use_gpu=args.use_gpu)
|
||||||
|
|
||||||
|
# 开始评估
|
||||||
|
start = time.time()
|
||||||
|
tpr, fpr, eer, threshold = trainer.evaluate(resume_model=args.resume_model, save_image_path=args.save_image_path)
|
||||||
|
end = time.time()
|
||||||
|
print('评估消耗时间:{}s,threshold:{:.2f},tpr:{:.5f}, fpr: {:.5f}, eer: {:.5f}'
|
||||||
|
.format(int(end - start), threshold, tpr, fpr, eer))
|
||||||
52
infer_contrast.py
Normal file
52
infer_contrast.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from mvector.predict import MVectorPredictor
|
||||||
|
from mvector.utils.utils import add_arguments, print_arguments
|
||||||
|
from mvector.data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
|
||||||
|
add_arg('use_gpu', bool, True, '是否使用GPU预测')
|
||||||
|
add_arg('audio_path1', str, 'dataset/source/王翔/wx1_5.wav', '预测第一个音频')
|
||||||
|
add_arg('audio_path2', str, 'dataset/source/刘云杰/lyj_no_5.wav', '预测第二个音频')
|
||||||
|
add_arg('threshold', float, 0.7, '判断是否为同一个人的阈值')
|
||||||
|
add_arg('model_path', str, 'models/test_model', '导出的预测模型文件路径')
|
||||||
|
args = parser.parse_args()
|
||||||
|
# print_arguments(args=args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 获取识别器
|
||||||
|
predictor = MVectorPredictor(configs=args.configs,
|
||||||
|
model_path=args.model_path,
|
||||||
|
use_gpu=args.use_gpu)
|
||||||
|
|
||||||
|
def load_audio_paths(file_path):
|
||||||
|
with open(file_path, 'r',encoding='utf-8') as file:
|
||||||
|
return [line.strip() for line in file if line.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def compare_audio_files(audio_paths, threshold):
|
||||||
|
# itertools.combinations 生成所有可能的两两组合,无重复
|
||||||
|
for audio1, audio2 in itertools.combinations(audio_paths, 2):
|
||||||
|
dist = predictor.contrast(audio1, audio2)
|
||||||
|
if dist > threshold:
|
||||||
|
print(f"{audio1} 和 {audio2} 为同一个人,相似度为:{dist}")
|
||||||
|
# else:
|
||||||
|
# print(f"{audio1} 和 {audio2} 不是同一个人,相似度为:{dist}")
|
||||||
|
|
||||||
|
file_path = 'dataset/ces.txt' # 假设音频路径存储在此文件中
|
||||||
|
|
||||||
|
# 执行比对
|
||||||
|
audio_paths = load_audio_paths(file_path)
|
||||||
|
compare_audio_files(audio_paths, args.threshold)
|
||||||
|
# # AudioSegment.silent_semoval(args.audio_path1, args.audio_path1)
|
||||||
|
# # AudioSegment.silent_semoval(args.audio_path2, args.audio_path2)
|
||||||
|
# dist = predictor.contrast(args.audio_path1, args.audio_path2)
|
||||||
|
# if dist > args.threshold:
|
||||||
|
# print(f"{args.audio_path1} 和 {args.audio_path2} 为同一个人,相似度为:{dist}")
|
||||||
|
# else:
|
||||||
|
# print(f"{args.audio_path1} 和 {args.audio_path2} 不是同一个人,相似度为:{dist}")
|
||||||
49
infer_recognition.py
Normal file
49
infer_recognition.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from mvector.predict import MVectorPredictor
|
||||||
|
from mvector.utils.record import RecordAudio
|
||||||
|
from mvector.utils.utils import add_arguments, print_arguments
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
|
||||||
|
add_arg('use_gpu', bool, True, '是否使用GPU预测')
|
||||||
|
add_arg('audio_db_path', str, 'audio_db/', '音频库的路径')
|
||||||
|
add_arg('record_seconds', int, 3, '录音长度')
|
||||||
|
add_arg('threshold', float, 0.6, '判断是否为同一个人的阈值')
|
||||||
|
add_arg('model_path', str, 'models/ecapa_tdnn_MelSpectrogram/best_model/', '导出的预测模型文件路径')
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args=args)
|
||||||
|
|
||||||
|
# 获取识别器
|
||||||
|
predictor = MVectorPredictor(configs=args.configs,
|
||||||
|
threshold=args.threshold,
|
||||||
|
audio_db_path=args.audio_db_path,
|
||||||
|
model_path=args.model_path,
|
||||||
|
use_gpu=args.use_gpu)
|
||||||
|
|
||||||
|
record_audio = RecordAudio()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
select_fun = int(input("请选择功能,0为注册音频到声纹库,1为执行声纹识别,2为删除用户:"))
|
||||||
|
if select_fun == 0:
|
||||||
|
input(f"按下回车键开机录音,录音{args.record_seconds}秒中:")
|
||||||
|
audio_data = record_audio.record(record_seconds=args.record_seconds)
|
||||||
|
name = input("请输入该音频用户的名称:")
|
||||||
|
if name == '': continue
|
||||||
|
predictor.register(user_name=name, audio_data=audio_data, sample_rate=record_audio.sample_rate)
|
||||||
|
elif select_fun == 1:
|
||||||
|
input(f"按下回车键开机录音,录音{args.record_seconds}秒中:")
|
||||||
|
audio_data = record_audio.record(record_seconds=args.record_seconds)
|
||||||
|
name = predictor.recognition(audio_data, sample_rate=record_audio.sample_rate)
|
||||||
|
if name:
|
||||||
|
print(f"识别说话的为:{name}")
|
||||||
|
else:
|
||||||
|
print(f"没有识别到说话人,可能是没注册。")
|
||||||
|
elif select_fun == 2:
|
||||||
|
name = input("请输入该音频用户的名称:")
|
||||||
|
if name == '': continue
|
||||||
|
predictor.remove_user(user_name=name)
|
||||||
|
else:
|
||||||
|
print('请正确选择功能')
|
||||||
40
main.py
Normal file
40
main.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from fastapi import FastAPI, File, UploadFile
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from mvector.predict import MVectorPredictor
|
||||||
|
import os
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# 允许跨域
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
predictor = MVectorPredictor(configs='configs/ecapa_tdnn.yml')
|
||||||
|
|
||||||
|
@app.post("/recognize")
|
||||||
|
async def recognize(file: UploadFile = File(...)):
|
||||||
|
try:
|
||||||
|
audio_bytes = await file.read()
|
||||||
|
result = predictor.recognition(audio_bytes)
|
||||||
|
return {"status": 200, "data": result}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": 500, "error": str(e)}
|
||||||
|
|
||||||
|
@app.post("/compare")
|
||||||
|
async def compare(file1: UploadFile = File(...), file2: UploadFile = File(...)):
|
||||||
|
try:
|
||||||
|
score = predictor.contrast(
|
||||||
|
await file1.read(),
|
||||||
|
await file2.read()
|
||||||
|
)
|
||||||
|
return {"similarity": float(score)}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": 500, "error": str(e)}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=9001)
|
||||||
BIN
models/.DS_Store
vendored
Normal file
BIN
models/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
models/pytorch_model.bin
Normal file
BIN
models/pytorch_model.bin
Normal file
Binary file not shown.
3
mvector/__init__.py
Normal file
3
mvector/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
__version__ = "0.3.9"
|
||||||
|
# 项目支持的模型
|
||||||
|
SUPPORT_MODEL = ['ecapa_tdnn', 'EcapaTdnn', 'Res2Net', 'ResNetSE', 'TDNN']
|
||||||
BIN
mvector/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
mvector/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mvector/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/__pycache__/predict.cpython-37.pyc
Normal file
BIN
mvector/__pycache__/predict.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/__pycache__/trainer.cpython-311.pyc
Normal file
BIN
mvector/__pycache__/trainer.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mvector/__pycache__/trainer.cpython-37.pyc
Normal file
BIN
mvector/__pycache__/trainer.cpython-37.pyc
Normal file
Binary file not shown.
0
mvector/data_utils/__init__.py
Normal file
0
mvector/data_utils/__init__.py
Normal file
BIN
mvector/data_utils/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/audio.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/audio.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/collate_fn.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/collate_fn.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/featurizer.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/featurizer.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/reader.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/reader.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/utils.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/utils.cpython-37.pyc
Normal file
Binary file not shown.
565
mvector/data_utils/audio.py
Normal file
565
mvector/data_utils/audio.py
Normal file
@ -0,0 +1,565 @@
|
|||||||
|
import copy
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import resampy
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from mvector.data_utils.utils import buf_to_float, vad, decode_audio
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSegment(object):
|
||||||
|
"""Monaural audio segment abstraction.
|
||||||
|
|
||||||
|
:param samples: Audio samples [num_samples x num_channels].
|
||||||
|
:type samples: ndarray.float32
|
||||||
|
:param sample_rate: Audio sample rate.
|
||||||
|
:type sample_rate: int
|
||||||
|
:raises TypeError: If the sample data type is not float or int.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, samples, sample_rate):
|
||||||
|
"""Create audio segment from samples.
|
||||||
|
|
||||||
|
Samples are convert float32 internally, with int scaled to [-1, 1].
|
||||||
|
"""
|
||||||
|
self._samples = self._convert_samples_to_float32(samples)
|
||||||
|
self._sample_rate = sample_rate
|
||||||
|
if self._samples.ndim >= 2:
|
||||||
|
self._samples = np.mean(self._samples, 1)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
"""返回两个对象是否相等"""
|
||||||
|
if type(other) is not type(self):
|
||||||
|
return False
|
||||||
|
if self._sample_rate != other._sample_rate:
|
||||||
|
return False
|
||||||
|
if self._samples.shape != other._samples.shape:
|
||||||
|
return False
|
||||||
|
if np.any(self.samples != other._samples):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
"""返回两个对象是否不相等"""
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""返回该音频的信息"""
|
||||||
|
return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
|
||||||
|
"rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate, self.duration, self.rms_db))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, file):
|
||||||
|
"""从音频文件创建音频段
|
||||||
|
|
||||||
|
:param file: 文件路径,或者文件对象
|
||||||
|
:type file: str, BufferedReader
|
||||||
|
:return: 音频片段实例
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
assert os.path.exists(file), f'文件不存在,请检查路径:{file}'
|
||||||
|
try:
|
||||||
|
samples, sample_rate = soundfile.read(file, dtype='float32')
|
||||||
|
except:
|
||||||
|
# 支持更多格式数据
|
||||||
|
sample_rate = 16000
|
||||||
|
samples = decode_audio(file=file, sample_rate=sample_rate)
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slice_from_file(cls, file, start=None, end=None):
|
||||||
|
"""只加载一小段音频,而不需要将整个文件加载到内存中,这是非常浪费的。
|
||||||
|
|
||||||
|
:param file: 输入音频文件路径或文件对象
|
||||||
|
:type file: str|file
|
||||||
|
:param start: 开始时间,单位为秒。如果start是负的,则它从末尾开始计算。如果没有提供,这个函数将从最开始读取。
|
||||||
|
:type start: float
|
||||||
|
:param end: 结束时间,单位为秒。如果end是负的,则它从末尾开始计算。如果没有提供,默认的行为是读取到文件的末尾。
|
||||||
|
:type end: float
|
||||||
|
:return: AudioSegment输入音频文件的指定片的实例。
|
||||||
|
:rtype: AudioSegment
|
||||||
|
:raise ValueError: 如开始或结束的设定不正确,例如时间不允许。
|
||||||
|
"""
|
||||||
|
sndfile = soundfile.SoundFile(file)
|
||||||
|
sample_rate = sndfile.samplerate
|
||||||
|
duration = round(float(len(sndfile)) / sample_rate, 3)
|
||||||
|
start = 0. if start is None else round(start, 3)
|
||||||
|
end = duration if end is None else round(end, 3)
|
||||||
|
# 从末尾开始计
|
||||||
|
if start < 0.0: start += duration
|
||||||
|
if end < 0.0: end += duration
|
||||||
|
# 保证数据不越界
|
||||||
|
if start < 0.0: start = 0.0
|
||||||
|
if end > duration: end = duration
|
||||||
|
if end < 0.0:
|
||||||
|
raise ValueError("切片结束位置(%f s)越界" % end)
|
||||||
|
if start > end:
|
||||||
|
raise ValueError("切片开始位置(%f s)晚于切片结束位置(%f s)" % (start, end))
|
||||||
|
start_frame = int(start * sample_rate)
|
||||||
|
end_frame = int(end * sample_rate)
|
||||||
|
sndfile.seek(start_frame)
|
||||||
|
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
|
||||||
|
return cls(data, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data):
|
||||||
|
"""从包含音频样本的字节创建音频段
|
||||||
|
|
||||||
|
:param data: 包含音频样本的字节
|
||||||
|
:type data: bytes
|
||||||
|
:return: 音频部分实例
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
samples, sample_rate = soundfile.read(io.BytesIO(data), dtype='float32')
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pcm_bytes(cls, data, channels=1, samp_width=2, sample_rate=16000):
|
||||||
|
"""从包含无格式PCM音频的字节创建音频
|
||||||
|
|
||||||
|
:param data: 包含音频样本的字节
|
||||||
|
:type data: bytes
|
||||||
|
:param channels: 音频的通道数
|
||||||
|
:type channels: int
|
||||||
|
:param samp_width: 音频采样的宽度,如np.int16为2
|
||||||
|
:type samp_width: int
|
||||||
|
:param sample_rate: 音频样本采样率
|
||||||
|
:type sample_rate: int
|
||||||
|
:return: 音频部分实例
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
samples = buf_to_float(data, n_bytes=samp_width)
|
||||||
|
if channels > 1:
|
||||||
|
samples = samples.reshape(-1, channels)
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_ndarray(cls, data, sample_rate=16000):
|
||||||
|
"""从numpy.ndarray创建音频段
|
||||||
|
|
||||||
|
:param data: numpy.ndarray类型的音频数据
|
||||||
|
:type data: ndarray
|
||||||
|
:param sample_rate: 音频样本采样率
|
||||||
|
:type sample_rate: int
|
||||||
|
:return: 音频部分实例
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
return cls(data, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, *segments):
|
||||||
|
"""将任意数量的音频片段连接在一起
|
||||||
|
|
||||||
|
:param *segments: 输入音频片段被连接
|
||||||
|
:type *segments: tuple of AudioSegment
|
||||||
|
:return: Audio segment instance as concatenating results.
|
||||||
|
:rtype: AudioSegment
|
||||||
|
:raises ValueError: If the number of segments is zero, or if the
|
||||||
|
sample_rate of any segments does not match.
|
||||||
|
:raises TypeError: If any segment is not AudioSegment instance.
|
||||||
|
"""
|
||||||
|
# Perform basic sanity-checks.
|
||||||
|
if len(segments) == 0:
|
||||||
|
raise ValueError("没有音频片段被给予连接")
|
||||||
|
sample_rate = segments[0]._sample_rate
|
||||||
|
for seg in segments:
|
||||||
|
if sample_rate != seg._sample_rate:
|
||||||
|
raise ValueError("能用不同的采样率连接片段")
|
||||||
|
if type(seg) is not cls:
|
||||||
|
raise TypeError("只有相同类型的音频片段可以连接")
|
||||||
|
samples = np.concatenate([seg.samples for seg in segments])
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_silence(cls, duration, sample_rate):
|
||||||
|
"""创建给定持续时间和采样率的静音音频段
|
||||||
|
|
||||||
|
:param duration: 静音的时间,以秒为单位
|
||||||
|
:type duration: float
|
||||||
|
:param sample_rate: 音频采样率
|
||||||
|
:type sample_rate: float
|
||||||
|
:return: 给定持续时间的静音AudioSegment实例
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
samples = np.zeros(int(duration * sample_rate))
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
def to_wav_file(self, filepath, dtype='float32'):
|
||||||
|
"""保存音频段到磁盘为wav文件
|
||||||
|
|
||||||
|
:param filepath: WAV文件路径或文件对象,以保存音频段
|
||||||
|
:type filepath: str|file
|
||||||
|
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
|
||||||
|
'float32', 'float64'. Default is 'float32'.
|
||||||
|
:type dtype: str
|
||||||
|
:raises TypeError: If dtype is not supported.
|
||||||
|
"""
|
||||||
|
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||||
|
subtype_map = {
|
||||||
|
'int16': 'PCM_16',
|
||||||
|
'int32': 'PCM_32',
|
||||||
|
'float32': 'FLOAT',
|
||||||
|
'float64': 'DOUBLE'
|
||||||
|
}
|
||||||
|
soundfile.write(
|
||||||
|
filepath,
|
||||||
|
samples,
|
||||||
|
self._sample_rate,
|
||||||
|
format='WAV',
|
||||||
|
subtype=subtype_map[dtype])
|
||||||
|
|
||||||
|
def superimpose(self, other):
|
||||||
|
"""将另一个段的样本添加到这个段的样本中(以样本方式添加,而不是段连接)。
|
||||||
|
|
||||||
|
:param other: 包含样品的片段被添加进去
|
||||||
|
:type other: AudioSegments
|
||||||
|
:raise TypeError: 如果两个片段的类型不匹配
|
||||||
|
:raise ValueError: 不能添加不同类型的段
|
||||||
|
"""
|
||||||
|
if not isinstance(other, type(self)):
|
||||||
|
raise TypeError("不能添加不同类型的段: %s 和 %s" % (type(self), type(other)))
|
||||||
|
if self._sample_rate != other._sample_rate:
|
||||||
|
raise ValueError("采样率必须匹配才能添加片段")
|
||||||
|
if len(self._samples) != len(other._samples):
|
||||||
|
raise ValueError("段长度必须匹配才能添加段")
|
||||||
|
self._samples += other._samples
|
||||||
|
|
||||||
|
def to_bytes(self, dtype='float32'):
|
||||||
|
"""创建包含音频内容的字节字符串
|
||||||
|
|
||||||
|
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
||||||
|
'float32', 'float64'. Default is 'float32'.
|
||||||
|
:type dtype: str
|
||||||
|
:return: Byte string containing audio content.
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||||
|
return samples.tostring()
|
||||||
|
|
||||||
|
def to(self, dtype='int16'):
|
||||||
|
"""类型转换
|
||||||
|
|
||||||
|
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
||||||
|
'float32', 'float64'. Default is 'float32'.
|
||||||
|
:type dtype: str
|
||||||
|
:return: np.ndarray containing `dtype` audio content.
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def gain_db(self, gain):
|
||||||
|
"""对音频施加分贝增益。
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param gain: Gain in decibels to apply to samples.
|
||||||
|
:type gain: float|1darray
|
||||||
|
"""
|
||||||
|
self._samples *= 10.**(gain / 20.)
|
||||||
|
|
||||||
|
def change_speed(self, speed_rate):
|
||||||
|
"""通过线性插值改变音频速度
|
||||||
|
|
||||||
|
:param speed_rate: Rate of speed change:
|
||||||
|
speed_rate > 1.0, speed up the audio;
|
||||||
|
speed_rate = 1.0, unchanged;
|
||||||
|
speed_rate < 1.0, slow down the audio;
|
||||||
|
speed_rate <= 0.0, not allowed, raise ValueError.
|
||||||
|
:type speed_rate: float
|
||||||
|
:raises ValueError: If speed_rate <= 0.0.
|
||||||
|
"""
|
||||||
|
if speed_rate == 1.0:
|
||||||
|
return
|
||||||
|
if speed_rate <= 0:
|
||||||
|
raise ValueError("速度速率应大于零")
|
||||||
|
old_length = self._samples.shape[0]
|
||||||
|
new_length = int(old_length / speed_rate)
|
||||||
|
old_indices = np.arange(old_length)
|
||||||
|
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
|
||||||
|
self._samples = np.interp(new_indices, old_indices, self._samples).astype(np.float32)
|
||||||
|
|
||||||
|
def normalize(self, target_db=-20, max_gain_db=300.0):
|
||||||
|
"""将音频归一化,使其具有所需的有效值(以分贝为单位)
|
||||||
|
|
||||||
|
:param target_db: Target RMS value in decibels. This value should be
|
||||||
|
less than 0.0 as 0.0 is full-scale audio.
|
||||||
|
:type target_db: float
|
||||||
|
:param max_gain_db: Max amount of gain in dB that can be applied for
|
||||||
|
normalization. This is to prevent nans when
|
||||||
|
attempting to normalize a signal consisting of
|
||||||
|
all zeros.
|
||||||
|
:type max_gain_db: float
|
||||||
|
:raises ValueError: If the required gain to normalize the segment to
|
||||||
|
the target_db value exceeds max_gain_db.
|
||||||
|
"""
|
||||||
|
if -np.inf == self.rms_db: return
|
||||||
|
gain = target_db - self.rms_db
|
||||||
|
if gain > max_gain_db:
|
||||||
|
raise ValueError(
|
||||||
|
"无法将段规范化到 %f dB,因为可能的增益已经超过max_gain_db (%f dB)" % (target_db, max_gain_db))
|
||||||
|
self.gain_db(min(max_gain_db, target_db - self.rms_db))
|
||||||
|
|
||||||
|
def resample(self, target_sample_rate, filter='kaiser_best'):
|
||||||
|
"""按目标采样率重新采样音频
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param target_sample_rate: Target sample rate.
|
||||||
|
:type target_sample_rate: int
|
||||||
|
:param filter: The resampling filter to use one of {'kaiser_best', 'kaiser_fast'}.
|
||||||
|
:type filter: str
|
||||||
|
"""
|
||||||
|
self._samples = resampy.resample(self.samples, self.sample_rate, target_sample_rate, filter=filter)
|
||||||
|
self._sample_rate = target_sample_rate
|
||||||
|
|
||||||
|
def pad_silence(self, duration, sides='both'):
|
||||||
|
"""在这个音频样本上加一段静音
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param duration: Length of silence in seconds to pad.
|
||||||
|
:type duration: float
|
||||||
|
:param sides: Position for padding:
|
||||||
|
'beginning' - adds silence in the beginning;
|
||||||
|
'end' - adds silence in the end;
|
||||||
|
'both' - adds silence in both the beginning and the end.
|
||||||
|
:type sides: str
|
||||||
|
:raises ValueError: If sides is not supported.
|
||||||
|
"""
|
||||||
|
if duration == 0.0:
|
||||||
|
return self
|
||||||
|
cls = type(self)
|
||||||
|
silence = self.make_silence(duration, self._sample_rate)
|
||||||
|
if sides == "beginning":
|
||||||
|
padded = cls.concatenate(silence, self)
|
||||||
|
elif sides == "end":
|
||||||
|
padded = cls.concatenate(self, silence)
|
||||||
|
elif sides == "both":
|
||||||
|
padded = cls.concatenate(silence, self, silence)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown value for the sides %s" % sides)
|
||||||
|
self._samples = padded._samples
|
||||||
|
|
||||||
|
def shift(self, shift_ms):
|
||||||
|
"""音频偏移。如果shift_ms为正,则随时间提前移位;如果为负,则随时间延迟移位。填补静音以保持持续时间不变。
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param shift_ms: Shift time in millseconds. If positive, shift with
|
||||||
|
time advance; if negative; shift with time delay.
|
||||||
|
:type shift_ms: float
|
||||||
|
:raises ValueError: If shift_ms is longer than audio duration.
|
||||||
|
"""
|
||||||
|
if abs(shift_ms) / 1000.0 > self.duration:
|
||||||
|
raise ValueError("shift_ms的绝对值应该小于音频持续时间")
|
||||||
|
shift_samples = int(shift_ms * self._sample_rate / 1000)
|
||||||
|
if shift_samples > 0:
|
||||||
|
# time advance
|
||||||
|
self._samples[:-shift_samples] = self._samples[shift_samples:]
|
||||||
|
self._samples[-shift_samples:] = 0
|
||||||
|
elif shift_samples < 0:
|
||||||
|
# time delay
|
||||||
|
self._samples[-shift_samples:] = self._samples[:shift_samples]
|
||||||
|
self._samples[:-shift_samples] = 0
|
||||||
|
|
||||||
|
def subsegment(self, start_sec=None, end_sec=None):
|
||||||
|
"""在给定的边界之间切割音频片段
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param start_sec: Beginning of subsegment in seconds.
|
||||||
|
:type start_sec: float
|
||||||
|
:param end_sec: End of subsegment in seconds.
|
||||||
|
:type end_sec: float
|
||||||
|
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
|
||||||
|
of bounds in time.
|
||||||
|
"""
|
||||||
|
start_sec = 0.0 if start_sec is None else start_sec
|
||||||
|
end_sec = self.duration if end_sec is None else end_sec
|
||||||
|
if start_sec < 0.0:
|
||||||
|
start_sec = self.duration + start_sec
|
||||||
|
if end_sec < 0.0:
|
||||||
|
end_sec = self.duration + end_sec
|
||||||
|
if start_sec < 0.0:
|
||||||
|
raise ValueError("切片起始位置(%f s)越界" % start_sec)
|
||||||
|
if end_sec < 0.0:
|
||||||
|
raise ValueError("切片结束位置(%f s)越界" % end_sec)
|
||||||
|
if start_sec > end_sec:
|
||||||
|
raise ValueError("切片的起始位置(%f s)晚于结束位置(%f s)" % (start_sec, end_sec))
|
||||||
|
if end_sec > self.duration:
|
||||||
|
raise ValueError("切片结束位置(%f s)越界(> %f s)" % (end_sec, self.duration))
|
||||||
|
start_sample = int(round(start_sec * self._sample_rate))
|
||||||
|
end_sample = int(round(end_sec * self._sample_rate))
|
||||||
|
self._samples = self._samples[start_sample:end_sample]
|
||||||
|
|
||||||
|
def random_subsegment(self, subsegment_length):
|
||||||
|
"""随机剪切指定长度的音频片段
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param subsegment_length: Subsegment length in seconds.
|
||||||
|
:type subsegment_length: float
|
||||||
|
:raises ValueError: If the length of subsegment is greater than
|
||||||
|
the origineal segemnt.
|
||||||
|
"""
|
||||||
|
if subsegment_length > self.duration:
|
||||||
|
raise ValueError("Length of subsegment must not be greater "
|
||||||
|
"than original segment.")
|
||||||
|
start_time = random.uniform(0.0, self.duration - subsegment_length)
|
||||||
|
self.subsegment(start_time, start_time + subsegment_length)
|
||||||
|
|
||||||
|
def add_noise(self,
|
||||||
|
noise,
|
||||||
|
snr_dB,
|
||||||
|
max_gain_db=300.0):
|
||||||
|
"""以特定的信噪比添加给定的噪声段。如果噪声段比该噪声段长,则从该噪声段中采样匹配长度的随机子段。
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param noise: Noise signal to add.
|
||||||
|
:type noise: AudioSegment
|
||||||
|
:param snr_dB: Signal-to-Noise Ratio, in decibels.
|
||||||
|
:type snr_dB: float
|
||||||
|
:param max_gain_db: Maximum amount of gain to apply to noise signal
|
||||||
|
before adding it in. This is to prevent attempting
|
||||||
|
to apply infinite gain to a zero signal.
|
||||||
|
:type max_gain_db: float
|
||||||
|
:raises ValueError: If the sample rate does not match between the two
|
||||||
|
audio segments, or if the duration of noise segments
|
||||||
|
is shorter than original audio segments.
|
||||||
|
"""
|
||||||
|
if noise.sample_rate != self.sample_rate:
|
||||||
|
raise ValueError("噪声采样率(%d Hz)不等于基信号采样率(%d Hz)" % (noise.sample_rate, self.sample_rate))
|
||||||
|
if noise.duration < self.duration:
|
||||||
|
raise ValueError("噪声信号(%f秒)必须至少与基信号(%f秒)一样长" % (noise.duration, self.duration))
|
||||||
|
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
|
||||||
|
noise_new = copy.deepcopy(noise)
|
||||||
|
noise_new.random_subsegment(self.duration)
|
||||||
|
noise_new.gain_db(noise_gain_db)
|
||||||
|
self.superimpose(noise_new)
|
||||||
|
|
||||||
|
def vad(self, top_db=20, overlap=0):
|
||||||
|
self._samples = vad(wav=self._samples, top_db=top_db, overlap=overlap)
|
||||||
|
|
||||||
|
def crop(self, duration, mode='eval'):
|
||||||
|
if self.duration > duration:
|
||||||
|
if mode == 'train':
|
||||||
|
self.random_subsegment(duration)
|
||||||
|
else:
|
||||||
|
self.subsegment(end_sec=duration)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def samples(self):
|
||||||
|
"""返回音频样本
|
||||||
|
|
||||||
|
:return: Audio samples.
|
||||||
|
:rtype: ndarray
|
||||||
|
"""
|
||||||
|
return self._samples.copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self):
|
||||||
|
"""返回音频采样率
|
||||||
|
|
||||||
|
:return: Audio sample rate.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
return self._sample_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self):
|
||||||
|
"""返回样品数量
|
||||||
|
|
||||||
|
:return: Number of samples.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
return self._samples.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self):
|
||||||
|
"""返回音频持续时间
|
||||||
|
|
||||||
|
:return: Audio duration in seconds.
|
||||||
|
:rtype: float
|
||||||
|
"""
|
||||||
|
return self._samples.shape[0] / float(self._sample_rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rms_db(self):
|
||||||
|
"""返回以分贝为单位的音频均方根能量
|
||||||
|
|
||||||
|
:return: Root mean square energy in decibels.
|
||||||
|
:rtype: float
|
||||||
|
"""
|
||||||
|
# square root => multiply by 10 instead of 20 for dBs
|
||||||
|
mean_square = np.mean(self._samples ** 2)
|
||||||
|
return 10 * np.log10(mean_square)
|
||||||
|
|
||||||
|
def _convert_samples_to_float32(self, samples):
|
||||||
|
"""Convert sample type to float32.
|
||||||
|
|
||||||
|
Audio sample type is usually integer or float-point.
|
||||||
|
Integers will be scaled to [-1, 1] in float32.
|
||||||
|
"""
|
||||||
|
float32_samples = samples.astype('float32')
|
||||||
|
if samples.dtype in np.sctypes['int']:
|
||||||
|
bits = np.iinfo(samples.dtype).bits
|
||||||
|
float32_samples *= (1. / 2 ** (bits - 1))
|
||||||
|
elif samples.dtype in np.sctypes['float']:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||||
|
return float32_samples
|
||||||
|
|
||||||
|
def _convert_samples_from_float32(self, samples, dtype):
|
||||||
|
"""Convert sample type from float32 to dtype.
|
||||||
|
|
||||||
|
Audio sample type is usually integer or float-point. For integer
|
||||||
|
type, float32 will be rescaled from [-1, 1] to the maximum range
|
||||||
|
supported by the integer type.
|
||||||
|
|
||||||
|
This is for writing a audio file.
|
||||||
|
"""
|
||||||
|
dtype = np.dtype(dtype)
|
||||||
|
output_samples = samples.copy()
|
||||||
|
if dtype in np.sctypes['int']:
|
||||||
|
bits = np.iinfo(dtype).bits
|
||||||
|
output_samples *= (2 ** (bits - 1) / 1.)
|
||||||
|
min_val = np.iinfo(dtype).min
|
||||||
|
max_val = np.iinfo(dtype).max
|
||||||
|
output_samples[output_samples > max_val] = max_val
|
||||||
|
output_samples[output_samples < min_val] = min_val
|
||||||
|
elif samples.dtype in np.sctypes['float']:
|
||||||
|
min_val = np.finfo(dtype).min
|
||||||
|
max_val = np.finfo(dtype).max
|
||||||
|
output_samples[output_samples > max_val] = max_val
|
||||||
|
output_samples[output_samples < min_val] = min_val
|
||||||
|
else:
|
||||||
|
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||||
|
return output_samples.astype(dtype)
|
||||||
|
|
||||||
|
def save(self, path, dtype='float32'):
|
||||||
|
"""保存音频段到磁盘为wav文件
|
||||||
|
|
||||||
|
:param path: WAV文件路径或文件对象,以保存音频段
|
||||||
|
:type path: str|file
|
||||||
|
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
|
||||||
|
'float32', 'float64'. Default is 'float32'.
|
||||||
|
:type dtype: str
|
||||||
|
:raises TypeError: If dtype is not supported.
|
||||||
|
"""
|
||||||
|
self.to_wav_file(path, dtype)
|
||||||
|
|
||||||
|
# 静音去除
|
||||||
|
@classmethod
|
||||||
|
def silent_semoval(self, inputpath, outputpath):
|
||||||
|
# 读取音频文件
|
||||||
|
audio = AudioSegment.from_file(inputpath)
|
||||||
|
# 语音活动检测
|
||||||
|
audio.vad()
|
||||||
|
# 保存裁剪后的音频
|
||||||
|
audio.save(outputpath)
|
||||||
0
mvector/data_utils/augmentor/__init__.py
Normal file
0
mvector/data_utils/augmentor/__init__.py
Normal file
BIN
mvector/data_utils/augmentor/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/data_utils/augmentor/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
mvector/data_utils/augmentor/__pycache__/base.cpython-37.pyc
Normal file
BIN
mvector/data_utils/augmentor/__pycache__/base.cpython-37.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
mvector/data_utils/augmentor/__pycache__/resample.cpython-37.pyc
Normal file
BIN
mvector/data_utils/augmentor/__pycache__/resample.cpython-37.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
115
mvector/data_utils/augmentor/augmentation.py
Normal file
115
mvector/data_utils/augmentor/augmentation.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
"""Contains the data augmentation pipeline."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
from mvector.data_utils.augmentor.noise_perturb import NoisePerturbAugmentor
|
||||||
|
from mvector.data_utils.augmentor.resample import ResampleAugmentor
|
||||||
|
from mvector.data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
|
||||||
|
from mvector.data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
|
||||||
|
from mvector.data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
|
||||||
|
from mvector.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationPipeline(object):
|
||||||
|
"""Build a pre-processing pipeline with various augmentation models.Such a
|
||||||
|
data augmentation pipeline is oftern leveraged to augment the training
|
||||||
|
samples to make the model invariant to certain types of perturbations in the
|
||||||
|
real world, improving model's generalization ability.
|
||||||
|
|
||||||
|
The pipeline is built according the the augmentation configuration in json
|
||||||
|
string, e.g.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "noise",
|
||||||
|
"params": {
|
||||||
|
"min_snr_dB": 10,
|
||||||
|
"max_snr_dB": 50,
|
||||||
|
"noise_manifest_path": "dataset/manifest.noise"
|
||||||
|
},
|
||||||
|
"prob": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "speed",
|
||||||
|
"params": {
|
||||||
|
"min_speed_rate": 0.9,
|
||||||
|
"max_speed_rate": 1.1,
|
||||||
|
"num_rates": 3
|
||||||
|
},
|
||||||
|
"prob": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "shift",
|
||||||
|
"params": {
|
||||||
|
"min_shift_ms": -5,
|
||||||
|
"max_shift_ms": 5
|
||||||
|
},
|
||||||
|
"prob": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "volume",
|
||||||
|
"params": {
|
||||||
|
"min_gain_dBFS": -15,
|
||||||
|
"max_gain_dBFS": 15
|
||||||
|
},
|
||||||
|
"prob": 1.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
This augmentation configuration inserts two augmentation models
|
||||||
|
into the pipeline, with one is VolumePerturbAugmentor and the other
|
||||||
|
SpeedPerturbAugmentor. "prob" indicates the probability of the current
|
||||||
|
augmentor to take effect. If "prob" is zero, the augmentor does not take
|
||||||
|
effect.
|
||||||
|
|
||||||
|
:param augmentation_config: Augmentation configuration in json string.
|
||||||
|
:type augmentation_config: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, augmentation_config):
|
||||||
|
self._augmentors, self._rates = self._parse_pipeline_from(augmentation_config, aug_type='audio')
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Run the pre-processing pipeline for data augmentation.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to process.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
for augmentor, rate in zip(self._augmentors, self._rates):
|
||||||
|
if random.random() < rate:
|
||||||
|
augmentor.transform_audio(audio_segment)
|
||||||
|
|
||||||
|
def _parse_pipeline_from(self, config_json, aug_type):
|
||||||
|
"""Parse the config json to build a augmentation pipelien."""
|
||||||
|
try:
|
||||||
|
configs = []
|
||||||
|
configs_temp = json.loads(config_json)
|
||||||
|
for config in configs_temp:
|
||||||
|
if config['aug_type'] != aug_type: continue
|
||||||
|
logger.info('数据增强配置:%s' % config)
|
||||||
|
configs.append(config)
|
||||||
|
augmentors = [self._get_augmentor(config["type"], config["params"]) for config in configs]
|
||||||
|
rates = [config["prob"] for config in configs]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("Failed to parse the augmentation config json: %s" % str(e))
|
||||||
|
return augmentors, rates
|
||||||
|
|
||||||
|
def _get_augmentor(self, augmentor_type, params):
|
||||||
|
"""Return an augmentation model by the type name, and pass in params."""
|
||||||
|
if augmentor_type == "volume":
|
||||||
|
return VolumePerturbAugmentor(**params)
|
||||||
|
elif augmentor_type == "shift":
|
||||||
|
return ShiftPerturbAugmentor(**params)
|
||||||
|
elif augmentor_type == "speed":
|
||||||
|
return SpeedPerturbAugmentor(**params)
|
||||||
|
elif augmentor_type == "resample":
|
||||||
|
return ResampleAugmentor(**params)
|
||||||
|
elif augmentor_type == "noise":
|
||||||
|
return NoisePerturbAugmentor(**params)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
|
||||||
30
mvector/data_utils/augmentor/base.py
Normal file
30
mvector/data_utils/augmentor/base.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Contains the abstract base class for augmentation models."""
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentorBase(object):
|
||||||
|
"""Abstract base class for augmentation model (augmentor) class.
|
||||||
|
All augmentor classes should inherit from this class, and implement the
|
||||||
|
following abstract methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__metaclass__ = ABCMeta
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Adds various effects to the input audio segment. Such effects
|
||||||
|
will augment the training data to make the model invariant to certain
|
||||||
|
types of perturbations in the real world, improving model's
|
||||||
|
generalization ability.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
pass
|
||||||
57
mvector/data_utils/augmentor/noise_perturb.py
Normal file
57
mvector/data_utils/augmentor/noise_perturb.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Contains the noise perturb augmentation model."""
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||||
|
from mvector.data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
|
||||||
|
class NoisePerturbAugmentor(AugmentorBase):
|
||||||
|
"""用于添加背景噪声的增强模型
|
||||||
|
|
||||||
|
:param min_snr_dB: Minimal signal noise ratio, in decibels.
|
||||||
|
:type min_snr_dB: float
|
||||||
|
:param max_snr_dB: Maximal signal noise ratio, in decibels.
|
||||||
|
:type max_snr_dB: float
|
||||||
|
:param repetition: repetition noise sum
|
||||||
|
:type repetition: int
|
||||||
|
:param noise_dir: noise audio file dir.
|
||||||
|
:type noise_dir: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_snr_dB, max_snr_dB, repetition, noise_dir):
|
||||||
|
self._min_snr_dB = min_snr_dB
|
||||||
|
self._max_snr_dB = max_snr_dB
|
||||||
|
self.repetition = repetition
|
||||||
|
self.noises_path = []
|
||||||
|
if os.path.exists(noise_dir):
|
||||||
|
for file in os.listdir(noise_dir):
|
||||||
|
self.noises_path.append(os.path.join(noise_dir, file))
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment: AudioSegment):
|
||||||
|
"""Add background noise audio.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet
|
||||||
|
"""
|
||||||
|
if len(self.noises_path) > 0:
|
||||||
|
for _ in range(random.randint(1, self.repetition)):
|
||||||
|
# 随机选择一个noises_path中的一个
|
||||||
|
noise_path = random.sample(self.noises_path, 1)[0]
|
||||||
|
# 读取噪声音频
|
||||||
|
noise_segment = AudioSegment.from_file(noise_path)
|
||||||
|
# 如果噪声采样率不等于audio_segment的采样率,则重采样
|
||||||
|
if noise_segment.sample_rate != audio_segment.sample_rate:
|
||||||
|
noise_segment.resample(audio_segment.sample_rate)
|
||||||
|
# 随机生成snr_dB的值
|
||||||
|
snr_dB = random.uniform(self._min_snr_dB, self._max_snr_dB)
|
||||||
|
# 如果噪声的长度小于audio_segment的长度,则将噪声的前面的部分填充噪声末尾补长
|
||||||
|
if noise_segment.duration < audio_segment.duration:
|
||||||
|
diff_duration = audio_segment.num_samples - noise_segment.num_samples
|
||||||
|
noise_segment._samples = np.pad(noise_segment.samples, (0, diff_duration), 'wrap')
|
||||||
|
# 将噪声添加到audio_segment中,并将snr_dB调整到最小值和最大值之间
|
||||||
|
audio_segment.add_noise(noise_segment, snr_dB)
|
||||||
31
mvector/data_utils/augmentor/resample.py
Normal file
31
mvector/data_utils/augmentor/resample.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""Contain the resample augmentation model."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mvector.data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleAugmentor(AugmentorBase):
|
||||||
|
"""重采样的增强模型
|
||||||
|
|
||||||
|
See more info here:
|
||||||
|
https://ccrma.stanford.edu/~jos/resample/index.html
|
||||||
|
|
||||||
|
:param new_sample_rate: New sample rate in Hz.
|
||||||
|
:type new_sample_rate: int
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, new_sample_rate: list):
|
||||||
|
self._new_sample_rate = new_sample_rate
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment: AudioSegment):
|
||||||
|
"""Resamples the input audio to a target sample rate.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegment|SpeechSegment
|
||||||
|
"""
|
||||||
|
_new_sample_rate = np.random.choice(self._new_sample_rate)
|
||||||
|
audio_segment.resample(_new_sample_rate)
|
||||||
31
mvector/data_utils/augmentor/shift_perturb.py
Normal file
31
mvector/data_utils/augmentor/shift_perturb.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""Contains the volume perturb augmentation model."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
from mvector.data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class ShiftPerturbAugmentor(AugmentorBase):
|
||||||
|
"""添加随机位移扰动的增强模型
|
||||||
|
|
||||||
|
:param min_shift_ms: Minimal shift in milliseconds.
|
||||||
|
:type min_shift_ms: float
|
||||||
|
:param max_shift_ms: Maximal shift in milliseconds.
|
||||||
|
:type max_shift_ms: float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_shift_ms, max_shift_ms):
|
||||||
|
self._min_shift_ms = min_shift_ms
|
||||||
|
self._max_shift_ms = max_shift_ms
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment: AudioSegment):
|
||||||
|
"""Shift audio.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
shift_ms = random.uniform(self._min_shift_ms, self._max_shift_ms)
|
||||||
|
audio_segment.shift(shift_ms)
|
||||||
50
mvector/data_utils/augmentor/speed_perturb.py
Normal file
50
mvector/data_utils/augmentor/speed_perturb.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
"""Contain the speech perturbation augmentation model."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mvector.data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedPerturbAugmentor(AugmentorBase):
|
||||||
|
"""添加速度扰动的增强模型
|
||||||
|
|
||||||
|
See reference paper here:
|
||||||
|
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
|
||||||
|
|
||||||
|
:param min_speed_rate: Lower bound of new speed rate to sample and should
|
||||||
|
not be smaller than 0.9.
|
||||||
|
:type min_speed_rate: float
|
||||||
|
:param max_speed_rate: Upper bound of new speed rate to sample and should
|
||||||
|
not be larger than 1.1.
|
||||||
|
:type max_speed_rate: float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=3):
|
||||||
|
if min_speed_rate < 0.9:
|
||||||
|
raise ValueError("Sampling speed below 0.9 can cause unnatural effects")
|
||||||
|
if max_speed_rate > 1.1:
|
||||||
|
raise ValueError("Sampling speed above 1.1 can cause unnatural effects")
|
||||||
|
self._min_speed_rate = min_speed_rate
|
||||||
|
self._max_speed_rate = max_speed_rate
|
||||||
|
self._num_rates = num_rates
|
||||||
|
if num_rates > 0:
|
||||||
|
self._rates = np.linspace(self._min_speed_rate, self._max_speed_rate, self._num_rates, endpoint=True)
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment: AudioSegment):
|
||||||
|
"""Sample a new speed rate from the given range and
|
||||||
|
changes the speed of the given audio clip.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegment|SpeechSegment
|
||||||
|
"""
|
||||||
|
if self._num_rates < 0:
|
||||||
|
speed_rate = random.uniform(self._min_speed_rate, self._max_speed_rate)
|
||||||
|
else:
|
||||||
|
speed_rate = random.choice(self._rates)
|
||||||
|
|
||||||
|
if speed_rate == 1.0: return
|
||||||
|
audio_segment.change_speed(speed_rate)
|
||||||
37
mvector/data_utils/augmentor/volume_perturb.py
Normal file
37
mvector/data_utils/augmentor/volume_perturb.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
"""Contains the volume perturb augmentation model."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
from mvector.data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class VolumePerturbAugmentor(AugmentorBase):
|
||||||
|
"""添加随机音量扰动的增强模型
|
||||||
|
|
||||||
|
This is used for multi-loudness training of PCEN. See
|
||||||
|
|
||||||
|
https://arxiv.org/pdf/1607.05666v1.pdf
|
||||||
|
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
:param min_gain_dBFS: Minimal gain in dBFS.
|
||||||
|
:type min_gain_dBFS: float
|
||||||
|
:param max_gain_dBFS: Maximal gain in dBFS.
|
||||||
|
:type max_gain_dBFS: float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_gain_dBFS, max_gain_dBFS):
|
||||||
|
self._min_gain_dBFS = min_gain_dBFS
|
||||||
|
self._max_gain_dBFS = max_gain_dBFS
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment: AudioSegment):
|
||||||
|
"""Change audio loadness.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
gain = random.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
|
||||||
|
audio_segment.gain_db(gain)
|
||||||
25
mvector/data_utils/collate_fn.py
Normal file
25
mvector/data_utils/collate_fn.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# 对一个batch的数据处理
|
||||||
|
def collate_fn(batch):
|
||||||
|
# 找出音频长度最长的
|
||||||
|
batch = sorted(batch, key=lambda sample: sample[0].shape[0], reverse=True)
|
||||||
|
max_audio_length = batch[0][0].shape[0]
|
||||||
|
batch_size = len(batch)
|
||||||
|
# 以最大的长度创建0张量
|
||||||
|
inputs = np.zeros((batch_size, max_audio_length), dtype='float32')
|
||||||
|
input_lens_ratio = []
|
||||||
|
labels = []
|
||||||
|
for x in range(batch_size):
|
||||||
|
sample = batch[x]
|
||||||
|
tensor = sample[0]
|
||||||
|
labels.append(sample[1])
|
||||||
|
seq_length = tensor.shape[0]
|
||||||
|
# 将数据插入都0张量中,实现了padding
|
||||||
|
inputs[x, :seq_length] = tensor[:]
|
||||||
|
input_lens_ratio.append(seq_length/max_audio_length)
|
||||||
|
input_lens_ratio = np.array(input_lens_ratio, dtype='float32')
|
||||||
|
labels = np.array(labels, dtype='int64')
|
||||||
|
return torch.tensor(inputs), torch.tensor(labels), torch.tensor(input_lens_ratio)
|
||||||
103
mvector/data_utils/featurizer.py
Normal file
103
mvector/data_utils/featurizer.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torchaudio.transforms import MelSpectrogram, Spectrogram, MFCC
|
||||||
|
import torchaudio.compliance.kaldi as Kaldi
|
||||||
|
|
||||||
|
|
||||||
|
class AudioFeaturizer(nn.Module):
|
||||||
|
"""音频特征器
|
||||||
|
|
||||||
|
:param feature_method: 所使用的预处理方法
|
||||||
|
:type feature_method: str
|
||||||
|
:param feature_conf: 预处理方法的参数
|
||||||
|
:type feature_conf: dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, feature_method='MelSpectrogram', feature_conf={}):
|
||||||
|
super().__init__()
|
||||||
|
self._feature_conf = feature_conf
|
||||||
|
self._feature_method = feature_method
|
||||||
|
if feature_method == 'MelSpectrogram':
|
||||||
|
self.feat_fun = MelSpectrogram(**feature_conf)
|
||||||
|
elif feature_method == 'Spectrogram':
|
||||||
|
self.feat_fun = Spectrogram(**feature_conf)
|
||||||
|
elif feature_method == 'MFCC':
|
||||||
|
melkwargs = feature_conf.copy()
|
||||||
|
del melkwargs['sample_rate']
|
||||||
|
del melkwargs['n_mfcc']
|
||||||
|
self.feat_fun = MFCC(sample_rate=self._feature_conf.sample_rate,
|
||||||
|
n_mfcc=self._feature_conf.n_mfcc,
|
||||||
|
melkwargs=melkwargs)
|
||||||
|
elif feature_method == 'Fbank':
|
||||||
|
self.feat_fun = KaldiFbank(**feature_conf)
|
||||||
|
else:
|
||||||
|
raise Exception(f'预处理方法 {self._feature_method} 不存在!')
|
||||||
|
|
||||||
|
def forward(self, waveforms, input_lens_ratio):
|
||||||
|
"""从AudioSegment中提取音频特征
|
||||||
|
|
||||||
|
:param waveforms: Audio segment to extract features from.
|
||||||
|
:type waveforms: AudioSegment
|
||||||
|
:param input_lens_ratio: input length ratio
|
||||||
|
:type input_lens_ratio: tensor
|
||||||
|
:return: Spectrogram audio feature in 2darray.
|
||||||
|
:rtype: ndarray
|
||||||
|
"""
|
||||||
|
feature = self.feat_fun(waveforms)
|
||||||
|
feature = feature.transpose(2, 1)
|
||||||
|
# 归一化
|
||||||
|
mean = torch.mean(feature, 1, keepdim=True)
|
||||||
|
std = torch.std(feature, 1, keepdim=True)
|
||||||
|
feature = (feature - mean) / (std + 1e-5)
|
||||||
|
# 对掩码比例进行扩展
|
||||||
|
input_lens = (input_lens_ratio * feature.shape[1])
|
||||||
|
mask_lens = torch.round(input_lens).long()
|
||||||
|
mask_lens = mask_lens.unsqueeze(1)
|
||||||
|
input_lens = input_lens.int()
|
||||||
|
# 生成掩码张量
|
||||||
|
idxs = torch.arange(feature.shape[1], device=feature.device).repeat(feature.shape[0], 1)
|
||||||
|
mask = idxs < mask_lens
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
# 对特征进行掩码操作
|
||||||
|
feature_masked = torch.where(mask, feature, torch.zeros_like(feature))
|
||||||
|
return feature_masked, input_lens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_dim(self):
|
||||||
|
"""返回特征大小
|
||||||
|
|
||||||
|
:return: 特征大小
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
if self._feature_method == 'LogMelSpectrogram':
|
||||||
|
return self._feature_conf.n_mels
|
||||||
|
elif self._feature_method == 'MelSpectrogram':
|
||||||
|
return self._feature_conf.n_mels
|
||||||
|
elif self._feature_method == 'Spectrogram':
|
||||||
|
return self._feature_conf.n_fft // 2 + 1
|
||||||
|
elif self._feature_method == 'MFCC':
|
||||||
|
return self._feature_conf.n_mfcc
|
||||||
|
elif self._feature_method == 'Fbank':
|
||||||
|
return self._feature_conf.num_mel_bins
|
||||||
|
else:
|
||||||
|
raise Exception('没有{}预处理方法'.format(self._feature_method))
|
||||||
|
|
||||||
|
class KaldiFbank(nn.Module):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(KaldiFbank, self).__init__()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def forward(self, waveforms):
|
||||||
|
"""
|
||||||
|
:param waveforms: [Batch, Length]
|
||||||
|
:return: [Batch, Length, Feature]
|
||||||
|
"""
|
||||||
|
log_fbanks = []
|
||||||
|
for waveform in waveforms:
|
||||||
|
if len(waveform.shape) == 1:
|
||||||
|
waveform = waveform.unsqueeze(0)
|
||||||
|
log_fbank = Kaldi.fbank(waveform, **self.kwargs)
|
||||||
|
log_fbank = log_fbank.transpose(0, 1)
|
||||||
|
log_fbanks.append(log_fbank)
|
||||||
|
log_fbank = torch.stack(log_fbanks)
|
||||||
|
return log_fbank
|
||||||
73
mvector/data_utils/reader.py
Normal file
73
mvector/data_utils/reader.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from mvector.data_utils.audio import AudioSegment
|
||||||
|
from mvector.data_utils.augmentor.augmentation import AugmentationPipeline
|
||||||
|
from mvector.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
data_list_path,
|
||||||
|
do_vad=True,
|
||||||
|
max_duration=6,
|
||||||
|
min_duration=0.5,
|
||||||
|
augmentation_config='{}',
|
||||||
|
mode='train',
|
||||||
|
sample_rate=16000,
|
||||||
|
use_dB_normalization=True,
|
||||||
|
target_dB=-20):
|
||||||
|
"""音频数据加载器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_list_path: 包含音频路径和标签的数据列表文件的路径
|
||||||
|
do_vad: 是否对音频进行语音活动检测(VAD)来裁剪静音部分
|
||||||
|
max_duration: 最长的音频长度,大于这个长度会裁剪掉
|
||||||
|
min_duration: 过滤最短的音频长度
|
||||||
|
augmentation_config: 用于指定音频增强的配置
|
||||||
|
mode: 数据集模式。在训练模式下,数据集可能会进行一些数据增强的预处理
|
||||||
|
sample_rate: 采样率
|
||||||
|
use_dB_normalization: 是否对音频进行音量归一化
|
||||||
|
target_dB: 音量归一化的大小
|
||||||
|
"""
|
||||||
|
super(CustomDataset, self).__init__()
|
||||||
|
self.do_vad = do_vad
|
||||||
|
self.max_duration = max_duration
|
||||||
|
self.min_duration = min_duration
|
||||||
|
self.mode = mode
|
||||||
|
self._target_sample_rate = sample_rate
|
||||||
|
self._use_dB_normalization = use_dB_normalization
|
||||||
|
self._target_dB = target_dB
|
||||||
|
self._augmentation_pipeline = AugmentationPipeline(augmentation_config=augmentation_config)
|
||||||
|
# 获取数据列表
|
||||||
|
with open(data_list_path, 'r') as f:
|
||||||
|
self.lines = f.readlines()
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# 分割音频路径和标签
|
||||||
|
audio_path, label = self.lines[idx].replace('\n', '').split('\t')
|
||||||
|
# 读取音频
|
||||||
|
audio_segment = AudioSegment.from_file(audio_path)
|
||||||
|
# 裁剪静音
|
||||||
|
if self.do_vad:
|
||||||
|
audio_segment.vad()
|
||||||
|
# 数据太短不利于训练
|
||||||
|
if self.mode == 'train':
|
||||||
|
if audio_segment.duration < self.min_duration:
|
||||||
|
return self.__getitem__(idx + 1 if idx < len(self.lines) - 1 else 0)
|
||||||
|
# 重采样
|
||||||
|
if audio_segment.sample_rate != self._target_sample_rate:
|
||||||
|
audio_segment.resample(self._target_sample_rate)
|
||||||
|
# decibel normalization
|
||||||
|
if self._use_dB_normalization:
|
||||||
|
audio_segment.normalize(target_db=self._target_dB)
|
||||||
|
# 裁剪需要的数据
|
||||||
|
audio_segment.crop(duration=self.max_duration, mode=self.mode)
|
||||||
|
# 音频增强
|
||||||
|
self._augmentation_pipeline.transform_audio(audio_segment)
|
||||||
|
return np.array(audio_segment.samples, dtype=np.float32), np.array(int(label), dtype=np.int64)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.lines)
|
||||||
179
mvector/data_utils/utils.py
Normal file
179
mvector/data_utils/utils.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
import io
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import av
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def vad(wav, top_db=10, overlap=200):
|
||||||
|
"""
|
||||||
|
去除音频中的静音部分
|
||||||
|
参数:
|
||||||
|
wav: 音频数据
|
||||||
|
top_db: 信噪比
|
||||||
|
overlap: 重叠长度
|
||||||
|
返回值:
|
||||||
|
wav_output: 去除静音后的音频数据
|
||||||
|
"""
|
||||||
|
intervals = librosa.effects.split(wav, top_db=top_db)
|
||||||
|
if len(intervals) == 0:
|
||||||
|
return wav
|
||||||
|
wav_output = [np.array([])]
|
||||||
|
for sliced in intervals:
|
||||||
|
seg = wav[sliced[0]:sliced[1]]
|
||||||
|
if len(seg) < 2 * overlap:
|
||||||
|
wav_output[-1] = np.concatenate((wav_output[-1], seg))
|
||||||
|
else:
|
||||||
|
wav_output.append(seg)
|
||||||
|
wav_output = [x for x in wav_output if len(x) > 0]
|
||||||
|
|
||||||
|
if len(wav_output) == 1:
|
||||||
|
wav_output = wav_output[0]
|
||||||
|
else:
|
||||||
|
wav_output = concatenate(wav_output)
|
||||||
|
return wav_output
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate(wave, overlap=200):
|
||||||
|
"""
|
||||||
|
拼接音频
|
||||||
|
参数:
|
||||||
|
wave: 音频数据
|
||||||
|
overlap: 重叠长度
|
||||||
|
返回值:
|
||||||
|
unfolded: 拼接后的音频数据
|
||||||
|
"""
|
||||||
|
total_len = sum([len(x) for x in wave])
|
||||||
|
unfolded = np.zeros(total_len)
|
||||||
|
|
||||||
|
# Equal power crossfade
|
||||||
|
window = np.hanning(2 * overlap)
|
||||||
|
fade_in = window[:overlap]
|
||||||
|
fade_out = window[-overlap:]
|
||||||
|
|
||||||
|
end = total_len
|
||||||
|
for i in range(1, len(wave)):
|
||||||
|
prev = wave[i - 1]
|
||||||
|
curr = wave[i]
|
||||||
|
|
||||||
|
if i == 1:
|
||||||
|
end = len(prev)
|
||||||
|
unfolded[:end] += prev
|
||||||
|
|
||||||
|
max_idx = 0
|
||||||
|
max_corr = 0
|
||||||
|
pattern = prev[-overlap:]
|
||||||
|
# slide the curr batch to match with the pattern of previous one
|
||||||
|
for j in range(overlap):
|
||||||
|
match = curr[j:j + overlap]
|
||||||
|
corr = np.sum(pattern * match) / [(np.sqrt(np.sum(pattern ** 2)) * np.sqrt(np.sum(match ** 2))) + 1e-8]
|
||||||
|
if corr > max_corr:
|
||||||
|
max_idx = j
|
||||||
|
max_corr = corr
|
||||||
|
|
||||||
|
# Apply the gain to the overlap samples
|
||||||
|
start = end - overlap
|
||||||
|
unfolded[start:end] *= fade_out
|
||||||
|
end = start + (len(curr) - max_idx)
|
||||||
|
curr[max_idx:max_idx + overlap] *= fade_in
|
||||||
|
unfolded[start:end] += curr[max_idx:]
|
||||||
|
return unfolded[:end]
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio(file, sample_rate: int = 16000):
|
||||||
|
"""读取音频,主要用于兜底读取,支持各种数据格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: Path to the input file or a file-like object.
|
||||||
|
sample_rate: Resample the audio to this sample rate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A float32 Numpy array.
|
||||||
|
"""
|
||||||
|
resampler = av.audio.resampler.AudioResampler(format="s16", layout="mono", rate=sample_rate)
|
||||||
|
|
||||||
|
raw_buffer = io.BytesIO()
|
||||||
|
dtype = None
|
||||||
|
|
||||||
|
with av.open(file, metadata_errors="ignore") as container:
|
||||||
|
frames = container.decode(audio=0)
|
||||||
|
frames = _ignore_invalid_frames(frames)
|
||||||
|
frames = _group_frames(frames, 500000)
|
||||||
|
frames = _resample_frames(frames, resampler)
|
||||||
|
|
||||||
|
for frame in frames:
|
||||||
|
array = frame.to_ndarray()
|
||||||
|
dtype = array.dtype
|
||||||
|
raw_buffer.write(array)
|
||||||
|
|
||||||
|
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
|
||||||
|
|
||||||
|
# Convert s16 back to f32.
|
||||||
|
return audio.astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
def _ignore_invalid_frames(frames):
|
||||||
|
iterator = iter(frames)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield next(iterator)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
except av.error.InvalidDataError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def _group_frames(frames, num_samples=None):
|
||||||
|
fifo = av.audio.fifo.AudioFifo()
|
||||||
|
|
||||||
|
for frame in frames:
|
||||||
|
frame.pts = None # Ignore timestamp check.
|
||||||
|
fifo.write(frame)
|
||||||
|
|
||||||
|
if num_samples is not None and fifo.samples >= num_samples:
|
||||||
|
yield fifo.read()
|
||||||
|
|
||||||
|
if fifo.samples > 0:
|
||||||
|
yield fifo.read()
|
||||||
|
|
||||||
|
|
||||||
|
def _resample_frames(frames, resampler):
|
||||||
|
# Add None to flush the resampler.
|
||||||
|
for frame in itertools.chain(frames, [None]):
|
||||||
|
yield from resampler.resample(frame)
|
||||||
|
|
||||||
|
|
||||||
|
# 将音频流转换为numpy
|
||||||
|
def buf_to_float(x, n_bytes=2, dtype=np.float32):
|
||||||
|
"""Convert an integer buffer to floating point values.
|
||||||
|
This is primarily useful when loading integer-valued wav data
|
||||||
|
into numpy arrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : np.ndarray [dtype=int]
|
||||||
|
The integer-valued data buffer
|
||||||
|
|
||||||
|
n_bytes : int [1, 2, 4]
|
||||||
|
The number of bytes per sample in ``x``
|
||||||
|
|
||||||
|
dtype : numeric type
|
||||||
|
The target output type (default: 32-bit float)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
x_float : np.ndarray [dtype=float]
|
||||||
|
The input data buffer cast to floating point
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Invert the scale of the data
|
||||||
|
scale = 1.0 / float(1 << ((8 * n_bytes) - 1))
|
||||||
|
|
||||||
|
# Construct the format string
|
||||||
|
fmt = "<i{:d}".format(n_bytes)
|
||||||
|
|
||||||
|
# Rescale and format the data buffer
|
||||||
|
return scale * np.frombuffer(x, fmt).astype(dtype)
|
||||||
0
mvector/metric/__init__.py
Normal file
0
mvector/metric/__init__.py
Normal file
BIN
mvector/metric/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/metric/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/metric/__pycache__/metrics.cpython-37.pyc
Normal file
BIN
mvector/metric/__pycache__/metrics.cpython-37.pyc
Normal file
Binary file not shown.
63
mvector/metric/metrics.py
Normal file
63
mvector/metric/metrics.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mvector.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TprAtFpr(object):
|
||||||
|
def __init__(self, max_fpr=0.01):
|
||||||
|
self.pos_score_list = []
|
||||||
|
self.neg_score_list = []
|
||||||
|
self.max_fpr = max_fpr
|
||||||
|
|
||||||
|
def add(self, y_labels, y_scores):
|
||||||
|
for y_label, y_score in zip(y_labels, y_scores):
|
||||||
|
if y_label == 0:
|
||||||
|
self.neg_score_list.append(y_score)
|
||||||
|
else:
|
||||||
|
self.pos_score_list.append(y_score)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.pos_score_list = []
|
||||||
|
self.neg_score_list = []
|
||||||
|
|
||||||
|
def calculate_eer(self, tprs, fprs):
|
||||||
|
# 记录所有的eer值
|
||||||
|
eer_list = []
|
||||||
|
n = len(tprs)
|
||||||
|
eer = 1.0
|
||||||
|
index = 0
|
||||||
|
for i in range(n):
|
||||||
|
eer_list.append(fprs[i] + (1 - tprs[i]))
|
||||||
|
if fprs[i] + (1 - tprs[i]) < eer:
|
||||||
|
eer = fprs[i] + (1 - tprs[i])
|
||||||
|
index = i
|
||||||
|
return eer, index,eer_list
|
||||||
|
|
||||||
|
def calculate(self):
|
||||||
|
tprs, fprs, thresholds = [], [], []
|
||||||
|
pos_score_list = np.array(self.pos_score_list)
|
||||||
|
neg_score_list = np.array(self.neg_score_list)
|
||||||
|
if len(pos_score_list) == 0:
|
||||||
|
msg = f"The number of positive samples is 0, please add positive samples."
|
||||||
|
logger.warning(msg)
|
||||||
|
return tprs, fprs, thresholds, None, None
|
||||||
|
if len(neg_score_list) == 0:
|
||||||
|
msg = f"The number of negative samples is 0, please add negative samples."
|
||||||
|
logger.warning(msg)
|
||||||
|
return tprs, fprs, thresholds, None, None
|
||||||
|
for i in range(0, 100):
|
||||||
|
threshold = i / 100.
|
||||||
|
tpr = np.sum(pos_score_list > threshold) / len(pos_score_list)
|
||||||
|
fpr = np.sum(neg_score_list > threshold) / len(neg_score_list)
|
||||||
|
tprs.append(tpr)
|
||||||
|
fprs.append(fpr)
|
||||||
|
thresholds.append(threshold)
|
||||||
|
eer, index,eer_list = self.calculate_eer(fprs=fprs, tprs=tprs)
|
||||||
|
|
||||||
|
# 根据对应的eer_list输出所有对应的阈值
|
||||||
|
for i in range(len(eer_list)):
|
||||||
|
print(f"threshold: {thresholds[i]}, eer: {eer_list[i]}")
|
||||||
|
|
||||||
|
return tprs, fprs, thresholds, eer, index
|
||||||
0
mvector/models/__init__.py
Normal file
0
mvector/models/__init__.py
Normal file
BIN
mvector/models/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/models/__pycache__/ecapa_tdnn.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/ecapa_tdnn.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/models/__pycache__/fc.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/fc.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/models/__pycache__/loss.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/loss.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/models/__pycache__/pooling.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/pooling.cpython-37.pyc
Normal file
Binary file not shown.
189
mvector/models/ecapa_tdnn.py
Normal file
189
mvector/models/ecapa_tdnn.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from mvector.models.pooling import AttentiveStatsPool, TemporalAveragePooling
|
||||||
|
from mvector.models.pooling import SelfAttentivePooling, TemporalStatisticsPooling
|
||||||
|
|
||||||
|
|
||||||
|
class Res2Conv1dReluBn(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False, scale=4):
|
||||||
|
super().__init__()
|
||||||
|
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
||||||
|
self.scale = scale
|
||||||
|
self.width = channels // scale
|
||||||
|
self.nums = scale if scale == 1 else scale - 1
|
||||||
|
|
||||||
|
self.convs = []
|
||||||
|
self.bns = []
|
||||||
|
for i in range(self.nums):
|
||||||
|
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
|
||||||
|
self.bns.append(nn.BatchNorm1d(self.width))
|
||||||
|
self.convs = nn.ModuleList(self.convs)
|
||||||
|
self.bns = nn.ModuleList(self.bns)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = []
|
||||||
|
spx = torch.split(x, self.width, 1)
|
||||||
|
# 遍历每个分支
|
||||||
|
for i in range(self.nums):
|
||||||
|
if i == 0:
|
||||||
|
sp = spx[i]
|
||||||
|
else:
|
||||||
|
# 其他分支则将当前子特征与前面所有子特征相加,形成残差连接
|
||||||
|
sp = sp + spx[i]
|
||||||
|
# Order: conv -> relu -> bn
|
||||||
|
sp = self.convs[i](sp)
|
||||||
|
sp = self.bns[i](F.relu(sp))
|
||||||
|
out.append(sp)
|
||||||
|
if self.scale != 1:
|
||||||
|
out.append(spx[self.nums])
|
||||||
|
|
||||||
|
# 将所有子分支的结果在通道维度上合并
|
||||||
|
out = torch.cat(out, dim=1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dReluBn(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
||||||
|
self.bn = nn.BatchNorm1d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bn(F.relu(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class SE_Connect(nn.Module):
|
||||||
|
def __init__(self, channels, s=2):
|
||||||
|
super().__init__()
|
||||||
|
assert channels % s == 0, "{} % {} != 0".format(channels, s)
|
||||||
|
self.linear1 = nn.Linear(channels, channels // s)
|
||||||
|
self.linear2 = nn.Linear(channels // s, channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x.mean(dim=2)
|
||||||
|
out = F.relu(self.linear1(out))
|
||||||
|
out = torch.sigmoid(self.linear2(out))
|
||||||
|
out = x * out.unsqueeze(2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
||||||
|
"""
|
||||||
|
初始化函数。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- input_size: 输入尺寸,默认为80。
|
||||||
|
- channels: 通道数,默认为512。
|
||||||
|
- kernel_size: 卷积核大小, 默认为3。
|
||||||
|
- embd_dim: 嵌入维度,默认为192。
|
||||||
|
- pooling_type: 池化类型,默认为"ASP",可选值包括"ASP"、"SAP"、"TAP"、"TSP"。
|
||||||
|
- dilation : 空洞卷积的空洞率,默认为1。
|
||||||
|
- scale: SE模块的缩放比例,默认为8。
|
||||||
|
|
||||||
|
返回值:
|
||||||
|
- 无。
|
||||||
|
"""
|
||||||
|
return nn.Sequential(
|
||||||
|
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
||||||
|
Res2Conv1dReluBn(channels, kernel_size, stride, padding, dilation, scale=scale),
|
||||||
|
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
||||||
|
SE_Connect(channels)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EcapaTdnn(nn.Module):
|
||||||
|
"""
|
||||||
|
初始化函数。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- input_size: 输入尺寸,默认为80。
|
||||||
|
- channels: 通道数,默认为512。
|
||||||
|
- embd_dim: 嵌入维度,默认为192。
|
||||||
|
- pooling_type: 池化类型,默认为"ASP",可选值包括"ASP"、"SAP"、"TAP"、"TSP"。
|
||||||
|
"""
|
||||||
|
def __init__(self, input_size=80, channels=512, embd_dim=192, pooling_type="ASP"):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = Conv1dReluBn(input_size, channels, kernel_size=5, padding=2, dilation=1)
|
||||||
|
self.layer2 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8)
|
||||||
|
self.layer3 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8)
|
||||||
|
self.layer4 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8)
|
||||||
|
|
||||||
|
cat_channels = channels * 3
|
||||||
|
self.emb_size = embd_dim
|
||||||
|
self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1)
|
||||||
|
if pooling_type == "ASP":
|
||||||
|
self.pooling = AttentiveStatsPool(cat_channels, 128)
|
||||||
|
self.bn1 = nn.BatchNorm1d(cat_channels * 2)
|
||||||
|
self.linear = nn.Linear(cat_channels * 2, embd_dim)
|
||||||
|
self.bn2 = nn.BatchNorm1d(embd_dim)
|
||||||
|
elif pooling_type == "SAP":
|
||||||
|
self.pooling = SelfAttentivePooling(cat_channels, 128)
|
||||||
|
self.bn1 = nn.BatchNorm1d(cat_channels)
|
||||||
|
self.linear = nn.Linear(cat_channels, embd_dim)
|
||||||
|
self.bn2 = nn.BatchNorm1d(embd_dim)
|
||||||
|
elif pooling_type == "TAP":
|
||||||
|
self.pooling = TemporalAveragePooling()
|
||||||
|
self.bn1 = nn.BatchNorm1d(cat_channels)
|
||||||
|
self.linear = nn.Linear(cat_channels, embd_dim)
|
||||||
|
self.bn2 = nn.BatchNorm1d(embd_dim)
|
||||||
|
elif pooling_type == "TSP":
|
||||||
|
self.pooling = TemporalStatisticsPooling()
|
||||||
|
self.bn1 = nn.BatchNorm1d(cat_channels * 2)
|
||||||
|
self.linear = nn.Linear(cat_channels * 2, embd_dim)
|
||||||
|
self.bn2 = nn.BatchNorm1d(embd_dim)
|
||||||
|
else:
|
||||||
|
raise Exception(f'没有{pooling_type}池化层!')
|
||||||
|
|
||||||
|
# def forward(self, x):
|
||||||
|
# """
|
||||||
|
# Compute embeddings.
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# x (torch.Tensor): Input data with shape (N, time, freq).
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# torch.Tensor: Output embeddings with shape (N, self.emb_size, 1)
|
||||||
|
# """
|
||||||
|
# x = x.transpose(2, 1)
|
||||||
|
# out1 = self.layer1(x)
|
||||||
|
# out2 = self.layer2(out1) + out1
|
||||||
|
# out3 = self.layer3(out1 + out2) + out1 + out2
|
||||||
|
# out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
|
||||||
|
|
||||||
|
# out = torch.cat([out2, out3, out4], dim=1)
|
||||||
|
# out = F.relu(self.conv(out))
|
||||||
|
# out = self.bn1(self.pooling(out))
|
||||||
|
# out = self.bn2(self.linear(out))
|
||||||
|
# return out
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
计算嵌入向量。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
x (torch.Tensor): 输入数据,形状为 (N, time, freq),其中N为样本数量,time为时间维度,freq为频率维度。
|
||||||
|
|
||||||
|
返回值:
|
||||||
|
torch.Tensor: 输出嵌入向量,形状为 (N, self.emb_size, 1)
|
||||||
|
"""
|
||||||
|
# 将输入数据的频率和时间维度交换
|
||||||
|
x = x.transpose(2, 1)
|
||||||
|
# 通过第一层卷积层
|
||||||
|
out1 = self.layer1(x)
|
||||||
|
# 通过第二层卷积层,并与第一层输出相加
|
||||||
|
out2 = self.layer2(out1) + out1
|
||||||
|
# 通过第三层卷积层,并依次与前两层输出相加
|
||||||
|
out3 = self.layer3(out1 + out2) + out1 + out2
|
||||||
|
# 通过第四层卷积层,并依次与前三层输出相加
|
||||||
|
out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
|
||||||
|
|
||||||
|
# 将第二、三、四层的输出在特征维度上连接
|
||||||
|
out = torch.cat([out2, out3, out4], dim=1)
|
||||||
|
# 应用ReLU激活函数,并通过卷积层处理
|
||||||
|
out = F.relu(self.conv(out))
|
||||||
|
# 经过批归一化和池化操作
|
||||||
|
out = self.bn1(self.pooling(out))
|
||||||
|
# 经过线性变换和批归一化
|
||||||
|
out = self.bn2(self.linear(out))
|
||||||
|
return out
|
||||||
90
mvector/models/fc.py
Normal file
90
mvector/models/fc.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerIdetification(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone,
|
||||||
|
num_class=1,
|
||||||
|
loss_type='AAMLoss',
|
||||||
|
lin_blocks=0,
|
||||||
|
lin_neurons=192,
|
||||||
|
dropout=0.1, ):
|
||||||
|
|
||||||
|
"""
|
||||||
|
初始化说话人识别模型,包括说话人背骨网络和在训练中针对说话人类别数的线性变换。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
backbone (Paddle.nn.Layer class): 说话人识别背骨网络模型。
|
||||||
|
num_class (_type_): 训练数据集中说话人的类别数。
|
||||||
|
lin_blocks (int, 可选): 从嵌入向量到最终线性层之间的线性层变换数量。默认为0。
|
||||||
|
lin_neurons (int, 可选): 最终线性层的输出维度。默认为192。
|
||||||
|
dropout (float, 可选): 嵌入向量上的dropout因子。默认为0.1。
|
||||||
|
"""
|
||||||
|
super(SpeakerIdetification, self).__init__()
|
||||||
|
# 初始化背骨网络模型
|
||||||
|
# 背骨网络的输出为目标嵌入向量
|
||||||
|
self.backbone = backbone
|
||||||
|
self.loss_type = loss_type
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
else:
|
||||||
|
self.dropout = None
|
||||||
|
|
||||||
|
# 构建说话人分类器
|
||||||
|
input_size = self.backbone.emb_size
|
||||||
|
self.blocks = list()
|
||||||
|
# 添加线性层变换
|
||||||
|
for i in range(lin_blocks):
|
||||||
|
self.blocks.extend([
|
||||||
|
nn.BatchNorm1d(input_size),
|
||||||
|
nn.Linear(in_features=input_size, out_features=lin_neurons),
|
||||||
|
])
|
||||||
|
input_size = lin_neurons
|
||||||
|
|
||||||
|
# 最终层初始化
|
||||||
|
if self.loss_type == 'AAMLoss':
|
||||||
|
self.weight = Parameter(torch.FloatTensor(num_class, input_size), requires_grad=True)
|
||||||
|
nn.init.xavier_normal_(self.weight, gain=1)
|
||||||
|
elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss':
|
||||||
|
self.weight = Parameter(torch.FloatTensor(input_size, num_class), requires_grad=True)
|
||||||
|
nn.init.xavier_normal_(self.weight, gain=1)
|
||||||
|
elif self.loss_type == 'CELoss':
|
||||||
|
self.output = nn.Linear(input_size, num_class)
|
||||||
|
else:
|
||||||
|
raise Exception(f'没有{self.loss_type}损失函数!')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
执行说话人识别模型的前向传播,
|
||||||
|
包括说话人嵌入模型和分类器模型网络
|
||||||
|
|
||||||
|
参数:
|
||||||
|
x (paddle.Tensor): 输入的音频特征,
|
||||||
|
形状=[批大小, 时间, 维度]
|
||||||
|
|
||||||
|
返回值:
|
||||||
|
paddle.Tensor: 返回特征的logits
|
||||||
|
"""
|
||||||
|
# x.shape: (N, L, C)
|
||||||
|
x = self.backbone(x) # (N, emb_size)
|
||||||
|
if self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
for fc in self.blocks:
|
||||||
|
x = fc(x)
|
||||||
|
if self.loss_type == 'AAMLoss':
|
||||||
|
logits = F.linear(F.normalize(x), F.normalize(self.weight, dim=-1))
|
||||||
|
elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss':
|
||||||
|
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
|
||||||
|
x_norm = torch.div(x, x_norm)
|
||||||
|
w_norm = torch.norm(self.weight, p=2, dim=0, keepdim=True).clamp(min=1e-12)
|
||||||
|
w_norm = torch.div(self.weight, w_norm)
|
||||||
|
logits = torch.mm(x_norm, w_norm)
|
||||||
|
else:
|
||||||
|
logits = self.output(x)
|
||||||
|
|
||||||
|
return logits
|
||||||
97
mvector/models/loss.py
Normal file
97
mvector/models/loss.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class AdditiveAngularMargin(nn.Module):
|
||||||
|
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
|
||||||
|
"""The Implementation of Additive Angular Margin (AAM) proposed
|
||||||
|
in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''
|
||||||
|
(https://arxiv.org/abs/1906.07317)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
margin (float, optional): margin factor. Defaults to 0.0.
|
||||||
|
scale (float, optional): scale factor. Defaults to 1.0.
|
||||||
|
easy_margin (bool, optional): easy_margin flag. Defaults to False.
|
||||||
|
"""
|
||||||
|
super(AdditiveAngularMargin, self).__init__()
|
||||||
|
self.margin = margin
|
||||||
|
self.scale = scale
|
||||||
|
self.easy_margin = easy_margin
|
||||||
|
|
||||||
|
self.cos_m = math.cos(self.margin)
|
||||||
|
self.sin_m = math.sin(self.margin)
|
||||||
|
self.th = math.cos(math.pi - self.margin)
|
||||||
|
self.mm = math.sin(math.pi - self.margin) * self.margin
|
||||||
|
|
||||||
|
def forward(self, outputs, targets):
|
||||||
|
cosine = outputs.float()
|
||||||
|
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
||||||
|
phi = cosine * self.cos_m - sine * self.sin_m
|
||||||
|
if self.easy_margin:
|
||||||
|
phi = torch.where(cosine > 0, phi, cosine)
|
||||||
|
else:
|
||||||
|
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
||||||
|
outputs = (targets * phi) + ((1.0 - targets) * cosine)
|
||||||
|
return self.scale * outputs
|
||||||
|
|
||||||
|
|
||||||
|
class AAMLoss(nn.Module):
|
||||||
|
def __init__(self, margin=0.2, scale=30, easy_margin=False):
|
||||||
|
super(AAMLoss, self).__init__()
|
||||||
|
self.loss_fn = AdditiveAngularMargin(margin=margin, scale=scale, easy_margin=easy_margin)
|
||||||
|
self.criterion = torch.nn.KLDivLoss(reduction="sum")
|
||||||
|
|
||||||
|
def forward(self, outputs, targets):
|
||||||
|
targets = F.one_hot(targets, outputs.shape[1]).float()
|
||||||
|
predictions = self.loss_fn(outputs, targets)
|
||||||
|
predictions = F.log_softmax(predictions, dim=1)
|
||||||
|
loss = self.criterion(predictions, targets) / targets.sum()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class AMLoss(nn.Module):
|
||||||
|
def __init__(self, margin=0.2, scale=30):
|
||||||
|
super(AMLoss, self).__init__()
|
||||||
|
self.m = margin
|
||||||
|
self.s = scale
|
||||||
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
||||||
|
|
||||||
|
def forward(self, outputs, targets):
|
||||||
|
label_view = targets.view(-1, 1)
|
||||||
|
delt_costh = torch.zeros(outputs.size(), device=targets.device).scatter_(1, label_view, self.m)
|
||||||
|
costh_m = outputs - delt_costh
|
||||||
|
predictions = self.s * costh_m
|
||||||
|
loss = self.criterion(predictions, targets) / targets.shape[0]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class ARMLoss(nn.Module):
|
||||||
|
def __init__(self, margin=0.2, scale=30):
|
||||||
|
super(ARMLoss, self).__init__()
|
||||||
|
self.m = margin
|
||||||
|
self.s = scale
|
||||||
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
||||||
|
|
||||||
|
def forward(self, outputs, targets):
|
||||||
|
label_view = targets.view(-1, 1)
|
||||||
|
delt_costh = torch.zeros(outputs.size(), device=targets.device).scatter_(1, label_view, self.m)
|
||||||
|
costh_m = outputs - delt_costh
|
||||||
|
costh_m_s = self.s * costh_m
|
||||||
|
delt_costh_m_s = costh_m_s.gather(1, label_view).repeat(1, costh_m_s.size()[1])
|
||||||
|
costh_m_s_reduct = costh_m_s - delt_costh_m_s
|
||||||
|
predictions = torch.where(costh_m_s_reduct < 0.0, torch.zeros_like(costh_m_s), costh_m_s)
|
||||||
|
loss = self.criterion(predictions, targets) / targets.shape[0]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class CELoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(CELoss, self).__init__()
|
||||||
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
||||||
|
|
||||||
|
def forward(self, outputs, targets):
|
||||||
|
loss = self.criterion(outputs, targets) / targets.shape[0]
|
||||||
|
return loss
|
||||||
75
mvector/models/pooling.py
Normal file
75
mvector/models/pooling.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalAveragePooling(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
"""TAP
|
||||||
|
Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
|
||||||
|
Link: https://arxiv.org/pdf/1903.12058.pdf
|
||||||
|
"""
|
||||||
|
super(TemporalAveragePooling, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Computes Temporal Average Pooling Module
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor (#batch, channels)
|
||||||
|
"""
|
||||||
|
x = torch.mean(x, dim=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalStatisticsPooling(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
"""TSP
|
||||||
|
Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
|
||||||
|
Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
|
||||||
|
"""
|
||||||
|
super(TemporalStatisticsPooling, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Computes Temporal Statistics Pooling Module
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor (#batch, channels*2)
|
||||||
|
"""
|
||||||
|
mean = torch.mean(x, dim=2)
|
||||||
|
var = torch.var(x, dim=2)
|
||||||
|
x = torch.cat((mean, var), dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentivePooling(nn.Module):
|
||||||
|
def __init__(self, in_dim, bottleneck_dim=128):
|
||||||
|
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
||||||
|
# attention dim = 128
|
||||||
|
super(SelfAttentivePooling, self).__init__()
|
||||||
|
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
||||||
|
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||||
|
alpha = torch.tanh(self.linear1(x))
|
||||||
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||||
|
mean = torch.sum(alpha * x, dim=2)
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
class AttentiveStatsPool(nn.Module):
|
||||||
|
def __init__(self, in_dim, bottleneck_dim=128):
|
||||||
|
super().__init__()
|
||||||
|
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
||||||
|
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
||||||
|
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||||
|
alpha = torch.tanh(self.linear1(x))
|
||||||
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||||
|
mean = torch.sum(alpha * x, dim=2)
|
||||||
|
residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
|
||||||
|
std = torch.sqrt(residuals.clamp(min=1e-9))
|
||||||
|
return torch.cat([mean, std], dim=1)
|
||||||
189
mvector/predict.py
Normal file
189
mvector/predict.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
from io import BufferedReader
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from mvector import SUPPORT_MODEL
|
||||||
|
from mvector.data_utils.audio import AudioSegment
|
||||||
|
from mvector.data_utils.featurizer import AudioFeaturizer
|
||||||
|
from mvector.models.ecapa_tdnn import EcapaTdnn
|
||||||
|
from mvector.models.fc import SpeakerIdetification
|
||||||
|
from mvector.utils.logger import setup_logger
|
||||||
|
from mvector.utils.utils import dict_to_object, print_arguments
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MVectorPredictor:
|
||||||
|
def __init__(self,
|
||||||
|
configs,
|
||||||
|
threshold=0.6,
|
||||||
|
model_path='models/ecapa_tdnn_FBank/best_model/',
|
||||||
|
use_gpu=True):
|
||||||
|
"""
|
||||||
|
声纹识别预测工具
|
||||||
|
:param configs: 配置参数
|
||||||
|
:param threshold: 判断是否为同一个人的阈值
|
||||||
|
:param model_path: 导出的预测模型文件夹路径
|
||||||
|
:param use_gpu: 是否使用GPU预测
|
||||||
|
"""
|
||||||
|
if use_gpu:
|
||||||
|
assert (torch.cuda.is_available()), 'GPU不可用'
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
# 索引候选数量
|
||||||
|
self.cdd_num = 5
|
||||||
|
self.threshold = threshold
|
||||||
|
# 读取配置文件
|
||||||
|
if isinstance(configs, str):
|
||||||
|
with open(configs, 'r', encoding='utf-8') as f:
|
||||||
|
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||||
|
# print_arguments(configs=configs)
|
||||||
|
self.configs = dict_to_object(configs)
|
||||||
|
assert 'max_duration' in self.configs.dataset_conf, \
|
||||||
|
'【警告】,您貌似使用了旧的配置文件,如果你同时使用了旧的模型,这是错误的,请重新下载或者重新训练,否则只能回滚代码。'
|
||||||
|
assert self.configs.use_model in SUPPORT_MODEL, f'没有该模型:{self.configs.use_model}'
|
||||||
|
self._audio_featurizer = AudioFeaturizer(feature_conf=self.configs.feature_conf, **self.configs.preprocess_conf)
|
||||||
|
self._audio_featurizer.to(self.device)
|
||||||
|
# 获取模型
|
||||||
|
if self.configs.use_model == 'EcapaTdnn' or self.configs.use_model == 'ecapa_tdnn':
|
||||||
|
backbone = EcapaTdnn(input_size=self._audio_featurizer.feature_dim, **self.configs.model_conf)
|
||||||
|
else:
|
||||||
|
raise Exception(f'{self.configs.use_model} 模型不存在!')
|
||||||
|
model = SpeakerIdetification(backbone=backbone, num_class=self.configs.dataset_conf.num_speakers)
|
||||||
|
model.to(self.device)
|
||||||
|
# 加载模型
|
||||||
|
if os.path.isdir(model_path):
|
||||||
|
model_path = os.path.join(model_path, 'model.pt')
|
||||||
|
assert os.path.exists(model_path), f"{model_path} 模型不存在!"
|
||||||
|
if torch.cuda.is_available() and use_gpu:
|
||||||
|
model_state_dict = torch.load(model_path)
|
||||||
|
else:
|
||||||
|
model_state_dict = torch.load(model_path, map_location='cpu')
|
||||||
|
# 加载模型参数
|
||||||
|
model.load_state_dict(model_state_dict)
|
||||||
|
print(f"成功加载模型参数:{model_path}")
|
||||||
|
# 设置为评估模式
|
||||||
|
model.eval()
|
||||||
|
self.predictor = model.backbone
|
||||||
|
# 声纹库的声纹特征
|
||||||
|
self.audio_feature = None
|
||||||
|
|
||||||
|
def _load_audio(self, audio_data, sample_rate=16000):
|
||||||
|
"""加载音频
|
||||||
|
:param audio_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整的字节文件
|
||||||
|
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
|
||||||
|
:return: 识别的文本结果和解码的得分数
|
||||||
|
"""
|
||||||
|
# 加载音频文件,并进行预处理
|
||||||
|
if isinstance(audio_data, str):
|
||||||
|
audio_segment = AudioSegment.from_file(audio_data)
|
||||||
|
elif isinstance(audio_data, BufferedReader):
|
||||||
|
audio_segment = AudioSegment.from_file(audio_data)
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
audio_segment = AudioSegment.from_ndarray(audio_data, sample_rate)
|
||||||
|
elif isinstance(audio_data, bytes):
|
||||||
|
audio_segment = AudioSegment.from_bytes(audio_data)
|
||||||
|
else:
|
||||||
|
raise Exception(f'不支持该数据类型,当前数据类型为:{type(audio_data)}')
|
||||||
|
assert audio_segment.duration >= self.configs.dataset_conf.min_duration, \
|
||||||
|
f'音频太短,最小应该为{self.configs.dataset_conf.min_duration}s,当前音频为{audio_segment.duration}s'
|
||||||
|
# 重采样
|
||||||
|
if audio_segment.sample_rate != self.configs.dataset_conf.sample_rate:
|
||||||
|
audio_segment.resample(self.configs.dataset_conf.sample_rate)
|
||||||
|
# decibel normalization
|
||||||
|
if self.configs.dataset_conf.use_dB_normalization:
|
||||||
|
audio_segment.normalize(target_db=self.configs.dataset_conf.target_dB)
|
||||||
|
return audio_segment
|
||||||
|
|
||||||
|
def predict(self,
|
||||||
|
audio_data,
|
||||||
|
sample_rate=16000):
|
||||||
|
"""预测一个音频的特征
|
||||||
|
|
||||||
|
:param audio_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整并带格式的字节文件
|
||||||
|
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
|
||||||
|
:return: 声纹特征向量
|
||||||
|
"""
|
||||||
|
# 加载音频文件,并进行预处理
|
||||||
|
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
|
||||||
|
input_data = torch.tensor(input_data.samples, dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||||
|
input_len_ratio = torch.tensor([1], dtype=torch.float32, device=self.device)
|
||||||
|
audio_feature, _ = self._audio_featurizer(input_data, input_len_ratio)
|
||||||
|
# 执行预测
|
||||||
|
feature = self.predictor(audio_feature).data.cpu().numpy()[0]
|
||||||
|
return feature
|
||||||
|
|
||||||
|
def predict_batch(self, audios_data, sample_rate=16000):
|
||||||
|
"""预测一批音频的特征
|
||||||
|
|
||||||
|
:param audios_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整并带格式的字节文件
|
||||||
|
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
|
||||||
|
:return: 声纹特征向量
|
||||||
|
"""
|
||||||
|
audios_data1 = []
|
||||||
|
for audio_data in audios_data:
|
||||||
|
# 加载音频文件,并进行预处理
|
||||||
|
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
|
||||||
|
audios_data1.append(input_data.samples)
|
||||||
|
# 找出音频长度最长的
|
||||||
|
batch = sorted(audios_data1, key=lambda a: a.shape[0], reverse=True)
|
||||||
|
max_audio_length = batch[0].shape[0]
|
||||||
|
batch_size = len(batch)
|
||||||
|
# 以最大的长度创建0张量
|
||||||
|
inputs = np.zeros((batch_size, max_audio_length), dtype='float32')
|
||||||
|
input_lens_ratio = []
|
||||||
|
for x in range(batch_size):
|
||||||
|
tensor = audios_data1[x]
|
||||||
|
seq_length = tensor.shape[0]
|
||||||
|
# 将数据插入都0张量中,实现了padding
|
||||||
|
inputs[x, :seq_length] = tensor[:]
|
||||||
|
input_lens_ratio.append(seq_length/max_audio_length)
|
||||||
|
audios_data = torch.tensor(inputs, dtype=torch.float32, device=self.device)
|
||||||
|
input_lens_ratio = torch.tensor(input_lens_ratio, dtype=torch.float32, device=self.device)
|
||||||
|
audio_feature, _ = self._audio_featurizer(audios_data, input_lens_ratio)
|
||||||
|
# 执行预测
|
||||||
|
features = self.predictor(audio_feature).data.cpu().numpy()
|
||||||
|
return features
|
||||||
|
|
||||||
|
# 声纹对比
|
||||||
|
def contrast(self, audio_data1, audio_data2):
|
||||||
|
feature1 = self.predict(audio_data1)
|
||||||
|
feature2 = self.predict(audio_data2)
|
||||||
|
# 对角余弦值
|
||||||
|
dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||||
|
return dist
|
||||||
|
|
||||||
|
|
||||||
|
def recognition(self, audio_data, threshold=None, sample_rate=16000):
|
||||||
|
"""声纹识别
|
||||||
|
:param audio_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整的字节文件
|
||||||
|
:param threshold: 判断的阈值,如果为None则用创建对象时使用的阈值
|
||||||
|
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
|
||||||
|
:return: 识别的用户名称,如果为None,即没有识别到用户
|
||||||
|
"""
|
||||||
|
if threshold:
|
||||||
|
self.threshold = threshold
|
||||||
|
feature = self.predict(audio_data, sample_rate=sample_rate)
|
||||||
|
name = self.__retrieval(np_feature=[feature])[0]
|
||||||
|
return name
|
||||||
|
|
||||||
|
def compare(self, feature1, feature2):
|
||||||
|
"""声纹对比
|
||||||
|
|
||||||
|
:param feature1: 特征1
|
||||||
|
:param feature2: 特征2
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 对角余弦值
|
||||||
|
dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||||
|
return dist
|
||||||
|
|
||||||
483
mvector/trainer.py
Normal file
483
mvector/trainer.py
Normal file
@ -0,0 +1,483 @@
|
|||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import yaml
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torchinfo import summary
|
||||||
|
from tqdm import tqdm
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
from mvector import SUPPORT_MODEL, __version__
|
||||||
|
from mvector.data_utils.collate_fn import collate_fn
|
||||||
|
from mvector.data_utils.featurizer import AudioFeaturizer
|
||||||
|
from mvector.data_utils.reader import CustomDataset
|
||||||
|
from mvector.metric.metrics import TprAtFpr
|
||||||
|
from mvector.models.ecapa_tdnn import EcapaTdnn
|
||||||
|
from mvector.models.fc import SpeakerIdetification
|
||||||
|
from mvector.models.loss import AAMLoss, CELoss, AMLoss, ARMLoss
|
||||||
|
from mvector.utils.logger import setup_logger
|
||||||
|
from mvector.utils.utils import dict_to_object, print_arguments
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MVectorTrainer(object):
|
||||||
|
def __init__(self, configs, use_gpu=True):
|
||||||
|
""" mvector集成工具类
|
||||||
|
|
||||||
|
:param configs: 配置字典
|
||||||
|
:param use_gpu: 是否使用GPU训练模型
|
||||||
|
"""
|
||||||
|
if use_gpu:
|
||||||
|
assert (torch.cuda.is_available()), 'GPU不可用'
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self.use_gpu = use_gpu
|
||||||
|
# 读取配置文件
|
||||||
|
if isinstance(configs, str):
|
||||||
|
with open(configs, 'r', encoding='utf-8') as f:
|
||||||
|
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||||
|
print_arguments(configs=configs)
|
||||||
|
self.configs = dict_to_object(configs)
|
||||||
|
assert self.configs.use_model in SUPPORT_MODEL, f'没有该模型:{self.configs.use_model}'
|
||||||
|
self.model = None
|
||||||
|
self.test_loader = None
|
||||||
|
# 获取特征器
|
||||||
|
self.audio_featurizer = AudioFeaturizer(feature_conf=self.configs.feature_conf, **self.configs.preprocess_conf)
|
||||||
|
self.audio_featurizer.to(self.device)
|
||||||
|
|
||||||
|
if platform.system().lower() == 'windows':
|
||||||
|
self.configs.dataset_conf.num_workers = 0
|
||||||
|
logger.warning('Windows系统不支持多线程读取数据,已自动关闭!')
|
||||||
|
|
||||||
|
# 获取数据
|
||||||
|
def __setup_dataloader(self, augment_conf_path=None, is_train=False):
|
||||||
|
# 获取训练数据
|
||||||
|
if augment_conf_path is not None and os.path.exists(augment_conf_path) and is_train:
|
||||||
|
augmentation_config = io.open(augment_conf_path, mode='r', encoding='utf8').read()
|
||||||
|
else:
|
||||||
|
if augment_conf_path is not None and not os.path.exists(augment_conf_path):
|
||||||
|
logger.info('数据增强配置文件{}不存在'.format(augment_conf_path))
|
||||||
|
augmentation_config = '{}'
|
||||||
|
# 兼容旧的配置文件
|
||||||
|
if 'max_duration' not in self.configs.dataset_conf:
|
||||||
|
self.configs.dataset_conf.max_duration = self.configs.dataset_conf.chunk_duration
|
||||||
|
if is_train:
|
||||||
|
self.train_dataset = CustomDataset(data_list_path=self.configs.dataset_conf.train_list,
|
||||||
|
do_vad=self.configs.dataset_conf.do_vad,
|
||||||
|
max_duration=self.configs.dataset_conf.max_duration,
|
||||||
|
min_duration=self.configs.dataset_conf.min_duration,
|
||||||
|
augmentation_config=augmentation_config,
|
||||||
|
sample_rate=self.configs.dataset_conf.sample_rate,
|
||||||
|
use_dB_normalization=self.configs.dataset_conf.use_dB_normalization,
|
||||||
|
target_dB=self.configs.dataset_conf.target_dB,
|
||||||
|
mode='train')
|
||||||
|
train_sampler = None
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
# 设置支持多卡训练
|
||||||
|
train_sampler = DistributedSampler(dataset=self.train_dataset)
|
||||||
|
self.train_loader = DataLoader(dataset=self.train_dataset,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
shuffle=(train_sampler is None),
|
||||||
|
batch_size=self.configs.dataset_conf.batch_size,
|
||||||
|
sampler=train_sampler,
|
||||||
|
num_workers=self.configs.dataset_conf.num_workers)
|
||||||
|
# 获取测试数据
|
||||||
|
self.test_dataset = CustomDataset(data_list_path=self.configs.dataset_conf.test_list,
|
||||||
|
do_vad=self.configs.dataset_conf.do_vad,
|
||||||
|
max_duration=self.configs.dataset_conf.max_duration,
|
||||||
|
min_duration=self.configs.dataset_conf.min_duration,
|
||||||
|
sample_rate=self.configs.dataset_conf.sample_rate,
|
||||||
|
use_dB_normalization=self.configs.dataset_conf.use_dB_normalization,
|
||||||
|
target_dB=self.configs.dataset_conf.target_dB,
|
||||||
|
mode='eval')
|
||||||
|
self.test_loader = DataLoader(dataset=self.test_dataset,
|
||||||
|
batch_size=self.configs.dataset_conf.batch_size,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
num_workers=self.configs.dataset_conf.num_workers)
|
||||||
|
|
||||||
|
def __setup_model(self, input_size, is_train=False):
|
||||||
|
|
||||||
|
use_loss = self.configs.get('use_loss', 'AAMLoss')
|
||||||
|
# 获取模型
|
||||||
|
if self.configs.use_model == 'EcapaTdnn' or self.configs.use_model == 'ecapa_tdnn':
|
||||||
|
backbone = EcapaTdnn(input_size=input_size, **self.configs.model_conf)
|
||||||
|
else:
|
||||||
|
raise Exception(f'{self.configs.use_model} 模型不存在!')
|
||||||
|
|
||||||
|
self.model = SpeakerIdetification(backbone=backbone,
|
||||||
|
num_class=self.configs.dataset_conf.num_speakers,
|
||||||
|
loss_type=use_loss)
|
||||||
|
self.model.to(self.device)
|
||||||
|
# 打印模型信息
|
||||||
|
summary(self.model, (1, 98, self.audio_featurizer.feature_dim))
|
||||||
|
# print(self.model)
|
||||||
|
# 获取损失函数
|
||||||
|
if use_loss == 'AAMLoss':
|
||||||
|
self.loss = AAMLoss()
|
||||||
|
elif use_loss == 'AMLoss':
|
||||||
|
self.loss = AMLoss()
|
||||||
|
elif use_loss == 'ARMLoss':
|
||||||
|
self.loss = ARMLoss()
|
||||||
|
elif use_loss == 'CELoss':
|
||||||
|
self.loss = CELoss()
|
||||||
|
else:
|
||||||
|
raise Exception(f'没有{use_loss}损失函数!')
|
||||||
|
if is_train:
|
||||||
|
# 获取优化方法
|
||||||
|
optimizer = self.configs.optimizer_conf.optimizer
|
||||||
|
if optimizer == 'Adam':
|
||||||
|
self.optimizer = torch.optim.Adam(params=self.model.parameters(),
|
||||||
|
lr=float(self.configs.optimizer_conf.learning_rate),
|
||||||
|
weight_decay=float(self.configs.optimizer_conf.weight_decay))
|
||||||
|
elif optimizer == 'AdamW':
|
||||||
|
self.optimizer = torch.optim.AdamW(params=self.model.parameters(),
|
||||||
|
lr=float(self.configs.optimizer_conf.learning_rate),
|
||||||
|
weight_decay=float(self.configs.optimizer_conf.weight_decay))
|
||||||
|
elif optimizer == 'SGD':
|
||||||
|
self.optimizer = torch.optim.SGD(params=self.model.parameters(),
|
||||||
|
momentum=self.configs.optimizer_conf.momentum,
|
||||||
|
lr=float(self.configs.optimizer_conf.learning_rate),
|
||||||
|
weight_decay=float(self.configs.optimizer_conf.weight_decay))
|
||||||
|
else:
|
||||||
|
raise Exception(f'不支持优化方法:{optimizer}')
|
||||||
|
# 学习率衰减函数
|
||||||
|
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=int(self.configs.train_conf.max_epoch * 1.2))
|
||||||
|
|
||||||
|
def __load_pretrained(self, pretrained_model):
|
||||||
|
# 加载预训练模型
|
||||||
|
if pretrained_model is not None:
|
||||||
|
if os.path.isdir(pretrained_model):
|
||||||
|
pretrained_model = os.path.join(pretrained_model, 'model.pt')
|
||||||
|
assert os.path.exists(pretrained_model), f"{pretrained_model} 模型不存在!"
|
||||||
|
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||||
|
model_dict = self.model.module.state_dict()
|
||||||
|
else:
|
||||||
|
model_dict = self.model.state_dict()
|
||||||
|
model_state_dict = torch.load(pretrained_model)
|
||||||
|
# 过滤不存在的参数
|
||||||
|
for name, weight in model_dict.items():
|
||||||
|
if name in model_state_dict.keys():
|
||||||
|
if list(weight.shape) != list(model_state_dict[name].shape):
|
||||||
|
logger.warning('{} not used, shape {} unmatched with {} in model.'.
|
||||||
|
format(name, list(model_state_dict[name].shape), list(weight.shape)))
|
||||||
|
model_state_dict.pop(name, None)
|
||||||
|
else:
|
||||||
|
logger.warning('Lack weight: {}'.format(name))
|
||||||
|
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||||
|
self.model.module.load_state_dict(model_state_dict, strict=False)
|
||||||
|
else:
|
||||||
|
self.model.load_state_dict(model_state_dict, strict=False)
|
||||||
|
logger.info('成功加载预训练模型:{}'.format(pretrained_model))
|
||||||
|
|
||||||
|
def __load_checkpoint(self, save_model_path, resume_model):
|
||||||
|
# 加载恢复模型
|
||||||
|
last_epoch = -1
|
||||||
|
best_eer = 1
|
||||||
|
last_model_dir = os.path.join(save_model_path,
|
||||||
|
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||||
|
'last_model')
|
||||||
|
if resume_model is not None or (os.path.exists(os.path.join(last_model_dir, 'model.pt'))
|
||||||
|
and os.path.exists(os.path.join(last_model_dir, 'optimizer.pt'))):
|
||||||
|
# 自动获取最新保存的模型
|
||||||
|
if resume_model is None: resume_model = last_model_dir
|
||||||
|
assert os.path.exists(os.path.join(resume_model, 'model.pt')), "模型参数文件不存在!"
|
||||||
|
assert os.path.exists(os.path.join(resume_model, 'optimizer.pt')), "优化方法参数文件不存在!"
|
||||||
|
state_dict = torch.load(os.path.join(resume_model, 'model.pt'))
|
||||||
|
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||||
|
self.model.module.load_state_dict(state_dict)
|
||||||
|
else:
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
self.optimizer.load_state_dict(torch.load(os.path.join(resume_model, 'optimizer.pt')))
|
||||||
|
with open(os.path.join(resume_model, 'model.state'), 'r', encoding='utf-8') as f:
|
||||||
|
json_data = json.load(f)
|
||||||
|
last_epoch = json_data['last_epoch'] - 1
|
||||||
|
best_eer = json_data['eer']
|
||||||
|
logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model))
|
||||||
|
return last_epoch, best_eer
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
def __save_checkpoint(self, save_model_path, epoch_id, best_eer=0., best_model=False):
|
||||||
|
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||||
|
state_dict = self.model.module.state_dict()
|
||||||
|
else:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
if best_model:
|
||||||
|
model_path = os.path.join(save_model_path,
|
||||||
|
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||||
|
'best_model')
|
||||||
|
else:
|
||||||
|
model_path = os.path.join(save_model_path,
|
||||||
|
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||||
|
'epoch_{}'.format(epoch_id))
|
||||||
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
torch.save(self.optimizer.state_dict(), os.path.join(model_path, 'optimizer.pt'))
|
||||||
|
torch.save(state_dict, os.path.join(model_path, 'model.pt'))
|
||||||
|
with open(os.path.join(model_path, 'model.state'), 'w', encoding='utf-8') as f:
|
||||||
|
data = {"last_epoch": epoch_id, "eer": best_eer, "version": __version__}
|
||||||
|
f.write(json.dumps(data))
|
||||||
|
if not best_model:
|
||||||
|
last_model_path = os.path.join(save_model_path,
|
||||||
|
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||||
|
'last_model')
|
||||||
|
shutil.rmtree(last_model_path, ignore_errors=True)
|
||||||
|
shutil.copytree(model_path, last_model_path)
|
||||||
|
# 删除旧的模型
|
||||||
|
old_model_path = os.path.join(save_model_path,
|
||||||
|
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||||
|
'epoch_{}'.format(epoch_id - 3))
|
||||||
|
if os.path.exists(old_model_path):
|
||||||
|
shutil.rmtree(old_model_path)
|
||||||
|
logger.info('已保存模型:{}'.format(model_path))
|
||||||
|
|
||||||
|
def __train_epoch(self, epoch_id, save_model_path, local_rank, writer, nranks=0):
|
||||||
|
# 训练一个epoch
|
||||||
|
train_times, accuracies, loss_sum = [], [], []
|
||||||
|
start = time.time()
|
||||||
|
sum_batch = len(self.train_loader) * self.configs.train_conf.max_epoch
|
||||||
|
for batch_id, (audio, label, input_lens_ratio) in enumerate(self.train_loader):
|
||||||
|
if nranks > 1:
|
||||||
|
audio = audio.to(local_rank)
|
||||||
|
input_lens_ratio = input_lens_ratio.to(local_rank)
|
||||||
|
label = label.to(local_rank).long()
|
||||||
|
else:
|
||||||
|
audio = audio.to(self.device)
|
||||||
|
input_lens_ratio = input_lens_ratio.to(self.device)
|
||||||
|
label = label.to(self.device).long()
|
||||||
|
|
||||||
|
# 获取音频MFCC特征
|
||||||
|
features, _ = self.audio_featurizer(audio, input_lens_ratio)
|
||||||
|
output = self.model(features)
|
||||||
|
# 计算损失值
|
||||||
|
los = self.loss(output, label)
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
los.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# 计算准确率
|
||||||
|
output = torch.nn.functional.softmax(output, dim=-1)
|
||||||
|
output = output.data.cpu().numpy()
|
||||||
|
output = np.argmax(output, axis=1)
|
||||||
|
label = label.data.cpu().numpy()
|
||||||
|
acc = np.mean((output == label).astype(int))
|
||||||
|
accuracies.append(acc)
|
||||||
|
loss_sum.append(los)
|
||||||
|
train_times.append((time.time() - start) * 1000)
|
||||||
|
|
||||||
|
# 多卡训练只使用一个进程打印
|
||||||
|
if batch_id % self.configs.train_conf.log_interval == 0 and local_rank == 0:
|
||||||
|
# 计算每秒训练数据量
|
||||||
|
train_speed = self.configs.dataset_conf.batch_size / (sum(train_times) / len(train_times) / 1000)
|
||||||
|
# 计算剩余时间
|
||||||
|
eta_sec = (sum(train_times) / len(train_times)) * (
|
||||||
|
sum_batch - (epoch_id - 1) * len(self.train_loader) - batch_id)
|
||||||
|
eta_str = str(timedelta(seconds=int(eta_sec / 1000)))
|
||||||
|
logger.info(f'Train epoch: [{epoch_id}/{self.configs.train_conf.max_epoch}], '
|
||||||
|
f'batch: [{batch_id}/{len(self.train_loader)}], '
|
||||||
|
f'loss: {sum(loss_sum) / len(loss_sum):.5f}, '
|
||||||
|
f'accuracy: {sum(accuracies) / len(accuracies):.5f}, '
|
||||||
|
f'learning rate: {self.scheduler.get_last_lr()[0]:>.8f}, '
|
||||||
|
f'speed: {train_speed:.2f} data/sec, eta: {eta_str}')
|
||||||
|
writer.add_scalar('Train/Loss', sum(loss_sum) / len(loss_sum), self.train_step)
|
||||||
|
writer.add_scalar('Train/Accuracy', (sum(accuracies) / len(accuracies)), self.train_step)
|
||||||
|
# 记录学习率
|
||||||
|
writer.add_scalar('Train/lr', self.scheduler.get_last_lr()[0], self.train_step)
|
||||||
|
self.train_step += 1
|
||||||
|
train_times = []
|
||||||
|
# 固定步数也要保存一次模型
|
||||||
|
if batch_id % 10000 == 0 and batch_id != 0 and local_rank == 0:
|
||||||
|
self.__save_checkpoint(save_model_path=save_model_path, epoch_id=epoch_id)
|
||||||
|
start = time.time()
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
def train(self,
|
||||||
|
save_model_path='models/',
|
||||||
|
resume_model=None,
|
||||||
|
pretrained_model=None,
|
||||||
|
augment_conf_path='configs/augmentation.json'):
|
||||||
|
"""
|
||||||
|
训练模型
|
||||||
|
:param save_model_path: 模型保存的路径
|
||||||
|
:param resume_model: 恢复训练,当为None则不使用预训练模型
|
||||||
|
:param pretrained_model: 预训练模型的路径,当为None则不使用预训练模型
|
||||||
|
:param augment_conf_path: 数据增强的配置文件,为json格式
|
||||||
|
"""
|
||||||
|
# 获取有多少张显卡训练
|
||||||
|
nranks = torch.cuda.device_count()
|
||||||
|
local_rank = 0
|
||||||
|
writer = None
|
||||||
|
if local_rank == 0:
|
||||||
|
# 日志记录器
|
||||||
|
writer = LogWriter(logdir='log')
|
||||||
|
|
||||||
|
if nranks > 1 and self.use_gpu:
|
||||||
|
# 初始化NCCL环境
|
||||||
|
dist.init_process_group(backend='nccl')
|
||||||
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
# 获取数据
|
||||||
|
self.__setup_dataloader(augment_conf_path=augment_conf_path, is_train=True)
|
||||||
|
# 获取模型
|
||||||
|
self.__setup_model(input_size=self.audio_featurizer.feature_dim, is_train=True)
|
||||||
|
|
||||||
|
# 支持多卡训练
|
||||||
|
if nranks > 1 and self.use_gpu:
|
||||||
|
self.model.to(local_rank)
|
||||||
|
self.audio_featurizer.to(local_rank)
|
||||||
|
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank])
|
||||||
|
logger.info('训练数据:{}'.format(len(self.train_dataset)))
|
||||||
|
|
||||||
|
self.__load_pretrained(pretrained_model=pretrained_model)
|
||||||
|
# 加载恢复模型
|
||||||
|
last_epoch, best_eer = self.__load_checkpoint(save_model_path=save_model_path, resume_model=resume_model)
|
||||||
|
if last_epoch > 0:
|
||||||
|
self.optimizer.step()
|
||||||
|
[self.scheduler.step() for _ in range(last_epoch)]
|
||||||
|
|
||||||
|
test_step, self.train_step = 0, 0
|
||||||
|
last_epoch += 1
|
||||||
|
if local_rank == 0:
|
||||||
|
writer.add_scalar('Train/lr', self.scheduler.get_last_lr()[0], last_epoch)
|
||||||
|
# 开始训练
|
||||||
|
for epoch_id in range(last_epoch, self.configs.train_conf.max_epoch):
|
||||||
|
epoch_id += 1
|
||||||
|
start_epoch = time.time()
|
||||||
|
# 训练一个epoch
|
||||||
|
self.__train_epoch(epoch_id=epoch_id, save_model_path=save_model_path, local_rank=local_rank,
|
||||||
|
writer=writer, nranks=nranks)
|
||||||
|
# 多卡训练只使用一个进程执行评估和保存模型
|
||||||
|
if local_rank == 0:
|
||||||
|
logger.info('=' * 70)
|
||||||
|
tpr, fpr, eer, threshold = self.evaluate(resume_model=None)
|
||||||
|
logger.info('Test epoch: {}, time/epoch: {}, threshold: {:.2f}, tpr: {:.5f}, fpr: {:.5f}, '
|
||||||
|
'eer: {:.5f}'.format(epoch_id, str(timedelta(
|
||||||
|
seconds=(time.time() - start_epoch))), threshold, tpr, fpr, eer))
|
||||||
|
logger.info('=' * 70)
|
||||||
|
writer.add_scalar('Test/threshold', threshold, test_step)
|
||||||
|
writer.add_scalar('Test/tpr', tpr, test_step)
|
||||||
|
writer.add_scalar('Test/fpr', fpr, test_step)
|
||||||
|
writer.add_scalar('Test/eer', eer, test_step)
|
||||||
|
test_step += 1
|
||||||
|
self.model.train()
|
||||||
|
# # 保存最优模型
|
||||||
|
if eer <= best_eer:
|
||||||
|
best_eer = eer
|
||||||
|
self.__save_checkpoint(save_model_path=save_model_path, epoch_id=epoch_id, best_eer=eer,
|
||||||
|
best_model=True)
|
||||||
|
# 保存模型
|
||||||
|
self.__save_checkpoint(save_model_path=save_model_path, epoch_id=epoch_id, best_eer=eer)
|
||||||
|
|
||||||
|
def evaluate(self, resume_model='models/EcapaTdnn_MFCC/best_model/', save_image_path=None):
|
||||||
|
"""
|
||||||
|
评估模型
|
||||||
|
:param resume_model: 所使用的模型
|
||||||
|
:param save_image_path: 保存混合矩阵的路径
|
||||||
|
:return: 评估结果
|
||||||
|
"""
|
||||||
|
if self.test_loader is None:
|
||||||
|
self.__setup_dataloader()
|
||||||
|
if self.model is None:
|
||||||
|
self.__setup_model(input_size=self.audio_featurizer.feature_dim)
|
||||||
|
if resume_model is not None:
|
||||||
|
if os.path.isdir(resume_model):
|
||||||
|
resume_model = os.path.join(resume_model, 'model.pt')
|
||||||
|
assert os.path.exists(resume_model), f"{resume_model} 模型不存在!"
|
||||||
|
model_state_dict = torch.load(resume_model)
|
||||||
|
self.model.load_state_dict(model_state_dict)
|
||||||
|
logger.info(f'成功加载模型:{resume_model}')
|
||||||
|
self.model.eval()
|
||||||
|
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||||
|
eval_model = self.model.module
|
||||||
|
else:
|
||||||
|
eval_model = self.model
|
||||||
|
|
||||||
|
features, labels = None, None
|
||||||
|
losses = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_id, (audio, label, input_lens_ratio) in enumerate(tqdm(self.test_loader)):
|
||||||
|
audio = audio.to(self.device)
|
||||||
|
input_lens_ratio = input_lens_ratio.to(self.device)
|
||||||
|
|
||||||
|
label = label.to(self.device).long()
|
||||||
|
audio_features, _ = self.audio_featurizer(audio, input_lens_ratio)
|
||||||
|
# logits = eval_model(audio_features)
|
||||||
|
|
||||||
|
|
||||||
|
# loss = self.loss(logits, label) # 注意,这里使用的是 logits 而不是提取的特征
|
||||||
|
# losses.append(loss.item())
|
||||||
|
|
||||||
|
|
||||||
|
feature = eval_model.backbone(audio_features).data.cpu().numpy()
|
||||||
|
label = label.data.cpu().numpy()
|
||||||
|
# 存放特征
|
||||||
|
features = np.concatenate((features, feature)) if features is not None else feature
|
||||||
|
labels = np.concatenate((labels, label)) if labels is not None else label
|
||||||
|
# print('Test loss: {:.5f}'.format(sum(losses) / len(losses)))
|
||||||
|
self.model.train()
|
||||||
|
metric = TprAtFpr()
|
||||||
|
labels = labels.astype(np.int32)
|
||||||
|
print('开始两两对比音频特征...')
|
||||||
|
for i in tqdm(range(len(features))):
|
||||||
|
feature_1 = features[i]
|
||||||
|
feature_1 = np.expand_dims(feature_1, 0).repeat(len(features) - i, axis=0)
|
||||||
|
feature_2 = features[i:]
|
||||||
|
feature_1 = torch.tensor(feature_1, dtype=torch.float32)
|
||||||
|
feature_2 = torch.tensor(feature_2, dtype=torch.float32)
|
||||||
|
score = torch.nn.functional.cosine_similarity(feature_1, feature_2, dim=-1).data.cpu().numpy().tolist()
|
||||||
|
y_true = np.array(labels[i] == labels[i:]).astype(np.int32).tolist()
|
||||||
|
metric.add(y_true, score)
|
||||||
|
tprs, fprs, thresholds, eer, index = metric.calculate()
|
||||||
|
tpr, fpr, threshold = tprs[index], fprs[index], thresholds[index]
|
||||||
|
if save_image_path:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
plt.plot(thresholds, tprs, color='blue', linestyle='-', label='tpr')
|
||||||
|
plt.plot(thresholds, fprs, color='red', linestyle='-', label='fpr')
|
||||||
|
plt.plot(threshold, tpr, 'bo-')
|
||||||
|
plt.text(threshold, tpr, (threshold, round(tpr, 5)), color='blue')
|
||||||
|
plt.plot(threshold, fpr, 'ro-')
|
||||||
|
plt.text(threshold, fpr, (threshold, round(fpr, 5)), color='red')
|
||||||
|
plt.xlabel('threshold')
|
||||||
|
plt.title('tpr and fpr')
|
||||||
|
plt.grid(True) # 显示网格线
|
||||||
|
# 保存图像
|
||||||
|
os.makedirs(save_image_path, exist_ok=True)
|
||||||
|
plt.savefig(os.path.join(save_image_path, 'result.png'))
|
||||||
|
logger.info(f"结果图以保存在:{os.path.join(save_image_path, 'result.png')}")
|
||||||
|
return tpr, fpr, eer, threshold
|
||||||
|
|
||||||
|
def export(self, save_model_path='models/', resume_model='models/EcapaTdnn_MelSpectrogram/best_model/'):
|
||||||
|
"""
|
||||||
|
导出预测模型
|
||||||
|
:param save_model_path: 模型保存的路径
|
||||||
|
:param resume_model: 准备转换的模型路径
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 获取模型
|
||||||
|
self.__setup_model(input_size=self.audio_featurizer.feature_dim)
|
||||||
|
# 加载预训练模型
|
||||||
|
if os.path.isdir(resume_model):
|
||||||
|
resume_model = os.path.join(resume_model, 'model.pt')
|
||||||
|
assert os.path.exists(resume_model), f"{resume_model} 模型不存在!"
|
||||||
|
model_state_dict = torch.load(resume_model)
|
||||||
|
self.model.load_state_dict(model_state_dict)
|
||||||
|
logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model))
|
||||||
|
self.model.eval()
|
||||||
|
# 获取静态模型
|
||||||
|
infer_model = torch.jit.script(self.model.backbone)
|
||||||
|
infer_model_path = os.path.join(save_model_path,
|
||||||
|
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||||
|
'inference.pt')
|
||||||
|
os.makedirs(os.path.dirname(infer_model_path), exist_ok=True)
|
||||||
|
torch.jit.save(infer_model, infer_model_path)
|
||||||
|
logger.info("预测模型已保存:{}".format(infer_model_path))
|
||||||
0
mvector/utils/__init__.py
Normal file
0
mvector/utils/__init__.py
Normal file
BIN
mvector/utils/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/utils/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/utils/__pycache__/logger.cpython-37.pyc
Normal file
BIN
mvector/utils/__pycache__/logger.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/utils/__pycache__/utils.cpython-37.pyc
Normal file
BIN
mvector/utils/__pycache__/utils.cpython-37.pyc
Normal file
Binary file not shown.
89
mvector/utils/logger.py
Normal file
89
mvector/utils/logger.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import termcolor
|
||||||
|
|
||||||
|
__all__ = ['setup_logger']
|
||||||
|
|
||||||
|
logger_initialized = []
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(name, output=None):
|
||||||
|
"""
|
||||||
|
Initialize logger and set its verbosity level to INFO.
|
||||||
|
Args:
|
||||||
|
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||||
|
If ends with ".txt" or ".log", assumed to be a file name.
|
||||||
|
Otherwise, logs will be saved to `output/log.txt`.
|
||||||
|
name (str): the root module name of this logger
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: a logger
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
if name in logger_initialized:
|
||||||
|
return logger
|
||||||
|
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
formatter = ("[%(asctime2)s %(levelname2)s] %(module2)s:%(funcName2)s:%(lineno2)s - %(message2)s")
|
||||||
|
color_formatter = ColoredFormatter(formatter, datefmt="%m/%d %H:%M:%S")
|
||||||
|
|
||||||
|
ch = logging.StreamHandler(stream=sys.stdout)
|
||||||
|
ch.setLevel(logging.DEBUG)
|
||||||
|
ch.setFormatter(color_formatter)
|
||||||
|
logger.addHandler(ch)
|
||||||
|
|
||||||
|
# file logging: all workers
|
||||||
|
if output is not None:
|
||||||
|
if output.endswith(".txt") or output.endswith(".log"):
|
||||||
|
filename = output
|
||||||
|
else:
|
||||||
|
filename = os.path.join(output, "log.txt")
|
||||||
|
os.makedirs(os.path.dirname(filename))
|
||||||
|
fh = logging.FileHandler(filename, mode='a')
|
||||||
|
fh.setLevel(logging.DEBUG)
|
||||||
|
fh.setFormatter(logging.Formatter())
|
||||||
|
logger.addHandler(fh)
|
||||||
|
logger_initialized.append(name)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
COLORS = {
|
||||||
|
"WARNING": "yellow",
|
||||||
|
"INFO": "white",
|
||||||
|
"DEBUG": "blue",
|
||||||
|
"CRITICAL": "red",
|
||||||
|
"ERROR": "red",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ColoredFormatter(logging.Formatter):
|
||||||
|
def __init__(self, fmt, datefmt, use_color=True):
|
||||||
|
logging.Formatter.__init__(self, fmt, datefmt=datefmt)
|
||||||
|
self.use_color = use_color
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
levelname = record.levelname
|
||||||
|
if self.use_color and levelname in COLORS:
|
||||||
|
|
||||||
|
def colored(text):
|
||||||
|
return termcolor.colored(
|
||||||
|
text,
|
||||||
|
color=COLORS[levelname],
|
||||||
|
attrs={"bold": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
record.levelname2 = colored("{:<7}".format(record.levelname))
|
||||||
|
record.message2 = colored(record.msg)
|
||||||
|
|
||||||
|
asctime2 = datetime.datetime.fromtimestamp(record.created)
|
||||||
|
record.asctime2 = termcolor.colored(asctime2, color="green")
|
||||||
|
|
||||||
|
record.module2 = termcolor.colored(record.module, color="cyan")
|
||||||
|
record.funcName2 = termcolor.colored(record.funcName, color="cyan")
|
||||||
|
record.lineno2 = termcolor.colored(record.lineno, color="cyan")
|
||||||
|
return logging.Formatter.format(self, record)
|
||||||
|
|
||||||
31
mvector/utils/record.py
Normal file
31
mvector/utils/record.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import soundcard
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
|
||||||
|
class RecordAudio:
|
||||||
|
def __init__(self, channels=1, sample_rate=16000):
|
||||||
|
# 录音参数
|
||||||
|
self.channels = channels
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
|
||||||
|
# 获取麦克风
|
||||||
|
self.default_mic = soundcard.default_microphone()
|
||||||
|
|
||||||
|
def record(self, record_seconds=3, save_path=None):
|
||||||
|
"""录音
|
||||||
|
|
||||||
|
:param record_seconds: 录音时间,默认3秒
|
||||||
|
:param save_path: 录音保存的路径,后缀名为wav
|
||||||
|
:return: 音频的numpy数据
|
||||||
|
"""
|
||||||
|
print("开始录音......")
|
||||||
|
num_frames = int(record_seconds * self.sample_rate)
|
||||||
|
data = self.default_mic.record(samplerate=self.sample_rate, numframes=num_frames, channels=self.channels)
|
||||||
|
audio_data = data.squeeze()
|
||||||
|
print("录音已结束!")
|
||||||
|
if save_path is not None:
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
soundfile.write(save_path, data=data, samplerate=self.sample_rate)
|
||||||
|
return audio_data
|
||||||
85
mvector/utils/utils.py
Normal file
85
mvector/utils/utils.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import distutils.util
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from mvector.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def print_arguments(args=None, configs=None):
|
||||||
|
if args:
|
||||||
|
logger.info("----------- 额外配置参数 -----------")
|
||||||
|
for arg, value in sorted(vars(args).items()):
|
||||||
|
logger.info("%s: %s" % (arg, value))
|
||||||
|
logger.info("------------------------------------------------")
|
||||||
|
if configs:
|
||||||
|
logger.info("----------- 配置文件参数 -----------")
|
||||||
|
for arg, value in sorted(configs.items()):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
logger.info(f"{arg}:")
|
||||||
|
for a, v in sorted(value.items()):
|
||||||
|
if isinstance(v, dict):
|
||||||
|
logger.info(f"\t{a}:")
|
||||||
|
for a1, v1 in sorted(v.items()):
|
||||||
|
logger.info("\t\t%s: %s" % (a1, v1))
|
||||||
|
else:
|
||||||
|
logger.info("\t%s: %s" % (a, v))
|
||||||
|
else:
|
||||||
|
logger.info("%s: %s" % (arg, value))
|
||||||
|
logger.info("------------------------------------------------")
|
||||||
|
|
||||||
|
|
||||||
|
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
||||||
|
type = distutils.util.strtobool if type == bool else type
|
||||||
|
argparser.add_argument("--" + argname,
|
||||||
|
default=default,
|
||||||
|
type=type,
|
||||||
|
help=help + ' 默认: %(default)s.',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Dict(dict):
|
||||||
|
__setattr__ = dict.__setitem__
|
||||||
|
__getattr__ = dict.__getitem__
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_object(dict_obj):
|
||||||
|
if not isinstance(dict_obj, dict):
|
||||||
|
return dict_obj
|
||||||
|
inst = Dict()
|
||||||
|
for k, v in dict_obj.items():
|
||||||
|
inst[k] = dict_to_object(v)
|
||||||
|
return inst
|
||||||
|
|
||||||
|
|
||||||
|
# 根据对角余弦值计算准确率和最优的阈值
|
||||||
|
def cal_accuracy_threshold(y_score, y_true):
|
||||||
|
y_score = np.asarray(y_score)
|
||||||
|
y_true = np.asarray(y_true)
|
||||||
|
best_accuracy = 0
|
||||||
|
best_threshold = 0
|
||||||
|
for i in tqdm(range(0, 100)):
|
||||||
|
threshold = i * 0.01
|
||||||
|
y_test = (y_score >= threshold)
|
||||||
|
acc = np.mean((y_test == y_true).astype(int))
|
||||||
|
if acc > best_accuracy:
|
||||||
|
best_accuracy = acc
|
||||||
|
best_threshold = threshold
|
||||||
|
|
||||||
|
return best_accuracy, best_threshold
|
||||||
|
|
||||||
|
|
||||||
|
# 根据对角余弦值计算准确率
|
||||||
|
def cal_accuracy(y_score, y_true, threshold=0.5):
|
||||||
|
y_score = np.asarray(y_score)
|
||||||
|
y_true = np.asarray(y_true)
|
||||||
|
y_test = (y_score >= threshold)
|
||||||
|
accuracy = np.mean((y_test == y_true).astype(int))
|
||||||
|
return accuracy
|
||||||
|
|
||||||
|
|
||||||
|
# 计算对角余弦值
|
||||||
|
def cosin_metric(x1, x2):
|
||||||
|
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
numba>=0.52.0
|
||||||
|
librosa>=0.9.1
|
||||||
|
numpy>=1.19.2
|
||||||
|
tqdm>=4.59.0
|
||||||
|
visualdl>=2.1.1
|
||||||
|
resampy==0.2.2
|
||||||
|
soundfile>=0.12.1
|
||||||
|
soundcard>=0.4.2
|
||||||
45
setup.py
Normal file
45
setup.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
import mvector
|
||||||
|
|
||||||
|
VERSION = mvector.__version__
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open('README.md', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def parse_requirements():
|
||||||
|
with open('./requirements.txt', encoding="utf-8") as f:
|
||||||
|
requirements = f.readlines()
|
||||||
|
return requirements
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup(
|
||||||
|
name='mvector',
|
||||||
|
packages=find_packages(),
|
||||||
|
author='yeyupiaoling',
|
||||||
|
version=VERSION,
|
||||||
|
install_requires=parse_requirements(),
|
||||||
|
description='Voice Print Recognition toolkit on Pytorch',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
url='https://github.com/yeyupiaoling/VoiceprintRecognition_Pytorch',
|
||||||
|
download_url='https://github.com/yeyupiaoling/VoiceprintRecognition_Pytorch.git',
|
||||||
|
keywords=['Voice', 'Pytorch'],
|
||||||
|
classifiers=[
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'License :: OSI Approved :: Apache Software License',
|
||||||
|
'Operating System :: OS Independent',
|
||||||
|
'Natural Language :: Chinese (Simplified)',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.5',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3.8',
|
||||||
|
'Programming Language :: Python :: 3.9', 'Topic :: Utilities'
|
||||||
|
],
|
||||||
|
license='Apache License 2.0',
|
||||||
|
ext_modules=[])
|
||||||
26
train.py
Normal file
26
train.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from mvector.trainer import MVectorTrainer
|
||||||
|
from mvector.utils.utils import add_arguments, print_arguments
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
|
||||||
|
add_arg("local_rank", int, 0, '多卡训练需要的参数')
|
||||||
|
add_arg("use_gpu", bool, True, '是否使用GPU训练')
|
||||||
|
add_arg('augment_conf_path',str, 'configs/augmentation.json', '数据增强的配置文件,为json格式')
|
||||||
|
add_arg('save_model_path', str, 'models/', '模型保存的路径')
|
||||||
|
add_arg('resume_model', str, None, '恢复训练,当为None则不使用预训练模型')
|
||||||
|
add_arg('save_image_path', str, 'output/images/', "保存结果图的路径")
|
||||||
|
add_arg('pretrained_model', str, 'models/ecapa_tdnn_MFCC/best_model/', '预训练模型的路径,当为None则不使用预训练模型')
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args=args)
|
||||||
|
|
||||||
|
# 获取训练器
|
||||||
|
trainer = MVectorTrainer(configs=args.configs, use_gpu=args.use_gpu)
|
||||||
|
|
||||||
|
trainer.train(save_model_path=args.save_model_path,
|
||||||
|
resume_model=args.resume_model,
|
||||||
|
pretrained_model=args.pretrained_model,
|
||||||
|
augment_conf_path=args.augment_conf_path)
|
||||||
Loading…
x
Reference in New Issue
Block a user