Пример #1
0
def validate_unofficial(args, data_loader, model, global_stats, mode):
    """Run one full unofficial validation.
    Unofficial = doesn't use SQuAD script.
    """
    from sklearn.metrics import roc_auc_score, f1_score
    eval_time = utils.Timer()
    trigger_acc = utils.AverageMeter()
    eval_time = utils.Timer()
    start_acc = utils.AverageMeter()
    end_acc = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Make predictions
    all_pred = []
    all_pred_label = []
    all_gt = []
    examples = 0
    for ex in data_loader:
        batch_size = ex[0].size(0)

        pred_score, pred_label, pred_s, pred_e = model.predict(ex)

        target_s, target_e = ex[-4:-2]

        accuracies = eval_accuracies_rc(pred_s, target_s, pred_e, target_e)
        start_acc.update(accuracies[0], batch_size)
        end_acc.update(accuracies[1], batch_size)
        exact_match.update(accuracies[2], batch_size)

        gt_label = ex[-1]
        all_pred.extend([x[1] for x in pred_score])
        all_gt.extend(gt_label)
        all_pred_label.extend(pred_label)
        # We get metrics for independent start/end and joint start/end
        accuracies = eval_accuracies(pred_label, gt_label)
        trigger_acc.update(accuracies, batch_size)

        # If getting train accuracies, sample max 10k
        examples += batch_size
        # only test train top 10000
        if mode == 'train' and examples >= 1e4:
            break
    auc_score = roc_auc_score(all_gt, all_pred)
    f1_scores = f1_score(all_gt, all_pred_label,average=None)

    logger.info('%s valid unofficial: Epoch = %d | ' %
                (mode, global_stats['epoch'], ) +
                'neg_f1 = %.2f | pos_f1 = %.2f |  trigger_auc = %.2f | trigger_acc = %.2f | examples = %d | ' %
                (f1_scores[0], f1_scores[1], auc_score, trigger_acc.avg, examples) +
                'valid time = %.2f (s)' % eval_time.time())

    return {'auc': auc_score, 'trigger_acc': trigger_acc.avg}
Пример #2
0
def main(args):
    dev_exs = utils.load_data(args, args.dev_file)
    dev_texts = utils.load_text(args.dev_json)
    dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
    dev_answers = utils.load_answers(args.dev_json)
    model = DocReader.load(args.pretrained, args)
    model.init_optimizer()
    if args.cuda:
        model.cuda()
    if args.parallel:
        model.parallelize()
    dev_dataset = reader_data.ReaderDataset(dev_exs,
                                            model,
                                            single_answer=False)
    if args.sort_by_len:
        dev_sampler = reader_data.SortedBatchSampler(dev_dataset.lengths(),
                                                     args.test_batch_size,
                                                     shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=reader_vector.batchify,
        pin_memory=args.cuda,
    )
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    scores, pred_objs = compute_expected_metric(args, dev_loader, model, stats,
                                                dev_offsets, dev_texts,
                                                dev_answers)
    return scores, pred_objs
Пример #3
0
def train(args, data_loader, model, global_stats):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    for idx, ex in enumerate(data_loader):
        train_loss.update(*model.update(ex))

        if idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()

        if args.indexcheckpoint != -1 and idx != 0 and idx % args.indexcheckpoint == 0:
            checkpointName = args.model_file + str(
                idx / args.indexcheckpoint) + ':' + str(
                    global_stats['epoch']) + '.checkpoint'
            model.checkpoint(checkpointName, global_stats['epoch'])
            print(['new checkpoint at : %s' % checkpointName])

    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
Пример #4
0
def validate_unofficial(args, data_loader, model, global_stats, mode):
    """Run one full unofficial validation.
    Unofficial = doesn't use SQuAD script.
    """
    eval_time = utils.Timer()
    start_acc = utils.AverageMeter()
    end_acc = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Make predictions
    examples = 0
    for ex in data_loader:
        batch_size = ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)
        target_s, target_e = ex[-3:-1]

        # We get metrics for independent start/end and joint start/end
        accuracies = eval_accuracies(pred_s, target_s, pred_e, target_e)
        start_acc.update(accuracies[0], batch_size)
        end_acc.update(accuracies[1], batch_size)
        exact_match.update(accuracies[2], batch_size)

        # If getting train accuracies, sample max 10k
        examples += batch_size
        if mode == 'train' and examples >= 1e4:
            break

    logger.info('%s valid unofficial: Epoch = %d | start = %.2f | ' %
                (mode, global_stats['epoch'], start_acc.avg) +
                'end = %.2f | exact = %.2f | examples = %d | ' %
                (end_acc.avg, exact_match.avg, examples) +
                'valid time = %.2f (s)' % eval_time.time())

    return {'exact_match': exact_match.avg}
Пример #5
0
def train(args, data_loader, model, global_stats):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    for idx, ex in enumerate(data_loader):
        '''
        print('ex:::: ' + str(ex))
        sys.exit()
        '''
        train_loss.update(*model.update(ex))

        if idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()

    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
Пример #6
0
def validate_official(args, data_loader, model, global_stats, offsets, texts,
                      answers):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    clean_id_file = open(os.path.join(DATA_DIR, "clean_qids.txt"), "w+")
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    bad_examples = 0
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        chosen_offset = ex[-2]
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            if pred_s[i][0] >= len(offsets[ex_id[i]]) or pred_e[i][0] >= len(
                    offsets[ex_id[i]]):
                bad_examples += 1
                continue
            if args.use_sentence_selector:
                s_offset = chosen_offset[i][pred_s[i][0]][0]
                e_offset = chosen_offset[i][pred_e[i][0]][1]
            else:
                s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
                e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]

            # Compute metrics
            ground_truths = answers[ex_id[i]]
            exact_match.update(
                utils.metric_max_over_ground_truths(utils.exact_match_score,
                                                    prediction, ground_truths))
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))

            f1_example = utils.metric_max_over_ground_truths(
                utils.f1_score, prediction, ground_truths)

            if f1_example != 0:
                clean_id_file.write(ex_id + "\n")

        examples += batch_size

    clean_id_file.close()
    logger.info('dev valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))
    logger.info('Bad Offset Examples during official eval: %d' % bad_examples)
    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
Пример #7
0
def train(args, data_loader, data_loader_source, data_loader_target,
          train_loader_source_Q, train_loader_target_Q, model, global_stats):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    for idx, ex in enumerate(data_loader_source):

        # Calculate n_critic
        epoch = global_stats['epoch']
        n_critic = args.n_critic
        if n_critic > 0 and ((epoch == 0 and idx <= 25) or (idx % 500 == 0)):
            n_critic = 10

        train_loss.update(*model.update(ex, n_critic, epoch))

        if idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()

    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
Пример #8
0
def validate_official(args,
                      data_loader,
                      model,
                      global_stats,
                      offsets,
                      texts,
                      answers,
                      mode="dev"):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
            e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]

            # Compute metrics
            ground_truths = answers[ex_id[i]]
            exact_match.update(
                utils.metric_max_over_ground_truths(utils.exact_match_score,
                                                    prediction, ground_truths))
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))

        examples += batch_size

    logger.info(mode + ' valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))

    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
Пример #9
0
def train(args, data_loader, model, global_stats, dev_loader):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    for idx, ex in enumerate(data_loader):
        train_loss.update(*model.update(ex))
        global_stats['Loss_Train'] = float(train_loss.avg)
        if idx % args.display_iter == 0:
            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()

    ####### Fix this later
    if args.show_dev_loss:
        dev_loss = utils.AverageMeter()
        for idx, ex in enumerate(dev_loader):
            if args.cuda:
                inputs = [
                    e if e is None else Variable(e.cuda(async=True))
                    for e in ex[:5]
                ]
                target_s = Variable(ex[5].cuda(async=True))
                target_e = Variable(ex[6].cuda(async=True))
            else:
                print("No cudaaa")
                inputs = [e if e is None else e for e in ex[:5]]
                target_s = ex[5]
                target_e = ex[6]
            score_s, score_e = model.network(*inputs)
            loss = F.nll_loss(score_s, target_s) + F.nll_loss(
                score_e, target_e)
            dev_loss.update(loss.data[0], ex[0].size(0))
Пример #10
0
def train(args, data_loader, model, global_stats):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    for idx, ex in enumerate(data_loader):
        train_loss.update(*model.update(ex))

        if idx % args.display_iter == 0:
            logger.info("train: Epoch = %d | iter = %d/%d | " %
                        (global_stats["epoch"], idx, len(data_loader)) +
                        "loss = %.2f | elapsed time = %.2f (s)" %
                        (train_loss.avg, global_stats["timer"].time()))
            train_loss.reset()

    logger.info("train: Epoch %d done. Time for epoch = %.2f (s)" %
                (global_stats["epoch"], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + ".checkpoint",
                         global_stats["epoch"] + 1)
Пример #11
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = utils.load_data(args,
                                args.train_file,
                                skip_no_answer=True,
                                trainset=True)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file, trainset=False)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)

    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_exs, dev_exs)
            # COMMENTED OUT QUANTIZATION/TT SUPPORT
            # if args.embed_type != 'plain':
            #     # Jian: replace embeddings if specified by args
            #     replace_embeddings(model.network, args, logger)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info('-' * 100)
            logger.info('Counting %d most frequent question words' %
                        args.tune_partial)
            top_words = utils.top_question_words(args, train_exs,
                                                 model.word_dict)
            for word in top_words[:5]:
                logger.info(word)
            logger.info('...')
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')
    bert_tokenizer = (None if not args.use_bert_embeddings else
                      BertTokenizer.from_pretrained(args.bert_model_name,
                                                    do_lower_case='uncased'
                                                    in args.bert_model_name))
    train_dataset = data.ReaderDataset(train_exs,
                                       model,
                                       single_answer=True,
                                       bert_tokenizer=bert_tokenizer)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                args.batch_size,
                                                shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )
    dev_dataset = data.ReaderDataset(dev_exs,
                                     model,
                                     single_answer=False,
                                     bert_tokenizer=bert_tokenizer)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    f1_scores = []
    exact_match_scores = []

    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # COMMENTED OUT QUANTIZATION/TT SUPPORT
        # Log model parameter status
        # log_param_list(model.network, logger)

        # Train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        validate_unofficial(args, train_loader, model, stats, mode='train')

        # Validate unofficial (dev)
        result = validate_unofficial(args,
                                     dev_loader,
                                     model,
                                     stats,
                                     mode='dev')

        # Validate official
        if args.official_eval:
            result = validate_official(args, dev_loader, model, stats,
                                       dev_offsets, dev_texts, dev_answers)

        # Save best valid
        if result[args.valid_metric] > stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                        (args.valid_metric, result[args.valid_metric],
                         stats['epoch'], model.updates))
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]

        f1_scores.append(result['f1'])
        exact_match_scores.append(result['exact_match'])

    return f1_scores, exact_match_scores
Пример #12
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = utils.load_data(args, args.train_file, skip_no_answer=True)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        if args.standard:
            dev_texts = utils.load_text_standard(args.dev_json)
        else:
            dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        if args.standard:
            dev_answers = utils.load_answers_standard(args.dev_json)
        else:
            dev_answers = utils.load_answers(args.dev_json)
    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_exs, dev_exs)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info('-' * 100)
            logger.info('Counting %d most frequent question words' %
                        args.tune_partial)
            top_words = utils.top_question_words(
                args, train_exs, model.word_dict
            )
            for word in top_words[:5]:
                logger.info(word)
            logger.info('...')
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')
    train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                train_dataset.labels(),
                                                args.batch_size,
                                                shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )
    dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                                dev_dataset.labels(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        validate_unofficial(args, train_loader, model, stats, mode='train')

        # Validate unofficial (dev)
        result = validate_unofficial(args, dev_loader, model, stats, mode='dev')

        # Validate official
        #if args.official_eval:
        #    result = validate_official(args, dev_loader, model, stats,
        #                               dev_offsets, dev_texts, dev_answers)

        # Save best valid
        if result['auc'] > stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                        ('auc_score', result['auc'],
                         stats['epoch'], model.updates))
            logger.info('save model %s' % args.model_file)
            model.save(args.model_file)
            stats['best_valid'] = result['auc']
        model.save(os.path.join(os.path.dirname(args.model_file), str(epoch) + '.' + os.path.basename(args.model_file)))
Пример #13
0
def validate_official(args, data_loader, model, global_stats, offsets, texts,
                      questions, answers):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    em_false = {}  # cid -> (context, [(qid, question, answer)...])
    predictions = {}  # qid -> prediction
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
            e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][1][s_offset:e_offset]
            cid = texts[ex_id[i]][0]
            predictions[ex_id[i]] = prediction

            # Compute metrics
            ground_truths = answers[ex_id[i]]

            em_score = utils.metric_max_over_ground_truths(
                utils.exact_match_score, prediction, ground_truths)
            if em_score < 1:
                if cid not in em_false:
                    em_false[cid] = {
                        'text':
                        texts[ex_id[i]][1],
                        'qa': [{
                            'qid': ex_id[i],
                            'question': questions[ex_id[i]],
                            'answers': answers[ex_id[i]],
                            'prediction': prediction
                        }]
                    }
                else:
                    em_false[cid]['qa'].append({
                        'qid': ex_id[i],
                        'question': questions[ex_id[i]],
                        'answers': answers[ex_id[i]],
                        'prediction': prediction
                    })

            exact_match.update(em_score)
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))

        examples += batch_size

    logger.info('dev valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))

    return {
        'exact_match': exact_match.avg * 100,
        'f1': f1.avg * 100
    }, em_false, predictions
Пример #14
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = utils.load_data(args, args.train_file, skip_no_answer=True)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file, skip_no_answer=True)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)
        train_texts = utils.load_text(args.train_json)
        train_offsets = {ex['id']: ex['offsets'] for ex in train_exs}
        train_answers = utils.load_answers(args.train_json)

    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_exs, dev_exs)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info('-' * 100)
            logger.info('Counting %d most frequent question words' %
                        args.tune_partial)
            top_words = utils.top_question_words(args, train_exs,
                                                 model.word_dict)
            for word in top_words[:5]:
                logger.info(word)
            logger.info('...')
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')
    train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                args.batch_size,
                                                shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )
    dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=True)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {
        'timer': utils.Timer(),
        'epoch': 0,
        'best_valid': 0,
        'epoch_best': 0,
        'F1_Dev_best': 0,
        'EM_Dev_best': 0,
        'S_Dev_best': 0,
        'E_Dev_best': 0,
        'Exact_Dev_best': 0,
        'F1_Dev': 0,
        'EM_Dev': 0,
        'S_Dev': 0,
        'E_Dev': 0,
        'Exact_Dev': 0,
        "Loss_Dev": 0,
        'F1_Train': 0,
        'EM_Train': 0,
        'S_Train': 0,
        'E_Train': 0,
        'Exact_Train': 0,
        "Loss_Train": 0
    }
    loss_file = open(args.loss_file, 'w')
    header = [
        'epoch', 'F1_Dev', 'EM_Dev', 'S_Dev', 'E_Dev', 'Exact_Dev', "Loss_Dev",
        'F1_Train', 'EM_Train', 'S_Train', 'E_Train', 'Exact_Train',
        "Loss_Train"
    ]
    loss_file.write(",".join(header) + "\n")
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        train(args, train_loader, model, stats, dev_loader)

        # Validate unofficial (train)
        Train_result = validate_unofficial(args,
                                           train_loader,
                                           model,
                                           stats,
                                           mode='train')

        # Validate unofficial (dev)
        result = validate_unofficial(args,
                                     dev_loader,
                                     model,
                                     stats,
                                     mode='dev')

        # Validate official
        if args.official_eval:
            result = validate_official(args,
                                       dev_loader,
                                       model,
                                       stats,
                                       dev_offsets,
                                       dev_texts,
                                       dev_answers,
                                       mode='dev')
            stats['F1_Dev'] = result["f1"]
            stats['EM_Dev'] = result["exact_match"]
        if args.test_official_train:
            result_train_official = validate_official(args,
                                                      train_loader,
                                                      model,
                                                      stats,
                                                      train_offsets,
                                                      train_texts,
                                                      train_answers,
                                                      mode='train')

            stats['F1_Train'] = result_train_official["f1"]
            stats['EM_Train'] = result_train_official["exact_match"]
        # Save best valid
        toWrite = []
        for key, value in stats.items():
            if (key != 'timer' and key != 'best_valid' and key[-4:] != 'best'):
                toWrite.append(str(round(value, 3)))
        toWrite = ",".join(toWrite) + "\n"
        loss_file.write(toWrite)

        if result[args.valid_metric] > stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                        (args.valid_metric, result[args.valid_metric],
                         stats['epoch'], model.updates))
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]
            stats['F1_Dev_best'] = result["f1"]
            stats['EM_Dev_best'] = result["exact_match"]
            stats['S_Dev_best'] = stats['S_Dev']
            stats['E_Dev_best'] = stats['E_Dev']
            stats['Exact_Dev_best'] = stats['Exact_Dev']
            stats['epoch_best'] = stats['epoch']
    loss_file.close()
    with open(args.best_loss_file, 'w+') as logFile:
        #head = ['epoch_best','F1_Dev_best','EM_Dev_best','S_Dev_best','E_Dev_best','Exact_Dev_best','F1_Train','EM_train','S_Train','E_Train','Exact_Train','Loss_Train']
        toWrite = []
        for key, value in stats.items():
            if key[-4:] == 'best' or key[-5:] == "Train":
                toWrite.append(str(round(value, 3)))
        toWrite = ",".join(toWrite)
        logFile.write(toWrite)
Пример #15
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = utils.load_data(args, args.train_file, skip_no_answer=True)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)
    else:
        dev_texts = None
        dev_offsets = None
        dev_answers = None

    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added_words = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added_words, args.embedding_file)

                logger.info('Expanding char dictionary for new data...')
                # Add words in training + dev examples
                chars = utils.load_chars(args, train_exs + dev_exs)
                added_chars = model.expand_char_dictionary(chars)
                # Load pretrained embeddings for added words
                if args.char_embedding_file:
                    model.load_char_embeddings(added_chars, args.char_embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_exs, dev_exs)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info('-' * 100)
            logger.info('Counting %d most frequent question words' %
                        args.tune_partial)
            top_words = utils.top_question_words(
                args, train_exs, model.word_dict
            )
            for word in top_words[:5]:
                logger.info(word)
            logger.info('...')
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    if args.use_ema:
        ema = EMA(args.decay)
        model.ema = ema
        for name, param in model.network.named_parameters():
            if param.requires_grad:
                ema.register(name, param.data)

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')
    train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                args.batch_size,
                                                shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )
    dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    model_prefix = os.path.join(args.model_dir, args.model_name)

    kept_models = []
    best_model_path = ''
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        logger.info('eval: train split unofficially...')
        validate_unofficial(args, train_loader, model, stats, mode='train')

        if args.official_eval:
            # Validate official (dev)
            logger.info('eval: dev split unofficially..')
            result = validate_official(args, dev_loader, model, stats,
                                       dev_offsets, dev_texts, dev_answers)
        else:
            # Validate unofficial (dev)
            logger.info('train: evaluating dev split evaluating dev official...')
            result = validate_unofficial(args, dev_loader, model, stats, mode='dev')

        em = result['exact_match']
        f1 = result['f1']
        suffix = 'em_{:4.2f}-f1_{:4.2f}.mdl'.format(em, f1)
        # Save best valid
        model_file = '{}-epoch_{}-{}'.format(model_prefix, epoch, suffix)
        if args.valid_metric:
            if result[args.valid_metric] > stats['best_valid']:
                for f in glob.glob('{}-best*'.format(model_prefix)):
                    os.remove(f)
                logger.info('eval: dev best %s = %.2f (epoch %d, %d updates)' %
                            (args.valid_metric, result[args.valid_metric],
                             stats['epoch'], model.updates))
                model_file = '{}-best-epoch_{}-{}'.format(model_prefix, epoch, suffix)
                best_model_path = model_file
                model.save(model_file)
                stats['best_valid'] = result[args.valid_metric]
                # for f in kept_models:
                #     os.remove(f)
                kept_models.clear()
            else:
                # model.save(model_file)
                kept_models.append(model_file)
                if len(kept_models) >= args.early_stop:
                    logger.info('Finished training due to %s not improved for %d epochs, best model is at: %s' %
                                (args.valid_metric, args.early_stop, best_model_path))
                    return
        else:
            # just save model every epoch since no validation metric is given
            model.save(model_file)
Пример #16
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info("-" * 100)
    logger.info("Load data files")
    train_exs = utils.load_data(args, args.train_file, skip_no_answer=True)
    logger.info("Num train examples = %d" % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file)
    logger.info("Num dev examples = %d" % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex["id"]: ex["offsets"] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)

    # --------------------------------------------------------------------------
    # MODEL
    logger.info("-" * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + ".checkpoint"):
        # Just resume training, no modifications.
        logger.info("Found a checkpoint...")
        checkpoint_file = args.model_file + ".checkpoint"
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info("Using pretrained model...")
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info("Expanding dictionary for new data...")
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info("Training model from scratch...")
            model = init_from_scratch(args, train_exs, dev_exs)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info("-" * 100)
            logger.info("Counting %d most frequent question words" %
                        args.tune_partial)
            top_words = utils.top_question_words(args, train_exs,
                                                 model.word_dict)
            for word in top_words[:5]:
                logger.info(word)
            logger.info("...")
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info("-" * 100)
    logger.info("Make data loaders")
    train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                args.batch_size,
                                                shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )
    dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info("-" * 100)
    logger.info("CONFIG:\n%s" %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info("-" * 100)
    logger.info("Starting training...")
    stats = {"timer": utils.Timer(), "epoch": 0, "best_valid": 0}
    for epoch in range(start_epoch, args.num_epochs):
        stats["epoch"] = epoch

        # Train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        # validate_unofficial(args, train_loader, model, stats, mode='train')

        # Validate official
        if args.official_eval:
            result = validate_official(args, dev_loader, model, stats,
                                       dev_offsets, dev_texts, dev_answers)
        else:
            # Validate unofficial (dev)
            result = validate_unofficial(args,
                                         dev_loader,
                                         model,
                                         stats,
                                         mode="dev")

        # Save best valid
        if result[args.valid_metric] > stats["best_valid"]:
            logger.info("Best valid: %s = %.2f (epoch %d, %d updates)" % (
                args.valid_metric,
                result[args.valid_metric],
                stats["epoch"],
                model.updates,
            ))
            model.save(args.model_file)
            stats["best_valid"] = result[args.valid_metric]
Пример #17
0
def main(args):

    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = []
    for t_file in args.train_file:
        train_exs += utils.load_data(args, t_file, skip_no_answer=True)
    np.random.shuffle(train_exs)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)

    ## OFFSET comes from the gold sentence; the predicted sentence value shoule be maintained and sent to official validation set
    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = DocReader.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_exs, dev_exs)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info('-' * 100)
            logger.info('Counting %d most frequent question words' %
                        args.tune_partial)
            top_words = utils.top_question_words(args, train_exs,
                                                 model.word_dict)
            for word in top_words[:5]:
                logger.info(word)
            logger.info('...')
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    # Sentence selection objective : run the sentence selector as a submodule
    logger.info('-' * 100)
    logger.info('Make data loaders')
    train_dataset = reader_data.ReaderDataset(train_exs,
                                              model,
                                              single_answer=True)
    # Filter out None examples in training dataset (where sentence selection fails)

    #train_dataset.examples = [t for t in train_dataset.examples if t is not None]
    if args.sort_by_len:
        train_sampler = reader_data.SortedBatchSampler(train_dataset.lengths(),
                                                       args.batch_size,
                                                       shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    if args.use_sentence_selector:
        train_batcher = reader_vector.sentence_batchifier(model,
                                                          single_answer=True)
        # batching_function = train_batcher.batchify
        batching_function = reader_vector.batchify
    else:
        batching_function = reader_vector.batchify
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=batching_function,
        pin_memory=args.cuda,
    )
    dev_dataset = reader_data.ReaderDataset(dev_exs,
                                            model,
                                            single_answer=False)
    #dev_dataset.examples = [t for t in dev_dataset.examples if t is not None]
    if args.sort_by_len:
        dev_sampler = reader_data.SortedBatchSampler(dev_dataset.lengths(),
                                                     args.test_batch_size,
                                                     shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)

    if args.use_sentence_selector:
        dev_batcher = reader_vector.sentence_batchifier(model,
                                                        single_answer=False)
        # batching_function = dev_batcher.batchify
        batching_function = reader_vector.batchify
    else:
        batching_function = reader_vector.batchify
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=batching_function,
        pin_memory=args.cuda,
    )

    ## Dev dataset for measuring performance of the trained sentence selector
    if args.use_sentence_selector:
        dev_dataset1 = selector_data.SentenceSelectorDataset(
            dev_exs, model.sentence_selector, single_answer=False)
        #dev_dataset1.examples = [t for t in dev_dataset.examples if t is not None]
        if args.sort_by_len:
            dev_sampler1 = selector_data.SortedBatchSampler(
                dev_dataset1.lengths(), args.test_batch_size, shuffle=False)
        else:
            dev_sampler1 = torch.utils.data.sampler.SequentialSampler(
                dev_dataset1)
        dev_loader1 = torch.utils.data.DataLoader(
            dev_dataset1,
            #batch_size=args.test_batch_size,
            #sampler=dev_sampler1,
            batch_sampler=dev_sampler1,
            num_workers=args.data_workers,
            collate_fn=selector_vector.batchify,
            pin_memory=args.cuda,
        )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}

    # --------------------------------------------------------------------------
    # QUICKLY VALIDATE ON PRETRAINED MODEL

    if args.global_mode == "test":
        result1 = validate_unofficial(args,
                                      dev_loader,
                                      model,
                                      stats,
                                      mode='dev')
        result2 = validate_official(args, dev_loader, model, stats,
                                    dev_offsets, dev_texts, dev_answers)
        print(result2[args.valid_metric])
        print(result1["exact_match"])
        if args.use_sentence_selector:
            sent_stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
            #sent_selector_results = validate_selector(model.sentence_selector.args, dev_loader1, model.sentence_selector, sent_stats, mode="dev")
            #print("Sentence Selector model acheives:")
            #print(sent_selector_results["accuracy"])

        if len(args.adv_dev_json) > 0:
            validate_adversarial(args, model, stats, mode="dev")
        exit(0)

    valid_history = []
    bad_counter = 0
    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        validate_unofficial(args, train_loader, model, stats, mode='train')

        # Validate unofficial (dev)
        result = validate_unofficial(args,
                                     dev_loader,
                                     model,
                                     stats,
                                     mode='dev')

        # Validate official
        if args.official_eval:
            result = validate_official(args, dev_loader, model, stats,
                                       dev_offsets, dev_texts, dev_answers)

        # Save best valid
        if result[args.valid_metric] >= stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                        (args.valid_metric, result[args.valid_metric],
                         stats['epoch'], model.updates))
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]
            bad_counter = 0
        else:
            bad_counter += 1
        if bad_counter > args.patience:
            logger.info("Early Stopping at epoch: %d" % epoch)
            exit(0)
Пример #18
0
def validate_adversarial(args, model, global_stats, mode="dev"):
    # create dataloader for dev sets, load thier jsons, integrate the function

    for idx, dataset_file in enumerate(args.adv_dev_json):

        predictions = {}

        logger.info("Validating Adversarial Dataset %s" % dataset_file)
        exs = utils.load_data(args, args.adv_dev_file[idx])
        logger.info('Num dev examples = %d' % len(exs))
        ## Create dataloader
        dev_dataset = reader_data.ReaderDataset(exs,
                                                model,
                                                single_answer=False)
        if args.sort_by_len:
            dev_sampler = reader_data.SortedBatchSampler(dev_dataset.lengths(),
                                                         args.test_batch_size,
                                                         shuffle=False)
        else:
            dev_sampler = torch.utils.data.sampler.SequentialSampler(
                dev_dataset)
        if args.use_sentence_selector:
            dev_batcher = reader_vector.sentence_batchifier(
                model, single_answer=False)
            #batching_function = dev_batcher.batchify
            batching_function = reader_vector.batchify
        else:
            batching_function = reader_vector.batchify
        dev_loader = torch.utils.data.DataLoader(
            dev_dataset,
            batch_size=args.test_batch_size,
            sampler=dev_sampler,
            num_workers=args.data_workers,
            collate_fn=batching_function,
            pin_memory=args.cuda,
        )

        texts = utils.load_text(dataset_file)
        offsets = {ex['id']: ex['offsets'] for ex in exs}
        answers = utils.load_answers(dataset_file)

        eval_time = utils.Timer()
        f1 = utils.AverageMeter()
        exact_match = utils.AverageMeter()

        examples = 0
        bad_examples = 0
        for ex in dev_loader:
            ex_id, batch_size = ex[-1], ex[0].size(0)
            chosen_offset = ex[-2]
            pred_s, pred_e, _ = model.predict(ex)

            for i in range(batch_size):
                if pred_s[i][0] >= len(
                        offsets[ex_id[i]]) or pred_e[i][0] >= len(
                            offsets[ex_id[i]]):
                    bad_examples += 1
                    continue
                if args.use_sentence_selector:
                    s_offset = chosen_offset[i][pred_s[i][0]][0]
                    e_offset = chosen_offset[i][pred_e[i][0]][1]
                else:
                    s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
                    e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
                prediction = texts[ex_id[i]][s_offset:e_offset]

                if args.select_k > 1:
                    prediction = ""
                    offset_subset = chosen_offset[i][pred_s[i][0]:pred_e[i][0]]
                    for enum_, o in enumerate(offset_subset):
                        prediction += texts[ex_id[i]][o[0]:o[1]] + " "
                    prediction = prediction.strip()

                predictions[ex_id[i]] = prediction

                ground_truths = answers[ex_id[i]]
                exact_match.update(
                    utils.metric_max_over_ground_truths(
                        utils.exact_match_score, prediction, ground_truths))
                f1.update(
                    utils.metric_max_over_ground_truths(
                        utils.f1_score, prediction, ground_truths))

            examples += batch_size

        logger.info(
            'dev valid official for dev file %s : Epoch = %d | EM = %.2f | ' %
            (dataset_file, global_stats['epoch'], exact_match.avg * 100) +
            'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
            (f1.avg * 100, examples, eval_time.time()))

        orig_f1_score = 0.0
        orig_exact_match_score = 0.0
        adv_f1_scores = {}  # Map from original ID to F1 score
        adv_exact_match_scores = {
        }  # Map from original ID to exact match score
        adv_ids = {}
        all_ids = set()  # Set of all original IDs
        f1 = exact_match = 0
        dataset = json.load(open(dataset_file))['data']
        for article in dataset:
            for paragraph in article['paragraphs']:
                for qa in paragraph['qas']:
                    orig_id = qa['id'].split('-')[0]
                    all_ids.add(orig_id)
                    if qa['id'] not in predictions:
                        message = 'Unanswered question ' + qa[
                            'id'] + ' will receive score 0.'
                        # logger.info(message)
                        continue
                    ground_truths = list(
                        map(lambda x: x['text'], qa['answers']))
                    prediction = predictions[qa['id']]
                    cur_exact_match = utils.metric_max_over_ground_truths(
                        utils.exact_match_score, prediction, ground_truths)
                    cur_f1 = utils.metric_max_over_ground_truths(
                        utils.f1_score, prediction, ground_truths)
                    if orig_id == qa['id']:
                        # This is an original example
                        orig_f1_score += cur_f1
                        orig_exact_match_score += cur_exact_match
                        if orig_id not in adv_f1_scores:
                            # Haven't seen adversarial example yet, so use original for adversary
                            adv_ids[orig_id] = orig_id
                            adv_f1_scores[orig_id] = cur_f1
                            adv_exact_match_scores[orig_id] = cur_exact_match
                    else:
                        # This is an adversarial example
                        if (orig_id not in adv_f1_scores
                                or adv_ids[orig_id] == orig_id
                                or adv_f1_scores[orig_id] > cur_f1):
                            # Always override if currently adversary currently using orig_id
                            adv_ids[orig_id] = qa['id']
                            adv_f1_scores[orig_id] = cur_f1
                            adv_exact_match_scores[orig_id] = cur_exact_match
        orig_f1 = 100.0 * orig_f1_score / len(all_ids)
        orig_exact_match = 100.0 * orig_exact_match_score / len(all_ids)
        adv_exact_match = 100.0 * sum(
            adv_exact_match_scores.values()) / len(all_ids)
        adv_f1 = 100.0 * sum(adv_f1_scores.values()) / len(all_ids)
        logger.info(
            "For the file %s Original Exact Match : %.4f ; Original F1 : : %.4f | "
            % (dataset_file, orig_exact_match, orig_f1) +
            "Adversarial Exact Match : %.4f ; Adversarial F1 : : %.4f " %
            (adv_exact_match, adv_f1))