1/mvector/data_utils/reader.py
2025-04-18 19:56:58 +08:00

74 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from torch.utils.data import Dataset
from mvector.data_utils.audio import AudioSegment
from mvector.data_utils.augmentor.augmentation import AugmentationPipeline
from mvector.utils.logger import setup_logger
logger = setup_logger(__name__)
class CustomDataset(Dataset):
def __init__(self,
data_list_path,
do_vad=True,
max_duration=6,
min_duration=0.5,
augmentation_config='{}',
mode='train',
sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
"""音频数据加载器
Args:
data_list_path: 包含音频路径和标签的数据列表文件的路径
do_vad: 是否对音频进行语音活动检测VAD来裁剪静音部分
max_duration: 最长的音频长度,大于这个长度会裁剪掉
min_duration: 过滤最短的音频长度
augmentation_config: 用于指定音频增强的配置
mode: 数据集模式。在训练模式下,数据集可能会进行一些数据增强的预处理
sample_rate: 采样率
use_dB_normalization: 是否对音频进行音量归一化
target_dB: 音量归一化的大小
"""
super(CustomDataset, self).__init__()
self.do_vad = do_vad
self.max_duration = max_duration
self.min_duration = min_duration
self.mode = mode
self._target_sample_rate = sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
self._augmentation_pipeline = AugmentationPipeline(augmentation_config=augmentation_config)
# 获取数据列表
with open(data_list_path, 'r') as f:
self.lines = f.readlines()
def __getitem__(self, idx):
# 分割音频路径和标签
audio_path, label = self.lines[idx].replace('\n', '').split('\t')
# 读取音频
audio_segment = AudioSegment.from_file(audio_path)
# 裁剪静音
if self.do_vad:
audio_segment.vad()
# 数据太短不利于训练
if self.mode == 'train':
if audio_segment.duration < self.min_duration:
return self.__getitem__(idx + 1 if idx < len(self.lines) - 1 else 0)
# 重采样
if audio_segment.sample_rate != self._target_sample_rate:
audio_segment.resample(self._target_sample_rate)
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# 裁剪需要的数据
audio_segment.crop(duration=self.max_duration, mode=self.mode)
# 音频增强
self._augmentation_pipeline.transform_audio(audio_segment)
return np.array(audio_segment.samples, dtype=np.float32), np.array(int(label), dtype=np.int64)
def __len__(self):
return len(self.lines)