1/mvector/data_utils/collate_fn.py

26 lines
940 B
Python
Raw Normal View History

2025-04-18 19:56:58 +08:00
import numpy as np
import torch
# 对一个batch的数据处理
def collate_fn(batch):
# 找出音频长度最长的
batch = sorted(batch, key=lambda sample: sample[0].shape[0], reverse=True)
max_audio_length = batch[0][0].shape[0]
batch_size = len(batch)
# 以最大的长度创建0张量
inputs = np.zeros((batch_size, max_audio_length), dtype='float32')
input_lens_ratio = []
labels = []
for x in range(batch_size):
sample = batch[x]
tensor = sample[0]
labels.append(sample[1])
seq_length = tensor.shape[0]
# 将数据插入都0张量中实现了padding
inputs[x, :seq_length] = tensor[:]
input_lens_ratio.append(seq_length/max_audio_length)
input_lens_ratio = np.array(input_lens_ratio, dtype='float32')
labels = np.array(labels, dtype='int64')
return torch.tensor(inputs), torch.tensor(labels), torch.tensor(input_lens_ratio)