Пример #1
0
    def train_log(self,
                  ap: AudioProcessor,
                  batch: Dict,
                  outputs: List,
                  name_prefix="train"):  # pylint: disable=no-self-use
        """Create visualizations and waveform examples.

        For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
        be projected onto Tensorboard.

        Args:
            ap (AudioProcessor): audio processor used at training.
            batch (Dict): Model inputs used at the previous training step.
            outputs (Dict): Model outputs generated at the previoud training step.

        Returns:
            Tuple[Dict, np.ndarray]: training plots and output waveform.
        """
        y_hat = outputs[0]["model_outputs"]
        y = outputs[0]["waveform_seg"]
        figures = plot_results(y_hat, y, ap, name_prefix)
        sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
        audios = {f"{name_prefix}/audio": sample_voice}

        alignments = outputs[0]["alignments"]
        align_img = alignments[0].data.cpu().numpy().T

        figures.update({
            "alignment":
            plot_alignment(align_img, output_fig=False),
        })

        return figures, audios
Пример #2
0
    def _log(self, ap, batch, outputs, name_prefix="train"):  # pylint: disable=unused-argument,no-self-use
        y_hat = outputs[0]["model_outputs"]
        y = outputs[0]["waveform_seg"]
        figures = plot_results(y_hat, y, ap, name_prefix)
        sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
        audios = {f"{name_prefix}/audio": sample_voice}

        alignments = outputs[0]["alignments"]
        align_img = alignments[0].data.cpu().numpy().T

        figures.update({
            "alignment":
            plot_alignment(align_img, output_fig=False),
        })

        return figures, audios
Пример #3
0
    def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
        """Logging shared by the training and evaluation.

        Args:
            name (str): Name of the run. `train` or `eval`,
            ap (AudioProcessor): Audio processor used in training.
            batch (Dict): Batch used in the last train/eval step.
            outputs (Dict): Model outputs from the last train/eval step.

        Returns:
            Tuple[Dict, Dict]: log figures and audio samples.
        """
        y_hat = outputs[0]["model_outputs"]
        y = batch["waveform"]
        figures = plot_results(y_hat, y, ap, name)
        sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
        audios = {f"{name}/audio": sample_voice}
        return figures, audios
Пример #4
0
 def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict):  # pylint: disable=unused-argument
     # setup noise schedule and inference
     noise_schedule = self.config["test_noise_schedule"]
     betas = np.linspace(noise_schedule["min_val"],
                         noise_schedule["max_val"],
                         noise_schedule["num_steps"])
     self.compute_noise_level(betas)
     for sample in samples:
         x = sample[0]
         x = x[None, :, :].to(next(self.parameters()).device)
         y = sample[1]
         y = y[None, :]
         # compute voice
         y_pred = self.inference(x)
         # compute spectrograms
         figures = plot_results(y_pred, y, ap, "test")
         # Sample audio
         sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
     return figures, {"test/audio": sample_voice}
Пример #5
0
 def test(self, assets: Dict, test_loader: "DataLoader", outputs=None):  # pylint: disable=unused-argument
     # setup noise schedule and inference
     ap = assets["audio_processor"]
     noise_schedule = self.config["test_noise_schedule"]
     betas = np.linspace(noise_schedule["min_val"],
                         noise_schedule["max_val"],
                         noise_schedule["num_steps"])
     self.compute_noise_level(betas)
     samples = test_loader.dataset.load_test_samples(1)
     for sample in samples:
         x = sample[0]
         x = x[None, :, :].to(next(self.parameters()).device)
         y = sample[1]
         y = y[None, :]
         # compute voice
         y_pred = self.inference(x)
         # compute spectrograms
         figures = plot_results(y_pred, y, ap, "test")
         # Sample audio
         sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
     return figures, {"test/audio": sample_voice}
Пример #6
0
def evaluate(model, criterion, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    end_time = time.time()
    c_logger.print_eval_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        m, x = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        # compute noisy input
        if hasattr(model, 'module'):
            noise, x_noisy, noise_scale = model.module.compute_y_n(x)
        else:
            noise, x_noisy, noise_scale = model.compute_y_n(x)


        # forward pass
        noise_hat = model(x_noisy, m, noise_scale)

        # compute losses
        loss = criterion(noise, noise_hat)
        loss_wavegrad_dict = {'wavegrad_loss':loss}


        loss_dict = dict()
        for key, value in loss_wavegrad_dict.items():
            if isinstance(value, (int, float)):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        step_time = time.time() - start_time
        epoch_time += step_time

        # update avg stats
        update_eval_values = dict()
        for key, value in loss_dict.items():
            update_eval_values['avg_' + key] = value
        update_eval_values['avg_loader_time'] = loader_time
        update_eval_values['avg_step_time'] = step_time
        keep_avg.update_values(update_eval_values)

        # print eval stats
        if c.print_eval:
            c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)

    if args.rank == 0:
        data_loader.dataset.return_segments = False
        samples = data_loader.dataset.load_test_samples(1)
        m, x = format_test_data(samples[0])

        # setup noise schedule and inference
        noise_schedule = c['test_noise_schedule']
        betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
        if hasattr(model, 'module'):
            model.module.compute_noise_level(betas)
            # compute voice
            x_pred = model.module.inference(m)
        else:
            model.compute_noise_level(betas)
            # compute voice
            x_pred = model.inference(m)

        # compute spectrograms
        figures = plot_results(x_pred, x, ap, global_step, 'eval')
        tb_logger.tb_eval_figures(global_step, figures)

        # Sample audio
        sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
        tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
                                 c.audio["sample_rate"])

        tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
        data_loader.dataset.return_segments = True

    return keep_avg.avg_values
Пример #7
0
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
          scheduler_G, scheduler_D, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    model_G.train()
    model_D.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        c_G, y_G, c_D, y_D = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        ##############################
        # GENERATOR
        ##############################

        # generator pass
        y_hat = model_G(c_G)
        y_hat_sub = None
        y_G_sub = None
        y_hat_vis = y_hat  # for visualization

        # PQMF formatting
        if y_hat.shape[1] > 1:
            y_hat_sub = y_hat
            y_hat = model_G.pqmf_synthesis(y_hat)
            y_hat_vis = y_hat
            y_G_sub = model_G.pqmf_analysis(y_G)

        scores_fake, feats_fake, feats_real = None, None, None
        if global_step > c.steps_to_start_discriminator:

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat, c_G)
            else:
                D_out_fake = model_D(y_hat)
            D_out_real = None

            if c.use_feat_match_loss:
                with torch.no_grad():
                    D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    feats_real = None
                else:
                    _, feats_real = D_out_real
            else:
                scores_fake = D_out_fake

        # compute losses
        loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
                                  feats_real, y_hat_sub, y_G_sub)
        loss_G = loss_G_dict['G_loss']

        # optimizer generator
        optimizer_G.zero_grad()
        loss_G.backward()
        if c.gen_clip_grad > 0:
            torch.nn.utils.clip_grad_norm_(model_G.parameters(),
                                           c.gen_clip_grad)
        optimizer_G.step()
        if scheduler_G is not None:
            scheduler_G.step()

        loss_dict = dict()
        for key, value in loss_G_dict.items():
            if isinstance(value, int):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        ##############################
        # DISCRIMINATOR
        ##############################
        if global_step >= c.steps_to_start_discriminator:
            # discriminator pass
            with torch.no_grad():
                y_hat = model_G(c_D)

            # PQMF formatting
            if y_hat.shape[1] > 1:
                y_hat = model_G.pqmf_synthesis(y_hat)

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat.detach(), c_D)
                D_out_real = model_D(y_D, c_D)
            else:
                D_out_fake = model_D(y_hat.detach())
                D_out_real = model_D(y_D)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    scores_real, feats_real = None, None
                else:
                    scores_real, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
                scores_real = D_out_real

            # compute losses
            loss_D_dict = criterion_D(scores_fake, scores_real)
            loss_D = loss_D_dict['D_loss']

            # optimizer discriminator
            optimizer_D.zero_grad()
            loss_D.backward()
            if c.disc_clip_grad > 0:
                torch.nn.utils.clip_grad_norm_(model_D.parameters(),
                                               c.disc_clip_grad)
            optimizer_D.step()
            if scheduler_D is not None:
                scheduler_D.step()

            for key, value in loss_D_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict[key] = value
                else:
                    loss_dict[key] = value.item()

        step_time = time.time() - start_time
        epoch_time += step_time

        # get current learning rates
        current_lr_G = list(optimizer_G.param_groups)[0]['lr']
        current_lr_D = list(optimizer_D.param_groups)[0]['lr']

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training stats
        if global_step % c.print_step == 0:
            log_dict = {
                'step_time': [step_time, 2],
                'loader_time': [loader_time, 4],
                "current_lr_G": current_lr_G,
                "current_lr_D": current_lr_D
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # plot step stats
            if global_step % 10 == 0:
                iter_stats = {
                    "lr_G": current_lr_G,
                    "lr_D": current_lr_D,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            # save checkpoint
            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model_G,
                                    optimizer_G,
                                    scheduler_G,
                                    model_D,
                                    optimizer_D,
                                    scheduler_D,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=loss_dict)

                # compute spectrograms
                figures = plot_results(y_hat_vis, y_G, ap, global_step,
                                       'train')
                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
                tb_logger.tb_train_audios(global_step,
                                          {'train/audio': sample_voice},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Training Epoch Stats
    epoch_stats = {"epoch_time": epoch_time}
    epoch_stats.update(keep_avg.avg_values)
    if args.rank == 0:
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
    # TODO: plot model stats
    # if c.tb_model_param_stats:
    # tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Пример #8
0
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step,
             epoch):
    data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
    model_G.eval()
    model_D.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    end_time = time.time()
    c_logger.print_eval_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        c_G, y_G, _, _ = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        ##############################
        # GENERATOR
        ##############################

        # generator pass
        y_hat = model_G(c_G)
        y_hat_sub = None
        y_G_sub = None

        # PQMF formatting
        if y_hat.shape[1] > 1:
            y_hat_sub = y_hat
            y_hat = model_G.pqmf_synthesis(y_hat)
            y_G_sub = model_G.pqmf_analysis(y_G)

        scores_fake, feats_fake, feats_real = None, None, None
        if global_step > c.steps_to_start_discriminator:

            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat, c_G)
            else:
                D_out_fake = model_D(y_hat)
            D_out_real = None

            if c.use_feat_match_loss:
                with torch.no_grad():
                    D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    feats_real = None
                else:
                    _, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
                feats_fake, feats_real = None, None

        # compute losses
        loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
                                  feats_real, y_hat_sub, y_G_sub)

        loss_dict = dict()
        for key, value in loss_G_dict.items():
            if isinstance(value, (int, float)):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        ##############################
        # DISCRIMINATOR
        ##############################

        if global_step >= c.steps_to_start_discriminator:
            # discriminator pass
            with torch.no_grad():
                y_hat = model_G(c_G)

            # PQMF formatting
            if y_hat.shape[1] > 1:
                y_hat = model_G.pqmf_synthesis(y_hat)

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat.detach(), c_G)
                D_out_real = model_D(y_G, c_G)
            else:
                D_out_fake = model_D(y_hat.detach())
                D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    scores_real, feats_real = None, None
                else:
                    scores_real, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
                scores_real = D_out_real

            # compute losses
            loss_D_dict = criterion_D(scores_fake, scores_real)

            for key, value in loss_D_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict[key] = value
                else:
                    loss_dict[key] = value.item()

        step_time = time.time() - start_time
        epoch_time += step_time

        # update avg stats
        update_eval_values = dict()
        for key, value in loss_dict.items():
            update_eval_values['avg_' + key] = value
        update_eval_values['avg_loader_time'] = loader_time
        update_eval_values['avg_step_time'] = step_time
        keep_avg.update_values(update_eval_values)

        # print eval stats
        if c.print_eval:
            c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)

    if args.rank == 0:
        # compute spectrograms
        figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
        tb_logger.tb_eval_figures(global_step, figures)

        # Sample audio
        sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
        tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
                                 c.audio["sample_rate"])

        tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)

    # synthesize a full voice
    data_loader.return_segments = False

    return keep_avg.avg_values