def train_log(self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"): # pylint: disable=no-self-use """Create visualizations and waveform examples. For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to be projected onto Tensorboard. Args: ap (AudioProcessor): audio processor used at training. batch (Dict): Model inputs used at the previous training step. outputs (Dict): Model outputs generated at the previoud training step. Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ y_hat = outputs[0]["model_outputs"] y = outputs[0]["waveform_seg"] figures = plot_results(y_hat, y, ap, name_prefix) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() audios = {f"{name_prefix}/audio": sample_voice} alignments = outputs[0]["alignments"] align_img = alignments[0].data.cpu().numpy().T figures.update({ "alignment": plot_alignment(align_img, output_fig=False), }) return figures, audios
def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use y_hat = outputs[0]["model_outputs"] y = outputs[0]["waveform_seg"] figures = plot_results(y_hat, y, ap, name_prefix) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() audios = {f"{name_prefix}/audio": sample_voice} alignments = outputs[0]["alignments"] align_img = alignments[0].data.cpu().numpy().T figures.update({ "alignment": plot_alignment(align_img, output_fig=False), }) return figures, audios
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: """Logging shared by the training and evaluation. Args: name (str): Name of the run. `train` or `eval`, ap (AudioProcessor): Audio processor used in training. batch (Dict): Batch used in the last train/eval step. outputs (Dict): Model outputs from the last train/eval step. Returns: Tuple[Dict, Dict]: log figures and audio samples. """ y_hat = outputs[0]["model_outputs"] y = batch["waveform"] figures = plot_results(y_hat, y, ap, name) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() audios = {f"{name}/audio": sample_voice} return figures, audios
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument # setup noise schedule and inference noise_schedule = self.config["test_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) for sample in samples: x = sample[0] x = x[None, :, :].to(next(self.parameters()).device) y = sample[1] y = y[None, :] # compute voice y_pred = self.inference(x) # compute spectrograms figures = plot_results(y_pred, y, ap, "test") # Sample audio sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy() return figures, {"test/audio": sample_voice}
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument # setup noise schedule and inference ap = assets["audio_processor"] noise_schedule = self.config["test_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) samples = test_loader.dataset.load_test_samples(1) for sample in samples: x = sample[0] x = x[None, :, :].to(next(self.parameters()).device) y = sample[1] y = y[None, :] # compute voice y_pred = self.inference(x) # compute spectrograms figures = plot_results(y_pred, y, ap, "test") # Sample audio sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy() return figures, {"test/audio": sample_voice}
def evaluate(model, criterion, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) model.eval() epoch_time = 0 keep_avg = KeepAverage() end_time = time.time() c_logger.print_eval_start() for num_iter, data in enumerate(data_loader): start_time = time.time() # format data m, x = format_data(data) loader_time = time.time() - end_time global_step += 1 # compute noisy input if hasattr(model, 'module'): noise, x_noisy, noise_scale = model.module.compute_y_n(x) else: noise, x_noisy, noise_scale = model.compute_y_n(x) # forward pass noise_hat = model(x_noisy, m, noise_scale) # compute losses loss = criterion(noise, noise_hat) loss_wavegrad_dict = {'wavegrad_loss':loss} loss_dict = dict() for key, value in loss_wavegrad_dict.items(): if isinstance(value, (int, float)): loss_dict[key] = value else: loss_dict[key] = value.item() step_time = time.time() - start_time epoch_time += step_time # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): update_eval_values['avg_' + key] = value update_eval_values['avg_loader_time'] = loader_time update_eval_values['avg_step_time'] = step_time keep_avg.update_values(update_eval_values) # print eval stats if c.print_eval: c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if args.rank == 0: data_loader.dataset.return_segments = False samples = data_loader.dataset.load_test_samples(1) m, x = format_test_data(samples[0]) # setup noise schedule and inference noise_schedule = c['test_noise_schedule'] betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps']) if hasattr(model, 'module'): model.module.compute_noise_level(betas) # compute voice x_pred = model.module.inference(m) else: model.compute_noise_level(betas) # compute voice x_pred = model.inference(m) # compute spectrograms figures = plot_results(x_pred, x, ap, global_step, 'eval') tb_logger.tb_eval_figures(global_step, figures) # Sample audio sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy() tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, c.audio["sample_rate"]) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) data_loader.dataset.return_segments = True return keep_avg.avg_values
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, scheduler_G, scheduler_D, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model_G.train() model_D.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: batch_n_iter = int( len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() for num_iter, data in enumerate(data_loader): start_time = time.time() # format data c_G, y_G, c_D, y_D = format_data(data) loader_time = time.time() - end_time global_step += 1 ############################## # GENERATOR ############################## # generator pass y_hat = model_G(c_G) y_hat_sub = None y_G_sub = None y_hat_vis = y_hat # for visualization # PQMF formatting if y_hat.shape[1] > 1: y_hat_sub = y_hat y_hat = model_G.pqmf_synthesis(y_hat) y_hat_vis = y_hat y_G_sub = model_G.pqmf_analysis(y_G) scores_fake, feats_fake, feats_real = None, None, None if global_step > c.steps_to_start_discriminator: # run D with or without cond. features if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat, c_G) else: D_out_fake = model_D(y_hat) D_out_real = None if c.use_feat_match_loss: with torch.no_grad(): D_out_real = model_D(y_G) # format D outputs if isinstance(D_out_fake, tuple): scores_fake, feats_fake = D_out_fake if D_out_real is None: feats_real = None else: _, feats_real = D_out_real else: scores_fake = D_out_fake # compute losses loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub) loss_G = loss_G_dict['G_loss'] # optimizer generator optimizer_G.zero_grad() loss_G.backward() if c.gen_clip_grad > 0: torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad) optimizer_G.step() if scheduler_G is not None: scheduler_G.step() loss_dict = dict() for key, value in loss_G_dict.items(): if isinstance(value, int): loss_dict[key] = value else: loss_dict[key] = value.item() ############################## # DISCRIMINATOR ############################## if global_step >= c.steps_to_start_discriminator: # discriminator pass with torch.no_grad(): y_hat = model_G(c_D) # PQMF formatting if y_hat.shape[1] > 1: y_hat = model_G.pqmf_synthesis(y_hat) # run D with or without cond. features if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat.detach(), c_D) D_out_real = model_D(y_D, c_D) else: D_out_fake = model_D(y_hat.detach()) D_out_real = model_D(y_D) # format D outputs if isinstance(D_out_fake, tuple): scores_fake, feats_fake = D_out_fake if D_out_real is None: scores_real, feats_real = None, None else: scores_real, feats_real = D_out_real else: scores_fake = D_out_fake scores_real = D_out_real # compute losses loss_D_dict = criterion_D(scores_fake, scores_real) loss_D = loss_D_dict['D_loss'] # optimizer discriminator optimizer_D.zero_grad() loss_D.backward() if c.disc_clip_grad > 0: torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad) optimizer_D.step() if scheduler_D is not None: scheduler_D.step() for key, value in loss_D_dict.items(): if isinstance(value, (int, float)): loss_dict[key] = value else: loss_dict[key] = value.item() step_time = time.time() - start_time epoch_time += step_time # get current learning rates current_lr_G = list(optimizer_G.param_groups)[0]['lr'] current_lr_D = list(optimizer_D.param_groups)[0]['lr'] # update avg stats update_train_values = dict() for key, value in loss_dict.items(): update_train_values['avg_' + key] = value update_train_values['avg_loader_time'] = loader_time update_train_values['avg_step_time'] = step_time keep_avg.update_values(update_train_values) # print training stats if global_step % c.print_step == 0: log_dict = { 'step_time': [step_time, 2], 'loader_time': [loader_time, 4], "current_lr_G": current_lr_G, "current_lr_D": current_lr_D } c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # plot step stats if global_step % 10 == 0: iter_stats = { "lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time } iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) # save checkpoint if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint(model_G, optimizer_G, scheduler_G, model_D, optimizer_D, scheduler_D, global_step, epoch, OUT_PATH, model_losses=loss_dict) # compute spectrograms figures = plot_results(y_hat_vis, y_G, ap, global_step, 'train') tb_logger.tb_train_figures(global_step, figures) # Sample audio sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() tb_logger.tb_train_audios(global_step, {'train/audio': sample_voice}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Training Epoch Stats epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(keep_avg.avg_values) if args.rank == 0: tb_logger.tb_train_epoch_stats(global_step, epoch_stats) # TODO: plot model stats # if c.tb_model_param_stats: # tb_logger.tb_model_weights(model, global_step) return keep_avg.avg_values, global_step
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) model_G.eval() model_D.eval() epoch_time = 0 keep_avg = KeepAverage() end_time = time.time() c_logger.print_eval_start() for num_iter, data in enumerate(data_loader): start_time = time.time() # format data c_G, y_G, _, _ = format_data(data) loader_time = time.time() - end_time global_step += 1 ############################## # GENERATOR ############################## # generator pass y_hat = model_G(c_G) y_hat_sub = None y_G_sub = None # PQMF formatting if y_hat.shape[1] > 1: y_hat_sub = y_hat y_hat = model_G.pqmf_synthesis(y_hat) y_G_sub = model_G.pqmf_analysis(y_G) scores_fake, feats_fake, feats_real = None, None, None if global_step > c.steps_to_start_discriminator: if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat, c_G) else: D_out_fake = model_D(y_hat) D_out_real = None if c.use_feat_match_loss: with torch.no_grad(): D_out_real = model_D(y_G) # format D outputs if isinstance(D_out_fake, tuple): scores_fake, feats_fake = D_out_fake if D_out_real is None: feats_real = None else: _, feats_real = D_out_real else: scores_fake = D_out_fake feats_fake, feats_real = None, None # compute losses loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub) loss_dict = dict() for key, value in loss_G_dict.items(): if isinstance(value, (int, float)): loss_dict[key] = value else: loss_dict[key] = value.item() ############################## # DISCRIMINATOR ############################## if global_step >= c.steps_to_start_discriminator: # discriminator pass with torch.no_grad(): y_hat = model_G(c_G) # PQMF formatting if y_hat.shape[1] > 1: y_hat = model_G.pqmf_synthesis(y_hat) # run D with or without cond. features if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat.detach(), c_G) D_out_real = model_D(y_G, c_G) else: D_out_fake = model_D(y_hat.detach()) D_out_real = model_D(y_G) # format D outputs if isinstance(D_out_fake, tuple): scores_fake, feats_fake = D_out_fake if D_out_real is None: scores_real, feats_real = None, None else: scores_real, feats_real = D_out_real else: scores_fake = D_out_fake scores_real = D_out_real # compute losses loss_D_dict = criterion_D(scores_fake, scores_real) for key, value in loss_D_dict.items(): if isinstance(value, (int, float)): loss_dict[key] = value else: loss_dict[key] = value.item() step_time = time.time() - start_time epoch_time += step_time # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): update_eval_values['avg_' + key] = value update_eval_values['avg_loader_time'] = loader_time update_eval_values['avg_step_time'] = step_time keep_avg.update_values(update_eval_values) # print eval stats if c.print_eval: c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if args.rank == 0: # compute spectrograms figures = plot_results(y_hat, y_G, ap, global_step, 'eval') tb_logger.tb_eval_figures(global_step, figures) # Sample audio sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, c.audio["sample_rate"]) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) # synthesize a full voice data_loader.return_segments = False return keep_avg.avg_values