def validate(opt, dset, model, mode="valid"): dset.set_mode(mode) torch.set_grad_enabled(False) model.eval() valid_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=False, collate_fn=pad_collate) valid_qids = [] valid_loss = [] valid_corrects = [] for _, batch in enumerate(valid_loader): model_inputs, targets, qids = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l, device=opt.device) outputs = model(*model_inputs) loss = criterion(outputs, targets) # measure accuracy and record loss valid_qids += [int(x) for x in qids] valid_loss.append(loss.item()) pred_ids = outputs.data.max(1)[1] valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist() if opt.debug: break valid_acc = sum(valid_corrects) / float(len(valid_corrects)) valid_loss = sum(valid_loss) / float(len(valid_corrects)) return valid_acc, valid_loss
def train(opt, dset, model, criterion, optimizer, epoch, previous_best_acc): dset.set_mode("train") model.train() train_loader = DataLoader(dset, batch_size=opt.bsz, shuffle=True, collate_fn=pad_collate) train_loss = [] valid_acc_log = ["batch_idx\tacc"] train_corrects = [] torch.set_grad_enabled(True) for batch_idx, batch in tqdm(enumerate(train_loader)): model_inputs, targets, _ = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l, device=opt.device) outputs = model(*model_inputs) loss = criterion(outputs, targets) # optimizer.zero_grad() # loss.backward() # optimizer.step() # measure accuracy and record loss train_loss.append(loss.item()) pred_ids = tf.max(outputs, 1)[1] train_corrects = pred_ids[tf.where(pred_ids=targets)] # pred_ids = outputs.data.max(1)[1] # train_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist() if batch_idx % opt.log_freq == 0: niter = epoch * len(train_loader) + batch_idx train_acc = sum(train_corrects) / float(len(train_corrects)) train_loss = sum(train_loss) / float(len(train_corrects)) opt.writer.add_scalar("Train/Acc", train_acc, niter) opt.writer.add_scalar("Train/Loss", train_loss, niter) # Test valid_acc, valid_loss = validate(opt, dset, model, mode="valid") opt.writer.add_scalar("Valid/Loss", valid_loss, niter) valid_log_str = "%02d\t%.4f" % (batch_idx, valid_acc) valid_acc_log.append(valid_log_str) if valid_acc > previous_best_acc: previous_best_acc = valid_acc torch.save(model.state_dict(), os.path.join(opt.results_dir, "best_valid.pth")) print(" Train Epoch %d loss %.4f acc %.4f Val loss %.4f acc %.4f" % (epoch, train_loss, train_acc, valid_loss, valid_acc)) # reset to train torch.set_grad_enabled(True) model.train() dset.set_mode("train") train_corrects = [] train_loss = [] if opt.debug: break # additional log with open(os.path.join(opt.results_dir, "valid_acc.log"), "a") as f: f.write("\n".join(valid_acc_log) + "\n") return previous_best_acc
def test(opt, dset, model): dset.set_mode(opt.mode) torch.set_grad_enabled(False) model.eval() valid_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=False, collate_fn=pad_collate) qid2preds = {} qid2targets = {} for valid_idx, batch in tqdm(enumerate(valid_loader)): model_inputs, targets, qids = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l, device=opt.device) outputs = model(*model_inputs) pred_ids = outputs.data.max(1)[1].cpu().numpy().tolist() cur_qid2preds = {qid: pred for qid, pred in zip(qids, pred_ids)} qid2preds = merge_two_dicts(qid2preds, cur_qid2preds) cur_qid2targets = {qid: target for qid, target in zip(qids, targets)} qid2targets = merge_two_dicts(qid2targets, cur_qid2targets) return qid2preds, qid2targets
def main(): """Main script.""" torch.manual_seed(2667) ################################################################################ parser = argparse.ArgumentParser() parser.add_argument('--mode', default='train', help='train/test') parser.add_argument('--save_path', type=str, default=os.path.expanduser('~/kable_management/blp_paper/.results/hme_tvqa/saved_models'), help='path for saving trained models') parser.add_argument('--test', type=int, default=0, help='0 | 1') parser.add_argument('--jobname', type=str , default='hme_tvqa-default', help='jobname to direct to plotter') parser.add_argument("--pool_type", type=str, default="default", choices=["default", "LinearSum", "ConcatMLP", "MCB", "MFH", "MFB", "MLB", "Block"], help="Which pooling technique to use") parser.add_argument("--pool_in_dims", type=int, nargs='+', default=[300,300], help="Input dimensions to pooling layers") parser.add_argument("--pool_out_dim", type=int, default=600, help="Output dimension to pooling layers") parser.add_argument("--pool_hidden_dim", type=int, default=1500, help="Some pooling types come with a hidden internal dimension") parser.add_argument("--results_dir_base", type=str, default=os.path.expanduser("~/kable_management/.results/test")) parser.add_argument("--log_freq", type=int, default=400, help="print, save training info") parser.add_argument("--lr", type=float, default=3e-4, help="learning rate") parser.add_argument("--wd", type=float, default=1e-5, help="weight decay") parser.add_argument("--n_epoch", type=int, default=100, help="number of epochs to run") parser.add_argument("--max_es_cnt", type=int, default=3, help="number of epochs to early stop") parser.add_argument("--bsz", type=int, default=32, help="mini-batch size") parser.add_argument("--test_bsz", type=int, default=100, help="mini-batch size for testing") parser.add_argument("--device", type=int, default=0, help="gpu ordinal, -1 indicates cpu") parser.add_argument("--no_core_driver", action="store_true", help="hdf5 driver, default use `core` (load into RAM), if specified, use `None`") parser.add_argument("--word_count_threshold", type=int, default=2, help="word vocabulary threshold") parser.add_argument("--no_glove", action="store_true", help="not use glove vectors") parser.add_argument("--no_ts", action="store_true", help="no timestep annotation, use full length feature") parser.add_argument("--input_streams", type=str, nargs="+", choices=["vcpt", "sub", "imagenet", "regional"], #added regional support here help="input streams for the model") ################ My stuff parser.add_argument("--modelname", type=str, default="tvqa_abc", help="name of the model ot use") parser.add_argument("--lrtype", type=str, choices=["adam", "cyclic", "radam", "lrelu"], default="adam", help="Kind of learning rate") parser.add_argument("--poolnonlin", type=str, choices=["tanh", "relu", "sigmoid", "None", "lrelu"], default="None", help="add nonlinearities to pooling layer") parser.add_argument("--pool_dropout", type=float, default=0.0, help="Dropout value for the projections") parser.add_argument("--testrun", type=bool, default=False, help="set True to stop writing and visdom") parser.add_argument("--topk", type=int, default=1, help="To use instead of max pooling") parser.add_argument("--nosub", type=bool, default=False, help="Ignore the sub stream") parser.add_argument("--noimg", type=bool, default=False, help="Ignore the imgnet stream") parser.add_argument("--bert", type=str, choices=["default", "mine", "multi_choice", "qa"], default=None, help="What kind of BERT model to use") parser.add_argument("--reg_feat_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/regional_features/100p.h5"), help="regional features") parser.add_argument("--my_vcpt", type=bool, default=False, help="Use my extracted visual concepts") parser.add_argument("--lanecheck", type=bool, default=False, help="Validation lane checks") parser.add_argument("--lanecheck_path", type=str, help="Validation lane check path") parser.add_argument("--best_path", type=str, help="Path to best model") parser.add_argument("--disable_streams", type=str, default=None, nargs="+", choices=["vcpt", "sub", "imagenet", "regional"], #added regional support here help="disable the input stream from voting in the softmax of model outputs") parser.add_argument("--dset", choices=["valid", "test", "train"], default="valid", type=str, help="The dataset to use") ######################## parser.add_argument("--n_layers_cls", type=int, default=1, help="number of layers in classifier") parser.add_argument("--hsz1", type=int, default=150, help="hidden size for the first lstm") parser.add_argument("--hsz2", type=int, default=300, help="hidden size for the second lstm") parser.add_argument("--embedding_size", type=int, default=300, help="word embedding dim") parser.add_argument("--max_sub_l", type=int, default=300, help="max length for subtitle") parser.add_argument("--max_vcpt_l", type=int, default=300, help="max length for visual concepts") parser.add_argument("--max_vid_l", type=int, default=480, help="max length for video feature") parser.add_argument("--vocab_size", type=int, default=0, help="vocabulary size") parser.add_argument("--no_normalize_v", action="store_true", help="do not normalize video featrue") # Data paths parser.add_argument("--train_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_train_processed.json"), help="train set path") parser.add_argument("--valid_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_val_processed.json"), help="valid set path") parser.add_argument("--test_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/tvqa_test_public_processed.json"), help="test set path") parser.add_argument("--glove_path", type=str, default=os.path.expanduser("~/kable_management/data/word_embeddings/glove.6B.300d.txt"), help="GloVe pretrained vector path") parser.add_argument("--vcpt_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/vcpt_features/py2_det_visual_concepts_hq.pickle"), help="visual concepts feature path") parser.add_argument("--vid_feat_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/imagenet_features/tvqa_imagenet_pool5_hq.h5"), help="imagenet feature path") parser.add_argument("--vid_feat_size", type=int, default=2048, help="visual feature dimension") parser.add_argument("--word2idx_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_word2idx.pickle"), help="word2idx cache path") parser.add_argument("--idx2word_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_idx2word.pickle"), help="idx2word cache path") parser.add_argument("--vocab_embedding_path", type=str, default=os.path.expanduser("~/kable_management/data/tvqa/cache/py2_vocab_embedding.pickle"), help="vocab_embedding cache path") parser.add_argument("--regional_topk", type=int, default=-1, help="Pick top-k scoring regional features across all frames") args = parser.parse_args() plotter = VisdomLinePlotter(env_name="TEST")#args.jobname) args.plotter = plotter args.normalize_v = not args.no_normalize_v args.device = torch.device("cuda:%d" % args.device if args.device >= 0 else "cpu") args.with_ts = not args.no_ts args.input_streams = [] if args.input_streams is None else args.input_streams args.vid_feat_flag = True if "imagenet" in args.input_streams else False args.h5driver = None if args.no_core_driver else "core" #args.dataset = 'msvd_qa' args.dataset = 'tvqa' args.word_dim = 300 args.vocab_num = 4000 args.pretrained_embedding = os.path.expanduser("~/kable_management/data/msvd_qa/word_embedding.npy") args.video_feature_dim = 4096 args.video_feature_num = 480#20 args.answer_num = 1000 args.memory_dim = 256 args.batch_size = 32 args.reg_coeff = 1e-5 args.learning_rate = 0.001 args.preprocess_dir = os.path.expanduser("~/kable_management/data/hme_tvqa/") args.log = os.path.expanduser('~/kable_management/blp_paper/.results/msvd/logs') args.hidden_size = 512 args.image_feature_net = 'concat' args.layer = 'fc' ################################################################################ #model.load_embedding(dset.vocab_embedding) #opt.vocab_size = len(dset.word2idx) ############################################################################## ############################################################################## ############################################################################## ############################################################################## ############################################################################## ############################################################################## ############################################################################## ############################################################################## ### add arguments ### args.max_sequence_length = 480 args.memory_type = '_mrm2s' args.image_feature_net = 'concat' args.layer = 'fc' #args.save_model_path = args.save_path + '%s_%s_%s%s' % (args.task,args.image_feature_net,args.layer,args.memory_type) ## Create model directory # if not os.path.exists(args.save_model_path): # os.makedirs(args.save_model_path) ############################# # Parameters ############################# SEQUENCE_LENGTH = args.max_sequence_length VOCABULARY_SIZE = train_dataset.n_words assert VOCABULARY_SIZE == voc_len FEAT_DIM = train_dataset.get_video_feature_dimension()[1:] train_iter = train_dataset.batch_iter(args.num_epochs, args.batch_size) # Create model directory if not os.path.exists(args.save_path): os.makedirs(args.save_path) from embed_loss import MultipleChoiceLoss criterion = MultipleChoiceLoss(num_option=5, margin=1, size_average=True).cuda() if args.memory_type=='_mrm2s': rnn = TGIFAttentionTwoStream(args, 'Trans', feat_channel, feat_dim, text_embed_size, args.hidden_size, voc_len, args.num_layers, word_matrix, answer_vocab_size = answer_vocab_size, max_len=args.max_sequence_length) else: assert 1==2 rnn = rnn.cuda() ############################################################################## ############################################################################################################################################################ ############################################################################################################################################################ ############################################################################################################################################################ ############################################################################## # Heres my code # rnn = TGIFAttentionTwoStream(args, feat_channel, feat_dim, text_embed_size, args.hidden_size, # voc_len, num_layers, word_matrix, answer_vocab_size = answer_vocab_size, # max_len=max_sequence_length) # #rnn = rnn.cuda() # rnn = rnn.to(args.device) # if args.test == 1: # rnn.load_state_dict(torch.load(os.path.join(args.save_model_path, 'rnn-3800-vl_0.316.pkl'))) # # loss function # criterion = nn.CrossEntropyLoss(size_average=True).cuda() # optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate) # iter = 0 # ###################################################################### # best_val_acc = 0.0 # best_test_acc = 0.0 # train_loss = [] # train_acc = [] # ###################################################################### # for epoch in range(0, 10000): # #dataset.reset_train() # for batch_idx, batch in enumerate(train_loader): # tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(batch, args.max_sub_l, args.max_vcpt_l, args.max_vid_l, # device='cpu')#args.device # assert( tvqa_data[18].shape[2] == 4096 ) # data_dict = {} # data_dict['num_mult_choices'] = 5 # data_dict['answers'] = targets #'int32 torch tensor correct answer ids' # data_dict['video_features'] = torch.cat( [tvqa_data[18], tvqa_data[16]], dim=2 ).unsqueeze(2).unsqueeze(2)#'(bsz, 35(max), 1, 1, 6144)' # data_dict['video_lengths'] = tvqa_data[17].tolist() #'list, (bsz), number of frames, #max is 35' # data_dict['candidates'] = tvqa_data[0]#'torch tensor (bsz, 5, 35) (max q length)' # data_dict['candidate_lengths'] = tvqa_data[1]#'list of list, lengths of each question+ans' # import ipdb; ipdb.set_trace() ################################################################################ #dataset = dt.MSVDQA(args.batch_size, args.preprocess_dir) dataset = tvqa_dataset.TVQADataset(args, mode="train") train_loader = DataLoader(dataset, batch_size=args.bsz, shuffle=True, collate_fn=tvqa_dataset.pad_collate) args.memory_type='_mrm2s' args.save_model_path = args.save_path + 'model_%s_%s%s' % (args.image_feature_net,args.layer,args.memory_type) if not os.path.exists(args.save_model_path): os.makedirs(args.save_model_path) ############################# # get video feature dimension ############################# feat_channel = args.video_feature_dim feat_dim = 1 text_embed_size = args.word_dim answer_vocab_size = args.answer_num voc_len = args.vocab_num num_layers = 2 max_sequence_length = args.video_feature_num word_matrix = np.load(args.pretrained_embedding) answerset = pd.read_csv(os.path.expanduser('~/kable_management/data/msvd_qa/answer_set.txt'), header=None)[0] rnn = TGIFAttentionTwoStream(args, feat_channel, feat_dim, text_embed_size, args.hidden_size, voc_len, num_layers, word_matrix, answer_vocab_size = answer_vocab_size, max_len=max_sequence_length) #rnn = rnn.cuda() rnn = rnn.to(args.device) if args.test == 1: rnn.load_state_dict(torch.load(os.path.join(args.save_model_path, 'rnn-3800-vl_0.316.pkl'))) # loss function criterion = nn.CrossEntropyLoss(size_average=True).cuda() optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate) iter = 0 ###################################################################### best_val_acc = 0.0 best_test_acc = 0.0 train_loss = [] train_acc = [] ###################################################################### for epoch in range(0, 10000): #dataset.reset_train() for batch_idx, batch in enumerate(train_loader): tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs(batch, args.max_sub_l, args.max_vcpt_l, args.max_vid_l, device='cpu')#args.device assert( tvqa_data[18].shape[2] == 4096 ) data_dict = {} data_dict['num_mult_choices'] = 5 data_dict['answers'] = targets #'int32 torch tensor correct answer ids' data_dict['video_features'] = torch.cat( [tvqa_data[18], tvqa_data[16]], dim=2 ).unsqueeze(2).unsqueeze(2)#'(bsz, 35(max), 1, 1, 6144)' data_dict['video_lengths'] = tvqa_data[17].tolist() #'list, (bsz), number of frames, #max is 35' data_dict['candidates'] = tvqa_data[0]#'torch tensor (bsz, 5, 35) (max q length)' data_dict['candidate_lengths'] = tvqa_data[1]#'list of list, lengths of each question+ans' import ipdb; ipdb.set_trace() outputs, predictions = rnn(data_dict) targets = data_dict['answers'] loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() import ipdb; ipdb.set_trace() acc = rnn.accuracy(predictions, targets) print('Train iter %d, loss %.3f, acc %.2f' % (iter,loss.data,acc.item())) train_loss.append(loss.item()) train_acc.append(acc.item())
def train(opt, dset, model, criterion, optimizer, epoch, previous_best_acc, scheduler): dset.set_mode("train") if opt.rubi: # Model and the question/subtitle only rubi style model ill be packed together, unpack here model, rubi_model = model rubi_model.train() model.train() train_loader = DataLoader(dset, batch_size=opt.bsz, shuffle=True, collate_fn=pad_collate) train_loss = [] valid_acc_log = ["batch_idx\tacc"] train_corrects = [] torch.set_grad_enabled(True) for batch_idx, batch in tqdm(enumerate(train_loader)): # Process inputs if (opt.lrtype == "cyclic"): scheduler.batch_step() model_inputs, targets, _ = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l, device=opt.device) # model output if opt.dual_stream: outputs = model(*model_inputs) if opt.lanecheck: raise Exception( "Not implemeneted lanechecking with dual stream") else: outputs = model(*model_inputs) if opt.rubi: rubi_outputs = rubi_model(*model_inputs) # Loss if not opt.rubi: if opt.lanecheck: loss = criterion(outputs[-1], targets) else: loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() else: if opt.lanecheck: rubi_in = { # This may be confusing, but this is because of my naming scheme conflicting with Remi's 'logits': None, 'logits_q': rubi_outputs[-1], 'logits_rubi': outputs[-1] } else: rubi_in = { # This may be confusing, but this is because of my naming scheme conflicting with Remi's 'logits': None, 'logits_q': rubi_outputs, 'logits_rubi': outputs } rubi_targets = {'class_id': targets} losses = criterion(rubi_in, rubi_targets) loss, q_loss = losses['loss_mm_q'], losses[ 'loss_q'] # loss is the fused rubi loss optimizer.zero_grad() loss.backward() q_loss.backward() # push the question loss backwards too optimizer.step() opt.plotter.text_plot(opt.jobname + " epoch", opt.jobname + " " + str(epoch)) # measure accuracy and record loss train_loss.append(loss.item()) if opt.lanecheck: pred_ids = outputs[-1].data.max(1)[1] else: pred_ids = outputs.data.max(1)[1] train_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist() if (batch_idx % opt.log_freq == 2): #(opt.log_freq-1)): niter = epoch * len(train_loader) + batch_idx train_acc = sum(train_corrects) / float(len(train_corrects)) train_loss = sum(train_loss) / float( len(train_loss)) #from train_corrects # Plotting if opt.testrun == False: opt.writer.add_scalar("Train/Acc", train_acc, niter) opt.writer.add_scalar("Train/Loss", train_loss, niter) opt.plotter.plot("accuracy", "train", opt.jobname, niter, train_acc) opt.plotter.plot("loss", "train", opt.jobname, niter, train_loss) # Validation if opt.dual_stream or opt.deep_cca: valid_acc, _ = validate(opt, dset, model, mode="valid") else: valid_acc, val_lanecheck_dict = validate_lanecheck( opt, dset, model, mode="valid") if opt.testrun == False: #opt.writer.add_scalar("Valid/Loss", valid_loss, niter) opt.plotter.plot("accuracy", "val", opt.jobname, niter, valid_acc) #opt.plotter.plot("loss", "val", opt.jobname, niter, valid_loss) valid_log_str = "%02d\t%.4f" % (batch_idx, valid_acc) valid_acc_log.append(valid_log_str) # If this is the best run yet if valid_acc > previous_best_acc: previous_best_acc = valid_acc # Plot best accuracy so far in text box if opt.testrun == False: opt.plotter.text_plot( opt.jobname + " val", opt.jobname + " val " + str(round(previous_best_acc, 4))) # Save the predictions for validation and training datasets, and the state dictionary of the model #_, train_lanecheck_dict = validate_lanecheck(opt, dset, model, mode="train") if (not opt.dual_stream) and (not opt.deep_cca): save_pickle(val_lanecheck_dict, opt.lanecheck_path + '_valid') torch.save(model.state_dict(), os.path.join(opt.results_dir, "best_valid.pth")) # reset to train torch.set_grad_enabled(True) model.train() dset.set_mode("train") train_corrects = [] train_loss = [] if opt.debug: break # additional log with open(os.path.join(opt.results_dir, "valid_acc.log"), "a") as f: f.write("\n".join(valid_acc_log) + "\n") return previous_best_acc
def validate_lanecheck(opt, dset, model, mode="valid"): dset.set_mode(mode) # Change mode to training here torch.set_grad_enabled(False) model.eval() valid_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=False, collate_fn=pad_collate) lanecheck_dict = {} valid_corrects = [] #opt.lanecheck = True if opt.disable_streams is not None: for d_stream in opt.disable_streams: if d_stream in opt.input_streams: opt.input_streams.remove(d_stream) else: opt.disable_streams = [] for batch_idx, batch in enumerate(valid_loader): # Accuracy model_inputs, targets, qids = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l, device=opt.device) sub_out, vcpt_out, vid_out, reg_out, regtopk_out, outputs = model( *model_inputs) pred_ids = outputs.data.max(1)[1] valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist() # Add the ground truth to the end of each output response, and then the predicted ID for the question after that if 'sub' in opt.input_streams and not opt.dual_stream: sub_out = torch.cat( (sub_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) if 'vcpt' in opt.input_streams and not opt.dual_stream: vcpt_out = torch.cat( (vcpt_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) if 'imagenet' in opt.input_streams: vid_out = torch.cat( (vid_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) if ('regional' in opt.input_streams) and opt.regional_topk == -1: reg_out = torch.cat( (reg_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) if opt.regional_topk != -1: regtopk_out = torch.cat((regtopk_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) # Add them to the lanecheck dictionary for id_idx in range(len(qids)): lanecheck_dict[qids[id_idx]] = {} if 'sub' in opt.input_streams: lanecheck_dict[qids[id_idx]]['sub_out'] = sub_out[id_idx] if 'vcpt' in opt.input_streams: lanecheck_dict[qids[id_idx]]['vcpt_out'] = vcpt_out[id_idx] if 'imagenet' in opt.input_streams: lanecheck_dict[qids[id_idx]]['vid_out'] = vid_out[id_idx] if ('regional' in opt.input_streams) and opt.regional_topk == -1: lanecheck_dict[qids[id_idx]]['reg_out'] = reg_out[id_idx] if opt.regional_topk != -1: lanecheck_dict[ qids[id_idx]]['regtopk_out'] = regtopk_out[id_idx] valid_acc = sum(valid_corrects) / float(len(valid_corrects)) lanecheck_dict['acc'] = valid_acc #opt.lanecheck = False return valid_acc, lanecheck_dict
def validate(opt, dset, model, mode="valid"): dset.set_mode(opt.dset) # Change mode to training here torch.set_grad_enabled(False) model.eval() valid_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=True, collate_fn=pad_collate) if opt.bert != None: bert_tok = BertTokenizer.from_pretrained('bert-base-uncased') else: # Word embedding lookup GloVE from utils import load_pickle idx2word = load_pickle(opt.idx2word_path) #valid_qids = [] lanecheck_dict = {} valid_corrects = [] if opt.disable_streams is not None: for d_stream in opt.disable_streams: if d_stream in opt.input_streams: opt.input_streams.remove(d_stream) else: opt.disable_streams = [] for batch_idx, batch in enumerate(valid_loader): print(round(batch_idx/len(valid_loader)*100, 2), "percent complete") model_inputs, targets, qids = preprocess_inputs(batch, opt.max_sub_l, opt.max_vcpt_l, opt.max_vid_l, device=opt.device) if opt.lanecheck: sub_out, vcpt_out, vid_out, reg_out, regtopk_out, outputs = model(*model_inputs) pred_ids = outputs.data.max(1)[1] valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist() # measure accuracy and record loss #valid_qids += [int(x) for x in qids] batch_q , _ = getattr(batch, 'q') batch_a0, _ = getattr(batch, 'a0') batch_a1, _ = getattr(batch, 'a1') batch_a2, _ = getattr(batch, 'a2') batch_a3, _ = getattr(batch, 'a3') batch_a4, _ = getattr(batch, 'a4') if 'sub' in opt.input_streams: batch_sub, _ = getattr(batch, 'sub') if 'vcpt' in opt.input_streams: batch_vcpt, _ = getattr(batch, 'vcpt') # Add the ground truth to the end of each output response, and then the predicted ID for the question after that if 'sub' in opt.input_streams: sub_out = torch.cat((sub_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) if 'vcpt' in opt.input_streams: vcpt_out = torch.cat((vcpt_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) if 'imagenet' in opt.input_streams: vid_out = torch.cat((vid_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) if ('regional' in opt.input_streams) and opt.regional_topk == -1: reg_out = torch.cat((reg_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) if opt.regional_topk != -1: regtopk_out = torch.cat((regtopk_out.cpu().squeeze(), targets.cpu().float().unsqueeze(1), pred_ids.cpu().float().unsqueeze(1)), dim=1) # Add them to the lanecheck dictionary for id_idx in range(len(qids)): lanecheck_dict[qids[id_idx]] = {} if 'sub' in opt.input_streams: lanecheck_dict[qids[id_idx]]['sub_out'] = sub_out[id_idx] if 'vcpt' in opt.input_streams: lanecheck_dict[qids[id_idx]]['vcpt_out'] = vcpt_out[id_idx] if 'imagenet' in opt.input_streams: lanecheck_dict[qids[id_idx]]['vid_out'] = vid_out[id_idx] if ('regional' in opt.input_streams) and opt.regional_topk == -1: lanecheck_dict[qids[id_idx]]['reg_out'] = reg_out[id_idx] if opt.regional_topk != -1: lanecheck_dict[qids[id_idx]]['regtopk_out'] = regtopk_out[id_idx] #Vcpt #decode from bert if opt.bert != None: lanecheck_dict[qids[id_idx]]['q'] = bert_tok.decode(batch_q[id_idx]) lanecheck_dict[qids[id_idx]]['a0'] = bert_tok.decode(batch_a0[id_idx]) lanecheck_dict[qids[id_idx]]['a1'] = bert_tok.decode(batch_a1[id_idx]) lanecheck_dict[qids[id_idx]]['a2'] = bert_tok.decode(batch_a2[id_idx]) lanecheck_dict[qids[id_idx]]['a3'] = bert_tok.decode(batch_a3[id_idx]) lanecheck_dict[qids[id_idx]]['a4'] = bert_tok.decode(batch_a4[id_idx]) if 'sub' in opt.input_streams: lanecheck_dict[qids[id_idx]]['sub'] = bert_tok.decode(batch_sub[id_idx]) if 'vcpt' in opt.input_streams: lanecheck_dict[qids[id_idx]]['vcpt'] = bert_tok.decode(batch_vcpt[id_idx]) else: # Decode from GloVE #idx2word lanecheck_dict[qids[id_idx]]['q'] = [ idx2word[int(word)] for word in batch_q[id_idx] ] lanecheck_dict[qids[id_idx]]['a0'] = [ idx2word[int(word)] for word in batch_a0[id_idx] ] lanecheck_dict[qids[id_idx]]['a1'] = [ idx2word[int(word)] for word in batch_a1[id_idx] ] lanecheck_dict[qids[id_idx]]['a2'] = [ idx2word[int(word)] for word in batch_a2[id_idx] ] lanecheck_dict[qids[id_idx]]['a3'] = [ idx2word[int(word)] for word in batch_a3[id_idx] ] lanecheck_dict[qids[id_idx]]['a4'] = [ idx2word[int(word)] for word in batch_a4[id_idx] ] if 'sub' in opt.input_streams: lanecheck_dict[qids[id_idx]]['sub'] = [ idx2word[int(word)] for word in batch_sub[id_idx] ] if 'vcpt' in opt.input_streams: lanecheck_dict[qids[id_idx]]['vcpt'] = [ idx2word[int(word)] for word in batch_vcpt[id_idx] ] valid_acc = sum(valid_corrects) / float(len(valid_corrects)) lanecheck_dict['valid_acc'] = valid_acc print('valid acc', valid_acc) from utils import save_pickle save_pickle(lanecheck_dict, opt.lanecheck_path+'_'+opt.dset) return valid_acc
def main(args): torch.manual_seed(2667) #1 ### add arguments ### args.vc_dir = os.path.expanduser( "~/kable_management/data/tgif_qa/data/Vocabulary") args.df_dir = os.path.expanduser( "~/kable_management/data/tgif_qa/data/dataset") # Dataset dataset = tvqa_dataset.TVQADataset(args, mode="train") args.max_sequence_length = 50 #35 args.model_name = args.task + args.feat_type args.memory_type = '_mrm2s' args.image_feature_net = 'concat' args.layer = 'fc' args.save_model_path = args.save_path + '%s_%s_%s%s' % ( args.task, args.image_feature_net, args.layer, args.memory_type) # Create model directory if not os.path.exists(args.save_model_path): os.makedirs(args.save_model_path) ############################# # get video feature dimension ############################# video_feature_dimension = (480, 1, 1, 6144) feat_channel = video_feature_dimension[3] feat_dim = video_feature_dimension[2] text_embed_size = 300 answer_vocab_size = None ############################# # get word vector dimension ############################# word_matrix = dataset.vocab_embedding voc_len = len(dataset.word2idx) assert text_embed_size == word_matrix.shape[1] ############################# # Parameters ############################# SEQUENCE_LENGTH = args.max_sequence_length VOCABULARY_SIZE = voc_len assert VOCABULARY_SIZE == voc_len FEAT_DIM = (1, 1, 6144) # Create model directory if not os.path.exists(args.save_path): os.makedirs(args.save_path) if args.task == 'Count': # add L2 loss criterion = nn.MSELoss(size_average=True).to(args.device) elif args.task in ['Action', 'Trans']: from embed_loss import MultipleChoiceLoss criterion = MultipleChoiceLoss(num_option=5, margin=1, size_average=True).to(args.device) elif args.task == 'FrameQA': # add classification loss answer_vocab_size = len(train_dataset.ans2idx) print('Vocabulary size', answer_vocab_size, VOCABULARY_SIZE) criterion = nn.CrossEntropyLoss(size_average=True).to(args.device) if args.memory_type == '_mrm2s': rnn = TGIFAttentionTwoStream(args, args.task, feat_channel, feat_dim, text_embed_size, args.hidden_size, voc_len, args.num_layers, word_matrix, answer_vocab_size=answer_vocab_size, max_len=args.max_sequence_length) else: assert 1 == 2 rnn = rnn.to(args.device) # to directly test, load pre-trained model, replace with your model to test your model if args.test == 1: if args.task == 'Count': rnn.load_state_dict( torch.load( './saved_models/Count_concat_fc_mrm2s/rnn-1300-l3.257-a27.942.pkl' )) elif args.task == 'Action': rnn.load_state_dict( torch.load( './saved_models/Action_concat_fc_mrm2s/rnn-0800-l0.137-a84.663.pkl' )) elif args.task == 'Trans': rnn.load_state_dict( torch.load( './saved_models/Trans_concat_fc_mrm2s/rnn-1500-l0.246-a78.068.pkl' )) elif args.task == 'FrameQA': rnn.load_state_dict( torch.load( './saved_models/FrameQA_concat_fc_mrm2s/rnn-4200-l1.233-a69.361.pkl' )) else: assert 1 == 2, 'Invalid task' optimizer = torch.optim.Adam(rnn.parameters(), lr=args.learning_rate) iter = 0 if args.task in ['Action', 'Trans']: best_val_acc = 0.0 best_val_iter = 0.0 best_test_acc = 0.0 train_acc = [] train_loss = [] for epoch in range(0, args.num_epochs): dataset.set_mode("train") train_loader = DataLoader(dataset, batch_size=args.bsz, shuffle=True, collate_fn=tvqa_dataset.pad_collate, num_workers=args.num_workers) for batch_idx, batch in enumerate(train_loader): # if(batch_idx<51): # iter+=1 # continue # else: # pass tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs( batch, args.max_sub_l, args.max_vcpt_l, args.max_vid_l, device=args.device) assert (tvqa_data[18].shape[2] == 4096) data_dict = {} data_dict['num_mult_choices'] = 5 data_dict[ 'answers'] = targets #'int32 torch tensor correct answer ids' data_dict['video_features'] = torch.cat( [tvqa_data[18], tvqa_data[16]], dim=2).unsqueeze(2).unsqueeze( 2) #'(bsz, 35(max), 1, 1, 6144)' data_dict['video_lengths'] = tvqa_data[17].tolist( ) #'list, (bsz), number of frames, #max is 35' data_dict['candidates'] = tvqa_data[ 0] #'torch tensor (bsz, 5, 35) (max q length)' data_dict['candidate_lengths'] = tvqa_data[ 1] #'list of list, lengths of each question+ans' # if hard restrict if args.hard_restrict: data_dict['video_features'] = data_dict[ 'video_features'][:, :args. max_sequence_length] #'(bsz, 35(max), 1, 1, 6144)' data_dict['video_lengths'] = [ i if i <= args.max_sequence_length else args.max_sequence_length for i in data_dict['video_lengths'] ] data_dict['candidates'] = data_dict[ 'candidates'][:, :, :args.max_sequence_length] for xx in range(data_dict['candidate_lengths'].shape[0]): if (data_dict['video_lengths'][xx] == 0): data_dict['video_lengths'][xx] = 1 for yy in range( data_dict['candidate_lengths'].shape[1]): if (data_dict['candidate_lengths'][xx][yy] > args.max_sequence_length): data_dict['candidate_lengths'][xx][ yy] = args.max_sequence_length #import ipdb; ipdb.set_trace() #print("boom") outputs, targets, predictions = rnn(data_dict, args.task) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() acc = rnn.accuracy(predictions, targets.long()) #print("boom") print('Train %s iter %d, loss %.3f, acc %.2f' % (args.task, iter, loss.data, acc.item())) ###################################################################### #args.plotter.plot("accuracy", "train", args.jobname, iter, acc.item()) #args.plotter.plot("loss", "train", args.jobname, iter, int(loss.data)) train_acc.append(acc.item()) train_loss.append(loss.item()) ###################################################################### if (batch_idx % args.log_freq) == 0: #(args.log_freq-1): dataset.set_mode("valid") valid_loader = DataLoader( dataset, batch_size=args.test_bsz, shuffle=True, collate_fn=tvqa_dataset.pad_collate, num_workers=args.num_workers) rnn.eval() ###################################################################### train_loss = sum(train_loss) / float(len(train_loss)) train_acc = sum(train_acc) / float(len(train_acc)) args.plotter.plot("accuracy", "train", args.jobname + " " + args.task, iter, train_acc) #import ipdb; ipdb.set_trace() args.plotter.plot("loss", "train", args.jobname + " " + args.task, iter, train_loss) train_loss = [] train_acc = [] ###################################################################### with torch.no_grad(): if args.test == 0: n_iter = len(valid_loader) losses = [] accuracy = [] iter_val = 0 for batch_idx, batch in enumerate(valid_loader): tvqa_data, targets, _ = tvqa_dataset.preprocess_inputs( batch, args.max_sub_l, args.max_vcpt_l, args.max_vid_l, device=args.device) assert (tvqa_data[18].shape[2] == 4096) data_dict = {} data_dict['num_mult_choices'] = 5 data_dict[ 'answers'] = targets #'int32 torch tensor correct answer ids' data_dict['video_features'] = torch.cat( [tvqa_data[18], tvqa_data[16]], dim=2).unsqueeze(2).unsqueeze( 2) #'(bsz, 35(max), 1, 1, 6144)' data_dict['video_lengths'] = tvqa_data[ 17].tolist( ) #'list, (bsz), number of frames, #max is 35' data_dict['candidates'] = tvqa_data[ 0] #'torch tensor (bsz, 5, 35) (max q length)' data_dict['candidate_lengths'] = tvqa_data[ 1] #'list of list, lengths of each question+ans' # if hard restrict if args.hard_restrict: data_dict['video_features'] = data_dict[ 'video_features'][:, :args. max_sequence_length] #'(bsz, 35(max), 1, 1, 6144)' data_dict['video_lengths'] = [ i if i <= args.max_sequence_length else args.max_sequence_length for i in data_dict['video_lengths'] ] data_dict['candidates'] = data_dict[ 'candidates'][:, :, :args. max_sequence_length] for xx in range( data_dict['candidate_lengths']. shape[0]): if (data_dict['video_lengths'][xx] == 0 ): data_dict['video_lengths'][xx] = 1 for yy in range( data_dict['candidate_lengths']. shape[1]): if (data_dict['candidate_lengths'] [xx][yy] > args.max_sequence_length): data_dict['candidate_lengths'][ xx][yy] = args.max_sequence_length outputs, targets, predictions = rnn( data_dict, args.task) loss = criterion(outputs, targets) acc = rnn.accuracy(predictions, targets.long()) losses.append(loss.item()) accuracy.append(acc.item()) losses = sum(losses) / n_iter accuracy = sum(accuracy) / n_iter ###################################################################### if best_val_acc < accuracy: best_val_iter = iter best_val_acc = accuracy args.plotter.text_plot( args.task + args.jobname + " val", args.task + args.jobname + " val " + str(round(best_val_acc, 4)) + " " + str(iter)) #args.plotter.plot("loss", "val", args.jobname, iter, losses.avg) args.plotter.plot("accuracy", "val", args.jobname + " " + args.task, iter, accuracy) ###################################################################### #print('[Val] iter %d, loss %.3f, acc %.2f, best acc %.3f at iter %d' % ( # iter, losses, accuracy, best_val_acc, best_val_iter)) # torch.save(rnn.state_dict(), os.path.join(args.save_model_path, 'rnn-%04d-l%.3f-a%.3f.pkl' % ( # iter, losses.avg, accuracy.avg))) rnn.train() dataset.set_mode("train") iter += 1 elif args.task == 'Count': pass