import distutils.util import numpy as np from tqdm import tqdm from mvector.utils.logger import setup_logger logger = setup_logger(__name__) def print_arguments(args=None, configs=None): if args: logger.info("----------- 额外配置参数 -----------") for arg, value in sorted(vars(args).items()): logger.info("%s: %s" % (arg, value)) logger.info("------------------------------------------------") if configs: logger.info("----------- 配置文件参数 -----------") for arg, value in sorted(configs.items()): if isinstance(value, dict): logger.info(f"{arg}:") for a, v in sorted(value.items()): if isinstance(v, dict): logger.info(f"\t{a}:") for a1, v1 in sorted(v.items()): logger.info("\t\t%s: %s" % (a1, v1)) else: logger.info("\t%s: %s" % (a, v)) else: logger.info("%s: %s" % (arg, value)) logger.info("------------------------------------------------") def add_arguments(argname, type, default, help, argparser, **kwargs): type = distutils.util.strtobool if type == bool else type argparser.add_argument("--" + argname, default=default, type=type, help=help + ' 默认: %(default)s.', **kwargs) class Dict(dict): __setattr__ = dict.__setitem__ __getattr__ = dict.__getitem__ def dict_to_object(dict_obj): if not isinstance(dict_obj, dict): return dict_obj inst = Dict() for k, v in dict_obj.items(): inst[k] = dict_to_object(v) return inst # 根据对角余弦值计算准确率和最优的阈值 def cal_accuracy_threshold(y_score, y_true): y_score = np.asarray(y_score) y_true = np.asarray(y_true) best_accuracy = 0 best_threshold = 0 for i in tqdm(range(0, 100)): threshold = i * 0.01 y_test = (y_score >= threshold) acc = np.mean((y_test == y_true).astype(int)) if acc > best_accuracy: best_accuracy = acc best_threshold = threshold return best_accuracy, best_threshold # 根据对角余弦值计算准确率 def cal_accuracy(y_score, y_true, threshold=0.5): y_score = np.asarray(y_score) y_true = np.asarray(y_true) y_test = (y_score >= threshold) accuracy = np.mean((y_test == y_true).astype(int)) return accuracy # 计算对角余弦值 def cosin_metric(x1, x2): return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))