1/mvector/models/pooling.py
2025-04-18 19:56:58 +08:00

76 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class TemporalAveragePooling(nn.Module):
def __init__(self):
"""TAP
Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
Link: https://arxiv.org/pdf/1903.12058.pdf
"""
super(TemporalAveragePooling, self).__init__()
def forward(self, x):
"""Computes Temporal Average Pooling Module
Args:
x (torch.Tensor): Input tensor (#batch, channels, frames).
Returns:
torch.Tensor: Output tensor (#batch, channels)
"""
x = torch.mean(x, dim=2)
return x
class TemporalStatisticsPooling(nn.Module):
def __init__(self):
"""TSP
Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
Link http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
"""
super(TemporalStatisticsPooling, self).__init__()
def forward(self, x):
"""Computes Temporal Statistics Pooling Module
Args:
x (torch.Tensor): Input tensor (#batch, channels, frames).
Returns:
torch.Tensor: Output tensor (#batch, channels*2)
"""
mean = torch.mean(x, dim=2)
var = torch.var(x, dim=2)
x = torch.cat((mean, var), dim=1)
return x
class SelfAttentivePooling(nn.Module):
def __init__(self, in_dim, bottleneck_dim=128):
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
# attention dim = 128
super(SelfAttentivePooling, self).__init__()
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
return mean
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, bottleneck_dim=128):
super().__init__()
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)