예제 #1
0
def train(train_loader, model, device, mels_criterion, stop_criterion,
          optimizer, scheduler, writer, train_dir):
    batch_time = ValueWindow()
    data_time = ValueWindow()
    losses = ValueWindow()

    # switch to train mode
    model.train()

    end = time.time()
    global global_epoch
    global global_step
    for i, (txts, mels, stop_tokens, txt_lengths,
            mels_lengths) in enumerate(train_loader):
        scheduler.adjust_learning_rate(optimizer, global_step)
        # measure data loading time
        data_time.update(time.time() - end)

        if device > -1:
            txts = txts.cuda(device)
            mels = mels.cuda(device)
            stop_tokens = stop_tokens.cuda(device)
            txt_lengths = txt_lengths.cuda(device)
            mels_lengths = mels_lengths.cuda(device)

        # compute output
        frames, decoder_frames, stop_tokens_predict, alignment = model(
            txts, txt_lengths, mels)
        decoder_frames_loss = mels_criterion(decoder_frames,
                                             mels,
                                             lengths=mels_lengths)
        frames_loss = mels_criterion(frames, mels, lengths=mels_lengths)
        stop_token_loss = stop_criterion(stop_tokens_predict,
                                         stop_tokens,
                                         lengths=mels_lengths)
        loss = decoder_frames_loss + frames_loss + stop_token_loss

        #print(frames_loss, decoder_frames_loss)
        losses.update(loss.item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        if hparams.clip_thresh > 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.get_trainable_parameters(), hparams.clip_thresh)

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % hparams.print_freq == 0:
            log('Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                    global_epoch,
                    i,
                    len(train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses))

        # Logs
        writer.add_scalar("loss", float(loss.item()), global_step)
        writer.add_scalar(
            "avg_loss in {} window".format(losses.get_dinwow_size),
            float(losses.avg), global_step)
        writer.add_scalar("stop_token_loss", float(stop_token_loss.item()),
                          global_step)
        writer.add_scalar("decoder_frames_loss",
                          float(decoder_frames_loss.item()), global_step)
        writer.add_scalar("output_frames_loss", float(frames_loss.item()),
                          global_step)

        if hparams.clip_thresh > 0:
            writer.add_scalar("gradient norm", grad_norm, global_step)
        writer.add_scalar("learning rate", optimizer.param_groups[0]['lr'],
                          global_step)
        global_step += 1

    dst_alignment_path = join(train_dir,
                              "{}_alignment.png".format(global_step))
    alignment = alignment.cpu().detach().numpy()
    plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]],
                   dst_alignment_path,
                   info="{}, {}".format(hparams.builder, global_step))
예제 #2
0
def validate(val_loader, model, device, mels_criterion, stop_criterion, writer,
             val_dir):
    batch_time = ValueWindow()
    losses = ValueWindow()

    # switch to evaluate mode
    model.eval()

    global global_epoch
    global global_step
    with torch.no_grad():
        end = time.time()
        for i, (txts, mels, stop_tokens, txt_lengths,
                mels_lengths) in enumerate(val_loader):
            # measure data loading time
            batch_time.update(time.time() - end)

            if device > -1:
                txts = txts.cuda(device)
                mels = mels.cuda(device)
                stop_tokens = stop_tokens.cuda(device)
                txt_lengths = txt_lengths.cuda(device)
                mels_lengths = mels_lengths.cuda(device)

            # compute output
            frames, decoder_frames, stop_tokens_predict, alignment = model(
                txts, txt_lengths, mels)
            decoder_frames_loss = mels_criterion(decoder_frames,
                                                 mels,
                                                 lengths=mels_lengths)
            frames_loss = mels_criterion(frames, mels, lengths=mels_lengths)
            stop_token_loss = stop_criterion(stop_tokens_predict,
                                             stop_tokens,
                                             lengths=mels_lengths)
            loss = decoder_frames_loss + frames_loss + stop_token_loss

            losses.update(loss.item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % hparams.print_freq == 0:
                log('Epoch: [{0}]\t'
                    'Test: [{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                        global_epoch,
                        i,
                        len(val_loader),
                        batch_time=batch_time,
                        loss=losses))
            # Logs
            writer.add_scalar("loss", float(loss.item()), global_step)
            writer.add_scalar(
                "avg_loss in {} window".format(losses.get_dinwow_size),
                float(losses.avg), global_step)
            writer.add_scalar("stop_token_loss", float(stop_token_loss.item()),
                              global_step)
            writer.add_scalar("decoder_frames_loss",
                              float(decoder_frames_loss.item()), global_step)
            writer.add_scalar("output_frames_loss", float(frames_loss.item()),
                              global_step)

        dst_alignment_path = join(val_dir,
                                  "{}_alignment.png".format(global_step))
        alignment = alignment.cpu().detach().numpy()
        plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]],
                       dst_alignment_path,
                       info="{}, {}".format(hparams.builder, global_step))

    return losses.avg