1/mvector/data_utils/reader.py

74 lines
3.0 KiB
Python
Raw Permalink Normal View History

2025-04-18 19:56:58 +08:00
import numpy as np
from torch.utils.data import Dataset
from mvector.data_utils.audio import AudioSegment
from mvector.data_utils.augmentor.augmentation import AugmentationPipeline
from mvector.utils.logger import setup_logger
logger = setup_logger(__name__)
class CustomDataset(Dataset):
def __init__(self,
data_list_path,
do_vad=True,
max_duration=6,
min_duration=0.5,
augmentation_config='{}',
mode='train',
sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
"""音频数据加载器
Args:
data_list_path: 包含音频路径和标签的数据列表文件的路径
do_vad: 是否对音频进行语音活动检测VAD来裁剪静音部分
max_duration: 最长的音频长度大于这个长度会裁剪掉
min_duration: 过滤最短的音频长度
augmentation_config: 用于指定音频增强的配置
mode: 数据集模式在训练模式下数据集可能会进行一些数据增强的预处理
sample_rate: 采样率
use_dB_normalization: 是否对音频进行音量归一化
target_dB: 音量归一化的大小
"""
super(CustomDataset, self).__init__()
self.do_vad = do_vad
self.max_duration = max_duration
self.min_duration = min_duration
self.mode = mode
self._target_sample_rate = sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
self._augmentation_pipeline = AugmentationPipeline(augmentation_config=augmentation_config)
# 获取数据列表
with open(data_list_path, 'r') as f:
self.lines = f.readlines()
def __getitem__(self, idx):
# 分割音频路径和标签
audio_path, label = self.lines[idx].replace('\n', '').split('\t')
# 读取音频
audio_segment = AudioSegment.from_file(audio_path)
# 裁剪静音
if self.do_vad:
audio_segment.vad()
# 数据太短不利于训练
if self.mode == 'train':
if audio_segment.duration < self.min_duration:
return self.__getitem__(idx + 1 if idx < len(self.lines) - 1 else 0)
# 重采样
if audio_segment.sample_rate != self._target_sample_rate:
audio_segment.resample(self._target_sample_rate)
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# 裁剪需要的数据
audio_segment.crop(duration=self.max_duration, mode=self.mode)
# 音频增强
self._augmentation_pipeline.transform_audio(audio_segment)
return np.array(audio_segment.samples, dtype=np.float32), np.array(int(label), dtype=np.int64)
def __len__(self):
return len(self.lines)