def synthesize(t2m, ssrn, data_loader, batch_size=100): ''' DCTTS Architecture Text --> Text2Mel --> SSRN --> Wav file ''' # Text2Mel idx2char = load_vocab()[-1] with torch.no_grad(): print('=' * 10, ' Text2Mel ', '=' * 10) for step, (texts, _, _) in tqdm(enumerate(data_loader), total=len(data_loader), ncols=70): texts = texts.to(DEVICE) prev_mel_hats = torch.zeros([len(texts), args.max_Ty, args.n_mels]).to(DEVICE) total_mel_hats, A = t2m.synthesize(texts, prev_mel_hats) alignments = A.cpu().detach().numpy() visual_texts = texts.cpu().detach().numpy() # Mel --> Mag mags = ssrn(total_mel_hats) # mag: (N, Ty, n_mags) mags = mags.cpu().detach().numpy() for idx in range(len(mags)): fname = step * batch_size + idx text = [idx2char[ch] for ch in visual_texts[idx]] utils.plot_att(alignments[idx], text, args.global_step, path=os.path.join(args.sampledir, 'A'), name='{:02d}.png'.format(fname)) wav = utils.spectrogram2wav(mags[idx]) write(os.path.join(args.sampledir, '{:02d}.wav'.format(fname)), args.sr, wav) return None
def evaluate(model, data_loader, criterion, writer, global_step, batch_size=100): valid_loss = 0. A = None with torch.no_grad(): for step, (texts, mels, extras) in enumerate(data_loader): if model.name == 'Text2Mel': first_frames = torch.zeros([mels.shape[0], 1, args.n_mels]).to(DEVICE) # (N, Ty/r, n_mels) texts, mels = texts.to(DEVICE), mels.to(DEVICE) prev_mels = torch.cat((first_frames, mels[:, :-1, :]), 1) mels_hat, A = model(texts, prev_mels) # mels_hat: (N, Ty/r, n_mels), A: (N, Tx, Ty/r) loss = criterion(mels_hat, mels) elif model.name == 'SSRN': texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), extras.to(DEVICE) mags_hat = model(mels) # Predict loss = criterion(mags_hat, mags) valid_loss += loss.item() avg_loss = valid_loss / (len(data_loader)) writer.add_scalar('eval/loss', avg_loss, global_step) if model.name == 'Text2Mel': alignment = A[0:1].clone().cpu().detach().numpy() writer.add_image('eval/alignments', att2img(alignment), global_step) # (Tx, Ty) text = texts[0].cpu().detach().numpy() text = [load_vocab()[-1][ch] for ch in text] plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, model.name, 'A')) mel_hat = mels_hat[0:1].transpose(1,2) mel = mels[0:1].transpose(1, 2) writer.add_image('eval/mel_hat', mel_hat, global_step) writer.add_image('eval/mel', mel, global_step) else: mag_hat = mags_hat[0:1].transpose(1, 2) mag = mags[0:1].transpose(1, 2) writer.add_image('eval/mag_hat', mag_hat, global_step) writer.add_image('eval/mag', mag, global_step) return avg_loss
def synthesize(t2m, ssrn, data_loader, batch_size=100): ''' DCTTS Architecture Text --> Text2Mel --> SSRN --> Wav file ''' text2mel_total_time = 0 # Text2Mel idx2char = load_vocab()[-1] with torch.no_grad(): print('='*10, ' Text2Mel ', '='*10) is_test = [True, False] total_mel_hats = torch.zeros([len(data_loader.dataset), args.max_Ty, args.n_mels]).to(DEVICE) mags = torch.zeros([len(data_loader.dataset), args.max_Ty*args.r, args.n_mags]).to(DEVICE) for step, (texts, mel, _) in enumerate(data_loader): texts = texts.to(DEVICE) prev_mel_hats = torch.zeros([len(texts), args.max_Ty, args.n_mels]).to(DEVICE) text2mel_start_time = time.time() for t in tqdm(range(args.max_Ty-1), unit='B', ncols=70): if t == args.max_Ty - 2: is_test[1] = True mel_hats, A, result_tuple = t2m(texts, prev_mel_hats, t, is_test) # mel: (N, Ty/r, n_mels) prev_mel_hats[:, t+1, :] = mel_hats[:, t, :] print(mel_hats.sum(), mel.sum()) text2mel_finish_time = time.time() text2mel_total_time += (text2mel_finish_time - text2mel_start_time) total_mel_hats[step*batch_size:(step+1)*batch_size, :, :] = prev_mel_hats print('='*10, ' Alignment ', '='*10) alignments = A.cpu().detach().numpy() visual_texts = texts.cpu().detach().numpy() for idx in range(len(alignments)): text = [idx2char[ch] for ch in visual_texts[idx]] utils.plot_att(alignments[idx], text, args.global_step, path=os.path.join(args.sampledir, 'A'), name='{}.png'.format(idx)) print('='*10, ' SSRN ', '='*10) # Mel --> Mag mags[step*batch_size:(step+1)*batch_size:, :, :] = \ ssrn(total_mel_hats[step*batch_size:(step+1)*batch_size, :, :]) # mag: (N, Ty, n_mags) mags = mags.cpu().detach().numpy() print('='*10, ' Vocoder ', '='*10) for idx in trange(len(mags), unit='B', ncols=70): wav = utils.spectrogram2wav(mags[idx]) write(os.path.join(args.sampledir, '{}.wav'.format(idx+1)), args.sr, wav) result = list(result_tuple) result.append(text2mel_total_time) return result
def evaluate(model, data_loader, criterion, writer, global_step, batch_size=100): valid_loss_mel = 0. valid_loss_mag = 0. A = None with torch.no_grad(): for step, (texts, mels, mags) in enumerate(data_loader): texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), mags.to( DEVICE) GO_frames = torch.zeros([mels.shape[0], 1, args.n_mels * args.r ]).to(DEVICE) # (N, Ty/r, n_mels) prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1) mels_hat, mags_hat, A = model(texts, prev_mels) loss_mel = criterion(mels_hat, mels) loss_mag = criterion(mags_hat, mags) valid_loss_mel += loss_mel.item() valid_loss_mag += loss_mag.item() avg_loss_mel = valid_loss_mel / (len(data_loader)) avg_loss_mag = valid_loss_mag / (len(data_loader)) writer.add_scalar('eval/loss_mel', avg_loss_mel, global_step) writer.add_scalar('eval/loss_mag', avg_loss_mag, global_step) alignment = A[0:1].clone().cpu().detach().numpy() writer.add_image('eval/alignments', att2img(alignment), global_step) # (Tx, Ty) text = texts[0].cpu().detach().numpy() text = [load_vocab()[-1][ch] for ch in text] plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, model.name, 'A')) mel_hat = mels_hat[0:1].transpose(1, 2) mel = mels[0:1].transpose(1, 2) writer.add_image('eval/mel_hat', mel_hat, global_step) writer.add_image('eval/mel', mel, global_step) mag_hat = mags_hat[0:1].transpose(1, 2) mag = mags[0:1].transpose(1, 2) writer.add_image('eval/mag_hat', mag_hat, global_step) writer.add_image('eval/mag', mag, global_step) return avg_loss_mel
def synthesize(model, data_loader, batch_size=100): ''' Tacotron ''' idx2char = load_vocab()[-1] with torch.no_grad(): print('*' * 15, ' Synthesize ', '*' * 15) mags = torch.zeros( [len(data_loader.dataset), args.max_Ty * args.r, args.n_mags]).to(DEVICE) for step, (texts, _, _) in enumerate(data_loader): texts = texts.to(DEVICE) GO_frames = torch.zeros([texts.shape[0], 1, args.n_mels * args.r]).to(DEVICE) _, mags_hat, A = model(texts, GO_frames, synth=True) print('=' * 10, ' Alignment ', '=' * 10) alignments = A.cpu().detach().numpy() visual_texts = texts.cpu().detach().numpy() for idx in range(len(alignments)): text = [idx2char[ch] for ch in visual_texts[idx]] utils.plot_att(alignments[idx], text, args.global_step, path=os.path.join(args.sampledir, 'A'), name='{}.png'.format(idx + step * batch_size)) mags[step * batch_size:(step + 1) * batch_size:, :, :] = mags_hat # mag: (N, Ty, n_mags) print('=' * 10, ' Vocoder ', '=' * 10) mags = mags.cpu().detach().numpy() for idx in trange(len(mags), unit='B', ncols=70): wav = utils.spectrogram2wav(mags[idx]) write(os.path.join(args.sampledir, '{}.wav'.format(idx + 1)), args.sr, wav) return None
def train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=32, ckpt_dir=None, writer=None, DEVICE=None): """ train function :param model: nn module object :param data_loader: data loader for training set :param valid_loader: data loader for validation set :param optimizer: optimizer :param scheculer: for scheduling learning rate :param batch_size: Scalar :param ckpt_dir: String. checkpoint directory :param writer: Tensorboard writer :param DEVICE: 'cpu' or 'gpu' """ epochs = 0 global_step = args.global_step criterion = nn.L1Loss() # default average bce_loss = nn.BCELoss() xe_loss = nn.CrossEntropyLoss() GO_frames = torch.zeros([batch_size, 1, args.n_mels * args.r ]).to(DEVICE) # (N, Ty/r, n_mels) idx2char = load_vocab()[-1] while global_step < args.max_step: epoch_loss_mel, epoch_loss_fmel, epoch_loss_ff = 0., 0., 0. for step, (texts, mels, ff) in tqdm(enumerate(data_loader), total=len(data_loader), unit='B', ncols=70, leave=False): optimizer.zero_grad() texts, mels, ff = texts.to(DEVICE), mels.to(DEVICE), ff.to(DEVICE) prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1) refs = mels.view(mels.size(0), -1, args.n_mels).unsqueeze(1) # (N, 1, Ty, n_mels) if type(model).__name__ == 'TPGST': mels_hat, fmels_hat, A, style_attentions, ff_hat, se, tpse = model( texts, prev_mels, refs) loss_se = criterion(tpse, se.detach()) else: mels_hat, fmels_hat, A, ff_hat = model(texts, prev_mels) loss_mel = criterion(mels_hat, mels) fmels = mels.view(mels.size(0), -1, args.n_mels) loss_fmel = criterion(fmels_hat, fmels) loss_ff = bce_loss(ff_hat, ff) if global_step > args.tp_start and type(model).__name__ == 'TPGST': loss = loss_mel + 0.01 * loss_ff + 0.01 * loss_se else: loss = loss_mel + 0.01 * loss_ff loss.backward() # nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() scheduler.step() epoch_loss_mel += loss_mel.item() epoch_loss_fmel += loss_fmel.item() epoch_loss_ff += loss_ff.item() if global_step % args.log_term == 0: writer.add_scalar('batch/loss_mel', loss_mel.item(), global_step) if type(model).__name__ == 'TPGST': writer.add_scalar('batch/loss_se', loss_se.item(), global_step) writer.add_scalar('batch/loss_ff', loss_ff.item(), global_step) writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step) if global_step % args.eval_term == 0: model.eval() # val_loss = evaluate(model, valid_loader, criterion, writer, global_step, DEVICE=DEVICE) model.train() if global_step % args.save_term == 0: save_model(model, optimizer, scheduler, val_loss, global_step, ckpt_dir) # save best 5 models global_step += 1 if args.log_mode: # Summary avg_loss_mel = epoch_loss_mel / (len(data_loader)) avg_loss_fmel = epoch_loss_fmel / (len(data_loader)) avg_loss_ff = epoch_loss_ff / (len(data_loader)) writer.add_scalar('train/loss_mel', avg_loss_mel, global_step) writer.add_scalar('train/loss_fmel', avg_loss_fmel, global_step) writer.add_scalar('train/loss_ff', avg_loss_ff, global_step) writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step) alignment = A[0:1].clone().cpu().detach().numpy() writer.add_image('train/alignments', att2img(alignment), global_step) # (Tx, Ty) text = texts[0].cpu().detach().numpy() text = [idx2char[ch] for ch in text] plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, type(model).__name__, 'A', 'train')) mel_hat = mels_hat[0:1].transpose(1, 2) fmel_hat = fmels_hat[0:1].transpose(1, 2) mel = mels[0:1].transpose(1, 2) writer.add_image('train/mel_hat', mel_hat, global_step) writer.add_image('train/fmel_hat', fmel_hat, global_step) writer.add_image('train/mel', mel, global_step) if type(model).__name__ == 'TPGST': styleA = style_attentions.unsqueeze(0) * 255. writer.add_image('train/styleA', styleA, global_step) # print('Training Loss: {}'.format(avg_loss)) epochs += 1 print('Training complete')
def evaluate(model, data_loader, criterion, writer, global_step, DEVICE=None): """ To evaluate with validation set :param model: nn module object :param data_loader: data loader :param criterion: criterion for spectorgrams :param writer: Tensorboard writer :param global_step: Scalar. global step :param DEVICE: 'cpu' or 'gpu' """ bce_loss = nn.BCELoss() xe_loss = nn.CrossEntropyLoss() valid_loss_mel, valid_loss_fmel, valid_loss_ff, valid_loss_se = 0., 0., 0., 0. A = None with torch.no_grad(): for step, (texts, mels, ff) in enumerate(data_loader): texts, mels, ff = texts.to(DEVICE), mels.to(DEVICE), ff.to(DEVICE) GO_frames = torch.zeros([mels.shape[0], 1, args.n_mels * args.r ]).to(DEVICE) # (N, Ty/r, n_mels) prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1) refs = mels.view(mels.size(0), -1, args.n_mels).unsqueeze(1) # (N, 1, Ty, n_mels) if type(model).__name__ == 'TPGST': mels_hat, fmels_hat, A, style_attentions, ff_hat, se, tpse = model( texts, prev_mels, refs) loss_se = criterion(tpse, se) valid_loss_se += loss_se.item() else: mels_hat, fmels_hat, A, ff_hat = model(texts, prev_mels) loss_mel = criterion(mels_hat, mels) fmels = mels.view(mels.size(0), -1, args.n_mels) loss_fmel = criterion(fmels_hat, fmels) loss_ff = bce_loss(ff_hat, ff) valid_loss_mel += loss_mel.item() valid_loss_fmel += loss_fmel.item() valid_loss_ff += loss_ff.item() avg_loss_mel = valid_loss_mel / (len(data_loader)) avg_loss_fmel = valid_loss_fmel / (len(data_loader)) avg_loss_ff = valid_loss_ff / (len(data_loader)) writer.add_scalar('eval/loss_mel', avg_loss_mel, global_step) writer.add_scalar('eval/loss_fmel', avg_loss_fmel, global_step) writer.add_scalar('eval/loss_ff', avg_loss_ff, global_step) alignment = A[0:1].clone().cpu().detach().numpy() writer.add_image('eval/alignments', att2img(alignment), global_step) # (Tx, Ty) text = texts[0].cpu().detach().numpy() text = [load_vocab()[-1][ch] for ch in text] plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, type(model).__name__, 'A')) mel_hat = mels_hat[0:1].transpose(1, 2) fmel_hat = fmels_hat[0:1].transpose(1, 2) mel = mels[0:1].transpose(1, 2) writer.add_image('eval/mel_hat', mel_hat, global_step) writer.add_image('eval/fmel_hat', fmel_hat, global_step) writer.add_image('eval/mel', mel, global_step) if type(model).__name__ == 'TPGST': avg_loss_se = valid_loss_se / (len(data_loader)) writer.add_scalar('eval/loss_se', avg_loss_se, global_step) styleA = style_attentions.view(1, mels.size(0), args.n_tokens) * 255. writer.add_image('eval/styleA', styleA, global_step) return avg_loss_mel
def train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=32, ckpt_dir=None, writer=None, mode='1'): epochs = 0 global_step = args.global_step l1_criterion = nn.L1Loss().to(DEVICE) # default average bd_criterion = nn.BCELoss().to(DEVICE) model_infos = [('None', 10000.)] * 5 first_frames = torch.zeros([batch_size, 1, args.n_mels]).to(DEVICE) # (N, Ty/r, n_mels) idx2char = load_vocab()[-1] while global_step < args.max_step: epoch_loss = 0 train_iter = iter(data_loader) data = train_iter.next() for step, (texts, mels, extras) in tqdm(enumerate(data_loader), total=len(data_loader), unit='B', ncols=70, leave=False): optimizer.zero_grad() if model.name == 'Text2Mel': if args.ga_mode: texts, mels, gas = texts.to(DEVICE), mels.to( DEVICE), extras.to(DEVICE) else: texts, mels = texts.to(DEVICE), mels.to(DEVICE) prev_mels = torch.cat((first_frames, mels[:, :-1, :]), 1) mels_hat, A, _ = model( texts, prev_mels, 0) # mels_hat: (N, Ty/r, n_mels), A: (N, Tx, Ty/r) if args.ga_mode: l1_loss = l1_criterion(mels_hat, mels) bd_loss = bd_criterion(mels_hat, mels) att_loss = torch.mean(A * gas) loss = l1_loss + bd_loss + att_loss else: l1_loss = l1_criterion(mels_hat, mels) bd_loss = bd_criterion(mels_hat, mels) loss = l1_loss + bd_loss elif model.name == 'SSRN': texts, mels, mags = texts.to(DEVICE), mels.to( DEVICE), extras.to(DEVICE) mags_hat = model(mels) # mags_hat: (N, Ty, n_mags) l1_loss = l1_criterion(mags_hat, mags) bd_loss = bd_criterion(mags_hat, mags) loss = l1_loss + bd_loss loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 2.0) scheduler.step() optimizer.step() epoch_loss += l1_loss.item() global_step += 1 if global_step % args.save_term == 0: model.eval() val_loss = evaluate(model, valid_loader, l1_criterion, writer, global_step, args.test_batch) model_infos = save_model(model, model_infos, optimizer, scheduler, val_loss, global_step, ckpt_dir) # save best 5 models model.train() if args.log_mode: # Summary avg_loss = epoch_loss / (len(data_loader)) writer.add_scalar('train/loss', avg_loss, global_step) writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step) if model.name == 'Text2Mel': alignment = A[0:1].clone().cpu().detach().numpy() writer.add_image('train/alignments', att2img(alignment), global_step) # (Tx, Ty) if args.ga_mode: writer.add_scalar('train/loss_att', att_loss, global_step) text = texts[0].cpu().detach().numpy() text = [idx2char[ch] for ch in text] plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, model.name, 'A', 'train')) mel_hat = mels_hat[0:1].transpose(1, 2) mel = mels[0:1].transpose(1, 2) writer.add_image('train/mel_hat', mel_hat, global_step) writer.add_image('train/mel', mel, global_step) else: mag_hat = mags_hat[0:1].transpose(1, 2) mag = mags[0:1].transpose(1, 2) writer.add_image('train/mag_hat', mag_hat, global_step) writer.add_image('train/mag', mag, global_step) # print('Training Loss: {}'.format(avg_loss)) epochs += 1 print('Training complete')
def train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=32, ckpt_dir=None, writer=None, mode='1'): epochs = 0 global_step = args.global_step criterion = nn.L1Loss().to(DEVICE) # default average model_infos = [('None', 10000.)] * 5 GO_frames = torch.zeros([batch_size, 1, args.n_mels * args.r ]).to(DEVICE) # (N, Ty/r, n_mels) idx2char = load_vocab()[-1] while global_step < args.max_step: epoch_loss_mel = 0 epoch_loss_mag = 0 for step, (texts, mels, mags) in tqdm(enumerate(data_loader), total=len(data_loader), unit='B', ncols=70, leave=False): optimizer.zero_grad() texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), mags.to( DEVICE) prev_mels = torch.cat((GO_frames, mels[:, :-1, :]), 1) mels_hat, mags_hat, A = model(texts, prev_mels) loss_mel = criterion(mels_hat, mels) loss_mag = criterion(mags_hat, mags) loss = loss_mel + loss_mag loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 5.0) scheduler.step() optimizer.step() epoch_loss_mel += loss_mel.item() epoch_loss_mag += loss_mag.item() global_step += 1 if global_step % args.save_term == 0: model.eval() # val_loss = evaluate(model, valid_loader, criterion, writer, global_step, args.test_batch) model_infos = save_model(model, model_infos, optimizer, scheduler, val_loss, global_step, ckpt_dir) # save best 5 models model.train() if args.log_mode: # Summary avg_loss_mel = epoch_loss_mel / (len(data_loader)) avg_loss_mag = epoch_loss_mag / (len(data_loader)) writer.add_scalar('train/loss_mel', avg_loss_mel, global_step) writer.add_scalar('train/loss_mag', avg_loss_mag, global_step) writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step) alignment = A[0:1].clone().cpu().detach().numpy() writer.add_image('train/alignments', att2img(alignment), global_step) # (Tx, Ty) text = texts[0].cpu().detach().numpy() text = [idx2char[ch] for ch in text] plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, model.name, 'A', 'train')) mel_hat = mels_hat[0:1].transpose(1, 2) mel = mels[0:1].transpose(1, 2) writer.add_image('train/mel_hat', mel_hat, global_step) writer.add_image('train/mel', mel, global_step) mag_hat = mags_hat[0:1].transpose(1, 2) mag = mags[0:1].transpose(1, 2) writer.add_image('train/mag_hat', mag_hat, global_step) writer.add_image('train/mag', mag, global_step) # print('Training Loss: {}'.format(avg_loss)) epochs += 1 print('Training complete')
def visualize_loop(args, val_loader): image_feature_size = 512 lidar_feature_size = 1024 if args.model_type == 'SAN': question_feat_size = 512 model = SAN(args, question_feat_size, image_feature_size, lidar_feature_size, num_classes=34, qa=None, encoder=args.encoder_type, method='hierarchical') if args.model_type == 'MCB': question_feat_size = 512 model = MCB(args, question_feat_size, image_feature_size, lidar_feature_size, num_classes=34, qa=None, encoder=args.encoder_type, method='hierarchical') if args.model_type == 'MFB': question_feat_size = 512 # image_feature_size=512 model = MFB(args, question_feat_size, image_feature_size, lidar_feature_size, num_classes=34, qa=None, encoder=args.encoder_type, method='hierarchical') if args.model_type == 'MLB': question_feat_size = 1024 image_feature_size = 512 model = MLB(args, question_feat_size, image_feature_size, lidar_feature_size, num_classes=34, qa=None, encoder=args.encoder_type, method='hierarchical') if args.model_type == 'MUTAN': question_feat_size = 1024 image_feature_size = 512 model = MUTAN(args, question_feat_size, image_feature_size, lidar_feature_size, num_classes=34, qa=None, encoder=args.encoder_type, method='hierarchical') if args.model_type == 'DAN': question_feat_size = 512 model = DAN(args, question_feat_size, image_feature_size, lidar_feature_size, num_classes=34, qa=None, encoder=args.encoder_type, method='hierarchical') data = load_weights(args, model, optimizer=None) if type(data) == list: model, optimizer, start_epoch, loss, accuracy = data print("Loaded weights") print("Epoch: %d, loss: %.3f, Accuracy: %.4f " % (start_epoch, loss, accuracy), flush=True) else: print(" error occured while loading model training freshly") model = data return ###########################################################################multiple GPU use# # if torch.cuda.device_count() > 1: # print("Using ", torch.cuda.device_count(), "GPUs!") # model = nn.DataParallel(model) model.to(device=args.device) model.eval() import argoverse from argoverse.data_loading.argoverse_tracking_loader import ArgoverseTrackingLoader from argoverse.utils.json_utils import read_json_file from argoverse.map_representation.map_api import ArgoverseMap vocab = load_vocab(os.path.join(args.input_base, args.vocab)) argoverse_loader = ArgoverseTrackingLoader( '../../../Data/train/argoverse-tracking') k = 1 with torch.no_grad(): for data in tqdm(val_loader): question, image_feature, ques_lengths, point_set, answer, image_name = data question = question.to(device=args.device) ques_lengths = ques_lengths.to(device=args.device) image_feature = image_feature.to(device=args.device) point_set = point_set.to(device=args.device) pred, wgt, energies = model(question, image_feature, ques_lengths, point_set) question = question.cpu().data.numpy() answer = answer.cpu().data.numpy() pred = F.softmax(pred, dim=1) pred = torch.argmax(pred, dim=1) pred = np.asarray(pred.cpu().data) wgt = wgt.cpu().data.numpy() energies = energies.squeeze(1).cpu().data.numpy() ques_lengths = ques_lengths.cpu().data.numpy() pat = re.compile(r'(.*)@(.*)') _, keep = np.where([answer == pred]) temp_batch_size = question.shape[0] for b in range(temp_batch_size): q = get_ques(question[b], ques_lengths[b], vocab) ans = get_ans(answer[b]) pred_ans = get_ans(pred[b]) # print(q,ans) c = list(re.findall(pat, image_name[b]))[0] log_id = c[0] idx = int(c[1]) print(k) argoverse_data = argoverse_loader.get(log_id) if args.model_type == 'SAN': plot_att(argoverse_data, idx, wgt[b, :, 1, :], energies[b], q, ans, args.save_dir, k, pred_ans) if args.model_type == 'MCB': plot_att(argoverse_data, idx, wgt[b], energies[b], q, ans, args.save_dir, k, pred_ans) if args.model_type == 'MFB': plot_att(argoverse_data, idx, wgt[b, :, :, 1], energies[b], q, ans, args.save_dir, k, pred_ans) if args.model_type == 'MLB': plot_att(argoverse_data, idx, wgt[b, :, 3, :], energies[b], q, ans, args.save_dir, k, pred_ans) if args.model_type == 'MUTAN': #only two glimpses plot_att(argoverse_data, idx, wgt[b, :, 1, :], energies[b], q, ans, args.save_dir, k, pred_ans) if args.model_type == 'DAN': #only two memory plot_att(argoverse_data, idx, wgt[b, :, 1, :], energies[b], q, ans, args.save_dir, k, pred_ans) k = k + 1