1/mvector/models/ecapa_tdnn.py
2025-04-18 19:56:58 +08:00

190 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from mvector.models.pooling import AttentiveStatsPool, TemporalAveragePooling
from mvector.models.pooling import SelfAttentivePooling, TemporalStatisticsPooling
class Res2Conv1dReluBn(nn.Module):
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False, scale=4):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
# 遍历每个分支
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
# 其他分支则将当前子特征与前面所有子特征相加,形成残差连接
sp = sp + spx[i]
# Order: conv -> relu -> bn
sp = self.convs[i](sp)
sp = self.bns[i](F.relu(sp))
out.append(sp)
if self.scale != 1:
out.append(spx[self.nums])
# 将所有子分支的结果在通道维度上合并
out = torch.cat(out, dim=1)
return out
class Conv1dReluBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
class SE_Connect(nn.Module):
def __init__(self, channels, s=2):
super().__init__()
assert channels % s == 0, "{} % {} != 0".format(channels, s)
self.linear1 = nn.Linear(channels, channels // s)
self.linear2 = nn.Linear(channels // s, channels)
def forward(self, x):
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2)
return out
def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
"""
初始化函数。
参数:
- input_size: 输入尺寸默认为80。
- channels: 通道数默认为512。
- kernel_size: 卷积核大小, 默认为3。
- embd_dim: 嵌入维度默认为192。
- pooling_type: 池化类型,默认为"ASP",可选值包括"ASP""SAP""TAP""TSP"
- dilation : 空洞卷积的空洞率默认为1。
- scale: SE模块的缩放比例默认为8。
返回值:
- 无。
"""
return nn.Sequential(
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
Res2Conv1dReluBn(channels, kernel_size, stride, padding, dilation, scale=scale),
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
SE_Connect(channels)
)
class EcapaTdnn(nn.Module):
"""
初始化函数。
参数:
- input_size: 输入尺寸默认为80。
- channels: 通道数默认为512。
- embd_dim: 嵌入维度默认为192。
- pooling_type: 池化类型,默认为"ASP",可选值包括"ASP""SAP""TAP""TSP"
"""
def __init__(self, input_size=80, channels=512, embd_dim=192, pooling_type="ASP"):
super().__init__()
self.layer1 = Conv1dReluBn(input_size, channels, kernel_size=5, padding=2, dilation=1)
self.layer2 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8)
self.layer3 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8)
self.layer4 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8)
cat_channels = channels * 3
self.emb_size = embd_dim
self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1)
if pooling_type == "ASP":
self.pooling = AttentiveStatsPool(cat_channels, 128)
self.bn1 = nn.BatchNorm1d(cat_channels * 2)
self.linear = nn.Linear(cat_channels * 2, embd_dim)
self.bn2 = nn.BatchNorm1d(embd_dim)
elif pooling_type == "SAP":
self.pooling = SelfAttentivePooling(cat_channels, 128)
self.bn1 = nn.BatchNorm1d(cat_channels)
self.linear = nn.Linear(cat_channels, embd_dim)
self.bn2 = nn.BatchNorm1d(embd_dim)
elif pooling_type == "TAP":
self.pooling = TemporalAveragePooling()
self.bn1 = nn.BatchNorm1d(cat_channels)
self.linear = nn.Linear(cat_channels, embd_dim)
self.bn2 = nn.BatchNorm1d(embd_dim)
elif pooling_type == "TSP":
self.pooling = TemporalStatisticsPooling()
self.bn1 = nn.BatchNorm1d(cat_channels * 2)
self.linear = nn.Linear(cat_channels * 2, embd_dim)
self.bn2 = nn.BatchNorm1d(embd_dim)
else:
raise Exception(f'没有{pooling_type}池化层!')
# def forward(self, x):
# """
# Compute embeddings.
# Args:
# x (torch.Tensor): Input data with shape (N, time, freq).
# Returns:
# torch.Tensor: Output embeddings with shape (N, self.emb_size, 1)
# """
# x = x.transpose(2, 1)
# out1 = self.layer1(x)
# out2 = self.layer2(out1) + out1
# out3 = self.layer3(out1 + out2) + out1 + out2
# out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
# out = torch.cat([out2, out3, out4], dim=1)
# out = F.relu(self.conv(out))
# out = self.bn1(self.pooling(out))
# out = self.bn2(self.linear(out))
# return out
def forward(self, x):
"""
计算嵌入向量。
参数:
x (torch.Tensor): 输入数据,形状为 (N, time, freq)其中N为样本数量time为时间维度freq为频率维度。
返回值:
torch.Tensor: 输出嵌入向量,形状为 (N, self.emb_size, 1)
"""
# 将输入数据的频率和时间维度交换
x = x.transpose(2, 1)
# 通过第一层卷积层
out1 = self.layer1(x)
# 通过第二层卷积层,并与第一层输出相加
out2 = self.layer2(out1) + out1
# 通过第三层卷积层,并依次与前两层输出相加
out3 = self.layer3(out1 + out2) + out1 + out2
# 通过第四层卷积层,并依次与前三层输出相加
out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
# 将第二、三、四层的输出在特征维度上连接
out = torch.cat([out2, out3, out4], dim=1)
# 应用ReLU激活函数并通过卷积层处理
out = F.relu(self.conv(out))
# 经过批归一化和池化操作
out = self.bn1(self.pooling(out))
# 经过线性变换和批归一化
out = self.bn2(self.linear(out))
return out