예제 #1
0
def prepare_input_sequence(fields,
                           device,
                           batch_size=1,
                           dataset=None,
                           load_mels=False,
                           load_pitch=False):
    print(fields)
    fields['text'] = [
        torch.LongTensor(text_to_sequence(t, ['english_cleaners']))
        for t in fields['text']
    ]
    #order = np.argsort([-t.size(0) for t in fields['text']])
    order = range(len(fields['text']))

    fields['text'] = [fields['text'][i] for i in order]
    fields['text_lens'] = torch.LongTensor([t.size(0) for t in fields['text']])

    if load_mels:
        assert 'mel' in fields
        fields['mel'] = [
            torch.load(Path(dataset, fields['mel'][i])).t() for i in order
        ]
        fields['mel_lens'] = torch.LongTensor(
            [t.size(0) for t in fields['mel']])

    if load_pitch:
        assert 'pitch' in fields
        fields['pitch'] = [
            torch.load(Path(dataset, fields['pitch'][i])) for i in order
        ]
        fields['pitch_lens'] = torch.LongTensor(
            [t.size(0) for t in fields['pitch']])

    if 'output' in fields:
        fields['output'] = [fields['output'][i] for i in order]

    # cut into batches & pad
    batches = []
    for b in range(0, len(order), batch_size):
        batch = {f: values[b:b + batch_size] for f, values in fields.items()}
        for f in batch:
            if f == 'text':
                batch[f] = pad_sequence(batch[f], batch_first=True)
            elif f == 'mel' and load_mels:
                batch[f] = pad_sequence(batch[f],
                                        batch_first=True).permute(0, 2, 1)
            elif f == 'pitch' and load_pitch:
                batch[f] = pad_sequence(batch[f], batch_first=True)

            if type(batch[f]) is torch.Tensor:
                batch[f] = batch[f].to(device)
        batches.append(batch)

    return batches
 def get_text(self, text):
     text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
     return text_norm