26 lines
940 B
Python
26 lines
940 B
Python
import numpy as np
|
||
import torch
|
||
|
||
|
||
# 对一个batch的数据处理
|
||
def collate_fn(batch):
|
||
# 找出音频长度最长的
|
||
batch = sorted(batch, key=lambda sample: sample[0].shape[0], reverse=True)
|
||
max_audio_length = batch[0][0].shape[0]
|
||
batch_size = len(batch)
|
||
# 以最大的长度创建0张量
|
||
inputs = np.zeros((batch_size, max_audio_length), dtype='float32')
|
||
input_lens_ratio = []
|
||
labels = []
|
||
for x in range(batch_size):
|
||
sample = batch[x]
|
||
tensor = sample[0]
|
||
labels.append(sample[1])
|
||
seq_length = tensor.shape[0]
|
||
# 将数据插入都0张量中,实现了padding
|
||
inputs[x, :seq_length] = tensor[:]
|
||
input_lens_ratio.append(seq_length/max_audio_length)
|
||
input_lens_ratio = np.array(input_lens_ratio, dtype='float32')
|
||
labels = np.array(labels, dtype='int64')
|
||
return torch.tensor(inputs), torch.tensor(labels), torch.tensor(input_lens_ratio)
|