1/mvector/data_utils/featurizer.py

104 lines
3.8 KiB
Python
Raw Normal View History

2025-04-18 19:56:58 +08:00
import torch
from torch import nn
from torchaudio.transforms import MelSpectrogram, Spectrogram, MFCC
import torchaudio.compliance.kaldi as Kaldi
class AudioFeaturizer(nn.Module):
"""音频特征器
:param feature_method: 所使用的预处理方法
:type feature_method: str
:param feature_conf: 预处理方法的参数
:type feature_conf: dict
"""
def __init__(self, feature_method='MelSpectrogram', feature_conf={}):
super().__init__()
self._feature_conf = feature_conf
self._feature_method = feature_method
if feature_method == 'MelSpectrogram':
self.feat_fun = MelSpectrogram(**feature_conf)
elif feature_method == 'Spectrogram':
self.feat_fun = Spectrogram(**feature_conf)
elif feature_method == 'MFCC':
melkwargs = feature_conf.copy()
del melkwargs['sample_rate']
del melkwargs['n_mfcc']
self.feat_fun = MFCC(sample_rate=self._feature_conf.sample_rate,
n_mfcc=self._feature_conf.n_mfcc,
melkwargs=melkwargs)
elif feature_method == 'Fbank':
self.feat_fun = KaldiFbank(**feature_conf)
else:
raise Exception(f'预处理方法 {self._feature_method} 不存在!')
def forward(self, waveforms, input_lens_ratio):
"""从AudioSegment中提取音频特征
:param waveforms: Audio segment to extract features from.
:type waveforms: AudioSegment
:param input_lens_ratio: input length ratio
:type input_lens_ratio: tensor
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
"""
feature = self.feat_fun(waveforms)
feature = feature.transpose(2, 1)
# 归一化
mean = torch.mean(feature, 1, keepdim=True)
std = torch.std(feature, 1, keepdim=True)
feature = (feature - mean) / (std + 1e-5)
# 对掩码比例进行扩展
input_lens = (input_lens_ratio * feature.shape[1])
mask_lens = torch.round(input_lens).long()
mask_lens = mask_lens.unsqueeze(1)
input_lens = input_lens.int()
# 生成掩码张量
idxs = torch.arange(feature.shape[1], device=feature.device).repeat(feature.shape[0], 1)
mask = idxs < mask_lens
mask = mask.unsqueeze(-1)
# 对特征进行掩码操作
feature_masked = torch.where(mask, feature, torch.zeros_like(feature))
return feature_masked, input_lens
@property
def feature_dim(self):
"""返回特征大小
:return: 特征大小
:rtype: int
"""
if self._feature_method == 'LogMelSpectrogram':
return self._feature_conf.n_mels
elif self._feature_method == 'MelSpectrogram':
return self._feature_conf.n_mels
elif self._feature_method == 'Spectrogram':
return self._feature_conf.n_fft // 2 + 1
elif self._feature_method == 'MFCC':
return self._feature_conf.n_mfcc
elif self._feature_method == 'Fbank':
return self._feature_conf.num_mel_bins
else:
raise Exception('没有{}预处理方法'.format(self._feature_method))
class KaldiFbank(nn.Module):
def __init__(self, **kwargs):
super(KaldiFbank, self).__init__()
self.kwargs = kwargs
def forward(self, waveforms):
"""
:param waveforms: [Batch, Length]
:return: [Batch, Length, Feature]
"""
log_fbanks = []
for waveform in waveforms:
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
log_fbank = Kaldi.fbank(waveform, **self.kwargs)
log_fbank = log_fbank.transpose(0, 1)
log_fbanks.append(log_fbank)
log_fbank = torch.stack(log_fbanks)
return log_fbank