예제 #1
0
def train(args, data_loader, model, global_stats):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    for idx, ex in enumerate(data_loader):
        train_loss.update(*model.update(ex))

        # writer.add_scalar("loss", train_loss.avg, idx)
        # for name, param in model.network.named_parameters():
        #     writer.add_histogram(name, param.clone().cpu().data.numpy(), idx)

        if idx % args.display_iter == 0:

            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()

    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
예제 #2
0
def validate_unofficial(args, data_loader, model, global_stats, mode):
    """Run one full unofficial validation.
    Unofficial = doesn't use SQuAD script.
    """
    eval_time = utils.Timer()
    acc = utils.AverageMeter()
    # end_acc = utils.AverageMeter()
    # exact_match = utils.AverageMeter()

    # Make predictions
    examples = 0
    for ex in data_loader:
        batch_size = ex[0].size(0)
        pred = model.predict(ex)
        target = ex[-2:-1]


        if args.global_mode == "test":
            preds = np.array([p[0] for p in pred])
            sent_lengths = (ex[3].sum(1) - 1).long().data.numpy()
            attacked = (preds == sent_lengths).sum()

        # We get metrics for independent start/end and joint start/end
        accuracy = eval_accuracies(pred, target, mode)
        acc.update(accuracy, batch_size)
        # end_acc.update(accuracies[1], batch_size)
        # exact_match.update(accuracies[2], batch_size)

        # If getting train accuracies, sample max 10k
        examples += batch_size
        if mode == 'train' and examples >= 1e4:
            break

    logger.info('%s valid unofficial: Epoch = %d | accuracy = %.2f | ' %
                (mode, global_stats['epoch'], acc.avg) +
                'examples = %d | ' %
                (examples) +
                'valid time = %.2f (s)' % eval_time.time())

    if args.global_mode == "test":
        print(attacked)
        print(examples)
    return {'accuracy': acc.avg}
예제 #3
0
def validate_official(args, data_loader, model, global_stats, offsets, texts,
                      answers):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
            e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]

            # Compute metrics
            ground_truths = answers[ex_id[i]]
            exact_match.update(
                utils.metric_max_over_ground_truths(utils.exact_match_score,
                                                    prediction, ground_truths))
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))

        examples += batch_size

    logger.info('dev valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))

    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
예제 #4
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    train_exs = []
    for t_file in args.train_file:
        train_exs += utils.load_data(args, t_file, skip_no_answer=True)
    # Shuffle training examples
    np.random.shuffle(train_exs)
    logger.info('Num train examples = %d' % len(train_exs))
    dev_exs = utils.load_data(args, args.dev_file)
    logger.info('Num dev examples = %d' % len(dev_exs))

    # If we are doing offician evals then we need to:
    # 1) Load the original text to retrieve spans from offsets.
    # 2) Load the (multiple) text answers for each question.
    if args.official_eval:
        dev_texts = utils.load_text(args.dev_json)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = utils.load_answers(args.dev_json)

    # --------------------------------------------------------------------------
    # MODEL
    logger.info('-' * 100)
    start_epoch = 0
    if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
        # Just resume training, no modifications.
        logger.info('Found a checkpoint...')
        checkpoint_file = args.model_file + '.checkpoint'
        model, start_epoch = SentenceSelector.load_checkpoint(checkpoint_file, args)
    else:
        # Training starts fresh. But the model state is either pretrained or
        # newly (randomly) initialized.
        if args.pretrained:
            logger.info('Using pretrained model...')
            model = SentenceSelector.load(args.pretrained, args)
            if args.expand_dictionary:
                logger.info('Expanding dictionary for new data...')
                # Add words in training + dev examples
                words = utils.load_words(args, train_exs + dev_exs)
                added = model.expand_dictionary(words)
                # Load pretrained embeddings for added words
                if args.embedding_file:
                    model.load_embeddings(added, args.embedding_file)

        else:
            logger.info('Training model from scratch...')
            model = init_from_scratch(args, train_exs, dev_exs)

        # Set up partial tuning of embeddings
        if args.tune_partial > 0:
            logger.info('-' * 100)
            logger.info('Counting %d most frequent question words' %
                        args.tune_partial)
            top_words = utils.top_question_words(
                args, train_exs, model.word_dict
            )
            for word in top_words[:5]:
                logger.info(word)
            logger.info('...')
            for word in top_words[-6:-1]:
                logger.info(word)
            model.tune_embeddings([w[0] for w in top_words])

        # Set up optimizer
        model.init_optimizer()

    # Use the GPU?
    if args.cuda:
        model.cuda()

    # Use multiple GPUs?
    if args.parallel:
        model.parallelize()

    # --------------------------------------------------------------------------
    # DATA ITERATORS
    # Two datasets: train and dev. If we sort by length it's faster.
    logger.info('-' * 100)
    logger.info('Make data loaders')
    train_dataset = data.SentenceSelectorDataset(train_exs, model, single_answer=True)
    if args.sort_by_len:
        train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
                                                args.batch_size,
                                                shuffle=True)
    else:
        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )
    dev_dataset = data.SentenceSelectorDataset(dev_exs, model, single_answer=False)
    if args.sort_by_len:
        dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
                                              args.test_batch_size,
                                              shuffle=False)
    else:
        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
    dev_loader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=args.test_batch_size,
        sampler=dev_sampler,
        num_workers=args.data_workers,
        collate_fn=vector.batchify,
        pin_memory=args.cuda,
    )

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}

    ## allow toggle mode that will let you evaluate on whatever dev set you give it; preload model
    if args.global_mode == "test":
        result = validate_unofficial(args, dev_loader, model, stats, mode='dev')
        print(result[args.valid_metric])
        exit(0)

    for epoch in range(start_epoch, args.num_epochs):
        stats['epoch'] = epoch

        # Train
        train(args, train_loader, model, stats)

        # Validate unofficial (train)
        validate_unofficial(args, train_loader, model, stats, mode='train')

        # Validate unofficial (dev)
        result = validate_unofficial(args, dev_loader, model, stats, mode='dev')

        # Validate official
        if args.official_eval:
            result = validate_official(args, dev_loader, model, stats,
                                       dev_offsets, dev_texts, dev_answers)

        # Save best valid
        if result[args.valid_metric] > stats['best_valid']:
            logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
                        (args.valid_metric, result[args.valid_metric],
                         stats['epoch'], model.updates))
            model.save(args.model_file)
            stats['best_valid'] = result[args.valid_metric]
        if epoch % 5 == 0:
            model.save(args.model_file + ".dummy")
예제 #5
0
def validate_unofficial(args, data_loader, model, global_stats, mode):
    """Run one full unofficial validation.
    Unofficial = doesn't use SQuAD script.
    """
    eval_time = utils.Timer()
    acc = utils.AverageMeter()
    # end_acc = utils.AverageMeter()
    # exact_match = utils.AverageMeter()
    # fout = open(os.path.join(DATA_DIR,DUMP_FILE), "w+")

    # Make predictions
    examples = 0
    attacked = 0
    attacked_correct = 0
    correct = 0
    non_adv = 0
    for ex in data_loader:
        batch_size = ex[0].size(0)
        pred = model.predict(ex, top_n=3)
        target = ex[-2:-1]

        if args.global_mode == "test":
            preds = np.array([p[0] for p in pred])
            sent_lengths = (ex[3].sum(1)).long().data.numpy()
            #attacked += (pred == sent_lengths - 1).sum()
            for enum_, p in enumerate(pred):
                # fout.write("%s\t%d\t%d\t%d\n"%(ex[-1][enum_], p[0], p[1], p[2]))
                true_flag = False
                if "high" not in ex[-1][enum_]:
                    non_adv += 1
                    continue
                for q in p:
                    if q in target[0][enum_]:
                        correct += 1
                        true_flag = True
                for q in p:
                    if q == sent_lengths[enum_] - 1:
                        attacked += 1
                    if q == sent_lengths[enum_] - 1 and true_flag:
                        attacked_correct += 1
            #attacked += (pred == sent_lengths - 1).astype(int).sum()
        # We get metrics for independent start/end and joint start/end
        accuracy = eval_accuracies(pred, target, mode)
        acc.update(accuracy, batch_size)
        # end_acc.update(accuracies[1], batch_size)
        # exact_match.update(accuracies[2], batch_size)

        # If getting train accuracies, sample max 10k
        examples += batch_size
        if examples % 1000 == 0:
            print("%d examples completed" % examples)
        if mode == 'train' and examples >= 1e4:
            break
    # fout.close()
    logger.info('%s valid unofficial: Epoch = %d | accuracy = %.2f | ' %
                (mode, global_stats['epoch'], acc.avg) + 'examples = %d | ' %
                (examples) + 'valid time = %.2f (s)' % eval_time.time())

    if args.global_mode == "test":
        print("Number of examples attacked succesfully: %d" % attacked)
        print("Total number of correct adversarial examples: %d" % correct)
        print("Number of examples adversarial examples: %d" % non_adv)
        print("Number of correct examples attacked succesfully: %d" %
              attacked_correct)
        print("Number of examples: %d" % examples)
    return {'accuracy': acc.avg}