Exemple #1
0
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            (
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                linear_input,
                stop_targets,
                speaker_ids,
                speaker_embeddings,
                _,
                _,
            ) = format_data(data)
            assert mel_input.shape[1] % model.decoder.r == 0

            # forward pass model
            if c.bidirectional_decoder or c.double_decoder_consistency:
                (
                    decoder_output,
                    postnet_output,
                    alignments,
                    stop_tokens,
                    decoder_backward_output,
                    alignments_backward,
                ) = model(text_input,
                          text_lengths,
                          mel_input,
                          speaker_ids=speaker_ids,
                          speaker_embeddings=speaker_embeddings)
            else:
                decoder_output, postnet_output, alignments, stop_tokens = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings)
                decoder_backward_output = None
                alignments_backward = None

            # set the alignment lengths wrt reduction factor for guided attention
            if mel_lengths.max() % model.decoder.r != 0:
                alignment_lengths = (
                    mel_lengths +
                    (model.decoder.r -
                     (mel_lengths.max() % model.decoder.r))) // model.decoder.r
            else:
                alignment_lengths = mel_lengths // model.decoder.r

            # compute loss
            loss_dict = criterion(
                postnet_output,
                decoder_output,
                mel_input,
                linear_input,
                stop_tokens,
                stop_targets,
                mel_lengths,
                decoder_backward_output,
                alignments,
                alignment_lengths,
                alignments_backward,
                text_lengths,
            )

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments)
            loss_dict["align_error"] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict["postnet_loss"] = reduce_tensor(
                    loss_dict["postnet_loss"].data, num_gpus)
                loss_dict["decoder_loss"] = reduce_tensor(
                    loss_dict["decoder_loss"].data, num_gpus)
                if c.stopnet:
                    loss_dict["stopnet_loss"] = reduce_tensor(
                        loss_dict["stopnet_loss"].data, num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values["avg_" + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            idx = np.random.randint(mel_input.shape[0])
            const_spec = postnet_output[idx].data.cpu().numpy()
            gt_spec = (linear_input[idx].data.cpu().numpy() if c.model in [
                "Tacotron", "TacotronGST"
            ] else mel_input[idx].data.cpu().numpy())
            align_img = alignments[idx].data.cpu().numpy()

            eval_figures = {
                "prediction": plot_spectrogram(const_spec,
                                               ap,
                                               output_fig=False),
                "ground_truth": plot_spectrogram(gt_spec, ap,
                                                 output_fig=False),
                "alignment": plot_alignment(align_img, output_fig=False),
            }

            # Sample audio
            if c.model in ["Tacotron", "TacotronGST"]:
                eval_audio = ap.inv_spectrogram(const_spec.T)
            else:
                eval_audio = ap.inv_melspectrogram(const_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                     c.audio["sample_rate"])

            # Plot Validation Stats

            if c.bidirectional_decoder or c.double_decoder_consistency:
                align_b_img = alignments_backward[idx].data.cpu().numpy()
                eval_figures["alignment2"] = plot_alignment(align_b_img,
                                                            output_fig=False)
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch > c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
                "Be a voice, not an echo.",
                "I'm sorry Dave. I'm afraid I can't do that.",
                "This cake is great. It's so delicious and moist.",
                "Prior to November 22, 1963.",
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        speaker_id = 0 if c.use_speaker_embedding else None
        speaker_embedding = (speaker_mapping[list(
            speaker_mapping.keys())[randrange(len(speaker_mapping) -
                                              1)]]["embedding"]
                             if c.use_external_speaker_embedding_file
                             and c.use_speaker_embedding else None)
        style_wav = c.get("gst_style_input")
        if style_wav is None and c.use_gst:
            # inicialize GST with zero dict.
            style_wav = {}
            print(
                "WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!"
            )
            for i in range(c.gst["gst_style_tokens"]):
                style_wav[str(i)] = 0
        style_wav = c.get("gst_style_input")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    speaker_embedding=speaker_embedding,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  # pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False,
                )

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios["{}-audio".format(idx)] = wav
                test_figures["{}-prediction".format(idx)] = plot_spectrogram(
                    postnet_output, ap, output_fig=False)
                test_figures["{}-alignment".format(idx)] = plot_alignment(
                    alignment, output_fig=False)
            except:  # pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio["sample_rate"])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
Exemple #2
0
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
    data_loader = setup_loader(ap, model.decoder.r, is_val=True)
    if c.use_speaker_embedding:
        speaker_mapping = load_speaker_mapping(OUT_PATH)
    model.eval()
    epoch_time = 0
    eval_values_dict = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stop_loss': 0,
        'avg_align_score': 0
    }
    if c.bidirectional_decoder:
        eval_values_dict['avg_decoder_b_loss'] = 0  # decoder backward loss
        eval_values_dict['avg_decoder_c_loss'] = 0  # decoder consistency loss
    keep_avg = KeepAverage()
    keep_avg.add_values(eval_values_dict)
    print("\n > Validation")

    with torch.no_grad():
        if data_loader is not None:
            for num_iter, data in enumerate(data_loader):
                start_time = time.time()

                # format data
                text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(
                    data)
                assert mel_input.shape[1] % model.decoder.r == 0

                # forward pass model
                if c.bidirectional_decoder:
                    decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                        text_input,
                        text_lengths,
                        mel_input,
                        speaker_ids=speaker_ids)
                else:
                    decoder_output, postnet_output, alignments, stop_tokens = model(
                        text_input,
                        text_lengths,
                        mel_input,
                        speaker_ids=speaker_ids)

                # loss computation
                stop_loss = criterion_st(
                    stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
                if c.loss_masking:
                    decoder_loss = criterion(decoder_output, mel_input,
                                             mel_lengths)
                    if c.model in ["Tacotron", "TacotronGST"]:
                        postnet_loss = criterion(postnet_output, linear_input,
                                                 mel_lengths)
                    else:
                        postnet_loss = criterion(postnet_output, mel_input,
                                                 mel_lengths)
                else:
                    decoder_loss = criterion(decoder_output, mel_input)
                    if c.model in ["Tacotron", "TacotronGST"]:
                        postnet_loss = criterion(postnet_output, linear_input)
                    else:
                        postnet_loss = criterion(postnet_output, mel_input)
                loss = decoder_loss + postnet_loss + stop_loss

                # backward decoder loss
                if c.bidirectional_decoder:
                    if c.loss_masking:
                        decoder_backward_loss = criterion(
                            torch.flip(decoder_backward_output, dims=(1, )),
                            mel_input, mel_lengths)
                    else:
                        decoder_backward_loss = criterion(
                            torch.flip(decoder_backward_output, dims=(1, )),
                            mel_input)
                    decoder_c_loss = torch.nn.functional.l1_loss(
                        torch.flip(decoder_backward_output, dims=(1, )),
                        decoder_output)
                    loss += decoder_backward_loss + decoder_c_loss
                    keep_avg.update_values({
                        'avg_decoder_b_loss':
                        decoder_backward_loss.item(),
                        'avg_decoder_c_loss':
                        decoder_c_loss.item()
                    })

                step_time = time.time() - start_time
                epoch_time += step_time

                # compute alignment score
                align_score = alignment_diagonal_score(alignments)
                keep_avg.update_value('avg_align_score', align_score)

                # aggregate losses from processes
                if num_gpus > 1:
                    postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
                    decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
                    if c.stopnet:
                        stop_loss = reduce_tensor(stop_loss.data, num_gpus)

                keep_avg.update_values({
                    'avg_postnet_loss':
                    float(postnet_loss.item()),
                    'avg_decoder_loss':
                    float(decoder_loss.item()),
                    'avg_stop_loss':
                    float(stop_loss.item()),
                })

                if num_iter % c.print_step == 0:
                    print(
                        "   | > TotalLoss: {:.5f}   PostnetLoss: {:.5f} - {:.5f}  DecoderLoss:{:.5f} - {:.5f} "
                        "StopLoss: {:.5f} - {:.5f}  AlignScore: {:.4f} : {:.4f}"
                        .format(loss.item(), postnet_loss.item(),
                                keep_avg['avg_postnet_loss'],
                                decoder_loss.item(),
                                keep_avg['avg_decoder_loss'], stop_loss.item(),
                                keep_avg['avg_stop_loss'], align_score,
                                keep_avg['avg_align_score']),
                        flush=True)

            if args.rank == 0:
                # Diagnostic visualizations
                idx = np.random.randint(mel_input.shape[0])
                const_spec = postnet_output[idx].data.cpu().numpy()
                gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[idx].data.cpu().numpy()
                align_img = alignments[idx].data.cpu().numpy()

                eval_figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    eval_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    eval_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                         c.audio["sample_rate"])

                # Plot Validation Stats
                epoch_stats = {
                    "loss_postnet": keep_avg['avg_postnet_loss'],
                    "loss_decoder": keep_avg['avg_decoder_loss'],
                    "stop_loss": keep_avg['avg_stop_loss'],
                    "alignment_score": keep_avg['avg_align_score']
                }

                if c.bidirectional_decoder:
                    epoch_stats['loss_decoder_backward'] = keep_avg[
                        'avg_decoder_b_loss']
                    align_b_img = alignments_backward[idx].data.cpu().numpy()
                    eval_figures['alignment_backward'] = plot_alignment(
                        align_b_img)
                tb_logger.tb_eval_stats(global_step, epoch_stats)
                tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch > c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
                "Be a voice, not an echo.",
                "I'm sorry Dave. I'm afraid I can't do that.",
                "This cake is great. It's so delicious and moist."
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        speaker_id = 0 if c.use_speaker_embedding else None
        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    style_wav=style_wav)
                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment)
            except:
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg['avg_postnet_loss']
Exemple #3
0
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
          ap, global_step, epoch, scaler, scaler_st):
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        (
            text_input,
            text_lengths,
            mel_input,
            mel_lengths,
            linear_input,
            stop_targets,
            speaker_ids,
            speaker_embeddings,
            max_text_length,
            max_spec_length,
        ) = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()

        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        with torch.cuda.amp.autocast(enabled=c.mixed_precision):
            # forward pass model
            if c.bidirectional_decoder or c.double_decoder_consistency:
                (
                    decoder_output,
                    postnet_output,
                    alignments,
                    stop_tokens,
                    decoder_backward_output,
                    alignments_backward,
                ) = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    mel_lengths,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings,
                )
            else:
                decoder_output, postnet_output, alignments, stop_tokens = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    mel_lengths,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings,
                )
                decoder_backward_output = None
                alignments_backward = None

            # set the [alignment] lengths wrt reduction factor for guided attention
            if mel_lengths.max() % model.decoder.r != 0:
                alignment_lengths = (
                    mel_lengths +
                    (model.decoder.r -
                     (mel_lengths.max() % model.decoder.r))) // model.decoder.r
            else:
                alignment_lengths = mel_lengths // model.decoder.r

            # compute loss
            loss_dict = criterion(
                postnet_output,
                decoder_output,
                mel_input,
                linear_input,
                stop_tokens,
                stop_targets,
                mel_lengths,
                decoder_backward_output,
                alignments,
                alignment_lengths,
                alignments_backward,
                text_lengths,
            )

        # check nan loss
        if torch.isnan(loss_dict["loss"]).any():
            raise RuntimeError(f"Detected NaN loss at step {global_step}.")

        # optimizer step
        if c.mixed_precision:
            # model optimizer step in mixed precision mode
            scaler.scale(loss_dict["loss"]).backward()
            scaler.unscale_(optimizer)
            optimizer, current_lr = adam_weight_decay(optimizer)
            grad_norm, _ = check_update(model,
                                        c.grad_clip,
                                        ignore_stopnet=True)
            scaler.step(optimizer)
            scaler.update()

            # stopnet optimizer step
            if c.separate_stopnet:
                scaler_st.scale(loss_dict["stopnet_loss"]).backward()
                scaler.unscale_(optimizer_st)
                optimizer_st, _ = adam_weight_decay(optimizer_st)
                grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
                scaler_st.step(optimizer)
                scaler_st.update()
            else:
                grad_norm_st = 0
        else:
            # main model optimizer step
            loss_dict["loss"].backward()
            optimizer, current_lr = adam_weight_decay(optimizer)
            grad_norm, _ = check_update(model,
                                        c.grad_clip,
                                        ignore_stopnet=True)
            optimizer.step()

            # stopnet optimizer step
            if c.separate_stopnet:
                loss_dict["stopnet_loss"].backward()
                optimizer_st, _ = adam_weight_decay(optimizer_st)
                grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
                optimizer_st.step()
            else:
                grad_norm_st = 0

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        loss_dict["align_error"] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict["postnet_loss"] = reduce_tensor(
                loss_dict["postnet_loss"].data, num_gpus)
            loss_dict["decoder_loss"] = reduce_tensor(
                loss_dict["decoder_loss"].data, num_gpus)
            loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
            loss_dict["stopnet_loss"] = (reduce_tensor(
                loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else
                                         loss_dict["stopnet_loss"])

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values["avg_" + key] = value
        update_train_values["avg_loader_time"] = loader_time
        update_train_values["avg_step_time"] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "max_spec_length": [max_spec_length, 1],  # value, precision
                "max_text_length": [max_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time,
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        model.decoder.r,
                        OUT_PATH,
                        optimizer_st=optimizer_st,
                        model_loss=loss_dict["postnet_loss"],
                        characters=model_characters,
                        scaler=scaler.state_dict()
                        if c.mixed_precision else None,
                    )

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = (linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy())
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction":
                    plot_spectrogram(const_spec, ap, output_fig=False),
                    "ground_truth":
                    plot_spectrogram(gt_spec, ap, output_fig=False),
                    "alignment":
                    plot_alignment(align_img, output_fig=False),
                }

                if c.bidirectional_decoder or c.double_decoder_consistency:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy(),
                        output_fig=False)

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {"TrainAudio": train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Exemple #4
0
def train(model, criterion, optimizer, optimizer_st, scheduler, ap,
          global_step, epoch):
    data_loader = setup_loader(ap,
                               model.decoder.r,
                               is_val=False,
                               verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    train_values = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stopnet_loss': 0,
        'avg_align_error': 0,
        'avg_step_time': 0,
        'avg_loader_time': 0
    }
    if c.bidirectional_decoder:
        train_values['avg_decoder_b_loss'] = 0  # decoder backward loss
        train_values['avg_decoder_c_loss'] = 0  # decoder consistency loss
    if c.ga_alpha > 0:
        train_values['avg_ga_loss'] = 0  # guidede attention loss
    keep_avg = KeepAverage()
    keep_avg.add_values(train_values)
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(
            data)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # forward pass model
        if c.bidirectional_decoder:
            decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
        else:
            decoder_output, postnet_output, alignments, stop_tokens = model(
                text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
            decoder_backward_output = None

        # set the alignment lengths wrt reduction factor for guided attention
        if mel_lengths.max() % model.decoder.r != 0:
            alignment_lengths = (
                mel_lengths +
                (model.decoder.r -
                 (mel_lengths.max() % model.decoder.r))) // model.decoder.r
        else:
            alignment_lengths = mel_lengths // model.decoder.r

        # compute loss
        loss_dict = criterion(postnet_output, decoder_output, mel_input,
                              linear_input, stop_tokens, stop_targets,
                              mel_lengths, decoder_backward_output, alignments,
                              alignment_lengths, text_lengths)
        if c.bidirectional_decoder:
            keep_avg.update_values({
                'avg_decoder_b_loss':
                loss_dict['decoder_backward_loss'].item(),
                'avg_decoder_c_loss':
                loss_dict['decoder_c_loss'].item()
            })
        if c.ga_alpha > 0:
            keep_avg.update_values(
                {'avg_ga_loss': loss_dict['ga_loss'].item()})

        # backward pass
        loss_dict['loss'].backward()
        optimizer, current_lr = adam_weight_decay(optimizer)
        grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
        optimizer.step()

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        keep_avg.update_value('avg_align_error', align_error)
        loss_dict['align_error'] = align_error

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            loss_dict['stopnet_loss'].backward()
            optimizer_st, _ = adam_weight_decay(optimizer_st)
            grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        epoch_time += step_time

        # update avg stats
        update_train_values = {
            'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
            'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
            'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \
                if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
            'avg_step_time': step_time,
            'avg_loader_time': loader_time
        }
        keep_avg.update_values(update_train_values)

        if global_step % c.print_step == 0:
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      avg_spec_length, avg_text_length,
                                      step_time, loader_time, current_lr,
                                      loss_dict, keep_avg.avg_values)

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict['postnet_loss'] = reduce_tensor(
                loss_dict['postnet_loss'].data, num_gpus)
            loss_dict['decoder_loss'] = reduce_tensor(
                loss_dict['decoder_loss'].data, num_gpus)
            loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus)
            loss_dict['stopnet_loss'] = reduce_tensor(
                loss_dict['stopnet_loss'].data,
                num_gpus) if c.stopnet else loss_dict['stopnet_loss']

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % 10 == 0:
                iter_stats = {
                    "loss_posnet": loss_dict['postnet_loss'].item(),
                    "loss_decoder": loss_dict['decoder_loss'].item(),
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time
                }
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        model.decoder.r,
                        OUT_PATH,
                        optimizer_st=optimizer_st,
                        model_loss=loss_dict['postnet_loss'].item())

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                if c.bidirectional_decoder:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy())

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": keep_avg['avg_postnet_loss'],
            "loss_decoder": keep_avg['avg_decoder_loss'],
            "stopnet_loss": keep_avg['avg_stopnet_loss'],
            "alignment_score": keep_avg['avg_align_error'],
            "epoch_time": epoch_time
        }
        if c.ga_alpha > 0:
            epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss']
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Exemple #5
0
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
          ap, global_step, epoch):
    data_loader = setup_loader(ap,
                               model.decoder.r,
                               is_val=False,
                               verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    train_values = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stop_loss': 0,
        'avg_align_score': 0,
        'avg_step_time': 0,
        'avg_loader_time': 0,
        'avg_alignment_score': 0
    }
    if c.bidirectional_decoder:
        train_values['avg_decoder_b_loss'] = 0  # decoder backward loss
        train_values['avg_decoder_c_loss'] = 0  # decoder consistency loss
    keep_avg = KeepAverage()
    keep_avg.add_values(train_values)
    print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(
            data)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # forward pass model
        if c.bidirectional_decoder:
            decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
        else:
            decoder_output, postnet_output, alignments, stop_tokens = model(
                text_input, text_lengths, mel_input, speaker_ids=speaker_ids)

        # loss computation
        stop_loss = criterion_st(stop_tokens,
                                 stop_targets) if c.stopnet else torch.zeros(1)
        if c.loss_masking:
            decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input,
                                         mel_lengths)
            else:
                postnet_loss = criterion(postnet_output, mel_input,
                                         mel_lengths)
        else:
            decoder_loss = criterion(decoder_output, mel_input)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input)
            else:
                postnet_loss = criterion(postnet_output, mel_input)
        loss = decoder_loss + postnet_loss
        if not c.separate_stopnet and c.stopnet:
            loss += stop_loss

        # backward decoder
        if c.bidirectional_decoder:
            if c.loss_masking:
                decoder_backward_loss = criterion(
                    torch.flip(decoder_backward_output, dims=(1, )), mel_input,
                    mel_lengths)
            else:
                decoder_backward_loss = criterion(
                    torch.flip(decoder_backward_output, dims=(1, )), mel_input)
            decoder_c_loss = torch.nn.functional.l1_loss(
                torch.flip(decoder_backward_output, dims=(1, )),
                decoder_output)
            loss += decoder_backward_loss + decoder_c_loss
            keep_avg.update_values({
                'avg_decoder_b_loss':
                decoder_backward_loss.item(),
                'avg_decoder_c_loss':
                decoder_c_loss.item()
            })

        loss.backward()
        optimizer, current_lr = adam_weight_decay(optimizer)
        grad_norm, grad_flag = check_update(model,
                                            c.grad_clip,
                                            ignore_stopnet=True)
        optimizer.step()

        # compute alignment score
        align_score = alignment_diagonal_score(alignments)
        keep_avg.update_value('avg_align_score', align_score)

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            stop_loss.backward()
            optimizer_st, _ = adam_weight_decay(optimizer_st)
            grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        epoch_time += step_time

        if global_step % c.print_step == 0:
            print(
                "   | > Step:{}/{}  GlobalStep:{}  PostnetLoss:{:.5f}  "
                "DecoderLoss:{:.5f}  StopLoss:{:.5f}  AlignScore:{:.4f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  "
                "LoaderTime:{:.2f}  LR:{:.6f}".format(
                    num_iter, batch_n_iter, global_step, postnet_loss.item(),
                    decoder_loss.item(), stop_loss.item(), align_score,
                    grad_norm, grad_norm_st, avg_text_length, avg_spec_length,
                    step_time, loader_time, current_lr),
                flush=True)

        # aggregate losses from processes
        if num_gpus > 1:
            postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
            decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
            loss = reduce_tensor(loss.data, num_gpus)
            stop_loss = reduce_tensor(stop_loss.data,
                                      num_gpus) if c.stopnet else stop_loss

        if args.rank == 0:
            update_train_values = {
                'avg_postnet_loss':
                float(postnet_loss.item()),
                'avg_decoder_loss':
                float(decoder_loss.item()),
                'avg_stop_loss':
                stop_loss
                if isinstance(stop_loss, float) else float(stop_loss.item()),
                'avg_step_time':
                step_time,
                'avg_loader_time':
                loader_time
            }
            keep_avg.update_values(update_train_values)

            # Plot Training Iter Stats
            # reduce TB load
            if global_step % 10 == 0:
                iter_stats = {
                    "loss_posnet": postnet_loss.item(),
                    "loss_decoder": decoder_loss.item(),
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time
                }
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model, optimizer, optimizer_st,
                                    postnet_loss.item(), OUT_PATH, global_step,
                                    epoch)

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                if c.bidirectional_decoder:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy())

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    print("   | > EPOCH END -- GlobalStep:{}  "
          "AvgPostnetLoss:{:.5f}  AvgDecoderLoss:{:.5f}  "
          "AvgStopLoss:{:.5f}  AvgAlignScore:{:3f}  EpochTime:{:.2f}  "
          "AvgStepTime:{:.2f}  AvgLoaderTime:{:.2f}".format(
              global_step, keep_avg['avg_postnet_loss'],
              keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
              keep_avg['avg_align_score'], epoch_time,
              keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
          flush=True)
    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": keep_avg['avg_postnet_loss'],
            "loss_decoder": keep_avg['avg_decoder_loss'],
            "stop_loss": keep_avg['avg_stop_loss'],
            "alignment_score": keep_avg['avg_align_score'],
            "epoch_time": epoch_time
        }
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg['avg_postnet_loss'], global_step
Exemple #6
0
def evaluate(model, criterion, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    end_time = time.time()
    c_logger.print_eval_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        m, x = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        # compute noisy input
        if hasattr(model, 'module'):
            noise, x_noisy, noise_scale = model.module.compute_y_n(x)
        else:
            noise, x_noisy, noise_scale = model.compute_y_n(x)


        # forward pass
        noise_hat = model(x_noisy, m, noise_scale)

        # compute losses
        loss = criterion(noise, noise_hat)
        loss_wavegrad_dict = {'wavegrad_loss':loss}


        loss_dict = dict()
        for key, value in loss_wavegrad_dict.items():
            if isinstance(value, (int, float)):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        step_time = time.time() - start_time
        epoch_time += step_time

        # update avg stats
        update_eval_values = dict()
        for key, value in loss_dict.items():
            update_eval_values['avg_' + key] = value
        update_eval_values['avg_loader_time'] = loader_time
        update_eval_values['avg_step_time'] = step_time
        keep_avg.update_values(update_eval_values)

        # print eval stats
        if c.print_eval:
            c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)

    if args.rank == 0:
        data_loader.dataset.return_segments = False
        samples = data_loader.dataset.load_test_samples(1)
        m, x = format_test_data(samples[0])

        # setup noise schedule and inference
        noise_schedule = c['test_noise_schedule']
        betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
        if hasattr(model, 'module'):
            model.module.compute_noise_level(betas)
            # compute voice
            x_pred = model.module.inference(m)
        else:
            model.compute_noise_level(betas)
            # compute voice
            x_pred = model.inference(m)

        # compute spectrograms
        figures = plot_results(x_pred, x, ap, global_step, 'eval')
        tb_logger.tb_eval_figures(global_step, figures)

        # Sample audio
        sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
        tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
                                 c.audio["sample_rate"])

        tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
        data_loader.dataset.return_segments = True

    return keep_avg.avg_values
Exemple #7
0
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
          epoch):

    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
            avg_text_length, avg_spec_length, attn_mask, _ = format_data(data)

        loader_time = time.time() - end_time

        global_step += 1
        optimizer.zero_grad()

        # forward pass model
        with torch.cuda.amp.autocast(enabled=c.mixed_precision):
            z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                attn_mask,
                g=speaker_c)

            # compute loss
            loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
                                  o_dur_log, o_total_dur, text_lengths)

        # backward pass with loss scaling
        if c.mixed_precision:
            scaler.scale(loss_dict['loss']).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       c.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict['loss'].backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       c.grad_clip)
            optimizer.step()

        # setup lr
        if c.noam_schedule:
            scheduler.step()

        # current_lr
        current_lr = optimizer.param_groups[0]['lr']

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments, binary=True)
        loss_dict['align_error'] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data,
                                                 num_gpus)
            loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data,
                                                  num_gpus)
            loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus)

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "avg_spec_length": [avg_spec_length, 1],  # value, precision
                "avg_text_length": [avg_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model,
                                    optimizer,
                                    global_step,
                                    epoch,
                                    1,
                                    OUT_PATH,
                                    model_characters,
                                    model_loss=loss_dict['loss'])

                # wait all kernels to be completed
                torch.cuda.synchronize()

                # Diagnostic visualizations
                # direct pass on model for spec predictions
                target_speaker = None if speaker_c is None else speaker_c[:1]

                if hasattr(model, 'module'):
                    spec_pred, *_ = model.module.inference(text_input[:1],
                                                           text_lengths[:1],
                                                           g=target_speaker)
                else:
                    spec_pred, *_ = model.inference(text_input[:1],
                                                    text_lengths[:1],
                                                    g=target_speaker)

                spec_pred = spec_pred.permute(0, 2, 1)
                gt_spec = mel_input.permute(0, 2, 1)
                const_spec = spec_pred[0].data.cpu().numpy()
                gt_spec = gt_spec[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
                _, _, _, dur_target, _ = format_data(data)

            # forward pass model
            with torch.cuda.amp.autocast(enabled=c.mixed_precision):
                decoder_output, dur_output, alignments = model.forward(
                    text_input,
                    text_lengths,
                    mel_lengths,
                    dur_target,
                    g=speaker_c)

                # compute loss
                loss_dict = criterion(decoder_output, mel_targets,
                                      mel_lengths, dur_output,
                                      torch.log(1 + dur_target), text_lengths)

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments, binary=True)
            loss_dict['align_error'] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data,
                                                     num_gpus)
                loss_dict['loss_ssim'] = reduce_tensor(
                    loss_dict['loss_ssim'].data, num_gpus)
                loss_dict['loss_dur'] = reduce_tensor(
                    loss_dict['loss_dur'].data, num_gpus)
                loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
                                                  num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values['avg_' + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            idx = np.random.randint(mel_targets.shape[0])
            pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
            gt_spec = mel_targets[idx].data.cpu().numpy().T
            align_img = alignments[idx].data.cpu()

            eval_figures = {
                "prediction": plot_spectrogram(pred_spec, ap,
                                               output_fig=False),
                "ground_truth": plot_spectrogram(gt_spec, ap,
                                                 output_fig=False),
                "alignment": plot_alignment(align_img, output_fig=False)
            }

            # Sample audio
            eval_audio = ap.inv_melspectrogram(pred_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                     c.audio["sample_rate"])

            # Plot Validation Stats
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch >= c.test_delay_epochs and epoch % c.test_every_epochs == 0:
        if c.test_sentences_file is None:
            test_sentences = [
                "ජනක ප්‍රදීප් ලියනගේ.",
                "රගර් ගැහුවා කියල කොහොමද බූරුවො වොලි බෝල් නැති වෙන්නෙ.",
                "රට්ඨපාල කුමරු ගිහිගෙය හැර පැවිදි වී සිටියි.",
                "අජාසත් රජතුමාගේ ඇත් සේනාවේ අති භයානක ඇතෙක් සිටියා."
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        if c.use_speaker_embedding:
            if c.use_external_speaker_embedding_file:
                speaker_embedding = speaker_mapping[list(
                    speaker_mapping.keys())[randrange(
                        len(speaker_mapping) - 1)]]['embedding']
                speaker_id = None
            else:
                speaker_id = 0
                speaker_embedding = None
        else:
            speaker_id = None
            speaker_embedding = None

        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, _, postnet_output, _, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    speaker_embedding=speaker_embedding,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  #pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False)

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment)
            except:  #pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
Exemple #9
0
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=True)
    if c.use_speaker_embedding:
        speaker_mapping = load_speaker_mapping(OUT_PATH)
    model.eval()
    epoch_time = 0
    eval_values_dict = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stop_loss': 0,
        'avg_align_score': 0
    }
    keep_avg = KeepAverage()
    keep_avg.add_values(eval_values_dict)
    print("\n > Validation")
    if c.test_sentences_file is None:
        test_sentences = [
            "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
            "Be a voice, not an echo.",
            "I'm sorry Dave. I'm afraid I can't do that.",
            "This cake is great. It's so delicious and moist."
        ]
    else:
        with open(c.test_sentences_file, "r") as f:
            test_sentences = [s.strip() for s in f.readlines()]
    test_sentences_with_speaker_id = []
    for sentence in test_sentences:
        ss = sentence.split("|")
        if len(ss) == 1:
            speaker_id = 0 if c.use_speaker_embedding else None
        else:
            speaker_id = int(ss[1])
        test_sentences_with_speaker_id.append((sentence, speaker_id))

    with torch.no_grad():
        if data_loader is not None:
            for num_iter, data in enumerate(data_loader):
                start_time = time.time()

                # setup input data
                text_input = data[0]
                text_lengths = data[1]
                speaker_names = data[2]
                linear_input = data[3] if c.model in [
                    "Tacotron", "TacotronGST"
                ] else None
                mel_input = data[4]
                mel_lengths = data[5]
                stop_targets = data[6]

                if c.use_speaker_embedding:
                    speaker_ids = [
                        speaker_mapping[speaker_name]
                        for speaker_name in speaker_names
                    ]
                    speaker_ids = torch.LongTensor(speaker_ids)
                else:
                    speaker_ids = None

                # set stop targets view, we predict a single stop token per r frames prediction
                stop_targets = stop_targets.view(text_input.shape[0],
                                                 stop_targets.size(1) // c.r,
                                                 -1)
                stop_targets = (stop_targets.sum(2) >
                                0.0).unsqueeze(2).float().squeeze(2)

                # dispatch data to GPU
                if use_cuda:
                    text_input = text_input.cuda()
                    mel_input = mel_input.cuda()
                    mel_lengths = mel_lengths.cuda()
                    linear_input = linear_input.cuda() if c.model in [
                        "Tacotron", "TacotronGST"
                    ] else None
                    stop_targets = stop_targets.cuda()
                    if speaker_ids is not None:
                        speaker_ids = speaker_ids.cuda()

                # forward pass
                decoder_output, postnet_output, alignments, stop_tokens =\
                    model.forward(text_input, text_lengths, mel_input,
                                  speaker_ids=speaker_ids)

                # loss computation
                stop_loss = criterion_st(
                    stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
                if c.loss_masking:
                    decoder_loss = criterion(decoder_output, mel_input,
                                             mel_lengths)
                    if c.model in ["Tacotron", "TacotronGST"]:
                        postnet_loss = criterion(postnet_output, linear_input,
                                                 mel_lengths)
                    else:
                        postnet_loss = criterion(postnet_output, mel_input,
                                                 mel_lengths)
                else:
                    decoder_loss = criterion(decoder_output, mel_input)
                    if c.model in ["Tacotron", "TacotronGST"]:
                        postnet_loss = criterion(postnet_output, linear_input)
                    else:
                        postnet_loss = criterion(postnet_output, mel_input)
                loss = decoder_loss + postnet_loss + stop_loss

                step_time = time.time() - start_time
                epoch_time += step_time

                # compute alignment score
                align_score = alignment_diagonal_score(alignments)
                keep_avg.update_value('avg_align_score', align_score)

                # aggregate losses from processes
                if num_gpus > 1:
                    postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
                    decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
                    if c.stopnet:
                        stop_loss = reduce_tensor(stop_loss.data, num_gpus)

                keep_avg.update_values({
                    'avg_postnet_loss':
                    float(postnet_loss.item()),
                    'avg_decoder_loss':
                    float(decoder_loss.item()),
                    'avg_stop_loss':
                    float(stop_loss.item())
                })

                if num_iter % c.print_step == 0:
                    print(
                        "   | > TotalLoss: {:.5f}   PostnetLoss: {:.5f} - {:.5f}  DecoderLoss:{:.5f} - {:.5f} "
                        "StopLoss: {:.5f} - {:.5f}  AlignScore: {:.4f} : {:.4f}"
                        .format(loss.item(), postnet_loss.item(),
                                keep_avg['avg_postnet_loss'],
                                decoder_loss.item(),
                                keep_avg['avg_decoder_loss'], stop_loss.item(),
                                keep_avg['avg_stop_loss'], align_score,
                                keep_avg['avg_align_score']),
                        flush=True)

            if args.rank == 0:
                # Diagnostic visualizations
                idx = np.random.randint(mel_input.shape[0])
                const_spec = postnet_output[idx].data.cpu().numpy()
                gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[idx].data.cpu().numpy()
                align_img = alignments[idx].data.cpu().numpy()

                eval_figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }
                tb_logger.tb_eval_figures(global_step, eval_figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    eval_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    eval_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                         c.audio["sample_rate"])

                # Plot Validation Stats
                epoch_stats = {
                    "loss_postnet": keep_avg['avg_postnet_loss'],
                    "loss_decoder": keep_avg['avg_decoder_loss'],
                    "stop_loss": keep_avg['avg_stop_loss']
                }
                tb_logger.tb_eval_stats(global_step, epoch_stats)

    if args.rank == 0 and epoch > c.test_delay_epochs:
        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        style_wav = c.get("style_wav_for_test")
        for idx, (test_sentence,
                  speaker_id) in enumerate(test_sentences_with_speaker_id):
            try:
                wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    style_wav=style_wav)
                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(
                    file_path,
                    "TestSentence_{}_{}.wav".format(idx, speaker_id))
                ap.save_wav(wav, file_path)
                test_audios['{}-{}-audio'.format(idx, speaker_id)] = wav
                test_figures['{}-{}-prediction'.format(
                    idx, speaker_id)] = plot_spectrogram(postnet_output, ap)
                test_figures['{}-{}-alignment'.format(
                    idx, speaker_id)] = plot_alignment(alignment)
            except:
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg['avg_postnet_loss']
Exemple #10
0
def train(model,
          criterion,
          optimizer,
          optimizer_st,
          scheduler,
          ap,
          global_step,
          epoch,
          amp,
          speaker_mapping=None):
    data_loader = setup_loader(ap,
                               model.decoder.r,
                               is_val=False,
                               verbose=(epoch == 0),
                               speaker_mapping=speaker_mapping)
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(
            data, speaker_mapping)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # forward pass model
        if c.bidirectional_decoder or c.double_decoder_consistency:
            decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                speaker_ids=speaker_ids,
                speaker_embeddings=speaker_embeddings)
        else:
            decoder_output, postnet_output, alignments, stop_tokens = model(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                speaker_ids=speaker_ids,
                speaker_embeddings=speaker_embeddings)
            decoder_backward_output = None
            alignments_backward = None

        # set the [alignment] lengths wrt reduction factor for guided attention
        if mel_lengths.max() % model.decoder.r != 0:
            alignment_lengths = (
                mel_lengths +
                (model.decoder.r -
                 (mel_lengths.max() % model.decoder.r))) // model.decoder.r
        else:
            alignment_lengths = mel_lengths // model.decoder.r

        # compute loss
        loss_dict = criterion(postnet_output, decoder_output, mel_input,
                              linear_input, stop_tokens, stop_targets,
                              mel_lengths, decoder_backward_output, alignments,
                              alignment_lengths, alignments_backward,
                              text_lengths)

        # backward pass
        if amp is not None:
            with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss_dict['loss'].backward()

        optimizer, current_lr = adam_weight_decay(optimizer)
        if amp:
            amp_opt_params = amp.master_params(optimizer)
        else:
            amp_opt_params = None
        grad_norm, _ = check_update(model,
                                    c.grad_clip,
                                    ignore_stopnet=True,
                                    amp_opt_params=amp_opt_params)
        optimizer.step()

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        loss_dict['align_error'] = align_error

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            loss_dict['stopnet_loss'].backward()
            optimizer_st, _ = adam_weight_decay(optimizer_st)
            if amp:
                amp_opt_params = amp.master_params(optimizer)
            else:
                amp_opt_params = None
            grad_norm_st, _ = check_update(model.decoder.stopnet,
                                           1.0,
                                           amp_opt_params=amp_opt_params)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict['postnet_loss'] = reduce_tensor(
                loss_dict['postnet_loss'].data, num_gpus)
            loss_dict['decoder_loss'] = reduce_tensor(
                loss_dict['decoder_loss'].data, num_gpus)
            loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus)
            loss_dict['stopnet_loss'] = reduce_tensor(
                loss_dict['stopnet_loss'].data,
                num_gpus) if c.stopnet else loss_dict['stopnet_loss']

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "avg_spec_length": [avg_spec_length, 1],  # value, precision
                "avg_text_length": [avg_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        model.decoder.r,
                        OUT_PATH,
                        optimizer_st=optimizer_st,
                        model_loss=loss_dict['postnet_loss'],
                        amp_state_dict=amp.state_dict() if amp else None)

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction":
                    plot_spectrogram(const_spec, ap, output_fig=False),
                    "ground_truth":
                    plot_spectrogram(gt_spec, ap, output_fig=False),
                    "alignment":
                    plot_alignment(align_img, output_fig=False),
                }

                if c.bidirectional_decoder or c.double_decoder_consistency:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy(),
                        output_fig=False)

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Exemple #11
0
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _, dur_target, _ = format_data(data)

            # forward pass model
            with torch.cuda.amp.autocast(enabled=c.mixed_precision):
                decoder_output, dur_output, alignments = model.forward(
                    text_input, text_lengths, mel_lengths, dur_target, g=speaker_c
                )

                # compute loss
                loss_dict = criterion(
                    decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths
                )

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments, binary=True)
            loss_dict["align_error"] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
                loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
                loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
                loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values["avg_" + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            idx = np.random.randint(mel_targets.shape[0])
            pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
            gt_spec = mel_targets[idx].data.cpu().numpy().T
            align_img = alignments[idx].data.cpu()

            eval_figures = {
                "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
                "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
                "alignment": plot_alignment(align_img, output_fig=False),
            }

            # Sample audio
            eval_audio = ap.inv_melspectrogram(pred_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])

            # Plot Validation Stats
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch >= c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
                "Be a voice, not an echo.",
                "I'm sorry Dave. I'm afraid I can't do that.",
                "This cake is great. It's so delicious and moist.",
                "Prior to November 22, 1963.",
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        if c.use_speaker_embedding:
            if c.use_external_speaker_embedding_file:
                speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
                    "embedding"
                ]
                speaker_id = None
            else:
                speaker_id = 0
                speaker_embedding = None
        else:
            speaker_id = None
            speaker_embedding = None

        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, _, postnet_output, _, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    speaker_embedding=speaker_embedding,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  # pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False,
                )

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios["{}-audio".format(idx)] = wav
                test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
                test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
            except:  # pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
          epoch):

    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (config.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / config.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        (
            text_input,
            text_lengths,
            mel_targets,
            mel_lengths,
            speaker_c,
            avg_text_length,
            avg_spec_length,
            _,
            dur_target,
            _,
        ) = format_data(data)

        loader_time = time.time() - end_time

        global_step += 1
        optimizer.zero_grad()

        # forward pass model
        with torch.cuda.amp.autocast(enabled=config.mixed_precision):
            decoder_output, dur_output, alignments = model.forward(
                text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)

            # compute loss
            loss_dict = criterion(decoder_output, mel_targets,
                                  mel_lengths, dur_output,
                                  torch.log(1 + dur_target), text_lengths)

        # backward pass with loss scaling
        if config.mixed_precision:
            scaler.scale(loss_dict["loss"]).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       config.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict["loss"].backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       config.grad_clip)
            optimizer.step()

        # setup lr
        if config.noam_schedule:
            scheduler.step()

        # current_lr
        current_lr = optimizer.param_groups[0]["lr"]

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments, binary=True)
        loss_dict["align_error"] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data,
                                                 num_gpus)
            loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data,
                                                   num_gpus)
            loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data,
                                                  num_gpus)
            loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values["avg_" + key] = value
        update_train_values["avg_loader_time"] = loader_time
        update_train_values["avg_step_time"] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % config.print_step == 0:
            log_dict = {
                "avg_spec_length": [avg_spec_length, 1],  # value, precision
                "avg_text_length": [avg_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % config.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % config.save_step == 0:
                if config.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        1,
                        OUT_PATH,
                        model_characters,
                        model_loss=loss_dict["loss"],
                    )

                # wait all kernels to be completed
                torch.cuda.synchronize()

                # Diagnostic visualizations
                idx = np.random.randint(mel_targets.shape[0])
                pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
                gt_spec = mel_targets[idx].data.cpu().numpy().T
                align_img = alignments[idx].data.cpu()

                figures = {
                    "prediction": plot_spectrogram(pred_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                train_audio = ap.inv_melspectrogram(pred_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {"TrainAudio": train_audio},
                                          config.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if config.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Exemple #13
0
def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None):
    data_loader = setup_loader(ap,
                               model.decoder.r,
                               is_val=True,
                               speaker_mapping=speaker_mapping)
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(
                data, speaker_mapping)
            assert mel_input.shape[1] % model.decoder.r == 0

            # forward pass model
            if c.bidirectional_decoder or c.double_decoder_consistency:
                decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings)
            else:
                decoder_output, postnet_output, alignments, stop_tokens = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings)
                decoder_backward_output = None
                alignments_backward = None

            # set the alignment lengths wrt reduction factor for guided attention
            if mel_lengths.max() % model.decoder.r != 0:
                alignment_lengths = (
                    mel_lengths +
                    (model.decoder.r -
                     (mel_lengths.max() % model.decoder.r))) // model.decoder.r
            else:
                alignment_lengths = mel_lengths // model.decoder.r

            # compute loss
            loss_dict = criterion(postnet_output, decoder_output, mel_input,
                                  linear_input, stop_tokens, stop_targets,
                                  mel_lengths, decoder_backward_output,
                                  alignments, alignment_lengths,
                                  alignments_backward, text_lengths)

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments)
            loss_dict['align_error'] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict['postnet_loss'] = reduce_tensor(
                    loss_dict['postnet_loss'].data, num_gpus)
                loss_dict['decoder_loss'] = reduce_tensor(
                    loss_dict['decoder_loss'].data, num_gpus)
                if c.stopnet:
                    loss_dict['stopnet_loss'] = reduce_tensor(
                        loss_dict['stopnet_loss'].data, num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values['avg_' + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            idx = np.random.randint(mel_input.shape[0])
            const_spec = postnet_output[idx].data.cpu().numpy()
            gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
                "Tacotron", "TacotronGST"
            ] else mel_input[idx].data.cpu().numpy()
            align_img = alignments[idx].data.cpu().numpy()

            eval_figures = {
                "prediction": plot_spectrogram(const_spec,
                                               ap,
                                               output_fig=False),
                "ground_truth": plot_spectrogram(gt_spec, ap,
                                                 output_fig=False),
                "alignment": plot_alignment(align_img, output_fig=False)
            }

            # Sample audio
            if c.model in ["Tacotron", "TacotronGST"]:
                eval_audio = ap.inv_spectrogram(const_spec.T)
            else:
                eval_audio = ap.inv_melspectrogram(const_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                     c.audio["sample_rate"])

            # Plot Validation Stats

            if c.bidirectional_decoder or c.double_decoder_consistency:
                align_b_img = alignments_backward[idx].data.cpu().numpy()
                eval_figures['alignment2'] = plot_alignment(align_b_img,
                                                            output_fig=False)
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch > c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "Unabhängig davon, wer gewinnt, bestehen erhebliche Zweifel, ob die Präsidentschaftswahlen überhaupt verfassungskonform sind.",
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        speaker_id = 0 if c.use_speaker_embedding else None
        style_wav = c.get("gst_style_input")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  #pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False)

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap, output_fig=False)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment, output_fig=False)
            except:  #pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
Exemple #14
0
def evaluate(model, criterion, ap, global_step, epoch):
    data_loader = setup_loader(ap, model.decoder.r, is_val=True)
    model.eval()
    epoch_time = 0
    eval_values_dict = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stopnet_loss': 0,
        'avg_align_error': 0
    }
    if c.bidirectional_decoder:
        eval_values_dict['avg_decoder_b_loss'] = 0  # decoder backward loss
        eval_values_dict['avg_decoder_c_loss'] = 0  # decoder consistency loss
    if c.ga_alpha > 0:
        eval_values_dict['avg_ga_loss'] = 0  # guidede attention loss
    keep_avg = KeepAverage()
    keep_avg.add_values(eval_values_dict)

    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(
                data)
            assert mel_input.shape[1] % model.decoder.r == 0

            # forward pass model
            if c.bidirectional_decoder:
                decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    speaker_ids=speaker_ids)
            else:
                decoder_output, postnet_output, alignments, stop_tokens = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    speaker_ids=speaker_ids)
                decoder_backward_output = None

            # set the alignment lengths wrt reduction factor for guided attention
            if mel_lengths.max() % model.decoder.r != 0:
                alignment_lengths = (
                    mel_lengths +
                    (model.decoder.r -
                     (mel_lengths.max() % model.decoder.r))) // model.decoder.r
            else:
                alignment_lengths = mel_lengths // model.decoder.r

            # compute loss
            loss_dict = criterion(postnet_output, decoder_output, mel_input,
                                  linear_input, stop_tokens, stop_targets,
                                  mel_lengths, decoder_backward_output,
                                  alignments, alignment_lengths, text_lengths)
            if c.bidirectional_decoder:
                keep_avg.update_values({
                    'avg_decoder_b_loss':
                    loss_dict['decoder_b_loss'].item(),
                    'avg_decoder_c_loss':
                    loss_dict['decoder_c_loss'].item()
                })
            if c.ga_alpha > 0:
                keep_avg.update_values(
                    {'avg_ga_loss': loss_dict['ga_loss'].item()})

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments)
            keep_avg.update_value('avg_align_error', align_error)

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict['postnet_loss'] = reduce_tensor(
                    loss_dict['postnet_loss'].data, num_gpus)
                loss_dict['decoder_loss'] = reduce_tensor(
                    loss_dict['decoder_loss'].data, num_gpus)
                if c.stopnet:
                    loss_dict['stopnet_loss'] = reduce_tensor(
                        loss_dict['stopnet_loss'].data, num_gpus)

            keep_avg.update_values({
                'avg_postnet_loss':
                float(loss_dict['postnet_loss'].item()),
                'avg_decoder_loss':
                float(loss_dict['decoder_loss'].item()),
                'avg_stopnet_loss':
                float(loss_dict['stopnet_loss'].item()),
            })

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            idx = np.random.randint(mel_input.shape[0])
            const_spec = postnet_output[idx].data.cpu().numpy()
            gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
                "Tacotron", "TacotronGST"
            ] else mel_input[idx].data.cpu().numpy()
            align_img = alignments[idx].data.cpu().numpy()

            eval_figures = {
                "prediction": plot_spectrogram(const_spec, ap),
                "ground_truth": plot_spectrogram(gt_spec, ap),
                "alignment": plot_alignment(align_img)
            }

            # Sample audio
            if c.model in ["Tacotron", "TacotronGST"]:
                eval_audio = ap.inv_spectrogram(const_spec.T)
            else:
                eval_audio = ap.inv_melspectrogram(const_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                     c.audio["sample_rate"])

            # Plot Validation Stats
            epoch_stats = {
                "loss_postnet": keep_avg['avg_postnet_loss'],
                "loss_decoder": keep_avg['avg_decoder_loss'],
                "stopnet_loss": keep_avg['avg_stopnet_loss'],
                "alignment_score": keep_avg['avg_align_error'],
            }

            if c.bidirectional_decoder:
                epoch_stats['loss_decoder_backward'] = keep_avg[
                    'avg_decoder_b_loss']
                align_b_img = alignments_backward[idx].data.cpu().numpy()
                eval_figures['alignment_backward'] = plot_alignment(
                    align_b_img)
            if c.ga_alpha > 0:
                epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss']
            tb_logger.tb_eval_stats(global_step, epoch_stats)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch > c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "Con la mia voce posso dire cose splendide.",
                "Ciao Marco ed Alice, come state?",
                "Ora che ho una voce, voglio solo parlare.",
                "Tra tutte le cose che ho letto, in tanti anni, questo libro è davvero il mio preferito."
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        speaker_id = 0 if c.use_speaker_embedding else None
        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  #pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False)

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment)
            except:
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
Exemple #15
0
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step,
             epoch):
    data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
    model_G.eval()
    model_D.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    end_time = time.time()
    c_logger.print_eval_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        c_G, y_G, _, _ = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        ##############################
        # GENERATOR
        ##############################

        # generator pass
        y_hat = model_G(c_G)
        y_hat_sub = None
        y_G_sub = None

        # PQMF formatting
        if y_hat.shape[1] > 1:
            y_hat_sub = y_hat
            y_hat = model_G.pqmf_synthesis(y_hat)
            y_G_sub = model_G.pqmf_analysis(y_G)

        scores_fake, feats_fake, feats_real = None, None, None
        if global_step > c.steps_to_start_discriminator:

            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat, c_G)
            else:
                D_out_fake = model_D(y_hat)
            D_out_real = None

            if c.use_feat_match_loss:
                with torch.no_grad():
                    D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    feats_real = None
                else:
                    _, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
                feats_fake, feats_real = None, None

        # compute losses
        loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
                                  feats_real, y_hat_sub, y_G_sub)

        loss_dict = dict()
        for key, value in loss_G_dict.items():
            if isinstance(value, (int, float)):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        ##############################
        # DISCRIMINATOR
        ##############################

        if global_step >= c.steps_to_start_discriminator:
            # discriminator pass
            with torch.no_grad():
                y_hat = model_G(c_G)

            # PQMF formatting
            if y_hat.shape[1] > 1:
                y_hat = model_G.pqmf_synthesis(y_hat)

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat.detach(), c_G)
                D_out_real = model_D(y_G, c_G)
            else:
                D_out_fake = model_D(y_hat.detach())
                D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    scores_real, feats_real = None, None
                else:
                    scores_real, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
                scores_real = D_out_real

            # compute losses
            loss_D_dict = criterion_D(scores_fake, scores_real)

            for key, value in loss_D_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict[key] = value
                else:
                    loss_dict[key] = value.item()

        step_time = time.time() - start_time
        epoch_time += step_time

        # update avg stats
        update_eval_values = dict()
        for key, value in loss_dict.items():
            update_eval_values['avg_' + key] = value
        update_eval_values['avg_loader_time'] = loader_time
        update_eval_values['avg_step_time'] = step_time
        keep_avg.update_values(update_eval_values)

        # print eval stats
        if c.print_eval:
            c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)

    if args.rank == 0:
        # compute spectrograms
        figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
        tb_logger.tb_eval_figures(global_step, figures)

        # Sample audio
        sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
        tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
                                 c.audio["sample_rate"])

        tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)

    # synthesize a full voice
    data_loader.return_segments = False

    return keep_avg.avg_values
Exemple #16
0
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
          ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    if c.use_speaker_embedding:
        speaker_mapping = load_speaker_mapping(OUT_PATH)
    model.train()
    epoch_time = 0
    train_values = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stop_loss': 0,
        'avg_align_score': 0,
        'avg_step_time': 0,
        'avg_loader_time': 0,
        'avg_alignment_score': 0
    }
    keep_avg = KeepAverage()
    keep_avg.add_values(train_values)
    print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        speaker_names = data[2]
        linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"
                                              ] else None
        mel_input = data[4]
        mel_lengths = data[5]
        stop_targets = data[6]
        avg_text_length = torch.mean(text_lengths.float())
        avg_spec_length = torch.mean(mel_lengths.float())
        loader_time = time.time() - end_time

        if c.use_speaker_embedding:
            speaker_ids = [
                speaker_mapping[speaker_name] for speaker_name in speaker_names
            ]
            speaker_ids = torch.LongTensor(speaker_ids)
        else:
            speaker_ids = None

        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0],
                                         stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) >
                        0.0).unsqueeze(2).float().squeeze(2)

        global_step += 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda(non_blocking=True)
            text_lengths = text_lengths.cuda(non_blocking=True)
            mel_input = mel_input.cuda(non_blocking=True)
            mel_lengths = mel_lengths.cuda(non_blocking=True)
            linear_input = linear_input.cuda(
                non_blocking=True) if c.model in ["Tacotron", "TacotronGST"
                                                  ] else None
            stop_targets = stop_targets.cuda(non_blocking=True)
            if speaker_ids is not None:
                speaker_ids = speaker_ids.cuda(non_blocking=True)

        # forward pass model
        decoder_output, postnet_output, alignments, stop_tokens = model(
            text_input, text_lengths, mel_input, speaker_ids=speaker_ids)

        # loss computation
        stop_loss = criterion_st(stop_tokens,
                                 stop_targets) if c.stopnet else torch.zeros(1)
        if c.loss_masking:
            decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input,
                                         mel_lengths)
            else:
                postnet_loss = criterion(postnet_output, mel_input,
                                         mel_lengths)
        else:
            decoder_loss = criterion(decoder_output, mel_input)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input)
            else:
                postnet_loss = criterion(postnet_output, mel_input)
        loss = decoder_loss + postnet_loss
        if not c.separate_stopnet and c.stopnet:
            loss += stop_loss

        loss.backward()
        optimizer, current_lr = adam_weight_decay(optimizer)
        grad_norm, _ = check_update(model, c.grad_clip)
        optimizer.step()

        # compute alignment score
        align_score = alignment_diagonal_score(alignments)
        keep_avg.update_value('avg_align_score', align_score)

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            stop_loss.backward()
            optimizer_st, _ = adam_weight_decay(optimizer_st)
            grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        epoch_time += step_time

        if global_step % c.print_step == 0:
            print(
                "   | > Step:{}/{}  GlobalStep:{}  PostnetLoss:{:.5f}  "
                "DecoderLoss:{:.5f}  StopLoss:{:.5f}  AlignScore:{:.4f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  "
                "LoaderTime:{:.2f}  LR:{:.6f}".format(
                    num_iter, batch_n_iter, global_step, postnet_loss.item(),
                    decoder_loss.item(), stop_loss.item(), align_score,
                    grad_norm, grad_norm_st, avg_text_length, avg_spec_length,
                    step_time, loader_time, current_lr),
                flush=True)

        # aggregate losses from processes
        if num_gpus > 1:
            postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
            decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
            loss = reduce_tensor(loss.data, num_gpus)
            stop_loss = reduce_tensor(stop_loss.data,
                                      num_gpus) if c.stopnet else stop_loss

        if args.rank == 0:
            update_train_values = {
                'avg_postnet_loss':
                float(postnet_loss.item()),
                'avg_decoder_loss':
                float(decoder_loss.item()),
                'avg_stop_loss':
                stop_loss
                if isinstance(stop_loss, float) else float(stop_loss.item()),
                'avg_step_time':
                step_time,
                'avg_loader_time':
                loader_time
            }
            keep_avg.update_values(update_train_values)

            # Plot Training Iter Stats
            # reduce TB load
            if global_step % 10 == 0:
                iter_stats = {
                    "loss_posnet": postnet_loss.item(),
                    "loss_decoder": decoder_loss.item(),
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time
                }
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model, optimizer, optimizer_st,
                                    postnet_loss.item(), OUT_PATH, global_step,
                                    epoch)

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }
                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()
        # print(start_time-end_time)

    # print epoch stats
    print("   | > EPOCH END -- GlobalStep:{}  AvgTotalLoss:{:.5f}  "
          "AvgPostnetLoss:{:.5f}  AvgDecoderLoss:{:.5f}  "
          "AvgStopLoss:{:.5f}  EpochTime:{:.2f}  "
          "AvgStepTime:{:.2f}  AvgLoaderTime:{:.2f}".format(
              global_step, keep_avg['avg_postnet_loss'],
              keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
              keep_avg['avg_align_score'], epoch_time,
              keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
          flush=True)

    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": keep_avg['avg_postnet_loss'],
            "loss_decoder": keep_avg['avg_decoder_loss'],
            "stop_loss": keep_avg['avg_stop_loss'],
            "alignment_score": keep_avg['avg_align_score'],
            "epoch_time": epoch_time
        }
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg['avg_postnet_loss'], global_step
Exemple #17
0
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
          scheduler_G, scheduler_D, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    model_G.train()
    model_D.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        c_G, y_G, c_D, y_D = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        ##############################
        # GENERATOR
        ##############################

        # generator pass
        y_hat = model_G(c_G)
        y_hat_sub = None
        y_G_sub = None
        y_hat_vis = y_hat  # for visualization

        # PQMF formatting
        if y_hat.shape[1] > 1:
            y_hat_sub = y_hat
            y_hat = model_G.pqmf_synthesis(y_hat)
            y_hat_vis = y_hat
            y_G_sub = model_G.pqmf_analysis(y_G)

        scores_fake, feats_fake, feats_real = None, None, None
        if global_step > c.steps_to_start_discriminator:

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat, c_G)
            else:
                D_out_fake = model_D(y_hat)
            D_out_real = None

            if c.use_feat_match_loss:
                with torch.no_grad():
                    D_out_real = model_D(y_G)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    feats_real = None
                else:
                    _, feats_real = D_out_real
            else:
                scores_fake = D_out_fake

        # compute losses
        loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
                                  feats_real, y_hat_sub, y_G_sub)
        loss_G = loss_G_dict['G_loss']

        # optimizer generator
        optimizer_G.zero_grad()
        loss_G.backward()
        if c.gen_clip_grad > 0:
            torch.nn.utils.clip_grad_norm_(model_G.parameters(),
                                           c.gen_clip_grad)
        optimizer_G.step()
        if scheduler_G is not None:
            scheduler_G.step()

        loss_dict = dict()
        for key, value in loss_G_dict.items():
            if isinstance(value, int):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        ##############################
        # DISCRIMINATOR
        ##############################
        if global_step >= c.steps_to_start_discriminator:
            # discriminator pass
            with torch.no_grad():
                y_hat = model_G(c_D)

            # PQMF formatting
            if y_hat.shape[1] > 1:
                y_hat = model_G.pqmf_synthesis(y_hat)

            # run D with or without cond. features
            if len(signature(model_D.forward).parameters) == 2:
                D_out_fake = model_D(y_hat.detach(), c_D)
                D_out_real = model_D(y_D, c_D)
            else:
                D_out_fake = model_D(y_hat.detach())
                D_out_real = model_D(y_D)

            # format D outputs
            if isinstance(D_out_fake, tuple):
                scores_fake, feats_fake = D_out_fake
                if D_out_real is None:
                    scores_real, feats_real = None, None
                else:
                    scores_real, feats_real = D_out_real
            else:
                scores_fake = D_out_fake
                scores_real = D_out_real

            # compute losses
            loss_D_dict = criterion_D(scores_fake, scores_real)
            loss_D = loss_D_dict['D_loss']

            # optimizer discriminator
            optimizer_D.zero_grad()
            loss_D.backward()
            if c.disc_clip_grad > 0:
                torch.nn.utils.clip_grad_norm_(model_D.parameters(),
                                               c.disc_clip_grad)
            optimizer_D.step()
            if scheduler_D is not None:
                scheduler_D.step()

            for key, value in loss_D_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict[key] = value
                else:
                    loss_dict[key] = value.item()

        step_time = time.time() - start_time
        epoch_time += step_time

        # get current learning rates
        current_lr_G = list(optimizer_G.param_groups)[0]['lr']
        current_lr_D = list(optimizer_D.param_groups)[0]['lr']

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training stats
        if global_step % c.print_step == 0:
            log_dict = {
                'step_time': [step_time, 2],
                'loader_time': [loader_time, 4],
                "current_lr_G": current_lr_G,
                "current_lr_D": current_lr_D
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # plot step stats
            if global_step % 10 == 0:
                iter_stats = {
                    "lr_G": current_lr_G,
                    "lr_D": current_lr_D,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            # save checkpoint
            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model_G,
                                    optimizer_G,
                                    scheduler_G,
                                    model_D,
                                    optimizer_D,
                                    scheduler_D,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=loss_dict)

                # compute spectrograms
                figures = plot_results(y_hat_vis, y_G, ap, global_step,
                                       'train')
                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
                tb_logger.tb_train_audios(global_step,
                                          {'train/audio': sample_voice},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Training Epoch Stats
    epoch_stats = {"epoch_time": epoch_time}
    epoch_stats.update(keep_avg.avg_values)
    if args.rank == 0:
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
    # TODO: plot model stats
    # if c.tb_model_param_stats:
    # tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Exemple #18
0
def evaluate(model, criterion, ap, global_step, epoch):
    # create train loader
    data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    end_time = time.time()
    c_logger.print_eval_start()
    with torch.no_grad():
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()
            # format data
            x_input, mels, y_coarse = format_data(data)
            loader_time = time.time() - end_time
            global_step += 1

            y_hat = model(x_input, mels)
            if isinstance(model.mode, int):
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            else:
                y_coarse = y_coarse.float()
            y_coarse = y_coarse.unsqueeze(-1)
            loss = criterion(y_hat, y_coarse)
            # Compute avg loss
            # if num_gpus > 1:
            #     loss = reduce_tensor(loss.data, num_gpus)
            loss_dict = dict()
            loss_dict["model_loss"] = loss.item()

            step_time = time.time() - start_time
            epoch_time += step_time

            # update avg stats
            update_eval_values = dict()
            for key, value in loss_dict.items():
                update_eval_values["avg_" + key] = value
            update_eval_values["avg_loader_time"] = loader_time
            update_eval_values["avg_step_time"] = step_time
            keep_avg.update_values(update_eval_values)

            # print eval stats
            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

    if epoch % c.test_every_epochs == 0 and epoch != 0:
        # synthesize a full voice
        rand_idx = random.randrange(0, len(eval_data))
        wav_path = eval_data[rand_idx] if not isinstance(
            eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
        wav = ap.load_wav(wav_path)
        ground_mel = ap.melspectrogram(wav)
        sample_wav = model.generate(ground_mel, c.batched, c.target_samples,
                                    c.overlap_samples, use_cuda)
        predict_mel = ap.melspectrogram(sample_wav)

        # Sample audio
        tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_wav},
                                 c.audio["sample_rate"])

        # compute spectrograms
        figures = {
            "eval/ground_truth": plot_spectrogram(ground_mel.T),
            "eval/prediction": plot_spectrogram(predict_mel.T)
        }
        tb_logger.tb_eval_figures(global_step, figures)

    tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
    return keep_avg.avg_values
Exemple #19
0
def train(model, criterion, optimizer,
          scheduler, scaler, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    # setup noise schedule
    noise_schedule = c['train_noise_schedule']
    betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
    if hasattr(model, 'module'):
        model.module.compute_noise_level(betas)
    else:
        model.compute_noise_level(betas)
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        m, x = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        with torch.cuda.amp.autocast(enabled=c.mixed_precision):
            # compute noisy input
            if hasattr(model, 'module'):
                noise, x_noisy, noise_scale = model.module.compute_y_n(x)
            else:
                noise, x_noisy, noise_scale = model.compute_y_n(x)

            # forward pass
            noise_hat = model(x_noisy, m, noise_scale)

            # compute losses
            loss = criterion(noise, noise_hat)
        loss_wavegrad_dict = {'wavegrad_loss':loss}

        # check nan loss
        if torch.isnan(loss).any():
            raise RuntimeError(f'Detected NaN loss at step {global_step}.')

        optimizer.zero_grad()

        # backward pass with loss scaling
        if c.mixed_precision:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           c.clip_grad)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           c.clip_grad)
            optimizer.step()

        # schedule update
        if scheduler is not None:
            scheduler.step()

        # disconnect loss values
        loss_dict = dict()
        for key, value in loss_wavegrad_dict.items():
            if isinstance(value, int):
                loss_dict[key] = value
            else:
                loss_dict[key] = value.item()

        # epoch/step timing
        step_time = time.time() - start_time
        epoch_time += step_time

        # get current learning rates
        current_lr = list(optimizer.param_groups)[0]['lr']

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training stats
        if global_step % c.print_step == 0:
            log_dict = {
                'step_time': [step_time, 2],
                'loader_time': [loader_time, 4],
                "current_lr": current_lr,
                "grad_norm": grad_norm.item()
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # plot step stats
            if global_step % 10 == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm.item(),
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            # save checkpoint
            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model,
                                    optimizer,
                                    scheduler,
                                    None,
                                    None,
                                    None,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=loss_dict,
                                    scaler=scaler.state_dict() if c.mixed_precision else None)

        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Training Epoch Stats
    epoch_stats = {"epoch_time": epoch_time}
    epoch_stats.update(keep_avg.avg_values)
    if args.rank == 0:
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
    # TODO: plot model stats
    if c.tb_model_param_stats and args.rank == 0:
        tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Exemple #20
0
def train(model, optimizer, criterion, scheduler, scaler, ap, global_step,
          epoch):
    # create train loader
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    # train loop
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()
        x_input, mels, y_coarse = format_data(data)
        loader_time = time.time() - end_time
        global_step += 1

        optimizer.zero_grad()

        if c.mixed_precision:
            # mixed precision training
            with torch.cuda.amp.autocast():
                y_hat = model(x_input, mels)
                if isinstance(model.mode, int):
                    y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
                else:
                    y_coarse = y_coarse.float()
                y_coarse = y_coarse.unsqueeze(-1)
                # compute losses
                loss = criterion(y_hat, y_coarse)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            if c.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            # full precision training
            y_hat = model(x_input, mels)
            if isinstance(model.mode, int):
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            else:
                y_coarse = y_coarse.float()
            y_coarse = y_coarse.unsqueeze(-1)
            # compute losses
            loss = criterion(y_hat, y_coarse)
            if loss.item() is None:
                raise RuntimeError(" [!] None loss. Exiting ...")
            loss.backward()
            if c.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        # get the current learning rate
        cur_lr = list(optimizer.param_groups)[0]["lr"]

        step_time = time.time() - start_time
        epoch_time += step_time

        update_train_values = dict()
        loss_dict = dict()
        loss_dict["model_loss"] = loss.item()
        for key, value in loss_dict.items():
            update_train_values["avg_" + key] = value
        update_train_values["avg_loader_time"] = loader_time
        update_train_values["avg_step_time"] = step_time
        keep_avg.update_values(update_train_values)

        # print training stats
        if global_step % c.print_step == 0:
            log_dict = {
                "step_time": [step_time, 2],
                "loader_time": [loader_time, 4],
                "current_lr": cur_lr,
            }
            c_logger.print_train_step(
                batch_n_iter,
                num_iter,
                global_step,
                log_dict,
                loss_dict,
                keep_avg.avg_values,
            )

        # plot step stats
        if global_step % 10 == 0:
            iter_stats = {"lr": cur_lr, "step_time": step_time}
            iter_stats.update(loss_dict)
            tb_logger.tb_train_iter_stats(global_step, iter_stats)

        # save checkpoint
        if global_step % c.save_step == 0:
            if c.checkpoint:
                # save model
                save_checkpoint(
                    model,
                    optimizer,
                    scheduler,
                    None,
                    None,
                    None,
                    global_step,
                    epoch,
                    OUT_PATH,
                    model_losses=loss_dict,
                    scaler=scaler.state_dict() if c.mixed_precision else None)

            # synthesize a full voice
            rand_idx = random.randrange(0, len(train_data))
            wav_path = train_data[rand_idx] if not isinstance(
                train_data[rand_idx],
                (tuple, list)) else train_data[rand_idx][0]
            wav = ap.load_wav(wav_path)
            ground_mel = ap.melspectrogram(wav)
            sample_wav = model.generate(ground_mel, c.batched,
                                        c.target_samples, c.overlap_samples,
                                        use_cuda)
            predict_mel = ap.melspectrogram(sample_wav)

            # compute spectrograms
            figures = {
                "train/ground_truth": plot_spectrogram(ground_mel.T),
                "train/prediction": plot_spectrogram(predict_mel.T)
            }
            tb_logger.tb_train_figures(global_step, figures)

            # Sample audio
            tb_logger.tb_train_audios(global_step, {"train/audio": sample_wav},
                                      c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Training Epoch Stats
    epoch_stats = {"epoch_time": epoch_time}
    epoch_stats.update(keep_avg.avg_values)
    tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
    # TODO: plot model stats
    # if c.tb_model_param_stats:
    # tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
Exemple #21
0
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
    model.eval()
    epoch_time = 0
    keep_avg = KeepAverage()
    c_logger.print_eval_start()
    if data_loader is not None:
        for num_iter, data in enumerate(data_loader):
            start_time = time.time()

            # format data
            text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
                _, _, attn_mask, _ = format_data(data)

            # forward pass model
            z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                attn_mask,
                g=speaker_c)

            # compute loss
            loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
                                  o_dur_log, o_total_dur, text_lengths)

            # step time
            step_time = time.time() - start_time
            epoch_time += step_time

            # compute alignment score
            align_error = 1 - alignment_diagonal_score(alignments)
            loss_dict['align_error'] = align_error

            # aggregate losses from processes
            if num_gpus > 1:
                loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data,
                                                     num_gpus)
                loss_dict['loss_dur'] = reduce_tensor(
                    loss_dict['loss_dur'].data, num_gpus)
                loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
                                                  num_gpus)

            # detach loss values
            loss_dict_new = dict()
            for key, value in loss_dict.items():
                if isinstance(value, (int, float)):
                    loss_dict_new[key] = value
                else:
                    loss_dict_new[key] = value.item()
            loss_dict = loss_dict_new

            # update avg stats
            update_train_values = dict()
            for key, value in loss_dict.items():
                update_train_values['avg_' + key] = value
            keep_avg.update_values(update_train_values)

            if c.print_eval:
                c_logger.print_eval_step(num_iter, loss_dict,
                                         keep_avg.avg_values)

        if args.rank == 0:
            # Diagnostic visualizations
            # direct pass on model for spec predictions
            target_speaker = None if speaker_c is None else speaker_c[:1]
            if hasattr(model, 'module'):
                spec_pred, *_ = model.module.inference(text_input[:1],
                                                       text_lengths[:1],
                                                       g=target_speaker)
            else:
                spec_pred, *_ = model.inference(text_input[:1],
                                                text_lengths[:1],
                                                g=target_speaker)
            spec_pred = spec_pred.permute(0, 2, 1)
            gt_spec = mel_input.permute(0, 2, 1)

            const_spec = spec_pred[0].data.cpu().numpy()
            gt_spec = gt_spec[0].data.cpu().numpy()
            align_img = alignments[0].data.cpu().numpy()

            eval_figures = {
                "prediction": plot_spectrogram(const_spec, ap),
                "ground_truth": plot_spectrogram(gt_spec, ap),
                "alignment": plot_alignment(align_img)
            }

            # Sample audio
            eval_audio = ap.inv_melspectrogram(const_spec.T)
            tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                     c.audio["sample_rate"])

            # Plot Validation Stats
            tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
            tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch >= c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
                "Be a voice, not an echo.",
                "I'm sorry Dave. I'm afraid I can't do that.",
                "This cake is great. It's so delicious and moist.",
                "Prior to November 22, 1963."
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        if c.use_speaker_embedding:
            if c.use_external_speaker_embedding_file:
                speaker_embedding = speaker_mapping[list(
                    speaker_mapping.keys())[randrange(
                        len(speaker_mapping) - 1)]]['embedding']
                speaker_id = None
            else:
                speaker_id = 0
                speaker_embedding = None
        else:
            speaker_id = None
            speaker_embedding = None

        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, _, postnet_output, _, _ = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    speaker_embedding=speaker_embedding,
                    style_wav=style_wav,
                    truncated=False,
                    enable_eos_bos_chars=c.enable_eos_bos_chars,  #pylint: disable=unused-argument
                    use_griffin_lim=True,
                    do_trim_silence=False)

                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment)
            except:  #pylint: disable=bare-except
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg.avg_values
Exemple #22
0
def train(model, criterion, optimizer, scheduler,
          ap, global_step, epoch, amp):
    data_loader = setup_loader(ap, 1, is_val=False,
                               verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, _,\
            avg_text_length, avg_spec_length, attn_mask = format_data(data)

        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()
        optimizer.zero_grad()

        # forward pass model
        z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
            text_input, text_lengths, mel_input, mel_lengths, attn_mask)

        # compute loss
        loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
                              o_dur_log, o_total_dur, text_lengths)

        # backward pass
        if amp is not None:
            with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss_dict['loss'].backward()

        if amp:
            amp_opt_params = amp.master_params(optimizer)
        else:
            amp_opt_params = None
        grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
        optimizer.step()

        # current_lr
        current_lr = optimizer.param_groups[0]['lr']

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        loss_dict['align_error'] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
            loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
            loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "avg_spec_length": [avg_spec_length, 1],  # value, precision
                "avg_text_length": [avg_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
                                    model_loss=loss_dict['loss'],
                                    amp_state_dict=amp.state_dict() if amp else None)

                # Diagnostic visualizations
                # direct pass on model for spec predictions
                spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
                spec_pred = spec_pred.permute(0, 2, 1)
                gt_spec = mel_input.permute(0, 2, 1)
                const_spec = spec_pred[0].data.cpu().numpy()
                gt_spec = gt_spec[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step