import torch from torch import nn from torchaudio.transforms import MelSpectrogram, Spectrogram, MFCC import torchaudio.compliance.kaldi as Kaldi class AudioFeaturizer(nn.Module): """音频特征器 :param feature_method: 所使用的预处理方法 :type feature_method: str :param feature_conf: 预处理方法的参数 :type feature_conf: dict """ def __init__(self, feature_method='MelSpectrogram', feature_conf={}): super().__init__() self._feature_conf = feature_conf self._feature_method = feature_method if feature_method == 'MelSpectrogram': self.feat_fun = MelSpectrogram(**feature_conf) elif feature_method == 'Spectrogram': self.feat_fun = Spectrogram(**feature_conf) elif feature_method == 'MFCC': melkwargs = feature_conf.copy() del melkwargs['sample_rate'] del melkwargs['n_mfcc'] self.feat_fun = MFCC(sample_rate=self._feature_conf.sample_rate, n_mfcc=self._feature_conf.n_mfcc, melkwargs=melkwargs) elif feature_method == 'Fbank': self.feat_fun = KaldiFbank(**feature_conf) else: raise Exception(f'预处理方法 {self._feature_method} 不存在!') def forward(self, waveforms, input_lens_ratio): """从AudioSegment中提取音频特征 :param waveforms: Audio segment to extract features from. :type waveforms: AudioSegment :param input_lens_ratio: input length ratio :type input_lens_ratio: tensor :return: Spectrogram audio feature in 2darray. :rtype: ndarray """ feature = self.feat_fun(waveforms) feature = feature.transpose(2, 1) # 归一化 mean = torch.mean(feature, 1, keepdim=True) std = torch.std(feature, 1, keepdim=True) feature = (feature - mean) / (std + 1e-5) # 对掩码比例进行扩展 input_lens = (input_lens_ratio * feature.shape[1]) mask_lens = torch.round(input_lens).long() mask_lens = mask_lens.unsqueeze(1) input_lens = input_lens.int() # 生成掩码张量 idxs = torch.arange(feature.shape[1], device=feature.device).repeat(feature.shape[0], 1) mask = idxs < mask_lens mask = mask.unsqueeze(-1) # 对特征进行掩码操作 feature_masked = torch.where(mask, feature, torch.zeros_like(feature)) return feature_masked, input_lens @property def feature_dim(self): """返回特征大小 :return: 特征大小 :rtype: int """ if self._feature_method == 'LogMelSpectrogram': return self._feature_conf.n_mels elif self._feature_method == 'MelSpectrogram': return self._feature_conf.n_mels elif self._feature_method == 'Spectrogram': return self._feature_conf.n_fft // 2 + 1 elif self._feature_method == 'MFCC': return self._feature_conf.n_mfcc elif self._feature_method == 'Fbank': return self._feature_conf.num_mel_bins else: raise Exception('没有{}预处理方法'.format(self._feature_method)) class KaldiFbank(nn.Module): def __init__(self, **kwargs): super(KaldiFbank, self).__init__() self.kwargs = kwargs def forward(self, waveforms): """ :param waveforms: [Batch, Length] :return: [Batch, Length, Feature] """ log_fbanks = [] for waveform in waveforms: if len(waveform.shape) == 1: waveform = waveform.unsqueeze(0) log_fbank = Kaldi.fbank(waveform, **self.kwargs) log_fbank = log_fbank.transpose(0, 1) log_fbanks.append(log_fbank) log_fbank = torch.stack(log_fbanks) return log_fbank