Exemple #1
0
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
    print("Saving model and optimizer state at iteration {} to {}".format(
        iteration, filepath))
    model_for_saving = DNNnet(**DNN_net_config).cuda()
    model_for_saving.load_state_dict(model.state_dict())
    torch.save(
        {
            'model': model_for_saving,
            'iteration': iteration,
            'optimizer': optimizer.state_dict(),
            'learning_rate': learning_rate
        }, filepath)
Exemple #2
0
def DNN_test(checkpoint_path, filename):
    model = DNNnet(**DNN_net_config)
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model_for_loading = checkpoint_dict['model']
    model.load_state_dict(model_for_loading.state_dict())
    print("Loaded checkpoint '{}' ".format(checkpoint_path))

    model.eval()

    Data_gen = spec2load(**DNN_data_config)

    audio, sr = ms.load_wav_to_torch(filename)
    # Take segment
    if audio.size(0) >= 2240:
        max_audio_start = audio.size(0) - 2240
        audio_start = random.randint(0, max_audio_start)
        audio = audio[audio_start:audio_start + 2240]
    else:
        audio = torch.nn.functional.pad(audio, (2240 - audio.size(0), 0),
                                        'constant').data

    spec = Data_gen.get_spec(audio)
    feed_spec = spec[:, :-2]
    targ_spec = spec[:, -2:]
    feed_spec = feed_spec.unsqueeze(0)
    gener_spec = model.forward(feed_spec)

    # gener_linear = gener_linear.squeeze().view(-1, 2).T
    gener_spec = gener_spec.detach().numpy()
    targ_spec = torch.cat((targ_spec[:, 0], targ_spec[:, 1]), 0)
    targ_spec = targ_spec.numpy()
    # targ_mel_db = librosa.power_to_db(targ_mel[0], ref=np.max)
    # gener_mel_db = librosa.power_to_db(gener_mel[0], ref=np.max)
    plt.figure()

    plt.plot(targ_spec)
    # librosa.display.specshow(targ_mel, x_axis='time', y_axis='mel')

    plt.plot(gener_spec[0], 'r')
    # librosa.display.specshow(gener_mel, x_axis='time', y_axis='mel')

    plt.show()
Exemple #3
0
def net_init(DNN_path, waveglow_path):
    assert os.path.isfile(DNN_path)
    assert os.path.isfile(waveglow_path)
    DNN_checkpoint_dict = torch.load(DNN_path)
    DNN_model = DNNnet(**DNN_net_config)
    iteration = DNN_checkpoint_dict['iteration']
    model_for_loading = DNN_checkpoint_dict['model']
    DNN_model.load_state_dict(model_for_loading.state_dict())
    print("Loaded checkpoint '{}' (iteration {})".format(DNN_path, iteration))

    DNN_model.cuda().eval()

    waveglow = torch.load(waveglow_path)['model']
    waveglow = waveglow.remove_weightnorm(waveglow)
    waveglow.cuda().eval()
    print("Loaded checkpoint '{}' (iteration {})".format(
        waveglow_path, iteration))

    return DNN_model, waveglow
Exemple #4
0
def DNN_stretch(feed_audio, DNN_path):
    assert os.path.isfile(DNN_path)
    checkpoint_dict = torch.load(DNN_path, map_location='cpu')
    model = DNNnet(**DNN_net_config)
    iteration = checkpoint_dict['iteration']
    model_for_loading = checkpoint_dict['model']
    model.load_state_dict(model_for_loading.state_dict())
    print("Loaded checkpoint '{}' (iteration {})".format(DNN_path, iteration))

    model.eval()

    Data_gen = spec2load(**DNN_data_config)

    spec = Data_gen.get_spec(feed_audio)
    feed_spec = spec.unsqueeze(0)
    gener_spec = model.forward(feed_spec)

    out_mel = torch.stack([gener_spec[:, :spec.size(0)], \
        gener_spec[:, spec.size(0):]], dim=-1)

    com_mel = torch.cat((feed_spec, out_mel), dim=2)

    return com_mel
Exemple #5
0
def DNN_train(output_directory, epochs, learning_rate,\
    iters_per_checkpoint, batch_size, seed, checkpoint_path):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # MMSELoss
    criterion = Dnn_net_Loss()

    model = DNNnet(**DNN_net_config).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        model.cuda()
        iteration += 1  # next iteration is iteration + 1

    trainSet = spec2load(**DNN_data_config)
    trainSet.load_buffer()
    train_loader = DataLoader(trainSet,
                              num_workers=1,
                              shuffle=True,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)
    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
        os.chmod(output_directory, 0o775)
    print("output directory", output_directory)
    model.train(mode=True)
    epoch_offset = max(0, int(iteration / len(train_loader)))

    for epoch in range(epoch_offset, epochs):
        epoch_ave_loss = 0
        for i, batch in tqdm(enumerate(train_loader)):
            model.zero_grad()

            feed_in, targ_in = batch
            feed_in = torch.autograd.Variable(feed_in.cuda())
            targ_in = torch.autograd.Variable(targ_in.cuda())
            outputs = model(feed_in)

            loss = criterion(outputs, targ_in)

            reduced_loss = loss.item()

            loss.backward()

            optimizer.step()

            epoch_ave_loss += reduced_loss

            if (iteration % iters_per_checkpoint == 0):
                print("{}:\t{:.9f}".format(iteration, reduced_loss))

            iteration += 1

        checkpoint_path = "{}/DNN_net_{}".format(output_directory, epoch)
        save_checkpoint(model, optimizer, learning_rate, iteration,
                        checkpoint_path)
        epoch_ave_loss = epoch_ave_loss / i
        print("Epoch: {}, the average epoch loss: {}".format(
            epoch, epoch_ave_loss))