Esempio n. 1
0
def pretrain_selector(args, data_loader, model, global_stats, exs_with_doc,
                      docs_by_question):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()
    # Run one epoch
    tot_ans = 0
    tot_num = 0
    global HasAnswer_Map
    for idx, ex_with_doc in enumerate(data_loader):
        ex = ex_with_doc[0]
        batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1]
        if (idx not in HasAnswer_Map):
            HasAnswer_list = []
            for idx_doc in range(0, vector.num_docs):
                HasAnswer = []
                for i in range(batch_size):
                    has_a, a_l = has_answer(
                        args, exs_with_doc[ex_id[i]]['answer'],
                        docs_by_question[ex_id[i]][idx_doc % len(
                            docs_by_question[ex_id[i]])]["document"])
                    HasAnswer.append(has_a)
                HasAnswer_list.append(HasAnswer)
            #HasAnswer_list = torch.LongTensor(HasAnswer_list)
            HasAnswer_Map[idx] = HasAnswer_list
        else:
            HasAnswer_list = HasAnswer_Map[idx]
        for idx_doc in range(0, vector.num_docs):
            for i in range(batch_size):
                tot_ans += HasAnswer_list[idx_doc][i]
                tot_num += 1

        weights = []
        for idx_doc in range(0, vector.num_docs):
            weights.append(1)
        weights = torch.Tensor(weights)
        idx_random = torch.multinomial(weights, int(vector.num_docs))

        HasAnswer_list_sample = []
        ex_with_doc_sample = []
        for idx_doc in idx_random:
            HasAnswer_list_sample.append(HasAnswer_list[idx_doc])
            ex_with_doc_sample.append(ex_with_doc[idx_doc])
        HasAnswer_list_sample = torch.LongTensor(HasAnswer_list_sample)

        train_loss.update(*model.pretrain_selector(ex_with_doc_sample,
                                                   HasAnswer_list_sample))
        #train_loss.update(*model.pretrain_ranker(ex_with_doc, HasAnswer_list))
        if idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            logger.info("tot_ans:\t%d\t%d\t%f", tot_ans, tot_num,
                        tot_ans * 1.0 / tot_num)
            train_loss.reset()
    logger.info("tot_ans:\t%d\t%d", tot_ans, tot_num)
    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))
Esempio n. 2
0
def pretrain_reader(args, data_loader, model, global_stats, exs_with_doc,
                    docs_by_question):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()
    logger.info("pretrain_reader")
    # Run one epoch
    global HasAnswer_Map
    count_ans = 0
    count_tot = 0
    for idx, ex_with_doc in enumerate(data_loader):
        #logger.info(idx)
        ex = ex_with_doc[0]
        batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1]
        if (idx not in HasAnswer_Map):
            HasAnswer_list = []
            for idx_doc in range(0, vector.num_docs):
                HasAnswer = []
                for i in range(batch_size):
                    HasAnswer.append(
                        has_answer(
                            args, exs_with_doc[ex_id[i]]['answer'],
                            docs_by_question[ex_id[i]][idx_doc % len(
                                docs_by_question[ex_id[i]])]["document"]))
                HasAnswer_list.append(HasAnswer)
            HasAnswer_Map[idx] = HasAnswer_list
        else:
            HasAnswer_list = HasAnswer_Map[idx]

        for idx_doc in range(0, vector.num_docs):
            l_list = []
            r_list = []
            pred_s, pred_e, pred_score = model.predict(ex_with_doc[idx_doc],
                                                       top_n=1)
            for i in range(batch_size):
                if HasAnswer_list[idx_doc][i][0]:
                    count_ans += len(HasAnswer_list[idx_doc][i][1])
                    count_tot += 1
                    l_list.append(HasAnswer_list[idx_doc][i][1])
                else:
                    l_list.append([(int(pred_s[i][0]), int(pred_e[i][0]))])
            train_loss.update(*model.update(ex_with_doc[idx_doc], l_list,
                                            r_list, HasAnswer_list[idx_doc]))
        if idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()
            logger.info("%d\t%d\t%f", count_ans, count_tot,
                        1.0 * count_ans / (count_tot + 1))
    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))
Esempio n. 3
0
def load_train_evaluate_save(mode):
    
    # -------------------------------------------------------------------------   
    # PARSER
    # -------------------------------------------------------------------------   
    
    # Parse cmdline args and setup environment
    parser = argparse.ArgumentParser(
        'OpenQA Question Answering Model',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    add_main_args(parser, mode)
    config.add_model_args(parser)
    args = parser.parse_args()
    set_defaults(args)
        
    
    # -------------------------------------------------------------------------   
    # INITIALIZATIONS
    # -------------------------------------------------------------------------   
    
    # CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    assert(args.cuda)
    if args.cuda:
        torch.cuda.set_device(args.gpu) # no-op if args.gpu is negative
        torch.cuda.empty_cache()
    
    # Set random state
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.random_seed)
    
    if args.log_file:
        if args.checkpoint:
            logfile = logging.FileHandler(args.log_file, 'a')
        else:
            logfile = logging.FileHandler(args.log_file, 'w')
    
        logfile.setFormatter(txtfmt)
        logger.addHandler(logfile)
    
    logger.info('COMMAND: {}'.format(' '.join(sys.argv)))
    
    # GPU cleaning
    gc.collect()
    for obj in gc.get_objects():
        del obj
    torch.cuda.empty_cache()
    
    
    # --------------------------------------------------------------------------
    # DATASET
    # -------------------------------------------------------------------------   
    
    logger.info('-' * 100)
    logger.info('Load data files')
    
    dataset = args.dataset # == 'searchqa', 'quasart' or 'unftriviaqa'
    
    filename_train_docs = sys_dir+'/data/datasets/'+dataset+'/train.json' 
    filename_dev_docs = sys_dir+'/data/datasets/'+dataset+'/dev.json' 
    filename_test_docs = sys_dir+'/data/datasets/'+dataset+'/test.json' 
    filename_train = sys_dir+'/data/datasets/'+dataset+'/train.txt' 
    filename_dev = sys_dir+'/data/datasets/'+dataset+'/dev.txt' 
    filename_test = sys_dir+'/data/datasets/'+dataset+'/test.txt'
    
    train_docs, train_questions, train_len = utils.load_data_with_doc(
            args, filename_train_docs)
    logger.info(len(train_docs))
    logger.info(len(train_questions))
    
    train_exs_with_doc = read_data(filename_train, train_questions, train_len)
    logger.info('Num train examples = {}'.format(str(len(train_exs_with_doc))))
    
    dev_docs, dev_questions, _ = utils.load_data_with_doc(
            args, filename_dev_docs)
    logger.info(len(dev_docs))
    logger.info(len(dev_questions))
    
    dev_exs_with_doc = read_data(filename_dev, dev_questions)
    logger.info('Num dev examples = {}'.format(str(len(dev_exs_with_doc))))
    
    test_docs, test_questions, _ = utils.load_data_with_doc(
            args, filename_test_docs)
    logger.info(len(test_docs))
    logger.info(len(test_questions))
    
    test_exs_with_doc = read_data(filename_test, test_questions)
    logger.info('Num test examples = {}'.format(str(len(test_exs_with_doc))))


    # --------------------------------------------------------------------------
    # MODEL SETUP
    # -------------------------------------------------------------------------   
    
    logger.info('-' * 100)
    start_epoch = 0
    
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file)
        start_epoch = 0
    
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                
                # Add words in training and dev examples
                #words = utils.load_words(args, train_exs + dev_exs)
                words = utils.load_words(
                        args, train_exs_with_doc + dev_exs_with_doc)
                added = model.expand_dictionary(words)
                
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_docs)

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()
    
    # GPU usage
    if args.show_cuda_stats:
        gpu_usage()

    
    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # -------------------------------------------------------------------------   
    
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')
    
    # best practices for memory management are available here:
    # https://pytorch.org/docs/stable/notes/cuda.html#best-practices
    
    train_dataset_with_doc = data.ReaderDataset_with_Doc(
            train_exs_with_doc, model, train_docs, single_answer=True)
    train_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
            train_dataset_with_doc)
    train_loader_with_doc = torch.utils.data.DataLoader(
            train_dataset_with_doc,
            batch_size=args.batch_size, # batch_size of 128 samples
            sampler=train_sampler_with_doc,
            num_workers=args.data_workers, # num_workers increased to 12
            collate_fn=vector.batchify_with_docs,
            pin_memory=args.cuda, # pin_memory = True by default
            )

    dev_dataset_with_doc = data.ReaderDataset_with_Doc(
            dev_exs_with_doc, model, dev_docs, single_answer=False)
    dev_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
            dev_dataset_with_doc)
    dev_loader_with_doc = torch.utils.data.DataLoader(
            dev_dataset_with_doc,
            batch_size=args.test_batch_size,
            sampler=dev_sampler_with_doc,
            num_workers=args.data_workers,
            collate_fn=vector.batchify_with_docs,
            pin_memory=args.cuda,
            )

    test_dataset_with_doc = data.ReaderDataset_with_Doc(
            test_exs_with_doc, model, test_docs, single_answer=False)
    test_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
            test_dataset_with_doc)
    test_loader_with_doc = torch.utils.data.DataLoader(
           test_dataset_with_doc,
           batch_size=args.test_batch_size,
           sampler=test_sampler_with_doc,
           num_workers=args.data_workers,
           collate_fn=vector.batchify_with_docs,
           pin_memory=args.cuda,
           )


    # -------------------------------------------------------------------------
    # PRINT CONFIG 
    # -------------------------------------------------------------------------   
    
    logger.info('-' * 100)
    logger.info('CONFIG:')
    print(json.dumps(vars(args), indent=4, sort_keys=True))


    # --------------------------------------------------------------------------
    # TRAIN/VALIDATION LOOP
    # -------------------------------------------------------------------------   
    
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
          
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        logger.info('-' * 100)
        logger.info('Mode: ' + args.mode)
        
        if (args.mode == 'all'):
            train(args, 
                    train_loader_with_doc, model, stats, 
                    train_exs_with_doc, train_docs)
        if (args.mode == 'reader'):
            pretrain_reader(args, 
                    train_loader_with_doc, model, stats, 
                    train_exs_with_doc, train_docs)
        if (args.mode == 'selector'):
            pretrain_selector(args, 
                    train_loader_with_doc, model, stats, 
                    train_exs_with_doc, train_docs)
        
        # ---------------------------------------------------------------------
        with torch.no_grad():
            # -----------------------------------------------------------------
            result = validate_with_doc(args, 
                    dev_loader_with_doc, model, stats, dev_exs_with_doc, 
                    dev_docs, 'dev')
            
            validate_with_doc(args, 
                    train_loader_with_doc, model, stats, train_exs_with_doc, 
                    train_docs, 'train')
            
            if (dataset=='webquestions' or dataset=='CuratedTrec'): # not applicable
                result = validate_with_doc(args, 
                        test_loader_with_doc, model, stats, 
                        test_exs_with_doc, test_docs, 'test')
            else: # dataset == 'searchqa' by default, 'squad', 'quasart' or 'unftriviaqa'
                validate_with_doc(args, 
                        test_loader_with_doc, model, stats, 
                        test_exs_with_doc, test_docs, 'test')
        # ---------------------------------------------------------------------
        
        # Save model with improved evaluation results
        if result[args.valid_metric] > stats['best_valid']:
            
            txt = 'Best valid: {} = {:.2f} (epoch {}, {} updates)'
            logger.info(txt.format(
                    args.valid_metric, result[args.valid_metric],
                    stats['epoch'], model.updates))
            
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]
        
        # Clean the gpu before running a new iteration
        if args.cuda:                 
            
            gc.collect() # force garbage collection
            for obj in gc.get_objects(): 
                if torch.is_tensor(obj): 
                    del obj
            
            torch.cuda.synchronize(device=model.device) # wait for the gpu
            torch.cuda.empty_cache() # force garbage removal
        
        # CUDA memory
        txt_cuda(show=True, txt='after garbage collection')
Esempio n. 4
0
def validate_with_doc(args, data_loader, model, global_stats, 
                                 exs_with_doc, docs_by_question, mode):
    '''Run one full unofficial validation with docs.
    Unofficial = doesn't use SQuAD script.
    '''
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    logger.info('validate_with_doc')
    
    # Intialize counters
    examples = 0
    
    aa = [0.0 for i in range(num_docs)] # increment only if example has answer
    bb = [0.0 for i in range(num_docs)] # increment regardless
    
    for idx, ex_with_doc in enumerate(data_loader):
        ex = ex_with_doc[0]
        batch_size, ex_id = ex[0].size(0), ex[-1]
        
        # ---------------------------------------------------------------------
        # Document Selector
        # ---------------------------------------------------------------------
        
        '''
        ex_with_doc = 
        [tensor]  x1 = document word indices            [batch * len_d]
        [tensor]  x1_f = document word features indices [batch * len_d * nfeat]
        [tensor]  x1_mask = document padding mask       [batch * len_d]
        [tensor]  x2 = question word indices            [batch * len_q]
        [tensor]  x2_mask = question padding mask       [batch * len_q]
        [list]    indices                               [batch]
        '''
        
        scores_doc_num = model.predict_with_doc(ex_with_doc)
        scores = [{} for i in range(batch_size)]
        
        # ---------------------------------------------------------------------
        # Document Reader
        # ---------------------------------------------------------------------
        for idx_doc in range(0, num_docs):
            ex = ex_with_doc[idx_doc]
            pred_s, pred_e, pred_score = model.predict(
                    ex, top_n=display_num)
            
            for i in range(batch_size):
                idx_doc_i = idx_doc %len(docs_by_question[ex_id[i]])
                doc_text = docs_by_question[ex_id[i]][idx_doc_i]['document']
                
                # try to read the 10 best predicted answers (this may trigger 
                # an 'index out of range' exception)
                for k in range(display_num):
                    
                    try:
                        prediction = [doc_text[j] for j in range(pred_s[i][k], 
                                      pred_e[i][k]+1)]
                        prediction = ' '.join(prediction).lower()
                        
                        # update prediction scores
                        if (prediction not in scores[i]): 
                            scores[i][prediction] = 0
                        scores[i][prediction] += (pred_score[i][k] * 
                              scores_doc_num[i][idx_doc])
                    
                    except:
                        pass 
        
        # Get the 10 most likely answers from the batch and see if the answer 
        # is actually in there           
        for i in range(batch_size):
            _, indices = scores_doc_num[i].sort(0, descending = True)
            
            for j in range(0, display_num):
                idx_doc = indices[j]
                idx_doc_i = idx_doc %len(docs_by_question[ex_id[i]])
                
                doc_text = docs_by_question[ex_id[i]][idx_doc_i]['document']
                ex_answer = exs_with_doc[ex_id[i]]['answer']
                
                # Looking for the answer in the document...
                if (has_answer(args, 
                               ex_answer, 
                               doc_text)[0]):
                    aa[j]= aa[j] + 1
                    
                bb[j]= bb[j]+1

        # Update performance metrics
        for i in range(batch_size):
            
            best_score = 0
            prediction = ''
            for key in scores[i]:
                if (scores[i][key] > best_score):
                    best_score = scores[i][key]
                    prediction = key
            
            ground_truths = []
            ex_answer = exs_with_doc[ex_id[i]]['answer']
            
            # Ground truth answers
            if (args.dataset == 'CuratedTrec'): # not applicable
                ground_truths = ex_answer
            else: 
                for a in ex_answer:
                    ground_truths.append(' '.join([w for w in a]))
                    
            exact_match.update(
                    utils.metric_max_over_ground_truths(
                            utils.exact_match_score, prediction, ground_truths))
            
            f1.update(
                    utils.metric_max_over_ground_truths(
                            utils.f1_score, prediction, ground_truths))
            
        examples += batch_size
        
        if (mode=='train' and examples>=1000):
            break
    
    try:
        for j in range(display_num):
            if (j>0):
                aa[j]= aa[j]+aa[j-1]
                bb[j]= bb[j]+bb[j-1]
    except:
        pass
    
    txt =  '{} valid official with doc: Epoch = {} | EM = {:.2f} | '
    txt += 'F1 = {:.2f} | examples = {} | valid time = {:.2f} (s)'
    logger.info(txt.format(
            mode, global_stats['epoch'], exact_match.avg * 100, 
            f1.avg * 100, examples, eval_time.time()))

    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
Esempio n. 5
0
def main(args):
    # --------------------------------------------------------------------------
    # TOK
    global PROCESS_TOK
    tok_class = tokenizers.get_class("corenlp")
    tok_opts = {}
    PROCESS_TOK = tok_class(**tok_opts)
    Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)

    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    dataset = args.dataset  #'quasart'#'searchqa'#'unftriviaqa'#'squad'#
    filename_train_docs = sys_dir + "/data/datasets/" + dataset + "/train.json"
    filename_dev_docs = sys_dir + "/data/datasets/" + dataset + "/dev.json"
    filename_test_docs = sys_dir + "/data/datasets/" + dataset + "/test.json"
    train_docs, train_questions = utils.load_data_with_doc(
        args, filename_train_docs)
    logger.info(len(train_docs))
    filename_train = sys_dir + "/data/datasets/" + dataset + "/train.txt"
    filename_dev = sys_dir + "/data/datasets/" + dataset + "/dev.txt"
    train_exs_with_doc = read_data(filename_train, train_questions)

    logger.info('Num train examples = %d' % len(train_exs_with_doc))

    dev_docs, dev_questions = utils.load_data_with_doc(args, filename_dev_docs)
    logger.info(len(dev_docs))
    dev_exs_with_doc = read_data(filename_dev, dev_questions)
    logger.info('Num dev examples = %d' % len(dev_exs_with_doc))

    test_docs, test_questions = utils.load_data_with_doc(
        args, filename_test_docs)
    logger.info(len(test_docs))
    test_exs_with_doc = read_data(
        sys_dir + "/data/datasets/" + dataset + "/test.txt", test_questions)
    logger.info('Num dev examples = %d' % len(test_exs_with_doc))

    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file)
        #model = DocReader.load(checkpoint_file, args)
        start_epoch = 0
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_docs)  #, train_exs, dev_exs)

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')

    train_dataset_with_doc = data.ReaderDataset_with_Doc(train_exs_with_doc,
                                                         model,
                                                         train_docs,
                                                         single_answer=True)
    train_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
        train_dataset_with_doc)
    train_loader_with_doc = torch.utils.data.DataLoader(
        train_dataset_with_doc,
        batch_size=args.batch_size,
        sampler=train_sampler_with_doc,
        num_workers=args.data_workers,
        collate_fn=vector.batchify_with_docs,
        pin_memory=args.cuda,
    )

    dev_dataset_with_doc = data.ReaderDataset_with_Doc(dev_exs_with_doc,
                                                       model,
                                                       dev_docs,
                                                       single_answer=False)
    dev_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
        dev_dataset_with_doc)
    dev_loader_with_doc = torch.utils.data.DataLoader(
        dev_dataset_with_doc,
        batch_size=args.test_batch_size,
        sampler=dev_sampler_with_doc,
        num_workers=args.data_workers,
        collate_fn=vector.batchify_with_docs,
        pin_memory=args.cuda,
    )

    test_dataset_with_doc = data.ReaderDataset_with_Doc(test_exs_with_doc,
                                                        model,
                                                        test_docs,
                                                        single_answer=False)
    test_sampler_with_doc = torch.utils.data.sampler.SequentialSampler(
        test_dataset_with_doc)
    test_loader_with_doc = torch.utils.data.DataLoader(
        test_dataset_with_doc,
        batch_size=args.test_batch_size,
        sampler=test_sampler_with_doc,
        num_workers=args.data_workers,
        collate_fn=vector.batchify_with_docs,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}

    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        if (args.mode == 'all'):
            train(args, train_loader_with_doc, model, stats,
                  train_exs_with_doc, train_docs)
        if (args.mode == 'reader'):
            pretrain_reader(args, train_loader_with_doc, model, stats,
                            train_exs_with_doc, train_docs)
        if (args.mode == 'selector'):
            pretrain_ranker(args, train_loader_with_doc, model, stats,
                            train_exs_with_doc, train_docs)

        result = validate_unofficial_with_doc(args, dev_loader_with_doc, model,
                                              stats, dev_exs_with_doc,
                                              dev_docs, 'dev')
        validate_unofficial_with_doc(args, train_loader_with_doc, model, stats,
                                     train_exs_with_doc, train_docs, 'train')
        if (dataset == 'webquestions' or dataset == 'CuratedTrec'):
            result = validate_unofficial_with_doc(args, test_loader_with_doc,
                                                  model, stats,
                                                  test_exs_with_doc, test_docs,
                                                  'test')
        else:
            validate_unofficial_with_doc(args, test_loader_with_doc, model,
                                         stats, test_exs_with_doc, test_docs,
                                         'test')
        if result[args.valid_metric] > stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                        (args.valid_metric, result[args.valid_metric],
                         stats['epoch'], model.updates))
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]
Esempio n. 6
0
def validate_unofficial_with_doc(args, data_loader, model, global_stats,
                                 exs_with_doc, docs_by_question, mode):
    """Run one full unofficial validation with docs.
    Unofficial = doesn't use SQuAD script.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    out_set = set({33, 42, 45, 70, 39})
    logger.info("validate_unofficial_with_doc")
    # Run through examples

    examples = 0
    aa = [0.0 for i in range(vector.num_docs)]
    bb = [0.0 for i in range(vector.num_docs)]
    aa_sum = 0.0
    display_num = 10
    for idx, ex_with_doc in enumerate(data_loader):
        ex = ex_with_doc[0]
        batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1]
        scores_doc_num = model.predict_with_doc(ex_with_doc)
        scores = [{} for i in range(batch_size)]

        tot_sum = [0.0 for i in range(batch_size)]
        tot_sum1 = [0.0 for i in range(batch_size)]
        neg_sum = [0.0 for i in range(batch_size)]
        min_sum = [[] for i in range(batch_size)]
        min_sum1 = [[] for i in range(batch_size)]

        for idx_doc in range(0, vector.num_docs):
            ex = ex_with_doc[idx_doc]
            pred_s, pred_e, pred_score = model.predict(ex, top_n=10)
            for i in range(batch_size):
                doc_text = docs_by_question[ex_id[i]][idx_doc % len(
                    docs_by_question[ex_id[i]])]["document"]
                has_answer_t = has_answer(args,
                                          exs_with_doc[ex_id[i]]['answer'],
                                          doc_text)

                for k in range(10):
                    try:
                        prediction = []
                        for j in range(pred_s[i][k], pred_e[i][k] + 1):
                            prediction.append(doc_text[j])
                        prediction = " ".join(prediction).lower()
                        if (prediction not in scores[i]):
                            scores[i][prediction] = 0
                        scores[i][prediction] += pred_score[i][
                            k] * scores_doc_num[i][idx_doc]
                    except:
                        pass
        for i in range(batch_size):
            _, indices = scores_doc_num[i].sort(0, descending=True)
            for j in range(0, display_num):
                idx_doc = indices[j]
                doc_text = docs_by_question[ex_id[i]][idx_doc % len(
                    docs_by_question[ex_id[i]])]["document"]
                if (has_answer(args, exs_with_doc[ex_id[i]]['answer'],
                               doc_text)[0]):

                    aa[j] = aa[j] + 1
                bb[j] = bb[j] + 1

        for i in range(batch_size):

            best_score = 0
            prediction = ""
            for key in scores[i]:
                if (scores[i][key] > best_score):
                    best_score = scores[i][key]
                    prediction = key

            # Compute metrics
            ground_truths = []
            answer = exs_with_doc[ex_id[i]]['answer']
            if (args.dataset == "CuratedTrec"):
                ground_truths = answer
            else:
                for a in answer:
                    ground_truths.append(" ".join([w for w in a]))
            #logger.info(prediction)
            #logger.info(ground_truths)
            exact_match.update(
                utils.metric_max_over_ground_truths(utils.exact_match_score,
                                                    prediction, ground_truths))
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))
            a = sorted(scores[i].items(), key=lambda d: d[1], reverse=True)

        examples += batch_size
        if (mode == "train" and examples >= 1000):
            break
    try:
        for j in range(0, display_num):
            if (j > 0):
                aa[j] = aa[j] + aa[j - 1]
                bb[j] = bb[j] + bb[j - 1]
            logger.info(aa[j] / bb[j])
    except:
        pass
    logger.info(aa_sum)
    if (mode == 'dev' or mode == 'train'):
        g.write("*" * 50 + "\n")
        g.close()
    logger.info('%s valid official with doc: Epoch = %d | EM = %.2f | ' %
                (mode, global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))

    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
Esempio n. 7
0
def train(args, data_loader, model, global_stats, exs_with_doc,
          docs_by_question):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()
    # Run one epoch
    update_step = 0
    for idx, ex_with_doc in enumerate(data_loader):
        ex = ex_with_doc[0]
        batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1]
        if (idx not in HasAnswer_Map):
            HasAnswer_list = []
            for idx_doc in range(0, vector.num_docs):
                HasAnswer = []
                for i in range(batch_size):
                    HasAnswer.append(
                        has_answer(
                            args, exs_with_doc[ex_id[i]]['answer'],
                            docs_by_question[ex_id[i]][idx_doc % len(
                                docs_by_question[ex_id[i]])]["document"]))
                HasAnswer_list.append(HasAnswer)
            HasAnswer_Map[idx] = HasAnswer_list
        else:
            HasAnswer_list = HasAnswer_Map[idx]

        weights = []
        for idx_doc in range(0, vector.num_docs):
            weights.append(1)
        weights = torch.Tensor(weights)
        idx_random = torch.multinomial(weights, int(vector.num_docs))

        HasAnswer_list_sample = []
        ex_with_doc_sample = []
        for idx_doc in idx_random:
            HasAnswer_list_sample.append(HasAnswer_list[idx_doc])
            ex_with_doc_sample.append(ex_with_doc[idx_doc])

        l_list_doc = []
        r_list_doc = []
        for idx_doc in idx_random:
            l_list = []
            r_list = []
            for i in range(batch_size):
                if HasAnswer_list[idx_doc][i][0]:
                    l_list.append(HasAnswer_list[idx_doc][i][1])
                else:
                    l_list.append((-1, -1))
            l_list_doc.append(l_list)
            r_list_doc.append(r_list)
        pred_s_list_doc = []
        pred_e_list_doc = []
        tmp_top_n = 1
        for idx_doc in idx_random:
            ex = ex_with_doc[idx_doc]
            pred_s, pred_e, pred_score = model.predict(ex, top_n=tmp_top_n)
            pred_s_list = []
            pred_e_list = []
            for i in range(batch_size):
                pred_s_list.append(pred_s[i].tolist())
                pred_e_list.append(pred_e[i].tolist())
            pred_s_list_doc.append(torch.LongTensor(pred_s_list))
            pred_e_list_doc.append(torch.LongTensor(pred_e_list))

        train_loss.update(*model.update_with_doc(
            update_step, ex_with_doc_sample, pred_s_list_doc, pred_e_list_doc,
            tmp_top_n, l_list_doc, r_list_doc, HasAnswer_list_sample))
        update_step = (update_step + 1) % 4
        if idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()
        if (idx % 200 == 199):
            validate_unofficial_with_doc(args, data_loader, model,
                                         global_stats, exs_with_doc,
                                         docs_by_question, 'train')
    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
Esempio n. 8
0
def update_evidence(args, data_loader, model, global_stats, exs_with_doc, docs_by_question):
    Top_k = args.top_k
    logger.info('Top k is set to %d' % (Top_k))

    Probability = {}
    Attention_Weight = {}

    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_prob = utils.AverageMeter()
    train_attention = utils.AverageMeter()
    epoch_time = utils.Timer()
    # Run one epoch
    update_step = 0
    for idx, ex_with_doc in enumerate(data_loader):
        ex = ex_with_doc[0]
        batch_size, question, ex_id = ex[0].size(0), ex[3], ex[-1]
        if (idx not in HasAnswer_Map):
            HasAnswer_list = []
            for idx_doc in range(0, vector.num_docs):
                HasAnswer = []
                for i in range(batch_size):
                    HasAnswer.append(has_answer(args, exs_with_doc[ex_id[i]]['answer'], docs_by_question[ex_id[i]][idx_doc%len(docs_by_question[ex_id[i]])]["document"]))
                HasAnswer_list.append(HasAnswer)
            HasAnswer_Map[idx] = HasAnswer_list
        else:
            HasAnswer_list = HasAnswer_Map[idx]

        if (idx not in Evidence_Label):
            Evidence_list = [-1] * batch_size
            Evidence_Label[idx] = Evidence_list

        # Don't shuffle when update evidence
        idx_random = range(vector.num_docs)

        HasAnswer_list_sample = []
        ex_with_doc_sample = []
        for idx_doc in idx_random:
            HasAnswer_list_sample.append(HasAnswer_list[idx_doc])
            ex_with_doc_sample.append(ex_with_doc[idx_doc])

        l_list_doc = []
        r_list_doc = []
        for idx_doc in idx_random:
            l_list = []
            r_list = []
            for i in range(batch_size):
                if HasAnswer_list[idx_doc][i][0]:
                    l_list.append(HasAnswer_list[idx_doc][i][1])
                else:
                    l_list.append((-1,-1))
            l_list_doc.append(l_list)
            r_list_doc.append(r_list)
        pred_s_list_doc = []
        pred_e_list_doc = []
        tmp_top_n = 1
        for idx_doc in idx_random:
            ex = ex_with_doc[idx_doc]
            pred_s, pred_e, pred_score = model.predict(ex,top_n = tmp_top_n)
            pred_s_list = []
            pred_e_list = []
            for i in range(batch_size):
                pred_s_list.append(pred_s[i].tolist())
                pred_e_list.append(pred_e[i].tolist())
            pred_s_list_doc.append(torch.LongTensor(pred_s_list))
            pred_e_list_doc.append(torch.LongTensor(pred_e_list))

        probs, attentions = model.update_with_doc(update_step, ex_with_doc_sample, \
                                        pred_s_list_doc, pred_e_list_doc, tmp_top_n, \
                                        l_list_doc, r_list_doc, HasAnswer_list_sample, \
                                        return_prob=True)
        train_prob.update(np.mean(probs), batch_size)
        train_attention.update(np.mean(attentions[0]), batch_size)
        update_step = (update_step + 1) % 4

        if idx % args.display_iter == 0:
            logger.info('Update Evidence: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'Average prob = %f | Average attention = %f | elapsed time = %.2f (s)' %
                        (train_prob.avg, train_attention.avg, global_stats['timer'].time()))

        for i in range(batch_size):
            key = "%d|%d" % (idx, i)
            if key in Probability or key in Attention_Weight:
                raise ValueError("%s exists in Probability or Attention_Weight" % (key))
            # Add threshold here
            Probability[key] = probs[i]
            Attention_Weight[key] = (attentions[0][i], attentions[1][i]) # max_value, max_index
        # break
    evidence_scores = {key: (attention[0], attention[1]) for key, attention in Attention_Weight.items() if attention[1] != -1}
    evidence_scores = sorted(evidence_scores.items(), key=lambda x: x[1][0], reverse=True)
    count = 0
    label_prob = []
    label_attention = []
    for key, value in evidence_scores:
        idx, i = key.split('|')
        idx = int(idx)
        i = int(i)
        if Evidence_Label[idx][i] != -1:
            continue
        count += 1
        Evidence_Label[idx][i] = value[1]
        label_prob.append(Probability[key])
        label_attention.append(Attention_Weight[key][0])
        if count >= Top_k:
            break

    logger.info('Update Evidence: Epoch %d done. Time for epoch = %.2f (s). Average prob = %f. Average attention = %f.' %
                (global_stats['epoch'], epoch_time.time(), train_prob.avg, train_attention.avg))
    logger.info('Update Evidence: Label %d examples. Average prob = %f. Average attention = %f.' %
                (count, np.mean(label_prob), np.mean(label_attention)))
Esempio n. 9
0
def train(args, data_loader, model, global_stats, exs_with_doc,
          docs_by_question):
    '''Run through one epoch of model training with the provided data loader.'''

    # Initialize meters and timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    global HasAnswer_Map

    update_step = 0

    for idx, ex_with_doc in enumerate(data_loader):
        ex = ex_with_doc[0]
        batch_size, ex_id = ex[0].size(0), ex[-1]

        # Display GPU usage statitstics every <display_stats> iterations
        show_stats = (args.show_cuda_stats
                      and (idx % args.display_stats == args.display_stats - 1))

        if (idx not in HasAnswer_Map):
            HasAnswer_list = []

            for idx_doc in range(0, num_docs):

                HasAnswer = []
                for i in range(batch_size):

                    idx_doc_i = idx_doc % len(docs_by_question[ex_id[i]])
                    answer = exs_with_doc[ex_id[i]]['answer']
                    document = docs_by_question[
                        ex_id[i]][idx_doc_i]['document']

                    # ---------------------------------------------------------
                    # Looking for the answer in the document...
                    # ---------------------------------------------------------
                    HasAnswer.append(has_answer(args, answer, document))
                    # ---------------------------------------------------------
                HasAnswer_list.append(HasAnswer)

            HasAnswer_Map[idx] = HasAnswer_list

        else:
            HasAnswer_list = HasAnswer_Map[idx]

        # Initializing weights and sampling indices...
        weights = torch.tensor([1.0 for idx_doc in range(0, num_docs)])
        idx_random = torch.multinomial(weights, int(num_docs))

        HasAnswer_list_sample = []
        ex_with_doc_sample = []

        for idx_doc in idx_random:
            HasAnswer_list_sample.append(HasAnswer_list[idx_doc])
            ex_with_doc_sample.append(ex_with_doc[idx_doc])

        l_list_doc = []
        r_list_doc = []
        for idx_doc in idx_random:

            l_list = []
            r_list = []
            for i in range(batch_size):
                if HasAnswer_list[idx_doc][i][0]:
                    l_list.append(HasAnswer_list[idx_doc][i][1])
                else:
                    l_list.append((-1, -1))

            l_list_doc.append(l_list)
            r_list_doc.append(r_list)

        # Generating predictions...
        pred_s_list_doc = []
        pred_e_list_doc = []
        tmp_top_n = 1

        # CUDA memory before forward pass
        txt_cuda(show_stats, 'before forward pass')

        for idx_doc in idx_random:
            ex = ex_with_doc[idx_doc]
            pred_s, pred_e, pred_score = model.predict(ex, top_n=tmp_top_n)

            pred_s_list = []
            pred_e_list = []
            for i in range(batch_size):
                pred_s_list.append(pred_s[i].tolist())
                pred_e_list.append(pred_e[i].tolist())

            pred_s_list_doc.append(torch.tensor(pred_s_list, dtype=torch.long))
            pred_e_list_doc.append(torch.tensor(pred_e_list, dtype=torch.long))

        # CUDA memory before backpropagation
        txt_cuda(show_stats, 'before backpropagation')

        # ---------------------------------------------------------------------
        # Updating (one epoch)...
        # ---------------------------------------------------------------------
        train_loss.update(*model.update_with_doc(
            update_step, ex_with_doc_sample, pred_s_list_doc, pred_e_list_doc,
            tmp_top_n, l_list_doc, r_list_doc, HasAnswer_list_sample))
        # ---------------------------------------------------------------------
        update_step = (update_step + 1) % 4
        # ---------------------------------------------------------------------

        # CUDA memory after backpropagation
        txt_cuda(show_stats, 'after backpropagation')
        if show_stats: gpu_usage()

        # Resetting...
        if idx % args.display_iter == 0:

            txt = 'train: Epoch = {} | iter = {}/{} | loss = {:.2f} | '
            txt += 'elapsed time = {:.2f} (s)'
            logger.info(
                txt.format(global_stats['epoch'], idx, len(data_loader),
                           train_loss.avg, global_stats['timer'].time()))

            train_loss.reset()

        # Validation...
        if show_stats:
            with torch.no_grad():
                validate_with_doc(args,
                                  data_loader,
                                  model,
                                  global_stats,
                                  exs_with_doc,
                                  docs_by_question,
                                  mode='train')

    logger.info('-' * 100)
    txt = 'train: Epoch {} done. Time for epoch = {:.2f} (s)'
    logger.info(txt.format(global_stats['epoch'], epoch_time.time()))
    logger.info('-' * 100)

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
Esempio n. 10
0
def pretrain_reader(args, data_loader, model, global_stats, exs_with_doc,
                    docs_by_question):
    '''Run through one epoch of model training with the provided data loader.'''

    # Initialize meters and timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    logger.info('pretrain_reader')

    # Run one epoch
    global HasAnswer_Map

    count_ans = 0
    count_tot = 0

    for idx, ex_with_doc in enumerate(data_loader):
        ex = ex_with_doc[0]
        batch_size, ex_id = ex[0].size(0), ex[-1]

        if (idx not in HasAnswer_Map):

            HasAnswer_list = []
            for idx_doc in range(0, num_docs):

                HasAnswer = []
                for i in range(batch_size):

                    idx_doc_i = idx_doc % len(docs_by_question[ex_id[i]])
                    answer = exs_with_doc[ex_id[i]]['answer']
                    document = docs_by_question[
                        ex_id[i]][idx_doc_i]['document']

                    # Looking for the answer in the document...
                    # ---------------------------------------------------------
                    # Here we do care about the presence/absence of answers
                    # AND their positions in the documents
                    # ---------------------------------------------------------
                    HasAnswer.append(has_answer(args, answer, document))
                    # ---------------------------------------------------------

                HasAnswer_list.append(HasAnswer)
            HasAnswer_Map[idx] = HasAnswer_list

        else:
            HasAnswer_list = HasAnswer_Map[idx]

        # Forward pass for the batch...
        for idx_doc in range(0, num_docs):

            l_list = []
            r_list = []

            # Forward pass for the batch...
            pred_s, pred_e, pred_score = model.predict(ex_with_doc[idx_doc],
                                                       top_n=1)

            for i in range(batch_size):

                if HasAnswer_list[idx_doc][i][0]:
                    count_ans += int(HasAnswer_list[idx_doc][i][0])
                    count_tot += 1

                    # Store recorded answers' positions in a list
                    l_list.append(HasAnswer_list[idx_doc][i][1])

                else:
                    # Store the most answers' predicted positions
                    l_list.append([(int(pred_s[i][0]), int(pred_e[i][0]))])

            # -----------------------------------------------------------------
            # Model update: weights are adjusted so as to minimize the loss
            # function / reducing inconsistencies between predicted and actual
            # answer positions
            # -----------------------------------------------------------------
            train_loss.update(*model.update(ex_with_doc[idx_doc], l_list,
                                            r_list, HasAnswer_list[idx_doc]))
            # -----------------------------------------------------------------

        # Resetting train loss...
        if idx % args.display_iter == 0:

            txt = 'train: Epoch = {} | iter = {}/{} | loss = {:.2f} | '
            txt += 'elapsed time = {:.2f} (s)'
            logger.info(
                txt.format(global_stats['epoch'], idx, len(data_loader),
                           train_loss.avg, global_stats['timer'].time()))

            train_loss.reset()

            txt = 'count_ans: {} | count_tot: {} | count_ans/count_tot: {:.2f} (%)'
            logger.info(
                txt.format(count_ans, count_tot,
                           100.0 * count_ans / (count_tot + 1)))

    logger.info('-' * 100)
    txt = 'train: Epoch {} done. Time for epoch = {:.2f} (s)'
    logger.info(txt.format(global_stats['epoch'], epoch_time.time()))
    logger.info('-' * 100)
Esempio n. 11
0
def pretrain_selector(args, data_loader, model, global_stats, exs_with_doc,
                      docs_by_question):
    '''Run through one epoch of model training with the provided data loader.'''

    # Initialize meters and timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    global HasAnswer_Map

    tot_ans = 0
    tot_num = 0

    for idx, ex_with_doc in enumerate(data_loader):
        ex = ex_with_doc[0]
        batch_size, ex_id = ex[0].size(0), ex[-1]

        # Update the answer mapping
        # with starting and ending positions if an answer is found
        if (idx not in HasAnswer_Map):

            HasAnswer_list = []
            for idx_doc in range(0, num_docs):

                HasAnswer = []
                for i in range(batch_size):

                    idx_doc_i = idx_doc % len(docs_by_question[ex_id[i]])
                    answer = exs_with_doc[ex_id[i]]['answer']
                    document = docs_by_question[
                        ex_id[i]][idx_doc_i]['document']

                    # ---------------------------------------------------------
                    # Looking for the answer in the document...
                    # [positions are ** ignored ** at this stage]
                    # ---------------------------------------------------------
                    bool_has, _ = has_answer(args, answer, document)
                    HasAnswer.append((bool_has, ))
                    # ---------------------------------------------------------

                HasAnswer_list.append(HasAnswer)
            HasAnswer_Map[idx] = HasAnswer_list

        else:
            HasAnswer_list = HasAnswer_Map[idx]

        # Update counters
        for idx_doc in range(0, num_docs):
            for i in range(batch_size):

                tot_ans += int(HasAnswer_list[idx_doc][i][0])
                tot_num += 1

        # Randomly sample the dataset to fit the model's input size
        weights = torch.tensor([1.0 for idx_doc in range(0, num_docs)])
        idx_random = torch.multinomial(weights, int(num_docs))

        HasAnswer_list_sample = []
        ex_with_doc_sample = []
        for idx_doc in idx_random:
            HasAnswer_list_idx_doc = [
                HasAnswer_list[idx_doc][i][0] for i in range(batch_size)
            ]
            HasAnswer_list_sample.append(HasAnswer_list_idx_doc)
            ex_with_doc_sample.append(ex_with_doc[idx_doc])

        HasAnswer_list_sample = torch.tensor(HasAnswer_list_sample,
                                             dtype=torch.long)

        # ---------------------------------------------------------------------
        # Updating train loss...
        # ---------------------------------------------------------------------
        train_loss.update(*model.pretrain_selector(ex_with_doc_sample,
                                                   HasAnswer_list_sample))
        # ---------------------------------------------------------------------

        # Resetting...
        if idx % args.display_iter == 0:

            txt = 'train: Epoch = {} | iter = {}/{} | loss = {:.2f} | '
            txt += 'elapsed time = {:.2f} (s)'
            logger.info(
                txt.format(global_stats['epoch'], idx, len(data_loader),
                           train_loss.avg, global_stats['timer'].time()))

            txt = 'tot_ans: {} | tot_num: {} | tot_ans/tot_num: {:.1f} (%)'
            logger.info(txt.format(tot_ans, tot_num,
                                   tot_ans * 100.0 / tot_num))

            train_loss.reset()

    logger.info('-' * 100)
    txt = 'tot_ans: {} | tot_num: {}'
    logger.info(txt.format(tot_ans, tot_num))
    txt = 'train: Epoch {} done. Time for epoch = {:.2f} (s)'
    logger.info(txt.format(global_stats['epoch'], epoch_time.time()))
    logger.info('-' * 100)