Пример #1
0
def bert2tag_decoder(args, data_loader, dataset, model, test_input_refactor, 
                     pred_arranger, mode, stem_flag=False):
    logging.info('Start Generating Keyphrases for %s ... \n'%mode)
    test_time = utils.Timer()
    if args.dataset_class == "kp20k":stem_flag = True

    tot_examples = 0
    tot_predictions = []
    for step, batch in enumerate(tqdm(data_loader)):
        inputs, indices, lengths = test_input_refactor(batch, model.args.device)
        try:
            logit_lists = model.test_bert2tag(inputs, lengths)
        except:
            logging.error(str(traceback.format_exc()))
            continue
            
        # decode logits to phrase per batch
        params = {'examples': dataset.examples, 
                  'logit_lists':logit_lists, 
                  'indices':indices, 
                  'max_phrase_words':args.max_phrase_words, 
                  'pooling':args.tag_pooling, 
                  'return_num':Decode_Candidate_Number[args.dataset_class],
                  'stem_flag':stem_flag}      
            
        batch_predictions = generator.tag2phrase(**params)
        tot_predictions.extend(batch_predictions)
    
    candidate = pred_arranger(tot_predictions)
    return candidate
Пример #2
0
def train(args, data_loader, model, train_input_refactor, stats, writer):
    logger.info("start training %s on %s (%d epoch) || local_rank = %d..." %
                (args.model_class, args.dataset_class, stats['epoch'],
                 args.local_rank))

    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    epoch_loss = 0
    epoch_step = 0

    epoch_iterator = tqdm(data_loader,
                          desc="Train_Iteration",
                          disable=args.local_rank not in [-1, 0])
    for step, batch in enumerate(epoch_iterator):
        inputs, indices = train_input_refactor(batch, model.args.device)
        try:
            loss = model.update(step, inputs)
        except:
            logging.error(str(traceback.format_exc()))
            continue

        train_loss.update(loss)
        epoch_loss += loss
        epoch_step += 1

        if args.local_rank in [-1, 0] and step % args.display_iter == 0:
            if args.use_viso:
                writer.add_scalar('train/loss', train_loss.avg, model.updates)
                writer.add_scalar('train/lr',
                                  model.scheduler.get_lr()[0], model.updates)

            logging.info(
                'Local Rank = %d | train: Epoch = %d | iter = %d/%d | ' %
                (args.local_rank, stats['epoch'], step,
                 len(train_data_loader)) +
                'loss = %.4f | lr = %f | %d updates | elapsed time = %.2f (s) \n'
                % (train_loss.avg, model.scheduler.get_lr()[0], model.updates,
                   stats['timer'].time()))
            train_loss.reset()

    logging.info(
        'Local Rank = %d | Epoch Mean Loss = %.8f ( Epoch = %d ) | Time for epoch = %.2f (s) \n'
        % (args.local_rank,
           (epoch_loss / epoch_step), stats['epoch'], epoch_time.time()))
Пример #3
0
def validate_official(args, data_loader, model, global_stats, 
                      offsets, texts, answers):
    """ 
    Uses exact spans and same exact match/F1 score computation as in SQuAD script. 
    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matchs offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    for ex in data_loader:
        ex_id = ex[-1]
        batch_size = ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):


            s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
            e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]

            # Compute metrics
            ground_truths = answers[ex_id[i]]
            exact_match.update(utils.metric_max_over_ground_truths(
                utils.exact_match_score, prediction, ground_truths))
            f1.update(utils.metric_max_over_ground_truths(
                utils.f1_score, prediction, ground_truths))

        examples += batch_size

    logger.info('dev valid official: Epoch = %d | EM = %.2f |' % 
        (global_stats['epoch'], exact_match.avg * 100) + 
        'F1 = %.2f | examples = %d | valid time = %.2f (s)' % 
        (f1.avg * 100, examples, eval_time.time()))

    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
Пример #4
0
def train(args, data_loader, model, global_stats):
    """ Run througn one epoch of model training with the provided data loader """
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    for idx, ex in enumerate(data_loader):
        train_loss.update(*model.update(ex))

        if idx > 0 and idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d|' % 
                (global_stats['epoch'], idx, len(data_loader)) + 
                'loss = %.2f | elapsed time = %.2f (s)' % 
                (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()

    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' % 
        (global_stats['epoch'], epoch_time.time()))

    #Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_name + '.checkpoint', 
            global_stats['epoch'] + 1)
Пример #5
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = utils.load_data(args, args.train_file, args.skip_no_answer )
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file, args.skip_no_answer )
    logger.info('Num dev examples = %d' % len(dev_exs))

    # Doing Offician evals
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the text answers for each question
    dev_texts = utils.load_text(args.dev_json)
    dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
    dev_answers = utils.load_answers(args.dev_json)

    # --------------------------------------------------------------------------
    # Model
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_name + '.checkpoint'):
        pass
    else:
        logger.info('Training model from scratch ...')
        model = init_from_scratch(args, train_exs, dev_exs)
        model.init_optimizer()

    # if args.tune_partial:
    #     pass
    if args.cuda:
        model.cuda()

    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two dataset: train and dev. If sort by length it's faster
    logger.info('-' * 100)
    logger.info('Make data loaders')
    train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(), args.batch_size, shuffle=True)
    else:
        train_sampler = touch.utils.data.sampler.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size = args.batch_size,
        sampler = train_sampler,
        num_workers = args.data_workers,
        collate_fn = vector.batchify,
        pin_memory = args.cuda,
        )
    
    dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(), args.dev_batch_size, shuffle=True)
    else:
        dev_sampler = touch.utils.data.sampler.RandomSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size = args.dev_batch_size,
        sampler = dev_sampler,
        num_workers = args.data_workers,
        collate_fn = vector.batchify,
        pin_memory = args.cuda,
        )

    # --------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' % json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN&VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training... ')
    stats = {'timer':utils.Timer(), 'epoch':0, 'best_valid':0}
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        #Train
        train(args, train_loader, model, stats)

        # Validate official (dev)
        result = validate_official(args, dev_loader, model, stats, dev_offsets, dev_texts, dev_answers)
        if args.lrshrink > 0:
            _lr = model.lr_step(result[args.valid_metric])
            logger.info('learning rate is %f' % _lr)

        # Save best valid
        if result[args.valid_metric] > stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' % 
                (args.valid_metric, result[args.valid_metric],
                    stats['epoch'], model.updates))
            model_save_name = os.path.join(args.save_dir, args.model_name+str(stats['epoch'])+'.pt')
            model.save(model_save_name)
            stats['best_valid'] = result[args.valid_metric]
        logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' % 
                   (args.valid_metric, stats['best_valid'],
                    stats['epoch'], model.updates))
Пример #6
0
        " *********************************************************************** "
    )

    # -------------------------------------------------------------------------------------------
    # Method Select
    candidate_decoder = test.select_decoder(args.model_class)
    evaluate_script, main_metric_name = utils.select_eval_script(
        args.dataset_class)
    train_input_refactor, test_input_refactor = utils.select_input_refactor(
        args.model_class)

    # -------------------------------------------------------------------------------------------
    # start training
    # -------------------------------------------------------------------------------------------
    model.zero_grad()
    stats = {'timer': utils.Timer(), 'epoch': 0, main_metric_name: 0}
    for epoch in range(1, (args.max_train_epochs + 1)):
        stats['epoch'] = epoch

        # train
        train(args, train_data_loader, model, train_input_refactor, stats,
              tb_writer)

        # previous metric score
        prev_metric_score = stats[main_metric_name]

        # decode candidate phrases
        dev_candidate = candidate_decoder(args, dev_data_loader, dev_dataset,
                                          model, test_input_refactor,
                                          pred_arranger, 'dev')
        stats = evaluate_script(args,