示例#1
0
def train(model, criterion, optimizer, scheduler, ap, global_step):
    data_loader = setup_loader(ap, is_val=False, verbose=True)
    model.train()
    epoch_time = 0
    best_loss = float('inf')
    avg_loss = 0
    end_time = time.time()
    for _, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        inputs = data[0]
        loader_time = time.time() - end_time
        global_step += 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            inputs = inputs.cuda(non_blocking=True)
            # labels = labels.cuda(non_blocking=True)

        # forward pass model
        outputs = model(inputs)

        # loss computation
        loss = criterion(
            outputs.view(c.num_speakers_in_batch,
                         outputs.shape[0] // c.num_speakers_in_batch, -1))
        loss.backward()
        grad_norm, _ = check_update(model, c.grad_clip)
        optimizer.step()

        step_time = time.time() - start_time
        epoch_time += step_time

        avg_loss = 0.01 * loss.item(
        ) + 0.99 * avg_loss if avg_loss != 0 else loss.item()
        current_lr = optimizer.param_groups[0]['lr']

        if global_step % c.steps_plot_stats == 0:
            # Plot Training Epoch Stats
            train_stats = {
                "loss": avg_loss,
                "lr": current_lr,
                "grad_norm": grad_norm,
                "step_time": step_time
            }
            tb_logger.tb_train_epoch_stats(global_step, train_stats)
            figures = {
                # FIXME: not constant
                "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(),
                                             10),
            }
            tb_logger.tb_train_figures(global_step, figures)

        if global_step % c.print_step == 0:
            print(
                "   | > Step:{}  Loss:{:.5f}  AvgLoss:{:.5f}  GradNorm:{:.5f}  "
                "StepTime:{:.2f}  LoaderTime:{:.2f}  LR:{:.6f}".format(
                    global_step, loss.item(), avg_loss, grad_norm, step_time,
                    loader_time, current_lr),
                flush=True)

        # save best model
        best_loss = save_best_model(model, optimizer, avg_loss, best_loss,
                                    OUT_PATH, global_step)

        end_time = time.time()
    return avg_loss, global_step
示例#2
0
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
          ap, global_step, epoch, scaler, scaler_st):
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        (
            text_input,
            text_lengths,
            mel_input,
            mel_lengths,
            linear_input,
            stop_targets,
            speaker_ids,
            speaker_embeddings,
            max_text_length,
            max_spec_length,
        ) = format_data(data)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()

        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        with torch.cuda.amp.autocast(enabled=c.mixed_precision):
            # forward pass model
            if c.bidirectional_decoder or c.double_decoder_consistency:
                (
                    decoder_output,
                    postnet_output,
                    alignments,
                    stop_tokens,
                    decoder_backward_output,
                    alignments_backward,
                ) = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    mel_lengths,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings,
                )
            else:
                decoder_output, postnet_output, alignments, stop_tokens = model(
                    text_input,
                    text_lengths,
                    mel_input,
                    mel_lengths,
                    speaker_ids=speaker_ids,
                    speaker_embeddings=speaker_embeddings,
                )
                decoder_backward_output = None
                alignments_backward = None

            # set the [alignment] lengths wrt reduction factor for guided attention
            if mel_lengths.max() % model.decoder.r != 0:
                alignment_lengths = (
                    mel_lengths +
                    (model.decoder.r -
                     (mel_lengths.max() % model.decoder.r))) // model.decoder.r
            else:
                alignment_lengths = mel_lengths // model.decoder.r

            # compute loss
            loss_dict = criterion(
                postnet_output,
                decoder_output,
                mel_input,
                linear_input,
                stop_tokens,
                stop_targets,
                mel_lengths,
                decoder_backward_output,
                alignments,
                alignment_lengths,
                alignments_backward,
                text_lengths,
            )

        # check nan loss
        if torch.isnan(loss_dict["loss"]).any():
            raise RuntimeError(f"Detected NaN loss at step {global_step}.")

        # optimizer step
        if c.mixed_precision:
            # model optimizer step in mixed precision mode
            scaler.scale(loss_dict["loss"]).backward()
            scaler.unscale_(optimizer)
            optimizer, current_lr = adam_weight_decay(optimizer)
            grad_norm, _ = check_update(model,
                                        c.grad_clip,
                                        ignore_stopnet=True)
            scaler.step(optimizer)
            scaler.update()

            # stopnet optimizer step
            if c.separate_stopnet:
                scaler_st.scale(loss_dict["stopnet_loss"]).backward()
                scaler.unscale_(optimizer_st)
                optimizer_st, _ = adam_weight_decay(optimizer_st)
                grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
                scaler_st.step(optimizer)
                scaler_st.update()
            else:
                grad_norm_st = 0
        else:
            # main model optimizer step
            loss_dict["loss"].backward()
            optimizer, current_lr = adam_weight_decay(optimizer)
            grad_norm, _ = check_update(model,
                                        c.grad_clip,
                                        ignore_stopnet=True)
            optimizer.step()

            # stopnet optimizer step
            if c.separate_stopnet:
                loss_dict["stopnet_loss"].backward()
                optimizer_st, _ = adam_weight_decay(optimizer_st)
                grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
                optimizer_st.step()
            else:
                grad_norm_st = 0

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        loss_dict["align_error"] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict["postnet_loss"] = reduce_tensor(
                loss_dict["postnet_loss"].data, num_gpus)
            loss_dict["decoder_loss"] = reduce_tensor(
                loss_dict["decoder_loss"].data, num_gpus)
            loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
            loss_dict["stopnet_loss"] = (reduce_tensor(
                loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else
                                         loss_dict["stopnet_loss"])

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values["avg_" + key] = value
        update_train_values["avg_loader_time"] = loader_time
        update_train_values["avg_step_time"] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "max_spec_length": [max_spec_length, 1],  # value, precision
                "max_text_length": [max_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time,
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        model.decoder.r,
                        OUT_PATH,
                        optimizer_st=optimizer_st,
                        model_loss=loss_dict["postnet_loss"],
                        characters=model_characters,
                        scaler=scaler.state_dict()
                        if c.mixed_precision else None,
                    )

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = (linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy())
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction":
                    plot_spectrogram(const_spec, ap, output_fig=False),
                    "ground_truth":
                    plot_spectrogram(gt_spec, ap, output_fig=False),
                    "alignment":
                    plot_alignment(align_img, output_fig=False),
                }

                if c.bidirectional_decoder or c.double_decoder_consistency:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy(),
                        output_fig=False)

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {"TrainAudio": train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
示例#3
0
def train(model, optimizer, scheduler, criterion, data_loader,
          eval_data_loader, global_step):
    model.train()
    best_loss = float("inf")
    avg_loader_time = 0
    end_time = time.time()
    for epoch in range(c.epochs):
        tot_loss = 0
        epoch_time = 0
        for _, data in enumerate(data_loader):
            start_time = time.time()

            # setup input data
            inputs, labels = data
            # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
            labels = torch.transpose(
                labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0,
                1).reshape(labels.shape)
            inputs = torch.transpose(
                inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1),
                0, 1).reshape(inputs.shape)
            # ToDo: move it to a unit test
            # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
            # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
            # idx = 0
            # for j in range(0, c.num_classes_in_batch, 1):
            #     for i in range(j, len(labels), c.num_classes_in_batch):
            #         if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])):
            #             print("Invalid")
            #             print(labels)
            #             exit()
            #         idx += 1
            # labels = labels_converted
            # inputs = inputs_converted

            loader_time = time.time() - end_time
            global_step += 1

            # setup lr
            if c.lr_decay:
                scheduler.step()
            optimizer.zero_grad()

            # dispatch data to GPU
            if use_cuda:
                inputs = inputs.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)

            # forward pass model
            outputs = model(inputs)

            # loss computation
            loss = criterion(
                outputs.view(c.num_classes_in_batch,
                             outputs.shape[0] // c.num_classes_in_batch, -1),
                labels)
            loss.backward()
            grad_norm, _ = check_update(model, c.grad_clip)
            optimizer.step()

            step_time = time.time() - start_time
            epoch_time += step_time

            # acumulate the total epoch loss
            tot_loss += loss.item()

            # Averaged Loader Time
            num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
            avg_loader_time = (
                1 / num_loader_workers * loader_time +
                (num_loader_workers - 1) / num_loader_workers * avg_loader_time
                if avg_loader_time != 0 else loader_time)
            current_lr = optimizer.param_groups[0]["lr"]

            if global_step % c.steps_plot_stats == 0:
                # Plot Training Epoch Stats
                train_stats = {
                    "loss": loss.item(),
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "step_time": step_time,
                    "avg_loader_time": avg_loader_time,
                }
                dashboard_logger.train_epoch_stats(global_step, train_stats)
                figures = {
                    "UMAP Plot":
                    plot_embeddings(outputs.detach().cpu().numpy(),
                                    c.num_classes_in_batch),
                }
                dashboard_logger.train_figures(global_step, figures)

            if global_step % c.print_step == 0:
                print(
                    "   | > Step:{}  Loss:{:.5f}  GradNorm:{:.5f}  "
                    "StepTime:{:.2f}  LoaderTime:{:.2f}  AvGLoaderTime:{:.2f}  LR:{:.6f}"
                    .format(global_step, loss.item(), grad_norm, step_time,
                            loader_time, avg_loader_time, current_lr),
                    flush=True,
                )

            if global_step % c.save_step == 0:
                # save model
                save_checkpoint(model, optimizer, criterion, loss.item(),
                                OUT_PATH, global_step, epoch)

            end_time = time.time()

        print("")
        print(
            ">>> Epoch:{}  AvgLoss: {:.5f} GradNorm:{:.5f}  "
            "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
                epoch, tot_loss / len(data_loader), grad_norm, epoch_time,
                avg_loader_time),
            flush=True,
        )
        # evaluation
        if c.run_eval:
            model.eval()
            eval_loss = evaluation(model, criterion, eval_data_loader,
                                   global_step)
            print("\n\n")
            print("--> EVAL PERFORMANCE")
            print(
                "   | > Epoch:{}  AvgLoss: {:.5f} ".format(epoch, eval_loss),
                flush=True,
            )
            # save the best checkpoint
            best_loss = save_best_model(model, optimizer, criterion, eval_loss,
                                        best_loss, OUT_PATH, global_step,
                                        epoch)
            model.train()

    return best_loss, global_step
示例#4
0
def train(model,
          criterion,
          optimizer,
          optimizer_st,
          scheduler,
          ap,
          global_step,
          epoch,
          amp,
          speaker_mapping=None):
    data_loader = setup_loader(ap,
                               model.decoder.r,
                               is_val=False,
                               verbose=(epoch == 0),
                               speaker_mapping=speaker_mapping)
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(
            data, speaker_mapping)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # forward pass model
        if c.bidirectional_decoder or c.double_decoder_consistency:
            decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                speaker_ids=speaker_ids,
                speaker_embeddings=speaker_embeddings)
        else:
            decoder_output, postnet_output, alignments, stop_tokens = model(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                speaker_ids=speaker_ids,
                speaker_embeddings=speaker_embeddings)
            decoder_backward_output = None
            alignments_backward = None

        # set the [alignment] lengths wrt reduction factor for guided attention
        if mel_lengths.max() % model.decoder.r != 0:
            alignment_lengths = (
                mel_lengths +
                (model.decoder.r -
                 (mel_lengths.max() % model.decoder.r))) // model.decoder.r
        else:
            alignment_lengths = mel_lengths // model.decoder.r

        # compute loss
        loss_dict = criterion(postnet_output, decoder_output, mel_input,
                              linear_input, stop_tokens, stop_targets,
                              mel_lengths, decoder_backward_output, alignments,
                              alignment_lengths, alignments_backward,
                              text_lengths)

        # backward pass
        if amp is not None:
            with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss_dict['loss'].backward()

        optimizer, current_lr = adam_weight_decay(optimizer)
        if amp:
            amp_opt_params = amp.master_params(optimizer)
        else:
            amp_opt_params = None
        grad_norm, _ = check_update(model,
                                    c.grad_clip,
                                    ignore_stopnet=True,
                                    amp_opt_params=amp_opt_params)
        optimizer.step()

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        loss_dict['align_error'] = align_error

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            loss_dict['stopnet_loss'].backward()
            optimizer_st, _ = adam_weight_decay(optimizer_st)
            if amp:
                amp_opt_params = amp.master_params(optimizer)
            else:
                amp_opt_params = None
            grad_norm_st, _ = check_update(model.decoder.stopnet,
                                           1.0,
                                           amp_opt_params=amp_opt_params)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict['postnet_loss'] = reduce_tensor(
                loss_dict['postnet_loss'].data, num_gpus)
            loss_dict['decoder_loss'] = reduce_tensor(
                loss_dict['decoder_loss'].data, num_gpus)
            loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus)
            loss_dict['stopnet_loss'] = reduce_tensor(
                loss_dict['stopnet_loss'].data,
                num_gpus) if c.stopnet else loss_dict['stopnet_loss']

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "avg_spec_length": [avg_spec_length, 1],  # value, precision
                "avg_text_length": [avg_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        model.decoder.r,
                        OUT_PATH,
                        optimizer_st=optimizer_st,
                        model_loss=loss_dict['postnet_loss'],
                        amp_state_dict=amp.state_dict() if amp else None)

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction":
                    plot_spectrogram(const_spec, ap, output_fig=False),
                    "ground_truth":
                    plot_spectrogram(gt_spec, ap, output_fig=False),
                    "alignment":
                    plot_alignment(align_img, output_fig=False),
                }

                if c.bidirectional_decoder or c.double_decoder_consistency:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy(),
                        output_fig=False)

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
示例#5
0
def train(model, optimizer, scheduler, criterion, data_loader, global_step):
    model.train()
    epoch_time = 0
    best_loss = float("inf")
    avg_loss = 0
    avg_loss_all = 0
    avg_loader_time = 0
    end_time = time.time()

    for _, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        inputs, labels = data
        loader_time = time.time() - end_time
        global_step += 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            inputs = inputs.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)

        # forward pass model
        outputs = model(inputs)

        # loss computation
        loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels)
        loss.backward()
        grad_norm, _ = check_update(model, c.grad_clip)
        optimizer.step()

        step_time = time.time() - start_time
        epoch_time += step_time

        # Averaged Loss and Averaged Loader Time
        avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item()
        num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
        avg_loader_time = (
            1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
            if avg_loader_time != 0
            else loader_time
        )
        current_lr = optimizer.param_groups[0]["lr"]

        if global_step % c.steps_plot_stats == 0:
            # Plot Training Epoch Stats
            train_stats = {
                "loss": avg_loss,
                "lr": current_lr,
                "grad_norm": grad_norm,
                "step_time": step_time,
                "avg_loader_time": avg_loader_time,
            }
            tb_logger.tb_train_epoch_stats(global_step, train_stats)
            figures = {
                # FIXME: not constant
                "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
            }
            tb_logger.tb_train_figures(global_step, figures)

        if global_step % c.print_step == 0:
            print(
                "   | > Step:{}  Loss:{:.5f}  AvgLoss:{:.5f}  GradNorm:{:.5f}  "
                "StepTime:{:.2f}  LoaderTime:{:.2f}  AvGLoaderTime:{:.2f}  LR:{:.6f}".format(
                    global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr
                ),
                flush=True,
            )
        avg_loss_all += avg_loss

        if global_step >= c.max_train_step or global_step % c.save_step == 0:
            # save best model only
            best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step)
            avg_loss_all = 0
            if global_step >= c.max_train_step:
                break

        end_time = time.time()

    return avg_loss, global_step
示例#6
0
def train(model, criterion, optimizer, optimizer_st, scheduler, ap,
          global_step, epoch):
    data_loader = setup_loader(ap,
                               model.decoder.r,
                               is_val=False,
                               verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    train_values = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stopnet_loss': 0,
        'avg_align_error': 0,
        'avg_step_time': 0,
        'avg_loader_time': 0
    }
    if c.bidirectional_decoder:
        train_values['avg_decoder_b_loss'] = 0  # decoder backward loss
        train_values['avg_decoder_c_loss'] = 0  # decoder consistency loss
    if c.ga_alpha > 0:
        train_values['avg_ga_loss'] = 0  # guidede attention loss
    keep_avg = KeepAverage()
    keep_avg.add_values(train_values)
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(
            data)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # forward pass model
        if c.bidirectional_decoder:
            decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
        else:
            decoder_output, postnet_output, alignments, stop_tokens = model(
                text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
            decoder_backward_output = None

        # set the alignment lengths wrt reduction factor for guided attention
        if mel_lengths.max() % model.decoder.r != 0:
            alignment_lengths = (
                mel_lengths +
                (model.decoder.r -
                 (mel_lengths.max() % model.decoder.r))) // model.decoder.r
        else:
            alignment_lengths = mel_lengths // model.decoder.r

        # compute loss
        loss_dict = criterion(postnet_output, decoder_output, mel_input,
                              linear_input, stop_tokens, stop_targets,
                              mel_lengths, decoder_backward_output, alignments,
                              alignment_lengths, text_lengths)
        if c.bidirectional_decoder:
            keep_avg.update_values({
                'avg_decoder_b_loss':
                loss_dict['decoder_backward_loss'].item(),
                'avg_decoder_c_loss':
                loss_dict['decoder_c_loss'].item()
            })
        if c.ga_alpha > 0:
            keep_avg.update_values(
                {'avg_ga_loss': loss_dict['ga_loss'].item()})

        # backward pass
        loss_dict['loss'].backward()
        optimizer, current_lr = adam_weight_decay(optimizer)
        grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
        optimizer.step()

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        keep_avg.update_value('avg_align_error', align_error)
        loss_dict['align_error'] = align_error

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            loss_dict['stopnet_loss'].backward()
            optimizer_st, _ = adam_weight_decay(optimizer_st)
            grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        epoch_time += step_time

        # update avg stats
        update_train_values = {
            'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
            'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
            'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \
                if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
            'avg_step_time': step_time,
            'avg_loader_time': loader_time
        }
        keep_avg.update_values(update_train_values)

        if global_step % c.print_step == 0:
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      avg_spec_length, avg_text_length,
                                      step_time, loader_time, current_lr,
                                      loss_dict, keep_avg.avg_values)

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict['postnet_loss'] = reduce_tensor(
                loss_dict['postnet_loss'].data, num_gpus)
            loss_dict['decoder_loss'] = reduce_tensor(
                loss_dict['decoder_loss'].data, num_gpus)
            loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus)
            loss_dict['stopnet_loss'] = reduce_tensor(
                loss_dict['stopnet_loss'].data,
                num_gpus) if c.stopnet else loss_dict['stopnet_loss']

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % 10 == 0:
                iter_stats = {
                    "loss_posnet": loss_dict['postnet_loss'].item(),
                    "loss_decoder": loss_dict['decoder_loss'].item(),
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time
                }
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(
                        model,
                        optimizer,
                        global_step,
                        epoch,
                        model.decoder.r,
                        OUT_PATH,
                        optimizer_st=optimizer_st,
                        model_loss=loss_dict['postnet_loss'].item())

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                if c.bidirectional_decoder:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy())

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": keep_avg['avg_postnet_loss'],
            "loss_decoder": keep_avg['avg_decoder_loss'],
            "stopnet_loss": keep_avg['avg_stopnet_loss'],
            "alignment_score": keep_avg['avg_align_error'],
            "epoch_time": epoch_time
        }
        if c.ga_alpha > 0:
            epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss']
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step
示例#7
0
def train(model,
          criterion,
          optimizer,
          scheduler,
          ap,
          global_step,
          epoch,
          speaker_mapping=None):
    data_loader = setup_loader(ap,
                               1,
                               is_val=False,
                               verbose=(epoch == 0),
                               speaker_mapping=speaker_mapping)
    model.train()
    epoch_time = 0
    keep_avg = KeepAverage()
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    c_logger.print_train_start()
    scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
            avg_text_length, avg_spec_length, attn_mask = format_data(data)

        loader_time = time.time() - end_time

        global_step += 1
        optimizer.zero_grad()

        # forward pass model
        with torch.cuda.amp.autocast(enabled=c.mixed_precision):
            z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                attn_mask,
                g=speaker_ids)

            # compute loss
            loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
                                  o_dur_log, o_total_dur, text_lengths)

        # backward pass with loss scaling
        if c.mixed_precision:
            scaler.scale(loss_dict['loss']).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       c.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict['loss'].backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       c.grad_clip)
            optimizer.step()

        grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
        optimizer.step()

        # setup lr
        if c.noam_schedule:
            scheduler.step()

        # current_lr
        current_lr = optimizer.param_groups[0]['lr']

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(alignments)
        loss_dict['align_error'] = align_error

        step_time = time.time() - start_time
        epoch_time += step_time

        # aggregate losses from processes
        if num_gpus > 1:
            loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data,
                                                 num_gpus)
            loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data,
                                                  num_gpus)
            loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus)

        # detach loss values
        loss_dict_new = dict()
        for key, value in loss_dict.items():
            if isinstance(value, (int, float)):
                loss_dict_new[key] = value
            else:
                loss_dict_new[key] = value.item()
        loss_dict = loss_dict_new

        # update avg stats
        update_train_values = dict()
        for key, value in loss_dict.items():
            update_train_values['avg_' + key] = value
        update_train_values['avg_loader_time'] = loader_time
        update_train_values['avg_step_time'] = step_time
        keep_avg.update_values(update_train_values)

        # print training progress
        if global_step % c.print_step == 0:
            log_dict = {
                "avg_spec_length": [avg_spec_length, 1],  # value, precision
                "avg_text_length": [avg_text_length, 1],
                "step_time": [step_time, 4],
                "loader_time": [loader_time, 2],
                "current_lr": current_lr,
            }
            c_logger.print_train_step(batch_n_iter, num_iter, global_step,
                                      log_dict, loss_dict, keep_avg.avg_values)

        if args.rank == 0:
            # Plot Training Iter Stats
            # reduce TB load
            if global_step % c.tb_plot_step == 0:
                iter_stats = {
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "step_time": step_time
                }
                iter_stats.update(loss_dict)
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model,
                                    optimizer,
                                    global_step,
                                    epoch,
                                    1,
                                    OUT_PATH,
                                    model_loss=loss_dict['loss'])

                # Diagnostic visualizations
                # direct pass on model for spec predictions
                target_speaker = None if speaker_ids is None else speaker_ids[:
                                                                              1]
                spec_pred, *_ = model.inference(text_input[:1],
                                                text_lengths[:1],
                                                g=target_speaker)
                spec_pred = spec_pred.permute(0, 2, 1)
                gt_spec = mel_input.permute(0, 2, 1)
                const_spec = spec_pred[0].data.cpu().numpy()
                gt_spec = gt_spec[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                train_audio = ap.inv_melspectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    # print epoch stats
    c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)

    # Plot Epoch Stats
    if args.rank == 0:
        epoch_stats = {"epoch_time": epoch_time}
        epoch_stats.update(keep_avg.avg_values)
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg.avg_values, global_step