def train(args, data_loader, model, global_stats): """Run through one epoch of model training with the provided data loader.""" # Initialize meters + timers train_loss = utils.AverageMeter() epoch_time = utils.Timer() # Run one epoch for idx, ex in enumerate(data_loader): train_loss.update(*model.update(ex)) # writer.add_scalar("loss", train_loss.avg, idx) # for name, param in model.network.named_parameters(): # writer.add_histogram(name, param.clone().cpu().data.numpy(), idx) if idx % args.display_iter == 0: logger.info('train: Epoch = %d | iter = %d/%d | ' % (global_stats['epoch'], idx, len(data_loader)) + 'loss = %.2f | elapsed time = %.2f (s)' % (train_loss.avg, global_stats['timer'].time())) train_loss.reset() logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' % (global_stats['epoch'], epoch_time.time())) # Checkpoint if args.checkpoint: model.checkpoint(args.model_file + '.checkpoint', global_stats['epoch'] + 1)
def validate_unofficial(args, data_loader, model, global_stats, mode): """Run one full unofficial validation. Unofficial = doesn't use SQuAD script. """ eval_time = utils.Timer() acc = utils.AverageMeter() # end_acc = utils.AverageMeter() # exact_match = utils.AverageMeter() # Make predictions examples = 0 for ex in data_loader: batch_size = ex[0].size(0) pred = model.predict(ex) target = ex[-2:-1] if args.global_mode == "test": preds = np.array([p[0] for p in pred]) sent_lengths = (ex[3].sum(1) - 1).long().data.numpy() attacked = (preds == sent_lengths).sum() # We get metrics for independent start/end and joint start/end accuracy = eval_accuracies(pred, target, mode) acc.update(accuracy, batch_size) # end_acc.update(accuracies[1], batch_size) # exact_match.update(accuracies[2], batch_size) # If getting train accuracies, sample max 10k examples += batch_size if mode == 'train' and examples >= 1e4: break logger.info('%s valid unofficial: Epoch = %d | accuracy = %.2f | ' % (mode, global_stats['epoch'], acc.avg) + 'examples = %d | ' % (examples) + 'valid time = %.2f (s)' % eval_time.time()) if args.global_mode == "test": print(attacked) print(examples) return {'accuracy': acc.avg}
def validate_official(args, data_loader, model, global_stats, offsets, texts, answers): """Run one full official validation. Uses exact spans and same exact match/F1 score computation as in the SQuAD script. Extra arguments: offsets: The character start/end indices for the tokens in each context. texts: Map of qid --> raw text of examples context (matches offsets). answers: Map of qid --> list of accepted answers. """ eval_time = utils.Timer() f1 = utils.AverageMeter() exact_match = utils.AverageMeter() # Run through examples examples = 0 for ex in data_loader: ex_id, batch_size = ex[-1], ex[0].size(0) pred_s, pred_e, _ = model.predict(ex) for i in range(batch_size): s_offset = offsets[ex_id[i]][pred_s[i][0]][0] e_offset = offsets[ex_id[i]][pred_e[i][0]][1] prediction = texts[ex_id[i]][s_offset:e_offset] # Compute metrics ground_truths = answers[ex_id[i]] exact_match.update( utils.metric_max_over_ground_truths(utils.exact_match_score, prediction, ground_truths)) f1.update( utils.metric_max_over_ground_truths(utils.f1_score, prediction, ground_truths)) examples += batch_size logger.info('dev valid official: Epoch = %d | EM = %.2f | ' % (global_stats['epoch'], exact_match.avg * 100) + 'F1 = %.2f | examples = %d | valid time = %.2f (s)' % (f1.avg * 100, examples, eval_time.time())) return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
def main(args): # -------------------------------------------------------------------------- # DATA logger.info('-' * 100) logger.info('Load data files') train_exs = [] for t_file in args.train_file: train_exs += utils.load_data(args, t_file, skip_no_answer=True) # Shuffle training examples np.random.shuffle(train_exs) logger.info('Num train examples = %d' % len(train_exs)) dev_exs = utils.load_data(args, args.dev_file) logger.info('Num dev examples = %d' % len(dev_exs)) # If we are doing offician evals then we need to: # 1) Load the original text to retrieve spans from offsets. # 2) Load the (multiple) text answers for each question. if args.official_eval: dev_texts = utils.load_text(args.dev_json) dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs} dev_answers = utils.load_answers(args.dev_json) # -------------------------------------------------------------------------- # MODEL logger.info('-' * 100) start_epoch = 0 if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'): # Just resume training, no modifications. logger.info('Found a checkpoint...') checkpoint_file = args.model_file + '.checkpoint' model, start_epoch = SentenceSelector.load_checkpoint(checkpoint_file, args) else: # Training starts fresh. But the model state is either pretrained or # newly (randomly) initialized. if args.pretrained: logger.info('Using pretrained model...') model = SentenceSelector.load(args.pretrained, args) if args.expand_dictionary: logger.info('Expanding dictionary for new data...') # Add words in training + dev examples words = utils.load_words(args, train_exs + dev_exs) added = model.expand_dictionary(words) # Load pretrained embeddings for added words if args.embedding_file: model.load_embeddings(added, args.embedding_file) else: logger.info('Training model from scratch...') model = init_from_scratch(args, train_exs, dev_exs) # Set up partial tuning of embeddings if args.tune_partial > 0: logger.info('-' * 100) logger.info('Counting %d most frequent question words' % args.tune_partial) top_words = utils.top_question_words( args, train_exs, model.word_dict ) for word in top_words[:5]: logger.info(word) logger.info('...') for word in top_words[-6:-1]: logger.info(word) model.tune_embeddings([w[0] for w in top_words]) # Set up optimizer model.init_optimizer() # Use the GPU? if args.cuda: model.cuda() # Use multiple GPUs? if args.parallel: model.parallelize() # -------------------------------------------------------------------------- # DATA ITERATORS # Two datasets: train and dev. If we sort by length it's faster. logger.info('-' * 100) logger.info('Make data loaders') train_dataset = data.SentenceSelectorDataset(train_exs, model, single_answer=True) if args.sort_by_len: train_sampler = data.SortedBatchSampler(train_dataset.lengths(), args.batch_size, shuffle=True) else: train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, ) dev_dataset = data.SentenceSelectorDataset(dev_exs, model, single_answer=False) if args.sort_by_len: dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(), args.test_batch_size, shuffle=False) else: dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset) dev_loader = torch.utils.data.DataLoader( dev_dataset, batch_size=args.test_batch_size, sampler=dev_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, ) # ------------------------------------------------------------------------- # PRINT CONFIG logger.info('-' * 100) logger.info('CONFIG:\n%s' % json.dumps(vars(args), indent=4, sort_keys=True)) # -------------------------------------------------------------------------- # TRAIN/VALID LOOP logger.info('-' * 100) logger.info('Starting training...') stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0} ## allow toggle mode that will let you evaluate on whatever dev set you give it; preload model if args.global_mode == "test": result = validate_unofficial(args, dev_loader, model, stats, mode='dev') print(result[args.valid_metric]) exit(0) for epoch in range(start_epoch, args.num_epochs): stats['epoch'] = epoch # Train train(args, train_loader, model, stats) # Validate unofficial (train) validate_unofficial(args, train_loader, model, stats, mode='train') # Validate unofficial (dev) result = validate_unofficial(args, dev_loader, model, stats, mode='dev') # Validate official if args.official_eval: result = validate_official(args, dev_loader, model, stats, dev_offsets, dev_texts, dev_answers) # Save best valid if result[args.valid_metric] > stats['best_valid']: logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' % (args.valid_metric, result[args.valid_metric], stats['epoch'], model.updates)) model.save(args.model_file) stats['best_valid'] = result[args.valid_metric] if epoch % 5 == 0: model.save(args.model_file + ".dummy")
def validate_unofficial(args, data_loader, model, global_stats, mode): """Run one full unofficial validation. Unofficial = doesn't use SQuAD script. """ eval_time = utils.Timer() acc = utils.AverageMeter() # end_acc = utils.AverageMeter() # exact_match = utils.AverageMeter() # fout = open(os.path.join(DATA_DIR,DUMP_FILE), "w+") # Make predictions examples = 0 attacked = 0 attacked_correct = 0 correct = 0 non_adv = 0 for ex in data_loader: batch_size = ex[0].size(0) pred = model.predict(ex, top_n=3) target = ex[-2:-1] if args.global_mode == "test": preds = np.array([p[0] for p in pred]) sent_lengths = (ex[3].sum(1)).long().data.numpy() #attacked += (pred == sent_lengths - 1).sum() for enum_, p in enumerate(pred): # fout.write("%s\t%d\t%d\t%d\n"%(ex[-1][enum_], p[0], p[1], p[2])) true_flag = False if "high" not in ex[-1][enum_]: non_adv += 1 continue for q in p: if q in target[0][enum_]: correct += 1 true_flag = True for q in p: if q == sent_lengths[enum_] - 1: attacked += 1 if q == sent_lengths[enum_] - 1 and true_flag: attacked_correct += 1 #attacked += (pred == sent_lengths - 1).astype(int).sum() # We get metrics for independent start/end and joint start/end accuracy = eval_accuracies(pred, target, mode) acc.update(accuracy, batch_size) # end_acc.update(accuracies[1], batch_size) # exact_match.update(accuracies[2], batch_size) # If getting train accuracies, sample max 10k examples += batch_size if examples % 1000 == 0: print("%d examples completed" % examples) if mode == 'train' and examples >= 1e4: break # fout.close() logger.info('%s valid unofficial: Epoch = %d | accuracy = %.2f | ' % (mode, global_stats['epoch'], acc.avg) + 'examples = %d | ' % (examples) + 'valid time = %.2f (s)' % eval_time.time()) if args.global_mode == "test": print("Number of examples attacked succesfully: %d" % attacked) print("Total number of correct adversarial examples: %d" % correct) print("Number of examples adversarial examples: %d" % non_adv) print("Number of correct examples attacked succesfully: %d" % attacked_correct) print("Number of examples: %d" % examples) return {'accuracy': acc.avg}