first commit
This commit is contained in:
commit
9dbcf5c730
97
README.md
Normal file
97
README.md
Normal file
@ -0,0 +1,97 @@
|
||||
# 前言
|
||||
|
||||
|
||||
使用环境:
|
||||
|
||||
- Anaconda 3
|
||||
- Python 3.8
|
||||
- Pytorch 1.13.1
|
||||
- Windows 10 or Ubuntu 18.04
|
||||
|
||||
# 项目特性
|
||||
|
||||
1. 支持模型:EcapaTdnn、TDNN、Res2Net、ResNetSE
|
||||
2. 支持池化层:AttentiveStatsPool(ASP)、SelfAttentivePooling(SAP)、TemporalStatisticsPooling(TSP)、TemporalAveragePooling(TAP)
|
||||
3. 支持损失函数:AAMLoss、AMLoss、ARMLoss、CELoss
|
||||
4. 支持预处理方法:MelSpectrogram、Spectrogram、MFCC
|
||||
|
||||
|
||||
## 安装环境
|
||||
|
||||
- 首先安装的是Pytorch的GPU版本,如果已经安装过了,请跳过。
|
||||
|
||||
```shell
|
||||
conda install pytorch==11.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
|
||||
```
|
||||
|
||||
- 安装ppvector库。
|
||||
|
||||
使用pip安装,命令如下:
|
||||
|
||||
```shell
|
||||
python -m pip install mvector -U -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
# 使用指南
|
||||
|
||||
## 1. 环境准备
|
||||
### 1.1 安装依赖
|
||||
```shell
|
||||
# 使用conda创建环境(可选)
|
||||
conda create -n voiceprint python=3.8
|
||||
conda activate voiceprint
|
||||
|
||||
# 安装项目依赖
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 1.2 准备音频数据
|
||||
- 在`audio_db/`目录存放注册语音(建议16kHz单通道wav格式)
|
||||
- 测试音频建议存放至`test_audio/`目录
|
||||
|
||||
## 2. 核心功能使用
|
||||
|
||||
### 2.1 训练声纹模型
|
||||
```shell
|
||||
python train.py \
|
||||
--config_path configs/ecapa_tdnn.yml \
|
||||
--augmentation_config configs/augmentation.json \
|
||||
--save_dir models/
|
||||
```
|
||||
|
||||
### 2.2 声纹注册入库
|
||||
```python
|
||||
from mvector import MVector
|
||||
mvector = MVector()
|
||||
mvector.register_user(name="user1", audio_path="audio_db/user1.wav")
|
||||
```
|
||||
|
||||
### 2.3 实时声纹识别
|
||||
```shell
|
||||
python infer_recognition.py \
|
||||
--model_path models/ecapa_tdnn.pth \
|
||||
--audio_path test_audio/unknown.wav
|
||||
```
|
||||
|
||||
### 2.4 声纹对比验证
|
||||
```shell
|
||||
python infer_contrast.py \
|
||||
--audio1 audio_db/user1.wav \
|
||||
--audio2 test_audio/sample.wav \
|
||||
--threshold 0.7
|
||||
```
|
||||
|
||||
## 3. 降噪预处理
|
||||
```python
|
||||
from Reduction_Noise import NoiseReducer
|
||||
reducer = NoiseReducer("Reduction_Noise/pytorch_model.bin")
|
||||
clean_audio = reducer.process("noisy_audio.wav")
|
||||
```
|
||||
|
||||
## 4. 模型评估
|
||||
```shell
|
||||
python eval.py \
|
||||
--model_path models/ecapa_tdnn.pth \
|
||||
--test_csv eval_samples.csv \
|
||||
--batch_size 32
|
||||
```
|
||||
1
Reduction_Noise
Submodule
1
Reduction_Noise
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit cfc4b6a2433a4a6f0d2d1cda5f944d677e072ef4
|
||||
BIN
audio_db/output1.wav
Normal file
BIN
audio_db/output1.wav
Normal file
Binary file not shown.
BIN
audio_db/output2.wav
Normal file
BIN
audio_db/output2.wav
Normal file
Binary file not shown.
BIN
audio_db/test.wav
Normal file
BIN
audio_db/test.wav
Normal file
Binary file not shown.
BIN
audio_db/test_Re.wav
Normal file
BIN
audio_db/test_Re.wav
Normal file
Binary file not shown.
72
configs/augmentation.json
Normal file
72
configs/augmentation.json
Normal file
@ -0,0 +1,72 @@
|
||||
[
|
||||
{
|
||||
"type": "noise",
|
||||
"aug_type": "audio",
|
||||
"params": {
|
||||
"min_snr_dB": 10,
|
||||
"max_snr_dB": 50,
|
||||
"repetition": 2,
|
||||
"noise_dir": "dataset/noise/"
|
||||
},
|
||||
"prob": 0.0
|
||||
},
|
||||
{
|
||||
"type": "resample",
|
||||
"aug_type": "audio",
|
||||
"params": {
|
||||
"new_sample_rate": [8000, 32000, 44100, 48000]
|
||||
},
|
||||
"prob": 0.0
|
||||
},
|
||||
{
|
||||
"type": "speed",
|
||||
"aug_type": "audio",
|
||||
"params": {
|
||||
"min_speed_rate": 0.9,
|
||||
"max_speed_rate": 1.1,
|
||||
"num_rates": 3
|
||||
},
|
||||
"prob": 0.5
|
||||
},
|
||||
{
|
||||
"type": "shift",
|
||||
"aug_type": "audio",
|
||||
"params": {
|
||||
"min_shift_ms": -5,
|
||||
"max_shift_ms": 5
|
||||
},
|
||||
"prob": 0.0
|
||||
},
|
||||
{
|
||||
"type": "volume",
|
||||
"aug_type": "audio",
|
||||
"params": {
|
||||
"min_gain_dBFS": -15,
|
||||
"max_gain_dBFS": 15
|
||||
},
|
||||
"prob": 0.5
|
||||
},
|
||||
{
|
||||
"type": "specaug",
|
||||
"aug_type": "feature",
|
||||
"params": {
|
||||
"inplace": true,
|
||||
"max_time_warp": 5,
|
||||
"max_t_ratio": 0.01,
|
||||
"n_freq_masks": 2,
|
||||
"max_f_ratio": 0.05,
|
||||
"n_time_masks": 2,
|
||||
"replace_with_zero": true
|
||||
},
|
||||
"prob": 0.5
|
||||
},
|
||||
{
|
||||
"type": "specsub",
|
||||
"aug_type": "feature",
|
||||
"params": {
|
||||
"max_t": 10,
|
||||
"num_t_sub": 2
|
||||
},
|
||||
"prob": 0.0
|
||||
}
|
||||
]
|
||||
54
configs/ecapa_tdnn.yml
Normal file
54
configs/ecapa_tdnn.yml
Normal file
@ -0,0 +1,54 @@
|
||||
# 数据集参数
|
||||
dataset_conf:
|
||||
# 训练的批量大小
|
||||
batch_size: 256
|
||||
# 说话人数量,即分类大小
|
||||
num_speakers: 3242
|
||||
# 读取数据的线程数量
|
||||
num_workers: 12
|
||||
# 过滤最短的音频长度
|
||||
min_duration: 0.5
|
||||
# 最长的音频长度,大于这个长度会裁剪掉
|
||||
max_duration: 6
|
||||
# 是否裁剪静音片段
|
||||
do_vad: False
|
||||
# 音频的采样率
|
||||
sample_rate: 16000
|
||||
# 是否对音频进行音量归一化
|
||||
use_dB_normalization: False
|
||||
# 对音频进行音量归一化的音量分贝值
|
||||
target_dB: -20
|
||||
# 训练数据的数据列表路径
|
||||
train_list: 'dataset/train_list.txt'
|
||||
# 测试数据的数据列表路径
|
||||
test_list: 'dataset/test_list.txt'
|
||||
# 标签列表
|
||||
label_list_path: 'dataset/label_list.txt'
|
||||
|
||||
# 数据预处理参数
|
||||
preprocess_conf:
|
||||
# 音频预处理方法,支持:MelSpectrogram、Spectrogram、MFCC、Fbank
|
||||
feature_method: 'Fbank'
|
||||
|
||||
feature_conf:
|
||||
sample_frequency: 16000
|
||||
num_mel_bins: 80
|
||||
|
||||
optimizer_conf:
|
||||
# 优化方法,支持Adam、AdamW、SGD
|
||||
optimizer: 'Adam'
|
||||
# 初始学习率的大小
|
||||
learning_rate: 0.001
|
||||
weight_decay: 1e-6
|
||||
|
||||
model_conf:
|
||||
embd_dim: 192
|
||||
channels: 512
|
||||
|
||||
train_conf:
|
||||
# 训练的轮数
|
||||
max_epoch: 30
|
||||
log_interval: 100
|
||||
|
||||
# 所使用的模型
|
||||
use_model: 'ecapa_tdnn'
|
||||
104
create_data.py
Normal file
104
create_data.py
Normal file
@ -0,0 +1,104 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from datetime import timedelta
|
||||
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
# 生成数据列表
|
||||
def get_data_list(infodata_path, zhvoice_path):
|
||||
print('正在读取标注文件...')
|
||||
with open(infodata_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
data = []
|
||||
speakers = []
|
||||
speakers_dict = {}
|
||||
for line in lines:
|
||||
line = json.loads(line.replace('\n', ''))
|
||||
duration_ms = line['duration_ms']
|
||||
if duration_ms < 1300:
|
||||
continue
|
||||
speaker = line['speaker']
|
||||
if speaker not in speakers:
|
||||
speakers_dict[speaker] = len(speakers)
|
||||
speakers.append(speaker)
|
||||
label = speakers_dict[speaker]
|
||||
sound_path = os.path.join(zhvoice_path, line['index'])
|
||||
data.append([sound_path.replace('\\', '/'), label])
|
||||
print(f'一共有{len(data)}条数据!')
|
||||
return data
|
||||
|
||||
|
||||
def mp32wav(num, data_list):
|
||||
start = time.time()
|
||||
for i, data in enumerate(data_list):
|
||||
sound_path, label = data
|
||||
if os.path.exists(sound_path):
|
||||
save_path = sound_path.replace('.mp3', '.wav')
|
||||
if not os.path.exists(save_path):
|
||||
wav = AudioSegment.from_mp3(sound_path)
|
||||
wav.export(save_path, format="wav")
|
||||
os.remove(sound_path)
|
||||
if i % 100 == 0:
|
||||
eta_sec = ((time.time() - start) / 100 * (len(data_list) - i))
|
||||
start = time.time()
|
||||
eta_str = str(timedelta(seconds=int(eta_sec)))
|
||||
print(f'进程{num}进度:[{i}/{len(data_list)}],剩余时间:{eta_str}')
|
||||
|
||||
|
||||
def split_data(list_temp, n):
|
||||
length = len(list_temp) // n
|
||||
for i in range(0, len(list_temp), length):
|
||||
yield list_temp[i:i + length]
|
||||
|
||||
|
||||
def main(infodata_path, list_path, zhvoice_path, to_wav=True, num_workers=2):
|
||||
if to_wav:
|
||||
text = input(f'音频文件将会转换为wav格式,这个过程可能很长,而且最终文件大小接近100G,是否继续?(y/n)')
|
||||
if text is None or text != 'y':
|
||||
return
|
||||
else:
|
||||
text = input(f'将会直接使用MP3格式文件,但读取速度会比wav格式慢,是否继续?(y/n)')
|
||||
if text is None or text != 'y':
|
||||
return
|
||||
data_all = []
|
||||
data = get_data_list(infodata_path=infodata_path, zhvoice_path=zhvoice_path)
|
||||
if to_wav:
|
||||
print('准备把MP3总成WAV格式...')
|
||||
split_d = split_data(data, num_workers)
|
||||
pool = Pool(num_workers)
|
||||
for i, d in enumerate(split_d):
|
||||
pool.apply_async(mp32wav, (i, d))
|
||||
pool.close()
|
||||
pool.join()
|
||||
for d in data:
|
||||
sound_path, label = d
|
||||
sound_path = sound_path.replace('.mp3', '.wav')
|
||||
if os.path.exists(sound_path):
|
||||
data_all.append([sound_path, label])
|
||||
else:
|
||||
for d in data:
|
||||
sound_path, label = d
|
||||
if os.path.exists(sound_path):
|
||||
data_all.append(d)
|
||||
f_train = open(os.path.join(list_path, 'train_list.txt'), 'w')
|
||||
f_test = open(os.path.join(list_path, 'test_list.txt'), 'w')
|
||||
for i, d in enumerate(data_all):
|
||||
sound_path, label = d
|
||||
if i % 200 == 0:
|
||||
f_test.write(f'{sound_path}\t{label}\n')
|
||||
else:
|
||||
f_train.write(f'{sound_path}\t{label}\n')
|
||||
f_test.close()
|
||||
f_train.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(infodata_path='dataset/zhvoice/text/infodata.json',
|
||||
list_path='dataset',
|
||||
zhvoice_path='dataset/zhvoice',
|
||||
to_wav=False,
|
||||
num_workers=cpu_count())
|
||||
25
eval.py
Normal file
25
eval.py
Normal file
@ -0,0 +1,25 @@
|
||||
import argparse
|
||||
import functools
|
||||
import time
|
||||
|
||||
from mvector.trainer import MVectorTrainer
|
||||
from mvector.utils.utils import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
add_arg('configs', str, 'configs/ecapa_tdnn.yml', "配置文件")
|
||||
add_arg("use_gpu", bool, True, "是否使用GPU评估模型")
|
||||
add_arg('save_image_path', str, 'output/images/', "保存结果图的路径")
|
||||
add_arg('resume_model', str, 'models/ecapa_tdnn_MFCC/best_model/', "模型的路径")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args=args)
|
||||
|
||||
# 获取训练器
|
||||
trainer = MVectorTrainer(configs=args.configs, use_gpu=args.use_gpu)
|
||||
|
||||
# 开始评估
|
||||
start = time.time()
|
||||
tpr, fpr, eer, threshold = trainer.evaluate(resume_model=args.resume_model, save_image_path=args.save_image_path)
|
||||
end = time.time()
|
||||
print('评估消耗时间:{}s,threshold:{:.2f},tpr:{:.5f}, fpr: {:.5f}, eer: {:.5f}'
|
||||
.format(int(end - start), threshold, tpr, fpr, eer))
|
||||
52
infer_contrast.py
Normal file
52
infer_contrast.py
Normal file
@ -0,0 +1,52 @@
|
||||
import argparse
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
from mvector.predict import MVectorPredictor
|
||||
from mvector.utils.utils import add_arguments, print_arguments
|
||||
from mvector.data_utils.audio import AudioSegment
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
|
||||
add_arg('use_gpu', bool, True, '是否使用GPU预测')
|
||||
add_arg('audio_path1', str, 'dataset/source/王翔/wx1_5.wav', '预测第一个音频')
|
||||
add_arg('audio_path2', str, 'dataset/source/刘云杰/lyj_no_5.wav', '预测第二个音频')
|
||||
add_arg('threshold', float, 0.7, '判断是否为同一个人的阈值')
|
||||
add_arg('model_path', str, 'models/test_model', '导出的预测模型文件路径')
|
||||
args = parser.parse_args()
|
||||
# print_arguments(args=args)
|
||||
|
||||
|
||||
|
||||
# 获取识别器
|
||||
predictor = MVectorPredictor(configs=args.configs,
|
||||
model_path=args.model_path,
|
||||
use_gpu=args.use_gpu)
|
||||
|
||||
def load_audio_paths(file_path):
|
||||
with open(file_path, 'r',encoding='utf-8') as file:
|
||||
return [line.strip() for line in file if line.strip()]
|
||||
|
||||
|
||||
def compare_audio_files(audio_paths, threshold):
|
||||
# itertools.combinations 生成所有可能的两两组合,无重复
|
||||
for audio1, audio2 in itertools.combinations(audio_paths, 2):
|
||||
dist = predictor.contrast(audio1, audio2)
|
||||
if dist > threshold:
|
||||
print(f"{audio1} 和 {audio2} 为同一个人,相似度为:{dist}")
|
||||
# else:
|
||||
# print(f"{audio1} 和 {audio2} 不是同一个人,相似度为:{dist}")
|
||||
|
||||
file_path = 'dataset/ces.txt' # 假设音频路径存储在此文件中
|
||||
|
||||
# 执行比对
|
||||
audio_paths = load_audio_paths(file_path)
|
||||
compare_audio_files(audio_paths, args.threshold)
|
||||
# # AudioSegment.silent_semoval(args.audio_path1, args.audio_path1)
|
||||
# # AudioSegment.silent_semoval(args.audio_path2, args.audio_path2)
|
||||
# dist = predictor.contrast(args.audio_path1, args.audio_path2)
|
||||
# if dist > args.threshold:
|
||||
# print(f"{args.audio_path1} 和 {args.audio_path2} 为同一个人,相似度为:{dist}")
|
||||
# else:
|
||||
# print(f"{args.audio_path1} 和 {args.audio_path2} 不是同一个人,相似度为:{dist}")
|
||||
49
infer_recognition.py
Normal file
49
infer_recognition.py
Normal file
@ -0,0 +1,49 @@
|
||||
import argparse
|
||||
import functools
|
||||
|
||||
from mvector.predict import MVectorPredictor
|
||||
from mvector.utils.record import RecordAudio
|
||||
from mvector.utils.utils import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
|
||||
add_arg('use_gpu', bool, True, '是否使用GPU预测')
|
||||
add_arg('audio_db_path', str, 'audio_db/', '音频库的路径')
|
||||
add_arg('record_seconds', int, 3, '录音长度')
|
||||
add_arg('threshold', float, 0.6, '判断是否为同一个人的阈值')
|
||||
add_arg('model_path', str, 'models/ecapa_tdnn_MelSpectrogram/best_model/', '导出的预测模型文件路径')
|
||||
args = parser.parse_args()
|
||||
print_arguments(args=args)
|
||||
|
||||
# 获取识别器
|
||||
predictor = MVectorPredictor(configs=args.configs,
|
||||
threshold=args.threshold,
|
||||
audio_db_path=args.audio_db_path,
|
||||
model_path=args.model_path,
|
||||
use_gpu=args.use_gpu)
|
||||
|
||||
record_audio = RecordAudio()
|
||||
|
||||
while True:
|
||||
select_fun = int(input("请选择功能,0为注册音频到声纹库,1为执行声纹识别,2为删除用户:"))
|
||||
if select_fun == 0:
|
||||
input(f"按下回车键开机录音,录音{args.record_seconds}秒中:")
|
||||
audio_data = record_audio.record(record_seconds=args.record_seconds)
|
||||
name = input("请输入该音频用户的名称:")
|
||||
if name == '': continue
|
||||
predictor.register(user_name=name, audio_data=audio_data, sample_rate=record_audio.sample_rate)
|
||||
elif select_fun == 1:
|
||||
input(f"按下回车键开机录音,录音{args.record_seconds}秒中:")
|
||||
audio_data = record_audio.record(record_seconds=args.record_seconds)
|
||||
name = predictor.recognition(audio_data, sample_rate=record_audio.sample_rate)
|
||||
if name:
|
||||
print(f"识别说话的为:{name}")
|
||||
else:
|
||||
print(f"没有识别到说话人,可能是没注册。")
|
||||
elif select_fun == 2:
|
||||
name = input("请输入该音频用户的名称:")
|
||||
if name == '': continue
|
||||
predictor.remove_user(user_name=name)
|
||||
else:
|
||||
print('请正确选择功能')
|
||||
40
main.py
Normal file
40
main.py
Normal file
@ -0,0 +1,40 @@
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from mvector.predict import MVectorPredictor
|
||||
import os
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 允许跨域
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
predictor = MVectorPredictor(configs='configs/ecapa_tdnn.yml')
|
||||
|
||||
@app.post("/recognize")
|
||||
async def recognize(file: UploadFile = File(...)):
|
||||
try:
|
||||
audio_bytes = await file.read()
|
||||
result = predictor.recognition(audio_bytes)
|
||||
return {"status": 200, "data": result}
|
||||
except Exception as e:
|
||||
return {"status": 500, "error": str(e)}
|
||||
|
||||
@app.post("/compare")
|
||||
async def compare(file1: UploadFile = File(...), file2: UploadFile = File(...)):
|
||||
try:
|
||||
score = predictor.contrast(
|
||||
await file1.read(),
|
||||
await file2.read()
|
||||
)
|
||||
return {"similarity": float(score)}
|
||||
except Exception as e:
|
||||
return {"status": 500, "error": str(e)}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=9001)
|
||||
BIN
models/.DS_Store
vendored
Normal file
BIN
models/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
models/pytorch_model.bin
Normal file
BIN
models/pytorch_model.bin
Normal file
Binary file not shown.
3
mvector/__init__.py
Normal file
3
mvector/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
__version__ = "0.3.9"
|
||||
# 项目支持的模型
|
||||
SUPPORT_MODEL = ['ecapa_tdnn', 'EcapaTdnn', 'Res2Net', 'ResNetSE', 'TDNN']
|
||||
BIN
mvector/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
mvector/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mvector/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/__pycache__/predict.cpython-37.pyc
Normal file
BIN
mvector/__pycache__/predict.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/__pycache__/trainer.cpython-311.pyc
Normal file
BIN
mvector/__pycache__/trainer.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mvector/__pycache__/trainer.cpython-37.pyc
Normal file
BIN
mvector/__pycache__/trainer.cpython-37.pyc
Normal file
Binary file not shown.
0
mvector/data_utils/__init__.py
Normal file
0
mvector/data_utils/__init__.py
Normal file
BIN
mvector/data_utils/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/audio.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/audio.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/collate_fn.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/collate_fn.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/featurizer.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/featurizer.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/reader.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/reader.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/data_utils/__pycache__/utils.cpython-37.pyc
Normal file
BIN
mvector/data_utils/__pycache__/utils.cpython-37.pyc
Normal file
Binary file not shown.
565
mvector/data_utils/audio.py
Normal file
565
mvector/data_utils/audio.py
Normal file
@ -0,0 +1,565 @@
|
||||
import copy
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import resampy
|
||||
import soundfile
|
||||
|
||||
from mvector.data_utils.utils import buf_to_float, vad, decode_audio
|
||||
|
||||
|
||||
class AudioSegment(object):
|
||||
"""Monaural audio segment abstraction.
|
||||
|
||||
:param samples: Audio samples [num_samples x num_channels].
|
||||
:type samples: ndarray.float32
|
||||
:param sample_rate: Audio sample rate.
|
||||
:type sample_rate: int
|
||||
:raises TypeError: If the sample data type is not float or int.
|
||||
"""
|
||||
|
||||
def __init__(self, samples, sample_rate):
|
||||
"""Create audio segment from samples.
|
||||
|
||||
Samples are convert float32 internally, with int scaled to [-1, 1].
|
||||
"""
|
||||
self._samples = self._convert_samples_to_float32(samples)
|
||||
self._sample_rate = sample_rate
|
||||
if self._samples.ndim >= 2:
|
||||
self._samples = np.mean(self._samples, 1)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""返回两个对象是否相等"""
|
||||
if type(other) is not type(self):
|
||||
return False
|
||||
if self._sample_rate != other._sample_rate:
|
||||
return False
|
||||
if self._samples.shape != other._samples.shape:
|
||||
return False
|
||||
if np.any(self.samples != other._samples):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
"""返回两个对象是否不相等"""
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self):
|
||||
"""返回该音频的信息"""
|
||||
return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
|
||||
"rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate, self.duration, self.rms_db))
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file):
|
||||
"""从音频文件创建音频段
|
||||
|
||||
:param file: 文件路径,或者文件对象
|
||||
:type file: str, BufferedReader
|
||||
:return: 音频片段实例
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
assert os.path.exists(file), f'文件不存在,请检查路径:{file}'
|
||||
try:
|
||||
samples, sample_rate = soundfile.read(file, dtype='float32')
|
||||
except:
|
||||
# 支持更多格式数据
|
||||
sample_rate = 16000
|
||||
samples = decode_audio(file=file, sample_rate=sample_rate)
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def slice_from_file(cls, file, start=None, end=None):
|
||||
"""只加载一小段音频,而不需要将整个文件加载到内存中,这是非常浪费的。
|
||||
|
||||
:param file: 输入音频文件路径或文件对象
|
||||
:type file: str|file
|
||||
:param start: 开始时间,单位为秒。如果start是负的,则它从末尾开始计算。如果没有提供,这个函数将从最开始读取。
|
||||
:type start: float
|
||||
:param end: 结束时间,单位为秒。如果end是负的,则它从末尾开始计算。如果没有提供,默认的行为是读取到文件的末尾。
|
||||
:type end: float
|
||||
:return: AudioSegment输入音频文件的指定片的实例。
|
||||
:rtype: AudioSegment
|
||||
:raise ValueError: 如开始或结束的设定不正确,例如时间不允许。
|
||||
"""
|
||||
sndfile = soundfile.SoundFile(file)
|
||||
sample_rate = sndfile.samplerate
|
||||
duration = round(float(len(sndfile)) / sample_rate, 3)
|
||||
start = 0. if start is None else round(start, 3)
|
||||
end = duration if end is None else round(end, 3)
|
||||
# 从末尾开始计
|
||||
if start < 0.0: start += duration
|
||||
if end < 0.0: end += duration
|
||||
# 保证数据不越界
|
||||
if start < 0.0: start = 0.0
|
||||
if end > duration: end = duration
|
||||
if end < 0.0:
|
||||
raise ValueError("切片结束位置(%f s)越界" % end)
|
||||
if start > end:
|
||||
raise ValueError("切片开始位置(%f s)晚于切片结束位置(%f s)" % (start, end))
|
||||
start_frame = int(start * sample_rate)
|
||||
end_frame = int(end * sample_rate)
|
||||
sndfile.seek(start_frame)
|
||||
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
|
||||
return cls(data, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
"""从包含音频样本的字节创建音频段
|
||||
|
||||
:param data: 包含音频样本的字节
|
||||
:type data: bytes
|
||||
:return: 音频部分实例
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
samples, sample_rate = soundfile.read(io.BytesIO(data), dtype='float32')
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def from_pcm_bytes(cls, data, channels=1, samp_width=2, sample_rate=16000):
|
||||
"""从包含无格式PCM音频的字节创建音频
|
||||
|
||||
:param data: 包含音频样本的字节
|
||||
:type data: bytes
|
||||
:param channels: 音频的通道数
|
||||
:type channels: int
|
||||
:param samp_width: 音频采样的宽度,如np.int16为2
|
||||
:type samp_width: int
|
||||
:param sample_rate: 音频样本采样率
|
||||
:type sample_rate: int
|
||||
:return: 音频部分实例
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
samples = buf_to_float(data, n_bytes=samp_width)
|
||||
if channels > 1:
|
||||
samples = samples.reshape(-1, channels)
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def from_ndarray(cls, data, sample_rate=16000):
|
||||
"""从numpy.ndarray创建音频段
|
||||
|
||||
:param data: numpy.ndarray类型的音频数据
|
||||
:type data: ndarray
|
||||
:param sample_rate: 音频样本采样率
|
||||
:type sample_rate: int
|
||||
:return: 音频部分实例
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
return cls(data, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, *segments):
|
||||
"""将任意数量的音频片段连接在一起
|
||||
|
||||
:param *segments: 输入音频片段被连接
|
||||
:type *segments: tuple of AudioSegment
|
||||
:return: Audio segment instance as concatenating results.
|
||||
:rtype: AudioSegment
|
||||
:raises ValueError: If the number of segments is zero, or if the
|
||||
sample_rate of any segments does not match.
|
||||
:raises TypeError: If any segment is not AudioSegment instance.
|
||||
"""
|
||||
# Perform basic sanity-checks.
|
||||
if len(segments) == 0:
|
||||
raise ValueError("没有音频片段被给予连接")
|
||||
sample_rate = segments[0]._sample_rate
|
||||
for seg in segments:
|
||||
if sample_rate != seg._sample_rate:
|
||||
raise ValueError("能用不同的采样率连接片段")
|
||||
if type(seg) is not cls:
|
||||
raise TypeError("只有相同类型的音频片段可以连接")
|
||||
samples = np.concatenate([seg.samples for seg in segments])
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def make_silence(cls, duration, sample_rate):
|
||||
"""创建给定持续时间和采样率的静音音频段
|
||||
|
||||
:param duration: 静音的时间,以秒为单位
|
||||
:type duration: float
|
||||
:param sample_rate: 音频采样率
|
||||
:type sample_rate: float
|
||||
:return: 给定持续时间的静音AudioSegment实例
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
samples = np.zeros(int(duration * sample_rate))
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
def to_wav_file(self, filepath, dtype='float32'):
|
||||
"""保存音频段到磁盘为wav文件
|
||||
|
||||
:param filepath: WAV文件路径或文件对象,以保存音频段
|
||||
:type filepath: str|file
|
||||
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
|
||||
'float32', 'float64'. Default is 'float32'.
|
||||
:type dtype: str
|
||||
:raises TypeError: If dtype is not supported.
|
||||
"""
|
||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||
subtype_map = {
|
||||
'int16': 'PCM_16',
|
||||
'int32': 'PCM_32',
|
||||
'float32': 'FLOAT',
|
||||
'float64': 'DOUBLE'
|
||||
}
|
||||
soundfile.write(
|
||||
filepath,
|
||||
samples,
|
||||
self._sample_rate,
|
||||
format='WAV',
|
||||
subtype=subtype_map[dtype])
|
||||
|
||||
def superimpose(self, other):
|
||||
"""将另一个段的样本添加到这个段的样本中(以样本方式添加,而不是段连接)。
|
||||
|
||||
:param other: 包含样品的片段被添加进去
|
||||
:type other: AudioSegments
|
||||
:raise TypeError: 如果两个片段的类型不匹配
|
||||
:raise ValueError: 不能添加不同类型的段
|
||||
"""
|
||||
if not isinstance(other, type(self)):
|
||||
raise TypeError("不能添加不同类型的段: %s 和 %s" % (type(self), type(other)))
|
||||
if self._sample_rate != other._sample_rate:
|
||||
raise ValueError("采样率必须匹配才能添加片段")
|
||||
if len(self._samples) != len(other._samples):
|
||||
raise ValueError("段长度必须匹配才能添加段")
|
||||
self._samples += other._samples
|
||||
|
||||
def to_bytes(self, dtype='float32'):
|
||||
"""创建包含音频内容的字节字符串
|
||||
|
||||
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
||||
'float32', 'float64'. Default is 'float32'.
|
||||
:type dtype: str
|
||||
:return: Byte string containing audio content.
|
||||
:rtype: str
|
||||
"""
|
||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||
return samples.tostring()
|
||||
|
||||
def to(self, dtype='int16'):
|
||||
"""类型转换
|
||||
|
||||
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
||||
'float32', 'float64'. Default is 'float32'.
|
||||
:type dtype: str
|
||||
:return: np.ndarray containing `dtype` audio content.
|
||||
:rtype: str
|
||||
"""
|
||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||
return samples
|
||||
|
||||
def gain_db(self, gain):
|
||||
"""对音频施加分贝增益。
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param gain: Gain in decibels to apply to samples.
|
||||
:type gain: float|1darray
|
||||
"""
|
||||
self._samples *= 10.**(gain / 20.)
|
||||
|
||||
def change_speed(self, speed_rate):
|
||||
"""通过线性插值改变音频速度
|
||||
|
||||
:param speed_rate: Rate of speed change:
|
||||
speed_rate > 1.0, speed up the audio;
|
||||
speed_rate = 1.0, unchanged;
|
||||
speed_rate < 1.0, slow down the audio;
|
||||
speed_rate <= 0.0, not allowed, raise ValueError.
|
||||
:type speed_rate: float
|
||||
:raises ValueError: If speed_rate <= 0.0.
|
||||
"""
|
||||
if speed_rate == 1.0:
|
||||
return
|
||||
if speed_rate <= 0:
|
||||
raise ValueError("速度速率应大于零")
|
||||
old_length = self._samples.shape[0]
|
||||
new_length = int(old_length / speed_rate)
|
||||
old_indices = np.arange(old_length)
|
||||
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
|
||||
self._samples = np.interp(new_indices, old_indices, self._samples).astype(np.float32)
|
||||
|
||||
def normalize(self, target_db=-20, max_gain_db=300.0):
|
||||
"""将音频归一化,使其具有所需的有效值(以分贝为单位)
|
||||
|
||||
:param target_db: Target RMS value in decibels. This value should be
|
||||
less than 0.0 as 0.0 is full-scale audio.
|
||||
:type target_db: float
|
||||
:param max_gain_db: Max amount of gain in dB that can be applied for
|
||||
normalization. This is to prevent nans when
|
||||
attempting to normalize a signal consisting of
|
||||
all zeros.
|
||||
:type max_gain_db: float
|
||||
:raises ValueError: If the required gain to normalize the segment to
|
||||
the target_db value exceeds max_gain_db.
|
||||
"""
|
||||
if -np.inf == self.rms_db: return
|
||||
gain = target_db - self.rms_db
|
||||
if gain > max_gain_db:
|
||||
raise ValueError(
|
||||
"无法将段规范化到 %f dB,因为可能的增益已经超过max_gain_db (%f dB)" % (target_db, max_gain_db))
|
||||
self.gain_db(min(max_gain_db, target_db - self.rms_db))
|
||||
|
||||
def resample(self, target_sample_rate, filter='kaiser_best'):
|
||||
"""按目标采样率重新采样音频
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param target_sample_rate: Target sample rate.
|
||||
:type target_sample_rate: int
|
||||
:param filter: The resampling filter to use one of {'kaiser_best', 'kaiser_fast'}.
|
||||
:type filter: str
|
||||
"""
|
||||
self._samples = resampy.resample(self.samples, self.sample_rate, target_sample_rate, filter=filter)
|
||||
self._sample_rate = target_sample_rate
|
||||
|
||||
def pad_silence(self, duration, sides='both'):
|
||||
"""在这个音频样本上加一段静音
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param duration: Length of silence in seconds to pad.
|
||||
:type duration: float
|
||||
:param sides: Position for padding:
|
||||
'beginning' - adds silence in the beginning;
|
||||
'end' - adds silence in the end;
|
||||
'both' - adds silence in both the beginning and the end.
|
||||
:type sides: str
|
||||
:raises ValueError: If sides is not supported.
|
||||
"""
|
||||
if duration == 0.0:
|
||||
return self
|
||||
cls = type(self)
|
||||
silence = self.make_silence(duration, self._sample_rate)
|
||||
if sides == "beginning":
|
||||
padded = cls.concatenate(silence, self)
|
||||
elif sides == "end":
|
||||
padded = cls.concatenate(self, silence)
|
||||
elif sides == "both":
|
||||
padded = cls.concatenate(silence, self, silence)
|
||||
else:
|
||||
raise ValueError("Unknown value for the sides %s" % sides)
|
||||
self._samples = padded._samples
|
||||
|
||||
def shift(self, shift_ms):
|
||||
"""音频偏移。如果shift_ms为正,则随时间提前移位;如果为负,则随时间延迟移位。填补静音以保持持续时间不变。
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param shift_ms: Shift time in millseconds. If positive, shift with
|
||||
time advance; if negative; shift with time delay.
|
||||
:type shift_ms: float
|
||||
:raises ValueError: If shift_ms is longer than audio duration.
|
||||
"""
|
||||
if abs(shift_ms) / 1000.0 > self.duration:
|
||||
raise ValueError("shift_ms的绝对值应该小于音频持续时间")
|
||||
shift_samples = int(shift_ms * self._sample_rate / 1000)
|
||||
if shift_samples > 0:
|
||||
# time advance
|
||||
self._samples[:-shift_samples] = self._samples[shift_samples:]
|
||||
self._samples[-shift_samples:] = 0
|
||||
elif shift_samples < 0:
|
||||
# time delay
|
||||
self._samples[-shift_samples:] = self._samples[:shift_samples]
|
||||
self._samples[:-shift_samples] = 0
|
||||
|
||||
def subsegment(self, start_sec=None, end_sec=None):
|
||||
"""在给定的边界之间切割音频片段
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param start_sec: Beginning of subsegment in seconds.
|
||||
:type start_sec: float
|
||||
:param end_sec: End of subsegment in seconds.
|
||||
:type end_sec: float
|
||||
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
|
||||
of bounds in time.
|
||||
"""
|
||||
start_sec = 0.0 if start_sec is None else start_sec
|
||||
end_sec = self.duration if end_sec is None else end_sec
|
||||
if start_sec < 0.0:
|
||||
start_sec = self.duration + start_sec
|
||||
if end_sec < 0.0:
|
||||
end_sec = self.duration + end_sec
|
||||
if start_sec < 0.0:
|
||||
raise ValueError("切片起始位置(%f s)越界" % start_sec)
|
||||
if end_sec < 0.0:
|
||||
raise ValueError("切片结束位置(%f s)越界" % end_sec)
|
||||
if start_sec > end_sec:
|
||||
raise ValueError("切片的起始位置(%f s)晚于结束位置(%f s)" % (start_sec, end_sec))
|
||||
if end_sec > self.duration:
|
||||
raise ValueError("切片结束位置(%f s)越界(> %f s)" % (end_sec, self.duration))
|
||||
start_sample = int(round(start_sec * self._sample_rate))
|
||||
end_sample = int(round(end_sec * self._sample_rate))
|
||||
self._samples = self._samples[start_sample:end_sample]
|
||||
|
||||
def random_subsegment(self, subsegment_length):
|
||||
"""随机剪切指定长度的音频片段
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param subsegment_length: Subsegment length in seconds.
|
||||
:type subsegment_length: float
|
||||
:raises ValueError: If the length of subsegment is greater than
|
||||
the origineal segemnt.
|
||||
"""
|
||||
if subsegment_length > self.duration:
|
||||
raise ValueError("Length of subsegment must not be greater "
|
||||
"than original segment.")
|
||||
start_time = random.uniform(0.0, self.duration - subsegment_length)
|
||||
self.subsegment(start_time, start_time + subsegment_length)
|
||||
|
||||
def add_noise(self,
|
||||
noise,
|
||||
snr_dB,
|
||||
max_gain_db=300.0):
|
||||
"""以特定的信噪比添加给定的噪声段。如果噪声段比该噪声段长,则从该噪声段中采样匹配长度的随机子段。
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param noise: Noise signal to add.
|
||||
:type noise: AudioSegment
|
||||
:param snr_dB: Signal-to-Noise Ratio, in decibels.
|
||||
:type snr_dB: float
|
||||
:param max_gain_db: Maximum amount of gain to apply to noise signal
|
||||
before adding it in. This is to prevent attempting
|
||||
to apply infinite gain to a zero signal.
|
||||
:type max_gain_db: float
|
||||
:raises ValueError: If the sample rate does not match between the two
|
||||
audio segments, or if the duration of noise segments
|
||||
is shorter than original audio segments.
|
||||
"""
|
||||
if noise.sample_rate != self.sample_rate:
|
||||
raise ValueError("噪声采样率(%d Hz)不等于基信号采样率(%d Hz)" % (noise.sample_rate, self.sample_rate))
|
||||
if noise.duration < self.duration:
|
||||
raise ValueError("噪声信号(%f秒)必须至少与基信号(%f秒)一样长" % (noise.duration, self.duration))
|
||||
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
|
||||
noise_new = copy.deepcopy(noise)
|
||||
noise_new.random_subsegment(self.duration)
|
||||
noise_new.gain_db(noise_gain_db)
|
||||
self.superimpose(noise_new)
|
||||
|
||||
def vad(self, top_db=20, overlap=0):
|
||||
self._samples = vad(wav=self._samples, top_db=top_db, overlap=overlap)
|
||||
|
||||
def crop(self, duration, mode='eval'):
|
||||
if self.duration > duration:
|
||||
if mode == 'train':
|
||||
self.random_subsegment(duration)
|
||||
else:
|
||||
self.subsegment(end_sec=duration)
|
||||
|
||||
@property
|
||||
def samples(self):
|
||||
"""返回音频样本
|
||||
|
||||
:return: Audio samples.
|
||||
:rtype: ndarray
|
||||
"""
|
||||
return self._samples.copy()
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
"""返回音频采样率
|
||||
|
||||
:return: Audio sample rate.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._sample_rate
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
"""返回样品数量
|
||||
|
||||
:return: Number of samples.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._samples.shape[0]
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
"""返回音频持续时间
|
||||
|
||||
:return: Audio duration in seconds.
|
||||
:rtype: float
|
||||
"""
|
||||
return self._samples.shape[0] / float(self._sample_rate)
|
||||
|
||||
@property
|
||||
def rms_db(self):
|
||||
"""返回以分贝为单位的音频均方根能量
|
||||
|
||||
:return: Root mean square energy in decibels.
|
||||
:rtype: float
|
||||
"""
|
||||
# square root => multiply by 10 instead of 20 for dBs
|
||||
mean_square = np.mean(self._samples ** 2)
|
||||
return 10 * np.log10(mean_square)
|
||||
|
||||
def _convert_samples_to_float32(self, samples):
|
||||
"""Convert sample type to float32.
|
||||
|
||||
Audio sample type is usually integer or float-point.
|
||||
Integers will be scaled to [-1, 1] in float32.
|
||||
"""
|
||||
float32_samples = samples.astype('float32')
|
||||
if samples.dtype in np.sctypes['int']:
|
||||
bits = np.iinfo(samples.dtype).bits
|
||||
float32_samples *= (1. / 2 ** (bits - 1))
|
||||
elif samples.dtype in np.sctypes['float']:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||
return float32_samples
|
||||
|
||||
def _convert_samples_from_float32(self, samples, dtype):
|
||||
"""Convert sample type from float32 to dtype.
|
||||
|
||||
Audio sample type is usually integer or float-point. For integer
|
||||
type, float32 will be rescaled from [-1, 1] to the maximum range
|
||||
supported by the integer type.
|
||||
|
||||
This is for writing a audio file.
|
||||
"""
|
||||
dtype = np.dtype(dtype)
|
||||
output_samples = samples.copy()
|
||||
if dtype in np.sctypes['int']:
|
||||
bits = np.iinfo(dtype).bits
|
||||
output_samples *= (2 ** (bits - 1) / 1.)
|
||||
min_val = np.iinfo(dtype).min
|
||||
max_val = np.iinfo(dtype).max
|
||||
output_samples[output_samples > max_val] = max_val
|
||||
output_samples[output_samples < min_val] = min_val
|
||||
elif samples.dtype in np.sctypes['float']:
|
||||
min_val = np.finfo(dtype).min
|
||||
max_val = np.finfo(dtype).max
|
||||
output_samples[output_samples > max_val] = max_val
|
||||
output_samples[output_samples < min_val] = min_val
|
||||
else:
|
||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||
return output_samples.astype(dtype)
|
||||
|
||||
def save(self, path, dtype='float32'):
|
||||
"""保存音频段到磁盘为wav文件
|
||||
|
||||
:param path: WAV文件路径或文件对象,以保存音频段
|
||||
:type path: str|file
|
||||
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
|
||||
'float32', 'float64'. Default is 'float32'.
|
||||
:type dtype: str
|
||||
:raises TypeError: If dtype is not supported.
|
||||
"""
|
||||
self.to_wav_file(path, dtype)
|
||||
|
||||
# 静音去除
|
||||
@classmethod
|
||||
def silent_semoval(self, inputpath, outputpath):
|
||||
# 读取音频文件
|
||||
audio = AudioSegment.from_file(inputpath)
|
||||
# 语音活动检测
|
||||
audio.vad()
|
||||
# 保存裁剪后的音频
|
||||
audio.save(outputpath)
|
||||
0
mvector/data_utils/augmentor/__init__.py
Normal file
0
mvector/data_utils/augmentor/__init__.py
Normal file
BIN
mvector/data_utils/augmentor/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/data_utils/augmentor/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
mvector/data_utils/augmentor/__pycache__/base.cpython-37.pyc
Normal file
BIN
mvector/data_utils/augmentor/__pycache__/base.cpython-37.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
mvector/data_utils/augmentor/__pycache__/resample.cpython-37.pyc
Normal file
BIN
mvector/data_utils/augmentor/__pycache__/resample.cpython-37.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
115
mvector/data_utils/augmentor/augmentation.py
Normal file
115
mvector/data_utils/augmentor/augmentation.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""Contains the data augmentation pipeline."""
|
||||
|
||||
import json
|
||||
import random
|
||||
|
||||
from mvector.data_utils.augmentor.noise_perturb import NoisePerturbAugmentor
|
||||
from mvector.data_utils.augmentor.resample import ResampleAugmentor
|
||||
from mvector.data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
|
||||
from mvector.data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
|
||||
from mvector.data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
|
||||
from mvector.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class AugmentationPipeline(object):
|
||||
"""Build a pre-processing pipeline with various augmentation models.Such a
|
||||
data augmentation pipeline is oftern leveraged to augment the training
|
||||
samples to make the model invariant to certain types of perturbations in the
|
||||
real world, improving model's generalization ability.
|
||||
|
||||
The pipeline is built according the the augmentation configuration in json
|
||||
string, e.g.
|
||||
|
||||
.. code-block::
|
||||
[
|
||||
{
|
||||
"type": "noise",
|
||||
"params": {
|
||||
"min_snr_dB": 10,
|
||||
"max_snr_dB": 50,
|
||||
"noise_manifest_path": "dataset/manifest.noise"
|
||||
},
|
||||
"prob": 0.5
|
||||
},
|
||||
{
|
||||
"type": "speed",
|
||||
"params": {
|
||||
"min_speed_rate": 0.9,
|
||||
"max_speed_rate": 1.1,
|
||||
"num_rates": 3
|
||||
},
|
||||
"prob": 1.0
|
||||
},
|
||||
{
|
||||
"type": "shift",
|
||||
"params": {
|
||||
"min_shift_ms": -5,
|
||||
"max_shift_ms": 5
|
||||
},
|
||||
"prob": 1.0
|
||||
},
|
||||
{
|
||||
"type": "volume",
|
||||
"params": {
|
||||
"min_gain_dBFS": -15,
|
||||
"max_gain_dBFS": 15
|
||||
},
|
||||
"prob": 1.0
|
||||
}
|
||||
]
|
||||
This augmentation configuration inserts two augmentation models
|
||||
into the pipeline, with one is VolumePerturbAugmentor and the other
|
||||
SpeedPerturbAugmentor. "prob" indicates the probability of the current
|
||||
augmentor to take effect. If "prob" is zero, the augmentor does not take
|
||||
effect.
|
||||
|
||||
:param augmentation_config: Augmentation configuration in json string.
|
||||
:type augmentation_config: str
|
||||
"""
|
||||
|
||||
def __init__(self, augmentation_config):
|
||||
self._augmentors, self._rates = self._parse_pipeline_from(augmentation_config, aug_type='audio')
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Run the pre-processing pipeline for data augmentation.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to process.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
for augmentor, rate in zip(self._augmentors, self._rates):
|
||||
if random.random() < rate:
|
||||
augmentor.transform_audio(audio_segment)
|
||||
|
||||
def _parse_pipeline_from(self, config_json, aug_type):
|
||||
"""Parse the config json to build a augmentation pipelien."""
|
||||
try:
|
||||
configs = []
|
||||
configs_temp = json.loads(config_json)
|
||||
for config in configs_temp:
|
||||
if config['aug_type'] != aug_type: continue
|
||||
logger.info('数据增强配置:%s' % config)
|
||||
configs.append(config)
|
||||
augmentors = [self._get_augmentor(config["type"], config["params"]) for config in configs]
|
||||
rates = [config["prob"] for config in configs]
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to parse the augmentation config json: %s" % str(e))
|
||||
return augmentors, rates
|
||||
|
||||
def _get_augmentor(self, augmentor_type, params):
|
||||
"""Return an augmentation model by the type name, and pass in params."""
|
||||
if augmentor_type == "volume":
|
||||
return VolumePerturbAugmentor(**params)
|
||||
elif augmentor_type == "shift":
|
||||
return ShiftPerturbAugmentor(**params)
|
||||
elif augmentor_type == "speed":
|
||||
return SpeedPerturbAugmentor(**params)
|
||||
elif augmentor_type == "resample":
|
||||
return ResampleAugmentor(**params)
|
||||
elif augmentor_type == "noise":
|
||||
return NoisePerturbAugmentor(**params)
|
||||
else:
|
||||
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
|
||||
30
mvector/data_utils/augmentor/base.py
Normal file
30
mvector/data_utils/augmentor/base.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Contains the abstract base class for augmentation models."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class AugmentorBase(object):
|
||||
"""Abstract base class for augmentation model (augmentor) class.
|
||||
All augmentor classes should inherit from this class, and implement the
|
||||
following abstract methods.
|
||||
"""
|
||||
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Adds various effects to the input audio segment. Such effects
|
||||
will augment the training data to make the model invariant to certain
|
||||
types of perturbations in the real world, improving model's
|
||||
generalization ability.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
pass
|
||||
57
mvector/data_utils/augmentor/noise_perturb.py
Normal file
57
mvector/data_utils/augmentor/noise_perturb.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Contains the noise perturb augmentation model."""
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||
from mvector.data_utils.audio import AudioSegment
|
||||
|
||||
|
||||
class NoisePerturbAugmentor(AugmentorBase):
|
||||
"""用于添加背景噪声的增强模型
|
||||
|
||||
:param min_snr_dB: Minimal signal noise ratio, in decibels.
|
||||
:type min_snr_dB: float
|
||||
:param max_snr_dB: Maximal signal noise ratio, in decibels.
|
||||
:type max_snr_dB: float
|
||||
:param repetition: repetition noise sum
|
||||
:type repetition: int
|
||||
:param noise_dir: noise audio file dir.
|
||||
:type noise_dir: str
|
||||
"""
|
||||
|
||||
def __init__(self, min_snr_dB, max_snr_dB, repetition, noise_dir):
|
||||
self._min_snr_dB = min_snr_dB
|
||||
self._max_snr_dB = max_snr_dB
|
||||
self.repetition = repetition
|
||||
self.noises_path = []
|
||||
if os.path.exists(noise_dir):
|
||||
for file in os.listdir(noise_dir):
|
||||
self.noises_path.append(os.path.join(noise_dir, file))
|
||||
|
||||
def transform_audio(self, audio_segment: AudioSegment):
|
||||
"""Add background noise audio.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet
|
||||
"""
|
||||
if len(self.noises_path) > 0:
|
||||
for _ in range(random.randint(1, self.repetition)):
|
||||
# 随机选择一个noises_path中的一个
|
||||
noise_path = random.sample(self.noises_path, 1)[0]
|
||||
# 读取噪声音频
|
||||
noise_segment = AudioSegment.from_file(noise_path)
|
||||
# 如果噪声采样率不等于audio_segment的采样率,则重采样
|
||||
if noise_segment.sample_rate != audio_segment.sample_rate:
|
||||
noise_segment.resample(audio_segment.sample_rate)
|
||||
# 随机生成snr_dB的值
|
||||
snr_dB = random.uniform(self._min_snr_dB, self._max_snr_dB)
|
||||
# 如果噪声的长度小于audio_segment的长度,则将噪声的前面的部分填充噪声末尾补长
|
||||
if noise_segment.duration < audio_segment.duration:
|
||||
diff_duration = audio_segment.num_samples - noise_segment.num_samples
|
||||
noise_segment._samples = np.pad(noise_segment.samples, (0, diff_duration), 'wrap')
|
||||
# 将噪声添加到audio_segment中,并将snr_dB调整到最小值和最大值之间
|
||||
audio_segment.add_noise(noise_segment, snr_dB)
|
||||
31
mvector/data_utils/augmentor/resample.py
Normal file
31
mvector/data_utils/augmentor/resample.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Contain the resample augmentation model."""
|
||||
import numpy as np
|
||||
|
||||
from mvector.data_utils.audio import AudioSegment
|
||||
|
||||
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class ResampleAugmentor(AugmentorBase):
|
||||
"""重采样的增强模型
|
||||
|
||||
See more info here:
|
||||
https://ccrma.stanford.edu/~jos/resample/index.html
|
||||
|
||||
:param new_sample_rate: New sample rate in Hz.
|
||||
:type new_sample_rate: int
|
||||
"""
|
||||
|
||||
def __init__(self, new_sample_rate: list):
|
||||
self._new_sample_rate = new_sample_rate
|
||||
|
||||
def transform_audio(self, audio_segment: AudioSegment):
|
||||
"""Resamples the input audio to a target sample rate.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegment|SpeechSegment
|
||||
"""
|
||||
_new_sample_rate = np.random.choice(self._new_sample_rate)
|
||||
audio_segment.resample(_new_sample_rate)
|
||||
31
mvector/data_utils/augmentor/shift_perturb.py
Normal file
31
mvector/data_utils/augmentor/shift_perturb.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Contains the volume perturb augmentation model."""
|
||||
import random
|
||||
|
||||
from mvector.data_utils.audio import AudioSegment
|
||||
|
||||
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class ShiftPerturbAugmentor(AugmentorBase):
|
||||
"""添加随机位移扰动的增强模型
|
||||
|
||||
:param min_shift_ms: Minimal shift in milliseconds.
|
||||
:type min_shift_ms: float
|
||||
:param max_shift_ms: Maximal shift in milliseconds.
|
||||
:type max_shift_ms: float
|
||||
"""
|
||||
|
||||
def __init__(self, min_shift_ms, max_shift_ms):
|
||||
self._min_shift_ms = min_shift_ms
|
||||
self._max_shift_ms = max_shift_ms
|
||||
|
||||
def transform_audio(self, audio_segment: AudioSegment):
|
||||
"""Shift audio.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
shift_ms = random.uniform(self._min_shift_ms, self._max_shift_ms)
|
||||
audio_segment.shift(shift_ms)
|
||||
50
mvector/data_utils/augmentor/speed_perturb.py
Normal file
50
mvector/data_utils/augmentor/speed_perturb.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""Contain the speech perturbation augmentation model."""
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from mvector.data_utils.audio import AudioSegment
|
||||
|
||||
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class SpeedPerturbAugmentor(AugmentorBase):
|
||||
"""添加速度扰动的增强模型
|
||||
|
||||
See reference paper here:
|
||||
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
|
||||
|
||||
:param min_speed_rate: Lower bound of new speed rate to sample and should
|
||||
not be smaller than 0.9.
|
||||
:type min_speed_rate: float
|
||||
:param max_speed_rate: Upper bound of new speed rate to sample and should
|
||||
not be larger than 1.1.
|
||||
:type max_speed_rate: float
|
||||
"""
|
||||
|
||||
def __init__(self, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=3):
|
||||
if min_speed_rate < 0.9:
|
||||
raise ValueError("Sampling speed below 0.9 can cause unnatural effects")
|
||||
if max_speed_rate > 1.1:
|
||||
raise ValueError("Sampling speed above 1.1 can cause unnatural effects")
|
||||
self._min_speed_rate = min_speed_rate
|
||||
self._max_speed_rate = max_speed_rate
|
||||
self._num_rates = num_rates
|
||||
if num_rates > 0:
|
||||
self._rates = np.linspace(self._min_speed_rate, self._max_speed_rate, self._num_rates, endpoint=True)
|
||||
|
||||
def transform_audio(self, audio_segment: AudioSegment):
|
||||
"""Sample a new speed rate from the given range and
|
||||
changes the speed of the given audio clip.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegment|SpeechSegment
|
||||
"""
|
||||
if self._num_rates < 0:
|
||||
speed_rate = random.uniform(self._min_speed_rate, self._max_speed_rate)
|
||||
else:
|
||||
speed_rate = random.choice(self._rates)
|
||||
|
||||
if speed_rate == 1.0: return
|
||||
audio_segment.change_speed(speed_rate)
|
||||
37
mvector/data_utils/augmentor/volume_perturb.py
Normal file
37
mvector/data_utils/augmentor/volume_perturb.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Contains the volume perturb augmentation model."""
|
||||
import random
|
||||
|
||||
from mvector.data_utils.audio import AudioSegment
|
||||
|
||||
from mvector.data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class VolumePerturbAugmentor(AugmentorBase):
|
||||
"""添加随机音量扰动的增强模型
|
||||
|
||||
This is used for multi-loudness training of PCEN. See
|
||||
|
||||
https://arxiv.org/pdf/1607.05666v1.pdf
|
||||
|
||||
for more details.
|
||||
|
||||
:param min_gain_dBFS: Minimal gain in dBFS.
|
||||
:type min_gain_dBFS: float
|
||||
:param max_gain_dBFS: Maximal gain in dBFS.
|
||||
:type max_gain_dBFS: float
|
||||
"""
|
||||
|
||||
def __init__(self, min_gain_dBFS, max_gain_dBFS):
|
||||
self._min_gain_dBFS = min_gain_dBFS
|
||||
self._max_gain_dBFS = max_gain_dBFS
|
||||
|
||||
def transform_audio(self, audio_segment: AudioSegment):
|
||||
"""Change audio loadness.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
gain = random.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
|
||||
audio_segment.gain_db(gain)
|
||||
25
mvector/data_utils/collate_fn.py
Normal file
25
mvector/data_utils/collate_fn.py
Normal file
@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
# 对一个batch的数据处理
|
||||
def collate_fn(batch):
|
||||
# 找出音频长度最长的
|
||||
batch = sorted(batch, key=lambda sample: sample[0].shape[0], reverse=True)
|
||||
max_audio_length = batch[0][0].shape[0]
|
||||
batch_size = len(batch)
|
||||
# 以最大的长度创建0张量
|
||||
inputs = np.zeros((batch_size, max_audio_length), dtype='float32')
|
||||
input_lens_ratio = []
|
||||
labels = []
|
||||
for x in range(batch_size):
|
||||
sample = batch[x]
|
||||
tensor = sample[0]
|
||||
labels.append(sample[1])
|
||||
seq_length = tensor.shape[0]
|
||||
# 将数据插入都0张量中,实现了padding
|
||||
inputs[x, :seq_length] = tensor[:]
|
||||
input_lens_ratio.append(seq_length/max_audio_length)
|
||||
input_lens_ratio = np.array(input_lens_ratio, dtype='float32')
|
||||
labels = np.array(labels, dtype='int64')
|
||||
return torch.tensor(inputs), torch.tensor(labels), torch.tensor(input_lens_ratio)
|
||||
103
mvector/data_utils/featurizer.py
Normal file
103
mvector/data_utils/featurizer.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchaudio.transforms import MelSpectrogram, Spectrogram, MFCC
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
|
||||
|
||||
class AudioFeaturizer(nn.Module):
|
||||
"""音频特征器
|
||||
|
||||
:param feature_method: 所使用的预处理方法
|
||||
:type feature_method: str
|
||||
:param feature_conf: 预处理方法的参数
|
||||
:type feature_conf: dict
|
||||
"""
|
||||
|
||||
def __init__(self, feature_method='MelSpectrogram', feature_conf={}):
|
||||
super().__init__()
|
||||
self._feature_conf = feature_conf
|
||||
self._feature_method = feature_method
|
||||
if feature_method == 'MelSpectrogram':
|
||||
self.feat_fun = MelSpectrogram(**feature_conf)
|
||||
elif feature_method == 'Spectrogram':
|
||||
self.feat_fun = Spectrogram(**feature_conf)
|
||||
elif feature_method == 'MFCC':
|
||||
melkwargs = feature_conf.copy()
|
||||
del melkwargs['sample_rate']
|
||||
del melkwargs['n_mfcc']
|
||||
self.feat_fun = MFCC(sample_rate=self._feature_conf.sample_rate,
|
||||
n_mfcc=self._feature_conf.n_mfcc,
|
||||
melkwargs=melkwargs)
|
||||
elif feature_method == 'Fbank':
|
||||
self.feat_fun = KaldiFbank(**feature_conf)
|
||||
else:
|
||||
raise Exception(f'预处理方法 {self._feature_method} 不存在!')
|
||||
|
||||
def forward(self, waveforms, input_lens_ratio):
|
||||
"""从AudioSegment中提取音频特征
|
||||
|
||||
:param waveforms: Audio segment to extract features from.
|
||||
:type waveforms: AudioSegment
|
||||
:param input_lens_ratio: input length ratio
|
||||
:type input_lens_ratio: tensor
|
||||
:return: Spectrogram audio feature in 2darray.
|
||||
:rtype: ndarray
|
||||
"""
|
||||
feature = self.feat_fun(waveforms)
|
||||
feature = feature.transpose(2, 1)
|
||||
# 归一化
|
||||
mean = torch.mean(feature, 1, keepdim=True)
|
||||
std = torch.std(feature, 1, keepdim=True)
|
||||
feature = (feature - mean) / (std + 1e-5)
|
||||
# 对掩码比例进行扩展
|
||||
input_lens = (input_lens_ratio * feature.shape[1])
|
||||
mask_lens = torch.round(input_lens).long()
|
||||
mask_lens = mask_lens.unsqueeze(1)
|
||||
input_lens = input_lens.int()
|
||||
# 生成掩码张量
|
||||
idxs = torch.arange(feature.shape[1], device=feature.device).repeat(feature.shape[0], 1)
|
||||
mask = idxs < mask_lens
|
||||
mask = mask.unsqueeze(-1)
|
||||
# 对特征进行掩码操作
|
||||
feature_masked = torch.where(mask, feature, torch.zeros_like(feature))
|
||||
return feature_masked, input_lens
|
||||
|
||||
@property
|
||||
def feature_dim(self):
|
||||
"""返回特征大小
|
||||
|
||||
:return: 特征大小
|
||||
:rtype: int
|
||||
"""
|
||||
if self._feature_method == 'LogMelSpectrogram':
|
||||
return self._feature_conf.n_mels
|
||||
elif self._feature_method == 'MelSpectrogram':
|
||||
return self._feature_conf.n_mels
|
||||
elif self._feature_method == 'Spectrogram':
|
||||
return self._feature_conf.n_fft // 2 + 1
|
||||
elif self._feature_method == 'MFCC':
|
||||
return self._feature_conf.n_mfcc
|
||||
elif self._feature_method == 'Fbank':
|
||||
return self._feature_conf.num_mel_bins
|
||||
else:
|
||||
raise Exception('没有{}预处理方法'.format(self._feature_method))
|
||||
|
||||
class KaldiFbank(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super(KaldiFbank, self).__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, waveforms):
|
||||
"""
|
||||
:param waveforms: [Batch, Length]
|
||||
:return: [Batch, Length, Feature]
|
||||
"""
|
||||
log_fbanks = []
|
||||
for waveform in waveforms:
|
||||
if len(waveform.shape) == 1:
|
||||
waveform = waveform.unsqueeze(0)
|
||||
log_fbank = Kaldi.fbank(waveform, **self.kwargs)
|
||||
log_fbank = log_fbank.transpose(0, 1)
|
||||
log_fbanks.append(log_fbank)
|
||||
log_fbank = torch.stack(log_fbanks)
|
||||
return log_fbank
|
||||
73
mvector/data_utils/reader.py
Normal file
73
mvector/data_utils/reader.py
Normal file
@ -0,0 +1,73 @@
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mvector.data_utils.audio import AudioSegment
|
||||
from mvector.data_utils.augmentor.augmentation import AugmentationPipeline
|
||||
from mvector.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
def __init__(self,
|
||||
data_list_path,
|
||||
do_vad=True,
|
||||
max_duration=6,
|
||||
min_duration=0.5,
|
||||
augmentation_config='{}',
|
||||
mode='train',
|
||||
sample_rate=16000,
|
||||
use_dB_normalization=True,
|
||||
target_dB=-20):
|
||||
"""音频数据加载器
|
||||
|
||||
Args:
|
||||
data_list_path: 包含音频路径和标签的数据列表文件的路径
|
||||
do_vad: 是否对音频进行语音活动检测(VAD)来裁剪静音部分
|
||||
max_duration: 最长的音频长度,大于这个长度会裁剪掉
|
||||
min_duration: 过滤最短的音频长度
|
||||
augmentation_config: 用于指定音频增强的配置
|
||||
mode: 数据集模式。在训练模式下,数据集可能会进行一些数据增强的预处理
|
||||
sample_rate: 采样率
|
||||
use_dB_normalization: 是否对音频进行音量归一化
|
||||
target_dB: 音量归一化的大小
|
||||
"""
|
||||
super(CustomDataset, self).__init__()
|
||||
self.do_vad = do_vad
|
||||
self.max_duration = max_duration
|
||||
self.min_duration = min_duration
|
||||
self.mode = mode
|
||||
self._target_sample_rate = sample_rate
|
||||
self._use_dB_normalization = use_dB_normalization
|
||||
self._target_dB = target_dB
|
||||
self._augmentation_pipeline = AugmentationPipeline(augmentation_config=augmentation_config)
|
||||
# 获取数据列表
|
||||
with open(data_list_path, 'r') as f:
|
||||
self.lines = f.readlines()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# 分割音频路径和标签
|
||||
audio_path, label = self.lines[idx].replace('\n', '').split('\t')
|
||||
# 读取音频
|
||||
audio_segment = AudioSegment.from_file(audio_path)
|
||||
# 裁剪静音
|
||||
if self.do_vad:
|
||||
audio_segment.vad()
|
||||
# 数据太短不利于训练
|
||||
if self.mode == 'train':
|
||||
if audio_segment.duration < self.min_duration:
|
||||
return self.__getitem__(idx + 1 if idx < len(self.lines) - 1 else 0)
|
||||
# 重采样
|
||||
if audio_segment.sample_rate != self._target_sample_rate:
|
||||
audio_segment.resample(self._target_sample_rate)
|
||||
# decibel normalization
|
||||
if self._use_dB_normalization:
|
||||
audio_segment.normalize(target_db=self._target_dB)
|
||||
# 裁剪需要的数据
|
||||
audio_segment.crop(duration=self.max_duration, mode=self.mode)
|
||||
# 音频增强
|
||||
self._augmentation_pipeline.transform_audio(audio_segment)
|
||||
return np.array(audio_segment.samples, dtype=np.float32), np.array(int(label), dtype=np.int64)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lines)
|
||||
179
mvector/data_utils/utils.py
Normal file
179
mvector/data_utils/utils.py
Normal file
@ -0,0 +1,179 @@
|
||||
import io
|
||||
import itertools
|
||||
|
||||
import av
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def vad(wav, top_db=10, overlap=200):
|
||||
"""
|
||||
去除音频中的静音部分
|
||||
参数:
|
||||
wav: 音频数据
|
||||
top_db: 信噪比
|
||||
overlap: 重叠长度
|
||||
返回值:
|
||||
wav_output: 去除静音后的音频数据
|
||||
"""
|
||||
intervals = librosa.effects.split(wav, top_db=top_db)
|
||||
if len(intervals) == 0:
|
||||
return wav
|
||||
wav_output = [np.array([])]
|
||||
for sliced in intervals:
|
||||
seg = wav[sliced[0]:sliced[1]]
|
||||
if len(seg) < 2 * overlap:
|
||||
wav_output[-1] = np.concatenate((wav_output[-1], seg))
|
||||
else:
|
||||
wav_output.append(seg)
|
||||
wav_output = [x for x in wav_output if len(x) > 0]
|
||||
|
||||
if len(wav_output) == 1:
|
||||
wav_output = wav_output[0]
|
||||
else:
|
||||
wav_output = concatenate(wav_output)
|
||||
return wav_output
|
||||
|
||||
|
||||
def concatenate(wave, overlap=200):
|
||||
"""
|
||||
拼接音频
|
||||
参数:
|
||||
wave: 音频数据
|
||||
overlap: 重叠长度
|
||||
返回值:
|
||||
unfolded: 拼接后的音频数据
|
||||
"""
|
||||
total_len = sum([len(x) for x in wave])
|
||||
unfolded = np.zeros(total_len)
|
||||
|
||||
# Equal power crossfade
|
||||
window = np.hanning(2 * overlap)
|
||||
fade_in = window[:overlap]
|
||||
fade_out = window[-overlap:]
|
||||
|
||||
end = total_len
|
||||
for i in range(1, len(wave)):
|
||||
prev = wave[i - 1]
|
||||
curr = wave[i]
|
||||
|
||||
if i == 1:
|
||||
end = len(prev)
|
||||
unfolded[:end] += prev
|
||||
|
||||
max_idx = 0
|
||||
max_corr = 0
|
||||
pattern = prev[-overlap:]
|
||||
# slide the curr batch to match with the pattern of previous one
|
||||
for j in range(overlap):
|
||||
match = curr[j:j + overlap]
|
||||
corr = np.sum(pattern * match) / [(np.sqrt(np.sum(pattern ** 2)) * np.sqrt(np.sum(match ** 2))) + 1e-8]
|
||||
if corr > max_corr:
|
||||
max_idx = j
|
||||
max_corr = corr
|
||||
|
||||
# Apply the gain to the overlap samples
|
||||
start = end - overlap
|
||||
unfolded[start:end] *= fade_out
|
||||
end = start + (len(curr) - max_idx)
|
||||
curr[max_idx:max_idx + overlap] *= fade_in
|
||||
unfolded[start:end] += curr[max_idx:]
|
||||
return unfolded[:end]
|
||||
|
||||
|
||||
def decode_audio(file, sample_rate: int = 16000):
|
||||
"""读取音频,主要用于兜底读取,支持各种数据格式
|
||||
|
||||
Args:
|
||||
file: Path to the input file or a file-like object.
|
||||
sample_rate: Resample the audio to this sample rate.
|
||||
|
||||
Returns:
|
||||
A float32 Numpy array.
|
||||
"""
|
||||
resampler = av.audio.resampler.AudioResampler(format="s16", layout="mono", rate=sample_rate)
|
||||
|
||||
raw_buffer = io.BytesIO()
|
||||
dtype = None
|
||||
|
||||
with av.open(file, metadata_errors="ignore") as container:
|
||||
frames = container.decode(audio=0)
|
||||
frames = _ignore_invalid_frames(frames)
|
||||
frames = _group_frames(frames, 500000)
|
||||
frames = _resample_frames(frames, resampler)
|
||||
|
||||
for frame in frames:
|
||||
array = frame.to_ndarray()
|
||||
dtype = array.dtype
|
||||
raw_buffer.write(array)
|
||||
|
||||
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
|
||||
|
||||
# Convert s16 back to f32.
|
||||
return audio.astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def _ignore_invalid_frames(frames):
|
||||
iterator = iter(frames)
|
||||
|
||||
while True:
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
break
|
||||
except av.error.InvalidDataError:
|
||||
continue
|
||||
|
||||
|
||||
def _group_frames(frames, num_samples=None):
|
||||
fifo = av.audio.fifo.AudioFifo()
|
||||
|
||||
for frame in frames:
|
||||
frame.pts = None # Ignore timestamp check.
|
||||
fifo.write(frame)
|
||||
|
||||
if num_samples is not None and fifo.samples >= num_samples:
|
||||
yield fifo.read()
|
||||
|
||||
if fifo.samples > 0:
|
||||
yield fifo.read()
|
||||
|
||||
|
||||
def _resample_frames(frames, resampler):
|
||||
# Add None to flush the resampler.
|
||||
for frame in itertools.chain(frames, [None]):
|
||||
yield from resampler.resample(frame)
|
||||
|
||||
|
||||
# 将音频流转换为numpy
|
||||
def buf_to_float(x, n_bytes=2, dtype=np.float32):
|
||||
"""Convert an integer buffer to floating point values.
|
||||
This is primarily useful when loading integer-valued wav data
|
||||
into numpy arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : np.ndarray [dtype=int]
|
||||
The integer-valued data buffer
|
||||
|
||||
n_bytes : int [1, 2, 4]
|
||||
The number of bytes per sample in ``x``
|
||||
|
||||
dtype : numeric type
|
||||
The target output type (default: 32-bit float)
|
||||
|
||||
Returns
|
||||
-------
|
||||
x_float : np.ndarray [dtype=float]
|
||||
The input data buffer cast to floating point
|
||||
"""
|
||||
|
||||
# Invert the scale of the data
|
||||
scale = 1.0 / float(1 << ((8 * n_bytes) - 1))
|
||||
|
||||
# Construct the format string
|
||||
fmt = "<i{:d}".format(n_bytes)
|
||||
|
||||
# Rescale and format the data buffer
|
||||
return scale * np.frombuffer(x, fmt).astype(dtype)
|
||||
0
mvector/metric/__init__.py
Normal file
0
mvector/metric/__init__.py
Normal file
BIN
mvector/metric/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/metric/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/metric/__pycache__/metrics.cpython-37.pyc
Normal file
BIN
mvector/metric/__pycache__/metrics.cpython-37.pyc
Normal file
Binary file not shown.
63
mvector/metric/metrics.py
Normal file
63
mvector/metric/metrics.py
Normal file
@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
|
||||
from mvector.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class TprAtFpr(object):
|
||||
def __init__(self, max_fpr=0.01):
|
||||
self.pos_score_list = []
|
||||
self.neg_score_list = []
|
||||
self.max_fpr = max_fpr
|
||||
|
||||
def add(self, y_labels, y_scores):
|
||||
for y_label, y_score in zip(y_labels, y_scores):
|
||||
if y_label == 0:
|
||||
self.neg_score_list.append(y_score)
|
||||
else:
|
||||
self.pos_score_list.append(y_score)
|
||||
|
||||
def reset(self):
|
||||
self.pos_score_list = []
|
||||
self.neg_score_list = []
|
||||
|
||||
def calculate_eer(self, tprs, fprs):
|
||||
# 记录所有的eer值
|
||||
eer_list = []
|
||||
n = len(tprs)
|
||||
eer = 1.0
|
||||
index = 0
|
||||
for i in range(n):
|
||||
eer_list.append(fprs[i] + (1 - tprs[i]))
|
||||
if fprs[i] + (1 - tprs[i]) < eer:
|
||||
eer = fprs[i] + (1 - tprs[i])
|
||||
index = i
|
||||
return eer, index,eer_list
|
||||
|
||||
def calculate(self):
|
||||
tprs, fprs, thresholds = [], [], []
|
||||
pos_score_list = np.array(self.pos_score_list)
|
||||
neg_score_list = np.array(self.neg_score_list)
|
||||
if len(pos_score_list) == 0:
|
||||
msg = f"The number of positive samples is 0, please add positive samples."
|
||||
logger.warning(msg)
|
||||
return tprs, fprs, thresholds, None, None
|
||||
if len(neg_score_list) == 0:
|
||||
msg = f"The number of negative samples is 0, please add negative samples."
|
||||
logger.warning(msg)
|
||||
return tprs, fprs, thresholds, None, None
|
||||
for i in range(0, 100):
|
||||
threshold = i / 100.
|
||||
tpr = np.sum(pos_score_list > threshold) / len(pos_score_list)
|
||||
fpr = np.sum(neg_score_list > threshold) / len(neg_score_list)
|
||||
tprs.append(tpr)
|
||||
fprs.append(fpr)
|
||||
thresholds.append(threshold)
|
||||
eer, index,eer_list = self.calculate_eer(fprs=fprs, tprs=tprs)
|
||||
|
||||
# 根据对应的eer_list输出所有对应的阈值
|
||||
for i in range(len(eer_list)):
|
||||
print(f"threshold: {thresholds[i]}, eer: {eer_list[i]}")
|
||||
|
||||
return tprs, fprs, thresholds, eer, index
|
||||
0
mvector/models/__init__.py
Normal file
0
mvector/models/__init__.py
Normal file
BIN
mvector/models/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/models/__pycache__/ecapa_tdnn.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/ecapa_tdnn.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/models/__pycache__/fc.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/fc.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/models/__pycache__/loss.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/loss.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/models/__pycache__/pooling.cpython-37.pyc
Normal file
BIN
mvector/models/__pycache__/pooling.cpython-37.pyc
Normal file
Binary file not shown.
189
mvector/models/ecapa_tdnn.py
Normal file
189
mvector/models/ecapa_tdnn.py
Normal file
@ -0,0 +1,189 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mvector.models.pooling import AttentiveStatsPool, TemporalAveragePooling
|
||||
from mvector.models.pooling import SelfAttentivePooling, TemporalStatisticsPooling
|
||||
|
||||
|
||||
class Res2Conv1dReluBn(nn.Module):
|
||||
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False, scale=4):
|
||||
super().__init__()
|
||||
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
||||
self.scale = scale
|
||||
self.width = channels // scale
|
||||
self.nums = scale if scale == 1 else scale - 1
|
||||
|
||||
self.convs = []
|
||||
self.bns = []
|
||||
for i in range(self.nums):
|
||||
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
|
||||
self.bns.append(nn.BatchNorm1d(self.width))
|
||||
self.convs = nn.ModuleList(self.convs)
|
||||
self.bns = nn.ModuleList(self.bns)
|
||||
|
||||
def forward(self, x):
|
||||
out = []
|
||||
spx = torch.split(x, self.width, 1)
|
||||
# 遍历每个分支
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
# 其他分支则将当前子特征与前面所有子特征相加,形成残差连接
|
||||
sp = sp + spx[i]
|
||||
# Order: conv -> relu -> bn
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.bns[i](F.relu(sp))
|
||||
out.append(sp)
|
||||
if self.scale != 1:
|
||||
out.append(spx[self.nums])
|
||||
|
||||
# 将所有子分支的结果在通道维度上合并
|
||||
out = torch.cat(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
class Conv1dReluBn(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
||||
self.bn = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(F.relu(self.conv(x)))
|
||||
|
||||
|
||||
class SE_Connect(nn.Module):
|
||||
def __init__(self, channels, s=2):
|
||||
super().__init__()
|
||||
assert channels % s == 0, "{} % {} != 0".format(channels, s)
|
||||
self.linear1 = nn.Linear(channels, channels // s)
|
||||
self.linear2 = nn.Linear(channels // s, channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = x.mean(dim=2)
|
||||
out = F.relu(self.linear1(out))
|
||||
out = torch.sigmoid(self.linear2(out))
|
||||
out = x * out.unsqueeze(2)
|
||||
return out
|
||||
|
||||
def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
||||
"""
|
||||
初始化函数。
|
||||
|
||||
参数:
|
||||
- input_size: 输入尺寸,默认为80。
|
||||
- channels: 通道数,默认为512。
|
||||
- kernel_size: 卷积核大小, 默认为3。
|
||||
- embd_dim: 嵌入维度,默认为192。
|
||||
- pooling_type: 池化类型,默认为"ASP",可选值包括"ASP"、"SAP"、"TAP"、"TSP"。
|
||||
- dilation : 空洞卷积的空洞率,默认为1。
|
||||
- scale: SE模块的缩放比例,默认为8。
|
||||
|
||||
返回值:
|
||||
- 无。
|
||||
"""
|
||||
return nn.Sequential(
|
||||
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
||||
Res2Conv1dReluBn(channels, kernel_size, stride, padding, dilation, scale=scale),
|
||||
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
||||
SE_Connect(channels)
|
||||
)
|
||||
|
||||
|
||||
class EcapaTdnn(nn.Module):
|
||||
"""
|
||||
初始化函数。
|
||||
|
||||
参数:
|
||||
- input_size: 输入尺寸,默认为80。
|
||||
- channels: 通道数,默认为512。
|
||||
- embd_dim: 嵌入维度,默认为192。
|
||||
- pooling_type: 池化类型,默认为"ASP",可选值包括"ASP"、"SAP"、"TAP"、"TSP"。
|
||||
"""
|
||||
def __init__(self, input_size=80, channels=512, embd_dim=192, pooling_type="ASP"):
|
||||
super().__init__()
|
||||
self.layer1 = Conv1dReluBn(input_size, channels, kernel_size=5, padding=2, dilation=1)
|
||||
self.layer2 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8)
|
||||
self.layer3 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8)
|
||||
self.layer4 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8)
|
||||
|
||||
cat_channels = channels * 3
|
||||
self.emb_size = embd_dim
|
||||
self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1)
|
||||
if pooling_type == "ASP":
|
||||
self.pooling = AttentiveStatsPool(cat_channels, 128)
|
||||
self.bn1 = nn.BatchNorm1d(cat_channels * 2)
|
||||
self.linear = nn.Linear(cat_channels * 2, embd_dim)
|
||||
self.bn2 = nn.BatchNorm1d(embd_dim)
|
||||
elif pooling_type == "SAP":
|
||||
self.pooling = SelfAttentivePooling(cat_channels, 128)
|
||||
self.bn1 = nn.BatchNorm1d(cat_channels)
|
||||
self.linear = nn.Linear(cat_channels, embd_dim)
|
||||
self.bn2 = nn.BatchNorm1d(embd_dim)
|
||||
elif pooling_type == "TAP":
|
||||
self.pooling = TemporalAveragePooling()
|
||||
self.bn1 = nn.BatchNorm1d(cat_channels)
|
||||
self.linear = nn.Linear(cat_channels, embd_dim)
|
||||
self.bn2 = nn.BatchNorm1d(embd_dim)
|
||||
elif pooling_type == "TSP":
|
||||
self.pooling = TemporalStatisticsPooling()
|
||||
self.bn1 = nn.BatchNorm1d(cat_channels * 2)
|
||||
self.linear = nn.Linear(cat_channels * 2, embd_dim)
|
||||
self.bn2 = nn.BatchNorm1d(embd_dim)
|
||||
else:
|
||||
raise Exception(f'没有{pooling_type}池化层!')
|
||||
|
||||
# def forward(self, x):
|
||||
# """
|
||||
# Compute embeddings.
|
||||
|
||||
# Args:
|
||||
# x (torch.Tensor): Input data with shape (N, time, freq).
|
||||
|
||||
# Returns:
|
||||
# torch.Tensor: Output embeddings with shape (N, self.emb_size, 1)
|
||||
# """
|
||||
# x = x.transpose(2, 1)
|
||||
# out1 = self.layer1(x)
|
||||
# out2 = self.layer2(out1) + out1
|
||||
# out3 = self.layer3(out1 + out2) + out1 + out2
|
||||
# out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
|
||||
|
||||
# out = torch.cat([out2, out3, out4], dim=1)
|
||||
# out = F.relu(self.conv(out))
|
||||
# out = self.bn1(self.pooling(out))
|
||||
# out = self.bn2(self.linear(out))
|
||||
# return out
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
计算嵌入向量。
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入数据,形状为 (N, time, freq),其中N为样本数量,time为时间维度,freq为频率维度。
|
||||
|
||||
返回值:
|
||||
torch.Tensor: 输出嵌入向量,形状为 (N, self.emb_size, 1)
|
||||
"""
|
||||
# 将输入数据的频率和时间维度交换
|
||||
x = x.transpose(2, 1)
|
||||
# 通过第一层卷积层
|
||||
out1 = self.layer1(x)
|
||||
# 通过第二层卷积层,并与第一层输出相加
|
||||
out2 = self.layer2(out1) + out1
|
||||
# 通过第三层卷积层,并依次与前两层输出相加
|
||||
out3 = self.layer3(out1 + out2) + out1 + out2
|
||||
# 通过第四层卷积层,并依次与前三层输出相加
|
||||
out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
|
||||
|
||||
# 将第二、三、四层的输出在特征维度上连接
|
||||
out = torch.cat([out2, out3, out4], dim=1)
|
||||
# 应用ReLU激活函数,并通过卷积层处理
|
||||
out = F.relu(self.conv(out))
|
||||
# 经过批归一化和池化操作
|
||||
out = self.bn1(self.pooling(out))
|
||||
# 经过线性变换和批归一化
|
||||
out = self.bn2(self.linear(out))
|
||||
return out
|
||||
90
mvector/models/fc.py
Normal file
90
mvector/models/fc.py
Normal file
@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class SpeakerIdetification(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
num_class=1,
|
||||
loss_type='AAMLoss',
|
||||
lin_blocks=0,
|
||||
lin_neurons=192,
|
||||
dropout=0.1, ):
|
||||
|
||||
"""
|
||||
初始化说话人识别模型,包括说话人背骨网络和在训练中针对说话人类别数的线性变换。
|
||||
|
||||
参数:
|
||||
backbone (Paddle.nn.Layer class): 说话人识别背骨网络模型。
|
||||
num_class (_type_): 训练数据集中说话人的类别数。
|
||||
lin_blocks (int, 可选): 从嵌入向量到最终线性层之间的线性层变换数量。默认为0。
|
||||
lin_neurons (int, 可选): 最终线性层的输出维度。默认为192。
|
||||
dropout (float, 可选): 嵌入向量上的dropout因子。默认为0.1。
|
||||
"""
|
||||
super(SpeakerIdetification, self).__init__()
|
||||
# 初始化背骨网络模型
|
||||
# 背骨网络的输出为目标嵌入向量
|
||||
self.backbone = backbone
|
||||
self.loss_type = loss_type
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
# 构建说话人分类器
|
||||
input_size = self.backbone.emb_size
|
||||
self.blocks = list()
|
||||
# 添加线性层变换
|
||||
for i in range(lin_blocks):
|
||||
self.blocks.extend([
|
||||
nn.BatchNorm1d(input_size),
|
||||
nn.Linear(in_features=input_size, out_features=lin_neurons),
|
||||
])
|
||||
input_size = lin_neurons
|
||||
|
||||
# 最终层初始化
|
||||
if self.loss_type == 'AAMLoss':
|
||||
self.weight = Parameter(torch.FloatTensor(num_class, input_size), requires_grad=True)
|
||||
nn.init.xavier_normal_(self.weight, gain=1)
|
||||
elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss':
|
||||
self.weight = Parameter(torch.FloatTensor(input_size, num_class), requires_grad=True)
|
||||
nn.init.xavier_normal_(self.weight, gain=1)
|
||||
elif self.loss_type == 'CELoss':
|
||||
self.output = nn.Linear(input_size, num_class)
|
||||
else:
|
||||
raise Exception(f'没有{self.loss_type}损失函数!')
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
执行说话人识别模型的前向传播,
|
||||
包括说话人嵌入模型和分类器模型网络
|
||||
|
||||
参数:
|
||||
x (paddle.Tensor): 输入的音频特征,
|
||||
形状=[批大小, 时间, 维度]
|
||||
|
||||
返回值:
|
||||
paddle.Tensor: 返回特征的logits
|
||||
"""
|
||||
# x.shape: (N, L, C)
|
||||
x = self.backbone(x) # (N, emb_size)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
for fc in self.blocks:
|
||||
x = fc(x)
|
||||
if self.loss_type == 'AAMLoss':
|
||||
logits = F.linear(F.normalize(x), F.normalize(self.weight, dim=-1))
|
||||
elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss':
|
||||
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
|
||||
x_norm = torch.div(x, x_norm)
|
||||
w_norm = torch.norm(self.weight, p=2, dim=0, keepdim=True).clamp(min=1e-12)
|
||||
w_norm = torch.div(self.weight, w_norm)
|
||||
logits = torch.mm(x_norm, w_norm)
|
||||
else:
|
||||
logits = self.output(x)
|
||||
|
||||
return logits
|
||||
97
mvector/models/loss.py
Normal file
97
mvector/models/loss.py
Normal file
@ -0,0 +1,97 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class AdditiveAngularMargin(nn.Module):
|
||||
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
|
||||
"""The Implementation of Additive Angular Margin (AAM) proposed
|
||||
in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''
|
||||
(https://arxiv.org/abs/1906.07317)
|
||||
|
||||
Args:
|
||||
margin (float, optional): margin factor. Defaults to 0.0.
|
||||
scale (float, optional): scale factor. Defaults to 1.0.
|
||||
easy_margin (bool, optional): easy_margin flag. Defaults to False.
|
||||
"""
|
||||
super(AdditiveAngularMargin, self).__init__()
|
||||
self.margin = margin
|
||||
self.scale = scale
|
||||
self.easy_margin = easy_margin
|
||||
|
||||
self.cos_m = math.cos(self.margin)
|
||||
self.sin_m = math.sin(self.margin)
|
||||
self.th = math.cos(math.pi - self.margin)
|
||||
self.mm = math.sin(math.pi - self.margin) * self.margin
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
cosine = outputs.float()
|
||||
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
if self.easy_margin:
|
||||
phi = torch.where(cosine > 0, phi, cosine)
|
||||
else:
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
||||
outputs = (targets * phi) + ((1.0 - targets) * cosine)
|
||||
return self.scale * outputs
|
||||
|
||||
|
||||
class AAMLoss(nn.Module):
|
||||
def __init__(self, margin=0.2, scale=30, easy_margin=False):
|
||||
super(AAMLoss, self).__init__()
|
||||
self.loss_fn = AdditiveAngularMargin(margin=margin, scale=scale, easy_margin=easy_margin)
|
||||
self.criterion = torch.nn.KLDivLoss(reduction="sum")
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
targets = F.one_hot(targets, outputs.shape[1]).float()
|
||||
predictions = self.loss_fn(outputs, targets)
|
||||
predictions = F.log_softmax(predictions, dim=1)
|
||||
loss = self.criterion(predictions, targets) / targets.sum()
|
||||
return loss
|
||||
|
||||
|
||||
class AMLoss(nn.Module):
|
||||
def __init__(self, margin=0.2, scale=30):
|
||||
super(AMLoss, self).__init__()
|
||||
self.m = margin
|
||||
self.s = scale
|
||||
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
label_view = targets.view(-1, 1)
|
||||
delt_costh = torch.zeros(outputs.size(), device=targets.device).scatter_(1, label_view, self.m)
|
||||
costh_m = outputs - delt_costh
|
||||
predictions = self.s * costh_m
|
||||
loss = self.criterion(predictions, targets) / targets.shape[0]
|
||||
return loss
|
||||
|
||||
|
||||
class ARMLoss(nn.Module):
|
||||
def __init__(self, margin=0.2, scale=30):
|
||||
super(ARMLoss, self).__init__()
|
||||
self.m = margin
|
||||
self.s = scale
|
||||
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
label_view = targets.view(-1, 1)
|
||||
delt_costh = torch.zeros(outputs.size(), device=targets.device).scatter_(1, label_view, self.m)
|
||||
costh_m = outputs - delt_costh
|
||||
costh_m_s = self.s * costh_m
|
||||
delt_costh_m_s = costh_m_s.gather(1, label_view).repeat(1, costh_m_s.size()[1])
|
||||
costh_m_s_reduct = costh_m_s - delt_costh_m_s
|
||||
predictions = torch.where(costh_m_s_reduct < 0.0, torch.zeros_like(costh_m_s), costh_m_s)
|
||||
loss = self.criterion(predictions, targets) / targets.shape[0]
|
||||
return loss
|
||||
|
||||
|
||||
class CELoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(CELoss, self).__init__()
|
||||
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
loss = self.criterion(outputs, targets) / targets.shape[0]
|
||||
return loss
|
||||
75
mvector/models/pooling.py
Normal file
75
mvector/models/pooling.py
Normal file
@ -0,0 +1,75 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TemporalAveragePooling(nn.Module):
|
||||
def __init__(self):
|
||||
"""TAP
|
||||
Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
|
||||
Link: https://arxiv.org/pdf/1903.12058.pdf
|
||||
"""
|
||||
super(TemporalAveragePooling, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
"""Computes Temporal Average Pooling Module
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, channels)
|
||||
"""
|
||||
x = torch.mean(x, dim=2)
|
||||
return x
|
||||
|
||||
|
||||
class TemporalStatisticsPooling(nn.Module):
|
||||
def __init__(self):
|
||||
"""TSP
|
||||
Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
|
||||
Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
|
||||
"""
|
||||
super(TemporalStatisticsPooling, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
"""Computes Temporal Statistics Pooling Module
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, channels*2)
|
||||
"""
|
||||
mean = torch.mean(x, dim=2)
|
||||
var = torch.var(x, dim=2)
|
||||
x = torch.cat((mean, var), dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttentivePooling(nn.Module):
|
||||
def __init__(self, in_dim, bottleneck_dim=128):
|
||||
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
||||
# attention dim = 128
|
||||
super(SelfAttentivePooling, self).__init__()
|
||||
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||
alpha = torch.tanh(self.linear1(x))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
return mean
|
||||
|
||||
|
||||
class AttentiveStatsPool(nn.Module):
|
||||
def __init__(self, in_dim, bottleneck_dim=128):
|
||||
super().__init__()
|
||||
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
||||
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||
alpha = torch.tanh(self.linear1(x))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
|
||||
std = torch.sqrt(residuals.clamp(min=1e-9))
|
||||
return torch.cat([mean, std], dim=1)
|
||||
189
mvector/predict.py
Normal file
189
mvector/predict.py
Normal file
@ -0,0 +1,189 @@
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
from io import BufferedReader
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from mvector import SUPPORT_MODEL
|
||||
from mvector.data_utils.audio import AudioSegment
|
||||
from mvector.data_utils.featurizer import AudioFeaturizer
|
||||
from mvector.models.ecapa_tdnn import EcapaTdnn
|
||||
from mvector.models.fc import SpeakerIdetification
|
||||
from mvector.utils.logger import setup_logger
|
||||
from mvector.utils.utils import dict_to_object, print_arguments
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class MVectorPredictor:
|
||||
def __init__(self,
|
||||
configs,
|
||||
threshold=0.6,
|
||||
model_path='models/ecapa_tdnn_FBank/best_model/',
|
||||
use_gpu=True):
|
||||
"""
|
||||
声纹识别预测工具
|
||||
:param configs: 配置参数
|
||||
:param threshold: 判断是否为同一个人的阈值
|
||||
:param model_path: 导出的预测模型文件夹路径
|
||||
:param use_gpu: 是否使用GPU预测
|
||||
"""
|
||||
if use_gpu:
|
||||
assert (torch.cuda.is_available()), 'GPU不可用'
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
self.device = torch.device("cpu")
|
||||
# 索引候选数量
|
||||
self.cdd_num = 5
|
||||
self.threshold = threshold
|
||||
# 读取配置文件
|
||||
if isinstance(configs, str):
|
||||
with open(configs, 'r', encoding='utf-8') as f:
|
||||
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||
# print_arguments(configs=configs)
|
||||
self.configs = dict_to_object(configs)
|
||||
assert 'max_duration' in self.configs.dataset_conf, \
|
||||
'【警告】,您貌似使用了旧的配置文件,如果你同时使用了旧的模型,这是错误的,请重新下载或者重新训练,否则只能回滚代码。'
|
||||
assert self.configs.use_model in SUPPORT_MODEL, f'没有该模型:{self.configs.use_model}'
|
||||
self._audio_featurizer = AudioFeaturizer(feature_conf=self.configs.feature_conf, **self.configs.preprocess_conf)
|
||||
self._audio_featurizer.to(self.device)
|
||||
# 获取模型
|
||||
if self.configs.use_model == 'EcapaTdnn' or self.configs.use_model == 'ecapa_tdnn':
|
||||
backbone = EcapaTdnn(input_size=self._audio_featurizer.feature_dim, **self.configs.model_conf)
|
||||
else:
|
||||
raise Exception(f'{self.configs.use_model} 模型不存在!')
|
||||
model = SpeakerIdetification(backbone=backbone, num_class=self.configs.dataset_conf.num_speakers)
|
||||
model.to(self.device)
|
||||
# 加载模型
|
||||
if os.path.isdir(model_path):
|
||||
model_path = os.path.join(model_path, 'model.pt')
|
||||
assert os.path.exists(model_path), f"{model_path} 模型不存在!"
|
||||
if torch.cuda.is_available() and use_gpu:
|
||||
model_state_dict = torch.load(model_path)
|
||||
else:
|
||||
model_state_dict = torch.load(model_path, map_location='cpu')
|
||||
# 加载模型参数
|
||||
model.load_state_dict(model_state_dict)
|
||||
print(f"成功加载模型参数:{model_path}")
|
||||
# 设置为评估模式
|
||||
model.eval()
|
||||
self.predictor = model.backbone
|
||||
# 声纹库的声纹特征
|
||||
self.audio_feature = None
|
||||
|
||||
def _load_audio(self, audio_data, sample_rate=16000):
|
||||
"""加载音频
|
||||
:param audio_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整的字节文件
|
||||
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
|
||||
:return: 识别的文本结果和解码的得分数
|
||||
"""
|
||||
# 加载音频文件,并进行预处理
|
||||
if isinstance(audio_data, str):
|
||||
audio_segment = AudioSegment.from_file(audio_data)
|
||||
elif isinstance(audio_data, BufferedReader):
|
||||
audio_segment = AudioSegment.from_file(audio_data)
|
||||
elif isinstance(audio_data, np.ndarray):
|
||||
audio_segment = AudioSegment.from_ndarray(audio_data, sample_rate)
|
||||
elif isinstance(audio_data, bytes):
|
||||
audio_segment = AudioSegment.from_bytes(audio_data)
|
||||
else:
|
||||
raise Exception(f'不支持该数据类型,当前数据类型为:{type(audio_data)}')
|
||||
assert audio_segment.duration >= self.configs.dataset_conf.min_duration, \
|
||||
f'音频太短,最小应该为{self.configs.dataset_conf.min_duration}s,当前音频为{audio_segment.duration}s'
|
||||
# 重采样
|
||||
if audio_segment.sample_rate != self.configs.dataset_conf.sample_rate:
|
||||
audio_segment.resample(self.configs.dataset_conf.sample_rate)
|
||||
# decibel normalization
|
||||
if self.configs.dataset_conf.use_dB_normalization:
|
||||
audio_segment.normalize(target_db=self.configs.dataset_conf.target_dB)
|
||||
return audio_segment
|
||||
|
||||
def predict(self,
|
||||
audio_data,
|
||||
sample_rate=16000):
|
||||
"""预测一个音频的特征
|
||||
|
||||
:param audio_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整并带格式的字节文件
|
||||
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
|
||||
:return: 声纹特征向量
|
||||
"""
|
||||
# 加载音频文件,并进行预处理
|
||||
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
|
||||
input_data = torch.tensor(input_data.samples, dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
input_len_ratio = torch.tensor([1], dtype=torch.float32, device=self.device)
|
||||
audio_feature, _ = self._audio_featurizer(input_data, input_len_ratio)
|
||||
# 执行预测
|
||||
feature = self.predictor(audio_feature).data.cpu().numpy()[0]
|
||||
return feature
|
||||
|
||||
def predict_batch(self, audios_data, sample_rate=16000):
|
||||
"""预测一批音频的特征
|
||||
|
||||
:param audios_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整并带格式的字节文件
|
||||
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
|
||||
:return: 声纹特征向量
|
||||
"""
|
||||
audios_data1 = []
|
||||
for audio_data in audios_data:
|
||||
# 加载音频文件,并进行预处理
|
||||
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
|
||||
audios_data1.append(input_data.samples)
|
||||
# 找出音频长度最长的
|
||||
batch = sorted(audios_data1, key=lambda a: a.shape[0], reverse=True)
|
||||
max_audio_length = batch[0].shape[0]
|
||||
batch_size = len(batch)
|
||||
# 以最大的长度创建0张量
|
||||
inputs = np.zeros((batch_size, max_audio_length), dtype='float32')
|
||||
input_lens_ratio = []
|
||||
for x in range(batch_size):
|
||||
tensor = audios_data1[x]
|
||||
seq_length = tensor.shape[0]
|
||||
# 将数据插入都0张量中,实现了padding
|
||||
inputs[x, :seq_length] = tensor[:]
|
||||
input_lens_ratio.append(seq_length/max_audio_length)
|
||||
audios_data = torch.tensor(inputs, dtype=torch.float32, device=self.device)
|
||||
input_lens_ratio = torch.tensor(input_lens_ratio, dtype=torch.float32, device=self.device)
|
||||
audio_feature, _ = self._audio_featurizer(audios_data, input_lens_ratio)
|
||||
# 执行预测
|
||||
features = self.predictor(audio_feature).data.cpu().numpy()
|
||||
return features
|
||||
|
||||
# 声纹对比
|
||||
def contrast(self, audio_data1, audio_data2):
|
||||
feature1 = self.predict(audio_data1)
|
||||
feature2 = self.predict(audio_data2)
|
||||
# 对角余弦值
|
||||
dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||
return dist
|
||||
|
||||
|
||||
def recognition(self, audio_data, threshold=None, sample_rate=16000):
|
||||
"""声纹识别
|
||||
:param audio_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整的字节文件
|
||||
:param threshold: 判断的阈值,如果为None则用创建对象时使用的阈值
|
||||
:param sample_rate: 如果传入的事numpy数据,需要指定采样率
|
||||
:return: 识别的用户名称,如果为None,即没有识别到用户
|
||||
"""
|
||||
if threshold:
|
||||
self.threshold = threshold
|
||||
feature = self.predict(audio_data, sample_rate=sample_rate)
|
||||
name = self.__retrieval(np_feature=[feature])[0]
|
||||
return name
|
||||
|
||||
def compare(self, feature1, feature2):
|
||||
"""声纹对比
|
||||
|
||||
:param feature1: 特征1
|
||||
:param feature2: 特征2
|
||||
:return:
|
||||
"""
|
||||
# 对角余弦值
|
||||
dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||
return dist
|
||||
|
||||
483
mvector/trainer.py
Normal file
483
mvector/trainer.py
Normal file
@ -0,0 +1,483 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import yaml
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torchinfo import summary
|
||||
from tqdm import tqdm
|
||||
from visualdl import LogWriter
|
||||
|
||||
from mvector import SUPPORT_MODEL, __version__
|
||||
from mvector.data_utils.collate_fn import collate_fn
|
||||
from mvector.data_utils.featurizer import AudioFeaturizer
|
||||
from mvector.data_utils.reader import CustomDataset
|
||||
from mvector.metric.metrics import TprAtFpr
|
||||
from mvector.models.ecapa_tdnn import EcapaTdnn
|
||||
from mvector.models.fc import SpeakerIdetification
|
||||
from mvector.models.loss import AAMLoss, CELoss, AMLoss, ARMLoss
|
||||
from mvector.utils.logger import setup_logger
|
||||
from mvector.utils.utils import dict_to_object, print_arguments
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class MVectorTrainer(object):
|
||||
def __init__(self, configs, use_gpu=True):
|
||||
""" mvector集成工具类
|
||||
|
||||
:param configs: 配置字典
|
||||
:param use_gpu: 是否使用GPU训练模型
|
||||
"""
|
||||
if use_gpu:
|
||||
assert (torch.cuda.is_available()), 'GPU不可用'
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
self.device = torch.device("cpu")
|
||||
self.use_gpu = use_gpu
|
||||
# 读取配置文件
|
||||
if isinstance(configs, str):
|
||||
with open(configs, 'r', encoding='utf-8') as f:
|
||||
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||
print_arguments(configs=configs)
|
||||
self.configs = dict_to_object(configs)
|
||||
assert self.configs.use_model in SUPPORT_MODEL, f'没有该模型:{self.configs.use_model}'
|
||||
self.model = None
|
||||
self.test_loader = None
|
||||
# 获取特征器
|
||||
self.audio_featurizer = AudioFeaturizer(feature_conf=self.configs.feature_conf, **self.configs.preprocess_conf)
|
||||
self.audio_featurizer.to(self.device)
|
||||
|
||||
if platform.system().lower() == 'windows':
|
||||
self.configs.dataset_conf.num_workers = 0
|
||||
logger.warning('Windows系统不支持多线程读取数据,已自动关闭!')
|
||||
|
||||
# 获取数据
|
||||
def __setup_dataloader(self, augment_conf_path=None, is_train=False):
|
||||
# 获取训练数据
|
||||
if augment_conf_path is not None and os.path.exists(augment_conf_path) and is_train:
|
||||
augmentation_config = io.open(augment_conf_path, mode='r', encoding='utf8').read()
|
||||
else:
|
||||
if augment_conf_path is not None and not os.path.exists(augment_conf_path):
|
||||
logger.info('数据增强配置文件{}不存在'.format(augment_conf_path))
|
||||
augmentation_config = '{}'
|
||||
# 兼容旧的配置文件
|
||||
if 'max_duration' not in self.configs.dataset_conf:
|
||||
self.configs.dataset_conf.max_duration = self.configs.dataset_conf.chunk_duration
|
||||
if is_train:
|
||||
self.train_dataset = CustomDataset(data_list_path=self.configs.dataset_conf.train_list,
|
||||
do_vad=self.configs.dataset_conf.do_vad,
|
||||
max_duration=self.configs.dataset_conf.max_duration,
|
||||
min_duration=self.configs.dataset_conf.min_duration,
|
||||
augmentation_config=augmentation_config,
|
||||
sample_rate=self.configs.dataset_conf.sample_rate,
|
||||
use_dB_normalization=self.configs.dataset_conf.use_dB_normalization,
|
||||
target_dB=self.configs.dataset_conf.target_dB,
|
||||
mode='train')
|
||||
train_sampler = None
|
||||
if torch.cuda.device_count() > 1:
|
||||
# 设置支持多卡训练
|
||||
train_sampler = DistributedSampler(dataset=self.train_dataset)
|
||||
self.train_loader = DataLoader(dataset=self.train_dataset,
|
||||
collate_fn=collate_fn,
|
||||
shuffle=(train_sampler is None),
|
||||
batch_size=self.configs.dataset_conf.batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=self.configs.dataset_conf.num_workers)
|
||||
# 获取测试数据
|
||||
self.test_dataset = CustomDataset(data_list_path=self.configs.dataset_conf.test_list,
|
||||
do_vad=self.configs.dataset_conf.do_vad,
|
||||
max_duration=self.configs.dataset_conf.max_duration,
|
||||
min_duration=self.configs.dataset_conf.min_duration,
|
||||
sample_rate=self.configs.dataset_conf.sample_rate,
|
||||
use_dB_normalization=self.configs.dataset_conf.use_dB_normalization,
|
||||
target_dB=self.configs.dataset_conf.target_dB,
|
||||
mode='eval')
|
||||
self.test_loader = DataLoader(dataset=self.test_dataset,
|
||||
batch_size=self.configs.dataset_conf.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=self.configs.dataset_conf.num_workers)
|
||||
|
||||
def __setup_model(self, input_size, is_train=False):
|
||||
|
||||
use_loss = self.configs.get('use_loss', 'AAMLoss')
|
||||
# 获取模型
|
||||
if self.configs.use_model == 'EcapaTdnn' or self.configs.use_model == 'ecapa_tdnn':
|
||||
backbone = EcapaTdnn(input_size=input_size, **self.configs.model_conf)
|
||||
else:
|
||||
raise Exception(f'{self.configs.use_model} 模型不存在!')
|
||||
|
||||
self.model = SpeakerIdetification(backbone=backbone,
|
||||
num_class=self.configs.dataset_conf.num_speakers,
|
||||
loss_type=use_loss)
|
||||
self.model.to(self.device)
|
||||
# 打印模型信息
|
||||
summary(self.model, (1, 98, self.audio_featurizer.feature_dim))
|
||||
# print(self.model)
|
||||
# 获取损失函数
|
||||
if use_loss == 'AAMLoss':
|
||||
self.loss = AAMLoss()
|
||||
elif use_loss == 'AMLoss':
|
||||
self.loss = AMLoss()
|
||||
elif use_loss == 'ARMLoss':
|
||||
self.loss = ARMLoss()
|
||||
elif use_loss == 'CELoss':
|
||||
self.loss = CELoss()
|
||||
else:
|
||||
raise Exception(f'没有{use_loss}损失函数!')
|
||||
if is_train:
|
||||
# 获取优化方法
|
||||
optimizer = self.configs.optimizer_conf.optimizer
|
||||
if optimizer == 'Adam':
|
||||
self.optimizer = torch.optim.Adam(params=self.model.parameters(),
|
||||
lr=float(self.configs.optimizer_conf.learning_rate),
|
||||
weight_decay=float(self.configs.optimizer_conf.weight_decay))
|
||||
elif optimizer == 'AdamW':
|
||||
self.optimizer = torch.optim.AdamW(params=self.model.parameters(),
|
||||
lr=float(self.configs.optimizer_conf.learning_rate),
|
||||
weight_decay=float(self.configs.optimizer_conf.weight_decay))
|
||||
elif optimizer == 'SGD':
|
||||
self.optimizer = torch.optim.SGD(params=self.model.parameters(),
|
||||
momentum=self.configs.optimizer_conf.momentum,
|
||||
lr=float(self.configs.optimizer_conf.learning_rate),
|
||||
weight_decay=float(self.configs.optimizer_conf.weight_decay))
|
||||
else:
|
||||
raise Exception(f'不支持优化方法:{optimizer}')
|
||||
# 学习率衰减函数
|
||||
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=int(self.configs.train_conf.max_epoch * 1.2))
|
||||
|
||||
def __load_pretrained(self, pretrained_model):
|
||||
# 加载预训练模型
|
||||
if pretrained_model is not None:
|
||||
if os.path.isdir(pretrained_model):
|
||||
pretrained_model = os.path.join(pretrained_model, 'model.pt')
|
||||
assert os.path.exists(pretrained_model), f"{pretrained_model} 模型不存在!"
|
||||
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||
model_dict = self.model.module.state_dict()
|
||||
else:
|
||||
model_dict = self.model.state_dict()
|
||||
model_state_dict = torch.load(pretrained_model)
|
||||
# 过滤不存在的参数
|
||||
for name, weight in model_dict.items():
|
||||
if name in model_state_dict.keys():
|
||||
if list(weight.shape) != list(model_state_dict[name].shape):
|
||||
logger.warning('{} not used, shape {} unmatched with {} in model.'.
|
||||
format(name, list(model_state_dict[name].shape), list(weight.shape)))
|
||||
model_state_dict.pop(name, None)
|
||||
else:
|
||||
logger.warning('Lack weight: {}'.format(name))
|
||||
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||
self.model.module.load_state_dict(model_state_dict, strict=False)
|
||||
else:
|
||||
self.model.load_state_dict(model_state_dict, strict=False)
|
||||
logger.info('成功加载预训练模型:{}'.format(pretrained_model))
|
||||
|
||||
def __load_checkpoint(self, save_model_path, resume_model):
|
||||
# 加载恢复模型
|
||||
last_epoch = -1
|
||||
best_eer = 1
|
||||
last_model_dir = os.path.join(save_model_path,
|
||||
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||
'last_model')
|
||||
if resume_model is not None or (os.path.exists(os.path.join(last_model_dir, 'model.pt'))
|
||||
and os.path.exists(os.path.join(last_model_dir, 'optimizer.pt'))):
|
||||
# 自动获取最新保存的模型
|
||||
if resume_model is None: resume_model = last_model_dir
|
||||
assert os.path.exists(os.path.join(resume_model, 'model.pt')), "模型参数文件不存在!"
|
||||
assert os.path.exists(os.path.join(resume_model, 'optimizer.pt')), "优化方法参数文件不存在!"
|
||||
state_dict = torch.load(os.path.join(resume_model, 'model.pt'))
|
||||
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||
self.model.module.load_state_dict(state_dict)
|
||||
else:
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.optimizer.load_state_dict(torch.load(os.path.join(resume_model, 'optimizer.pt')))
|
||||
with open(os.path.join(resume_model, 'model.state'), 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
last_epoch = json_data['last_epoch'] - 1
|
||||
best_eer = json_data['eer']
|
||||
logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model))
|
||||
return last_epoch, best_eer
|
||||
|
||||
# 保存模型
|
||||
def __save_checkpoint(self, save_model_path, epoch_id, best_eer=0., best_model=False):
|
||||
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = self.model.module.state_dict()
|
||||
else:
|
||||
state_dict = self.model.state_dict()
|
||||
if best_model:
|
||||
model_path = os.path.join(save_model_path,
|
||||
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||
'best_model')
|
||||
else:
|
||||
model_path = os.path.join(save_model_path,
|
||||
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||
'epoch_{}'.format(epoch_id))
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(model_path, 'optimizer.pt'))
|
||||
torch.save(state_dict, os.path.join(model_path, 'model.pt'))
|
||||
with open(os.path.join(model_path, 'model.state'), 'w', encoding='utf-8') as f:
|
||||
data = {"last_epoch": epoch_id, "eer": best_eer, "version": __version__}
|
||||
f.write(json.dumps(data))
|
||||
if not best_model:
|
||||
last_model_path = os.path.join(save_model_path,
|
||||
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||
'last_model')
|
||||
shutil.rmtree(last_model_path, ignore_errors=True)
|
||||
shutil.copytree(model_path, last_model_path)
|
||||
# 删除旧的模型
|
||||
old_model_path = os.path.join(save_model_path,
|
||||
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||
'epoch_{}'.format(epoch_id - 3))
|
||||
if os.path.exists(old_model_path):
|
||||
shutil.rmtree(old_model_path)
|
||||
logger.info('已保存模型:{}'.format(model_path))
|
||||
|
||||
def __train_epoch(self, epoch_id, save_model_path, local_rank, writer, nranks=0):
|
||||
# 训练一个epoch
|
||||
train_times, accuracies, loss_sum = [], [], []
|
||||
start = time.time()
|
||||
sum_batch = len(self.train_loader) * self.configs.train_conf.max_epoch
|
||||
for batch_id, (audio, label, input_lens_ratio) in enumerate(self.train_loader):
|
||||
if nranks > 1:
|
||||
audio = audio.to(local_rank)
|
||||
input_lens_ratio = input_lens_ratio.to(local_rank)
|
||||
label = label.to(local_rank).long()
|
||||
else:
|
||||
audio = audio.to(self.device)
|
||||
input_lens_ratio = input_lens_ratio.to(self.device)
|
||||
label = label.to(self.device).long()
|
||||
|
||||
# 获取音频MFCC特征
|
||||
features, _ = self.audio_featurizer(audio, input_lens_ratio)
|
||||
output = self.model(features)
|
||||
# 计算损失值
|
||||
los = self.loss(output, label)
|
||||
self.optimizer.zero_grad()
|
||||
los.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# 计算准确率
|
||||
output = torch.nn.functional.softmax(output, dim=-1)
|
||||
output = output.data.cpu().numpy()
|
||||
output = np.argmax(output, axis=1)
|
||||
label = label.data.cpu().numpy()
|
||||
acc = np.mean((output == label).astype(int))
|
||||
accuracies.append(acc)
|
||||
loss_sum.append(los)
|
||||
train_times.append((time.time() - start) * 1000)
|
||||
|
||||
# 多卡训练只使用一个进程打印
|
||||
if batch_id % self.configs.train_conf.log_interval == 0 and local_rank == 0:
|
||||
# 计算每秒训练数据量
|
||||
train_speed = self.configs.dataset_conf.batch_size / (sum(train_times) / len(train_times) / 1000)
|
||||
# 计算剩余时间
|
||||
eta_sec = (sum(train_times) / len(train_times)) * (
|
||||
sum_batch - (epoch_id - 1) * len(self.train_loader) - batch_id)
|
||||
eta_str = str(timedelta(seconds=int(eta_sec / 1000)))
|
||||
logger.info(f'Train epoch: [{epoch_id}/{self.configs.train_conf.max_epoch}], '
|
||||
f'batch: [{batch_id}/{len(self.train_loader)}], '
|
||||
f'loss: {sum(loss_sum) / len(loss_sum):.5f}, '
|
||||
f'accuracy: {sum(accuracies) / len(accuracies):.5f}, '
|
||||
f'learning rate: {self.scheduler.get_last_lr()[0]:>.8f}, '
|
||||
f'speed: {train_speed:.2f} data/sec, eta: {eta_str}')
|
||||
writer.add_scalar('Train/Loss', sum(loss_sum) / len(loss_sum), self.train_step)
|
||||
writer.add_scalar('Train/Accuracy', (sum(accuracies) / len(accuracies)), self.train_step)
|
||||
# 记录学习率
|
||||
writer.add_scalar('Train/lr', self.scheduler.get_last_lr()[0], self.train_step)
|
||||
self.train_step += 1
|
||||
train_times = []
|
||||
# 固定步数也要保存一次模型
|
||||
if batch_id % 10000 == 0 and batch_id != 0 and local_rank == 0:
|
||||
self.__save_checkpoint(save_model_path=save_model_path, epoch_id=epoch_id)
|
||||
start = time.time()
|
||||
self.scheduler.step()
|
||||
|
||||
def train(self,
|
||||
save_model_path='models/',
|
||||
resume_model=None,
|
||||
pretrained_model=None,
|
||||
augment_conf_path='configs/augmentation.json'):
|
||||
"""
|
||||
训练模型
|
||||
:param save_model_path: 模型保存的路径
|
||||
:param resume_model: 恢复训练,当为None则不使用预训练模型
|
||||
:param pretrained_model: 预训练模型的路径,当为None则不使用预训练模型
|
||||
:param augment_conf_path: 数据增强的配置文件,为json格式
|
||||
"""
|
||||
# 获取有多少张显卡训练
|
||||
nranks = torch.cuda.device_count()
|
||||
local_rank = 0
|
||||
writer = None
|
||||
if local_rank == 0:
|
||||
# 日志记录器
|
||||
writer = LogWriter(logdir='log')
|
||||
|
||||
if nranks > 1 and self.use_gpu:
|
||||
# 初始化NCCL环境
|
||||
dist.init_process_group(backend='nccl')
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
# 获取数据
|
||||
self.__setup_dataloader(augment_conf_path=augment_conf_path, is_train=True)
|
||||
# 获取模型
|
||||
self.__setup_model(input_size=self.audio_featurizer.feature_dim, is_train=True)
|
||||
|
||||
# 支持多卡训练
|
||||
if nranks > 1 and self.use_gpu:
|
||||
self.model.to(local_rank)
|
||||
self.audio_featurizer.to(local_rank)
|
||||
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank])
|
||||
logger.info('训练数据:{}'.format(len(self.train_dataset)))
|
||||
|
||||
self.__load_pretrained(pretrained_model=pretrained_model)
|
||||
# 加载恢复模型
|
||||
last_epoch, best_eer = self.__load_checkpoint(save_model_path=save_model_path, resume_model=resume_model)
|
||||
if last_epoch > 0:
|
||||
self.optimizer.step()
|
||||
[self.scheduler.step() for _ in range(last_epoch)]
|
||||
|
||||
test_step, self.train_step = 0, 0
|
||||
last_epoch += 1
|
||||
if local_rank == 0:
|
||||
writer.add_scalar('Train/lr', self.scheduler.get_last_lr()[0], last_epoch)
|
||||
# 开始训练
|
||||
for epoch_id in range(last_epoch, self.configs.train_conf.max_epoch):
|
||||
epoch_id += 1
|
||||
start_epoch = time.time()
|
||||
# 训练一个epoch
|
||||
self.__train_epoch(epoch_id=epoch_id, save_model_path=save_model_path, local_rank=local_rank,
|
||||
writer=writer, nranks=nranks)
|
||||
# 多卡训练只使用一个进程执行评估和保存模型
|
||||
if local_rank == 0:
|
||||
logger.info('=' * 70)
|
||||
tpr, fpr, eer, threshold = self.evaluate(resume_model=None)
|
||||
logger.info('Test epoch: {}, time/epoch: {}, threshold: {:.2f}, tpr: {:.5f}, fpr: {:.5f}, '
|
||||
'eer: {:.5f}'.format(epoch_id, str(timedelta(
|
||||
seconds=(time.time() - start_epoch))), threshold, tpr, fpr, eer))
|
||||
logger.info('=' * 70)
|
||||
writer.add_scalar('Test/threshold', threshold, test_step)
|
||||
writer.add_scalar('Test/tpr', tpr, test_step)
|
||||
writer.add_scalar('Test/fpr', fpr, test_step)
|
||||
writer.add_scalar('Test/eer', eer, test_step)
|
||||
test_step += 1
|
||||
self.model.train()
|
||||
# # 保存最优模型
|
||||
if eer <= best_eer:
|
||||
best_eer = eer
|
||||
self.__save_checkpoint(save_model_path=save_model_path, epoch_id=epoch_id, best_eer=eer,
|
||||
best_model=True)
|
||||
# 保存模型
|
||||
self.__save_checkpoint(save_model_path=save_model_path, epoch_id=epoch_id, best_eer=eer)
|
||||
|
||||
def evaluate(self, resume_model='models/EcapaTdnn_MFCC/best_model/', save_image_path=None):
|
||||
"""
|
||||
评估模型
|
||||
:param resume_model: 所使用的模型
|
||||
:param save_image_path: 保存混合矩阵的路径
|
||||
:return: 评估结果
|
||||
"""
|
||||
if self.test_loader is None:
|
||||
self.__setup_dataloader()
|
||||
if self.model is None:
|
||||
self.__setup_model(input_size=self.audio_featurizer.feature_dim)
|
||||
if resume_model is not None:
|
||||
if os.path.isdir(resume_model):
|
||||
resume_model = os.path.join(resume_model, 'model.pt')
|
||||
assert os.path.exists(resume_model), f"{resume_model} 模型不存在!"
|
||||
model_state_dict = torch.load(resume_model)
|
||||
self.model.load_state_dict(model_state_dict)
|
||||
logger.info(f'成功加载模型:{resume_model}')
|
||||
self.model.eval()
|
||||
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
|
||||
eval_model = self.model.module
|
||||
else:
|
||||
eval_model = self.model
|
||||
|
||||
features, labels = None, None
|
||||
losses = []
|
||||
with torch.no_grad():
|
||||
for batch_id, (audio, label, input_lens_ratio) in enumerate(tqdm(self.test_loader)):
|
||||
audio = audio.to(self.device)
|
||||
input_lens_ratio = input_lens_ratio.to(self.device)
|
||||
|
||||
label = label.to(self.device).long()
|
||||
audio_features, _ = self.audio_featurizer(audio, input_lens_ratio)
|
||||
# logits = eval_model(audio_features)
|
||||
|
||||
|
||||
# loss = self.loss(logits, label) # 注意,这里使用的是 logits 而不是提取的特征
|
||||
# losses.append(loss.item())
|
||||
|
||||
|
||||
feature = eval_model.backbone(audio_features).data.cpu().numpy()
|
||||
label = label.data.cpu().numpy()
|
||||
# 存放特征
|
||||
features = np.concatenate((features, feature)) if features is not None else feature
|
||||
labels = np.concatenate((labels, label)) if labels is not None else label
|
||||
# print('Test loss: {:.5f}'.format(sum(losses) / len(losses)))
|
||||
self.model.train()
|
||||
metric = TprAtFpr()
|
||||
labels = labels.astype(np.int32)
|
||||
print('开始两两对比音频特征...')
|
||||
for i in tqdm(range(len(features))):
|
||||
feature_1 = features[i]
|
||||
feature_1 = np.expand_dims(feature_1, 0).repeat(len(features) - i, axis=0)
|
||||
feature_2 = features[i:]
|
||||
feature_1 = torch.tensor(feature_1, dtype=torch.float32)
|
||||
feature_2 = torch.tensor(feature_2, dtype=torch.float32)
|
||||
score = torch.nn.functional.cosine_similarity(feature_1, feature_2, dim=-1).data.cpu().numpy().tolist()
|
||||
y_true = np.array(labels[i] == labels[i:]).astype(np.int32).tolist()
|
||||
metric.add(y_true, score)
|
||||
tprs, fprs, thresholds, eer, index = metric.calculate()
|
||||
tpr, fpr, threshold = tprs[index], fprs[index], thresholds[index]
|
||||
if save_image_path:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(thresholds, tprs, color='blue', linestyle='-', label='tpr')
|
||||
plt.plot(thresholds, fprs, color='red', linestyle='-', label='fpr')
|
||||
plt.plot(threshold, tpr, 'bo-')
|
||||
plt.text(threshold, tpr, (threshold, round(tpr, 5)), color='blue')
|
||||
plt.plot(threshold, fpr, 'ro-')
|
||||
plt.text(threshold, fpr, (threshold, round(fpr, 5)), color='red')
|
||||
plt.xlabel('threshold')
|
||||
plt.title('tpr and fpr')
|
||||
plt.grid(True) # 显示网格线
|
||||
# 保存图像
|
||||
os.makedirs(save_image_path, exist_ok=True)
|
||||
plt.savefig(os.path.join(save_image_path, 'result.png'))
|
||||
logger.info(f"结果图以保存在:{os.path.join(save_image_path, 'result.png')}")
|
||||
return tpr, fpr, eer, threshold
|
||||
|
||||
def export(self, save_model_path='models/', resume_model='models/EcapaTdnn_MelSpectrogram/best_model/'):
|
||||
"""
|
||||
导出预测模型
|
||||
:param save_model_path: 模型保存的路径
|
||||
:param resume_model: 准备转换的模型路径
|
||||
:return:
|
||||
"""
|
||||
# 获取模型
|
||||
self.__setup_model(input_size=self.audio_featurizer.feature_dim)
|
||||
# 加载预训练模型
|
||||
if os.path.isdir(resume_model):
|
||||
resume_model = os.path.join(resume_model, 'model.pt')
|
||||
assert os.path.exists(resume_model), f"{resume_model} 模型不存在!"
|
||||
model_state_dict = torch.load(resume_model)
|
||||
self.model.load_state_dict(model_state_dict)
|
||||
logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model))
|
||||
self.model.eval()
|
||||
# 获取静态模型
|
||||
infer_model = torch.jit.script(self.model.backbone)
|
||||
infer_model_path = os.path.join(save_model_path,
|
||||
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
|
||||
'inference.pt')
|
||||
os.makedirs(os.path.dirname(infer_model_path), exist_ok=True)
|
||||
torch.jit.save(infer_model, infer_model_path)
|
||||
logger.info("预测模型已保存:{}".format(infer_model_path))
|
||||
0
mvector/utils/__init__.py
Normal file
0
mvector/utils/__init__.py
Normal file
BIN
mvector/utils/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
mvector/utils/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/utils/__pycache__/logger.cpython-37.pyc
Normal file
BIN
mvector/utils/__pycache__/logger.cpython-37.pyc
Normal file
Binary file not shown.
BIN
mvector/utils/__pycache__/utils.cpython-37.pyc
Normal file
BIN
mvector/utils/__pycache__/utils.cpython-37.pyc
Normal file
Binary file not shown.
89
mvector/utils/logger.py
Normal file
89
mvector/utils/logger.py
Normal file
@ -0,0 +1,89 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import termcolor
|
||||
|
||||
__all__ = ['setup_logger']
|
||||
|
||||
logger_initialized = []
|
||||
|
||||
|
||||
def setup_logger(name, output=None):
|
||||
"""
|
||||
Initialize logger and set its verbosity level to INFO.
|
||||
Args:
|
||||
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
name (str): the root module name of this logger
|
||||
|
||||
Returns:
|
||||
logging.Logger: a logger
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
formatter = ("[%(asctime2)s %(levelname2)s] %(module2)s:%(funcName2)s:%(lineno2)s - %(message2)s")
|
||||
color_formatter = ColoredFormatter(formatter, datefmt="%m/%d %H:%M:%S")
|
||||
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
ch.setFormatter(color_formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
# file logging: all workers
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "log.txt")
|
||||
os.makedirs(os.path.dirname(filename))
|
||||
fh = logging.FileHandler(filename, mode='a')
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(logging.Formatter())
|
||||
logger.addHandler(fh)
|
||||
logger_initialized.append(name)
|
||||
return logger
|
||||
|
||||
|
||||
COLORS = {
|
||||
"WARNING": "yellow",
|
||||
"INFO": "white",
|
||||
"DEBUG": "blue",
|
||||
"CRITICAL": "red",
|
||||
"ERROR": "red",
|
||||
}
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
def __init__(self, fmt, datefmt, use_color=True):
|
||||
logging.Formatter.__init__(self, fmt, datefmt=datefmt)
|
||||
self.use_color = use_color
|
||||
|
||||
def format(self, record):
|
||||
levelname = record.levelname
|
||||
if self.use_color and levelname in COLORS:
|
||||
|
||||
def colored(text):
|
||||
return termcolor.colored(
|
||||
text,
|
||||
color=COLORS[levelname],
|
||||
attrs={"bold": True},
|
||||
)
|
||||
|
||||
record.levelname2 = colored("{:<7}".format(record.levelname))
|
||||
record.message2 = colored(record.msg)
|
||||
|
||||
asctime2 = datetime.datetime.fromtimestamp(record.created)
|
||||
record.asctime2 = termcolor.colored(asctime2, color="green")
|
||||
|
||||
record.module2 = termcolor.colored(record.module, color="cyan")
|
||||
record.funcName2 = termcolor.colored(record.funcName, color="cyan")
|
||||
record.lineno2 = termcolor.colored(record.lineno, color="cyan")
|
||||
return logging.Formatter.format(self, record)
|
||||
|
||||
31
mvector/utils/record.py
Normal file
31
mvector/utils/record.py
Normal file
@ -0,0 +1,31 @@
|
||||
import os
|
||||
|
||||
import soundcard
|
||||
import soundfile
|
||||
|
||||
|
||||
class RecordAudio:
|
||||
def __init__(self, channels=1, sample_rate=16000):
|
||||
# 录音参数
|
||||
self.channels = channels
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
# 获取麦克风
|
||||
self.default_mic = soundcard.default_microphone()
|
||||
|
||||
def record(self, record_seconds=3, save_path=None):
|
||||
"""录音
|
||||
|
||||
:param record_seconds: 录音时间,默认3秒
|
||||
:param save_path: 录音保存的路径,后缀名为wav
|
||||
:return: 音频的numpy数据
|
||||
"""
|
||||
print("开始录音......")
|
||||
num_frames = int(record_seconds * self.sample_rate)
|
||||
data = self.default_mic.record(samplerate=self.sample_rate, numframes=num_frames, channels=self.channels)
|
||||
audio_data = data.squeeze()
|
||||
print("录音已结束!")
|
||||
if save_path is not None:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
soundfile.write(save_path, data=data, samplerate=self.sample_rate)
|
||||
return audio_data
|
||||
85
mvector/utils/utils.py
Normal file
85
mvector/utils/utils.py
Normal file
@ -0,0 +1,85 @@
|
||||
import distutils.util
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mvector.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def print_arguments(args=None, configs=None):
|
||||
if args:
|
||||
logger.info("----------- 额外配置参数 -----------")
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
logger.info("%s: %s" % (arg, value))
|
||||
logger.info("------------------------------------------------")
|
||||
if configs:
|
||||
logger.info("----------- 配置文件参数 -----------")
|
||||
for arg, value in sorted(configs.items()):
|
||||
if isinstance(value, dict):
|
||||
logger.info(f"{arg}:")
|
||||
for a, v in sorted(value.items()):
|
||||
if isinstance(v, dict):
|
||||
logger.info(f"\t{a}:")
|
||||
for a1, v1 in sorted(v.items()):
|
||||
logger.info("\t\t%s: %s" % (a1, v1))
|
||||
else:
|
||||
logger.info("\t%s: %s" % (a, v))
|
||||
else:
|
||||
logger.info("%s: %s" % (arg, value))
|
||||
logger.info("------------------------------------------------")
|
||||
|
||||
|
||||
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
||||
type = distutils.util.strtobool if type == bool else type
|
||||
argparser.add_argument("--" + argname,
|
||||
default=default,
|
||||
type=type,
|
||||
help=help + ' 默认: %(default)s.',
|
||||
**kwargs)
|
||||
|
||||
|
||||
class Dict(dict):
|
||||
__setattr__ = dict.__setitem__
|
||||
__getattr__ = dict.__getitem__
|
||||
|
||||
|
||||
def dict_to_object(dict_obj):
|
||||
if not isinstance(dict_obj, dict):
|
||||
return dict_obj
|
||||
inst = Dict()
|
||||
for k, v in dict_obj.items():
|
||||
inst[k] = dict_to_object(v)
|
||||
return inst
|
||||
|
||||
|
||||
# 根据对角余弦值计算准确率和最优的阈值
|
||||
def cal_accuracy_threshold(y_score, y_true):
|
||||
y_score = np.asarray(y_score)
|
||||
y_true = np.asarray(y_true)
|
||||
best_accuracy = 0
|
||||
best_threshold = 0
|
||||
for i in tqdm(range(0, 100)):
|
||||
threshold = i * 0.01
|
||||
y_test = (y_score >= threshold)
|
||||
acc = np.mean((y_test == y_true).astype(int))
|
||||
if acc > best_accuracy:
|
||||
best_accuracy = acc
|
||||
best_threshold = threshold
|
||||
|
||||
return best_accuracy, best_threshold
|
||||
|
||||
|
||||
# 根据对角余弦值计算准确率
|
||||
def cal_accuracy(y_score, y_true, threshold=0.5):
|
||||
y_score = np.asarray(y_score)
|
||||
y_true = np.asarray(y_true)
|
||||
y_test = (y_score >= threshold)
|
||||
accuracy = np.mean((y_test == y_true).astype(int))
|
||||
return accuracy
|
||||
|
||||
|
||||
# 计算对角余弦值
|
||||
def cosin_metric(x1, x2):
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
numba>=0.52.0
|
||||
librosa>=0.9.1
|
||||
numpy>=1.19.2
|
||||
tqdm>=4.59.0
|
||||
visualdl>=2.1.1
|
||||
resampy==0.2.2
|
||||
soundfile>=0.12.1
|
||||
soundcard>=0.4.2
|
||||
45
setup.py
Normal file
45
setup.py
Normal file
@ -0,0 +1,45 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
import mvector
|
||||
|
||||
VERSION = mvector.__version__
|
||||
|
||||
def readme():
|
||||
with open('README.md', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
|
||||
|
||||
def parse_requirements():
|
||||
with open('./requirements.txt', encoding="utf-8") as f:
|
||||
requirements = f.readlines()
|
||||
return requirements
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
name='mvector',
|
||||
packages=find_packages(),
|
||||
author='yeyupiaoling',
|
||||
version=VERSION,
|
||||
install_requires=parse_requirements(),
|
||||
description='Voice Print Recognition toolkit on Pytorch',
|
||||
long_description=readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/yeyupiaoling/VoiceprintRecognition_Pytorch',
|
||||
download_url='https://github.com/yeyupiaoling/VoiceprintRecognition_Pytorch.git',
|
||||
keywords=['Voice', 'Pytorch'],
|
||||
classifiers=[
|
||||
'Intended Audience :: Developers',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: OS Independent',
|
||||
'Natural Language :: Chinese (Simplified)',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9', 'Topic :: Utilities'
|
||||
],
|
||||
license='Apache License 2.0',
|
||||
ext_modules=[])
|
||||
26
train.py
Normal file
26
train.py
Normal file
@ -0,0 +1,26 @@
|
||||
import argparse
|
||||
import functools
|
||||
|
||||
from mvector.trainer import MVectorTrainer
|
||||
from mvector.utils.utils import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
|
||||
add_arg("local_rank", int, 0, '多卡训练需要的参数')
|
||||
add_arg("use_gpu", bool, True, '是否使用GPU训练')
|
||||
add_arg('augment_conf_path',str, 'configs/augmentation.json', '数据增强的配置文件,为json格式')
|
||||
add_arg('save_model_path', str, 'models/', '模型保存的路径')
|
||||
add_arg('resume_model', str, None, '恢复训练,当为None则不使用预训练模型')
|
||||
add_arg('save_image_path', str, 'output/images/', "保存结果图的路径")
|
||||
add_arg('pretrained_model', str, 'models/ecapa_tdnn_MFCC/best_model/', '预训练模型的路径,当为None则不使用预训练模型')
|
||||
args = parser.parse_args()
|
||||
print_arguments(args=args)
|
||||
|
||||
# 获取训练器
|
||||
trainer = MVectorTrainer(configs=args.configs, use_gpu=args.use_gpu)
|
||||
|
||||
trainer.train(save_model_path=args.save_model_path,
|
||||
resume_model=args.resume_model,
|
||||
pretrained_model=args.pretrained_model,
|
||||
augment_conf_path=args.augment_conf_path)
|
||||
Loading…
x
Reference in New Issue
Block a user