class VQA: def __init__(self): # Model self.model = VQAModel() # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: load_lxmert_qa(args.load_lxmert_qa, self.model) # GPU options self.model = self.model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() # Loss and Optimizer self.bce_loss = nn.BCEWithLogitsLoss() self.train_tuple = get_data_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True) self.valid_tuple = get_data_tuple(args.valid, bs=1024, shuffle=False, drop_last=False) self.test_tuple = get_data_tuple(args.test, bs=1024, shuffle=False, drop_last=False) if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("BertAdam Total Iters: %d" % t_total) from lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(self.model.parameters(), args.lr) # Output Directory self.output = args.output os.makedirs(self.output, exist_ok=True) def train(self, train_tuple, eval_tuple, test_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) start_epoch = -1 best_valid = 0. if args.RESUME: path_checkpoint = args.path_checkpoint # ../../model/checkpoint/lxr_best_%s.pth # 断点路径 checkpoint = torch.load(path_checkpoint) # 加载断点 self.model.load_state_dict(checkpoint['model_state_dict']) self.optim.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] for epoch in range(start_epoch + 1, args.epochs): for i, (feats, boxes, sent, _, _) in iter_wrapper(enumerate(loader)): # construct negative exmaples bs = len(sent) index_list = list(range(bs)) sent_negative = [] for j in range(bs): choice = random.choice(list(set(index_list) - {j})) sent_negative.append(sent[choice]) sent = sent + sent_negative feats = torch.cat([feats, feats]) boxes = torch.cat([boxes, boxes]) target = torch.ones(bs, 1) target_negative = torch.zeros(bs, 1) target = torch.cat([target, target_negative]) self.model.train() self.optim.zero_grad() feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda( ) logit = self.model(feats, boxes, sent) assert logit.dim() == target.dim() == 2 loss = self.bce_loss(logit, target) loss = loss * logit.size(1) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optim.step() batch_score = accuracy(logit, target) if i % 500 == 0: print('epoch {}, Step {}/{}, Loss: {}'.format( epoch, i, len(loader), loss.item())) log_str = "\nEpoch %d: Loss: %0.4f Train %0.2f\n" % ( epoch, loss.item(), batch_score) if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(eval_tuple, epoch) self.save_checkpoint(epoch) self.save(epoch) self.test_output(test_tuple, epoch) print('output done!') if valid_score > best_valid: best_valid = valid_score log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score ) + \ "Epoch %d: Best %0.2f\n" % (epoch, best_valid ) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() self.save("LAST") def predict(self, eval_tuple: DataTuple, epoch=0): """ Predict the answers to questions in a data split. :param eval_tuple: The data tuple to be evaluated. :param dump: The path of saved file to dump results. :return: A dict of question_id to answer. """ self.model.eval() dset, loader, evaluator = eval_tuple preds = [] query_ids = [] product_ids = [] with torch.no_grad(): for i, datum_tuple in enumerate(loader): feats, boxes, sent, ques_id, produce_id = datum_tuple # Avoid seeing ground truth query_ids.extend(ques_id) product_ids.extend(produce_id) feats, boxes = feats.cuda(), boxes.cuda() logit = self.model(feats, boxes, sent) logit = torch.sigmoid(logit) preds.append(logit.cpu().numpy()) preds = np.concatenate(preds) # deal with format query2product = collections.defaultdict(list) for i, query_id in enumerate(query_ids): pred = preds[i] product_id = product_ids[i] query2product[str(query_id.item())].append( [pred.tolist()[0], str(product_id.item())]) with open('../../user_data/lxmert_model/result/val/val_%s.txt' % epoch, 'w') as f: for q, p in query2product.items(): q = str(q) p = str(p) f.write(q + ',' + p + '\n') f.close() with open('submission.csv', 'w') as f: f.write('query-id,product1,product2,product3,product4,product5\n') for q, p in query2product.items(): p = sorted(p, key=lambda x: x[0], reverse=True) o = ','.join([p[i][1] for i in range(5)]) f.write(q + ',' + o + '\n') os.system('python eval.py submission.csv') score = json.load(open('score.json'))["score"] return score def test_output(self, eval_tuple: DataTuple, epoch=0): self.model.eval() dset, loader, evaluator = eval_tuple preds = [] query_ids = [] product_ids = [] with torch.no_grad(): for i, datum_tuple in enumerate(loader): feats, boxes, sent, ques_id, produce_id = datum_tuple # Avoid seeing ground truth query_ids.extend(ques_id) product_ids.extend(produce_id) feats, boxes = feats.cuda(), boxes.cuda() logit = self.model(feats, boxes, sent) logit = torch.sigmoid(logit) preds.append(logit.cpu().numpy()) preds = np.concatenate(preds) # deal with format query2product = collections.defaultdict(list) for i, query_id in enumerate(query_ids): pred = preds[i] product_id = product_ids[i] query2product[str(query_id.item())].append( [pred.tolist()[0], str(product_id.item())]) print(os.getcwd()) with open( '../../user_data/lxmert_model/result/test/test_%s.txt' % epoch, 'w') as f: for q, p in query2product.items(): q = str(q) p = str(p) f.write(q + ',' + p + '\n') f.close() def evaluate(self, eval_tuple: DataTuple, epoch=0): """Evaluate all data in data_tuple.""" score = self.predict(eval_tuple, epoch) return score #eval_tuple.evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(data_tuple): dset, loader, evaluator = data_tuple quesid2ans = {} for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): _, label = target.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans return evaluator.evaluate(quesid2ans) def save(self, epoch): torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % (str(epoch)))) def save_checkpoint(self, epoch): checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optim.state_dict(), } if not os.path.isdir('../../user_data/lxmert_model/checkpoint'): os.mkdir('../../user_data/lxmert_model/checkpoint') torch.save(checkpoint, '../../user_data/lxmert_model/checkpoint/lxr_best.pth') def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s.pth" % path) self.model.load_state_dict(state_dict)
def train(self, train_tuple: DataTuple, eval_tuple: DataTuple): train_ld = train_tuple.loader # Optimizer from lxrt.optimization import BertAdam batch_per_epoch = len(train_ld) t_total = int(batch_per_epoch * args.epochs) warmup_ratio = 0.05 warmup_iters = int(t_total * warmup_ratio) print("Batch per epoch: %d" % batch_per_epoch) print("Total Iters: %d" % t_total) print("Warm up Iters: %d" % warmup_iters) optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total) start_epoch = 0 if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) self.model, optim = amp.initialize(self.model, optim, opt_level='O1') # GPU Options if args.multiGPU: self.model = nn.DataParallel(self.model) if args.start_from > 0 and args.pretraining_index == 0: start_path = os.path.join( args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02')) print('start training from {0}'.format(start_path)) state = torch.load(start_path) self.model.load_state_dict(state['state_dict']) optim.load_state_dict(state['optimizer'], strict=False) start_epoch = args.start_from del state torch.cuda.empty_cache() elif args.start_from > 0 and args.pretraining_index == 1: start_path = os.path.join( args.output, "Epoch%s_LXRT.pth" % format(int(args.start_from), '02')) print('start training from {0}'.format(start_path)) state = torch.load(start_path) self.model.load_state_dict(state['state_dict'], strict=False) del state torch.cuda.empty_cache() # Train best_eval_loss = 9595. for epoch in range(start_epoch, args.epochs): # Train self.model.train() total_loss = 0. total_losses = 0. uid2ans = {} for batch in tqdm(train_ld, total=len(train_ld)): loss, losses, logit = self.train_batch(optim, batch) total_loss += loss total_losses += losses if args.task_qa: score, label = logit.max(1) for datum, l in zip(batch, label.cpu().numpy()): uid = datum.uid ans = train_tuple.dataset.answer_table.id2ans(l) uid2ans[uid] = ans print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch)) log_str = "\nThe training loss for Epoch %d is %0.4f" % ( epoch, total_loss / batch_per_epoch) losses_str = "\nThe losses are " log_str += "\nThe losses are " for name, loss in zip(LOSSES_NAME, total_losses): losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch) log_str += "\n %s: %0.4f " % (name, loss / batch_per_epoch) print(losses_str) with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() if args.task_qa: train_tuple.evaluator.evaluate(uid2ans, pprint=True) # Eval avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1) state = { 'state_dict': self.model.state_dict(), 'optimizer': optim.state_dict(), } # Save if avg_eval_loss < best_eval_loss: best_eval_loss = avg_eval_loss self.save("BEST_EVAL_LOSS", state) if args.pretraining_index == 0: self.save("Epoch%02d" % (epoch + 1), state) elif args.pretraining_index == 1: self.save("Epoch%02d" % (epoch + 1 + args.start_from), state)
class LXMERT: def __init__(self, max_seq_length): super().__init__() self.max_seq_length = max_seq_length self.tokenizer = BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True ) # Build model self.model = LXRTPretraining.from_pretrained( "bert-base-uncased", args = args, task_mask_lm=args.task_mask_lm, task_obj_predict=args.task_obj_predict, task_matched=args.task_matched, task_qa=args.task_qa, visual_losses=args.visual_losses, num_answers= args.num_answers if args.get("num_answers", None) else train_tuple.dataset.answer_table.num_answers ) # Weight initialization and loading if args.from_scratch: print("Train from Scratch: re-initialize all BERT weights.") self.model.apply(self.model.init_bert_weights) if args.get("use_tag_symbolic_embedding", False): self.model.bert.embeddings.initialize_symbolic_embeddings(symbolic_vocab.get_symbolic_list(self.tokenizer)) self.model.special_initialize_pretraining_head() if args.get("hybrid_embedding", False): self.model.bert.embeddings.initialize_visual_position_type_embeddings() if args.load_lxmert is not None: # Load lxmert would not load the answer head. self.load_lxmert(args.load_lxmert) self.model = self.model.cuda() if args.multiGPU: self.model = nn.DataParallel(self.model) self.global_step = 0 def forward(self, examples): for index, i in enumerate(examples): if i is not None: if isinstance(i, dict): for key in i: i[key] = (i[key][0].cuda(), i[key][1].cuda()) else: examples[index] = i.cuda() input_ids, segment_ids, input_mask, lm_labels, feats, pos, obj_labels, matched_labels, ans, visual_feats_seg_ids, visual_tags, visual_tags_mask, visual_tags_box, visual_tags_objective, visual_tags_mismatch, visual_tags_segment_ids = examples loss, losses, ans_logit, losses_dict = self.model( input_ids, segment_ids, input_mask, lm_labels, feats, pos, obj_labels, matched_labels, ans, visual_feats_seg_ids = visual_feats_seg_ids, visual_tags = visual_tags, visual_tags_mask = visual_tags_mask, visual_tags_box = visual_tags_box, visual_tags_objective = visual_tags_objective, visual_tags_mismatch = visual_tags_mismatch, visual_tags_segment_ids = visual_tags_segment_ids ) return loss, losses.detach().cpu(), ans_logit, losses_dict def train_batch(self, optim, batch): gradient_accumulation_steps = args.get("gradient_accumulation_steps", 1) if (self.global_step + 1) % gradient_accumulation_steps == 0: optim.zero_grad() loss, losses, ans_logit, losses_dict = self.forward(batch) if args.multiGPU: loss = loss.mean() losses = losses.mean(0) if gradient_accumulation_steps > 1: loss = loss / gradient_accumulation_steps loss.backward() if (self.global_step + 1) % gradient_accumulation_steps == 0: nn.utils.clip_grad_norm_(self.model.parameters(), 1.) optim.step() return loss.item(), losses.cpu().numpy(), ans_logit, losses_dict def valid_batch(self, batch): with torch.no_grad(): loss, losses, ans_logit, losses_dict = self.forward(batch) if args.multiGPU: loss = loss.mean() losses = losses.mean(0) return loss.item(), losses.cpu().numpy(), ans_logit, losses_dict def train(self, train_tuple: DataTuple, eval_tuple: DataTuple): train_ld = train_tuple.loader # Optimizer from lxrt.optimization import BertAdam batch_per_epoch = len(train_ld) t_total = int(batch_per_epoch * args.epochs) warmup_ratio = args.get("warmup_ratio", 0.05) print("Total Iters: %d" % t_total) if args.get("t_total", None): t_total = args.t_total print("!! Changing to specified t_toal in args: {}".format(t_total)) self.t_total = t_total warmup_iters = int(t_total * warmup_ratio) print("Batch per epoch: %d" % batch_per_epoch) print("Warm up Iters: %d" % warmup_iters) self.optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total) if args.load is not None: self.load(args.load, t_total = t_total) gradient_accumulation_steps = args.get("gradient_accumulation_steps", 1) # Train best_eval_loss = 9595. report_every = args.get("report_every", 100) custom_train_meter = TrainingMeter() for epoch in range(args.epochs): # Train self.model.train() total_loss = 0. total_losses = 0. uid2ans = {} for batch_id, batch in enumerate(tqdm(train_ld, total=len(train_ld))): if args.get("skip_training", False): break loss, losses, logit, losses_dict = self.train_batch(self.optim, batch) total_loss += loss try: total_losses += losses except: pass if args.task_qa and batch[0].sent is not None: assert(0) # Not used in our experiment score, label = logit.max(1) for datum, l in zip(batch, label.cpu().numpy()): uid = datum.uid ans = train_tuple.dataset.answer_table.id2ans(l) uid2ans[uid] = ans for key, value in losses_dict.items(): losses_dict[key] = value.mean().item() # make the losses scalar if "Masked LM" in losses_dict and losses_dict["Masked LM"] == 0: del losses_dict["Masked LM"] custom_train_meter.update(losses_dict) if batch_id % report_every == 0 and batch_id > 0: print("Folder: {} \n Epoch {} Iter: {}/{}".format(args.output, epoch, batch_id, len(train_ld))) #print(pd.DataFrame(train_results[-report_every:]).mean()) custom_train_meter.report() custom_train_meter.clean() print() if args.get("save_step", -1) != -1 and self.global_step != 0 and (self.global_step // gradient_accumulation_steps) % args.save_step == 0: self.save("Step{}".format(self.global_step)) self.global_step += 1 print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch)) if args.task_qa: train_tuple.evaluator.evaluate(uid2ans, pprint=True) # Eval avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1) if args.get("eval_on_train", False): print("On train set") self.evaluate_epoch(train_tuple, iters=-1) if avg_eval_loss < best_eval_loss: best_eval_loss = avg_eval_loss self.save("BEST_EVAL_LOSS") self.save("Epoch%02d" % (epoch+1)) def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1): self.model.eval() eval_ld = eval_tuple.loader total_loss = 0. total_losses = 0. uid2ans = {} eval_meter = TrainingMeter() for i, batch in enumerate(tqdm(eval_ld)): loss, losses, logit, losses_dict = self.valid_batch(batch) total_loss += loss try: total_losses += losses except: pass for key, value in losses_dict.items(): losses_dict[key] = value.mean().item() eval_meter.update(losses_dict) if args.task_qa: score, label = logit.max(1) for datum, l in zip(batch, label.cpu().numpy()): uid = datum.uid ans = train_tuple.dataset.answer_table.id2ans(l) uid2ans[uid] = ans if i == iters: break print("Evaluation:") eval_meter.report() print("\n\n\n\n\n\n\n\n") if args.task_qa: eval_tuple.evaluator.evaluate(uid2ans, pprint=True) return total_loss / len(eval_ld) def evaluate_epoch_text(self, eval_tuple: DataTuple, iters: int=-1): self.model.eval() eval_ld = eval_tuple.textonly total_loss = 0. total_losses = 0. uid2ans = {} eval_meter = TrainingMeter() for i, batch in enumerate(tqdm(eval_ld)): loss, losses, logit, losses_dict = self.valid_batch(batch) total_loss += loss total_losses += losses for key, value in losses_dict.items(): losses_dict[key] = value.mean().item() eval_meter.update(losses_dict) if args.task_qa: score, label = logit.max(1) for datum, l in zip(batch, label.cpu().numpy()): uid = datum.uid ans = train_tuple.dataset.answer_table.id2ans(l) uid2ans[uid] = ans if i == iters: break print("Evaluation text only:") eval_meter.report() print("\n\n\n\n\n\n\n\n") return total_loss / len(eval_ld) def save(self, name): torch.save(self.model.state_dict(), os.path.join(args.output, "%s_LXRT.pth" % name)) if args.get("save_optimizer", False) and "Step" not in name: torch.save(self.optim.state_dict(), os.path.join(args.output, "%s_LXRT_optimizer.pth" % name)) def load(self, path, t_total): print("Load model from %s" % path) state_dict = torch.load("%s_LXRT.pth" % path) #self.model.load_state_dict(state_dict) from qa_answer_table import load_state_dict_flexible load_state_dict_flexible(self.model, state_dict) optimizer_path = "{}_LXRT_optimizer.pth".format(path) if os.path.exists(optimizer_path) and args.get("load_optimizer", True): print("Load optimizer from {}".format(optimizer_path)) loaded_optim = torch.load(optimizer_path) if args.get("reset_schedule", False): for group in loaded_optim["param_groups"]: group['lr'] = args.lr group['warmup'] = args.warmup_ratio group["t_total"] = t_total for p in group['params']: loaded_optim["state"][p]["step"] loaded_optim["state"][p]["step"] = 0 self.optim.load_state_dict(loaded_optim) def load_lxmert(self, path): print("Load LXMERT model from %s" % path) state_dict = torch.load("%s_LXRT.pth" % path) # Do not load any answer head for key in list(state_dict.keys()): if 'answer' in key: state_dict.pop(key) # Change Multi GPU to single GPU new_state_dict = {} for key, value in state_dict.items(): if key.startswith("module."): new_state_dict[key[len("module."):]] = value state_dict = new_state_dict load_keys = set(state_dict.keys()) model_keys = set(self.model.state_dict().keys()) print() print("Keys in loaded but not in model:") for key in sorted(load_keys.difference(model_keys)): print(key) print() print("Keys in model but not in loaded:") for key in sorted(model_keys.difference(load_keys)): print(key) print() self.model.load_state_dict(state_dict, strict=False)