예제 #1
0
            yield ' '.join(curr)
            curr = []
            curr_len = 0
        curr.append(split)
        curr_len += len(split)
    if len(curr) > 0:
        yield ' '.join(curr)


if __name__ == '__main__':
    # Parse cmdline args and setup environment
    parser = argparse.ArgumentParser(
        'DrQA Document Reader',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    add_train_args(parser)
    config.add_model_args(parser)
    args = parser.parse_args()
    set_defaults(args)

    # Set cuda
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.cuda.set_device(args.gpu)

    # Set random state
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.random_seed)

    # Set logging
예제 #2
0
def load_train_evaluate_save(mode):
    
    # -------------------------------------------------------------------------   
    # PARSER
    # -------------------------------------------------------------------------   
    
    # Parse cmdline args and setup environment
    parser = argparse.ArgumentParser(
        'OpenQA Question Answering Model',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    add_main_args(parser, mode)
    config.add_model_args(parser)
    args = parser.parse_args()
    set_defaults(args)
        
    
    # -------------------------------------------------------------------------   
    # INITIALIZATIONS
    # -------------------------------------------------------------------------   
    
    # CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    assert(args.cuda)
    if args.cuda:
        torch.cuda.set_device(args.gpu) # no-op if args.gpu is negative
        torch.cuda.empty_cache()
    
    # Set random state
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.random_seed)
    
    if args.log_file:
        if args.checkpoint:
            logfile = logging.FileHandler(args.log_file, 'a')
        else:
            logfile = logging.FileHandler(args.log_file, 'w')
    
        logfile.setFormatter(txtfmt)
        logger.addHandler(logfile)
    
    logger.info('COMMAND: {}'.format(' '.join(sys.argv)))
    
    # GPU cleaning
    gc.collect()
    for obj in gc.get_objects():
        del obj
    torch.cuda.empty_cache()
    
    
    # --------------------------------------------------------------------------
    # DATASET
    # -------------------------------------------------------------------------   
    
    logger.info('-' * 100)
    logger.info('Load data files')
    
    dataset = args.dataset # == 'searchqa', 'quasart' or 'unftriviaqa'
    
    filename_train_docs = sys_dir+'/data/datasets/'+dataset+'/train.json' 
    filename_dev_docs = sys_dir+'/data/datasets/'+dataset+'/dev.json' 
    filename_test_docs = sys_dir+'/data/datasets/'+dataset+'/test.json' 
    filename_train = sys_dir+'/data/datasets/'+dataset+'/train.txt' 
    filename_dev = sys_dir+'/data/datasets/'+dataset+'/dev.txt' 
    filename_test = sys_dir+'/data/datasets/'+dataset+'/test.txt'
    
    train_docs, train_questions, train_len = utils.load_data_with_doc(
            args, filename_train_docs)
    logger.info(len(train_docs))
    logger.info(len(train_questions))
    
    train_exs_with_doc = read_data(filename_train, train_questions, train_len)
    logger.info('Num train examples = {}'.format(str(len(train_exs_with_doc))))
    
    dev_docs, dev_questions, _ = utils.load_data_with_doc(
            args, filename_dev_docs)
    logger.info(len(dev_docs))
    logger.info(len(dev_questions))
    
    dev_exs_with_doc = read_data(filename_dev, dev_questions)
    logger.info('Num dev examples = {}'.format(str(len(dev_exs_with_doc))))
    
    test_docs, test_questions, _ = utils.load_data_with_doc(
            args, filename_test_docs)
    logger.info(len(test_docs))
    logger.info(len(test_questions))
    
    test_exs_with_doc = read_data(filename_test, test_questions)
    logger.info('Num test examples = {}'.format(str(len(test_exs_with_doc))))


    # --------------------------------------------------------------------------
    # MODEL SETUP
    # -------------------------------------------------------------------------   
    
    logger.info('-' * 100)
    start_epoch = 0
    
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file)
        start_epoch = 0
    
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                
                # Add words in training and dev examples
                #words = utils.load_words(args, train_exs + dev_exs)
                words = utils.load_words(
                        args, train_exs_with_doc + dev_exs_with_doc)
                added = model.expand_dictionary(words)
                
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_docs)

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()
    
    # GPU usage
    if args.show_cuda_stats:
        gpu_usage()

    
    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # -------------------------------------------------------------------------   
    
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')
    
    # best practices for memory management are available here:
    # https://pytorch.org/docs/stable/notes/cuda.html#best-practices
    
    train_dataset_with_doc = data.ReaderDataset_with_Doc(
            train_exs_with_doc, model, train_docs, single_answer=True)
    train_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
            train_dataset_with_doc)
    train_loader_with_doc = torch.utils.data.DataLoader(
            train_dataset_with_doc,
            batch_size=args.batch_size, # batch_size of 128 samples
            sampler=train_sampler_with_doc,
            num_workers=args.data_workers, # num_workers increased to 12
            collate_fn=vector.batchify_with_docs,
            pin_memory=args.cuda, # pin_memory = True by default
            )

    dev_dataset_with_doc = data.ReaderDataset_with_Doc(
            dev_exs_with_doc, model, dev_docs, single_answer=False)
    dev_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
            dev_dataset_with_doc)
    dev_loader_with_doc = torch.utils.data.DataLoader(
            dev_dataset_with_doc,
            batch_size=args.test_batch_size,
            sampler=dev_sampler_with_doc,
            num_workers=args.data_workers,
            collate_fn=vector.batchify_with_docs,
            pin_memory=args.cuda,
            )

    test_dataset_with_doc = data.ReaderDataset_with_Doc(
            test_exs_with_doc, model, test_docs, single_answer=False)
    test_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
            test_dataset_with_doc)
    test_loader_with_doc = torch.utils.data.DataLoader(
           test_dataset_with_doc,
           batch_size=args.test_batch_size,
           sampler=test_sampler_with_doc,
           num_workers=args.data_workers,
           collate_fn=vector.batchify_with_docs,
           pin_memory=args.cuda,
           )


    # -------------------------------------------------------------------------
    # PRINT CONFIG 
    # -------------------------------------------------------------------------   
    
    logger.info('-' * 100)
    logger.info('CONFIG:')
    print(json.dumps(vars(args), indent=4, sort_keys=True))


    # --------------------------------------------------------------------------
    # TRAIN/VALIDATION LOOP
    # -------------------------------------------------------------------------   
    
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
          
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        logger.info('-' * 100)
        logger.info('Mode: ' + args.mode)
        
        if (args.mode == 'all'):
            train(args, 
                    train_loader_with_doc, model, stats, 
                    train_exs_with_doc, train_docs)
        if (args.mode == 'reader'):
            pretrain_reader(args, 
                    train_loader_with_doc, model, stats, 
                    train_exs_with_doc, train_docs)
        if (args.mode == 'selector'):
            pretrain_selector(args, 
                    train_loader_with_doc, model, stats, 
                    train_exs_with_doc, train_docs)
        
        # ---------------------------------------------------------------------
        with torch.no_grad():
            # -----------------------------------------------------------------
            result = validate_with_doc(args, 
                    dev_loader_with_doc, model, stats, dev_exs_with_doc, 
                    dev_docs, 'dev')
            
            validate_with_doc(args, 
                    train_loader_with_doc, model, stats, train_exs_with_doc, 
                    train_docs, 'train')
            
            if (dataset=='webquestions' or dataset=='CuratedTrec'): # not applicable
                result = validate_with_doc(args, 
                        test_loader_with_doc, model, stats, 
                        test_exs_with_doc, test_docs, 'test')
            else: # dataset == 'searchqa' by default, 'squad', 'quasart' or 'unftriviaqa'
                validate_with_doc(args, 
                        test_loader_with_doc, model, stats, 
                        test_exs_with_doc, test_docs, 'test')
        # ---------------------------------------------------------------------
        
        # Save model with improved evaluation results
        if result[args.valid_metric] > stats['best_valid']:
            
            txt = 'Best valid: {} = {:.2f} (epoch {}, {} updates)'
            logger.info(txt.format(
                    args.valid_metric, result[args.valid_metric],
                    stats['epoch'], model.updates))
            
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]
        
        # Clean the gpu before running a new iteration
        if args.cuda:                 
            
            gc.collect() # force garbage collection
            for obj in gc.get_objects(): 
                if torch.is_tensor(obj): 
                    del obj
            
            torch.cuda.synchronize(device=model.device) # wait for the gpu
            torch.cuda.empty_cache() # force garbage removal
        
        # CUDA memory
        txt_cuda(show=True, txt='after garbage collection')
예제 #3
0
def user_interface(mode):

    # Model arguments
    parser = argparse.ArgumentParser(
        'OpenQA Document Reader',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    add_interface_args(parser, mode)
    config.add_model_args(parser)
    args = parser.parse_args()

    set_defaults(args)

    # CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.cuda:
        torch.cuda.set_device(args.gpu)
        logger.info('CUDA enabled (GPU {})'.format(args.gpu))
    else:
        logger.info('Running on CPU only.')

    logger.info('Initializing pipeline...')

    # Load the tf-idf ranker
    logger.info('Initializing document ranker...')
    ranker = get_class('tfidf')()

    # Load a trained reader model
    logger.info('Loading selected model...')
    model = DocReader.load(args.model_file, new_args=args)

    model.network.eval()
    model.selector.eval()

    def ask(question=None,
            candidates=None,
            top_n=1,
            n_docs=5,
            return_context=True):

        if question is None:
            question = 'What is question answering?'

        return process(args,
                       question,
                       ranker,
                       model,
                       candidates=None,
                       top_n=1,
                       n_docs=5,
                       return_context=True)

    # Interactive mode
    logger.info('-' * 100)
    logger.info('Entering interactice mode...')

    banner = '''
    Interactive OpenQA 
    >> ask(question, top_n=1, n_docs=5, return_context=True)
    >> ask() will yield the answer to 'What is question answering?'
    press Ctrl+D to leave
    '''

    code.interact(banner=banner, local=locals())