def eval_accuracies(pred_s, target_s, pred_e, target_e): """An unofficial evalutation helper. Compute exact start/end/complete match accuracies for a batch. """ # Convert 1D tensors to lists of lists (compatibility) if torch.is_tensor(target_s): target_s = [[e] for e in target_s] target_e = [[e] for e in target_e] # Compute accuracies from targets batch_size = len(pred_s) start = utils.AverageMeter() end = utils.AverageMeter() em = utils.AverageMeter() for i in range(batch_size): # Start matches if pred_s[i] in target_s[i]: start.update(1) else: start.update(0) # End matches if pred_e[i] in target_e[i]: end.update(1) else: end.update(0) # Both start and end match if any([1 for _s, _e in zip(target_s[i], target_e[i]) if _s == pred_s[i] and _e == pred_e[i]]): em.update(1) else: em.update(0) return start.avg * 100, end.avg * 100, em.avg * 100
def pretrain_selector(args, data_loader, model, global_stats, exs_with_doc, docs_by_question): """Run through one epoch of model training with the provided data loader.""" # Initialize meters + timers train_loss = utils.AverageMeter() epoch_time = utils.Timer() # Run one epoch tot_ans = 0 tot_num = 0 global HasAnswer_Map for idx, ex_with_doc in enumerate(data_loader): ex = ex_with_doc[0] batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1] if (idx not in HasAnswer_Map): HasAnswer_list = [] for idx_doc in range(0, vector.num_docs): HasAnswer = [] for i in range(batch_size): has_a, a_l = has_answer( args, exs_with_doc[ex_id[i]]['answer'], docs_by_question[ex_id[i]][idx_doc % len( docs_by_question[ex_id[i]])]["document"]) HasAnswer.append(has_a) HasAnswer_list.append(HasAnswer) #HasAnswer_list = torch.LongTensor(HasAnswer_list) HasAnswer_Map[idx] = HasAnswer_list else: HasAnswer_list = HasAnswer_Map[idx] for idx_doc in range(0, vector.num_docs): for i in range(batch_size): tot_ans += HasAnswer_list[idx_doc][i] tot_num += 1 weights = [] for idx_doc in range(0, vector.num_docs): weights.append(1) weights = torch.Tensor(weights) idx_random = torch.multinomial(weights, int(vector.num_docs)) HasAnswer_list_sample = [] ex_with_doc_sample = [] for idx_doc in idx_random: HasAnswer_list_sample.append(HasAnswer_list[idx_doc]) ex_with_doc_sample.append(ex_with_doc[idx_doc]) HasAnswer_list_sample = torch.LongTensor(HasAnswer_list_sample) train_loss.update(*model.pretrain_selector(ex_with_doc_sample, HasAnswer_list_sample)) #train_loss.update(*model.pretrain_ranker(ex_with_doc, HasAnswer_list)) if idx % args.display_iter == 0: logger.info('train: Epoch = %d | iter = %d/%d | ' % (global_stats['epoch'], idx, len(data_loader)) + 'loss = %.2f | elapsed time = %.2f (s)' % (train_loss.avg, global_stats['timer'].time())) logger.info("tot_ans:\t%d\t%d\t%f", tot_ans, tot_num, tot_ans * 1.0 / tot_num) train_loss.reset() logger.info("tot_ans:\t%d\t%d", tot_ans, tot_num) logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' % (global_stats['epoch'], epoch_time.time()))
def pretrain_reader(args, data_loader, model, global_stats, exs_with_doc, docs_by_question): """Run through one epoch of model training with the provided data loader.""" # Initialize meters + timers train_loss = utils.AverageMeter() epoch_time = utils.Timer() logger.info("pretrain_reader") # Run one epoch global HasAnswer_Map count_ans = 0 count_tot = 0 for idx, ex_with_doc in enumerate(data_loader): #logger.info(idx) ex = ex_with_doc[0] batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1] if (idx not in HasAnswer_Map): HasAnswer_list = [] for idx_doc in range(0, vector.num_docs): HasAnswer = [] for i in range(batch_size): HasAnswer.append( has_answer( args, exs_with_doc[ex_id[i]]['answer'], docs_by_question[ex_id[i]][idx_doc % len( docs_by_question[ex_id[i]])]["document"])) HasAnswer_list.append(HasAnswer) HasAnswer_Map[idx] = HasAnswer_list else: HasAnswer_list = HasAnswer_Map[idx] for idx_doc in range(0, vector.num_docs): l_list = [] r_list = [] pred_s, pred_e, pred_score = model.predict(ex_with_doc[idx_doc], top_n=1) for i in range(batch_size): if HasAnswer_list[idx_doc][i][0]: count_ans += len(HasAnswer_list[idx_doc][i][1]) count_tot += 1 l_list.append(HasAnswer_list[idx_doc][i][1]) else: l_list.append([(int(pred_s[i][0]), int(pred_e[i][0]))]) train_loss.update(*model.update(ex_with_doc[idx_doc], l_list, r_list, HasAnswer_list[idx_doc])) if idx % args.display_iter == 0: logger.info('train: Epoch = %d | iter = %d/%d | ' % (global_stats['epoch'], idx, len(data_loader)) + 'loss = %.2f | elapsed time = %.2f (s)' % (train_loss.avg, global_stats['timer'].time())) train_loss.reset() logger.info("%d\t%d\t%f", count_ans, count_tot, 1.0 * count_ans / (count_tot + 1)) logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' % (global_stats['epoch'], epoch_time.time()))
def validate_with_doc(args, data_loader, model, global_stats, exs_with_doc, docs_by_question, mode): '''Run one full unofficial validation with docs. Unofficial = doesn't use SQuAD script. ''' eval_time = utils.Timer() f1 = utils.AverageMeter() exact_match = utils.AverageMeter() logger.info('validate_with_doc') # Intialize counters examples = 0 aa = [0.0 for i in range(num_docs)] # increment only if example has answer bb = [0.0 for i in range(num_docs)] # increment regardless for idx, ex_with_doc in enumerate(data_loader): ex = ex_with_doc[0] batch_size, ex_id = ex[0].size(0), ex[-1] # --------------------------------------------------------------------- # Document Selector # --------------------------------------------------------------------- ''' ex_with_doc = [tensor] x1 = document word indices [batch * len_d] [tensor] x1_f = document word features indices [batch * len_d * nfeat] [tensor] x1_mask = document padding mask [batch * len_d] [tensor] x2 = question word indices [batch * len_q] [tensor] x2_mask = question padding mask [batch * len_q] [list] indices [batch] ''' scores_doc_num = model.predict_with_doc(ex_with_doc) scores = [{} for i in range(batch_size)] # --------------------------------------------------------------------- # Document Reader # --------------------------------------------------------------------- for idx_doc in range(0, num_docs): ex = ex_with_doc[idx_doc] pred_s, pred_e, pred_score = model.predict( ex, top_n=display_num) for i in range(batch_size): idx_doc_i = idx_doc %len(docs_by_question[ex_id[i]]) doc_text = docs_by_question[ex_id[i]][idx_doc_i]['document'] # try to read the 10 best predicted answers (this may trigger # an 'index out of range' exception) for k in range(display_num): try: prediction = [doc_text[j] for j in range(pred_s[i][k], pred_e[i][k]+1)] prediction = ' '.join(prediction).lower() # update prediction scores if (prediction not in scores[i]): scores[i][prediction] = 0 scores[i][prediction] += (pred_score[i][k] * scores_doc_num[i][idx_doc]) except: pass # Get the 10 most likely answers from the batch and see if the answer # is actually in there for i in range(batch_size): _, indices = scores_doc_num[i].sort(0, descending = True) for j in range(0, display_num): idx_doc = indices[j] idx_doc_i = idx_doc %len(docs_by_question[ex_id[i]]) doc_text = docs_by_question[ex_id[i]][idx_doc_i]['document'] ex_answer = exs_with_doc[ex_id[i]]['answer'] # Looking for the answer in the document... if (has_answer(args, ex_answer, doc_text)[0]): aa[j]= aa[j] + 1 bb[j]= bb[j]+1 # Update performance metrics for i in range(batch_size): best_score = 0 prediction = '' for key in scores[i]: if (scores[i][key] > best_score): best_score = scores[i][key] prediction = key ground_truths = [] ex_answer = exs_with_doc[ex_id[i]]['answer'] # Ground truth answers if (args.dataset == 'CuratedTrec'): # not applicable ground_truths = ex_answer else: for a in ex_answer: ground_truths.append(' '.join([w for w in a])) exact_match.update( utils.metric_max_over_ground_truths( utils.exact_match_score, prediction, ground_truths)) f1.update( utils.metric_max_over_ground_truths( utils.f1_score, prediction, ground_truths)) examples += batch_size if (mode=='train' and examples>=1000): break try: for j in range(display_num): if (j>0): aa[j]= aa[j]+aa[j-1] bb[j]= bb[j]+bb[j-1] except: pass txt = '{} valid official with doc: Epoch = {} | EM = {:.2f} | ' txt += 'F1 = {:.2f} | examples = {} | valid time = {:.2f} (s)' logger.info(txt.format( mode, global_stats['epoch'], exact_match.avg * 100, f1.avg * 100, examples, eval_time.time())) return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
def validate_unofficial_with_doc(args, data_loader, model, global_stats, exs_with_doc, docs_by_question, mode): """Run one full unofficial validation with docs. Unofficial = doesn't use SQuAD script. """ eval_time = utils.Timer() f1 = utils.AverageMeter() exact_match = utils.AverageMeter() out_set = set({33, 42, 45, 70, 39}) logger.info("validate_unofficial_with_doc") # Run through examples examples = 0 aa = [0.0 for i in range(vector.num_docs)] bb = [0.0 for i in range(vector.num_docs)] aa_sum = 0.0 display_num = 10 for idx, ex_with_doc in enumerate(data_loader): ex = ex_with_doc[0] batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1] scores_doc_num = model.predict_with_doc(ex_with_doc) scores = [{} for i in range(batch_size)] tot_sum = [0.0 for i in range(batch_size)] tot_sum1 = [0.0 for i in range(batch_size)] neg_sum = [0.0 for i in range(batch_size)] min_sum = [[] for i in range(batch_size)] min_sum1 = [[] for i in range(batch_size)] for idx_doc in range(0, vector.num_docs): ex = ex_with_doc[idx_doc] pred_s, pred_e, pred_score = model.predict(ex, top_n=10) for i in range(batch_size): doc_text = docs_by_question[ex_id[i]][idx_doc % len( docs_by_question[ex_id[i]])]["document"] has_answer_t = has_answer(args, exs_with_doc[ex_id[i]]['answer'], doc_text) for k in range(10): try: prediction = [] for j in range(pred_s[i][k], pred_e[i][k] + 1): prediction.append(doc_text[j]) prediction = " ".join(prediction).lower() if (prediction not in scores[i]): scores[i][prediction] = 0 scores[i][prediction] += pred_score[i][ k] * scores_doc_num[i][idx_doc] except: pass for i in range(batch_size): _, indices = scores_doc_num[i].sort(0, descending=True) for j in range(0, display_num): idx_doc = indices[j] doc_text = docs_by_question[ex_id[i]][idx_doc % len( docs_by_question[ex_id[i]])]["document"] if (has_answer(args, exs_with_doc[ex_id[i]]['answer'], doc_text)[0]): aa[j] = aa[j] + 1 bb[j] = bb[j] + 1 for i in range(batch_size): best_score = 0 prediction = "" for key in scores[i]: if (scores[i][key] > best_score): best_score = scores[i][key] prediction = key # Compute metrics ground_truths = [] answer = exs_with_doc[ex_id[i]]['answer'] if (args.dataset == "CuratedTrec"): ground_truths = answer else: for a in answer: ground_truths.append(" ".join([w for w in a])) #logger.info(prediction) #logger.info(ground_truths) exact_match.update( utils.metric_max_over_ground_truths(utils.exact_match_score, prediction, ground_truths)) f1.update( utils.metric_max_over_ground_truths(utils.f1_score, prediction, ground_truths)) a = sorted(scores[i].items(), key=lambda d: d[1], reverse=True) examples += batch_size if (mode == "train" and examples >= 1000): break try: for j in range(0, display_num): if (j > 0): aa[j] = aa[j] + aa[j - 1] bb[j] = bb[j] + bb[j - 1] logger.info(aa[j] / bb[j]) except: pass logger.info(aa_sum) if (mode == 'dev' or mode == 'train'): g.write("*" * 50 + "\n") g.close() logger.info('%s valid official with doc: Epoch = %d | EM = %.2f | ' % (mode, global_stats['epoch'], exact_match.avg * 100) + 'F1 = %.2f | examples = %d | valid time = %.2f (s)' % (f1.avg * 100, examples, eval_time.time())) return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
def train(args, data_loader, model, global_stats, exs_with_doc, docs_by_question): """Run through one epoch of model training with the provided data loader.""" # Initialize meters + timers train_loss = utils.AverageMeter() epoch_time = utils.Timer() # Run one epoch update_step = 0 for idx, ex_with_doc in enumerate(data_loader): ex = ex_with_doc[0] batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1] if (idx not in HasAnswer_Map): HasAnswer_list = [] for idx_doc in range(0, vector.num_docs): HasAnswer = [] for i in range(batch_size): HasAnswer.append( has_answer( args, exs_with_doc[ex_id[i]]['answer'], docs_by_question[ex_id[i]][idx_doc % len( docs_by_question[ex_id[i]])]["document"])) HasAnswer_list.append(HasAnswer) HasAnswer_Map[idx] = HasAnswer_list else: HasAnswer_list = HasAnswer_Map[idx] weights = [] for idx_doc in range(0, vector.num_docs): weights.append(1) weights = torch.Tensor(weights) idx_random = torch.multinomial(weights, int(vector.num_docs)) HasAnswer_list_sample = [] ex_with_doc_sample = [] for idx_doc in idx_random: HasAnswer_list_sample.append(HasAnswer_list[idx_doc]) ex_with_doc_sample.append(ex_with_doc[idx_doc]) l_list_doc = [] r_list_doc = [] for idx_doc in idx_random: l_list = [] r_list = [] for i in range(batch_size): if HasAnswer_list[idx_doc][i][0]: l_list.append(HasAnswer_list[idx_doc][i][1]) else: l_list.append((-1, -1)) l_list_doc.append(l_list) r_list_doc.append(r_list) pred_s_list_doc = [] pred_e_list_doc = [] tmp_top_n = 1 for idx_doc in idx_random: ex = ex_with_doc[idx_doc] pred_s, pred_e, pred_score = model.predict(ex, top_n=tmp_top_n) pred_s_list = [] pred_e_list = [] for i in range(batch_size): pred_s_list.append(pred_s[i].tolist()) pred_e_list.append(pred_e[i].tolist()) pred_s_list_doc.append(torch.LongTensor(pred_s_list)) pred_e_list_doc.append(torch.LongTensor(pred_e_list)) train_loss.update(*model.update_with_doc( update_step, ex_with_doc_sample, pred_s_list_doc, pred_e_list_doc, tmp_top_n, l_list_doc, r_list_doc, HasAnswer_list_sample)) update_step = (update_step + 1) % 4 if idx % args.display_iter == 0: logger.info('train: Epoch = %d | iter = %d/%d | ' % (global_stats['epoch'], idx, len(data_loader)) + 'loss = %.2f | elapsed time = %.2f (s)' % (train_loss.avg, global_stats['timer'].time())) train_loss.reset() if (idx % 200 == 199): validate_unofficial_with_doc(args, data_loader, model, global_stats, exs_with_doc, docs_by_question, 'train') logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' % (global_stats['epoch'], epoch_time.time())) # Checkpoint if args.checkpoint: model.checkpoint(args.model_file + '.checkpoint', global_stats['epoch'] + 1)
def update_evidence(args, data_loader, model, global_stats, exs_with_doc, docs_by_question): Top_k = args.top_k logger.info('Top k is set to %d' % (Top_k)) Probability = {} Attention_Weight = {} """Run through one epoch of model training with the provided data loader.""" # Initialize meters + timers train_prob = utils.AverageMeter() train_attention = utils.AverageMeter() epoch_time = utils.Timer() # Run one epoch update_step = 0 for idx, ex_with_doc in enumerate(data_loader): ex = ex_with_doc[0] batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1] if (idx not in HasAnswer_Map): HasAnswer_list = [] for idx_doc in range(0, vector.num_docs): HasAnswer = [] for i in range(batch_size): HasAnswer.append(has_answer(args, exs_with_doc[ex_id[i]]['answer'], docs_by_question[ex_id[i]][idx_doc%len(docs_by_question[ex_id[i]])]["document"])) HasAnswer_list.append(HasAnswer) HasAnswer_Map[idx] = HasAnswer_list else: HasAnswer_list = HasAnswer_Map[idx] if (idx not in Evidence_Label): Evidence_list = [-1] * batch_size Evidence_Label[idx] = Evidence_list # Don't shuffle when update evidence idx_random = range(vector.num_docs) HasAnswer_list_sample = [] ex_with_doc_sample = [] for idx_doc in idx_random: HasAnswer_list_sample.append(HasAnswer_list[idx_doc]) ex_with_doc_sample.append(ex_with_doc[idx_doc]) l_list_doc = [] r_list_doc = [] for idx_doc in idx_random: l_list = [] r_list = [] for i in range(batch_size): if HasAnswer_list[idx_doc][i][0]: l_list.append(HasAnswer_list[idx_doc][i][1]) else: l_list.append((-1,-1)) l_list_doc.append(l_list) r_list_doc.append(r_list) pred_s_list_doc = [] pred_e_list_doc = [] tmp_top_n = 1 for idx_doc in idx_random: ex = ex_with_doc[idx_doc] pred_s, pred_e, pred_score = model.predict(ex,top_n = tmp_top_n) pred_s_list = [] pred_e_list = [] for i in range(batch_size): pred_s_list.append(pred_s[i].tolist()) pred_e_list.append(pred_e[i].tolist()) pred_s_list_doc.append(torch.LongTensor(pred_s_list)) pred_e_list_doc.append(torch.LongTensor(pred_e_list)) probs, attentions = model.update_with_doc(update_step, ex_with_doc_sample, \ pred_s_list_doc, pred_e_list_doc, tmp_top_n, \ l_list_doc, r_list_doc, HasAnswer_list_sample, \ return_prob=True) train_prob.update(np.mean(probs), batch_size) train_attention.update(np.mean(attentions[0]), batch_size) update_step = (update_step + 1) % 4 if idx % args.display_iter == 0: logger.info('Update Evidence: Epoch = %d | iter = %d/%d | ' % (global_stats['epoch'], idx, len(data_loader)) + 'Average prob = %f | Average attention = %f | elapsed time = %.2f (s)' % (train_prob.avg, train_attention.avg, global_stats['timer'].time())) for i in range(batch_size): key = "%d|%d" % (idx, i) if key in Probability or key in Attention_Weight: raise ValueError("%s exists in Probability or Attention_Weight" % (key)) # Add threshold here Probability[key] = probs[i] Attention_Weight[key] = (attentions[0][i], attentions[1][i]) # max_value, max_index # break evidence_scores = {key: (attention[0], attention[1]) for key, attention in Attention_Weight.items() if attention[1] != -1} evidence_scores = sorted(evidence_scores.items(), key=lambda x: x[1][0], reverse=True) count = 0 label_prob = [] label_attention = [] for key, value in evidence_scores: idx, i = key.split('|') idx = int(idx) i = int(i) if Evidence_Label[idx][i] != -1: continue count += 1 Evidence_Label[idx][i] = value[1] label_prob.append(Probability[key]) label_attention.append(Attention_Weight[key][0]) if count >= Top_k: break logger.info('Update Evidence: Epoch %d done. Time for epoch = %.2f (s). Average prob = %f. Average attention = %f.' % (global_stats['epoch'], epoch_time.time(), train_prob.avg, train_attention.avg)) logger.info('Update Evidence: Label %d examples. Average prob = %f. Average attention = %f.' % (count, np.mean(label_prob), np.mean(label_attention)))
def train(args, data_loader, model, global_stats, exs_with_doc, docs_by_question): '''Run through one epoch of model training with the provided data loader.''' # Initialize meters and timers train_loss = utils.AverageMeter() epoch_time = utils.Timer() # Run one epoch global HasAnswer_Map update_step = 0 for idx, ex_with_doc in enumerate(data_loader): ex = ex_with_doc[0] batch_size, ex_id = ex[0].size(0), ex[-1] # Display GPU usage statitstics every <display_stats> iterations show_stats = (args.show_cuda_stats and (idx % args.display_stats == args.display_stats - 1)) if (idx not in HasAnswer_Map): HasAnswer_list = [] for idx_doc in range(0, num_docs): HasAnswer = [] for i in range(batch_size): idx_doc_i = idx_doc % len(docs_by_question[ex_id[i]]) answer = exs_with_doc[ex_id[i]]['answer'] document = docs_by_question[ ex_id[i]][idx_doc_i]['document'] # --------------------------------------------------------- # Looking for the answer in the document... # --------------------------------------------------------- HasAnswer.append(has_answer(args, answer, document)) # --------------------------------------------------------- HasAnswer_list.append(HasAnswer) HasAnswer_Map[idx] = HasAnswer_list else: HasAnswer_list = HasAnswer_Map[idx] # Initializing weights and sampling indices... weights = torch.tensor([1.0 for idx_doc in range(0, num_docs)]) idx_random = torch.multinomial(weights, int(num_docs)) HasAnswer_list_sample = [] ex_with_doc_sample = [] for idx_doc in idx_random: HasAnswer_list_sample.append(HasAnswer_list[idx_doc]) ex_with_doc_sample.append(ex_with_doc[idx_doc]) l_list_doc = [] r_list_doc = [] for idx_doc in idx_random: l_list = [] r_list = [] for i in range(batch_size): if HasAnswer_list[idx_doc][i][0]: l_list.append(HasAnswer_list[idx_doc][i][1]) else: l_list.append((-1, -1)) l_list_doc.append(l_list) r_list_doc.append(r_list) # Generating predictions... pred_s_list_doc = [] pred_e_list_doc = [] tmp_top_n = 1 # CUDA memory before forward pass txt_cuda(show_stats, 'before forward pass') for idx_doc in idx_random: ex = ex_with_doc[idx_doc] pred_s, pred_e, pred_score = model.predict(ex, top_n=tmp_top_n) pred_s_list = [] pred_e_list = [] for i in range(batch_size): pred_s_list.append(pred_s[i].tolist()) pred_e_list.append(pred_e[i].tolist()) pred_s_list_doc.append(torch.tensor(pred_s_list, dtype=torch.long)) pred_e_list_doc.append(torch.tensor(pred_e_list, dtype=torch.long)) # CUDA memory before backpropagation txt_cuda(show_stats, 'before backpropagation') # --------------------------------------------------------------------- # Updating (one epoch)... # --------------------------------------------------------------------- train_loss.update(*model.update_with_doc( update_step, ex_with_doc_sample, pred_s_list_doc, pred_e_list_doc, tmp_top_n, l_list_doc, r_list_doc, HasAnswer_list_sample)) # --------------------------------------------------------------------- update_step = (update_step + 1) % 4 # --------------------------------------------------------------------- # CUDA memory after backpropagation txt_cuda(show_stats, 'after backpropagation') if show_stats: gpu_usage() # Resetting... if idx % args.display_iter == 0: txt = 'train: Epoch = {} | iter = {}/{} | loss = {:.2f} | ' txt += 'elapsed time = {:.2f} (s)' logger.info( txt.format(global_stats['epoch'], idx, len(data_loader), train_loss.avg, global_stats['timer'].time())) train_loss.reset() # Validation... if show_stats: with torch.no_grad(): validate_with_doc(args, data_loader, model, global_stats, exs_with_doc, docs_by_question, mode='train') logger.info('-' * 100) txt = 'train: Epoch {} done. Time for epoch = {:.2f} (s)' logger.info(txt.format(global_stats['epoch'], epoch_time.time())) logger.info('-' * 100) # Checkpoint if args.checkpoint: model.checkpoint(args.model_file + '.checkpoint', global_stats['epoch'] + 1)
def pretrain_reader(args, data_loader, model, global_stats, exs_with_doc, docs_by_question): '''Run through one epoch of model training with the provided data loader.''' # Initialize meters and timers train_loss = utils.AverageMeter() epoch_time = utils.Timer() logger.info('pretrain_reader') # Run one epoch global HasAnswer_Map count_ans = 0 count_tot = 0 for idx, ex_with_doc in enumerate(data_loader): ex = ex_with_doc[0] batch_size, ex_id = ex[0].size(0), ex[-1] if (idx not in HasAnswer_Map): HasAnswer_list = [] for idx_doc in range(0, num_docs): HasAnswer = [] for i in range(batch_size): idx_doc_i = idx_doc % len(docs_by_question[ex_id[i]]) answer = exs_with_doc[ex_id[i]]['answer'] document = docs_by_question[ ex_id[i]][idx_doc_i]['document'] # Looking for the answer in the document... # --------------------------------------------------------- # Here we do care about the presence/absence of answers # AND their positions in the documents # --------------------------------------------------------- HasAnswer.append(has_answer(args, answer, document)) # --------------------------------------------------------- HasAnswer_list.append(HasAnswer) HasAnswer_Map[idx] = HasAnswer_list else: HasAnswer_list = HasAnswer_Map[idx] # Forward pass for the batch... for idx_doc in range(0, num_docs): l_list = [] r_list = [] # Forward pass for the batch... pred_s, pred_e, pred_score = model.predict(ex_with_doc[idx_doc], top_n=1) for i in range(batch_size): if HasAnswer_list[idx_doc][i][0]: count_ans += int(HasAnswer_list[idx_doc][i][0]) count_tot += 1 # Store recorded answers' positions in a list l_list.append(HasAnswer_list[idx_doc][i][1]) else: # Store the most answers' predicted positions l_list.append([(int(pred_s[i][0]), int(pred_e[i][0]))]) # ----------------------------------------------------------------- # Model update: weights are adjusted so as to minimize the loss # function / reducing inconsistencies between predicted and actual # answer positions # ----------------------------------------------------------------- train_loss.update(*model.update(ex_with_doc[idx_doc], l_list, r_list, HasAnswer_list[idx_doc])) # ----------------------------------------------------------------- # Resetting train loss... if idx % args.display_iter == 0: txt = 'train: Epoch = {} | iter = {}/{} | loss = {:.2f} | ' txt += 'elapsed time = {:.2f} (s)' logger.info( txt.format(global_stats['epoch'], idx, len(data_loader), train_loss.avg, global_stats['timer'].time())) train_loss.reset() txt = 'count_ans: {} | count_tot: {} | count_ans/count_tot: {:.2f} (%)' logger.info( txt.format(count_ans, count_tot, 100.0 * count_ans / (count_tot + 1))) logger.info('-' * 100) txt = 'train: Epoch {} done. Time for epoch = {:.2f} (s)' logger.info(txt.format(global_stats['epoch'], epoch_time.time())) logger.info('-' * 100)
def pretrain_selector(args, data_loader, model, global_stats, exs_with_doc, docs_by_question): '''Run through one epoch of model training with the provided data loader.''' # Initialize meters and timers train_loss = utils.AverageMeter() epoch_time = utils.Timer() # Run one epoch global HasAnswer_Map tot_ans = 0 tot_num = 0 for idx, ex_with_doc in enumerate(data_loader): ex = ex_with_doc[0] batch_size, ex_id = ex[0].size(0), ex[-1] # Update the answer mapping # with starting and ending positions if an answer is found if (idx not in HasAnswer_Map): HasAnswer_list = [] for idx_doc in range(0, num_docs): HasAnswer = [] for i in range(batch_size): idx_doc_i = idx_doc % len(docs_by_question[ex_id[i]]) answer = exs_with_doc[ex_id[i]]['answer'] document = docs_by_question[ ex_id[i]][idx_doc_i]['document'] # --------------------------------------------------------- # Looking for the answer in the document... # [positions are ** ignored ** at this stage] # --------------------------------------------------------- bool_has, _ = has_answer(args, answer, document) HasAnswer.append((bool_has, )) # --------------------------------------------------------- HasAnswer_list.append(HasAnswer) HasAnswer_Map[idx] = HasAnswer_list else: HasAnswer_list = HasAnswer_Map[idx] # Update counters for idx_doc in range(0, num_docs): for i in range(batch_size): tot_ans += int(HasAnswer_list[idx_doc][i][0]) tot_num += 1 # Randomly sample the dataset to fit the model's input size weights = torch.tensor([1.0 for idx_doc in range(0, num_docs)]) idx_random = torch.multinomial(weights, int(num_docs)) HasAnswer_list_sample = [] ex_with_doc_sample = [] for idx_doc in idx_random: HasAnswer_list_idx_doc = [ HasAnswer_list[idx_doc][i][0] for i in range(batch_size) ] HasAnswer_list_sample.append(HasAnswer_list_idx_doc) ex_with_doc_sample.append(ex_with_doc[idx_doc]) HasAnswer_list_sample = torch.tensor(HasAnswer_list_sample, dtype=torch.long) # --------------------------------------------------------------------- # Updating train loss... # --------------------------------------------------------------------- train_loss.update(*model.pretrain_selector(ex_with_doc_sample, HasAnswer_list_sample)) # --------------------------------------------------------------------- # Resetting... if idx % args.display_iter == 0: txt = 'train: Epoch = {} | iter = {}/{} | loss = {:.2f} | ' txt += 'elapsed time = {:.2f} (s)' logger.info( txt.format(global_stats['epoch'], idx, len(data_loader), train_loss.avg, global_stats['timer'].time())) txt = 'tot_ans: {} | tot_num: {} | tot_ans/tot_num: {:.1f} (%)' logger.info(txt.format(tot_ans, tot_num, tot_ans * 100.0 / tot_num)) train_loss.reset() logger.info('-' * 100) txt = 'tot_ans: {} | tot_num: {}' logger.info(txt.format(tot_ans, tot_num)) txt = 'train: Epoch {} done. Time for epoch = {:.2f} (s)' logger.info(txt.format(global_stats['epoch'], epoch_time.time())) logger.info('-' * 100)