예제 #1
0
def synthesize(t2m, ssrn, data_loader, batch_size=100):
    '''
    DCTTS Architecture
    Text --> Text2Mel --> SSRN --> Wav file
    '''
    # Text2Mel
    idx2char = load_vocab()[-1]
    with torch.no_grad():
        print('=' * 10, ' Text2Mel ', '=' * 10)
        for step, (texts, _, _) in tqdm(enumerate(data_loader),
                                        total=len(data_loader),
                                        ncols=70):
            texts = texts.to(DEVICE)
            prev_mel_hats = torch.zeros([len(texts), args.max_Ty,
                                         args.n_mels]).to(DEVICE)
            total_mel_hats, A = t2m.synthesize(texts, prev_mel_hats)
            alignments = A.cpu().detach().numpy()
            visual_texts = texts.cpu().detach().numpy()
            # Mel --> Mag
            mags = ssrn(total_mel_hats)  # mag: (N, Ty, n_mags)
            mags = mags.cpu().detach().numpy()
            for idx in range(len(mags)):
                fname = step * batch_size + idx
                text = [idx2char[ch] for ch in visual_texts[idx]]
                utils.plot_att(alignments[idx],
                               text,
                               args.global_step,
                               path=os.path.join(args.sampledir, 'A'),
                               name='{:02d}.png'.format(fname))
                wav = utils.spectrogram2wav(mags[idx])
                write(os.path.join(args.sampledir, '{:02d}.wav'.format(fname)),
                      args.sr, wav)

    return None
예제 #2
0
def evaluate(model, data_loader, criterion, writer, global_step, batch_size=100):
    valid_loss = 0.
    A = None 
    with torch.no_grad():
        for step, (texts, mels, extras) in enumerate(data_loader):
            if model.name == 'Text2Mel':
                first_frames = torch.zeros([mels.shape[0], 1, args.n_mels]).to(DEVICE) # (N, Ty/r, n_mels)
                texts, mels = texts.to(DEVICE), mels.to(DEVICE)
                prev_mels = torch.cat((first_frames, mels[:, :-1, :]), 1)
                mels_hat, A = model(texts, prev_mels)  # mels_hat: (N, Ty/r, n_mels), A: (N, Tx, Ty/r)
                loss = criterion(mels_hat, mels)
            elif model.name == 'SSRN':
                texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), extras.to(DEVICE)
                mags_hat = model(mels)  # Predict
                loss = criterion(mags_hat, mags)
            valid_loss += loss.item()
        avg_loss = valid_loss / (len(data_loader))
        writer.add_scalar('eval/loss', avg_loss, global_step)
        if model.name == 'Text2Mel':
            alignment = A[0:1].clone().cpu().detach().numpy()
            writer.add_image('eval/alignments', att2img(alignment), global_step) # (Tx, Ty)
            text = texts[0].cpu().detach().numpy()
            text = [load_vocab()[-1][ch] for ch in text]
            plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, model.name, 'A'))
            mel_hat = mels_hat[0:1].transpose(1,2)
            mel = mels[0:1].transpose(1, 2)
            writer.add_image('eval/mel_hat', mel_hat, global_step)
            writer.add_image('eval/mel', mel, global_step)
        else:
            mag_hat = mags_hat[0:1].transpose(1, 2)
            mag = mags[0:1].transpose(1, 2)
            writer.add_image('eval/mag_hat', mag_hat, global_step)
            writer.add_image('eval/mag', mag, global_step)
    return avg_loss
예제 #3
0
def synthesize(t2m, ssrn, data_loader, batch_size=100):
    '''
    DCTTS Architecture
    Text --> Text2Mel --> SSRN --> Wav file
    '''

    text2mel_total_time = 0

    # Text2Mel
    idx2char = load_vocab()[-1]
    with torch.no_grad():
        print('='*10, ' Text2Mel ', '='*10)
        is_test = [True, False]
        total_mel_hats = torch.zeros([len(data_loader.dataset), args.max_Ty, args.n_mels]).to(DEVICE)
        mags = torch.zeros([len(data_loader.dataset), args.max_Ty*args.r, args.n_mags]).to(DEVICE)
        
        for step, (texts, mel, _) in enumerate(data_loader):
            texts = texts.to(DEVICE)
            prev_mel_hats = torch.zeros([len(texts), args.max_Ty, args.n_mels]).to(DEVICE)


            text2mel_start_time = time.time()         
            for t in tqdm(range(args.max_Ty-1), unit='B', ncols=70):
                if t == args.max_Ty - 2:
                    is_test[1] = True
                mel_hats, A, result_tuple = t2m(texts, prev_mel_hats, t, is_test) # mel: (N, Ty/r, n_mels)
                prev_mel_hats[:, t+1, :] = mel_hats[:, t, :]
		print(mel_hats.sum(), mel.sum())
            
            text2mel_finish_time = time.time()
            text2mel_total_time += (text2mel_finish_time - text2mel_start_time)

            total_mel_hats[step*batch_size:(step+1)*batch_size, :, :] = prev_mel_hats

            
            print('='*10, ' Alignment ', '='*10)
            alignments = A.cpu().detach().numpy()
            visual_texts = texts.cpu().detach().numpy()
            for idx in range(len(alignments)):
                text = [idx2char[ch] for ch in visual_texts[idx]]
                utils.plot_att(alignments[idx], text, args.global_step, path=os.path.join(args.sampledir, 'A'), name='{}.png'.format(idx))
            print('='*10, ' SSRN ', '='*10)
            # Mel --> Mag
            mags[step*batch_size:(step+1)*batch_size:, :, :] = \
                ssrn(total_mel_hats[step*batch_size:(step+1)*batch_size, :, :]) # mag: (N, Ty, n_mags)
            mags = mags.cpu().detach().numpy()
        print('='*10, ' Vocoder ', '='*10)
        for idx in trange(len(mags), unit='B', ncols=70):
            wav = utils.spectrogram2wav(mags[idx])
            write(os.path.join(args.sampledir, '{}.wav'.format(idx+1)), args.sr, wav)
 
    result = list(result_tuple)
    result.append(text2mel_total_time)

    return result
예제 #4
0
def evaluate(model,
             data_loader,
             criterion,
             writer,
             global_step,
             batch_size=100):
    valid_loss_mel = 0.
    valid_loss_mag = 0.
    A = None
    with torch.no_grad():
        for step, (texts, mels, mags) in enumerate(data_loader):
            texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), mags.to(
                DEVICE)
            GO_frames = torch.zeros([mels.shape[0], 1, args.n_mels * args.r
                                     ]).to(DEVICE)  # (N, Ty/r, n_mels)
            prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1)
            mels_hat, mags_hat, A = model(texts, prev_mels)

            loss_mel = criterion(mels_hat, mels)
            loss_mag = criterion(mags_hat, mags)
            valid_loss_mel += loss_mel.item()
            valid_loss_mag += loss_mag.item()
        avg_loss_mel = valid_loss_mel / (len(data_loader))
        avg_loss_mag = valid_loss_mag / (len(data_loader))
        writer.add_scalar('eval/loss_mel', avg_loss_mel, global_step)
        writer.add_scalar('eval/loss_mag', avg_loss_mag, global_step)

        alignment = A[0:1].clone().cpu().detach().numpy()
        writer.add_image('eval/alignments', att2img(alignment),
                         global_step)  # (Tx, Ty)
        text = texts[0].cpu().detach().numpy()
        text = [load_vocab()[-1][ch] for ch in text]
        plot_att(alignment[0],
                 text,
                 global_step,
                 path=os.path.join(args.logdir, model.name, 'A'))

        mel_hat = mels_hat[0:1].transpose(1, 2)
        mel = mels[0:1].transpose(1, 2)
        writer.add_image('eval/mel_hat', mel_hat, global_step)
        writer.add_image('eval/mel', mel, global_step)

        mag_hat = mags_hat[0:1].transpose(1, 2)
        mag = mags[0:1].transpose(1, 2)
        writer.add_image('eval/mag_hat', mag_hat, global_step)
        writer.add_image('eval/mag', mag, global_step)
    return avg_loss_mel
예제 #5
0
def synthesize(model, data_loader, batch_size=100):
    '''
    Tacotron

    '''
    idx2char = load_vocab()[-1]
    with torch.no_grad():
        print('*' * 15, ' Synthesize ', '*' * 15)
        mags = torch.zeros(
            [len(data_loader.dataset), args.max_Ty * args.r,
             args.n_mags]).to(DEVICE)
        for step, (texts, _, _) in enumerate(data_loader):
            texts = texts.to(DEVICE)
            GO_frames = torch.zeros([texts.shape[0], 1,
                                     args.n_mels * args.r]).to(DEVICE)
            _, mags_hat, A = model(texts, GO_frames, synth=True)

            print('=' * 10, ' Alignment ', '=' * 10)
            alignments = A.cpu().detach().numpy()
            visual_texts = texts.cpu().detach().numpy()
            for idx in range(len(alignments)):
                text = [idx2char[ch] for ch in visual_texts[idx]]
                utils.plot_att(alignments[idx],
                               text,
                               args.global_step,
                               path=os.path.join(args.sampledir, 'A'),
                               name='{}.png'.format(idx + step * batch_size))
            mags[step * batch_size:(step + 1) *
                 batch_size:, :, :] = mags_hat  # mag: (N, Ty, n_mags)
        print('=' * 10, ' Vocoder ', '=' * 10)
        mags = mags.cpu().detach().numpy()
        for idx in trange(len(mags), unit='B', ncols=70):
            wav = utils.spectrogram2wav(mags[idx])
            write(os.path.join(args.sampledir, '{}.wav'.format(idx + 1)),
                  args.sr, wav)
    return None
예제 #6
0
def train(model,
          data_loader,
          valid_loader,
          optimizer,
          scheduler,
          batch_size=32,
          ckpt_dir=None,
          writer=None,
          DEVICE=None):
    """
    train function

    :param model: nn module object
    :param data_loader: data loader for training set
    :param valid_loader: data loader for validation set
    :param optimizer: optimizer
    :param scheculer: for scheduling learning rate
    :param batch_size: Scalar
    :param ckpt_dir: String. checkpoint directory
    :param writer: Tensorboard writer
    :param DEVICE: 'cpu' or 'gpu'

    """
    epochs = 0
    global_step = args.global_step
    criterion = nn.L1Loss()  # default average
    bce_loss = nn.BCELoss()
    xe_loss = nn.CrossEntropyLoss()

    GO_frames = torch.zeros([batch_size, 1, args.n_mels * args.r
                             ]).to(DEVICE)  # (N, Ty/r, n_mels)
    idx2char = load_vocab()[-1]
    while global_step < args.max_step:
        epoch_loss_mel, epoch_loss_fmel, epoch_loss_ff = 0., 0., 0.
        for step, (texts, mels, ff) in tqdm(enumerate(data_loader),
                                            total=len(data_loader),
                                            unit='B',
                                            ncols=70,
                                            leave=False):
            optimizer.zero_grad()
            texts, mels, ff = texts.to(DEVICE), mels.to(DEVICE), ff.to(DEVICE)
            prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1)
            refs = mels.view(mels.size(0), -1,
                             args.n_mels).unsqueeze(1)  # (N, 1, Ty, n_mels)
            if type(model).__name__ == 'TPGST':
                mels_hat, fmels_hat, A, style_attentions, ff_hat, se, tpse = model(
                    texts, prev_mels, refs)
                loss_se = criterion(tpse, se.detach())
            else:
                mels_hat, fmels_hat, A, ff_hat = model(texts, prev_mels)

            loss_mel = criterion(mels_hat, mels)
            fmels = mels.view(mels.size(0), -1, args.n_mels)
            loss_fmel = criterion(fmels_hat, fmels)
            loss_ff = bce_loss(ff_hat, ff)

            if global_step > args.tp_start and type(model).__name__ == 'TPGST':
                loss = loss_mel + 0.01 * loss_ff + 0.01 * loss_se
            else:
                loss = loss_mel + 0.01 * loss_ff

            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            scheduler.step()

            epoch_loss_mel += loss_mel.item()
            epoch_loss_fmel += loss_fmel.item()
            epoch_loss_ff += loss_ff.item()

            if global_step % args.log_term == 0:
                writer.add_scalar('batch/loss_mel', loss_mel.item(),
                                  global_step)
                if type(model).__name__ == 'TPGST':
                    writer.add_scalar('batch/loss_se', loss_se.item(),
                                      global_step)
                writer.add_scalar('batch/loss_ff', loss_ff.item(), global_step)
                writer.add_scalar('train/lr',
                                  scheduler.get_lr()[0], global_step)

            if global_step % args.eval_term == 0:
                model.eval()  #
                val_loss = evaluate(model,
                                    valid_loader,
                                    criterion,
                                    writer,
                                    global_step,
                                    DEVICE=DEVICE)
                model.train()

            if global_step % args.save_term == 0:
                save_model(model, optimizer, scheduler, val_loss, global_step,
                           ckpt_dir)  # save best 5 models
            global_step += 1

        if args.log_mode:
            # Summary
            avg_loss_mel = epoch_loss_mel / (len(data_loader))
            avg_loss_fmel = epoch_loss_fmel / (len(data_loader))
            avg_loss_ff = epoch_loss_ff / (len(data_loader))

            writer.add_scalar('train/loss_mel', avg_loss_mel, global_step)
            writer.add_scalar('train/loss_fmel', avg_loss_fmel, global_step)
            writer.add_scalar('train/loss_ff', avg_loss_ff, global_step)
            writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step)

            alignment = A[0:1].clone().cpu().detach().numpy()
            writer.add_image('train/alignments', att2img(alignment),
                             global_step)  # (Tx, Ty)
            text = texts[0].cpu().detach().numpy()
            text = [idx2char[ch] for ch in text]
            plot_att(alignment[0],
                     text,
                     global_step,
                     path=os.path.join(args.logdir,
                                       type(model).__name__, 'A', 'train'))

            mel_hat = mels_hat[0:1].transpose(1, 2)
            fmel_hat = fmels_hat[0:1].transpose(1, 2)
            mel = mels[0:1].transpose(1, 2)
            writer.add_image('train/mel_hat', mel_hat, global_step)
            writer.add_image('train/fmel_hat', fmel_hat, global_step)
            writer.add_image('train/mel', mel, global_step)

            if type(model).__name__ == 'TPGST':
                styleA = style_attentions.unsqueeze(0) * 255.
                writer.add_image('train/styleA', styleA, global_step)
            # print('Training Loss: {}'.format(avg_loss))
        epochs += 1
    print('Training complete')
예제 #7
0
def evaluate(model, data_loader, criterion, writer, global_step, DEVICE=None):
    """
    To evaluate with validation set

    :param model: nn module object
    :param data_loader: data loader
    :param criterion: criterion for spectorgrams
    :param writer: Tensorboard writer
    :param global_step: Scalar. global step
    :param DEVICE: 'cpu' or 'gpu'

    """
    bce_loss = nn.BCELoss()
    xe_loss = nn.CrossEntropyLoss()
    valid_loss_mel, valid_loss_fmel, valid_loss_ff, valid_loss_se = 0., 0., 0., 0.
    A = None
    with torch.no_grad():
        for step, (texts, mels, ff) in enumerate(data_loader):
            texts, mels, ff = texts.to(DEVICE), mels.to(DEVICE), ff.to(DEVICE)
            GO_frames = torch.zeros([mels.shape[0], 1, args.n_mels * args.r
                                     ]).to(DEVICE)  # (N, Ty/r, n_mels)
            prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1)
            refs = mels.view(mels.size(0), -1,
                             args.n_mels).unsqueeze(1)  # (N, 1, Ty, n_mels)
            if type(model).__name__ == 'TPGST':
                mels_hat, fmels_hat, A, style_attentions, ff_hat, se, tpse = model(
                    texts, prev_mels, refs)
                loss_se = criterion(tpse, se)
                valid_loss_se += loss_se.item()
            else:
                mels_hat, fmels_hat, A, ff_hat = model(texts, prev_mels)

            loss_mel = criterion(mels_hat, mels)
            fmels = mels.view(mels.size(0), -1, args.n_mels)
            loss_fmel = criterion(fmels_hat, fmels)
            loss_ff = bce_loss(ff_hat, ff)

            valid_loss_mel += loss_mel.item()
            valid_loss_fmel += loss_fmel.item()
            valid_loss_ff += loss_ff.item()
        avg_loss_mel = valid_loss_mel / (len(data_loader))
        avg_loss_fmel = valid_loss_fmel / (len(data_loader))
        avg_loss_ff = valid_loss_ff / (len(data_loader))

        writer.add_scalar('eval/loss_mel', avg_loss_mel, global_step)
        writer.add_scalar('eval/loss_fmel', avg_loss_fmel, global_step)
        writer.add_scalar('eval/loss_ff', avg_loss_ff, global_step)

        alignment = A[0:1].clone().cpu().detach().numpy()
        writer.add_image('eval/alignments', att2img(alignment),
                         global_step)  # (Tx, Ty)
        text = texts[0].cpu().detach().numpy()
        text = [load_vocab()[-1][ch] for ch in text]
        plot_att(alignment[0],
                 text,
                 global_step,
                 path=os.path.join(args.logdir,
                                   type(model).__name__, 'A'))

        mel_hat = mels_hat[0:1].transpose(1, 2)
        fmel_hat = fmels_hat[0:1].transpose(1, 2)
        mel = mels[0:1].transpose(1, 2)

        writer.add_image('eval/mel_hat', mel_hat, global_step)
        writer.add_image('eval/fmel_hat', fmel_hat, global_step)
        writer.add_image('eval/mel', mel, global_step)
        if type(model).__name__ == 'TPGST':
            avg_loss_se = valid_loss_se / (len(data_loader))
            writer.add_scalar('eval/loss_se', avg_loss_se, global_step)
            styleA = style_attentions.view(1, mels.size(0),
                                           args.n_tokens) * 255.
            writer.add_image('eval/styleA', styleA, global_step)

    return avg_loss_mel
예제 #8
0
def train(model,
          data_loader,
          valid_loader,
          optimizer,
          scheduler,
          batch_size=32,
          ckpt_dir=None,
          writer=None,
          mode='1'):
    epochs = 0
    global_step = args.global_step
    l1_criterion = nn.L1Loss().to(DEVICE)  # default average
    bd_criterion = nn.BCELoss().to(DEVICE)
    model_infos = [('None', 10000.)] * 5
    first_frames = torch.zeros([batch_size, 1,
                                args.n_mels]).to(DEVICE)  # (N, Ty/r, n_mels)
    idx2char = load_vocab()[-1]
    while global_step < args.max_step:
        epoch_loss = 0
        train_iter = iter(data_loader)
        data = train_iter.next()

        for step, (texts, mels, extras) in tqdm(enumerate(data_loader),
                                                total=len(data_loader),
                                                unit='B',
                                                ncols=70,
                                                leave=False):
            optimizer.zero_grad()
            if model.name == 'Text2Mel':
                if args.ga_mode:
                    texts, mels, gas = texts.to(DEVICE), mels.to(
                        DEVICE), extras.to(DEVICE)
                else:
                    texts, mels = texts.to(DEVICE), mels.to(DEVICE)
                prev_mels = torch.cat((first_frames, mels[:, :-1, :]), 1)
                mels_hat, A, _ = model(
                    texts, prev_mels,
                    0)  # mels_hat: (N, Ty/r, n_mels), A: (N, Tx, Ty/r)
                if args.ga_mode:
                    l1_loss = l1_criterion(mels_hat, mels)
                    bd_loss = bd_criterion(mels_hat, mels)
                    att_loss = torch.mean(A * gas)
                    loss = l1_loss + bd_loss + att_loss
                else:
                    l1_loss = l1_criterion(mels_hat, mels)
                    bd_loss = bd_criterion(mels_hat, mels)
                    loss = l1_loss + bd_loss
            elif model.name == 'SSRN':
                texts, mels, mags = texts.to(DEVICE), mels.to(
                    DEVICE), extras.to(DEVICE)
                mags_hat = model(mels)  # mags_hat: (N, Ty, n_mags)
                l1_loss = l1_criterion(mags_hat, mags)
                bd_loss = bd_criterion(mags_hat, mags)
                loss = l1_loss + bd_loss
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            scheduler.step()
            optimizer.step()
            epoch_loss += l1_loss.item()
            global_step += 1
            if global_step % args.save_term == 0:
                model.eval()
                val_loss = evaluate(model, valid_loader, l1_criterion, writer,
                                    global_step, args.test_batch)
                model_infos = save_model(model, model_infos, optimizer,
                                         scheduler, val_loss, global_step,
                                         ckpt_dir)  # save best 5 models
                model.train()
        if args.log_mode:
            # Summary
            avg_loss = epoch_loss / (len(data_loader))
            writer.add_scalar('train/loss', avg_loss, global_step)
            writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step)
            if model.name == 'Text2Mel':
                alignment = A[0:1].clone().cpu().detach().numpy()
                writer.add_image('train/alignments', att2img(alignment),
                                 global_step)  # (Tx, Ty)
                if args.ga_mode:
                    writer.add_scalar('train/loss_att', att_loss, global_step)
                text = texts[0].cpu().detach().numpy()
                text = [idx2char[ch] for ch in text]
                plot_att(alignment[0],
                         text,
                         global_step,
                         path=os.path.join(args.logdir, model.name, 'A',
                                           'train'))
                mel_hat = mels_hat[0:1].transpose(1, 2)
                mel = mels[0:1].transpose(1, 2)
                writer.add_image('train/mel_hat', mel_hat, global_step)
                writer.add_image('train/mel', mel, global_step)
            else:
                mag_hat = mags_hat[0:1].transpose(1, 2)
                mag = mags[0:1].transpose(1, 2)
                writer.add_image('train/mag_hat', mag_hat, global_step)
                writer.add_image('train/mag', mag, global_step)
            # print('Training Loss: {}'.format(avg_loss))
        epochs += 1
    print('Training complete')
예제 #9
0
def train(model,
          data_loader,
          valid_loader,
          optimizer,
          scheduler,
          batch_size=32,
          ckpt_dir=None,
          writer=None,
          mode='1'):
    epochs = 0
    global_step = args.global_step
    criterion = nn.L1Loss().to(DEVICE)  # default average
    model_infos = [('None', 10000.)] * 5
    GO_frames = torch.zeros([batch_size, 1, args.n_mels * args.r
                             ]).to(DEVICE)  # (N, Ty/r, n_mels)
    idx2char = load_vocab()[-1]
    while global_step < args.max_step:
        epoch_loss_mel = 0
        epoch_loss_mag = 0
        for step, (texts, mels, mags) in tqdm(enumerate(data_loader),
                                              total=len(data_loader),
                                              unit='B',
                                              ncols=70,
                                              leave=False):
            optimizer.zero_grad()

            texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), mags.to(
                DEVICE)
            prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1)
            mels_hat, mags_hat, A = model(texts, prev_mels)

            loss_mel = criterion(mels_hat, mels)
            loss_mag = criterion(mags_hat, mags)
            loss = loss_mel + loss_mag
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            scheduler.step()
            optimizer.step()

            epoch_loss_mel += loss_mel.item()
            epoch_loss_mag += loss_mag.item()
            global_step += 1
            if global_step % args.save_term == 0:
                model.eval()  #
                val_loss = evaluate(model, valid_loader, criterion, writer,
                                    global_step, args.test_batch)
                model_infos = save_model(model, model_infos, optimizer,
                                         scheduler, val_loss, global_step,
                                         ckpt_dir)  # save best 5 models
                model.train()
        if args.log_mode:
            # Summary
            avg_loss_mel = epoch_loss_mel / (len(data_loader))
            avg_loss_mag = epoch_loss_mag / (len(data_loader))
            writer.add_scalar('train/loss_mel', avg_loss_mel, global_step)
            writer.add_scalar('train/loss_mag', avg_loss_mag, global_step)
            writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step)

            alignment = A[0:1].clone().cpu().detach().numpy()
            writer.add_image('train/alignments', att2img(alignment),
                             global_step)  # (Tx, Ty)
            text = texts[0].cpu().detach().numpy()
            text = [idx2char[ch] for ch in text]
            plot_att(alignment[0],
                     text,
                     global_step,
                     path=os.path.join(args.logdir, model.name, 'A', 'train'))

            mel_hat = mels_hat[0:1].transpose(1, 2)
            mel = mels[0:1].transpose(1, 2)
            writer.add_image('train/mel_hat', mel_hat, global_step)
            writer.add_image('train/mel', mel, global_step)

            mag_hat = mags_hat[0:1].transpose(1, 2)
            mag = mags[0:1].transpose(1, 2)
            writer.add_image('train/mag_hat', mag_hat, global_step)
            writer.add_image('train/mag', mag, global_step)
            # print('Training Loss: {}'.format(avg_loss))
        epochs += 1
    print('Training complete')
예제 #10
0
def visualize_loop(args, val_loader):

    image_feature_size = 512
    lidar_feature_size = 1024

    if args.model_type == 'SAN':
        question_feat_size = 512
        model = SAN(args,
                    question_feat_size,
                    image_feature_size,
                    lidar_feature_size,
                    num_classes=34,
                    qa=None,
                    encoder=args.encoder_type,
                    method='hierarchical')
    if args.model_type == 'MCB':
        question_feat_size = 512
        model = MCB(args,
                    question_feat_size,
                    image_feature_size,
                    lidar_feature_size,
                    num_classes=34,
                    qa=None,
                    encoder=args.encoder_type,
                    method='hierarchical')
    if args.model_type == 'MFB':
        question_feat_size = 512
        # image_feature_size=512
        model = MFB(args,
                    question_feat_size,
                    image_feature_size,
                    lidar_feature_size,
                    num_classes=34,
                    qa=None,
                    encoder=args.encoder_type,
                    method='hierarchical')
    if args.model_type == 'MLB':
        question_feat_size = 1024
        image_feature_size = 512
        model = MLB(args,
                    question_feat_size,
                    image_feature_size,
                    lidar_feature_size,
                    num_classes=34,
                    qa=None,
                    encoder=args.encoder_type,
                    method='hierarchical')
    if args.model_type == 'MUTAN':
        question_feat_size = 1024
        image_feature_size = 512
        model = MUTAN(args,
                      question_feat_size,
                      image_feature_size,
                      lidar_feature_size,
                      num_classes=34,
                      qa=None,
                      encoder=args.encoder_type,
                      method='hierarchical')
    if args.model_type == 'DAN':
        question_feat_size = 512
        model = DAN(args,
                    question_feat_size,
                    image_feature_size,
                    lidar_feature_size,
                    num_classes=34,
                    qa=None,
                    encoder=args.encoder_type,
                    method='hierarchical')

    data = load_weights(args, model, optimizer=None)
    if type(data) == list:
        model, optimizer, start_epoch, loss, accuracy = data
        print("Loaded  weights")
        print("Epoch: %d, loss: %.3f, Accuracy: %.4f " %
              (start_epoch, loss, accuracy),
              flush=True)
    else:
        print(" error occured while loading model training freshly")
        model = data
        return

    ###########################################################################multiple GPU use#
    # if torch.cuda.device_count() > 1:
    #     print("Using ", torch.cuda.device_count(), "GPUs!")
    #     model = nn.DataParallel(model)

    model.to(device=args.device)
    model.eval()

    import argoverse
    from argoverse.data_loading.argoverse_tracking_loader import ArgoverseTrackingLoader
    from argoverse.utils.json_utils import read_json_file
    from argoverse.map_representation.map_api import ArgoverseMap

    vocab = load_vocab(os.path.join(args.input_base, args.vocab))
    argoverse_loader = ArgoverseTrackingLoader(
        '../../../Data/train/argoverse-tracking')

    k = 1
    with torch.no_grad():
        for data in tqdm(val_loader):
            question, image_feature, ques_lengths, point_set, answer, image_name = data
            question = question.to(device=args.device)
            ques_lengths = ques_lengths.to(device=args.device)
            image_feature = image_feature.to(device=args.device)
            point_set = point_set.to(device=args.device)

            pred, wgt, energies = model(question, image_feature, ques_lengths,
                                        point_set)

            question = question.cpu().data.numpy()
            answer = answer.cpu().data.numpy()
            pred = F.softmax(pred, dim=1)
            pred = torch.argmax(pred, dim=1)
            pred = np.asarray(pred.cpu().data)
            wgt = wgt.cpu().data.numpy()
            energies = energies.squeeze(1).cpu().data.numpy()
            ques_lengths = ques_lengths.cpu().data.numpy()
            pat = re.compile(r'(.*)@(.*)')
            _, keep = np.where([answer == pred])
            temp_batch_size = question.shape[0]
            for b in range(temp_batch_size):
                q = get_ques(question[b], ques_lengths[b], vocab)
                ans = get_ans(answer[b])
                pred_ans = get_ans(pred[b])
                # print(q,ans)
                c = list(re.findall(pat, image_name[b]))[0]
                log_id = c[0]
                idx = int(c[1])
                print(k)
                argoverse_data = argoverse_loader.get(log_id)
                if args.model_type == 'SAN':
                    plot_att(argoverse_data, idx, wgt[b, :, 1, :], energies[b],
                             q, ans, args.save_dir, k, pred_ans)
                if args.model_type == 'MCB':
                    plot_att(argoverse_data, idx, wgt[b], energies[b], q, ans,
                             args.save_dir, k, pred_ans)
                if args.model_type == 'MFB':
                    plot_att(argoverse_data, idx, wgt[b, :, :, 1], energies[b],
                             q, ans, args.save_dir, k, pred_ans)
                if args.model_type == 'MLB':
                    plot_att(argoverse_data, idx, wgt[b, :, 3, :], energies[b],
                             q, ans, args.save_dir, k, pred_ans)
                if args.model_type == 'MUTAN':  #only two glimpses
                    plot_att(argoverse_data, idx, wgt[b, :, 1, :], energies[b],
                             q, ans, args.save_dir, k, pred_ans)
                if args.model_type == 'DAN':  #only two memory
                    plot_att(argoverse_data, idx, wgt[b, :, 1, :], energies[b],
                             q, ans, args.save_dir, k, pred_ans)

                k = k + 1