def main(args): # -------------------------------------------------------------------------- # DATA logger.info('-' * 100) logger.info('Load data files') train_exs = [] for t_file in args.train_file: train_exs += utils.load_data(args, t_file, skip_no_answer=True) # Shuffle training examples np.random.shuffle(train_exs) logger.info('Num train examples = %d' % len(train_exs)) dev_exs = utils.load_data(args, args.dev_file) logger.info('Num dev examples = %d' % len(dev_exs)) # If we are doing offician evals then we need to: # 1) Load the original text to retrieve spans from offsets. # 2) Load the (multiple) text answers for each question. if args.official_eval: dev_texts = utils.load_text(args.dev_json) dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs} dev_answers = utils.load_answers(args.dev_json) # -------------------------------------------------------------------------- # MODEL logger.info('-' * 100) start_epoch = 0 if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'): # Just resume training, no modifications. logger.info('Found a checkpoint...') checkpoint_file = args.model_file + '.checkpoint' model, start_epoch = SentenceSelector.load_checkpoint(checkpoint_file, args) else: # Training starts fresh. But the model state is either pretrained or # newly (randomly) initialized. if args.pretrained: logger.info('Using pretrained model...') model = SentenceSelector.load(args.pretrained, args) if args.expand_dictionary: logger.info('Expanding dictionary for new data...') # Add words in training + dev examples words = utils.load_words(args, train_exs + dev_exs) added = model.expand_dictionary(words) # Load pretrained embeddings for added words if args.embedding_file: model.load_embeddings(added, args.embedding_file) else: logger.info('Training model from scratch...') model = init_from_scratch(args, train_exs, dev_exs) # Set up partial tuning of embeddings if args.tune_partial > 0: logger.info('-' * 100) logger.info('Counting %d most frequent question words' % args.tune_partial) top_words = utils.top_question_words( args, train_exs, model.word_dict ) for word in top_words[:5]: logger.info(word) logger.info('...') for word in top_words[-6:-1]: logger.info(word) model.tune_embeddings([w[0] for w in top_words]) # Set up optimizer model.init_optimizer() # Use the GPU? if args.cuda: model.cuda() # Use multiple GPUs? if args.parallel: model.parallelize() # -------------------------------------------------------------------------- # DATA ITERATORS # Two datasets: train and dev. If we sort by length it's faster. logger.info('-' * 100) logger.info('Make data loaders') train_dataset = data.SentenceSelectorDataset(train_exs, model, single_answer=True) if args.sort_by_len: train_sampler = data.SortedBatchSampler(train_dataset.lengths(), args.batch_size, shuffle=True) else: train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, ) dev_dataset = data.SentenceSelectorDataset(dev_exs, model, single_answer=False) if args.sort_by_len: dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(), args.test_batch_size, shuffle=False) else: dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset) dev_loader = torch.utils.data.DataLoader( dev_dataset, batch_size=args.test_batch_size, sampler=dev_sampler, num_workers=args.data_workers, collate_fn=vector.batchify, pin_memory=args.cuda, ) # ------------------------------------------------------------------------- # PRINT CONFIG logger.info('-' * 100) logger.info('CONFIG:\n%s' % json.dumps(vars(args), indent=4, sort_keys=True)) # -------------------------------------------------------------------------- # TRAIN/VALID LOOP logger.info('-' * 100) logger.info('Starting training...') stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0} ## allow toggle mode that will let you evaluate on whatever dev set you give it; preload model if args.global_mode == "test": result = validate_unofficial(args, dev_loader, model, stats, mode='dev') print(result[args.valid_metric]) exit(0) for epoch in range(start_epoch, args.num_epochs): stats['epoch'] = epoch # Train train(args, train_loader, model, stats) # Validate unofficial (train) validate_unofficial(args, train_loader, model, stats, mode='train') # Validate unofficial (dev) result = validate_unofficial(args, dev_loader, model, stats, mode='dev') # Validate official if args.official_eval: result = validate_official(args, dev_loader, model, stats, dev_offsets, dev_texts, dev_answers) # Save best valid if result[args.valid_metric] > stats['best_valid']: logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' % (args.valid_metric, result[args.valid_metric], stats['epoch'], model.updates)) model.save(args.model_file) stats['best_valid'] = result[args.valid_metric] if epoch % 5 == 0: model.save(args.model_file + ".dummy")
def main(args): # -------------------------------------------------------------------------- # DATA logger.info('-' * 100) logger.info('Load data files') train_exs = [] for t_file in args.train_file: train_exs += utils.load_data(args, t_file, skip_no_answer=True) np.random.shuffle(train_exs) logger.info('Num train examples = %d' % len(train_exs)) dev_exs = utils.load_data(args, args.dev_file) logger.info('Num dev examples = %d' % len(dev_exs)) # If we are doing offician evals then we need to: # 1) Load the original text to retrieve spans from offsets. # 2) Load the (multiple) text answers for each question. if args.official_eval: dev_texts = utils.load_text(args.dev_json) dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs} dev_answers = utils.load_answers(args.dev_json) ## OFFSET comes from the gold sentence; the predicted sentence value shoule be maintained and sent to official validation set # -------------------------------------------------------------------------- # MODEL logger.info('-' * 100) start_epoch = 0 if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'): # Just resume training, no modifications. logger.info('Found a checkpoint...') checkpoint_file = args.model_file + '.checkpoint' model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args) else: # Training starts fresh. But the model state is either pretrained or # newly (randomly) initialized. if args.pretrained: logger.info('Using pretrained model...') model = DocReader.load(args.pretrained, args) if args.expand_dictionary: logger.info('Expanding dictionary for new data...') # Add words in training + dev examples words = utils.load_words(args, train_exs + dev_exs) added = model.expand_dictionary(words) # Load pretrained embeddings for added words if args.embedding_file: model.load_embeddings(added, args.embedding_file) else: logger.info('Training model from scratch...') model = init_from_scratch(args, train_exs, dev_exs) # Set up partial tuning of embeddings if args.tune_partial > 0: logger.info('-' * 100) logger.info('Counting %d most frequent question words' % args.tune_partial) top_words = utils.top_question_words( args, train_exs, model.word_dict ) for word in top_words[:5]: logger.info(word) logger.info('...') for word in top_words[-6:-1]: logger.info(word) model.tune_embeddings([w[0] for w in top_words]) # Set up optimizer model.init_optimizer() # Use the GPU? if args.cuda: model.cuda() # Use multiple GPUs? if args.parallel: model.parallelize() # -------------------------------------------------------------------------- # DATA ITERATORS # Two datasets: train and dev. If we sort by length it's faster. # Sentence selection objective : run the sentence selector as a submodule logger.info('-' * 100) logger.info('Make data loaders') train_dataset = reader_data.ReaderDataset(train_exs, model, single_answer=True) # Filter out None examples in training dataset (where sentence selection fails) #train_dataset.examples = [t for t in train_dataset.examples if t is not None] if args.sort_by_len: train_sampler = reader_data.SortedBatchSampler(train_dataset.lengths(), args.batch_size, shuffle=True) else: train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset) if args.use_sentence_selector: train_batcher = reader_vector.sentence_batchifier(model, single_answer=True) # batching_function = train_batcher.batchify batching_function = reader_vector.batchify else: batching_function = reader_vector.batchify train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.data_workers, collate_fn=batching_function, pin_memory=args.cuda, ) dev_dataset = reader_data.ReaderDataset(dev_exs, model, single_answer=False) #dev_dataset.examples = [t for t in dev_dataset.examples if t is not None] if args.sort_by_len: dev_sampler = reader_data.SortedBatchSampler(dev_dataset.lengths(), args.test_batch_size, shuffle=False) else: dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset) if args.use_sentence_selector: dev_batcher = reader_vector.sentence_batchifier(model, single_answer=False) # batching_function = dev_batcher.batchify batching_function = reader_vector.batchify else: batching_function = reader_vector.batchify dev_loader = torch.utils.data.DataLoader( dev_dataset, batch_size=args.test_batch_size, sampler=dev_sampler, num_workers=args.data_workers, collate_fn=batching_function, pin_memory=args.cuda, ) ## Dev dataset for measuring performance of the trained sentence selector if args.use_sentence_selector: dev_dataset1 = selector_data.SentenceSelectorDataset(dev_exs, model.sentence_selector, single_answer=False) #dev_dataset1.examples = [t for t in dev_dataset.examples if t is not None] if args.sort_by_len: dev_sampler1 = selector_data.SortedBatchSampler(dev_dataset1.lengths(), args.test_batch_size, shuffle=False) else: dev_sampler1 = torch.utils.data.sampler.SequentialSampler(dev_dataset1) dev_loader1 = torch.utils.data.DataLoader( dev_dataset1, #batch_size=args.test_batch_size, #sampler=dev_sampler1, batch_sampler = dev_sampler1, num_workers=args.data_workers, collate_fn=selector_vector.batchify, pin_memory=args.cuda, ) # ------------------------------------------------------------------------- # PRINT CONFIG logger.info('-' * 100) logger.info('CONFIG:\n%s' % json.dumps(vars(args), indent=4, sort_keys=True)) # -------------------------------------------------------------------------- # TRAIN/VALID LOOP logger.info('-' * 100) logger.info('Starting training...') stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0} # -------------------------------------------------------------------------- # QUICKLY VALIDATE ON PRETRAINED MODEL if args.global_mode == "test": result1 = validate_unofficial(args, dev_loader, model, stats, mode='dev') result2 = validate_official(args, dev_loader, model, stats, dev_offsets, dev_texts, dev_answers) print(result2[args.valid_metric]) print(result1["exact_match"]) if args.use_sentence_selector: sent_stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0} #sent_selector_results = validate_selector(model.sentence_selector.args, dev_loader1, model.sentence_selector, sent_stats, mode="dev") #print("Sentence Selector model acheives:") #print(sent_selector_results["accuracy"]) if len(args.adv_dev_json) > 0: validate_adversarial(args, model, stats, mode="dev") exit(0) valid_history = [] bad_counter = 0 for epoch in range(start_epoch, args.num_epochs): stats['epoch'] = epoch # Train train(args, train_loader, model, stats) # Validate unofficial (train) validate_unofficial(args, train_loader, model, stats, mode='train') # Validate unofficial (dev) result = validate_unofficial(args, dev_loader, model, stats, mode='dev') # Validate official if args.official_eval: result = validate_official(args, dev_loader, model, stats, dev_offsets, dev_texts, dev_answers) # Save best valid if result[args.valid_metric] >= stats['best_valid']: logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' % (args.valid_metric, result[args.valid_metric], stats['epoch'], model.updates)) model.save(args.model_file) stats['best_valid'] = result[args.valid_metric] bad_counter = 0 else: bad_counter += 1 if bad_counter > args.patience: logger.info("Early Stopping at epoch: %d" % epoch) exit(0)