104 lines
3.8 KiB
Python
104 lines
3.8 KiB
Python
import torch
|
|
from torch import nn
|
|
from torchaudio.transforms import MelSpectrogram, Spectrogram, MFCC
|
|
import torchaudio.compliance.kaldi as Kaldi
|
|
|
|
|
|
class AudioFeaturizer(nn.Module):
|
|
"""音频特征器
|
|
|
|
:param feature_method: 所使用的预处理方法
|
|
:type feature_method: str
|
|
:param feature_conf: 预处理方法的参数
|
|
:type feature_conf: dict
|
|
"""
|
|
|
|
def __init__(self, feature_method='MelSpectrogram', feature_conf={}):
|
|
super().__init__()
|
|
self._feature_conf = feature_conf
|
|
self._feature_method = feature_method
|
|
if feature_method == 'MelSpectrogram':
|
|
self.feat_fun = MelSpectrogram(**feature_conf)
|
|
elif feature_method == 'Spectrogram':
|
|
self.feat_fun = Spectrogram(**feature_conf)
|
|
elif feature_method == 'MFCC':
|
|
melkwargs = feature_conf.copy()
|
|
del melkwargs['sample_rate']
|
|
del melkwargs['n_mfcc']
|
|
self.feat_fun = MFCC(sample_rate=self._feature_conf.sample_rate,
|
|
n_mfcc=self._feature_conf.n_mfcc,
|
|
melkwargs=melkwargs)
|
|
elif feature_method == 'Fbank':
|
|
self.feat_fun = KaldiFbank(**feature_conf)
|
|
else:
|
|
raise Exception(f'预处理方法 {self._feature_method} 不存在!')
|
|
|
|
def forward(self, waveforms, input_lens_ratio):
|
|
"""从AudioSegment中提取音频特征
|
|
|
|
:param waveforms: Audio segment to extract features from.
|
|
:type waveforms: AudioSegment
|
|
:param input_lens_ratio: input length ratio
|
|
:type input_lens_ratio: tensor
|
|
:return: Spectrogram audio feature in 2darray.
|
|
:rtype: ndarray
|
|
"""
|
|
feature = self.feat_fun(waveforms)
|
|
feature = feature.transpose(2, 1)
|
|
# 归一化
|
|
mean = torch.mean(feature, 1, keepdim=True)
|
|
std = torch.std(feature, 1, keepdim=True)
|
|
feature = (feature - mean) / (std + 1e-5)
|
|
# 对掩码比例进行扩展
|
|
input_lens = (input_lens_ratio * feature.shape[1])
|
|
mask_lens = torch.round(input_lens).long()
|
|
mask_lens = mask_lens.unsqueeze(1)
|
|
input_lens = input_lens.int()
|
|
# 生成掩码张量
|
|
idxs = torch.arange(feature.shape[1], device=feature.device).repeat(feature.shape[0], 1)
|
|
mask = idxs < mask_lens
|
|
mask = mask.unsqueeze(-1)
|
|
# 对特征进行掩码操作
|
|
feature_masked = torch.where(mask, feature, torch.zeros_like(feature))
|
|
return feature_masked, input_lens
|
|
|
|
@property
|
|
def feature_dim(self):
|
|
"""返回特征大小
|
|
|
|
:return: 特征大小
|
|
:rtype: int
|
|
"""
|
|
if self._feature_method == 'LogMelSpectrogram':
|
|
return self._feature_conf.n_mels
|
|
elif self._feature_method == 'MelSpectrogram':
|
|
return self._feature_conf.n_mels
|
|
elif self._feature_method == 'Spectrogram':
|
|
return self._feature_conf.n_fft // 2 + 1
|
|
elif self._feature_method == 'MFCC':
|
|
return self._feature_conf.n_mfcc
|
|
elif self._feature_method == 'Fbank':
|
|
return self._feature_conf.num_mel_bins
|
|
else:
|
|
raise Exception('没有{}预处理方法'.format(self._feature_method))
|
|
|
|
class KaldiFbank(nn.Module):
|
|
def __init__(self, **kwargs):
|
|
super(KaldiFbank, self).__init__()
|
|
self.kwargs = kwargs
|
|
|
|
def forward(self, waveforms):
|
|
"""
|
|
:param waveforms: [Batch, Length]
|
|
:return: [Batch, Length, Feature]
|
|
"""
|
|
log_fbanks = []
|
|
for waveform in waveforms:
|
|
if len(waveform.shape) == 1:
|
|
waveform = waveform.unsqueeze(0)
|
|
log_fbank = Kaldi.fbank(waveform, **self.kwargs)
|
|
log_fbank = log_fbank.transpose(0, 1)
|
|
log_fbanks.append(log_fbank)
|
|
log_fbank = torch.stack(log_fbanks)
|
|
return log_fbank
|