1/train.py
2025-04-18 19:56:58 +08:00

27 lines
1.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import functools
from mvector.trainer import MVectorTrainer
from mvector.utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
add_arg("local_rank", int, 0, '多卡训练需要的参数')
add_arg("use_gpu", bool, True, '是否使用GPU训练')
add_arg('augment_conf_path',str, 'configs/augmentation.json', '数据增强的配置文件为json格式')
add_arg('save_model_path', str, 'models/', '模型保存的路径')
add_arg('resume_model', str, None, '恢复训练当为None则不使用预训练模型')
add_arg('save_image_path', str, 'output/images/', "保存结果图的路径")
add_arg('pretrained_model', str, 'models/ecapa_tdnn_MFCC/best_model/', '预训练模型的路径当为None则不使用预训练模型')
args = parser.parse_args()
print_arguments(args=args)
# 获取训练器
trainer = MVectorTrainer(configs=args.configs, use_gpu=args.use_gpu)
trainer.train(save_model_path=args.save_model_path,
resume_model=args.resume_model,
pretrained_model=args.pretrained_model,
augment_conf_path=args.augment_conf_path)