1/create_data.py
2025-04-18 19:56:58 +08:00

105 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import time
from multiprocessing import Pool, cpu_count
from datetime import timedelta
from pydub import AudioSegment
# 生成数据列表
def get_data_list(infodata_path, zhvoice_path):
print('正在读取标注文件...')
with open(infodata_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
data = []
speakers = []
speakers_dict = {}
for line in lines:
line = json.loads(line.replace('\n', ''))
duration_ms = line['duration_ms']
if duration_ms < 1300:
continue
speaker = line['speaker']
if speaker not in speakers:
speakers_dict[speaker] = len(speakers)
speakers.append(speaker)
label = speakers_dict[speaker]
sound_path = os.path.join(zhvoice_path, line['index'])
data.append([sound_path.replace('\\', '/'), label])
print(f'一共有{len(data)}条数据!')
return data
def mp32wav(num, data_list):
start = time.time()
for i, data in enumerate(data_list):
sound_path, label = data
if os.path.exists(sound_path):
save_path = sound_path.replace('.mp3', '.wav')
if not os.path.exists(save_path):
wav = AudioSegment.from_mp3(sound_path)
wav.export(save_path, format="wav")
os.remove(sound_path)
if i % 100 == 0:
eta_sec = ((time.time() - start) / 100 * (len(data_list) - i))
start = time.time()
eta_str = str(timedelta(seconds=int(eta_sec)))
print(f'进程{num}进度:[{i}/{len(data_list)}],剩余时间:{eta_str}')
def split_data(list_temp, n):
length = len(list_temp) // n
for i in range(0, len(list_temp), length):
yield list_temp[i:i + length]
def main(infodata_path, list_path, zhvoice_path, to_wav=True, num_workers=2):
if to_wav:
text = input(f'音频文件将会转换为wav格式这个过程可能很长而且最终文件大小接近100G是否继续(y/n)')
if text is None or text != 'y':
return
else:
text = input(f'将会直接使用MP3格式文件但读取速度会比wav格式慢是否继续(y/n)')
if text is None or text != 'y':
return
data_all = []
data = get_data_list(infodata_path=infodata_path, zhvoice_path=zhvoice_path)
if to_wav:
print('准备把MP3总成WAV格式...')
split_d = split_data(data, num_workers)
pool = Pool(num_workers)
for i, d in enumerate(split_d):
pool.apply_async(mp32wav, (i, d))
pool.close()
pool.join()
for d in data:
sound_path, label = d
sound_path = sound_path.replace('.mp3', '.wav')
if os.path.exists(sound_path):
data_all.append([sound_path, label])
else:
for d in data:
sound_path, label = d
if os.path.exists(sound_path):
data_all.append(d)
f_train = open(os.path.join(list_path, 'train_list.txt'), 'w')
f_test = open(os.path.join(list_path, 'test_list.txt'), 'w')
for i, d in enumerate(data_all):
sound_path, label = d
if i % 200 == 0:
f_test.write(f'{sound_path}\t{label}\n')
else:
f_train.write(f'{sound_path}\t{label}\n')
f_test.close()
f_train.close()
if __name__ == '__main__':
main(infodata_path='dataset/zhvoice/text/infodata.json',
list_path='dataset',
zhvoice_path='dataset/zhvoice',
to_wav=False,
num_workers=cpu_count())