cfg.dataset.motion_feat = os.path.join( cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name)) test_loader_kwargs = { 'question_type': cfg.dataset.question_type, 'question_pt': cfg.dataset.test_question_pt, 'vocab_json': cfg.dataset.vocab_json, 'appearance_feat': cfg.dataset.appearance_feat, 'motion_feat': cfg.dataset.motion_feat, 'test_num': cfg.test.test_num, 'batch_size': cfg.train.batch_size, 'num_workers': cfg.num_workers, 'shuffle': False } test_loader = VideoQADataLoader(**test_loader_kwargs) model_kwargs.update({'vocab': test_loader.vocab}) model = HCRN.HCRNNetwork(**model_kwargs).to(device) model.load_state_dict(loaded['state_dict']) if cfg.test.write_preds: acc, preds, gts, v_ids, q_ids = validate(cfg, model, test_loader, device, cfg.test.write_preds) sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format( test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold']))) sys.stdout.flush() # write predictions for visualization purposes output_dir = os.path.join(cfg.dataset.save_dir, 'preds') if not os.path.exists(output_dir):
def train(cfg): logging.info("Create train_loader and val_loader.........") train_loader_kwargs = { 'question_type': cfg.dataset.question_type, 'question_pt': cfg.dataset.train_question_pt, 'vocab_json': cfg.dataset.vocab_json, 'appearance_feat': cfg.dataset.appearance_feat, 'motion_feat': cfg.dataset.motion_feat, 'train_num': cfg.train.train_num, 'batch_size': cfg.train.batch_size, 'num_workers': cfg.num_workers, 'shuffle': True, # 'pin_memory': True } if cfg.bert.flag: if 'precomputed' in cfg.bert.model: train_loader_kwargs['question_feat'] = cfg.bert.train_question_feat train_loader_kwargs['bert_model'] = cfg.bert.model train_loader = VideoQADataLoader(**train_loader_kwargs) logging.info("number of train instances: {}".format( len(train_loader.dataset))) if cfg.val.flag: val_loader_kwargs = { 'question_type': cfg.dataset.question_type, 'question_pt': cfg.dataset.val_question_pt, 'vocab_json': cfg.dataset.vocab_json, 'appearance_feat': cfg.dataset.appearance_feat, 'motion_feat': cfg.dataset.motion_feat, 'val_num': cfg.val.val_num, 'batch_size': cfg.train.batch_size, 'num_workers': cfg.num_workers, 'shuffle': False, # 'pin_memory': True } if cfg.bert.flag: if 'precomputed' in cfg.bert.model: val_loader_kwargs['question_feat'] = cfg.bert.val_question_feat val_loader_kwargs['bert_model'] = cfg.bert.model val_loader = VideoQADataLoader(**val_loader_kwargs) logging.info("number of val instances: {}".format( len(val_loader.dataset))) logging.info("Create model.........") device = 'cuda' if torch.cuda.is_available() else 'cpu' model_kwargs = { 'vision_dim': cfg.train.vision_dim, 'module_dim': cfg.train.module_dim, 'word_dim': cfg.train.word_dim, 'k_max_frame_level': cfg.train.k_max_frame_level, 'k_max_clip_level': cfg.train.k_max_clip_level, 'spl_resolution': cfg.train.spl_resolution, 'vocab': train_loader.vocab, 'question_type': cfg.dataset.question_type, 'hcrn_model': cfg.train.hcrn_model, 'subvids': cfg.train.subvids, } if cfg.bert.flag: model_kwargs['bert_model'] = cfg.bert.model model_kwargs['word_dim'] = cfg.bert.word_dim model_kwargs_tosave = { k: v for k, v in model_kwargs.items() if k != 'vocab' } model = HCRN.HCRNNetwork(**model_kwargs).to(device) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logging.info('num of params: {}'.format(pytorch_total_params)) logging.info(model) if cfg.train.glove: logging.info('load glove vectors') train_loader.glove_matrix = torch.FloatTensor( train_loader.glove_matrix).to(device) with torch.no_grad(): model.linguistic_input_unit.encoder_embed.weight.set_( train_loader.glove_matrix) if torch.cuda.device_count() > 1 and cfg.multi_gpus: model = model.cuda() logging.info("Using {} GPUs".format(torch.cuda.device_count())) model = nn.DataParallel(model, device_ids=None) optimizer = optim.Adam(model.parameters(), cfg.train.lr) start_epoch = 0 if cfg.dataset.question_type == 'count': best_val = 100.0 else: best_val = 0 if cfg.train.restore: print("Restore checkpoint and optimizer...") ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt') ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage) start_epoch = ckpt['epoch'] + 1 # best_val = ckpt['best_val'] model.load_state_dict(ckpt['state_dict']) optimizer.load_state_dict(ckpt['optimizer']) if cfg.dataset.question_type in ['frameqa', 'none']: criterion = nn.CrossEntropyLoss().to(device) elif cfg.dataset.question_type == 'count': criterion = nn.MSELoss().to(device) logging.info("Start training........") for epoch in range(start_epoch, cfg.train.max_epochs): logging.info('>>>>>> epoch {epoch} <<<<<<'.format( epoch=colored("{}".format(epoch), "green", attrs=["bold"]))) model.train() total_acc, count = 0, 0 batch_mse_sum = 0.0 total_loss, avg_loss = 0.0, 0.0 avg_loss = 0 train_accuracy = 0 for i, batch in enumerate(iter(train_loader)): progress = epoch + i / len(train_loader) _, _, answers, *batch_input = [todevice(x, device) for x in batch] answers = todevice(answers, device).squeeze() loss, metric = train_it(batch_input, answers, model, criterion, optimizer, cfg) total_loss += loss avg_loss = total_loss / (i + 1) if cfg.dataset.question_type == 'count': batch_avg_mse = metric / answers.size(0) batch_mse_sum += metric count += answers.size(0) avg_mse = batch_mse_sum / count sys.stdout.write( "\rProgress = {progress} ce_loss = {ce_loss} avg_loss = {avg_loss} train_mse = {train_mse} avg_mse = {avg_mse} exp: {exp_name}" .format(progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), ce_loss=colored("{:.4f}".format(loss), "blue", attrs=['bold']), avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), train_mse=colored("{:.4f}".format(batch_avg_mse), "blue", attrs=['bold']), avg_mse=colored("{:.4f}".format(avg_mse), "red", attrs=['bold']), exp_name=cfg.exp_name)) sys.stdout.flush() else: total_acc += metric.sum() count += answers.size(0) train_accuracy = total_acc / count sys.stdout.write( "\rProgress = {progress} ce_loss = {ce_loss} avg_loss = {avg_loss} train_acc = {train_acc} avg_acc = {avg_acc} exp: {exp_name}" .format(progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), ce_loss=colored("{:.4f}".format(loss), "blue", attrs=['bold']), avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), train_acc=colored("{:.4f}".format(metric.mean()), "blue", attrs=['bold']), avg_acc=colored("{:.4f}".format(train_accuracy), "red", attrs=['bold']), exp_name=cfg.exp_name)) sys.stdout.flush() sys.stdout.write("\n") if cfg.dataset.question_type == 'count': if (epoch + 1) % 5 == 0: optimizer = step_decay(cfg, optimizer) else: if (epoch + 1) % 10 == 0: optimizer = step_decay(cfg, optimizer) sys.stdout.flush() logging.info("Epoch = %s avg_loss = %.3f avg_acc = %.3f" % (epoch, avg_loss, train_accuracy)) if cfg.val.flag: output_dir = os.path.join(cfg.dataset.save_dir, 'preds') if not os.path.exists(output_dir): os.makedirs(output_dir) else: assert os.path.isdir(output_dir) valid_acc = validate(cfg, model, val_loader, device, write_preds=False) if (valid_acc > best_val and cfg.dataset.question_type != 'count' ) or (valid_acc < best_val and cfg.dataset.question_type == 'count'): best_val = valid_acc # Save best model ckpt_dir = os.path.join(cfg.dataset.save_dir, 'ckpt') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) else: assert os.path.isdir(ckpt_dir) save_checkpoint(epoch, model, optimizer, model_kwargs_tosave, os.path.join(ckpt_dir, 'model.pt'), best_val) sys.stdout.write('\n >>>>>> save to %s <<<<<< \n' % (ckpt_dir)) sys.stdout.flush() logging.info('~~~~~~ Valid Accuracy: %.4f ~~~~~~~' % valid_acc) sys.stdout.write( '~~~~~~ Valid Accuracy: {valid_acc} ~~~~~~~\n'.format( valid_acc=colored( "{:.4f}".format(valid_acc), "red", attrs=['bold']))) sys.stdout.flush()
def train(cfg): logging.info("Create train_loader and val_loader.........") train_loader_kwargs = { 'question_type': cfg.dataset.question_type, 'question_pt': cfg.dataset.train_question_pt, 'vocab_json': cfg.dataset.vocab_json, 'appearance_feat': cfg.dataset.appearance_feat, 'motion_feat': cfg.dataset.motion_feat, 'train_num': cfg.train.train_num, 'batch_size': cfg.train.batch_size, 'num_workers': cfg.num_workers, 'shuffle': True } train_loader = VideoQADataLoader(**train_loader_kwargs) logging.info("number of train instances: {}".format( len(train_loader.dataset))) if cfg.val.flag: val_loader_kwargs = { 'question_type': cfg.dataset.question_type, 'question_pt': cfg.dataset.val_question_pt, 'vocab_json': cfg.dataset.vocab_json, 'appearance_feat': cfg.dataset.appearance_feat, 'motion_feat': cfg.dataset.motion_feat, 'val_num': cfg.val.val_num, 'batch_size': cfg.train.batch_size, 'num_workers': cfg.num_workers, 'shuffle': False } val_loader = VideoQADataLoader(**val_loader_kwargs) logging.info("number of val instances: {}".format( len(val_loader.dataset))) logging.info("Create model.........") device = 'cuda' if torch.cuda.is_available() else 'cpu' model_kwargs = { 'vision_dim': cfg.train.vision_dim, 'module_dim': cfg.train.module_dim, 'word_dim': cfg.train.word_dim, 'k_max_frame_level': cfg.train.k_max_frame_level, 'k_max_clip_level': cfg.train.k_max_clip_level, 'spl_resolution': cfg.train.spl_resolution, 'vocab': train_loader.vocab, 'question_type': cfg.dataset.question_type } model_kwargs_tosave = { k: v for k, v in model_kwargs.items() if k != 'vocab' } model = HCRN.HCRNNetwork(**model_kwargs).to(device) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logging.info('num of params: {}'.format(pytorch_total_params)) logging.info(model) if cfg.train.glove: logging.info('load glove vectors') train_loader.glove_matrix = torch.FloatTensor( train_loader.glove_matrix).to(device) with torch.no_grad(): model.linguistic_input_unit.encoder_embed.weight.set_( train_loader.glove_matrix) if torch.cuda.device_count() > 1 and cfg.multi_gpus: model = model.cuda() logging.info("Using {} GPUs".format(torch.cuda.device_count())) model = nn.DataParallel(model, device_ids=None) optimizer = optim.Adam(model.parameters(), cfg.train.lr) start_epoch = 0 if cfg.dataset.question_type == 'count': best_val = 100.0 else: best_val = 0 if cfg.train.restore: print("Restore checkpoint and optimizer...") ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt') ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage) start_epoch = ckpt['epoch'] + 1 model.load_state_dict(ckpt['state_dict']) optimizer.load_state_dict(ckpt['optimizer']) if cfg.dataset.question_type in ['frameqa', 'none']: criterion = nn.CrossEntropyLoss().to(device) elif cfg.dataset.question_type == 'count': criterion = nn.MSELoss().to(device) logging.info("Start training........") for epoch in range(start_epoch, cfg.train.max_epochs): logging.info('>>>>>> epoch {epoch} <<<<<<'.format( epoch=colored("{}".format(epoch), "green", attrs=["bold"]))) model.train() total_acc, count = 0, 0 batch_mse_sum = 0.0 total_loss, avg_loss = 0.0, 0.0 avg_loss = 0 train_accuracy = 0 for i, batch in enumerate(iter(train_loader)): progress = epoch + i / len(train_loader) _, _, answers, *batch_input = [todevice(x, device) for x in batch] answers = answers.cuda().squeeze() batch_size = answers.size(0) optimizer.zero_grad() logits = model(*batch_input) if cfg.dataset.question_type in ['action', 'transition']: batch_agg = np.concatenate( np.tile( np.arange(batch_size).reshape([batch_size, 1]), [1, 5])) * 5 # [0, 0, 0, 0, 0, 5, 5, 5, 5, 1, ...] answers_agg = tile(answers, 0, 5) loss = torch.max( torch.tensor(0.0).cuda(), 1.0 + logits - logits[answers_agg + torch.from_numpy(batch_agg).cuda()]) loss = loss.sum() loss.backward() total_loss += loss.detach() avg_loss = total_loss / (i + 1) nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) optimizer.step() preds = torch.argmax(logits.view(batch_size, 5), dim=1) aggreeings = (preds == answers) elif cfg.dataset.question_type == 'count': answers = answers.unsqueeze(-1) loss = criterion(logits, answers.float()) loss.backward() total_loss += loss.detach() avg_loss = total_loss / (i + 1) nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) optimizer.step() preds = (logits + 0.5).long().clamp(min=1, max=10) batch_mse = (preds - answers)**2 else: loss = criterion(logits, answers) loss.backward() total_loss += loss.detach() avg_loss = total_loss / (i + 1) nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) optimizer.step() aggreeings = batch_accuracy(logits, answers) if cfg.dataset.question_type == 'count': batch_avg_mse = batch_mse.sum().item() / answers.size(0) batch_mse_sum += batch_mse.sum().item() count += answers.size(0) avg_mse = batch_mse_sum / count sys.stdout.write( "\rProgress = {progress} ce_loss = {ce_loss} avg_loss = {avg_loss} train_mse = {train_mse} avg_mse = {avg_mse} exp: {exp_name}" .format(progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), ce_loss=colored("{:.4f}".format(loss.item()), "blue", attrs=['bold']), avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), train_mse=colored("{:.4f}".format(batch_avg_mse), "blue", attrs=['bold']), avg_mse=colored("{:.4f}".format(avg_mse), "red", attrs=['bold']), exp_name=cfg.exp_name)) sys.stdout.flush() else: total_acc += aggreeings.sum().item() count += answers.size(0) train_accuracy = total_acc / count sys.stdout.write( "\rProgress = {progress} ce_loss = {ce_loss} avg_loss = {avg_loss} train_acc = {train_acc} avg_acc = {avg_acc} exp: {exp_name}" .format(progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), ce_loss=colored("{:.4f}".format(loss.item()), "blue", attrs=['bold']), avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), train_acc=colored("{:.4f}".format( aggreeings.float().mean().cpu().numpy()), "blue", attrs=['bold']), avg_acc=colored("{:.4f}".format(train_accuracy), "red", attrs=['bold']), exp_name=cfg.exp_name)) sys.stdout.flush() sys.stdout.write("\n") if cfg.dataset.question_type == 'count': if (epoch + 1) % 5 == 0: optimizer = step_decay(cfg, optimizer) else: if (epoch + 1) % 10 == 0: optimizer = step_decay(cfg, optimizer) sys.stdout.flush() logging.info("Epoch = %s avg_loss = %.3f avg_acc = %.3f" % (epoch, avg_loss, train_accuracy)) if cfg.val.flag: output_dir = os.path.join(cfg.dataset.save_dir, 'preds') if not os.path.exists(output_dir): os.makedirs(output_dir) else: assert os.path.isdir(output_dir) valid_acc = validate(cfg, model, val_loader, device, write_preds=False) if (valid_acc > best_val and cfg.dataset.question_type != 'count' ) or (valid_acc < best_val and cfg.dataset.question_type == 'count'): best_val = valid_acc # Save best model ckpt_dir = os.path.join(cfg.dataset.save_dir, 'ckpt') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) else: assert os.path.isdir(ckpt_dir) save_checkpoint(epoch, model, optimizer, model_kwargs_tosave, os.path.join(ckpt_dir, 'model.pt')) sys.stdout.write('\n >>>>>> save to %s <<<<<< \n' % (ckpt_dir)) sys.stdout.flush() logging.info('~~~~~~ Valid Accuracy: %.4f ~~~~~~~' % valid_acc) sys.stdout.write( '~~~~~~ Valid Accuracy: {valid_acc} ~~~~~~~\n'.format( valid_acc=colored( "{:.4f}".format(valid_acc), "red", attrs=['bold']))) sys.stdout.flush()
def process_final(cfg): assert cfg.dataset.name in [ 'tgif-qa', 'tgif-qa-infer', 'msrvtt-qa', 'msvd-qa' ] assert cfg.dataset.question_type in [ 'frameqa', 'count', 'transition', 'action', 'none' ] # check if the data folder exists assert os.path.exists(cfg.dataset.data_dir) if cfg.dataset.name != 'tgif-qa-infer': device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = 'cpu' # cfg.dataset.save_dir = os.path.join(cfg.dataset.save_dir, cfg.exp_name) ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt') print('ckpt:: ', ckpt) assert os.path.exists(ckpt) # load pretrained model loaded = torch.load(ckpt, map_location='cpu') model_kwargs = loaded['model_kwargs'] if cfg.dataset.name == 'tgif-qa' or cfg.dataset.name == 'tgif-qa-infer': cfg.dataset.test_question_pt = os.path.join( cfg.dataset.data_dir, cfg.dataset.test_question_pt.format(cfg.dataset.name, cfg.dataset.question_type)) cfg.dataset.vocab_json = os.path.join( cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name, cfg.dataset.question_type)) cfg.dataset.appearance_feat = os.path.join( cfg.dataset.video_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name)) cfg.dataset.motion_feat = os.path.join( cfg.dataset.video_dir, cfg.dataset.motion_feat.format(cfg.dataset.name)) else: cfg.dataset.question_type = 'none' cfg.dataset.appearance_feat = '{}_appearance_feat.h5' cfg.dataset.motion_feat = '{}_motion_feat.h5' cfg.dataset.vocab_json = '{}_vocab.json' cfg.dataset.test_question_pt = '{}_test_questions.pt' cfg.dataset.test_question_pt = os.path.join( cfg.dataset.data_dir, cfg.dataset.test_question_pt.format(cfg.dataset.name)) cfg.dataset.vocab_json = os.path.join( cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name)) cfg.dataset.appearance_feat = os.path.join( cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name)) cfg.dataset.motion_feat = os.path.join( cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name)) test_loader_kwargs = { 'question_type': cfg.dataset.question_type, 'question_pt': cfg.dataset.test_question_pt, 'vocab_json': cfg.dataset.vocab_json, 'appearance_feat': cfg.dataset.appearance_feat, 'motion_feat': cfg.dataset.motion_feat, 'test_num': cfg.test.test_num, 'batch_size': cfg.train.batch_size, 'num_workers': cfg.num_workers, 'shuffle': False } test_loader = VideoQADataLoader(**test_loader_kwargs) model_kwargs.update({'vocab': test_loader.vocab}) model = HCRN.HCRNNetwork(**model_kwargs).to(device) model.load_state_dict(loaded['state_dict']) if cfg.test.write_preds: acc, preds, gts, v_ids, q_ids, logits = validate( cfg, model, test_loader, device, cfg.test.write_preds) # print('===Question_type', cfg.dataset.question_type) # print('====LOGIT ', logits) detail = [] if cfg.dataset.question_type in ['action', 'transition']: sm = torch.nn.Softmax(dim=1) print('origin_value::', logits.t()) probs = sm(logits.t()) print('>>>> Probs: ', type(probs), probs.size(), probs) detail = probs.numpy().tolist() elif cfg.dataset.question_type in ['frameqa']: sm = torch.nn.Softmax() probs = sm(logits) print('>>>> Probs: ', type(probs), probs.size(), probs) answer_vocab = test_loader.vocab['answer_idx_to_token'] values, idx = torch.topk(probs, 5) print('>>>Top5 ', idx) top_answer = [] i = 0 for predict in idx: print('FRAMEQA-pred.item:: ', list(predict.numpy()), list(values[i].numpy())) # print('FRAMEQA-answer_vocab:: ', answer_vocab) top_answer.append( ([answer_vocab[ix] for ix in list(predict.numpy())], [float(v) for v in list(values[i].numpy())])) i += 1 print('FRAMEQA-topk:: ', top_answer) detail = top_answer sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format( test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold']))) sys.stdout.flush() # write predictions for visualization purposes output_dir = os.path.join(cfg.dataset.save_dir, 'preds') if not os.path.exists(output_dir): os.makedirs(output_dir) else: assert os.path.isdir(output_dir) preds_file = os.path.join(output_dir, "test_preds.json") if cfg.dataset.question_type in ['action', 'transition']: \ # Find groundtruth questions and corresponding answer candidates vocab = test_loader.vocab['question_answer_idx_to_token'] dict = {} with open(cfg.dataset.test_question_pt, 'rb') as f: obj = pickle.load(f) questions = obj['questions'] org_v_ids = obj['video_ids'] org_v_names = obj['video_names'] org_q_ids = obj['question_id'] ans_candidates = obj['ans_candidates'] for idx in range(len(org_q_ids)): dict[str(org_q_ids[idx])] = [ org_v_names[idx], questions[idx], ans_candidates[idx] ] instances = [{ 'video_id': video_id, 'question_id': q_id, 'video_name': dict[str(q_id)][0], 'question': [ vocab[word.item()] for word in dict[str(q_id)][1] if word != 0 ], 'answer': answer, 'prediction': pred, 'detail': d } for video_id, q_id, answer, pred, d in zip( np.hstack(v_ids).tolist(), np.hstack(q_ids).tolist(), gts, preds, detail)] # write preditions to json file # with open(preds_file, 'w') as f: # json.dump(instances, f) sys.stdout.write('Display 10 samples...\n') # Display 10 samples if cfg.dataset.name == 'tgif-qa-infer': sample_size = 1 else: sample_size = 10 for idx in range(sample_size): print('Video name: {}'.format(dict[str(q_ids[idx].item())][0])) cur_question = [ vocab[word.item()] for word in dict[str(q_ids[idx].item())][1] if word != 0 ] print('Question: ' + ' '.join(cur_question) + '?') all_answer_cands = dict[str(q_ids[idx].item())][2] for cand_id in range(len(all_answer_cands)): cur_answer_cands = [ vocab[word.item()] for word in all_answer_cands[cand_id] if word != 0 ] print('({}): '.format(cand_id) + ' '.join(cur_answer_cands)) print('Prediction: {}'.format(preds[idx])) print('Groundtruth: {}'.format(gts[idx])) return instances else: vocab = test_loader.vocab['question_idx_to_token'] dict = {} with open(cfg.dataset.test_question_pt, 'rb') as f: obj = pickle.load(f) questions = obj['questions'] org_v_ids = obj['video_ids'] org_v_names = obj['video_names'] org_q_ids = obj['question_id'] for idx in range(len(org_q_ids)): dict[str(org_q_ids[idx])] = [org_v_names[idx], questions[idx]] if cfg.dataset.question_type == 'frameqa': instances = [{ 'video_id': video_id, 'question_id': q_id, 'video_name': str(dict[str(q_id)][0]), 'question': [ vocab[word.item()] for word in dict[str(q_id)][1] if word != 0 ], 'answer': answer, 'prediction': pred, 'detail': d } for video_id, q_id, answer, pred, d in zip( np.hstack(v_ids).tolist(), np.hstack(q_ids).tolist(), gts, preds, detail)] else: instances = [{ 'video_id': video_id, 'question_id': q_id, 'video_name': str(dict[str(q_id)][0]), 'question': [ vocab[word.item()] for word in dict[str(q_id)][1] if word != 0 ], 'answer': answer, 'prediction': pred } for video_id, q_id, answer, pred in zip( np.hstack(v_ids).tolist(), np.hstack(q_ids).tolist(), gts, preds)] # write preditions to json file # with open(preds_file, 'w') as f: # json.dump(instances, f) sys.stdout.write('Display 10 samples...\n') # Display 10 samples if cfg.dataset.name == 'tgif-qa-infer': sample_size = 1 else: sample_size = 10 for idx in range(sample_size): print('Video name: {}'.format(dict[str(q_ids[idx].item())][0])) cur_question = [ vocab[word.item()] for word in dict[str(q_ids[idx].item())][1] if word != 0 ] print('Question: ' + ' '.join(cur_question) + '?') print('Prediction: {}'.format(preds[idx])) print('Groundtruth: {}'.format(gts[idx])) return instances else: acc = validate(cfg, model, test_loader, device, cfg.test.write_preds) sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format( test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold']))) sys.stdout.flush() return []