Ejemplo n.º 1
0
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
    model = model.train()
    epoch_time = 0
    avg_linear_loss = 0
    avg_mel_loss = 0
    avg_stop_loss = 0
    print(" | > Epoch {}/{}".format(epoch, c.epochs))
    progbar = Progbar(len(data_loader.dataset) / c.batch_size)
    n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        print(text_input)
        text_lengths = data[1]
        linear_input = data[2]
        mel_input = data[3]
        mel_lengths = data[4]
        stop_targets = data[5]

        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()

        current_step = num_iter + args.restore_step + \
                       epoch * len(data_loader) + 1

        # setup lr
        current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
        current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)

        for params_group in optimizer.param_groups:
            params_group['lr'] = current_lr

        for params_group in optimizer_st.param_groups:
            params_group['lr'] = current_lr_st

        optimizer.zero_grad()
        optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda()
            mel_input = mel_input.cuda()
            mel_lengths = mel_lengths.cuda()
            linear_input = linear_input.cuda()
            stop_targets = stop_targets.cuda()

        # forward pass
        mel_output, linear_output, alignments, stop_tokens = \
            model.forward(text_input, mel_input)

        # loss computation
        stop_loss = criterion_st(stop_tokens, stop_targets)
        mel_loss = criterion(mel_output, mel_input, mel_lengths)
        linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
                      + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
                                        linear_input[:, :, :n_priority_freq],
                                        mel_lengths)
        loss = mel_loss + linear_loss

        # backpass and check the grad norm for spec losses
        loss.backward(retain_graph=True)
        grad_norm, skip_flag = check_update(model, 0.5, 100)
        if skip_flag:
            optimizer.zero_grad()
            print(" | > Iteration skipped!!")
            continue
        optimizer.step()

        # backpass and check the grad norm for stop loss
        stop_loss.backward()
        grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
        if skip_flag:
            optimizer_st.zero_grad()
            print(" | > Iteration skipped fro stopnet!!")
            continue
        optimizer_st.step()

        step_time = time.time() - start_time
        epoch_time += step_time

        # update
        progbar.update(num_iter + 1, values=[('total_loss', loss.item()),
                                             ('linear_loss', linear_loss.item()),
                                             ('mel_loss', mel_loss.item()),
                                             ('stop_loss', stop_loss.item()),
                                             ('grad_norm', grad_norm.item()),
                                             ('grad_norm_st', grad_norm_st.item())])
        avg_linear_loss += linear_loss.item()
        avg_mel_loss += mel_loss.item()
        avg_stop_loss += stop_loss.item()

        # Plot Training Iter Stats
        tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
        tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.item(),
                      current_step)
        tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.item(), current_step)
        tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
                      current_step)
        tb.add_scalar('Params/GradNorm', grad_norm, current_step)
        tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step)
        tb.add_scalar('Time/StepTime', step_time, current_step)

        if current_step % c.save_step == 0:
            if c.checkpoint:
                # save model
                save_checkpoint(model, optimizer, linear_loss.item(),
                                OUT_PATH, current_step, epoch)

            # Diagnostic visualizations
            const_spec = linear_output[0].data.cpu().numpy()
            gt_spec = linear_input[0].data.cpu().numpy()

            const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
            gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
            tb.add_image('Visual/Reconstruction', const_spec, current_step)
            tb.add_image('Visual/GroundTruth', gt_spec, current_step)

            align_img = alignments[0].data.cpu().numpy()
            align_img = plot_alignment(align_img)
            tb.add_image('Visual/Alignment', align_img, current_step)

            # Sample audio
            audio_signal = linear_output[0].data.cpu().numpy()
            data_loader.dataset.ap.griffin_lim_iters = 60
            audio_signal = data_loader.dataset.ap.inv_spectrogram(
                audio_signal.T)
            try:
                tb.add_audio('SampleAudio', audio_signal, current_step,
                             sample_rate=c.sample_rate)
            except:
                # print("\n > Error at audio signal on TB!!")
                # print(audio_signal.max())
                # print(audio_signal.min())
                pass

    avg_linear_loss /= (num_iter + 1)
    avg_mel_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss

    # Plot Training Epoch Stats
    tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
    tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
    tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
    tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
    tb.add_scalar('Time/EpochTime', epoch_time, epoch)
    epoch_time = 0

    return avg_linear_loss, current_step
Ejemplo n.º 2
0
def train(model,
          criterion,
          criterion_st,
          optimizer,
          optimizer_st,
          scheduler,
          ap,
          epoch,
          use_half=False):
    data_loader = setup_loader(is_val=False,
                               verbose=(epoch == 0),
                               use_half=use_half)

    model.train()
    epoch_time = 0
    avg_postnet_loss = 0
    avg_decoder_loss = 0
    avg_stop_loss = 0
    avg_step_time = 0
    print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
    start_time = time.time()
    for num_iter, data in enumerate(data_loader):
        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        linear_input = data[2] if c.model == "Tacotron" else None
        mel_input = data[3] if not use_half else data[3].type(torch.half)
        mel_lengths = data[4] if not use_half else data[4].type(torch.half)
        stop_targets = data[5]
        avg_text_length = torch.mean(text_lengths.float())
        avg_spec_length = torch.mean(mel_lengths.float())
        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0],
                                         stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) >
                        0.0).unsqueeze(2).float().squeeze(2)
        stop_targets = stop_targets if not use_half else stop_targets.type(
            torch.half)

        current_step = num_iter + args.restore_step + \
            epoch * len(data_loader) + 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st: optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda(non_blocking=True)
            text_lengths = text_lengths.cuda(non_blocking=True)
            mel_input = mel_input.cuda(non_blocking=True)
            mel_lengths = mel_lengths.cuda(non_blocking=True)
            linear_input = linear_input.cuda(
                non_blocking=True) if c.model == "Tacotron" else None
            stop_targets = stop_targets.cuda(non_blocking=True)
        decoder_output, postnet_output, alignments, stop_tokens = model(
            text_input, text_lengths, mel_input)

        # loss computation
        stop_loss = criterion_st(stop_tokens,
                                 stop_targets) if c.stopnet else torch.zeros(1)

        if c.loss_masking:
            decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
            if c.model == "Tacotron":
                postnet_loss = criterion(postnet_output, linear_input,
                                         mel_lengths)
            else:
                postnet_loss = criterion(postnet_output, mel_input,
                                         mel_lengths)
        else:
            decoder_loss = criterion(decoder_output, mel_input)
            if c.model == "Tacotron":
                postnet_loss = criterion(postnet_output, linear_input)
            else:
                postnet_loss = criterion(postnet_output, mel_input)
        USE_HALF_LOSS_SCALE = 10.0
        if use_half:
            postnet_loss = postnet_loss * USE_HALF_LOSS_SCALE
            decoder_loss = decoder_loss * USE_HALF_LOSS_SCALE
        loss = decoder_loss + postnet_loss

        if not c.separate_stopnet and c.stopnet:
            loss += stop_loss

        loss.backward()
        optimizer, current_lr = weight_decay(optimizer, c.wd)
        grad_norm, _ = check_update(model, c.grad_clip)
        optimizer.step()

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            USE_HALF_STOP_LOSS_SCALE = 1
            stop_loss = stop_loss * USE_HALF_STOP_LOSS_SCALE
            stop_loss.backward()
            optimizer_st, _ = weight_decay(optimizer_st, c.wd)
            grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        start_time = time.time()
        epoch_time += step_time

        if current_step % c.print_step == 0:
            print(
                "   | > Step:{}/{}  GlobalStep:{}  TotalLoss:{:.5f}  PostnetLoss:{:.5f}  "
                "DecoderLoss:{:.5f}  StopLoss:{:.5f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  LR:{:.6f}"
                .format(num_iter, batch_n_iter, current_step, loss.item(),
                        postnet_loss.item(), decoder_loss.item(),
                        stop_loss.item(), grad_norm, grad_norm_st,
                        avg_text_length, avg_spec_length, step_time,
                        current_lr),
                flush=True)
        # aggregate losses from processes
        if num_gpus > 1:
            postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
            decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
            loss = reduce_tensor(loss.data, num_gpus)
            stop_loss = reduce_tensor(stop_loss.data,
                                      num_gpus) if c.stopnet else stop_loss
        if args.rank == 0:
            avg_postnet_loss += float(postnet_loss.item())
            avg_decoder_loss += float(decoder_loss.item())
            avg_stop_loss += stop_loss if type(stop_loss) is float else float(
                stop_loss.item())
            avg_step_time += step_time
            # Plot Training Iter Stats
            iter_stats = {
                "loss_posnet": postnet_loss.item(),
                "loss_decoder": decoder_loss.item(),
                "lr": current_lr,
                "grad_norm": grad_norm,
                "grad_norm_st": grad_norm_st,
                "step_time": step_time
            }
            tb_logger.tb_train_iter_stats(current_step, iter_stats)
            if current_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model, optimizer, optimizer_st,
                                    postnet_loss.item(), OUT_PATH,
                                    current_step, epoch)

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().type(
                    torch.float).numpy()
                gt_spec = linear_input[0].data.cpu().type(torch.float).numpy(
                ) if c.model == "Tacotron" else mel_input[0].data.cpu().type(
                    torch.float).numpy()
                align_img = alignments[0].data.cpu().type(torch.float).numpy()
                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }
                tb_logger.tb_train_figures(current_step, figures)
                # Sample audio
                if c.model == "Tacotron":
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_train_audios(current_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])

    avg_postnet_loss /= (num_iter + 1)
    avg_decoder_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
    avg_step_time /= (num_iter + 1)

    # print epoch stats
    print("   | > EPOCH END -- GlobalStep:{}  AvgTotalLoss:{:.5f}  "
          "AvgPostnetLoss:{:.5f}  AvgDecoderLoss:{:.5f}  "
          "AvgStopLoss:{:.5f}  EpochTime:{:.2f}  "
          "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
                                      avg_postnet_loss, avg_decoder_loss,
                                      avg_stop_loss, epoch_time,
                                      avg_step_time),
          flush=True)

    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": avg_postnet_loss,
            "loss_decoder": avg_decoder_loss,
            "stop_loss": avg_stop_loss,
            "epoch_time": epoch_time
        }
        tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, current_step)
    return avg_postnet_loss, current_step
Ejemplo n.º 3
0
Archivo: train.py Proyecto: wurde/TTS
def train(model, criterion, criterion_st, optimizer, optimizer_st,
          scheduler, ap, epoch):
    data_loader = setup_loader(is_val=False)
    model.train()
    epoch_time = 0
    avg_linear_loss = 0
    avg_mel_loss = 0
    avg_stop_loss = 0
    avg_step_time = 0
    print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    n_priority_freq = int(
        3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
    batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        linear_input = data[2]
        mel_input = data[3]
        mel_lengths = data[4]
        stop_targets = data[5]
        avg_text_length = torch.mean(text_lengths.float())
        avg_spec_length = torch.mean(mel_lengths.float())

        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0],
                                         stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()

        current_step = num_iter + args.restore_step + \
            epoch * len(data_loader) + 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()
        optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda(non_blocking=True)
            text_lengths = text_lengths.cuda(non_blocking=True)
            mel_input = mel_input.cuda(non_blocking=True)
            mel_lengths = mel_lengths.cuda(non_blocking=True)
            linear_input = linear_input.cuda(non_blocking=True)
            stop_targets = stop_targets.cuda(non_blocking=True)

        # compute mask for padding
        mask = sequence_mask(text_lengths)

        # forward pass
        if use_cuda:
            mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
                model, (text_input, mel_input, mask))
        else:
            mel_output, linear_output, alignments, stop_tokens = model(
                text_input, mel_input, mask)

        # loss computation
        stop_loss = criterion_st(stop_tokens, stop_targets)
        mel_loss = criterion(mel_output, mel_input, mel_lengths)
        linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
            + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
                              linear_input[:, :, :n_priority_freq],
                              mel_lengths)
        loss = mel_loss + linear_loss

        # backpass and check the grad norm for spec losses
        loss.backward(retain_graph=True)
        # custom weight decay
        for group in optimizer.param_groups:
            for param in group['params']:
                current_lr = group['lr']
                param.data = param.data.add(-c.wd * group['lr'], param.data)
        grad_norm, skip_flag = check_update(model, 1)
        if skip_flag:
            optimizer.zero_grad()
            print("   | > Iteration skipped!!", flush=True)
            continue
        optimizer.step()

        # backpass and check the grad norm for stop loss
        stop_loss.backward()
        # custom weight decay
        for group in optimizer_st.param_groups:
            for param in group['params']:
                param.data = param.data.add(-c.wd * group['lr'], param.data)
        grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
        if skip_flag:
            optimizer_st.zero_grad()
            print("   | > Iteration skipped fro stopnet!!")
            continue
        optimizer_st.step()

        step_time = time.time() - start_time
        epoch_time += step_time

        if current_step % c.print_step == 0:
            print(
                "   | > Step:{}/{}  GlobalStep:{}  TotalLoss:{:.5f}  LinearLoss:{:.5f}  "
                "MelLoss:{:.5f}  StopLoss:{:.5f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  LR:{:.6f}".format(
                    num_iter, batch_n_iter, current_step, loss.item(),
                    linear_loss.item(), mel_loss.item(), stop_loss.item(),
                    grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
                flush=True)

        avg_linear_loss += linear_loss.item()
        avg_mel_loss += mel_loss.item()
        avg_stop_loss += stop_loss.item()
        avg_step_time += step_time

        # Plot Training Iter Stats
        tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
        tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.item(),
                      current_step)
        tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.item(), current_step)
        tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
                      current_step)
        tb.add_scalar('Params/GradNorm', grad_norm, current_step)
        tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step)
        tb.add_scalar('Time/StepTime', step_time, current_step)

        if current_step % c.save_step == 0:
            if c.checkpoint:
                # save model
                save_checkpoint(model, optimizer, optimizer_st,
                                linear_loss.item(), OUT_PATH, current_step,
                                epoch)

            # Diagnostic visualizations
            const_spec = linear_output[0].data.cpu().numpy()
            gt_spec = linear_input[0].data.cpu().numpy()

            const_spec = plot_spectrogram(const_spec, ap)
            gt_spec = plot_spectrogram(gt_spec, ap)
            tb.add_figure('Visual/Reconstruction', const_spec, current_step)
            tb.add_figure('Visual/GroundTruth', gt_spec, current_step)

            align_img = alignments[0].data.cpu().numpy()
            align_img = plot_alignment(align_img)
            tb.add_figure('Visual/Alignment', align_img, current_step)

            # Sample audio
            audio_signal = linear_output[0].data.cpu().numpy()
            ap.griffin_lim_iters = 60
            audio_signal = ap.inv_spectrogram(audio_signal.T)
            try:
                tb.add_audio(
                    'SampleAudio',
                    audio_signal,
                    current_step,
                    sample_rate=c.sample_rate)
            except:
                pass

    avg_linear_loss /= (num_iter + 1)
    avg_mel_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
    avg_step_time /= (num_iter + 1)

    # print epoch stats
    print(
        "   | > EPOCH END -- GlobalStep:{}  AvgTotalLoss:{:.5f}  "
        "AvgLinearLoss:{:.5f}  AvgMelLoss:{:.5f}  "
        "AvgStopLoss:{:.5f}  EpochTime:{:.2f}  "
        "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
                                    avg_linear_loss, avg_mel_loss,
                                    avg_stop_loss, epoch_time, avg_step_time),
        flush=True)

    # Plot Training Epoch Stats
    tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
    tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
    tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
    tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
    tb.add_scalar('Time/EpochTime', epoch_time, epoch)
    epoch_time = 0
    return avg_linear_loss, current_step
Ejemplo n.º 4
0
Archivo: train.py Proyecto: geneing/TTS
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
          ap, global_step, epoch, criterion_gst=None, optimizer_gst=None):
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    if c.use_speaker_embedding:
        speaker_mapping = load_speaker_mapping(OUT_PATH)
    model.train()
    epoch_time = 0
    avg_postnet_loss = 0
    avg_decoder_loss = 0
    avg_stop_loss = 0
    avg_gst_loss = 0
    avg_step_time = 0
    avg_loader_time = 0

    print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    if use_cuda:
        batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        speaker_names = data[2]
        linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
        mel_input = data[4]
        mel_lengths = data[5]
        stop_targets = data[6]
        avg_text_length = torch.mean(text_lengths.float())
        avg_spec_length = torch.mean(mel_lengths.float())
        loader_time = time.time() - end_time

        if c.use_speaker_embedding:
            speaker_ids = [speaker_mapping[speaker_name]
                           for speaker_name in speaker_names]
            speaker_ids = torch.LongTensor(speaker_ids)
        else:
            speaker_ids = None

        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0],
                                         stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)

        global_step += 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_gst:
            optimizer_gst.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda(non_blocking=True)
            text_lengths = text_lengths.cuda(non_blocking=True)
            mel_input = mel_input.cuda(non_blocking=True)
            mel_lengths = mel_lengths.cuda(non_blocking=True)
            linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
            stop_targets = stop_targets.cuda(non_blocking=True)
            if speaker_ids is not None:
                speaker_ids = speaker_ids.cuda(non_blocking=True)

        # forward pass model
        decoder_output, postnet_output, alignments, stop_tokens, text_gst = model(
            text_input, text_lengths, mel_input, speaker_ids=speaker_ids)

        # loss computation
        stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
        gst_loss = torch.zeros(1)
        if c.loss_masking:
            decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
            else:
                postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
        else:
            decoder_loss = criterion(decoder_output, mel_input)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input)
            else:
                postnet_loss = criterion(postnet_output, mel_input)
        loss = decoder_loss + postnet_loss
        if not c.separate_stopnet and c.stopnet:
            loss += stop_loss
        if c.text_gst and criterion_gst and optimizer_gst:
            mel_gst, _ = model.gst(mel_input)
            gst_loss = criterion_gst(text_gst, mel_gst.squeeze().detach())
            gst_loss.backward()
            optimizer_gst.step()

        loss.backward()
        optimizer, current_lr = weight_decay(optimizer, c.wd)
        grad_norm, _ = check_update(model, c.grad_clip)
        optimizer.step()

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            stop_loss.backward()
            optimizer_st, _ = weight_decay(optimizer_st, c.wd)
            grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
            optimizer_st.step()
        else:
            grad_norm_st = 0
        
        step_time = time.time() - start_time
        epoch_time += step_time

        if global_step % c.print_step == 0:
            print(
                "   | > Step:{}/{}  GlobalStep:{}  TotalLoss:{:.5f}  PostnetLoss:{:.5f}  "
                "DecoderLoss:{:.5f}  StopLoss:{:.5f} GSTLoss:{:.5f} GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  "
                "LoaderTime:{:.2f}  LR:{:.6f}".format(
                    num_iter, batch_n_iter, global_step, loss.item(),
                    postnet_loss.item(), decoder_loss.item(), stop_loss.item(), gst_loss.item(),
                    grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time,
                    loader_time, current_lr),
                flush=True)

        # aggregate losses from processes
        if num_gpus > 1:
            postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
            decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
            gst_loss = reduce_tensor(gst_loss.data, num_gpus) if c.text_gst else gst_loss
            loss = reduce_tensor(loss.data, num_gpus)
            stop_loss = reduce_tensor(stop_loss.data, num_gpus) if c.stopnet else stop_loss

        if args.rank == 0:
            avg_postnet_loss += float(postnet_loss.item())
            avg_decoder_loss += float(decoder_loss.item())
            avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item())
            avg_gst_loss += float(gst_loss.item())
            avg_step_time += step_time
            avg_loader_time += loader_time

            # Plot Training Iter Stats
            # reduce TB load
            if global_step % 10 == 0:
                iter_stats = {"loss_posnet": postnet_loss.item(),
                              "loss_decoder": decoder_loss.item(),
                              "gst_loss" : gst_loss.item(),
                              "lr": current_lr,
                              "grad_norm": grad_norm,
                              "grad_norm_st": grad_norm_st,
                              "step_time": step_time}
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model, optimizer, optimizer_st, optimizer_gst,
                                    postnet_loss.item(), OUT_PATH, global_step,
                                    epoch)

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in ["Tacotron", "TacotronGST"] else  mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }
                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    avg_postnet_loss /= (num_iter + 1)
    avg_decoder_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_gst_loss /= (num_iter + 1)
    avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
    avg_step_time /= (num_iter + 1)
    avg_loader_time /= (num_iter + 1)

    # print epoch stats
    print(
        "   | > EPOCH END -- GlobalStep:{}  AvgTotalLoss:{:.5f}  "
        "AvgPostnetLoss:{:.5f}  AvgDecoderLoss:{:.5f}  AvgGSTLoss:{:.5f} "
        "AvgStopLoss:{:.5f}  EpochTime:{:.2f}  "
        "AvgStepTime:{:.2f}  AvgLoaderTime:{:.2f}".format(global_step, avg_total_loss,
                                                          avg_postnet_loss, avg_decoder_loss, avg_gst_loss,
                                                          avg_stop_loss, epoch_time, avg_step_time,
                                                          avg_loader_time),
        flush=True)

    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {"loss_postnet": avg_postnet_loss,
                       "loss_decoder": avg_decoder_loss,
                       "stop_loss": avg_stop_loss,
                       "gst_loss" : avg_gst_loss,
                       "epoch_time": epoch_time}
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return avg_postnet_loss, global_step
Ejemplo n.º 5
0
def train(model, optimizer, criterion, scheduler, epochs, batch_size, step, lr, args):
    global CONFIG
    global train_ids
    # create train loader
    train_loader = setup_loader(False)

    for p in optimizer.param_groups:
        p["initial_lr"] = lr
        p["lr"] = lr

    best_loss = float('inf')
    skipped_steps = 0
    for e in range(epochs):
        running_loss = 0.0
        start = time.time()
        iters = len(train_loader)
        # train loop
        print(" > Training", flush=True)
        model.train()
        for i, (x, m, y) in enumerate(train_loader):
            if use_cuda:
                x, m, y = x.cuda(), m.cuda(), y.cuda()
            optimizer.zero_grad()
            y_hat = model(x, m)
            # y_hat = y_hat.transpose(1, 2)
            if type(model.mode) == int :
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
            else:
                y = y.float()
            y = y.unsqueeze(-1)
            # m_scaled, _ = model.upsample(m)
            loss = criterion(y_hat, y)
            if loss.item() is None:
                raise RuntimeError(" [!] None loss. Exiting ...")
            loss.backward()
            grad_norm, skip_flag = check_update(model, CONFIG.grad_clip) 
            if not skip_flag:           
                optimizer.step()
                # Compute avg loss
                if num_gpus > 1:
                    loss = reduce_tensor(loss.data, num_gpus)
                running_loss += loss.item()
                avg_loss = running_loss / (i + 1 - skipped_steps)
            else:
                print(" [!] Skipping the step...")
                skipped_steps += 1
            scheduler.step()
            speed = (i + 1) / (time.time() - start)
            step += 1
            cur_lr = optimizer.param_groups[0]["lr"]
            
            if step % CONFIG.print_step == 0:
                print(
                    " | > Epoch: {}/{} -- Batch: {}/{} -- Loss: {:.3f}"
                    " -- Speed: {:.2f} steps/sec -- Step: {} -- lr: {} -- GradNorm: {}".format(
                        e + 1, epochs, i + 1, iters, avg_loss, speed, step, cur_lr, grad_norm
                    ), flush=True
                )
            if step % CONFIG.checkpoint_step == 0 and args.rank == 0:
                save_checkpoint(model, optimizer, avg_loss, MODEL_PATH, step, e)
                print(" > checkpoint saved", flush=True)
        # visual
        # m_scaled, _ = model.upsample(m)
        # plot_spec(m[0], VIS_PATH + "/mel_{}.png".format(step))
        # plot_spec(
        #     m_scaled[0].transpose(0, 1), VIS_PATH + "/mel_scaled_{}.png".format(step)
        # )
        # validation loop
        avg_val_loss = evaluate(model, criterion, batch_size)
        if args.rank == 0:
            best_loss = save_best_model(model, optimizer, avg_val_loss, best_loss, MODEL_PATH, step, e)
Ejemplo n.º 6
0
def train(model, criterion, data_loader, optimizer, epoch):
    model = model.train()
    epoch_time = 0
    avg_linear_loss = 0
    avg_mel_loss = 0

    print(" | > Epoch {}/{}".format(epoch, c.epochs))
    progbar = Progbar(len(data_loader.dataset) / c.batch_size)
    n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        linear_input = data[2]
        mel_input = data[3]
        mel_lengths = data[4]

        current_step = num_iter + args.restore_step + \
            epoch * len(data_loader) + 1

        # setup lr
        current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
        for params_group in optimizer.param_groups:
            params_group['lr'] = current_lr

        optimizer.zero_grad()

        # convert inputs to variables
        text_input_var = Variable(text_input)
        mel_spec_var = Variable(mel_input)
        mel_lengths_var = Variable(mel_lengths)
        linear_spec_var = Variable(linear_input, volatile=True)

        # dispatch data to GPU
        if use_cuda:
            text_input_var = text_input_var.cuda()
            mel_spec_var = mel_spec_var.cuda()
            mel_lengths_var = mel_lengths_var.cuda()
            linear_spec_var = linear_spec_var.cuda()

        # forward pass
        mel_output, linear_output, alignments =\
            model.forward(text_input_var, mel_spec_var)

        # loss computation
        mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
        linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
            + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
                              linear_spec_var[:, :, :n_priority_freq],
                              mel_lengths_var)
        loss = mel_loss + linear_loss

        # backpass and check the grad norm
        loss.backward()
        grad_norm, skip_flag = check_update(model, 0.5, 100)
        if skip_flag:
            optimizer.zero_grad()
            print(" | > Iteration skipped!!")
            continue
        optimizer.step()

        step_time = time.time() - start_time
        epoch_time += step_time

        # update
        progbar.update(num_iter + 1,
                       values=[('total_loss', loss.data[0]),
                               ('linear_loss', linear_loss.data[0]),
                               ('mel_loss', mel_loss.data[0]),
                               ('grad_norm', grad_norm)])
        avg_linear_loss += linear_loss.data[0]
        avg_mel_loss += mel_loss.data[0]

        # Plot Training Iter Stats
        tb.add_scalar('TrainIterLoss/TotalLoss', loss.data[0], current_step)
        tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0],
                      current_step)
        tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step)
        tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
                      current_step)
        tb.add_scalar('Params/GradNorm', grad_norm, current_step)
        tb.add_scalar('Time/StepTime', step_time, current_step)

        if current_step % c.save_step == 0:
            if c.checkpoint:
                # save model
                save_checkpoint(model, optimizer, linear_loss.data[0],
                                OUT_PATH, current_step, epoch)

            # Diagnostic visualizations
            const_spec = linear_output[0].data.cpu().numpy()
            gt_spec = linear_spec_var[0].data.cpu().numpy()

            const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
            gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
            tb.add_image('Visual/Reconstruction', const_spec, current_step)
            tb.add_image('Visual/GroundTruth', gt_spec, current_step)

            align_img = alignments[0].data.cpu().numpy()
            align_img = plot_alignment(align_img)
            tb.add_image('Visual/Alignment', align_img, current_step)

            # Sample audio
            audio_signal = linear_output[0].data.cpu().numpy()
            data_loader.dataset.ap.griffin_lim_iters = 60
            audio_signal = data_loader.dataset.ap.inv_spectrogram(
                audio_signal.T)
            try:
                tb.add_audio('SampleAudio',
                             audio_signal,
                             current_step,
                             sample_rate=c.sample_rate)
            except:
                # print("\n > Error at audio signal on TB!!")
                # print(audio_signal.max())
                # print(audio_signal.min())
                pass

    avg_linear_loss /= (num_iter + 1)
    avg_mel_loss /= (num_iter + 1)
    avg_total_loss = avg_mel_loss + avg_linear_loss

    # Plot Training Epoch Stats
    tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
    tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
    tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
    tb.add_scalar('Time/EpochTime', epoch_time, epoch)
    epoch_time = 0

    return avg_linear_loss, current_step
Ejemplo n.º 7
0
def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
    model = model.train()
    epoch_time = 0
    avg_linear_loss = 0
    avg_mel_loss = 0
    avg_stop_loss = 0
    avg_attn_loss = 0

    print(" | > Epoch {}/{}".format(epoch, c.epochs))
    progbar = Progbar(len(data_loader.dataset) / c.batch_size)
    progbar_display = {}
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        linear_spec = data[2]
        mel_spec = data[3]
        mel_lengths = data[4]
        stop_target = data[5]

        current_step = num_iter + args.restore_step + \
            epoch * len(data_loader) + 1

        # setup lr
        current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
        for params_group in optimizer.param_groups:
            params_group['lr'] = current_lr

        optimizer.zero_grad()
        
        stop_target = stop_target.view(text_input.shape[0], stop_target.size(1) // c.r, -1)
        stop_target = (stop_target.sum(2) > 0.0).float()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda()
            mel_spec = mel_spec.cuda()
            mel_lengths = mel_lengths.cuda()
            linear_spec = linear_spec.cuda()
            stop_target = stop_target.cuda()
            
        # create attention mask
        if c.mk > 0.0:
            N = text_input.shape[1]
            T = mel_spec.shape[1] // c.r
            M = create_attn_mask(N, T, 0.03)
            mk = mk_decay(c.mk, c.epochs, epoch)
        
        # forward pass
        mel_output, linear_output, alignments, stop_tokens =\
            model.forward(text_input, mel_spec)

        # loss computation
        mel_loss = criterion(mel_output, mel_spec, mel_lengths)
        linear_loss = criterion(linear_output, linear_spec, mel_lengths)
        stop_loss = criterion_st(stop_tokens, stop_target)
        if c.priority_freq:
            linear_loss =  0.5 * linear_loss\
                + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
                                  linear_spec[:, :, :n_priority_freq],
                                  mel_lengths)
        loss = mel_loss + linear_loss + stop_loss
        if c.mk > 0.0:
            attention_loss = criterion(alignments, M, mel_lengths)
            loss += mk * attention_loss
            avg_attn_loss += attention_loss.item()
            progbar_display['attn_loss'] = attention_loss.item()

        # backpass and check the grad norm
        loss.backward()
        grad_norm, skip_flag = check_update(model, 0.5, 100)
        if skip_flag:
            optimizer.zero_grad()
            print(" | > Iteration skipped!!")
            continue
        optimizer.step()

        step_time = time.time() - start_time
        epoch_time += step_time
        
        progbar_display['total_loss'] =  loss.item()
        progbar_display['linear_loss'] = linear_loss.item()
        progbar_display['mel_loss'] = mel_loss.item()
        progbar_display['stop_loss'] = stop_loss.item()
        progbar_display['grad_norm'] = grad_norm.item()

        # update
        progbar.update(num_iter+1, values=list(progbar_display.items()))
        avg_linear_loss += linear_loss.item()
        avg_mel_loss += mel_loss.item()
        avg_stop_loss += stop_loss.item()

        # Plot Training Iter Stats
        tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
        tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.item(),
                      current_step)
        tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.item(), current_step)
        tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
                      current_step)
        tb.add_scalar('Params/GradNorm', grad_norm, current_step)
        tb.add_scalar('Time/StepTime', step_time, current_step)

        if current_step % c.save_step == 0:
            if c.checkpoint:
                # save model
                save_checkpoint(model, optimizer, linear_loss.item(),
                                OUT_PATH, current_step, epoch)

            # Diagnostic visualizations
            const_spec = linear_output[0].data.cpu().numpy()
            gt_spec = linear_spec[0].data.cpu().numpy()

            const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
            gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
            tb.add_image('Visual/Reconstruction', const_spec, current_step)
            tb.add_image('Visual/GroundTruth', gt_spec, current_step)

            align_img = alignments[0].data.cpu().numpy()
            align_img = plot_alignment(align_img)
            tb.add_image('Visual/Alignment', align_img, current_step)

            # Sample audio
            audio_signal = linear_output[0].data.cpu().numpy()
            data_loader.dataset.ap.griffin_lim_iters = 60
            audio_signal = data_loader.dataset.ap.inv_spectrogram(
                audio_signal.T)
            try:
                tb.add_audio('SampleAudio', audio_signal, current_step,
                             sample_rate=c.sample_rate)
            except:
                # print("\n > Error at audio signal on TB!!")
                # print(audio_signal.max())
                # print(audio_signal.min())
                pass

    avg_linear_loss /= (num_iter + 1)
    avg_mel_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss

    # Plot Training Epoch Stats
    tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
    tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
    tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
    tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
    if c.mk > 0:
        avg_attn_loss /= (num_iter + 1)
        tb.add_scalar('TrainEpochLoss/AttnLoss', avg_attn_loss, current_step)
    tb.add_scalar('Time/EpochTime', epoch_time, epoch)
    epoch_time = 0

    return avg_linear_loss, current_step
Ejemplo n.º 8
0
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
          ap, epoch):
    data_loader = setup_loader(is_val=False, verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    avg_linear_loss = 0
    avg_mel_loss = 0
    avg_stop_loss = 0
    avg_step_time = 0
    print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) *
                          c.audio['num_freq'])
    if num_gpus > 0:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        linear_input = data[2]
        mel_input = data[3]
        mel_lengths = data[4]
        stop_targets = data[5]
        avg_text_length = torch.mean(text_lengths.float())
        avg_spec_length = torch.mean(mel_lengths.float())

        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0],
                                         stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()

        current_step = num_iter + args.restore_step + \
            epoch * len(data_loader) + 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()
        optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda(non_blocking=True)
            text_lengths = text_lengths.cuda(non_blocking=True)
            mel_input = mel_input.cuda(non_blocking=True)
            mel_lengths = mel_lengths.cuda(non_blocking=True)
            linear_input = linear_input.cuda(non_blocking=True)
            stop_targets = stop_targets.cuda(non_blocking=True)

        # compute mask for padding
        mask = sequence_mask(text_lengths)

        # forward pass
        mel_output, linear_output, alignments, stop_tokens = model(
            text_input, mel_input, mask)

        # loss computation
        stop_loss = criterion_st(stop_tokens, stop_targets)
        mel_loss = criterion(mel_output, mel_input, mel_lengths)
        linear_loss = (1 - c.loss_weight) * criterion(linear_output, linear_input, mel_lengths)\
            + c.loss_weight * criterion(linear_output[:, :, :n_priority_freq],
                              linear_input[:, :, :n_priority_freq],
                              mel_lengths)
        loss = mel_loss + linear_loss

        # backpass and check the grad norm for spec losses
        loss.backward(retain_graph=True)
        optimizer, current_lr = weight_decay(optimizer, c.wd)
        grad_norm, _ = check_update(model, 1.0)
        optimizer.step()

        # backpass and check the grad norm for stop loss
        stop_loss.backward()
        optimizer_st, _ = weight_decay(optimizer_st, c.wd)
        grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
        optimizer_st.step()

        step_time = time.time() - start_time
        epoch_time += step_time

        if current_step % c.print_step == 0:
            print(
                " | > Step:{}/{}  GlobalStep:{}  TotalLoss:{:.5f}  LinearLoss:{:.5f}  "
                "MelLoss:{:.5f}  StopLoss:{:.5f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  LR:{:.6f}"
                .format(num_iter, batch_n_iter, current_step, loss.item(),
                        linear_loss.item(), mel_loss.item(), stop_loss.item(),
                        grad_norm, grad_norm_st, avg_text_length,
                        avg_spec_length, step_time, current_lr),
                flush=True)

        # aggregate losses from processes
        if num_gpus > 1:
            linear_loss = reduce_tensor(linear_loss.data, num_gpus)
            mel_loss = reduce_tensor(mel_loss.data, num_gpus)
            loss = reduce_tensor(loss.data, num_gpus)
            stop_loss = reduce_tensor(stop_loss.data, num_gpus)

        if args.rank == 0:
            avg_linear_loss += float(linear_loss.item())
            avg_mel_loss += float(mel_loss.item())
            avg_stop_loss += stop_loss.item()
            avg_step_time += step_time

            # Plot Training Iter Stats
            iter_stats = {
                "loss_posnet": linear_loss.item(),
                "loss_decoder": mel_loss.item(),
                "lr": current_lr,
                "grad_norm": grad_norm,
                "grad_norm_st": grad_norm_st,
                "step_time": step_time
            }
            tb_logger.tb_train_iter_stats(current_step, iter_stats)

            if current_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model, optimizer, optimizer_st,
                                    linear_loss.item(), OUT_PATH, current_step,
                                    epoch)

                # Diagnostic visualizations
                const_spec = linear_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }
                tb_logger.tb_train_figures(current_step, figures)

                # Sample audio
                tb_logger.tb_train_audios(
                    current_step,
                    {'TrainAudio': ap.inv_spectrogram(const_spec.T)},
                    c.audio["sample_rate"])

    avg_linear_loss /= (num_iter + 1)
    avg_mel_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
    avg_step_time /= (num_iter + 1)

    # print epoch stats
    print(" | > EPOCH END -- GlobalStep:{}  AvgTotalLoss:{:.5f}  "
          "AvgLinearLoss:{:.5f}  AvgMelLoss:{:.5f}  "
          "AvgStopLoss:{:.5f}  EpochTime:{:.2f}  "
          "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
                                      avg_linear_loss, avg_mel_loss,
                                      avg_stop_loss, epoch_time,
                                      avg_step_time),
          flush=True)

    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": avg_linear_loss,
            "loss_decoder": avg_mel_loss,
            "stop_loss": avg_stop_loss,
            "epoch_time": epoch_time
        }
        tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, current_step)
    return avg_linear_loss, current_step
Ejemplo n.º 9
0
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
          ap, global_step, epoch):
    data_loader = setup_loader(ap,
                               model.decoder.r,
                               is_val=False,
                               verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    train_values = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stop_loss': 0,
        'avg_align_score': 0,
        'avg_step_time': 0,
        'avg_loader_time': 0,
        'avg_alignment_score': 0
    }
    if c.bidirectional_decoder:
        train_values['avg_decoder_b_loss'] = 0  # decoder backward loss
        train_values['avg_decoder_c_loss'] = 0  # decoder consistency loss
    keep_avg = KeepAverage()
    keep_avg.add_values(train_values)
    print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # format data
        text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(
            data)
        loader_time = time.time() - end_time

        global_step += 1

        # setup lr
        if c.noam_schedule:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # forward pass model
        if c.bidirectional_decoder:
            decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward, mu, logvar, z = model(
                text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
        else:
            decoder_output, postnet_output, alignments, stop_tokens, mu, logvar, z = model(
                text_input, text_lengths, mel_input, speaker_ids=speaker_ids)

        # loss computation
        stop_loss = criterion_st(stop_tokens,
                                 stop_targets) if c.stopnet else torch.zeros(1)
        if c.loss_masking:
            decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input,
                                         mel_lengths)
            else:
                postnet_loss = criterion(postnet_output, mel_input,
                                         mel_lengths)
        else:
            decoder_loss = criterion(decoder_output, mel_input)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input)
            else:
                postnet_loss = criterion(postnet_output, mel_input)
        loss = decoder_loss + postnet_loss
        if not c.separate_stopnet and c.stopnet:
            loss += stop_loss

        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_weight = kl_anneal_function(c.VAE['anneal_function'],
                                       c.VAE['anneal_lag'], global_step,
                                       c.VAE['anneal_k'], c.VAE['anneal_x0'],
                                       c.VAE["anneal_upper"])

        loss = loss + kl_weight * kl_loss

        # backward decoder
        if c.bidirectional_decoder:
            if c.loss_masking:
                decoder_backward_loss = criterion(
                    torch.flip(decoder_backward_output, dims=(1, )), mel_input,
                    mel_lengths)
            else:
                decoder_backward_loss = criterion(
                    torch.flip(decoder_backward_output, dims=(1, )), mel_input)
            decoder_c_loss = torch.nn.functional.l1_loss(
                torch.flip(decoder_backward_output, dims=(1, )),
                decoder_output)
            loss += decoder_backward_loss + decoder_c_loss
            keep_avg.update_values({
                'avg_decoder_b_loss':
                decoder_backward_loss.item(),
                'avg_decoder_c_loss':
                decoder_c_loss.item()
            })

        loss.backward()
        optimizer, current_lr = adam_weight_decay(optimizer)
        grad_norm, grad_flag = check_update(model,
                                            c.grad_clip,
                                            ignore_stopnet=True)
        optimizer.step()

        # compute alignment score
        align_score = alignment_diagonal_score(alignments)
        keep_avg.update_value('avg_align_score', align_score)

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            stop_loss.backward()
            optimizer_st, _ = adam_weight_decay(optimizer_st)
            grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        epoch_time += step_time

        if global_step % c.print_step == 0:
            print(
                "   | > Step:{}/{}  GlobalStep:{}  PostnetLoss:{:.5f}  "
                "DecoderLoss:{:.5f}  StopLoss:{:.5f}  AlignScore:{:.4f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  "
                "LoaderTime:{:.2f}  LR:{:.6f}".format(
                    num_iter, batch_n_iter, global_step, postnet_loss.item(),
                    decoder_loss.item(), stop_loss.item(), align_score,
                    grad_norm, grad_norm_st, avg_text_length, avg_spec_length,
                    step_time, loader_time, current_lr),
                flush=True)

        # aggregate losses from processes
        if num_gpus > 1:
            postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
            decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
            loss = reduce_tensor(loss.data, num_gpus)
            stop_loss = reduce_tensor(stop_loss.data,
                                      num_gpus) if c.stopnet else stop_loss

        if args.rank == 0:
            update_train_values = {
                'avg_postnet_loss':
                float(postnet_loss.item()),
                'avg_decoder_loss':
                float(decoder_loss.item()),
                'avg_stop_loss':
                stop_loss
                if isinstance(stop_loss, float) else float(stop_loss.item()),
                'avg_step_time':
                step_time,
                'avg_loader_time':
                loader_time
            }
            keep_avg.update_values(update_train_values)

            # Plot Training Iter Stats
            # reduce TB load
            if global_step % 10 == 0:
                iter_stats = {
                    "loss_posnet": postnet_loss.item(),
                    "loss_decoder": decoder_loss.item(),
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time
                }
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                #print('>>>>>>>>>>>>>>100>>>>>>>>>>>>>>>>')
                if c.checkpoint:
                    # save model
                    #print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>in check point')
                    save_checkpoint(model, optimizer, optimizer_st,
                                    postnet_loss.item(), OUT_PATH, global_step,
                                    epoch)

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img),
                }

                if c.bidirectional_decoder:
                    figures["alignment_backward"] = plot_alignment(
                        alignments_backward[0].data.cpu().numpy())

                #tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_mel_spectrogram(const_spec.T)
                #tb_logger.tb_train_audios(global_step,{'TrainAudio': train_audio}, c.audio["sample_rate"])

        end_time = time.time()

    # print epoch stats
    print("   | > EPOCH END -- GlobalStep:{}  "
          "AvgPostnetLoss:{:.5f}  AvgDecoderLoss:{:.5f}  "
          "AvgStopLoss:{:.5f}  AvgAlignScore:{:3f}  EpochTime:{:.2f}  "
          "AvgStepTime:{:.2f}  AvgLoaderTime:{:.2f}".format(
              global_step, keep_avg['avg_postnet_loss'],
              keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
              keep_avg['avg_align_score'], epoch_time,
              keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
          flush=True)
    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": keep_avg['avg_postnet_loss'],
            "loss_decoder": keep_avg['avg_decoder_loss'],
            "stop_loss": keep_avg['avg_stop_loss'],
            "alignment_score": keep_avg['avg_align_score'],
            "epoch_time": epoch_time
        }
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return keep_avg['avg_postnet_loss'], global_step