1/mvector/models/fc.py
2025-04-18 19:56:58 +08:00

91 lines
3.4 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class SpeakerIdetification(nn.Module):
def __init__(
self,
backbone,
num_class=1,
loss_type='AAMLoss',
lin_blocks=0,
lin_neurons=192,
dropout=0.1, ):
"""
初始化说话人识别模型,包括说话人背骨网络和在训练中针对说话人类别数的线性变换。
参数:
backbone (Paddle.nn.Layer class): 说话人识别背骨网络模型。
num_class (_type_): 训练数据集中说话人的类别数。
lin_blocks (int, 可选): 从嵌入向量到最终线性层之间的线性层变换数量。默认为0。
lin_neurons (int, 可选): 最终线性层的输出维度。默认为192。
dropout (float, 可选): 嵌入向量上的dropout因子。默认为0.1。
"""
super(SpeakerIdetification, self).__init__()
# 初始化背骨网络模型
# 背骨网络的输出为目标嵌入向量
self.backbone = backbone
self.loss_type = loss_type
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
# 构建说话人分类器
input_size = self.backbone.emb_size
self.blocks = list()
# 添加线性层变换
for i in range(lin_blocks):
self.blocks.extend([
nn.BatchNorm1d(input_size),
nn.Linear(in_features=input_size, out_features=lin_neurons),
])
input_size = lin_neurons
# 最终层初始化
if self.loss_type == 'AAMLoss':
self.weight = Parameter(torch.FloatTensor(num_class, input_size), requires_grad=True)
nn.init.xavier_normal_(self.weight, gain=1)
elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss':
self.weight = Parameter(torch.FloatTensor(input_size, num_class), requires_grad=True)
nn.init.xavier_normal_(self.weight, gain=1)
elif self.loss_type == 'CELoss':
self.output = nn.Linear(input_size, num_class)
else:
raise Exception(f'没有{self.loss_type}损失函数!')
def forward(self, x):
"""
执行说话人识别模型的前向传播,
包括说话人嵌入模型和分类器模型网络
参数:
x (paddle.Tensor): 输入的音频特征,
形状=[批大小, 时间, 维度]
返回值:
paddle.Tensor: 返回特征的logits
"""
# x.shape: (N, L, C)
x = self.backbone(x) # (N, emb_size)
if self.dropout is not None:
x = self.dropout(x)
for fc in self.blocks:
x = fc(x)
if self.loss_type == 'AAMLoss':
logits = F.linear(F.normalize(x), F.normalize(self.weight, dim=-1))
elif self.loss_type == 'AMLoss' or self.loss_type == 'ARMLoss':
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
x_norm = torch.div(x, x_norm)
w_norm = torch.norm(self.weight, p=2, dim=0, keepdim=True).clamp(min=1e-12)
w_norm = torch.div(self.weight, w_norm)
logits = torch.mm(x_norm, w_norm)
else:
logits = self.output(x)
return logits