91 lines
3.4 KiB
Python
91 lines
3.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter
|
|
|
|
|
|
class SpeakerIdetification(nn.Module):
|
|
def __init__(
|
|
self,
|
|
backbone,
|
|
num_class=1,
|
|
loss_type='AAMLoss',
|
|
lin_blocks=0,
|
|
lin_neurons=192,
|
|
dropout=0.1, ):
|
|
|
|
"""
|
|
初始化说话人识别模型,包括说话人背骨网络和在训练中针对说话人类别数的线性变换。
|
|
|
|
参数:
|
|
backbone (Paddle.nn.Layer class): 说话人识别背骨网络模型。
|
|
num_class (_type_): 训练数据集中说话人的类别数。
|
|
lin_blocks (int, 可选): 从嵌入向量到最终线性层之间的线性层变换数量。默认为0。
|
|
lin_neurons (int, 可选): 最终线性层的输出维度。默认为192。
|
|
dropout (float, 可选): 嵌入向量上的dropout因子。默认为0.1。
|
|
"""
|
|
super(SpeakerIdetification, self).__init__()
|
|
# 初始化背骨网络模型
|
|
# 背骨网络的输出为目标嵌入向量
|
|
self.backbone = backbone
|
|
self.loss_type = loss_type
|
|
if dropout > 0:
|
|
self.dropout = nn.Dropout(dropout)
|
|
else:
|
|
self.dropout = None
|
|
|
|
# 构建说话人分类器
|
|
input_size = self.backbone.emb_size
|
|
self.blocks = list()
|
|
# 添加线性层变换
|
|
for i in range(lin_blocks):
|
|
self.blocks.extend([
|
|
nn.BatchNorm1d(input_size),
|
|
nn.Linear(in_features=input_size, out_features=lin_neurons),
|
|
])
|
|
input_size = lin_neurons
|
|
|
|
# 最终层初始化
|
|
if self.loss_type == 'AAMLoss':
|
|
self.weight = Parameter(torch.FloatTensor(num_class, input_size), requires_grad=True)
|
|
nn.init.xavier_normal_(self.weight, gain=1)
|
|
elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss':
|
|
self.weight = Parameter(torch.FloatTensor(input_size, num_class), requires_grad=True)
|
|
nn.init.xavier_normal_(self.weight, gain=1)
|
|
elif self.loss_type == 'CELoss':
|
|
self.output = nn.Linear(input_size, num_class)
|
|
else:
|
|
raise Exception(f'没有{self.loss_type}损失函数!')
|
|
|
|
def forward(self, x):
|
|
"""
|
|
执行说话人识别模型的前向传播,
|
|
包括说话人嵌入模型和分类器模型网络
|
|
|
|
参数:
|
|
x (paddle.Tensor): 输入的音频特征,
|
|
形状=[批大小, 时间, 维度]
|
|
|
|
返回值:
|
|
paddle.Tensor: 返回特征的logits
|
|
"""
|
|
# x.shape: (N, L, C)
|
|
x = self.backbone(x) # (N, emb_size)
|
|
if self.dropout is not None:
|
|
x = self.dropout(x)
|
|
|
|
for fc in self.blocks:
|
|
x = fc(x)
|
|
if self.loss_type == 'AAMLoss':
|
|
logits = F.linear(F.normalize(x), F.normalize(self.weight, dim=-1))
|
|
elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss':
|
|
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
|
|
x_norm = torch.div(x, x_norm)
|
|
w_norm = torch.norm(self.weight, p=2, dim=0, keepdim=True).clamp(min=1e-12)
|
|
w_norm = torch.div(self.weight, w_norm)
|
|
logits = torch.mm(x_norm, w_norm)
|
|
else:
|
|
logits = self.output(x)
|
|
|
|
return logits
|