def main(args): # -------------------------------------------------------------------------- # DATA logger.info('-' * 100) logger.info('Load data files') train_exs = utils.load_data(args, args.train_file, skip_no_answer=True) logger.info('Num train examples = %d' % len(train_exs)) dev_exs = utils.load_data(args, args.dev_file) logger.info('Num dev examples = %d' % len(dev_exs)) # If we are doing offician evals then we need to: # 1) Load the original text to retrieve spans from offsets. # 2) Load the (multiple) text answers for each question. if args.official_eval: dev_texts = utils.load_text(args.dev_json) dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs} dev_answers = utils.load_answers(args.dev_json) # -------------------------------------------------------------------------- # MODEL logger.info('-' * 100) start_epoch = 0 if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'): # Just resume training, no modifications. logger.info('Found a checkpoint...') checkpoint_file = args.model_file + '.checkpoint' model, start_epoch = DocReader.load_checkpoint(checkpoint_file) else: # Training starts fresh. But the model state is either pretrained or # newly (randomly) initialized. if args.pretrained: logger.info('Using pretrained model...') model = DocReader.load(args.pretrained, args, from_zero=args.from_zero) if args.expand_dictionary: if args.full_char: raise NotImplementedError logger.info('Expanding dictionary for new data...') # Add words in training + dev examples words = utils.load_words(args, train_exs + dev_exs) added_words = model.expand_dictionary(words) # Load pretrained embeddings for added words if args.embedding_file: model.load_embeddings(added_words, args.embedding_file) logger.info('Expanding char dictionary for new data...') # Add words in training + dev examples chars = utils.load_chars(args, train_exs + dev_exs) added_chars = model.expand_char_dictionary(chars) # Load pretrained embeddings for added words if args.char_embedding_file: model.load_char_embeddings(added_chars, args.char_embedding_file) else: logger.info('Training model from scratch...') model = init_from_scratch(args, train_exs, dev_exs) # Set up partial tuning of embeddings if args.tune_partial > 0: logger.info('-' * 100) logger.info('Counting %d most frequent question words' % args.tune_partial) top_words = utils.top_question_words(args, train_exs, model.word_dict) for word in top_words[:5]: logger.info(word) logger.info('...') for word in top_words[-6:-1]: logger.info(word) model.tune_embeddings([w[0] for w in top_words]) # Set up optimizer model.init_optimizer() # Use the GPU? if args.cuda: model.cuda() # Use multiple GPUs? if args.parallel: model.parallelize() # -------------------------------------------------------------------------- # DATA ITERATORS # Two datasets: train and dev. If we sort by length it's faster. logger.info('-' * 100) logger.info('Make data loaders') train_dataset = data.ReaderDataset(train_exs, model, single_answer=True) if args.sort_by_len: train_sampler = data.SortedBatchSampler(train_dataset.lengths(), args.batch_size, shuffle=True) else: train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.data_workers, collate_fn=vector.make_batchify_func(args), pin_memory=args.cuda, ) dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False) if args.sort_by_len: dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(), args.test_batch_size, shuffle=False) else: dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset) dev_loader = torch.utils.data.DataLoader( dev_dataset, batch_size=args.test_batch_size, sampler=dev_sampler, num_workers=args.data_workers, collate_fn=vector.make_batchify_func(args), pin_memory=args.cuda, ) # ------------------------------------------------------------------------- # PRINT CONFIG logger.info('-' * 100) logger.info('CONFIG:\n%s' % json.dumps(vars(args), indent=4, sort_keys=True)) # -------------------------------------------------------------------------- # TRAIN/VALID LOOP logger.info('-' * 100) logger.info('Starting training...') stats = { 'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0, 'best_valid_epoch': -1, 'best_valid_updates': -1 } model.save(args.model_file) for epoch in range(start_epoch, args.num_epochs): stats['epoch'] = epoch # Train train(args, train_loader, model, stats) # Validate unofficial (train) validate_unofficial(args, train_loader, model, stats, mode='train') # Validate unofficial (dev) result = validate_unofficial(args, dev_loader, model, stats, mode='dev') # Validate official if args.official_eval: result = validate_official(args, dev_loader, model, stats, dev_offsets, dev_texts, dev_answers) # Save and report best valid if result[args.valid_metric] > stats['best_valid']: stats['best_valid'] = result[args.valid_metric] stats['best_valid_epoch'] = stats['epoch'] stats['best_valid_updates'] = model.updates logger.info('Find better answer %.2f' % stats['best_valid']) model.save(args.model_file) logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' % (args.valid_metric, stats['best_valid'], stats['best_valid_epoch'], stats['best_valid_updates']))