105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
from multiprocessing import Pool, cpu_count
|
|||
|
|
from datetime import timedelta
|
|||
|
|
|
|||
|
|
from pydub import AudioSegment
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 生成数据列表
|
|||
|
|
def get_data_list(infodata_path, zhvoice_path):
|
|||
|
|
print('正在读取标注文件...')
|
|||
|
|
with open(infodata_path, 'r', encoding='utf-8') as f:
|
|||
|
|
lines = f.readlines()
|
|||
|
|
|
|||
|
|
data = []
|
|||
|
|
speakers = []
|
|||
|
|
speakers_dict = {}
|
|||
|
|
for line in lines:
|
|||
|
|
line = json.loads(line.replace('\n', ''))
|
|||
|
|
duration_ms = line['duration_ms']
|
|||
|
|
if duration_ms < 1300:
|
|||
|
|
continue
|
|||
|
|
speaker = line['speaker']
|
|||
|
|
if speaker not in speakers:
|
|||
|
|
speakers_dict[speaker] = len(speakers)
|
|||
|
|
speakers.append(speaker)
|
|||
|
|
label = speakers_dict[speaker]
|
|||
|
|
sound_path = os.path.join(zhvoice_path, line['index'])
|
|||
|
|
data.append([sound_path.replace('\\', '/'), label])
|
|||
|
|
print(f'一共有{len(data)}条数据!')
|
|||
|
|
return data
|
|||
|
|
|
|||
|
|
|
|||
|
|
def mp32wav(num, data_list):
|
|||
|
|
start = time.time()
|
|||
|
|
for i, data in enumerate(data_list):
|
|||
|
|
sound_path, label = data
|
|||
|
|
if os.path.exists(sound_path):
|
|||
|
|
save_path = sound_path.replace('.mp3', '.wav')
|
|||
|
|
if not os.path.exists(save_path):
|
|||
|
|
wav = AudioSegment.from_mp3(sound_path)
|
|||
|
|
wav.export(save_path, format="wav")
|
|||
|
|
os.remove(sound_path)
|
|||
|
|
if i % 100 == 0:
|
|||
|
|
eta_sec = ((time.time() - start) / 100 * (len(data_list) - i))
|
|||
|
|
start = time.time()
|
|||
|
|
eta_str = str(timedelta(seconds=int(eta_sec)))
|
|||
|
|
print(f'进程{num}进度:[{i}/{len(data_list)}],剩余时间:{eta_str}')
|
|||
|
|
|
|||
|
|
|
|||
|
|
def split_data(list_temp, n):
|
|||
|
|
length = len(list_temp) // n
|
|||
|
|
for i in range(0, len(list_temp), length):
|
|||
|
|
yield list_temp[i:i + length]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main(infodata_path, list_path, zhvoice_path, to_wav=True, num_workers=2):
|
|||
|
|
if to_wav:
|
|||
|
|
text = input(f'音频文件将会转换为wav格式,这个过程可能很长,而且最终文件大小接近100G,是否继续?(y/n)')
|
|||
|
|
if text is None or text != 'y':
|
|||
|
|
return
|
|||
|
|
else:
|
|||
|
|
text = input(f'将会直接使用MP3格式文件,但读取速度会比wav格式慢,是否继续?(y/n)')
|
|||
|
|
if text is None or text != 'y':
|
|||
|
|
return
|
|||
|
|
data_all = []
|
|||
|
|
data = get_data_list(infodata_path=infodata_path, zhvoice_path=zhvoice_path)
|
|||
|
|
if to_wav:
|
|||
|
|
print('准备把MP3总成WAV格式...')
|
|||
|
|
split_d = split_data(data, num_workers)
|
|||
|
|
pool = Pool(num_workers)
|
|||
|
|
for i, d in enumerate(split_d):
|
|||
|
|
pool.apply_async(mp32wav, (i, d))
|
|||
|
|
pool.close()
|
|||
|
|
pool.join()
|
|||
|
|
for d in data:
|
|||
|
|
sound_path, label = d
|
|||
|
|
sound_path = sound_path.replace('.mp3', '.wav')
|
|||
|
|
if os.path.exists(sound_path):
|
|||
|
|
data_all.append([sound_path, label])
|
|||
|
|
else:
|
|||
|
|
for d in data:
|
|||
|
|
sound_path, label = d
|
|||
|
|
if os.path.exists(sound_path):
|
|||
|
|
data_all.append(d)
|
|||
|
|
f_train = open(os.path.join(list_path, 'train_list.txt'), 'w')
|
|||
|
|
f_test = open(os.path.join(list_path, 'test_list.txt'), 'w')
|
|||
|
|
for i, d in enumerate(data_all):
|
|||
|
|
sound_path, label = d
|
|||
|
|
if i % 200 == 0:
|
|||
|
|
f_test.write(f'{sound_path}\t{label}\n')
|
|||
|
|
else:
|
|||
|
|
f_train.write(f'{sound_path}\t{label}\n')
|
|||
|
|
f_test.close()
|
|||
|
|
f_train.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main(infodata_path='dataset/zhvoice/text/infodata.json',
|
|||
|
|
list_path='dataset',
|
|||
|
|
zhvoice_path='dataset/zhvoice',
|
|||
|
|
to_wav=False,
|
|||
|
|
num_workers=cpu_count())
|