import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Parameter class SpeakerIdetification(nn.Module): def __init__( self, backbone, num_class=1, loss_type='AAMLoss', lin_blocks=0, lin_neurons=192, dropout=0.1, ): """ 初始化说话人识别模型,包括说话人背骨网络和在训练中针对说话人类别数的线性变换。 参数: backbone (Paddle.nn.Layer class): 说话人识别背骨网络模型。 num_class (_type_): 训练数据集中说话人的类别数。 lin_blocks (int, 可选): 从嵌入向量到最终线性层之间的线性层变换数量。默认为0。 lin_neurons (int, 可选): 最终线性层的输出维度。默认为192。 dropout (float, 可选): 嵌入向量上的dropout因子。默认为0.1。 """ super(SpeakerIdetification, self).__init__() # 初始化背骨网络模型 # 背骨网络的输出为目标嵌入向量 self.backbone = backbone self.loss_type = loss_type if dropout > 0: self.dropout = nn.Dropout(dropout) else: self.dropout = None # 构建说话人分类器 input_size = self.backbone.emb_size self.blocks = list() # 添加线性层变换 for i in range(lin_blocks): self.blocks.extend([ nn.BatchNorm1d(input_size), nn.Linear(in_features=input_size, out_features=lin_neurons), ]) input_size = lin_neurons # 最终层初始化 if self.loss_type == 'AAMLoss': self.weight = Parameter(torch.FloatTensor(num_class, input_size), requires_grad=True) nn.init.xavier_normal_(self.weight, gain=1) elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss': self.weight = Parameter(torch.FloatTensor(input_size, num_class), requires_grad=True) nn.init.xavier_normal_(self.weight, gain=1) elif self.loss_type == 'CELoss': self.output = nn.Linear(input_size, num_class) else: raise Exception(f'没有{self.loss_type}损失函数!') def forward(self, x): """ 执行说话人识别模型的前向传播, 包括说话人嵌入模型和分类器模型网络 参数: x (paddle.Tensor): 输入的音频特征, 形状=[批大小, 时间, 维度] 返回值: paddle.Tensor: 返回特征的logits """ # x.shape: (N, L, C) x = self.backbone(x) # (N, emb_size) if self.dropout is not None: x = self.dropout(x) for fc in self.blocks: x = fc(x) if self.loss_type == 'AAMLoss': logits = F.linear(F.normalize(x), F.normalize(self.weight, dim=-1)) elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss': x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) x_norm = torch.div(x, x_norm) w_norm = torch.norm(self.weight, p=2, dim=0, keepdim=True).clamp(min=1e-12) w_norm = torch.div(self.weight, w_norm) logits = torch.mm(x_norm, w_norm) else: logits = self.output(x) return logits