74 lines
3.0 KiB
Python
74 lines
3.0 KiB
Python
import numpy as np
|
||
from torch.utils.data import Dataset
|
||
|
||
from mvector.data_utils.audio import AudioSegment
|
||
from mvector.data_utils.augmentor.augmentation import AugmentationPipeline
|
||
from mvector.utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
class CustomDataset(Dataset):
|
||
def __init__(self,
|
||
data_list_path,
|
||
do_vad=True,
|
||
max_duration=6,
|
||
min_duration=0.5,
|
||
augmentation_config='{}',
|
||
mode='train',
|
||
sample_rate=16000,
|
||
use_dB_normalization=True,
|
||
target_dB=-20):
|
||
"""音频数据加载器
|
||
|
||
Args:
|
||
data_list_path: 包含音频路径和标签的数据列表文件的路径
|
||
do_vad: 是否对音频进行语音活动检测(VAD)来裁剪静音部分
|
||
max_duration: 最长的音频长度,大于这个长度会裁剪掉
|
||
min_duration: 过滤最短的音频长度
|
||
augmentation_config: 用于指定音频增强的配置
|
||
mode: 数据集模式。在训练模式下,数据集可能会进行一些数据增强的预处理
|
||
sample_rate: 采样率
|
||
use_dB_normalization: 是否对音频进行音量归一化
|
||
target_dB: 音量归一化的大小
|
||
"""
|
||
super(CustomDataset, self).__init__()
|
||
self.do_vad = do_vad
|
||
self.max_duration = max_duration
|
||
self.min_duration = min_duration
|
||
self.mode = mode
|
||
self._target_sample_rate = sample_rate
|
||
self._use_dB_normalization = use_dB_normalization
|
||
self._target_dB = target_dB
|
||
self._augmentation_pipeline = AugmentationPipeline(augmentation_config=augmentation_config)
|
||
# 获取数据列表
|
||
with open(data_list_path, 'r') as f:
|
||
self.lines = f.readlines()
|
||
|
||
def __getitem__(self, idx):
|
||
# 分割音频路径和标签
|
||
audio_path, label = self.lines[idx].replace('\n', '').split('\t')
|
||
# 读取音频
|
||
audio_segment = AudioSegment.from_file(audio_path)
|
||
# 裁剪静音
|
||
if self.do_vad:
|
||
audio_segment.vad()
|
||
# 数据太短不利于训练
|
||
if self.mode == 'train':
|
||
if audio_segment.duration < self.min_duration:
|
||
return self.__getitem__(idx + 1 if idx < len(self.lines) - 1 else 0)
|
||
# 重采样
|
||
if audio_segment.sample_rate != self._target_sample_rate:
|
||
audio_segment.resample(self._target_sample_rate)
|
||
# decibel normalization
|
||
if self._use_dB_normalization:
|
||
audio_segment.normalize(target_db=self._target_dB)
|
||
# 裁剪需要的数据
|
||
audio_segment.crop(duration=self.max_duration, mode=self.mode)
|
||
# 音频增强
|
||
self._augmentation_pipeline.transform_audio(audio_segment)
|
||
return np.array(audio_segment.samples, dtype=np.float32), np.array(int(label), dtype=np.int64)
|
||
|
||
def __len__(self):
|
||
return len(self.lines)
|