def clotho_collate_fn(batch: MutableSequence[ndarray], nb_t_steps: Union[AnyStr, Tuple[int, int]], input_pad_at: str, output_pad_at: str) \ -> Tuple[Tensor, Tensor]: """Pads data. :param batch: Batch data. :type batch: list[numpy.ndarray] :param nb_t_steps: Number of time steps to\ pad/truncate to. Cab use\ 'max', 'min', or exact number\ e.g. (1024, 10). :type nb_t_steps: str|(int, int) :param input_pad_at: Pad input at the start or\ at the end? :type input_pad_at: str :param output_pad_at: Pad output at the start or\ at the end? :type output_pad_at: str :return: Padded data. :rtype: torch.Tensor, torch.Tensor """ if type(nb_t_steps) == str: truncate_fn = max if nb_t_steps.lower() == 'max' else min in_t_steps = truncate_fn([i[0].shape[0] for i in batch]) out_t_steps = truncate_fn([i[1].shape[0] for i in batch]) else: in_t_steps, out_t_steps = nb_t_steps in_dim = batch[0][0].shape[-1] eos_token = batch[0][1][-1] PAD = 4367 input_tensor, output_tensor = [], [] for in_b, out_b in batch: if in_t_steps >= in_b.shape[0]: padding = pt_zeros(in_t_steps - in_b.shape[0], in_dim).float() data = [from_numpy(in_b).float()] if input_pad_at.lower() == 'start': data.insert(0, padding) else: data.append(padding) tmp_in: Tensor = pt_cat(data) else: tmp_in: Tensor = from_numpy(in_b[:in_t_steps, :]).float() input_tensor.append(tmp_in.unsqueeze_(0)) if out_t_steps >= out_b.shape[0]: padding = pt_ones(out_t_steps - len(out_b)).mul(PAD).long() data = [from_numpy(out_b).long()] if output_pad_at.lower() == 'start': data.insert(0, padding) else: data.append(padding) tmp_out: Tensor = pt_cat(data) else: tmp_out: Tensor = from_numpy(out_b[:out_t_steps]).long() output_tensor.append(tmp_out.unsqueeze_(0)) input_tensor = pt_cat(input_tensor) output_tensor = pt_cat(output_tensor) return input_tensor, output_tensor
def clotho_collate_fn_eval(batch: MutableSequence[ndarray], nb_t_steps: Union[AnyStr, Tuple[int, int]], input_pad_at: str, output_pad_at: str, split: str, augment:bool) \ -> Tuple[Tensor, Tensor, Tensor, list]: """Pads data. :param batch: Batch data. :type batch: list[numpy.ndarray] :param nb_t_steps: Number of time steps to\ pad/truncate to. Cab use\ 'max', 'min', or exact number\ e.g. (1024, 10). :type nb_t_steps: str|(int, int) :param input_pad_at: Pad input at the start or\ at the end? :type input_pad_at: str :param output_pad_at: Pad output at the start or\ at the end? :type output_pad_at: str :return: Padded data. :rtype: torch.Tensor, torch.Tensor """ if type(nb_t_steps) == str: truncate_fn = max if nb_t_steps.lower() == 'max' else min in_t_steps = truncate_fn([i[0].shape[0] for i in batch]) out_t_steps = truncate_fn([i[1].shape[0] for i in batch]) else: in_t_steps, out_t_steps = nb_t_steps in_dim = batch[0][0].shape[-1] eos_token = batch[0][1][-1] batch = sorted(batch, key=lambda x: x[-1], reverse=True) PAD = 4367 input_tensor, output_tensor = [], [] for in_b, out_b, ref, filename, out_len in batch: if in_t_steps >= in_b.shape[0]: padding = pt_zeros(in_t_steps - in_b.shape[0], in_dim).float() data = [from_numpy(in_b).float()] if input_pad_at.lower() == 'start': data.insert(0, padding) else: data.append(padding) tmp_in: Tensor = pt_cat(data) else: tmp_in: Tensor = from_numpy(in_b[:in_t_steps, :]).float() input_tensor.append(tmp_in.unsqueeze_(0)) if out_t_steps >= out_b.shape[0]: padding = pt_ones(out_t_steps - len(out_b)).mul(PAD).long() data = [from_numpy(out_b).long()] if output_pad_at.lower() == 'start': data.insert(0, padding) else: data.append(padding) tmp_out: Tensor = pt_cat(data) else: tmp_out: Tensor = from_numpy(out_b[:out_t_steps]).long() output_tensor.append(tmp_out.unsqueeze_(0)) input_tensor = pt_cat(input_tensor) if augment: input_tensor = spec_augment(input_tensor) output_tensor = pt_cat(output_tensor) all_ref = [i[2] for i in batch] filename = [i[3] for i in batch] *_, target_len = zip(*batch) target_len = torch.LongTensor(target_len) return input_tensor, output_tensor, target_len, all_ref
def clotho_train_collate_fn(batch: MutableSequence[ndarray], nb_t_steps: Union[AnyStr, Tuple[int, int]], input_pad_at: str, output_pad_at: str) \ -> Tuple[Tensor, Tensor, Tensor, Tensor, list]: """Pads data. :param batch: Batch data. :type batch: list[numpy.ndarray] :param nb_t_steps: Number of time steps to\ pad/truncate to. Cab use\ 'max', 'min', or exact number\ e.g. (1024, 10). :type nb_t_steps: str|(int, int) :param input_pad_at: Pad input at the start or\ at the end? :type input_pad_at: str :param output_pad_at: Pad output at the start or\ at the end? :type output_pad_at: str :return: Padded data. :rtype: torch.Tensor, torch.Tensor """ def make_seq_even(sequences, audio_lengths): even_seqs = [] even_len = [] for i, s in enumerate(sequences): if len(s) % 2 != 0: even_seqs.append(s[:-1]) even_len.append(audio_lengths[i] - 1) else: even_seqs.append(s) even_len.append(audio_lengths[i]) return even_seqs, even_len if type(nb_t_steps) == str: truncate_fn = max if nb_t_steps.lower() == 'max' else min in_t_steps = truncate_fn([i[0].shape[0] for i in batch]) out_t_steps = truncate_fn([i[1].shape[0] for i in batch]) else: in_t_steps, out_t_steps = nb_t_steps in_dim = batch[0][0].shape[-1] eos_token = batch[0][1][-1] input_tensor, output_tensor = [], [] audio_lengths, text_lengths = [], [] file_ids_list = [] for in_b, out_b, fileid_b in batch: audio_lengths.append(in_b.shape[0]) # print("toto", out_b.shape) text_lengths.append(out_b.shape[0]) file_ids_list.extend(fileid_b) if in_t_steps >= in_b.shape[0]: padding = pt_zeros(in_t_steps - in_b.shape[0], in_dim).float() data = [from_numpy(in_b).float()] if input_pad_at.lower() == 'start': data.insert(0, padding) else: data.append(padding) tmp_in: Tensor = pt_cat(data) else: tmp_in: Tensor = from_numpy(in_b[:in_t_steps, :]).float() # input_tensor.append(tmp_in.unsqueeze_(0)) input_tensor.append(tmp_in) if out_t_steps >= out_b.shape[0]: padding = pt_ones(out_t_steps - len(out_b)).mul(eos_token).long() data = [from_numpy(out_b).long()] if output_pad_at.lower() == 'start': data.insert(0, padding) else: data.append(padding) tmp_out: Tensor = pt_cat(data) else: tmp_out: Tensor = from_numpy(out_b[:out_t_steps]).long() # output_tensor.append(tmp_out.unsqueeze_(0)) output_tensor.append(tmp_out) # we sort by increasing lengths # print("audio_lengths", audio_lengths) audio_sorted_indices = sorted(range(len(audio_lengths)), key=lambda k: audio_lengths[k]) audio_batch_sorted = [input_tensor[i] for i in audio_sorted_indices] audio_lengths_sorted = [audio_lengths[i] for i in audio_sorted_indices] # print("before, audio_sorted_indices", audio_sorted_indices) # print("audio_lengths_sorted", audio_lengths_sorted) # get text with the audio_sorted_indices indices text_batch_sorted = [ output_tensor[i].unsqueeze_(0) for i in audio_sorted_indices ] text_lengths = [text_lengths[i] for i in audio_sorted_indices] # print("text_lengths", text_lengths) # print("before, text_lengths", text_lengths) # make all audio tensors to even length even_audio_batch_sorted, even_audio_lengths_sorted = make_seq_even( audio_batch_sorted, audio_lengths_sorted) # reverse lists: largest sequence first (needed for packed sequences) # audio_sorted_indices = audio_sorted_indices[::-1] even_audio_lengths_sorted = even_audio_lengths_sorted[::-1] even_audio_batch_sorted = even_audio_batch_sorted[::-1] text_batch_sorted = text_batch_sorted[::-1] text_lengths = text_lengths[::-1] text_lengths = LongTensor(text_lengths) # print("text_lengths tensor", text_lengths) text_batch_sorted = pt_cat(text_batch_sorted) even_audio_lengths_sorted = LongTensor(even_audio_lengths_sorted) # we pad the sequences and get a tensor input_tensor = rnn_utils.pad_sequence( even_audio_batch_sorted) # size: T, B, F=40 # let's sort the file ids list with the sorted indices: # print("????", len(audio_sorted_indices)) # print("????", len(file_ids_list), file_ids_list) file_ids_list_sorted = [file_ids_list[ind] for ind in audio_sorted_indices] file_ids_list_sorted = file_ids_list_sorted[::-1] # print("????", len(file_ids_list_sorted), file_ids_list_sorted) # print('input_tensor', input_tensor.size()) # print("text_batch_sorted", text_batch_sorted) # print("even_audio_lengths_sorted tensor", even_audio_lengths_sorted) # print("text_lengths", text_lengths) # print("x_pad", x_pad.size()) # for i in range(len(audio_batch)): # print(i, audio_lengths_sorted[i], audio_batch_sorted[i].size(), x_pad[:,i,:].size(), text_lengths[i], padded_text[i].size()) return input_tensor, text_batch_sorted, even_audio_lengths_sorted, text_lengths, file_ids_list_sorted