示例#1
0
def prepare_dataloaders(hparams):
    # Get data, data loaders and collate function ready
    #pids = [id2sp[hparams.speaker_A], id2sp[hparams.speaker_B]]
    trainset = TextMelIDLoader(hparams.training_list,
                               hparams.mel_mean_std,
                               hparams.speaker_A,
                               hparams.speaker_B,
                               pids=None)
    valset = TextMelIDLoader(hparams.validation_list,
                             hparams.mel_mean_std,
                             hparams.speaker_A,
                             hparams.speaker_B,
                             pids=None)
    collate_fn = TextMelIDCollate(
        lcm(hparams.n_frames_per_step_encoder,
            hparams.n_frames_per_step_decoder))

    train_sampler = DistributedSampler(trainset) \
        if hparams.distributed_run else None

    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=True,
                              sampler=train_sampler,
                              batch_size=hparams.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)
    return train_loader, valset, collate_fn
def gen_embedding(speaker):

    training_list = hparams.training_list

    train_set_A = TextMelIDLoader(training_list,
                                  hparams.mel_mean_std,
                                  hparams.speaker_A,
                                  hparams.speaker_B,
                                  shuffle=False,
                                  pids=[speaker])

    collate_fn = TextMelIDCollate(
        lcm(hparams.n_frames_per_step_encoder,
            hparams.n_frames_per_step_decoder))

    train_loader_A = DataLoader(train_set_A,
                                num_workers=1,
                                shuffle=False,
                                sampler=None,
                                batch_size=1,
                                pin_memory=False,
                                drop_last=True,
                                collate_fn=collate_fn)

    with torch.no_grad():

        speaker_embeddings = []

        for i, batch in enumerate(train_loader_A):
            #print i
            x, y = model.parse_batch(batch)
            text_input_padded, mel_padded, text_lengths, mel_lengths, speaker_id = x
            speaker_id, speaker_embedding = model.speaker_encoder.inference(
                mel_padded)

            speaker_embedding = speaker_embedding.data.cpu().numpy()

            #speaker_embedding = speaker_embedding * 0.1

            speaker_embeddings.append(speaker_embedding)

        speaker_embeddings = np.vstack(speaker_embeddings)

    print(speaker_embeddings.shape)
    if not os.path.exists('outdir/embeddings'):
        os.makedirs('outdir/embeddings')

    np.save('outdir/embeddings/%s.npy' % speaker, speaker_embeddings)
    plot_data([speaker_embeddings], 'outdir/embeddings/%s.pdf' % speaker)
示例#3
0
hparams = create_hparams(args.hparams)

test_list = hparams.validation_list
checkpoint_path = args.checkpoint_path
gen_num = args.num
ISMEL = (not hparams.predict_spectrogram)

model = load_model(hparams)

model.load_state_dict(torch.load(checkpoint_path)['state_dict'], strict=False)
_ = model.eval()

train_set_A = TextMelIDLoader(hparams.training_list,
                              hparams.mel_mean_std,
                              hparams.speaker_A,
                              hparams.speaker_B,
                              shuffle=False,
                              pids=[hparams.speaker_A])

train_set_B = TextMelIDLoader(hparams.training_list,
                              hparams.mel_mean_std,
                              hparams.speaker_A,
                              hparams.speaker_B,
                              shuffle=False,
                              pids=[hparams.speaker_B])

test_set_A = TextMelIDLoader(test_list,
                             hparams.mel_mean_std,
                             hparams.speaker_A,
                             hparams.speaker_B,
                             shuffle=False,
示例#4
0
            ax = axes[i]
        # origin='bottom' no longer working after matplotlib 3.3.2
        g = ax.imshow(data[i],
                      aspect='auto',
                      origin='lower',
                      interpolation='none')
        plt.colorbar(g, ax=ax)
    plt.savefig(fn)


model = load_model(hparams)

model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
_ = model.eval()

test_set = TextMelIDLoader(test_list, hparams.mel_mean_std, shuffle=True)
sample_list = test_set.file_path_list
collate_fn = TextMelIDCollate(
    lcm(hparams.n_frames_per_step_encoder, hparams.n_frames_per_step_decoder))

test_loader = DataLoader(test_set,
                         num_workers=1,
                         shuffle=False,
                         sampler=None,
                         batch_size=1,
                         pin_memory=False,
                         drop_last=True,
                         collate_fn=collate_fn)

task = 'tts' if input_text else 'vc'
path_save = os.path.join(checkpoint_path.replace('checkpoint', 'test'), task)