def __init__(self): self.train_tuple = get_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True) if args.valid != "": valid_bsize = 2048 if args.multiGPU else 512 self.valid_tuple = get_tuple(args.valid, bs=valid_bsize, shuffle=False, drop_last=False) else: self.valid_tuple = None self.model = GQAModel(self.train_tuple.dataset.num_answers) self.momentum = 0.99997 self.siam_model = copy.deepcopy(self.model) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) self.siam_model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: load_lxmert_qa(args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) load_lxmert_qa(args.load_lxmert_qa, self.siam_model, label2ans=self.train_tuple.dataset.label2ans) # GPU options self.model = self.model.cuda() self.siam_model = self.siam_model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() self.siam_model.lxrt_encoder.multi_gpu() # Losses and optimizer self.bce_loss = nn.BCEWithLogitsLoss() self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(list(self.model.parameters()), args.lr) self.output = args.output os.makedirs(self.output, exist_ok=True)
class GQA: def __init__(self): self.train_tuple = get_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True) if args.valid != "": valid_bsize = 2048 if args.multiGPU else 512 self.valid_tuple = get_tuple(args.valid, bs=valid_bsize, shuffle=False, drop_last=False, skip_semantics=True) else: self.valid_tuple = None self.model = GQAModel(self.train_tuple.dataset.num_answers) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: load_lxmert_qa(args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) # GPU options self.model = self.model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() # Losses and optimizer self.bce_loss = nn.BCEWithLogitsLoss() self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) if args.task_nsp_qfpm or args.task_mlm_qfpm: self.task_optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(list(self.model.parameters()), args.lr) self.output = args.output os.makedirs(self.output, exist_ok=True) def train(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) # Pre train the NSP head if args.task_nsp_qfpm or args.task_mlm_qfpm: print(f"********* Pretraining for {args.epochs} epochs *********") for epoch in range(args.epochs): epoch_pretrain_loss = 0 epoch_nsp_loss = 0 epoch_mlm_loss = 0 epoch_nsp_avg = [] for i, (ques_id, feats, boxes, sent, sem_query, sem_matched, target) in iter_wrapper(enumerate(loader)): self.model.train() self.task_optim.zero_grad() feats, boxes, sem_matched, target = feats.cuda( ), boxes.cuda(), sem_matched.cuda(), target.cuda() _, logit_nsp_qfpm, logit_mlm_qfpm, masked_lang_labels = self.model( feats, boxes, sent, sem_query) loss = 0 if args.task_nsp_qfpm: nsp_qfpm_loss = self.mce_loss(logit_nsp_qfpm, sem_matched) loss += 100 * nsp_qfpm_loss # multiply to equally weight the MLM with NSP loss epoch_nsp_loss += nsp_qfpm_loss.detach() _, idx = torch.max(logit_nsp_qfpm, 1) diff = torch.abs(sem_matched - idx) epoch_nsp_avg.append( torch.sum(diff).item() / sem_matched.shape[0]) if args.task_mlm_qfpm: # masked_lang_labels: [batch, max_length], logit_mlm_qfpm: [batch, max_length, vocab_size] vocab_size = logit_mlm_qfpm.shape[2] masked_lm_loss = self.mce_loss( logit_mlm_qfpm.view(-1, vocab_size), masked_lang_labels.view(-1)) loss += masked_lm_loss epoch_mlm_loss += masked_lm_loss.detach() loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.task_optim.step() epoch_pretrain_loss += loss.detach() print( f"Total loss for epoch = {epoch_pretrain_loss} ... NSP = {epoch_nsp_loss} ... MLM = {epoch_mlm_loss}" ) print( f"NSP: average of average ....... {sum(epoch_nsp_avg)/len(epoch_nsp_avg)}" ) best_valid = 0. args.task_nsp_qfpm = False args.task_mlm_qfpm = False if args.no_fp_train is True: loader.dataset.skip_semantics = True print(f"********* Finetuning for {args.epochs} epochs *********") for epoch in range(args.epochs): quesid2ans = {} for i, (ques_id, feats, boxes, sent, sem_query, _, target) in iter_wrapper(enumerate(loader)): # sb = [q + ' ** ' + sq for q, sq in zip(sent, sem_query)] # print(sb[:10]) self.model.train() self.optim.zero_grad() feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda( ) logit_qa = self.model(feats, boxes, sent, sem_query) assert logit_qa.dim() == target.dim() == 2 if args.mce_loss: max_value, target = target.max(1) loss = self.mce_loss(logit_qa, target) * logit_qa.size(1) else: loss = self.bce_loss(logit_qa, target) loss = loss * logit_qa.size(1) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optim.step() score, label = logit_qa.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans log_str = "\nEpoch %d: Train %0.2f\n" % ( epoch, evaluator.evaluate(quesid2ans) * 100.) if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(eval_tuple) if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() self.save("LAST") def predict(self, eval_tuple: DataTuple, dump=None): self.model.eval() dset, loader, evaluator = eval_tuple quesid2ans = {} for i, datum_tuple in enumerate(loader): ques_id, feats, boxes, sent, sem_query = datum_tuple[: 5] # avoid handling target with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() logit_qa = self.model(feats, boxes, sent, sem_query) score, label = logit_qa.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans if dump is not None: evaluator.dump_result(quesid2ans, dump) return quesid2ans def evaluate(self, eval_tuple: DataTuple, dump=None): dset, loader, evaluator = eval_tuple quesid2ans = self.predict(eval_tuple, dump) return evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(data_tuple): dset, loader, evaluator = data_tuple quesid2ans = {} for i, (ques_id, feats, boxes, sent, sem_query, sem_matched, target) in enumerate(loader): _, label = target.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans return evaluator.evaluate(quesid2ans) def save(self, name): torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s.pth" % path) for key in list(state_dict.keys()): if '.module' in key: state_dict[key.replace('.module', '')] = state_dict.pop(key) self.model.load_state_dict(state_dict, strict=False)
def __init__(self, args, train_loader=None, val_loader=None, logger=None, num_answers=0, train=True): self.args = args self.max_text_length = args.max_text_length self.train_loader = train_loader self.val_loader = val_loader self.num_answers = num_answers self.logger = logger # Model self.model = GQAModel.from_pretrained("bert-base-uncased", args=args, num_answers=self.num_answers) self.verbose = True if self.args.distributed: if self.args.gpu != 0: self.verbose = False # Load Checkpoint self.start_epoch = None if args.load is not None: path = args.load + '.pth' self.load(path, verbose=self.verbose) elif args.load_lxmert_qa is not None: path = args.load_lxmert_qa + '_LXRT.pth' load_lxmert_qa( args, path, self.model, label2ans=self.train_loader.dataset.raw_dataset.label2ans, verbose=self.verbose) # GPU Options print(f'Model Launching at GPU {self.args.gpu}') from time import time start = time() self.model.cuda(args.gpu) # Optimizer if train: self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler( ) self.bce_loss = nn.BCEWithLogitsLoss() if args.multiGPU: assert args.distributed self.model = DDP(self.model, device_ids=[args.gpu], find_unused_parameters=True) if args.gpu == 0: print(f'It took {time() - start:.1f}s') # Output Directory self.output = args.output os.makedirs(self.output, exist_ok=True)
def __init__(self): self.train_tuple = get_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True) if args.valid != "": valid_bsize = 2048 if args.multiGPU else 512 self.valid_tuple = get_tuple(args.valid, bs=valid_bsize, shuffle=False, drop_last=False) else: self.valid_tuple = None self.model = GQAModel(self.train_tuple.dataset.num_answers) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: self.new_ans_label = load_lxmert_qa( args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) # GPU options self.model = self.model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() # Losses and optimizer self.bce_loss = nn.BCEWithLogitsLoss() self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) # self.KL_loss = nn.KLDivLoss(reduction='none') if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(list(self.model.parameters()), args.lr) self.output = args.output os.makedirs(self.output, exist_ok=True) # Tensorboard self.boards_dir = os.path.join('boards', self.output) if not os.path.exists(self.boards_dir): os.makedirs(self.boards_dir) self.writerTbrd = SummaryWriter(self.boards_dir) # get Glove projection for all answers if args.answer_loss == 'glove': path_glove = './data/GloVe/GloVeDict.pkl' with open(path_glove, 'rb') as f: glove_dic = pickle.load(f) glove_dim = glove_dic['the'].shape[-1] print("Loading Glove%d answer's vector" % glove_dim) self.labelans2glove = [] self.valid_ans_embed = [1] * len( self.train_tuple.dataset.label2ans) for label, ans in enumerate(self.train_tuple.dataset.label2ans): ans = ans.split(' ') glove_ans = [] for w in ans: #print(w) try: glove_ans.append(glove_dic[w]) except KeyError: #print('Full ans: %s' % ans) #input(' ') self.valid_ans_embed[label] = 0 glove_ans.append(np.zeros(glove_dim)) #print(glove_ans) glove_ans = torch.tensor(glove_ans).mean(-2) self.labelans2glove.append(torch.tensor(glove_ans)) #print(self.labelans2glove) print( 'Ratio of valid ans embedding: %f' % (float(sum(self.valid_ans_embed)) / len(self.valid_ans_embed))) self.labelans2glove = torch.stack( self.labelans2glove).float().cuda() self.cosineSim = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
class GQA: def __init__(self): self.train_tuple = get_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True) if args.valid != "": valid_bsize = 2048 if args.multiGPU else 512 self.valid_tuple = get_tuple(args.valid, bs=valid_bsize, shuffle=False, drop_last=False) else: self.valid_tuple = None self.model = GQAModel(self.train_tuple.dataset.num_answers) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: self.new_ans_label = load_lxmert_qa( args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) # GPU options self.model = self.model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() # Losses and optimizer self.bce_loss = nn.BCEWithLogitsLoss() self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) # self.KL_loss = nn.KLDivLoss(reduction='none') if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(list(self.model.parameters()), args.lr) self.output = args.output os.makedirs(self.output, exist_ok=True) # Tensorboard self.boards_dir = os.path.join('boards', self.output) if not os.path.exists(self.boards_dir): os.makedirs(self.boards_dir) self.writerTbrd = SummaryWriter(self.boards_dir) # get Glove projection for all answers if args.answer_loss == 'glove': path_glove = './data/GloVe/GloVeDict.pkl' with open(path_glove, 'rb') as f: glove_dic = pickle.load(f) glove_dim = glove_dic['the'].shape[-1] print("Loading Glove%d answer's vector" % glove_dim) self.labelans2glove = [] self.valid_ans_embed = [1] * len( self.train_tuple.dataset.label2ans) for label, ans in enumerate(self.train_tuple.dataset.label2ans): ans = ans.split(' ') glove_ans = [] for w in ans: #print(w) try: glove_ans.append(glove_dic[w]) except KeyError: #print('Full ans: %s' % ans) #input(' ') self.valid_ans_embed[label] = 0 glove_ans.append(np.zeros(glove_dim)) #print(glove_ans) glove_ans = torch.tensor(glove_ans).mean(-2) self.labelans2glove.append(torch.tensor(glove_ans)) #print(self.labelans2glove) print( 'Ratio of valid ans embedding: %f' % (float(sum(self.valid_ans_embed)) / len(self.valid_ans_embed))) self.labelans2glove = torch.stack( self.labelans2glove).float().cuda() self.cosineSim = torch.nn.CosineSimilarity(dim=1, eps=1e-08) # ? DEBUG CORENTIN **************************************************************** def check_pointer_manually(self, train_tuple): IMAGE_PATH = 'data/gqa/images' dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer)\ in iter_wrapper(enumerate(loader)): for batch_index in range(len(ques_id)): datum = dset.id2datum[ques_id[batch_index]] # Load image im_id = datum['image_id'] im_path = os.path.join(IMAGE_PATH, '%s.jpg' % im_id) image_pil = Image.open(im_path) im = np.array(image_pil, dtype=np.uint8) height = image_pil.height width = image_pil.width # Load annotations question = datum['sent'] question_pointer = datum['pointer']['question'] answer = list(datum['label'])[0] answer_pointer = datum['pointer']['answer'] # * * Display pointer and bboxes # Create plot fig = plt.figure() plt.suptitle(question) plt.title(answer) ax = fig.add_subplot(1, 1, 1) # draw image ax.imshow(im) # draw detected boxes def iou_preprocess(iou): TRESHOLD = 0.1 TOPK = 5 # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0) #iou / (iou.sum() + 1e-9) # f_iou = norm_iou * (iou.sum() >= TRESHOLD) sorted_idx = np.argsort(iou)[::-1] iou_topk = iou iou_topk[sorted_idx[TOPK:]] = -1e9 f_iou = np.exp(iou_topk) / np.sum( np.exp(iou_topk), axis=0) #iou / (iou.sum() + 1e-9) f_iou = f_iou * (iou_topk.clip(min=0).sum() >= TRESHOLD) return f_iou detected_bboxe = boxes[batch_index] * torch.tensor( [width, height, width, height]).float() total_iou_per_object = np.zeros((boxes.size(1))) for _, pointer in question_pointer.items(): iou = np.array(pointer['iou']) f_iou = iou_preprocess(iou) total_iou_per_object += f_iou for _, pointer in answer_pointer.items(): iou = np.array(pointer['iou']) f_iou = iou_preprocess(iou) total_iou_per_object += f_iou intensity = total_iou_per_object.clip(min=0, max=1) c = [np.array([0, 0, 1]) for j in range(boxes.size(1))] draw_bboxes(detected_bboxe.numpy(), ax, color=c, alpha=intensity) # draw pointer boxes for word_id, pointer in question_pointer.items(): bboxe = pointer['boxe'] bboxe = [ bboxe[0] * width, bboxe[1] * height, bboxe[2] * width, bboxe[3] * height ] c = [np.array([1, 0, 0])] draw_bboxes([bboxe], ax, color=c, label=['q_%s' % word_id]) for word_id, pointer in answer_pointer.items(): bboxe = pointer['boxe'] bboxe = [ bboxe[0] * width, bboxe[1] * height, bboxe[2] * width, bboxe[3] * height ] c = [np.array([0, 1, 0])] draw_bboxes([bboxe], ax, color=c, label=['a_%s' % word_id]) plt.savefig('check_pointer_%s_sftmx.jpg' % im_id) plt.close() # * * Retrieve statistics # input('Press ENTER for next image') def pointer_stats(self, train_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) # Stat watchers pointer_per_question = [] nb_words_per_question = [] pointed_words = {} max_iou = [] top5_cumiou = [] top10_cumiou = [] total_cumiou = [] start = time.time() for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer)\ in iter_wrapper(enumerate(loader)): for batch_index in range(len(ques_id)): datum = dset.id2datum[ques_id[batch_index]] # Load annotations question = datum['sent'] question_pointer = datum['pointer']['question'] answer = list(datum['label'])[0] answer_pointer = datum['pointer']['answer'] # parse question and answer parsed_question = question.translate( str.maketrans('', '', string.punctuation)).split(' ') parsed_answer = answer.translate( str.maketrans('', '', string.punctuation)).split(' ') # Stats words pointer_per_question.append( len(question_pointer) + len(answer_pointer)) nb_words_per_question.append( len(parsed_question) + len(parsed_answer)) def add2dic(pointer, parsed_sent): for w_idx in pointer: if ':' in w_idx: indexes = w_idx.split(':') for j in range(int(indexes[0]), int(indexes[1])): word = parsed_sent[j] if word in pointed_words: pointed_words[word] += 1 else: pointed_words[word] = 1 else: word = parsed_sent[int(w_idx)] if word in pointed_words: pointed_words[word] += 1 else: pointed_words[word] = 1 add2dic(question_pointer, parsed_question) add2dic(answer_pointer, parsed_answer) # Stats IoU # max max_iou_question = iou_question.max(-1)[0] max_iou += max_iou_question[ max_iou_question > 0].flatten().tolist() max_iou_answer = iou_answer.max(-1)[0] max_iou += max_iou_answer[ max_iou_answer > 0].flatten().tolist() # top5 top5_cumiou_question = iou_question.topk(5)[0].sum(-1) top5_cumiou_answer = iou_answer.topk(5)[0].sum(-1) top5_cumiou += top5_cumiou_question[ top5_cumiou_question > 0].flatten().tolist() top5_cumiou += top5_cumiou_answer[ top5_cumiou_answer > 0].flatten().tolist() # top10 top10_cumiou_question = iou_question.topk(10)[0].sum(-1) top10_cumiou_answer = iou_answer.topk(10)[0].sum(-1) top10_cumiou += top10_cumiou_question[ top10_cumiou_question > 0].flatten().tolist() top10_cumiou += top10_cumiou_answer[ top10_cumiou_answer > 0].flatten().tolist() # total total_cumiou_question = iou_question.sum(-1) total_cumiou_answer = iou_answer.sum(-1) total_cumiou += total_cumiou_question[ total_cumiou_question > 0].flatten().tolist() total_cumiou += total_cumiou_answer[ total_cumiou_answer > 0].flatten().tolist() elapsed_time = time.time() - start print("Time: %.1fmin" % (elapsed_time / 60)) # Save into pickle dic dic = { 'pointer_per_question': pointer_per_question, 'nb_words_per_question': nb_words_per_question, 'pointed_words': pointed_words, 'max_iou': max_iou, 'top5_cumiou': top5_cumiou, 'top10_cumiou': top10_cumiou, 'total_cumiou': total_cumiou } with open('stats_pointer.pickle', 'wb') as handle: pickle.dump(dic, handle, protocol=pickle.HIGHEST_PROTOCOL) # ? DEBUG CORENTIN **************************************************************** # ? Manual evaluation (for matching) ********************************************** def eval_matching_manual(self, eval_tuple): IMAGE_PATH = 'data/gqa/images' self.model.eval() dset, loader, evaluator = eval_tuple quesid2ans = {} for i, datum_tuple in enumerate(loader): ques_id, feats, boxes, sent, _, iou_question, iou_answer = datum_tuple with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() iou_question, iou_answer = iou_question.cuda( ), iou_answer.cuda() logit, iou_target, iou_pred = self.model( feats, boxes, sent, iou_question, iou_answer) score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans for batch_index in range(len(ques_id)): # Retrieve info + prediction qid = ques_id[batch_index] datum = dset.id2datum[qid] question = datum['sent'] answer_pred = quesid2ans[qid] # Load image im_id = datum['image_id'] im_path = os.path.join(IMAGE_PATH, '%s.jpg' % im_id) image_pil = Image.open(im_path) im = np.array(image_pil, dtype=np.uint8) height = image_pil.height width = image_pil.width detected_bboxe = boxes[batch_index].cpu() * torch.tensor( [width, height, width, height]).float() # Display iou prediction for w in range(iou_pred[batch_index].size(0)): fig = plt.figure() plt.suptitle('Q:%s __ Idx:%d' % (question, w)) plt.title(answer_pred) # all predicted bboxes ax = fig.add_subplot(2, 1, 1) ax.imshow(im) c = [np.array([0, 0, 1]) for j in range(boxes.size(1))] draw_bboxes(detected_bboxe.numpy(), ax, color=c) # matched bboxes iou = iou_pred[batch_index, w] iou_norm = iou / (iou.sum() + 1e-9) ax = fig.add_subplot(2, 1, 2) ax.imshow(im) c = [np.array([0, 0, 1]) for j in range(boxes.size(1))] draw_bboxes(detected_bboxe.numpy(), ax, color=c, alpha=iou_norm) plt.savefig('t0.4_pred_pointer_%s_%d.jpg' % (im_id, w)) plt.close() input('Press ENTER for next image') # ? Manual evaluation (for matching) ********************************************** def gqa_analysis(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words,)\ in iter_wrapper(enumerate(loader)): with torch.no_grad(): feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda( ) iou_question, iou_answer = iou_question.cuda( ), iou_answer.cuda() sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda( ), sem_answer_words.cuda(), bboxes_words.cuda() logit, iou_target, iou_score, lang_feat, vis_feat, tkn_sent = self.model( feats, boxes, sent, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words, verbose=True) for i in range(lang_feat.size(0)): len_sent = len(tkn_sent[i]) self.writerTbrd.add_embedding(lang_feat[i, :len_sent], metadata=tkn_sent[i]) pass def extract_maps(self, eval_tuple): self.model.eval() #self.model.cpu() dset, loader, evaluator = eval_tuple timer = time.time() att_maps = None for i, datum_tuple in enumerate(loader): ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words = datum_tuple with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() iou_question, iou_answer = iou_question.cuda( ), iou_answer.cuda() sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda( ), sem_answer_words.cuda(), bboxes_words.cuda() logit, iou_target, iou_score, activations = self.model( feats, boxes, sent, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words) score, label = logit.max(1) # init according to model's architecture activations['cross'] = [ item for sublist in activations['cross'] for item in sublist ] # flatten if att_maps is None: print('map shape', activations['lang'][0].shape) n_head = activations['lang'][0].shape[1] att_maps = { 'lang': [ torch.zeros((n_head)) for t in range(len(activations['lang'])) ], 'vis': [ torch.zeros((n_head)) for t in range(len(activations['vis'])) ], 'cross': [ torch.zeros((n_head)) for t in range(len(activations['cross'])) ] } map_names = [] for maptype in ['lang', 'vis', 'cross']: for idx, maps in enumerate(activations[maptype]): if maptype == 'cross': sub_id = idx % 4 sub_maptype = ['xvl', 'xlv', 'xl', 'xv'] name = '%s%d' % (sub_maptype[sub_id], idx) else: name = '%s%d' % (maptype[0], idx) map_names.append(name) d_b, d_h, d_1, d_2 = maps.shape # [batch, head, d1, d2] head_max = torch.max(maps.view(d_b, d_h, d_1 * d_2), dim=-1).values # [batch X heads] head_max = head_max.sum(0) # sum over batch: [heads] att_maps[maptype][idx] += head_max.cpu().data all_max_maps = torch.cat([ torch.stack(att_maps['lang']), torch.stack(att_maps['vis']), torch.stack(att_maps['cross']) ]).numpy() # [layer, heads] all_max_maps = all_max_maps / len(dset) print("MAX ATT divided", all_max_maps) # Draw histogram FIG_PATH = self.output draw_histogram(all_max_maps, path=FIG_PATH, labels=map_names) input('_') print('Processes set in %ds' % (time.time() - timer)) def train(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) best_valid = 0. optim_steps = 0 for epoch in range(args.epochs): quesid2ans = {} for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words,)\ in iter_wrapper(enumerate(loader)): self.model.train() self.optim.zero_grad() # DEBUG: print pointer (set batch size to 1) # print(dset.id2datum[ques_id[0]]['sent']) # print(dset.id2datum[ques_id[0]]['label']) # q_pointer = dset.id2datum[ques_id[0]]['pointer']['question'] # for w_index in q_pointer: # print(w_index) feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda( ) iou_question, iou_answer = iou_question.cuda( ), iou_answer.cuda() sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda( ), sem_answer_words.cuda(), bboxes_words.cuda() logit, iou_target, iou_score = self.model( feats, boxes, sent, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words) assert logit.dim() == target.dim() == 2 if args.mce_loss: max_value, target = target.max(1) loss = self.mce_loss(logit, target) * logit.size(1) else: loss = self.bce_loss(logit, target) loss = loss * logit.size(1) #print('CE', loss.item()) if args.answer_loss == 'glove': gold_glove = (self.labelans2glove.unsqueeze(0) * target.unsqueeze(-1)).sum(1) #gold_ans = self.train_tuple.dataset.label2ans[target.argmax(dim=1)[0]] #print('gold:', gold_ans) pred_glove = ( self.labelans2glove.unsqueeze(0) * torch.softmax(logit, dim=1).unsqueeze(-1)).sum(1) #pred_ans = self.train_tuple.dataset.label2ans[logit.argmax(dim=1)[0]] #print('pred:', pred_ans) sim_answer = self.cosineSim(gold_glove, pred_glove).mean() loss += -10 * sim_answer #print('Similarity', sim_answer) #input(' ') if optim_steps % 1000 == 0: self.writerTbrd.add_scalar('vqa_loss_train', loss.item(), optim_steps) # task_pointer = 'KLDiv' ALPHA = args.alpha_pointer def iou_preprocess(iou, obj_conf=None): TRESHOLD = 0.1 TOPK = 3 # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0) #iou / (iou.sum() + 1e-9) # f_iou = norm_iou * (iou.sum() >= TRESHOLD) sorted_values = torch.sort(iou, descending=True, dim=-1)[0] t_top = sorted_values[:, :, TOPK - 1] iou_topk = iou.masked_fill(iou < t_top.unsqueeze(-1), -1e9) f_iou = torch.softmax(iou_topk, dim=-1) treshold_mask = (iou_topk.clamp(min=.0).sum(-1) >= TRESHOLD).float() if args.task_pointer == 'KLDiv': return f_iou, treshold_mask elif args.task_pointer == 'Triplet': # Remove top10 most similar objects t_bot = sorted_values[:, :, 10] iou_botk = (iou < t_bot.unsqueeze(-1)).float() # Take topk most confident objects conf_top = torch.sort(obj_conf.unsqueeze(1) * iou_botk, descending=True, dim=-1)[0][:, :, TOPK - 1] conf_mask = obj_conf.unsqueeze(1).expand( -1, iou.size(1), -1) >= conf_top.unsqueeze(-1) neg_score = iou_botk * conf_mask.float() return f_iou, treshold_mask, neg_score if args.task_pointer == 'KLDiv': iou_target_preprocess, treshold_mask = iou_preprocess( iou_target) loss_pointer_fct = KLDivLoss(reduction='none') iou_pred = torch.log_softmax(iou_score, dim=-1) matching_loss = loss_pointer_fct( input=iou_pred, target=iou_target_preprocess) matching_loss = ALPHA * (matching_loss.sum(-1) * treshold_mask).sum() / ( (treshold_mask).sum() + 1e-9) if optim_steps % 1000 == 0: self.writerTbrd.add_scalar('pointer_loss_train', matching_loss.item(), optim_steps) loss += matching_loss # ? by Corentin: Matching loss # def iou_preprocess(iou): # TRESHOLD = 0.1 # TOPK = 1 # # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0) #iou / (iou.sum() + 1e-9) # # f_iou = norm_iou * (iou.sum() >= TRESHOLD) # t = torch.sort(iou, descending=True, dim=-1)[0][:, :, TOPK-1] # iou_topk = iou.masked_fill(iou < t.unsqueeze(-1), -1e9) # f_iou = torch.softmax(iou_topk, dim=-1) # treshold_mask = (iou_topk.clamp(min=.0).sum(-1) >= TRESHOLD).float() # return f_iou, treshold_mask # # discard iou_target when total iou is under treshold # # it includes unsupervised datum # iou_target_preprocess, treshold_mask = iou_preprocess(iou_target) # iou_pred = torch.log_softmax(iou_pred, dim=-1) # # KL loss # matching_loss = [] # matching_loss = self.KL_loss(input=iou_pred, target=iou_target_preprocess) # matching_loss = (matching_loss.sum(-1) * treshold_mask).sum() / treshold_mask.sum() # if optim_steps % 1000 == 0: # self.writerTbrd.add_scalar('pointer_loss_train', matching_loss.item(), optim_steps) # ALPHA = 5.0 # loss += ALPHA * matching_loss # ? ************************** loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optim.step() optim_steps += 1 score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans # if self.valid_tuple is not None and optim_steps % 1152 == 0: # Do Validation # valid_score = self.evaluate(eval_tuple) # fastepoch = int(optim_steps / 1152) # print("fastEpoch %d: Valid %0.2f\n" % (fastepoch, valid_score * 100.,)) log_str = "\nEpoch %d: Train %0.2f\n" % ( epoch, evaluator.evaluate(quesid2ans) * 100.) if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(eval_tuple) self.writerTbrd.add_scalar('vqa_acc_valid', valid_score, epoch) if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() self.save("LAST") def predict(self, eval_tuple: DataTuple, dump=None, iou=False): self.model.eval() #self.model.cpu() dset, loader, evaluator = eval_tuple quesid2ans = {} quesid2iou = {} timer = time.time() for i, datum_tuple in enumerate(loader): ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words = datum_tuple with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() iou_question, iou_answer = iou_question.cuda( ), iou_answer.cuda() sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda( ), sem_answer_words.cuda(), bboxes_words.cuda() logit, iou_target, iou_score = self.model( feats, boxes, sent, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words) score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans quesid2iou[qid] = None #iou_pred print('Processes set in %ds' % (time.time() - timer)) if dump is not None: evaluator.dump_result(quesid2ans, dump) if iou is True: return quesid2ans, quesid2iou else: return quesid2ans def evaluate(self, eval_tuple: DataTuple, dump=None): dset, loader, evaluator = eval_tuple quesid2ans = self.predict(eval_tuple, dump) return evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(data_tuple): dset, loader, evaluator = data_tuple quesid2ans = {} for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words,)\ in enumerate(loader): _, label = target.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans return evaluator.evaluate(quesid2ans) def save(self, name): torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s.pth" % path) for key in list(state_dict.keys()): if '.module' in key: state_dict[key.replace('.module', '')] = state_dict.pop(key) self.model.load_state_dict(state_dict, strict=False) def finetune(self, train_tuple, eval_tuple): # log output_1 = os.path.join(self.output, 'finetune_1') os.makedirs(output_1, exist_ok=True) output_2 = os.path.join(self.output, 'finetune_2') os.makedirs(output_2, exist_ok=True) # Tensorboard boards_dir_1 = os.path.join(self.boards_dir, 'finetune_1') if not os.path.exists(boards_dir_1): os.makedirs(boards_dir_1) boards_dir_2 = os.path.join(self.boards_dir, 'finetune_2') if not os.path.exists(boards_dir_2): os.makedirs(boards_dir_2) # Params lr_1 = args.lr lr_2 = args.lr / 10 epochs_1 = 4 #int(args.epochs / 3) epochs_2 = args.epochs - epochs_1 # Step 0: evaluate pretraining if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(eval_tuple) print("Before finetune: Valid %0.2f\n" % (valid_score * 100.)) # Step 0.1: finetune new ans only # new_ans_params = [] # for name, p in self.model.named_parameters(): # if "logit_fc.3" in name: # for idx in range(p.size(0)): # if idx in self.new_ans_label: # new_ans_params.append({'params': p[idx]}) # args.epochs = epochs_0 # from lxrt.optimization import BertAdam # self.optim = BertAdam(new_ans_params, # lr=lr_1, # warmup=0.0, # t_total=-1) # print('### Start finetuning new ans...') # self.train(train_tuple, eval_tuple) # First step, only updates answer head #self.optim = torch.optim.Adamax(list(self.model.parameters()), lr_1) #self.optim = torch.optim.SGD(list(self.model.parameters()), lr_1) args.epochs = epochs_1 batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * epochs_1) print("Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam( list(self.model.parameters()), lr=lr_1, warmup=0.0, #!0.034 t_total=-1) # loaded_optim = torch.load("%s_LXRT.pth" % args.load_lxmert_qa)['optimizer'] # self.optim.load_state_dict(loaded_optim) # for group in loaded_optim.param_groups: # for p in group['params']: # if p in loaded_optim['state']: # self.optim.state[p] = loaded_optim.state[p] self.writerTbrd = SummaryWriter(boards_dir_1) self.output = output_1 for name, p in self.model.named_parameters(): if "logit_fc" in name: p.requires_grad = True else: p.requires_grad = False print('### Start finetuning step 1...') self.train(train_tuple, eval_tuple) # Second step, finetune all for name, p in self.model.named_parameters(): p.requires_grad = True if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * epochs_2) print("Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=lr_2, warmup=0.1, t_total=t_total, lr_min=1e-7) else: self.optim = args.optimizer(list(self.model.parameters()), lr_2) args.epochs = epochs_2 self.writerTbrd = SummaryWriter(boards_dir_2) self.output = output_2 print('### Start finetuning step 2...') self.train(train_tuple, eval_tuple)
class GQA: def __init__(self): self.train_tuple = get_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True) if args.valid != "": valid_bsize = 512 if args.multiGPU else 512 self.valid_tuple = get_tuple(args.valid, bs=valid_bsize, shuffle=False, drop_last=False) else: self.valid_tuple = None self.model = GQAModel(self.train_tuple.dataset.num_answers) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: load_lxmert_qa(args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) # GPU options self.model = self.model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() # Losses and optimizer self.bce_loss = nn.BCEWithLogitsLoss() self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(list(self.model.parameters()), args.lr) self.output = args.output os.makedirs(self.output, exist_ok=True) def train(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) best_valid = 0. for epoch in range(args.epochs): quesid2ans = {} for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): self.model.train() self.optim.zero_grad() feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda( ) logit = self.model(feats, boxes, sent) assert logit.dim() == target.dim() == 2 if args.mce_loss: max_value, target = target.max(1) loss = self.mce_loss(logit, target) * logit.size(1) else: loss = self.bce_loss(logit, target) loss = loss * logit.size(1) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optim.step() score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans log_str = "\nEpoch %d: Train %0.2f\n" % ( epoch, evaluator.evaluate(quesid2ans) * 100.) if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(eval_tuple) if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() self.save("LAST") def predict(self, eval_tuple: DataTuple, dump=None): self.model.eval() dset, loader, evaluator = eval_tuple quesid2ans = {} for i, datum_tuple in enumerate(loader): if i % 100 == 0: print(i) ques_id, feats, boxes, sent = datum_tuple[: 4] # avoid handling target with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() logit = self.model(feats, boxes, sent) score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans if dump is not None: evaluator.dump_result(quesid2ans, dump) return quesid2ans def evaluate(self, eval_tuple: DataTuple, dump=None): dset, loader, evaluator = eval_tuple quesid2ans = self.predict(eval_tuple, dump) return evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(data_tuple): dset, loader, evaluator = data_tuple quesid2ans = {} for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): _, label = target.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid] = ans return evaluator.evaluate(quesid2ans) def save(self, name): torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s.pth" % path) for key in list(state_dict.keys()): if '.module' in key: state_dict[key.replace('.module', '')] = state_dict.pop(key) self.model.load_state_dict(state_dict, strict=False)