Exemple #1
0
def train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=32, ckpt_dir=None, writer=None, mode='1'):
    epochs = 0
    global_step = args.global_step
    l1_criterion = nn.L1Loss() # default average
    bd_criterion = nn.BCELoss()
    GO_frames = torch.zeros([batch_size, 1, args.n_mels]).to(DEVICE) # (N, Ty/r, n_mels)
    idx2char = load_vocab_tool(args.lang)[-1]
    while global_step < args.max_step:
        epoch_loss = 0
        for step, (texts, mels, pmels, gas, f0) in tqdm(enumerate(data_loader), total=len(data_loader), unit='B', ncols=70, leave=False):
            optimizer.zero_grad()
            texts, mels, pmels, gas, f0 = texts.to(DEVICE), mels.to(DEVICE), pmels.to(DEVICE), gas.to(DEVICE), f0.to(DEVICE)
            prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1)
            mels_hat, pmels_hat, A = model(texts, prev_mels, f0)  # mels_hat: (N, Ty/r, n_mels), A: (N, Tx, Ty/r)
            
            mel_loss = l1_criterion(mels_hat, mels)
            bd_loss = bd_criterion(mels_hat, mels)
            pmel_loss = l1_criterion(pmels_hat, pmels)
            att_loss = torch.mean(A*gas)
            loss = mel_loss + bd_loss + att_loss + pmel_loss
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            optimizer.step()
            if args.lr_decay:
                scheduler.step()
            if global_step % args.save_term == 0:
                model.eval()
                val_loss = evaluate(model, valid_loader, writer, global_step, args.test_batch)
                save_model(model, optimizer, scheduler, val_loss, global_step, ckpt_dir) # save best 5 models
                model.train()
            global_step += 1
        if args.log_mode:
            # Summary
            writer.add_scalar('train/mel_loss', mel_loss.item(), global_step)
            writer.add_scalar('train/pmel_loss', pmel_loss.item(), global_step)
            writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step)
            alignment = A[0:1].clone().cpu().detach().numpy()
            guided_att = gas[0:1].clone().cpu().detach().numpy()
            writer.add_image('train/alignments', att2img(alignment), global_step) # (Tx, Ty)
            writer.add_image('train/guided_att', att2img(guided_att), global_step) # (Tx, Ty)
            writer.add_scalar('train/ga_loss', att_loss, global_step)
            # text = texts[0].cpu().detach().numpy()
            # text = [idx2char[ch] for ch in text]
            # plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, type(model).__name__, 'A', 'train'))
            mel_hat = mels_hat[0:1].transpose(1,2)
            mel = mels[0:1].transpose(1, 2)
            writer.add_image('train/mel_hat', mel_hat, global_step)
            writer.add_image('train/mel', mel, global_step)
            # print('Training Loss: {}'.format(avg_loss))
        epochs += 1
    print('Training complete')
Exemple #2
0
def evaluate(model, data_loader, writer, global_step, batch_size=100):
    valid_loss = 0.
    A = None 
    l1_loss = nn.L1Loss()
    with torch.no_grad():
        mel_sum_loss = 0.
        pmel_sum_loss = 0.
        for step, (texts, mels, pmels, gas, f0) in enumerate(data_loader):
            texts, mels, pmels, gas, f0 = texts.to(DEVICE), mels.to(DEVICE), pmels.to(DEVICE), gas.to(DEVICE), f0.to(DEVICE)
            GO_frames = torch.zeros([mels.shape[0], 1, args.n_mels]).to(DEVICE) # (N, Ty/r, n_mels)
            prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1)

            mels_hat, pmels_hat, A = model(texts, prev_mels, f0)  # mels_hat: (N, Ty/r, n_mels), A: (N, Tx, Ty/r)
            mel_loss = l1_loss(mels_hat, mels)
            pmel_loss = l1_loss(pmels_hat, pmels)
            att_loss = torch.mean(A*gas)
            mel_sum_loss += mel_loss.item()
            pmel_sum_loss += pmel_loss.item()

        mel_avg_loss = mel_sum_loss / (len(data_loader))
        pmel_avg_loss = pmel_sum_loss / (len(data_loader))
        writer.add_scalar('eval/mel_loss', mel_avg_loss, global_step)
        writer.add_scalar('eval/pmel_loss', pmel_avg_loss, global_step)
        writer.add_scalar('eval/ga_loss', att_loss, global_step)
        alignment = A[0:1].clone().cpu().detach().numpy()
        guided_att = gas[0:1].clone().cpu().detach().numpy()
        writer.add_image('eval/alignments', att2img(alignment), global_step) # (Tx, Ty)
        writer.add_image('eval/guided_att', att2img(guided_att), global_step) # (Tx, Ty)
        # text = texts[0].cpu().detach().numpy()
        # text = [load_vocab_tool(args.lang)[-1][ch] for ch in text]
        # plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, args.model_name, 'A'))
        writer.add_image('eval/mel_hat', mels_hat[0:1].transpose(1,2), global_step)
        writer.add_image('eval/mel', mels[0:1].transpose(1,2), global_step)
        writer.add_image('eval/pmel_hat', pmels_hat[0:1].transpose(1,2), global_step)
        writer.add_image('eval/pmel', pmels[0:1].transpose(1,2), global_step)
    return mel_avg_loss
Exemple #3
0
def evaluate(model,
             data_loader,
             criterion,
             writer,
             global_step,
             batch_size=100):
    valid_loss_mel = 0.
    valid_loss_mag = 0.
    A = None
    with torch.no_grad():
        for step, (texts, mels, mags) in enumerate(data_loader):
            texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), mags.to(
                DEVICE)
            GO_frames = torch.zeros([mels.shape[0], 1, args.n_mels * args.r
                                     ]).to(DEVICE)  # (N, Ty/r, n_mels)
            prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1)
            mels_hat, mags_hat, A = model(texts, prev_mels)

            loss_mel = criterion(mels_hat, mels)
            loss_mag = criterion(mags_hat, mags)
            valid_loss_mel += loss_mel.item()
            valid_loss_mag += loss_mag.item()
        avg_loss_mel = valid_loss_mel / (len(data_loader))
        avg_loss_mag = valid_loss_mag / (len(data_loader))
        writer.add_scalar('eval/loss_mel', avg_loss_mel, global_step)
        writer.add_scalar('eval/loss_mag', avg_loss_mag, global_step)

        alignment = A[0:1].clone().cpu().detach().numpy()
        writer.add_image('eval/alignments', att2img(alignment),
                         global_step)  # (Tx, Ty)
        text = texts[0].cpu().detach().numpy()
        text = [load_vocab()[-1][ch] for ch in text]
        plot_att(alignment[0],
                 text,
                 global_step,
                 path=os.path.join(args.logdir, model.name, 'A'))

        mel_hat = mels_hat[0:1].transpose(1, 2)
        mel = mels[0:1].transpose(1, 2)
        writer.add_image('eval/mel_hat', mel_hat, global_step)
        writer.add_image('eval/mel', mel, global_step)

        mag_hat = mags_hat[0:1].transpose(1, 2)
        mag = mags[0:1].transpose(1, 2)
        writer.add_image('eval/mag_hat', mag_hat, global_step)
        writer.add_image('eval/mag', mag, global_step)
    return avg_loss_mel
Exemple #4
0
def train(model,
          data_loader,
          valid_loader,
          optimizer,
          scheduler,
          batch_size=32,
          ckpt_dir=None,
          writer=None,
          mode='1'):
    epochs = 0
    global_step = args.global_step
    criterion = nn.L1Loss().to(DEVICE)  # default average
    model_infos = [('None', 10000.)] * 5
    GO_frames = torch.zeros([batch_size, 1, args.n_mels * args.r
                             ]).to(DEVICE)  # (N, Ty/r, n_mels)
    idx2char = load_vocab()[-1]
    while global_step < args.max_step:
        epoch_loss_mel = 0
        epoch_loss_mag = 0
        for step, (texts, mels, mags) in tqdm(enumerate(data_loader),
                                              total=len(data_loader),
                                              unit='B',
                                              ncols=70,
                                              leave=False):
            optimizer.zero_grad()

            texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), mags.to(
                DEVICE)
            prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1)
            mels_hat, mags_hat, A = model(texts, prev_mels)

            loss_mel = criterion(mels_hat, mels)
            loss_mag = criterion(mags_hat, mags)
            loss = loss_mel + loss_mag
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            scheduler.step()
            optimizer.step()

            epoch_loss_mel += loss_mel.item()
            epoch_loss_mag += loss_mag.item()
            global_step += 1
            if global_step % args.save_term == 0:
                model.eval()  #
                val_loss = evaluate(model, valid_loader, criterion, writer,
                                    global_step, args.test_batch)
                model_infos = save_model(model, model_infos, optimizer,
                                         scheduler, val_loss, global_step,
                                         ckpt_dir)  # save best 5 models
                model.train()
        if args.log_mode:
            # Summary
            avg_loss_mel = epoch_loss_mel / (len(data_loader))
            avg_loss_mag = epoch_loss_mag / (len(data_loader))
            writer.add_scalar('train/loss_mel', avg_loss_mel, global_step)
            writer.add_scalar('train/loss_mag', avg_loss_mag, global_step)
            writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step)

            alignment = A[0:1].clone().cpu().detach().numpy()
            writer.add_image('train/alignments', att2img(alignment),
                             global_step)  # (Tx, Ty)
            text = texts[0].cpu().detach().numpy()
            text = [idx2char[ch] for ch in text]
            plot_att(alignment[0],
                     text,
                     global_step,
                     path=os.path.join(args.logdir, model.name, 'A', 'train'))

            mel_hat = mels_hat[0:1].transpose(1, 2)
            mel = mels[0:1].transpose(1, 2)
            writer.add_image('train/mel_hat', mel_hat, global_step)
            writer.add_image('train/mel', mel, global_step)

            mag_hat = mags_hat[0:1].transpose(1, 2)
            mag = mags[0:1].transpose(1, 2)
            writer.add_image('train/mag_hat', mag_hat, global_step)
            writer.add_image('train/mag', mag, global_step)
            # print('Training Loss: {}'.format(avg_loss))
        epochs += 1
    print('Training complete')