class VQA: def __init__(self, folder="/", load=True): # Datasets if load: self.train_tuple = get_data_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True, folder=folder) if args.valid != "": self.valid_tuple = get_data_tuple(args.valid, bs=128, shuffle=False, drop_last=False, folder=folder) else: self.valid_tuple = None # Model # self.model = VQAModel(self.train_tuple.dataset.num_answers) self.model = VQAModel(3129, fn_type=args.fn_type) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: load_lxmert_qa(args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) # GPU options self.model = self.model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() # Load IndexList of Answer to Type Map self.indexlist = json.load(open("data/vqa/indexlist.json")) indextensor = torch.cuda.LongTensor(self.indexlist) self.mask0 = torch.eq(indextensor, 0).float() self.mask1 = torch.eq(indextensor, 1).float() self.mask2 = torch.eq(indextensor, 2).float() self.mask_cache = {} self.yes_index = 425 self.no_index = 1403 self.mask_yes = torch.zeros(len(self.indexlist)).cuda() self.mask_yes[self.yes_index] = 1.0 self.mask_yes[self.no_index] = 1.0 # Loss and Optimizer self.logsoftmax = nn.LogSoftmax() self.sigmoid = nn.Sigmoid() self.softmax = nn.Softmax() self.bceloss = nn.BCELoss() self.nllloss = nn.NLLLoss() self.mseloss = nn.MSELoss() self.bce_loss = nn.BCEWithLogitsLoss() self.ce_loss = nn.CrossEntropyLoss() if load: if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("BertAdam Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(self.model.parameters(), args.lr) # Output Directory self.output = args.output os.makedirs(self.output, exist_ok=True) def train(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader), ascii=True) ) if args.tqdm else (lambda x: x) best_valid = 0. for epoch in range(args.epochs): quesid2ans = {} epoch_loss = 0 type_loss = 0 yn_loss = 0 const_loss = 0 total_items = 0 for i, (ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target) in iter_wrapper(enumerate(loader)): self.model.train() self.optim.zero_grad() feats, boxes, target, yntypetarget, typetarget = feats.cuda( ), boxes.cuda(), target.cuda(), yesnotypetargets.cuda( ), typetarget.cuda() op, q1typetarget, q2typetarget, q1yntypetargets, q2yntypetargets, q1_target, q2_target = op.cuda( ), q1typetarget.cuda(), q2typetarget.cuda( ), q1yntypetargets.cuda(), q2yntypetargets.cuda( ), q1_target.cuda(), q2_target.cuda() # The actual question logit, type_logit = self.model_forward(feats, boxes, ques, target, yntypetarget) # print(logit,target) loss = self.bceloss(logit, target) loss = loss * logit.size(1) loss_type = self.nllloss(type_logit, typetarget) loss_type = loss_type * type_logit.size(1) # loss_yn = self.bce_loss(yn_type_logit,yntypetarget) # loss_yn = loss_yn*yn_type_logit.size(1) # Q1 and Q2 Prediction, no loss for these predictions only constraint loss for these guys q1_logit, q1_type_logit = self.model_forward( feats, boxes, q1, q1_target, q1yntypetargets) q2_logit, q2_type_logit = self.model_forward( feats, boxes, q2, q2_target, q2yntypetargets) constraint_loss = self.constraintloss(logit, q1_logit, q2_logit, op) # Final Loss # print(loss, loss_type,loss_yn, constraint_loss) epoch_loss += loss.item() type_loss += loss_type.item() # yn_loss+=loss_yn.item() const_loss += constraint_loss.item() total_items += 1 loss = 0.5 * loss + 0.25 * loss_type + 0.25 * constraint_loss loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optim.step() score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans log_str = "\nEpoch %d: Train %0.2f\n" % ( epoch, evaluator.evaluate(quesid2ans) * 100.) if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(eval_tuple) if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) + \ "Epoch Losses: QA : %0.5f, YN : %0.5f, Type : %0.5f, Const : %0.5f\n"%(epoch_loss/total_items,yn_loss/total_items,type_loss/total_items, const_loss/total_items) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() self.save("LAST") return best_valid def model_forward(self, feats, boxes, sent, target, yntypetarget): logit, type_logit = self.model(feats, boxes, sent) assert logit.dim() == target.dim() == 2 logit = self.sigmoid(logit) type_logit_soft = self.softmax(type_logit) type_logit = self.logsoftmax(type_logit) logit = self.calculatelogits(logit, type_logit_soft, "train") return logit, type_logit def get_masks(self, mode, batch): # key = mode+str(batch) # if key in self.mask_cache: # return self.mask_cache[key] mask0 = self.mask0.repeat([batch, 1]) mask1 = self.mask1.repeat([batch, 1]) mask2 = self.mask2.repeat([batch, 1]) return [mask0, mask1, mask2] # self.mask_cache[key] = [mask0,mask1,mask2] # return self.mask_cache[key] def calculatelogits(self, anspreds, typepreds, mode): batch = anspreds.size()[0] mask0, mask1, mask2 = self.get_masks(mode, batch) anspreds0 = anspreds * mask0 * typepreds.select( dim=1, index=0).reshape([batch, 1]).repeat([1, 3129]) anspreds1 = anspreds * mask1 * typepreds.select( dim=1, index=1).reshape([batch, 1]).repeat([1, 3129]) anspreds2 = anspreds * mask2 * typepreds.select( dim=1, index=2).reshape([batch, 1]).repeat([1, 3129]) nanspreds = anspreds0 + anspreds1 + anspreds2 return nanspreds def rangeloss(self, x, lower, upper, lamb=4): mean = (lower + upper) / 2 sigma = (upper - lower + 0.00001) / lamb loss = 1 - torch.exp(-0.5 * torch.pow(torch.div(x - mean, sigma), 2)) return loss.sum() def select_yesnoprobs(self, logit, x, op): op_mask = torch.eq(op, x) logit = logit[op_mask].view(-1, 3129) logit_m = logit * self.mask_yes m = logit_m == 0 logit_m = logit_m[~m].view(-1, 2) logit_m = torch.softmax(logit_m, 1) return logit_m.select(dim=1, index=0).view(-1, 1) def constraintloss(self, logit, q1_logit, q2_logit, op): total_loss = torch.zeros([1]).cuda() for x in range(1, 11): logit_m = self.select_yesnoprobs(logit, x, op) q1_logit_m = self.select_yesnoprobs(q1_logit, x, op) q2_logit_m = self.select_yesnoprobs(q2_logit, x, op) if logit_m.nelement() == 0: continue ideal_logit_m = op_map[x](q1_logit_m, q2_logit_m) rangeloss = self.mseloss(logit_m, ideal_logit_m) total_loss += rangeloss return total_loss def predict(self, eval_tuple: DataTuple, dump=None): """ Predict the answers to questions in a data split. :param eval_tuple: The data tuple to be evaluated. :param dump: The path of saved file to dump results. :return: A dict of question_id to answer. """ self.model.eval() dset, loader, evaluator = eval_tuple quesid2ans = {} type_accuracy = 0.0 yn_type_accuracy = 0.0 num_batches = 0 for i, datum_tuple in tqdm(enumerate(loader), ascii=True, desc="Evaluating"): # ques_id, feats, boxes, sent, typed_target, yesnotypetargets, target = datum_tuple # Avoid seeing ground truth ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target = datum_tuple with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() logit, typelogit = self.model(feats, boxes, ques) logit = self.sigmoid(logit) type_logit_soft = self.softmax(typelogit) logit = self.calculatelogits(logit, type_logit_soft, "predict") score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans type_accuracy += torch.mean( torch.eq( typelogit.argmax(dim=1).cuda(), typetarget.cuda()).float().cuda()).cpu().item() # yn_type_accuracy+= torch.mean(torch.eq((yn_type_logit>0.5).float().cuda(),yesnotypetargets.cuda().float()).float().cuda()).cpu().item() num_batches += 1 print("Type Accuracy:", type_accuracy / num_batches) print("YN Accuracy:", yn_type_accuracy / num_batches) if dump is not None: evaluator.dump_result(quesid2ans, dump) return quesid2ans def evaluate(self, eval_tuple: DataTuple, dump=None): """Evaluate all data in data_tuple.""" quesid2ans = self.predict(eval_tuple, dump) return eval_tuple.evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(data_tuple): dset, loader, evaluator = data_tuple quesid2ans = {} for i, datum_tuple in enumerate(loader): ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target = datum_tuple _, label = target.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans return evaluator.evaluate(quesid2ans) def save(self, name): model_to_save = self.model.module if hasattr(self.model, "module") else self.model torch.save(model_to_save.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s.pth" % path) self.model.load_state_dict(state_dict)
class VQA: def __init__(self, folder="/", load=True): # Datasets if load: self.train_tuple = get_data_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True, folder=folder) if args.valid != "": self.valid_tuple = get_data_tuple(args.valid, bs=128, shuffle=False, drop_last=False, folder=folder) else: self.valid_tuple = None # Model # self.model = VQAModel(self.train_tuple.dataset.num_answers) self.model = VQAModel(3129, fn_type=args.fn_type) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: load_lxmert_qa(args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) # GPU options self.model = self.model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() # Load IndexList of Answer to Type Map self.indexlist = json.load(open("data/vqa/indexlist.json")) indextensor = torch.cuda.LongTensor(self.indexlist) self.mask0 = torch.eq(indextensor, 0).float() self.mask1 = torch.eq(indextensor, 1).float() self.mask2 = torch.eq(indextensor, 2).float() self.mask_cache = {} # Loss and Optimizer self.logsoftmax = nn.LogSoftmax() self.sigmoid = nn.Sigmoid() self.softmax = nn.Softmax() self.bceloss = nn.BCELoss() self.nllloss = nn.NLLLoss() self.bce_loss = nn.BCEWithLogitsLoss() self.ce_loss = nn.CrossEntropyLoss() if load: if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("BertAdam Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(self.model.parameters(), args.lr) # Output Directory self.output = args.output os.makedirs(self.output, exist_ok=True) def train(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader), ascii=True) ) if args.tqdm else (lambda x: x) best_valid = 0. for epoch in range(args.epochs): quesid2ans = {} for i, (ques_id, feats, boxes, sent, typetarget, target) in iter_wrapper(enumerate(loader)): self.model.train() self.optim.zero_grad() feats, boxes, target, typetarget = feats.cuda(), boxes.cuda( ), target.cuda(), typetarget.cuda() logit, type_logit = self.model(feats, boxes, sent) assert logit.dim() == target.dim() == 2 logit = self.sigmoid(logit) type_logit_soft = self.softmax(type_logit) type_logit = self.logsoftmax(type_logit) logit = self.calculatelogits(logit, type_logit_soft, "train") loss = self.bceloss(logit, target) loss = loss * logit.size(1) loss_type = self.nllloss(type_logit, typetarget) loss_type = loss_type * type_logit.size(1) loss = 0.9 * loss + 0.1 * loss_type loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optim.step() score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans log_str = "\nEpoch %d: Train %0.2f\n" % ( epoch, evaluator.evaluate(quesid2ans) * 100.) if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(eval_tuple) if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() self.save("LAST") return best_valid def get_masks(self, mode, batch): key = mode + str(batch) if key in self.mask_cache: return self.mask_cache[key] mask0 = self.mask0.repeat([batch, 1]) mask1 = self.mask1.repeat([batch, 1]) mask2 = self.mask2.repeat([batch, 1]) self.mask_cache[key] = [mask0, mask1, mask2] return self.mask_cache[key] def calculatelogits(self, anspreds, typepreds, mode): batch = anspreds.size()[0] mask0, mask1, mask2 = self.get_masks(mode, batch) anspreds0 = anspreds * mask0 * typepreds.select( dim=1, index=0).reshape([batch, 1]).repeat([1, 3129]) anspreds1 = anspreds * mask1 * typepreds.select( dim=1, index=1).reshape([batch, 1]).repeat([1, 3129]) anspreds2 = anspreds * mask2 * typepreds.select( dim=1, index=2).reshape([batch, 1]).repeat([1, 3129]) nanspreds = anspreds0 + anspreds1 + anspreds2 return nanspreds def predict(self, eval_tuple: DataTuple, dump=None): """ Predict the answers to questions in a data split. :param eval_tuple: The data tuple to be evaluated. :param dump: The path of saved file to dump results. :return: A dict of question_id to answer. """ self.model.eval() dset, loader, evaluator = eval_tuple quesid2ans = {} type_accuracy = 0.0 num_batches = 0 for i, datum_tuple in tqdm(enumerate(loader), ascii=True, desc="Evaluating"): ques_id, feats, boxes, sent, typed_target, target = datum_tuple # Avoid seeing ground truth with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() logit, typelogit = self.model(feats, boxes, sent) logit = self.sigmoid(logit) type_logit_soft = self.softmax(typelogit) logit = self.calculatelogits(logit, type_logit_soft, "predict") score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans # type_accuracy+= torch.mean(torch.eq((typelogit>0.5).int().cuda(),typed_target.cuda()).float().cuda()).cpu().item() type_accuracy += torch.mean( torch.eq( typelogit.argmax(dim=1).cuda(), typed_target.cuda()).float().cuda()).cpu().item() num_batches += 1 print("Type Accuracy:", type_accuracy / num_batches) if dump is not None: evaluator.dump_result(quesid2ans, dump) return quesid2ans def evaluate(self, eval_tuple: DataTuple, dump=None): """Evaluate all data in data_tuple.""" quesid2ans = self.predict(eval_tuple, dump) return eval_tuple.evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(data_tuple): dset, loader, evaluator = data_tuple quesid2ans = {} for i, datum_tuple in enumerate(loader): ques_id, feats, boxes, sent, typetarget, target = datum_tuple _, label = target.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans return evaluator.evaluate(quesid2ans) def save(self, name): model_to_save = self.model.module if hasattr(self.model, "module") else self.model torch.save(smodel_to_save.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s.pth" % path) self.model.load_state_dict(state_dict)