Beispiel #1
0
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
          epoch):

    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
            avg_text_length, avg_spec_length, attn_mask, _ = format_data(data)

        loader_time = time.time() - end_time

        global_step += 1
        optimizer.zero_grad()

        # forward pass model
        with torch.cuda.amp.autocast(enabled=c.mixed_precision):
            z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                attn_mask,
                g=speaker_c)

            # compute loss
            loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
                                  o_dur_log, o_total_dur, text_lengths)

        # backward pass with loss scaling
        if c.mixed_precision:
            scaler.scale(loss_dict['loss']).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       c.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict['loss'].backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       c.grad_clip)
            optimizer.step()

        # setup lr
        if c.noam_schedule:
            scheduler.step()

        # current_lr
        current_lr = optimizer.param_groups[0]['lr']

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments, binary=True)
        loss_dict['align_error'] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data,
                                                 num_gpus)
            loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data,
                                                  num_gpus)
            loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus)

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "avg_spec_length": [avg_spec_length, 1],  # value, precision
                "avg_text_length": [avg_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model,
                                    optimizer,
                                    global_step,
                                    epoch,
                                    1,
                                    OUT_PATH,
                                    model_characters,
                                    model_loss=loss_dict['loss'])

                # wait all kernels to be completed
                torch.cuda.synchronize()

                # Diagnostic visualizations
                # direct pass on model for spec predictions
                target_speaker = None if speaker_c is None else speaker_c[:1]

                if hasattr(model, 'module'):
                    spec_pred, *_ = model.module.inference(text_input[:1],
                                                           text_lengths[:1],
                                                           g=target_speaker)
                else:
                    spec_pred, *_ = model.inference(text_input[:1],
                                                    text_lengths[:1],
                                                    g=target_speaker)

                spec_pred = spec_pred.permute(0, 2, 1)
                gt_spec = mel_input.permute(0, 2, 1)
                const_spec = spec_pred[0].data.cpu().numpy()
                gt_spec = gt_spec[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Beispiel #2
0
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
                _, _, attn_mask, _ = format_data(data)

            # forward pass model
            z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                attn_mask,
                g=speaker_c)

            # compute loss
            loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
                                  o_dur_log, o_total_dur, text_lengths)

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments)
            loss_dict['align_error'] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data,
                                                     num_gpus)
                loss_dict['loss_dur'] = reduce_tensor(
                    loss_dict['loss_dur'].data, num_gpus)
                loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
                                                  num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values['avg_' + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            # direct pass on model for spec predictions
            target_speaker = None if speaker_c is None else speaker_c[:1]
            if hasattr(model, 'module'):
                spec_pred, *_ = model.module.inference(text_input[:1],
                                                       text_lengths[:1],
                                                       g=target_speaker)
            else:
                spec_pred, *_ = model.inference(text_input[:1],
                                                text_lengths[:1],
                                                g=target_speaker)
            spec_pred = spec_pred.permute(0, 2, 1)
            gt_spec = mel_input.permute(0, 2, 1)

            const_spec = spec_pred[0].data.cpu().numpy()
            gt_spec = gt_spec[0].data.cpu().numpy()
            align_img = alignments[0].data.cpu().numpy()

            eval_figures = {
                "prediction": plot_spectrogram(const_spec, ap),
                "ground_truth": plot_spectrogram(gt_spec, ap),
                "alignment": plot_alignment(align_img)
            }

            # Sample audio
            eval_audio = ap.inv_melspectrogram(const_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                     c.audio["sample_rate"])

            # Plot Validation Stats
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch >= c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
                "Be a voice, not an echo.",
                "I'm sorry Dave. I'm afraid I can't do that.",
                "This cake is great. It's so delicious and moist.",
                "Prior to November 22, 1963."
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        if c.use_speaker_embedding:
            if c.use_external_speaker_embedding_file:
                speaker_embedding = speaker_mapping[list(
                    speaker_mapping.keys())[randrange(
                        len(speaker_mapping) - 1)]]['embedding']
                speaker_id = None
            else:
                speaker_id = 0
                speaker_embedding = None
        else:
            speaker_id = None
            speaker_embedding = None

        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, _, postnet_output, _, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    speaker_embedding=speaker_embedding,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  #pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False)

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment)
            except:  #pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
Beispiel #3
0
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            (
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                linear_input,
                stop_targets,
                speaker_ids,
                speaker_embeddings,
                _,
                _,
            ) = format_data(data)
            assert mel_input.shape[1] % model.decoder.r == 0

            # forward pass model
            if c.bidirectional_decoder or c.double_decoder_consistency:
                (
                    decoder_output,
                    postnet_output,
                    alignments,
                    stop_tokens,
                    decoder_backward_output,
                    alignments_backward,
                ) = model(text_input,
                          text_lengths,
                          mel_input,
                          speaker_ids=speaker_ids,
                          speaker_embeddings=speaker_embeddings)
            else:
                decoder_output, postnet_output, alignments, stop_tokens = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings)
                decoder_backward_output = None
                alignments_backward = None

            # set the alignment lengths wrt reduction factor for guided attention
            if mel_lengths.max() % model.decoder.r != 0:
                alignment_lengths = (
                    mel_lengths +
                    (model.decoder.r -
                     (mel_lengths.max() % model.decoder.r))) // model.decoder.r
            else:
                alignment_lengths = mel_lengths // model.decoder.r

            # compute loss
            loss_dict = criterion(
                postnet_output,
                decoder_output,
                mel_input,
                linear_input,
                stop_tokens,
                stop_targets,
                mel_lengths,
                decoder_backward_output,
                alignments,
                alignment_lengths,
                alignments_backward,
                text_lengths,
            )

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments)
            loss_dict["align_error"] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict["postnet_loss"] = reduce_tensor(
                    loss_dict["postnet_loss"].data, num_gpus)
                loss_dict["decoder_loss"] = reduce_tensor(
                    loss_dict["decoder_loss"].data, num_gpus)
                if c.stopnet:
                    loss_dict["stopnet_loss"] = reduce_tensor(
                        loss_dict["stopnet_loss"].data, num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values["avg_" + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            idx = np.random.randint(mel_input.shape[0])
            const_spec = postnet_output[idx].data.cpu().numpy()
            gt_spec = (linear_input[idx].data.cpu().numpy() if c.model in [
                "Tacotron", "TacotronGST"
            ] else mel_input[idx].data.cpu().numpy())
            align_img = alignments[idx].data.cpu().numpy()

            eval_figures = {
                "prediction": plot_spectrogram(const_spec,
                                               ap,
                                               output_fig=False),
                "ground_truth": plot_spectrogram(gt_spec, ap,
                                                 output_fig=False),
                "alignment": plot_alignment(align_img, output_fig=False),
            }

            # Sample audio
            if c.model in ["Tacotron", "TacotronGST"]:
                eval_audio = ap.inv_spectrogram(const_spec.T)
            else:
                eval_audio = ap.inv_melspectrogram(const_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                     c.audio["sample_rate"])

            # Plot Validation Stats

            if c.bidirectional_decoder or c.double_decoder_consistency:
                align_b_img = alignments_backward[idx].data.cpu().numpy()
                eval_figures["alignment2"] = plot_alignment(align_b_img,
                                                            output_fig=False)
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch > c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
                "Be a voice, not an echo.",
                "I'm sorry Dave. I'm afraid I can't do that.",
                "This cake is great. It's so delicious and moist.",
                "Prior to November 22, 1963.",
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        speaker_id = 0 if c.use_speaker_embedding else None
        speaker_embedding = (speaker_mapping[list(
            speaker_mapping.keys())[randrange(len(speaker_mapping) -
                                              1)]]["embedding"]
                             if c.use_external_speaker_embedding_file
                             and c.use_speaker_embedding else None)
        style_wav = c.get("gst_style_input")
        if style_wav is None and c.use_gst:
            # inicialize GST with zero dict.
            style_wav = {}
            print(
                "WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!"
            )
            for i in range(c.gst["gst_style_tokens"]):
                style_wav[str(i)] = 0
        style_wav = c.get("gst_style_input")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    speaker_embedding=speaker_embedding,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  # pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False,
                )

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios["{}-audio".format(idx)] = wav
                test_figures["{}-prediction".format(idx)] = plot_spectrogram(
                    postnet_output, ap, output_fig=False)
                test_figures["{}-alignment".format(idx)] = plot_alignment(
                    alignment, output_fig=False)
            except:  # pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio["sample_rate"])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
Beispiel #4
0
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
          ap, global_step, epoch, scaler, scaler_st):
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        (
            text_input,
            text_lengths,
            mel_input,
            mel_lengths,
            linear_input,
            stop_targets,
            speaker_ids,
            speaker_embeddings,
            max_text_length,
            max_spec_length,
        ) = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()

        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        with torch.cuda.amp.autocast(enabled=c.mixed_precision):
            # forward pass model
            if c.bidirectional_decoder or c.double_decoder_consistency:
                (
                    decoder_output,
                    postnet_output,
                    alignments,
                    stop_tokens,
                    decoder_backward_output,
                    alignments_backward,
                ) = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    mel_lengths,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings,
                )
            else:
                decoder_output, postnet_output, alignments, stop_tokens = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    mel_lengths,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings,
                )
                decoder_backward_output = None
                alignments_backward = None

            # set the [alignment] lengths wrt reduction factor for guided attention
            if mel_lengths.max() % model.decoder.r != 0:
                alignment_lengths = (
                    mel_lengths +
                    (model.decoder.r -
                     (mel_lengths.max() % model.decoder.r))) // model.decoder.r
            else:
                alignment_lengths = mel_lengths // model.decoder.r

            # compute loss
            loss_dict = criterion(
                postnet_output,
                decoder_output,
                mel_input,
                linear_input,
                stop_tokens,
                stop_targets,
                mel_lengths,
                decoder_backward_output,
                alignments,
                alignment_lengths,
                alignments_backward,
                text_lengths,
            )

        # check nan loss
        if torch.isnan(loss_dict["loss"]).any():
            raise RuntimeError(f"Detected NaN loss at step {global_step}.")

        # optimizer step
        if c.mixed_precision:
            # model optimizer step in mixed precision mode
            scaler.scale(loss_dict["loss"]).backward()
            scaler.unscale_(optimizer)
            optimizer, current_lr = adam_weight_decay(optimizer)
            grad_norm, _ = check_update(model,
                                        c.grad_clip,
                                        ignore_stopnet=True)
            scaler.step(optimizer)
            scaler.update()

            # stopnet optimizer step
            if c.separate_stopnet:
                scaler_st.scale(loss_dict["stopnet_loss"]).backward()
                scaler.unscale_(optimizer_st)
                optimizer_st, _ = adam_weight_decay(optimizer_st)
                grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
                scaler_st.step(optimizer)
                scaler_st.update()
            else:
                grad_norm_st = 0
        else:
            # main model optimizer step
            loss_dict["loss"].backward()
            optimizer, current_lr = adam_weight_decay(optimizer)
            grad_norm, _ = check_update(model,
                                        c.grad_clip,
                                        ignore_stopnet=True)
            optimizer.step()

            # stopnet optimizer step
            if c.separate_stopnet:
                loss_dict["stopnet_loss"].backward()
                optimizer_st, _ = adam_weight_decay(optimizer_st)
                grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
                optimizer_st.step()
            else:
                grad_norm_st = 0

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        loss_dict["align_error"] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict["postnet_loss"] = reduce_tensor(
                loss_dict["postnet_loss"].data, num_gpus)
            loss_dict["decoder_loss"] = reduce_tensor(
                loss_dict["decoder_loss"].data, num_gpus)
            loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
            loss_dict["stopnet_loss"] = (reduce_tensor(
                loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else
                                         loss_dict["stopnet_loss"])

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values["avg_" + key] = value
        update_train_values["avg_loader_time"] = loader_time
        update_train_values["avg_step_time"] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "max_spec_length": [max_spec_length, 1],  # value, precision
                "max_text_length": [max_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time,
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        model.decoder.r,
                        OUT_PATH,
                        optimizer_st=optimizer_st,
                        model_loss=loss_dict["postnet_loss"],
                        characters=model_characters,
                        scaler=scaler.state_dict()
                        if c.mixed_precision else None,
                    )

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = (linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy())
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction":
                    plot_spectrogram(const_spec, ap, output_fig=False),
                    "ground_truth":
                    plot_spectrogram(gt_spec, ap, output_fig=False),
                    "alignment":
                    plot_alignment(align_img, output_fig=False),
                }

                if c.bidirectional_decoder or c.double_decoder_consistency:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy(),
                        output_fig=False)

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {"TrainAudio": train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Beispiel #5
0
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
                _, _, _, dur_target, _ = format_data(data)

            # forward pass model
            with torch.cuda.amp.autocast(enabled=c.mixed_precision):
                decoder_output, dur_output, alignments = model.forward(
                    text_input,
                    text_lengths,
                    mel_lengths,
                    dur_target,
                    g=speaker_c)

                # compute loss
                loss_dict = criterion(decoder_output, mel_targets,
                                      mel_lengths, dur_output,
                                      torch.log(1 + dur_target), text_lengths)

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments, binary=True)
            loss_dict['align_error'] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data,
                                                     num_gpus)
                loss_dict['loss_ssim'] = reduce_tensor(
                    loss_dict['loss_ssim'].data, num_gpus)
                loss_dict['loss_dur'] = reduce_tensor(
                    loss_dict['loss_dur'].data, num_gpus)
                loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
                                                  num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values['avg_' + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            idx = np.random.randint(mel_targets.shape[0])
            pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
            gt_spec = mel_targets[idx].data.cpu().numpy().T
            align_img = alignments[idx].data.cpu()

            eval_figures = {
                "prediction": plot_spectrogram(pred_spec, ap,
                                               output_fig=False),
                "ground_truth": plot_spectrogram(gt_spec, ap,
                                                 output_fig=False),
                "alignment": plot_alignment(align_img, output_fig=False)
            }

            # Sample audio
            eval_audio = ap.inv_melspectrogram(pred_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                     c.audio["sample_rate"])

            # Plot Validation Stats
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch >= c.test_delay_epochs and epoch % c.test_every_epochs == 0:
        if c.test_sentences_file is None:
            test_sentences = [
                "ජනක ප්‍රදීප් ලියනගේ.",
                "රගර් ගැහුවා කියල කොහොමද බූරුවො වොලි බෝල් නැති වෙන්නෙ.",
                "රට්ඨපාල කුමරු ගිහිගෙය හැර පැවිදි වී සිටියි.",
                "අජාසත් රජතුමාගේ ඇත් සේනාවේ අති භයානක ඇතෙක් සිටියා."
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        if c.use_speaker_embedding:
            if c.use_external_speaker_embedding_file:
                speaker_embedding = speaker_mapping[list(
                    speaker_mapping.keys())[randrange(
                        len(speaker_mapping) - 1)]]['embedding']
                speaker_id = None
            else:
                speaker_id = 0
                speaker_embedding = None
        else:
            speaker_id = None
            speaker_embedding = None

        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, _, postnet_output, _, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    speaker_embedding=speaker_embedding,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  #pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False)

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment)
            except:  #pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
Beispiel #6
0
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _, dur_target, _ = format_data(data)

            # forward pass model
            with torch.cuda.amp.autocast(enabled=c.mixed_precision):
                decoder_output, dur_output, alignments = model.forward(
                    text_input, text_lengths, mel_lengths, dur_target, g=speaker_c
                )

                # compute loss
                loss_dict = criterion(
                    decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths
                )

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments, binary=True)
            loss_dict["align_error"] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
                loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
                loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
                loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values["avg_" + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            idx = np.random.randint(mel_targets.shape[0])
            pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
            gt_spec = mel_targets[idx].data.cpu().numpy().T
            align_img = alignments[idx].data.cpu()

            eval_figures = {
                "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
                "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
                "alignment": plot_alignment(align_img, output_fig=False),
            }

            # Sample audio
            eval_audio = ap.inv_melspectrogram(pred_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])

            # Plot Validation Stats
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch >= c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
                "Be a voice, not an echo.",
                "I'm sorry Dave. I'm afraid I can't do that.",
                "This cake is great. It's so delicious and moist.",
                "Prior to November 22, 1963.",
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        if c.use_speaker_embedding:
            if c.use_external_speaker_embedding_file:
                speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
                    "embedding"
                ]
                speaker_id = None
            else:
                speaker_id = 0
                speaker_embedding = None
        else:
            speaker_id = None
            speaker_embedding = None

        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, _, postnet_output, _, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    speaker_embedding=speaker_embedding,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  # pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False,
                )

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios["{}-audio".format(idx)] = wav
                test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
                test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
            except:  # pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
          epoch):

    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (config.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / config.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        (
            text_input,
            text_lengths,
            mel_targets,
            mel_lengths,
            speaker_c,
            avg_text_length,
            avg_spec_length,
            _,
            dur_target,
            _,
        ) = format_data(data)

        loader_time = time.time() - end_time

        global_step += 1
        optimizer.zero_grad()

        # forward pass model
        with torch.cuda.amp.autocast(enabled=config.mixed_precision):
            decoder_output, dur_output, alignments = model.forward(
                text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)

            # compute loss
            loss_dict = criterion(decoder_output, mel_targets,
                                  mel_lengths, dur_output,
                                  torch.log(1 + dur_target), text_lengths)

        # backward pass with loss scaling
        if config.mixed_precision:
            scaler.scale(loss_dict["loss"]).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       config.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict["loss"].backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       config.grad_clip)
            optimizer.step()

        # setup lr
        if config.noam_schedule:
            scheduler.step()

        # current_lr
        current_lr = optimizer.param_groups[0]["lr"]

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments, binary=True)
        loss_dict["align_error"] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data,
                                                 num_gpus)
            loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data,
                                                   num_gpus)
            loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data,
                                                  num_gpus)
            loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values["avg_" + key] = value
        update_train_values["avg_loader_time"] = loader_time
        update_train_values["avg_step_time"] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % config.print_step == 0:
            log_dict = {
                "avg_spec_length": [avg_spec_length, 1],  # value, precision
                "avg_text_length": [avg_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % config.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % config.save_step == 0:
                if config.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        1,
                        OUT_PATH,
                        model_characters,
                        model_loss=loss_dict["loss"],
                    )

                # wait all kernels to be completed
                torch.cuda.synchronize()

                # Diagnostic visualizations
                idx = np.random.randint(mel_targets.shape[0])
                pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
                gt_spec = mel_targets[idx].data.cpu().numpy().T
                align_img = alignments[idx].data.cpu()

                figures = {
                    "prediction": plot_spectrogram(pred_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                train_audio = ap.inv_melspectrogram(pred_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {"TrainAudio": train_audio},
                                          config.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if config.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step