1/mvector/utils/utils.py

86 lines
2.6 KiB
Python
Raw Normal View History

2025-04-18 19:56:58 +08:00
import distutils.util
import numpy as np
from tqdm import tqdm
from mvector.utils.logger import setup_logger
logger = setup_logger(__name__)
def print_arguments(args=None, configs=None):
if args:
logger.info("----------- 额外配置参数 -----------")
for arg, value in sorted(vars(args).items()):
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
if configs:
logger.info("----------- 配置文件参数 -----------")
for arg, value in sorted(configs.items()):
if isinstance(value, dict):
logger.info(f"{arg}:")
for a, v in sorted(value.items()):
if isinstance(v, dict):
logger.info(f"\t{a}:")
for a1, v1 in sorted(v.items()):
logger.info("\t\t%s: %s" % (a1, v1))
else:
logger.info("\t%s: %s" % (a, v))
else:
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
type = distutils.util.strtobool if type == bool else type
argparser.add_argument("--" + argname,
default=default,
type=type,
help=help + ' 默认: %(default)s.',
**kwargs)
class Dict(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
def dict_to_object(dict_obj):
if not isinstance(dict_obj, dict):
return dict_obj
inst = Dict()
for k, v in dict_obj.items():
inst[k] = dict_to_object(v)
return inst
# 根据对角余弦值计算准确率和最优的阈值
def cal_accuracy_threshold(y_score, y_true):
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
best_accuracy = 0
best_threshold = 0
for i in tqdm(range(0, 100)):
threshold = i * 0.01
y_test = (y_score >= threshold)
acc = np.mean((y_test == y_true).astype(int))
if acc > best_accuracy:
best_accuracy = acc
best_threshold = threshold
return best_accuracy, best_threshold
# 根据对角余弦值计算准确率
def cal_accuracy(y_score, y_true, threshold=0.5):
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
y_test = (y_score >= threshold)
accuracy = np.mean((y_test == y_true).astype(int))
return accuracy
# 计算对角余弦值
def cosin_metric(x1, x2):
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))