76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TemporalAveragePooling(nn.Module):
|
|||
|
|
def __init__(self):
|
|||
|
|
"""TAP
|
|||
|
|
Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
|
|||
|
|
Link: https://arxiv.org/pdf/1903.12058.pdf
|
|||
|
|
"""
|
|||
|
|
super(TemporalAveragePooling, self).__init__()
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
"""Computes Temporal Average Pooling Module
|
|||
|
|
Args:
|
|||
|
|
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
|||
|
|
Returns:
|
|||
|
|
torch.Tensor: Output tensor (#batch, channels)
|
|||
|
|
"""
|
|||
|
|
x = torch.mean(x, dim=2)
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TemporalStatisticsPooling(nn.Module):
|
|||
|
|
def __init__(self):
|
|||
|
|
"""TSP
|
|||
|
|
Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
|
|||
|
|
Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
|
|||
|
|
"""
|
|||
|
|
super(TemporalStatisticsPooling, self).__init__()
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
"""Computes Temporal Statistics Pooling Module
|
|||
|
|
Args:
|
|||
|
|
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
|||
|
|
Returns:
|
|||
|
|
torch.Tensor: Output tensor (#batch, channels*2)
|
|||
|
|
"""
|
|||
|
|
mean = torch.mean(x, dim=2)
|
|||
|
|
var = torch.var(x, dim=2)
|
|||
|
|
x = torch.cat((mean, var), dim=1)
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SelfAttentivePooling(nn.Module):
|
|||
|
|
def __init__(self, in_dim, bottleneck_dim=128):
|
|||
|
|
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
|||
|
|
# attention dim = 128
|
|||
|
|
super(SelfAttentivePooling, self).__init__()
|
|||
|
|
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
|||
|
|
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
|||
|
|
alpha = torch.tanh(self.linear1(x))
|
|||
|
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
|||
|
|
mean = torch.sum(alpha * x, dim=2)
|
|||
|
|
return mean
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AttentiveStatsPool(nn.Module):
|
|||
|
|
def __init__(self, in_dim, bottleneck_dim=128):
|
|||
|
|
super().__init__()
|
|||
|
|
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
|||
|
|
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
|||
|
|
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
|||
|
|
alpha = torch.tanh(self.linear1(x))
|
|||
|
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
|||
|
|
mean = torch.sum(alpha * x, dim=2)
|
|||
|
|
residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
|
|||
|
|
std = torch.sqrt(residuals.clamp(min=1e-9))
|
|||
|
|
return torch.cat([mean, std], dim=1)
|