def train(args, data_loader, model, global_stats): """Run through one epoch of model training with the provided data loader.""" # Initialize meters + timers train_loss = utils.AverageMeter() epoch_time = utils.Timer() # Run one epoch for idx, ex in enumerate(data_loader): train_loss.update(*model.update(ex)) # writer.add_scalar("loss", train_loss.avg, idx) # for name, param in model.network.named_parameters(): # writer.add_histogram(name, param.clone().cpu().data.numpy(), idx) if idx % args.display_iter == 0: logger.info('train: Epoch = %d | iter = %d/%d | ' % (global_stats['epoch'], idx, len(data_loader)) + 'loss = %.2f | elapsed time = %.2f (s)' % (train_loss.avg, global_stats['timer'].time())) train_loss.reset() logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' % (global_stats['epoch'], epoch_time.time())) # Checkpoint if args.checkpoint: model.checkpoint(args.model_file + '.checkpoint', global_stats['epoch'] + 1)
def validate_official(args, data_loader, model, global_stats, offsets, texts, answers): """Run one full official validation. Uses exact spans and same exact match/F1 score computation as in the SQuAD script. Extra arguments: offsets: The character start/end indices for the tokens in each context. texts: Map of qid --> raw text of examples context (matches offsets). answers: Map of qid --> list of accepted answers. """ eval_time = utils.Timer() f1 = utils.AverageMeter() exact_match = utils.AverageMeter() # Run through examples examples = 0 for ex in data_loader: ex_id, batch_size = ex[-1], ex[0].size(0) pred_s, pred_e, _ = model.predict(ex) for i in range(batch_size): s_offset = offsets[ex_id[i]][pred_s[i][0]][0] e_offset = offsets[ex_id[i]][pred_e[i][0]][1] prediction = texts[ex_id[i]][s_offset:e_offset] # Compute metrics ground_truths = answers[ex_id[i]] exact_match.update( utils.metric_max_over_ground_truths(utils.exact_match_score, prediction, ground_truths)) f1.update( utils.metric_max_over_ground_truths(utils.f1_score, prediction, ground_truths)) examples += batch_size logger.info('dev valid official: Epoch = %d | EM = %.2f | ' % (global_stats['epoch'], exact_match.avg * 100) + 'F1 = %.2f | examples = %d | valid time = %.2f (s)' % (f1.avg * 100, examples, eval_time.time())) return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
def eval_accuracies(pred, target, mode="dev"): """An unofficial evalutation helper. Compute exact start/end/complete match accuracies for a batch. """ # Convert 1D tensors to lists of lists (compatibility) if torch.is_tensor(target): target = [[e] for e in target] elif torch.is_tensor(target[0]): target = [[e.item()] for e in target[0]] else: target = target[0] # target_e = [[e] for e in target_e] ## make changes according to mode # Compute accuracies from targets batch_size = len(pred) accuracy = utils.AverageMeter() # end = utils.AverageMeter() # em = utils.AverageMeter() for i in range(batch_size): # Start matches flag = False for j in pred[i]: if j in target[i]: flag = True break if flag: accuracy.update(1) else: accuracy.update(0) # End matches # if pred_e[i] in target_e[i]: # end.update(1) # else: # end.update(0) # # # Both start and end match # if any([1 for _s, _e in zip(target_s[i], target_e[i]) # if _s == pred_s[i] and _e == pred_e[i]]): # em.update(1) # else: # em.update(0) return accuracy.avg * 100
def validate_unofficial(args, data_loader, model, global_stats, mode): """Run one full unofficial validation. Unofficial = doesn't use SQuAD script. """ eval_time = utils.Timer() acc = utils.AverageMeter() # end_acc = utils.AverageMeter() # exact_match = utils.AverageMeter() # Make predictions examples = 0 for ex in data_loader: batch_size = ex[0].size(0) pred = model.predict(ex) target = ex[-2:-1] if args.global_mode == "test": preds = np.array([p[0] for p in pred]) sent_lengths = (ex[3].sum(1) - 1).long().data.numpy() attacked = (preds == sent_lengths).sum() # We get metrics for independent start/end and joint start/end accuracy = eval_accuracies(pred, target, mode) acc.update(accuracy, batch_size) # end_acc.update(accuracies[1], batch_size) # exact_match.update(accuracies[2], batch_size) # If getting train accuracies, sample max 10k examples += batch_size if mode == 'train' and examples >= 1e4: break logger.info('%s valid unofficial: Epoch = %d | accuracy = %.2f | ' % (mode, global_stats['epoch'], acc.avg) + 'examples = %d | ' % (examples) + 'valid time = %.2f (s)' % eval_time.time()) if args.global_mode == "test": print(attacked) print(examples) return {'accuracy': acc.avg}
def validate_unofficial(args, data_loader, model, global_stats, mode): """Run one full unofficial validation. Unofficial = doesn't use SQuAD script. """ eval_time = utils.Timer() acc = utils.AverageMeter() # end_acc = utils.AverageMeter() # exact_match = utils.AverageMeter() # fout = open(os.path.join(DATA_DIR,DUMP_FILE), "w+") # Make predictions examples = 0 attacked = 0 attacked_correct = 0 correct = 0 non_adv = 0 for ex in data_loader: batch_size = ex[0].size(0) pred = model.predict(ex, top_n=3) target = ex[-2:-1] if args.global_mode == "test": preds = np.array([p[0] for p in pred]) sent_lengths = (ex[3].sum(1)).long().data.numpy() #attacked += (pred == sent_lengths - 1).sum() for enum_, p in enumerate(pred): # fout.write("%s\t%d\t%d\t%d\n"%(ex[-1][enum_], p[0], p[1], p[2])) true_flag = False if "high" not in ex[-1][enum_]: non_adv += 1 continue for q in p: if q in target[0][enum_]: correct += 1 true_flag = True for q in p: if q == sent_lengths[enum_] - 1: attacked += 1 if q == sent_lengths[enum_] - 1 and true_flag: attacked_correct += 1 #attacked += (pred == sent_lengths - 1).astype(int).sum() # We get metrics for independent start/end and joint start/end accuracy = eval_accuracies(pred, target, mode) acc.update(accuracy, batch_size) # end_acc.update(accuracies[1], batch_size) # exact_match.update(accuracies[2], batch_size) # If getting train accuracies, sample max 10k examples += batch_size if examples % 1000 == 0: print("%d examples completed" % examples) if mode == 'train' and examples >= 1e4: break # fout.close() logger.info('%s valid unofficial: Epoch = %d | accuracy = %.2f | ' % (mode, global_stats['epoch'], acc.avg) + 'examples = %d | ' % (examples) + 'valid time = %.2f (s)' % eval_time.time()) if args.global_mode == "test": print("Number of examples attacked succesfully: %d" % attacked) print("Total number of correct adversarial examples: %d" % correct) print("Number of examples adversarial examples: %d" % non_adv) print("Number of correct examples attacked succesfully: %d" % attacked_correct) print("Number of examples: %d" % examples) return {'accuracy': acc.avg}