1/mvector/models/loss.py

98 lines
3.7 KiB
Python
Raw Normal View History

2025-04-18 19:56:58 +08:00
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdditiveAngularMargin(nn.Module):
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
"""The Implementation of Additive Angular Margin (AAM) proposed
in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''
(https://arxiv.org/abs/1906.07317)
Args:
margin (float, optional): margin factor. Defaults to 0.0.
scale (float, optional): scale factor. Defaults to 1.0.
easy_margin (bool, optional): easy_margin flag. Defaults to False.
"""
super(AdditiveAngularMargin, self).__init__()
self.margin = margin
self.scale = scale
self.easy_margin = easy_margin
self.cos_m = math.cos(self.margin)
self.sin_m = math.sin(self.margin)
self.th = math.cos(math.pi - self.margin)
self.mm = math.sin(math.pi - self.margin) * self.margin
def forward(self, outputs, targets):
cosine = outputs.float()
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
outputs = (targets * phi) + ((1.0 - targets) * cosine)
return self.scale * outputs
class AAMLoss(nn.Module):
def __init__(self, margin=0.2, scale=30, easy_margin=False):
super(AAMLoss, self).__init__()
self.loss_fn = AdditiveAngularMargin(margin=margin, scale=scale, easy_margin=easy_margin)
self.criterion = torch.nn.KLDivLoss(reduction="sum")
def forward(self, outputs, targets):
targets = F.one_hot(targets, outputs.shape[1]).float()
predictions = self.loss_fn(outputs, targets)
predictions = F.log_softmax(predictions, dim=1)
loss = self.criterion(predictions, targets) / targets.sum()
return loss
class AMLoss(nn.Module):
def __init__(self, margin=0.2, scale=30):
super(AMLoss, self).__init__()
self.m = margin
self.s = scale
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
def forward(self, outputs, targets):
label_view = targets.view(-1, 1)
delt_costh = torch.zeros(outputs.size(), device=targets.device).scatter_(1, label_view, self.m)
costh_m = outputs - delt_costh
predictions = self.s * costh_m
loss = self.criterion(predictions, targets) / targets.shape[0]
return loss
class ARMLoss(nn.Module):
def __init__(self, margin=0.2, scale=30):
super(ARMLoss, self).__init__()
self.m = margin
self.s = scale
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
def forward(self, outputs, targets):
label_view = targets.view(-1, 1)
delt_costh = torch.zeros(outputs.size(), device=targets.device).scatter_(1, label_view, self.m)
costh_m = outputs - delt_costh
costh_m_s = self.s * costh_m
delt_costh_m_s = costh_m_s.gather(1, label_view).repeat(1, costh_m_s.size()[1])
costh_m_s_reduct = costh_m_s - delt_costh_m_s
predictions = torch.where(costh_m_s_reduct < 0.0, torch.zeros_like(costh_m_s), costh_m_s)
loss = self.criterion(predictions, targets) / targets.shape[0]
return loss
class CELoss(nn.Module):
def __init__(self):
super(CELoss, self).__init__()
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
def forward(self, outputs, targets):
loss = self.criterion(outputs, targets) / targets.shape[0]
return loss