def prepare_input_sequence(fields, device, batch_size=1, dataset=None, load_mels=False, load_pitch=False): print(fields) fields['text'] = [ torch.LongTensor(text_to_sequence(t, ['english_cleaners'])) for t in fields['text'] ] #order = np.argsort([-t.size(0) for t in fields['text']]) order = range(len(fields['text'])) fields['text'] = [fields['text'][i] for i in order] fields['text_lens'] = torch.LongTensor([t.size(0) for t in fields['text']]) if load_mels: assert 'mel' in fields fields['mel'] = [ torch.load(Path(dataset, fields['mel'][i])).t() for i in order ] fields['mel_lens'] = torch.LongTensor( [t.size(0) for t in fields['mel']]) if load_pitch: assert 'pitch' in fields fields['pitch'] = [ torch.load(Path(dataset, fields['pitch'][i])) for i in order ] fields['pitch_lens'] = torch.LongTensor( [t.size(0) for t in fields['pitch']]) if 'output' in fields: fields['output'] = [fields['output'][i] for i in order] # cut into batches & pad batches = [] for b in range(0, len(order), batch_size): batch = {f: values[b:b + batch_size] for f, values in fields.items()} for f in batch: if f == 'text': batch[f] = pad_sequence(batch[f], batch_first=True) elif f == 'mel' and load_mels: batch[f] = pad_sequence(batch[f], batch_first=True).permute(0, 2, 1) elif f == 'pitch' and load_pitch: batch[f] = pad_sequence(batch[f], batch_first=True) if type(batch[f]) is torch.Tensor: batch[f] = batch[f].to(device) batches.append(batch) return batches
def get_text(self, text): text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) return text_norm