def train(args, data_loader, model, global_stats): """Run through one epoch of model training with the provided data loader.""" # Initialize meters + timers ml_loss = AverageMeter() epoch_time = Timer() model.optimizer.param_groups[0]['lr'] = \ model.optimizer.param_groups[0]['lr'] * args.lr_decay pbar = tqdm(data_loader) pbar.set_description("%s" % 'Epoch = %d [ml_loss = x.xx]' % global_stats['epoch']) # Run one epoch for idx, ex in enumerate(pbar): bsz = ex['batch_size'] net_loss = model.update(ex) ml_loss.update(net_loss.item(), bsz) log_info = 'Epoch = %d [ml_loss = %.2f]' % \ (global_stats['epoch'], ml_loss.avg) pbar.set_description("%s" % log_info) torch.cuda.empty_cache() logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' % (global_stats['epoch'], epoch_time.time())) # Checkpoint if args.checkpoint: model.checkpoint(args.model_file + '.checkpoint', global_stats['epoch'] + 1)
def eval_accuracies(hypotheses, references, copy_info, sources=None, filename=None, print_copy_info=False): """An unofficial evalutation helper. Arguments: hypotheses: A mapping from instance id to predicted sequences. references: A mapping from instance id to ground truth sequences. copy_info: Map of id --> copy information. sources: Map of id --> input text sequence. filename: print_copy_info: """ assert (sorted(references.keys()) == sorted(hypotheses.keys())) # Compute BLEU scores bleu_scorer = Bleu(n=4) bleu, ind_bleu = bleu_scorer.compute_score(references, hypotheses, verbose=0) # Compute ROUGE scores rouge_calculator = Rouge() rouge_l, ind_rouge = rouge_calculator.compute_score(references, hypotheses) f1 = AverageMeter() exact_match = AverageMeter() fw = open(filename, 'w') if filename else None for key in references.keys(): exact_match.update( metric_max_over_ground_truths(exact_match_score, hypotheses[key][0], references[key])) f1.update( metric_max_over_ground_truths(f1_score, hypotheses[key][0], references[key])) if fw: if copy_info is not None and print_copy_info: prediction = hypotheses[key][0].split() pred_i = [ word + ' [' + str(copy_info[key][j]) + ']' for j, word in enumerate(prediction) ] pred_i = [' '.join(pred_i)] else: pred_i = hypotheses[key] logobj = OrderedDict() logobj['session_id'] = key if sources is not None: logobj['previous_queries'] = sources[key] logobj['predictions'] = pred_i logobj['references'] = references[key][0] if args.print_one_target \ else references[key] logobj['bleu'] = ind_bleu[key] fw.write(json.dumps(logobj) + '\n') if fw: fw.close() return bleu, rouge_l * 100, exact_match.avg * 100, f1.avg * 100
def validate_official(args, data_loader, model, global_stats=None): """Run one full official validation. Uses exact spans and same exact match/F1 score computation as in the SQuAD script. Extra arguments: offsets: The character start/end indices for the tokens in each context. texts: Map of qid --> raw text of examples context (matches offsets). answers: Map of qid --> list of accepted answers. """ eval_time = Timer() # Run through examples examples = 0 map = AverageMeter() mrr = AverageMeter() prec_1 = AverageMeter() prec_3 = AverageMeter() prec_5 = AverageMeter() sources, hypotheses, references = dict(), dict(), dict() with torch.no_grad(): pbar = tqdm(data_loader) for ex in pbar: batch_size = ex['batch_size'] * ex['session_len'] outputs = model.predict(ex) scores = outputs['click_scores'].view(batch_size, -1).contiguous() labels = ex['document_labels'].view(batch_size, -1).contiguous().numpy() predictions = np.argsort( -scores.cpu().numpy()) # sort in descending order map.update(MAP(predictions, labels)) mrr.update(MRR(predictions, labels)) prec_1.update(precision_at_k(predictions, labels, 1)) prec_3.update(precision_at_k(predictions, labels, 3)) prec_5.update(precision_at_k(predictions, labels, 5)) ex_ids = outputs['ex_ids'] predictions = outputs['predictions'] targets = outputs['targets'] src_sequences = outputs['src_sequences'] examples += batch_size for key, src, pred, tgt in zip(ex_ids, src_sequences, predictions, targets): hypotheses[key] = [normalize_string(p) for p in pred] \ if isinstance(pred, list) else [normalize_string(pred)] references[key] = [normalize_string(t) for t in tgt] sources[key] = src if global_stats is not None: pbar.set_description("%s" % 'Epoch = %d [validating ... ]' % global_stats['epoch']) else: pbar.set_description("%s" % '[evaluating ... ]') bleu, rouge, exact_match, f1 = eval_accuracies( hypotheses, references, None, sources=sources, filename=args.pred_file, print_copy_info=args.print_copy_info) bleu = [b * 100 for b in bleu] \ if isinstance(bleu, list) else bleu result = dict() result['rouge'] = rouge result['bleu'] = sum(bleu) / len(bleu) \ if isinstance(bleu, list) else bleu result['em'] = exact_match result['f1'] = f1 result['map'] = map.avg result['mrr'] = mrr.avg result['prec@1'] = prec_1.avg result['prec@3'] = prec_3.avg result['prec@5'] = prec_5.avg if global_stats is None: logger.info( 'test results: MAP = %.2f | MRR = %.2f | Prec@1 = %.2f | ' % (result['map'], result['mrr'], result['prec@1']) + 'Prec@3 = %.2f | Prec@5 = %.2f | ' % (result['prec@3'], result['prec@5']) + 'rouge_l = %.2f | bleu = [%s] | ' % (rouge, ", ".join(format(b, ".2f") for b in bleu)) + 'EM = %.2f | F1 = %.2f | examples = %d | ' % (exact_match, f1, examples) + 'test time = %.2f (s)' % eval_time.time()) else: logger.info('dev results: MAP = %.2f | MRR = %.2f | Prec@1 = %.2f | ' % (result['map'], result['mrr'], result['prec@1']) + 'Prec@3 = %.2f | Prec@5 = %.2f | ' % (result['prec@3'], result['prec@5']) + 'rouge_l = %.2f | bleu = [%s] | ' % (rouge, ", ".join(format(b, ".2f") for b in bleu)) + 'EM = %.2f | F1 = %.2f | examples = %d | ' % (exact_match, f1, examples) + 'valid time = %.2f (s)' % eval_time.time()) return result
def validate_official(args, data_loader, model, global_stats=None): """Run one full official validation. Uses exact spans and same exact match/F1 score computation as in the SQuAD script. Extra arguments: offsets: The character start/end indices for the tokens in each context. texts: Map of qid --> raw text of examples context (matches offsets). answers: Map of qid --> list of accepted answers. """ eval_time = Timer() # Run through examples examples = 0 map = AverageMeter() mrr = AverageMeter() prec_1 = AverageMeter() prec_3 = AverageMeter() prec_5 = AverageMeter() with torch.no_grad(): pbar = tqdm(data_loader) for ex in pbar: ids, batch_size = ex['ids'], ex['batch_size'] scores = model.predict(ex) predictions = np.argsort( -scores.cpu().numpy()) # sort in descending order labels = ex['label'].numpy() map.update(MAP(predictions, labels)) mrr.update(MRR(predictions, labels)) prec_1.update(precision_at_k(predictions, labels, 1)) prec_3.update(precision_at_k(predictions, labels, 3)) prec_5.update(precision_at_k(predictions, labels, 5)) if global_stats is None: pbar.set_description('[testing ... ]') else: pbar.set_description("%s" % 'Epoch = %d [validating... ]' % global_stats['epoch']) examples += batch_size result = dict() result['map'] = map.avg result['mrr'] = mrr.avg result['prec@1'] = prec_1.avg result['prec@3'] = prec_3.avg result['prec@5'] = prec_5.avg if global_stats is None: logger.info( 'test results: MAP = %.2f | MRR = %.2f | Prec@1 = %.2f | ' % (result['map'], result['mrr'], result['prec@1']) + 'Prec@3 = %.2f | Prec@5 = %.2f | examples = %d | ' % (result['prec@3'], result['prec@5'], examples) + 'time elapsed = %.2f (s)' % (eval_time.time())) else: logger.info('valid official: Epoch = %d | MAP = %.2f | ' % (global_stats['epoch'], result['map']) + 'MRR = %.2f | Prec@1 = %.2f | Prec@3 = %.2f | ' % (result['mrr'], result['prec@1'], result['prec@3']) + 'Prec@5 = %.2f | examples = %d | valid time = %.2f (s)' % (result['prec@5'], examples, eval_time.time())) return result