import numpy as np from mvector.utils.logger import setup_logger logger = setup_logger(__name__) class TprAtFpr(object): def __init__(self, max_fpr=0.01): self.pos_score_list = [] self.neg_score_list = [] self.max_fpr = max_fpr def add(self, y_labels, y_scores): for y_label, y_score in zip(y_labels, y_scores): if y_label == 0: self.neg_score_list.append(y_score) else: self.pos_score_list.append(y_score) def reset(self): self.pos_score_list = [] self.neg_score_list = [] def calculate_eer(self, tprs, fprs): # 记录所有的eer值 eer_list = [] n = len(tprs) eer = 1.0 index = 0 for i in range(n): eer_list.append(fprs[i] + (1 - tprs[i])) if fprs[i] + (1 - tprs[i]) < eer: eer = fprs[i] + (1 - tprs[i]) index = i return eer, index,eer_list def calculate(self): tprs, fprs, thresholds = [], [], [] pos_score_list = np.array(self.pos_score_list) neg_score_list = np.array(self.neg_score_list) if len(pos_score_list) == 0: msg = f"The number of positive samples is 0, please add positive samples." logger.warning(msg) return tprs, fprs, thresholds, None, None if len(neg_score_list) == 0: msg = f"The number of negative samples is 0, please add negative samples." logger.warning(msg) return tprs, fprs, thresholds, None, None for i in range(0, 100): threshold = i / 100. tpr = np.sum(pos_score_list > threshold) / len(pos_score_list) fpr = np.sum(neg_score_list > threshold) / len(neg_score_list) tprs.append(tpr) fprs.append(fpr) thresholds.append(threshold) eer, index,eer_list = self.calculate_eer(fprs=fprs, tprs=tprs) # 根据对应的eer_list输出所有对应的阈值 for i in range(len(eer_list)): print(f"threshold: {thresholds[i]}, eer: {eer_list[i]}") return tprs, fprs, thresholds, eer, index