def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): print("Saving model and optimizer state at iteration {} to {}".format( iteration, filepath)) model_for_saving = DNNnet(**DNN_net_config).cuda() model_for_saving.load_state_dict(model.state_dict()) torch.save( { 'model': model_for_saving, 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'learning_rate': learning_rate }, filepath)
def DNN_test(checkpoint_path, filename): model = DNNnet(**DNN_net_config) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') model_for_loading = checkpoint_dict['model'] model.load_state_dict(model_for_loading.state_dict()) print("Loaded checkpoint '{}' ".format(checkpoint_path)) model.eval() Data_gen = spec2load(**DNN_data_config) audio, sr = ms.load_wav_to_torch(filename) # Take segment if audio.size(0) >= 2240: max_audio_start = audio.size(0) - 2240 audio_start = random.randint(0, max_audio_start) audio = audio[audio_start:audio_start + 2240] else: audio = torch.nn.functional.pad(audio, (2240 - audio.size(0), 0), 'constant').data spec = Data_gen.get_spec(audio) feed_spec = spec[:, :-2] targ_spec = spec[:, -2:] feed_spec = feed_spec.unsqueeze(0) gener_spec = model.forward(feed_spec) # gener_linear = gener_linear.squeeze().view(-1, 2).T gener_spec = gener_spec.detach().numpy() targ_spec = torch.cat((targ_spec[:, 0], targ_spec[:, 1]), 0) targ_spec = targ_spec.numpy() # targ_mel_db = librosa.power_to_db(targ_mel[0], ref=np.max) # gener_mel_db = librosa.power_to_db(gener_mel[0], ref=np.max) plt.figure() plt.plot(targ_spec) # librosa.display.specshow(targ_mel, x_axis='time', y_axis='mel') plt.plot(gener_spec[0], 'r') # librosa.display.specshow(gener_mel, x_axis='time', y_axis='mel') plt.show()
def net_init(DNN_path, waveglow_path): assert os.path.isfile(DNN_path) assert os.path.isfile(waveglow_path) DNN_checkpoint_dict = torch.load(DNN_path) DNN_model = DNNnet(**DNN_net_config) iteration = DNN_checkpoint_dict['iteration'] model_for_loading = DNN_checkpoint_dict['model'] DNN_model.load_state_dict(model_for_loading.state_dict()) print("Loaded checkpoint '{}' (iteration {})".format(DNN_path, iteration)) DNN_model.cuda().eval() waveglow = torch.load(waveglow_path)['model'] waveglow = waveglow.remove_weightnorm(waveglow) waveglow.cuda().eval() print("Loaded checkpoint '{}' (iteration {})".format( waveglow_path, iteration)) return DNN_model, waveglow
def DNN_stretch(feed_audio, DNN_path): assert os.path.isfile(DNN_path) checkpoint_dict = torch.load(DNN_path, map_location='cpu') model = DNNnet(**DNN_net_config) iteration = checkpoint_dict['iteration'] model_for_loading = checkpoint_dict['model'] model.load_state_dict(model_for_loading.state_dict()) print("Loaded checkpoint '{}' (iteration {})".format(DNN_path, iteration)) model.eval() Data_gen = spec2load(**DNN_data_config) spec = Data_gen.get_spec(feed_audio) feed_spec = spec.unsqueeze(0) gener_spec = model.forward(feed_spec) out_mel = torch.stack([gener_spec[:, :spec.size(0)], \ gener_spec[:, spec.size(0):]], dim=-1) com_mel = torch.cat((feed_spec, out_mel), dim=2) return com_mel
def DNN_train(output_directory, epochs, learning_rate,\ iters_per_checkpoint, batch_size, seed, checkpoint_path): torch.manual_seed(seed) torch.cuda.manual_seed(seed) # MMSELoss criterion = Dnn_net_Loss() model = DNNnet(**DNN_net_config).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Load checkpoint if one exists iteration = 0 if checkpoint_path != "": model, optimizer, iteration = load_checkpoint(checkpoint_path, model, optimizer) model.cuda() iteration += 1 # next iteration is iteration + 1 trainSet = spec2load(**DNN_data_config) trainSet.load_buffer() train_loader = DataLoader(trainSet, num_workers=1, shuffle=True, batch_size=batch_size, pin_memory=False, drop_last=True) if not os.path.isdir(output_directory): os.makedirs(output_directory) os.chmod(output_directory, 0o775) print("output directory", output_directory) model.train(mode=True) epoch_offset = max(0, int(iteration / len(train_loader))) for epoch in range(epoch_offset, epochs): epoch_ave_loss = 0 for i, batch in tqdm(enumerate(train_loader)): model.zero_grad() feed_in, targ_in = batch feed_in = torch.autograd.Variable(feed_in.cuda()) targ_in = torch.autograd.Variable(targ_in.cuda()) outputs = model(feed_in) loss = criterion(outputs, targ_in) reduced_loss = loss.item() loss.backward() optimizer.step() epoch_ave_loss += reduced_loss if (iteration % iters_per_checkpoint == 0): print("{}:\t{:.9f}".format(iteration, reduced_loss)) iteration += 1 checkpoint_path = "{}/DNN_net_{}".format(output_directory, epoch) save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) epoch_ave_loss = epoch_ave_loss / i print("Epoch: {}, the average epoch loss: {}".format( epoch, epoch_ave_loss))