98 lines
3.7 KiB
Python
98 lines
3.7 KiB
Python
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class AdditiveAngularMargin(nn.Module):
|
|
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
|
|
"""The Implementation of Additive Angular Margin (AAM) proposed
|
|
in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''
|
|
(https://arxiv.org/abs/1906.07317)
|
|
|
|
Args:
|
|
margin (float, optional): margin factor. Defaults to 0.0.
|
|
scale (float, optional): scale factor. Defaults to 1.0.
|
|
easy_margin (bool, optional): easy_margin flag. Defaults to False.
|
|
"""
|
|
super(AdditiveAngularMargin, self).__init__()
|
|
self.margin = margin
|
|
self.scale = scale
|
|
self.easy_margin = easy_margin
|
|
|
|
self.cos_m = math.cos(self.margin)
|
|
self.sin_m = math.sin(self.margin)
|
|
self.th = math.cos(math.pi - self.margin)
|
|
self.mm = math.sin(math.pi - self.margin) * self.margin
|
|
|
|
def forward(self, outputs, targets):
|
|
cosine = outputs.float()
|
|
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
|
phi = cosine * self.cos_m - sine * self.sin_m
|
|
if self.easy_margin:
|
|
phi = torch.where(cosine > 0, phi, cosine)
|
|
else:
|
|
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
|
outputs = (targets * phi) + ((1.0 - targets) * cosine)
|
|
return self.scale * outputs
|
|
|
|
|
|
class AAMLoss(nn.Module):
|
|
def __init__(self, margin=0.2, scale=30, easy_margin=False):
|
|
super(AAMLoss, self).__init__()
|
|
self.loss_fn = AdditiveAngularMargin(margin=margin, scale=scale, easy_margin=easy_margin)
|
|
self.criterion = torch.nn.KLDivLoss(reduction="sum")
|
|
|
|
def forward(self, outputs, targets):
|
|
targets = F.one_hot(targets, outputs.shape[1]).float()
|
|
predictions = self.loss_fn(outputs, targets)
|
|
predictions = F.log_softmax(predictions, dim=1)
|
|
loss = self.criterion(predictions, targets) / targets.sum()
|
|
return loss
|
|
|
|
|
|
class AMLoss(nn.Module):
|
|
def __init__(self, margin=0.2, scale=30):
|
|
super(AMLoss, self).__init__()
|
|
self.m = margin
|
|
self.s = scale
|
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
|
|
|
def forward(self, outputs, targets):
|
|
label_view = targets.view(-1, 1)
|
|
delt_costh = torch.zeros(outputs.size(), device=targets.device).scatter_(1, label_view, self.m)
|
|
costh_m = outputs - delt_costh
|
|
predictions = self.s * costh_m
|
|
loss = self.criterion(predictions, targets) / targets.shape[0]
|
|
return loss
|
|
|
|
|
|
class ARMLoss(nn.Module):
|
|
def __init__(self, margin=0.2, scale=30):
|
|
super(ARMLoss, self).__init__()
|
|
self.m = margin
|
|
self.s = scale
|
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
|
|
|
def forward(self, outputs, targets):
|
|
label_view = targets.view(-1, 1)
|
|
delt_costh = torch.zeros(outputs.size(), device=targets.device).scatter_(1, label_view, self.m)
|
|
costh_m = outputs - delt_costh
|
|
costh_m_s = self.s * costh_m
|
|
delt_costh_m_s = costh_m_s.gather(1, label_view).repeat(1, costh_m_s.size()[1])
|
|
costh_m_s_reduct = costh_m_s - delt_costh_m_s
|
|
predictions = torch.where(costh_m_s_reduct < 0.0, torch.zeros_like(costh_m_s), costh_m_s)
|
|
loss = self.criterion(predictions, targets) / targets.shape[0]
|
|
return loss
|
|
|
|
|
|
class CELoss(nn.Module):
|
|
def __init__(self):
|
|
super(CELoss, self).__init__()
|
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
|
|
|
def forward(self, outputs, targets):
|
|
loss = self.criterion(outputs, targets) / targets.shape[0]
|
|
return loss
|