예제 #1
0
def train(args, data_loader, model, global_stats):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    train_loss = utils.AverageMeter()
    epoch_time = utils.Timer()

    # Run one epoch
    for idx, ex in enumerate(data_loader):
        train_loss.update(*model.update(ex))

        # writer.add_scalar("loss", train_loss.avg, idx)
        # for name, param in model.network.named_parameters():
        #     writer.add_histogram(name, param.clone().cpu().data.numpy(), idx)

        if idx % args.display_iter == 0:

            logger.info('train: Epoch = %d | iter = %d/%d | ' %
                        (global_stats['epoch'], idx, len(data_loader)) +
                        'loss = %.2f | elapsed time = %.2f (s)' %
                        (train_loss.avg, global_stats['timer'].time()))
            train_loss.reset()

    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
예제 #2
0
def validate_official(args, data_loader, model, global_stats, offsets, texts,
                      answers):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
            e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]

            # Compute metrics
            ground_truths = answers[ex_id[i]]
            exact_match.update(
                utils.metric_max_over_ground_truths(utils.exact_match_score,
                                                    prediction, ground_truths))
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))

        examples += batch_size

    logger.info('dev valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))

    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
예제 #3
0
def eval_accuracies(pred, target, mode="dev"):
    """An unofficial evalutation helper.
    Compute exact start/end/complete match accuracies for a batch.
    """
    # Convert 1D tensors to lists of lists (compatibility)
    if torch.is_tensor(target):
        target = [[e] for e in target]
    elif torch.is_tensor(target[0]):
        target = [[e.item()] for e in target[0]]
    else:
        target = target[0]

        # target_e = [[e] for e in target_e]

    ## make changes according to mode

    # Compute accuracies from targets
    batch_size = len(pred)
    accuracy = utils.AverageMeter()
    # end = utils.AverageMeter()
    # em = utils.AverageMeter()
    for i in range(batch_size):
        # Start matches
        flag = False
        for j in pred[i]:
            if j in target[i]:
                flag = True
                break
        if flag:
            accuracy.update(1)
        else:
            accuracy.update(0)

        # End matches
        # if pred_e[i] in target_e[i]:
        #     end.update(1)
        # else:
        #     end.update(0)
#
# # Both start and end match
# if any([1 for _s, _e in zip(target_s[i], target_e[i])
#         if _s == pred_s[i] and _e == pred_e[i]]):
#     em.update(1)
# else:
#     em.update(0)
    return accuracy.avg * 100
예제 #4
0
def validate_unofficial(args, data_loader, model, global_stats, mode):
    """Run one full unofficial validation.
    Unofficial = doesn't use SQuAD script.
    """
    eval_time = utils.Timer()
    acc = utils.AverageMeter()
    # end_acc = utils.AverageMeter()
    # exact_match = utils.AverageMeter()

    # Make predictions
    examples = 0
    for ex in data_loader:
        batch_size = ex[0].size(0)
        pred = model.predict(ex)
        target = ex[-2:-1]


        if args.global_mode == "test":
            preds = np.array([p[0] for p in pred])
            sent_lengths = (ex[3].sum(1) - 1).long().data.numpy()
            attacked = (preds == sent_lengths).sum()

        # We get metrics for independent start/end and joint start/end
        accuracy = eval_accuracies(pred, target, mode)
        acc.update(accuracy, batch_size)
        # end_acc.update(accuracies[1], batch_size)
        # exact_match.update(accuracies[2], batch_size)

        # If getting train accuracies, sample max 10k
        examples += batch_size
        if mode == 'train' and examples >= 1e4:
            break

    logger.info('%s valid unofficial: Epoch = %d | accuracy = %.2f | ' %
                (mode, global_stats['epoch'], acc.avg) +
                'examples = %d | ' %
                (examples) +
                'valid time = %.2f (s)' % eval_time.time())

    if args.global_mode == "test":
        print(attacked)
        print(examples)
    return {'accuracy': acc.avg}
예제 #5
0
def validate_unofficial(args, data_loader, model, global_stats, mode):
    """Run one full unofficial validation.
    Unofficial = doesn't use SQuAD script.
    """
    eval_time = utils.Timer()
    acc = utils.AverageMeter()
    # end_acc = utils.AverageMeter()
    # exact_match = utils.AverageMeter()
    # fout = open(os.path.join(DATA_DIR,DUMP_FILE), "w+")

    # Make predictions
    examples = 0
    attacked = 0
    attacked_correct = 0
    correct = 0
    non_adv = 0
    for ex in data_loader:
        batch_size = ex[0].size(0)
        pred = model.predict(ex, top_n=3)
        target = ex[-2:-1]

        if args.global_mode == "test":
            preds = np.array([p[0] for p in pred])
            sent_lengths = (ex[3].sum(1)).long().data.numpy()
            #attacked += (pred == sent_lengths - 1).sum()
            for enum_, p in enumerate(pred):
                # fout.write("%s\t%d\t%d\t%d\n"%(ex[-1][enum_], p[0], p[1], p[2]))
                true_flag = False
                if "high" not in ex[-1][enum_]:
                    non_adv += 1
                    continue
                for q in p:
                    if q in target[0][enum_]:
                        correct += 1
                        true_flag = True
                for q in p:
                    if q == sent_lengths[enum_] - 1:
                        attacked += 1
                    if q == sent_lengths[enum_] - 1 and true_flag:
                        attacked_correct += 1
            #attacked += (pred == sent_lengths - 1).astype(int).sum()
        # We get metrics for independent start/end and joint start/end
        accuracy = eval_accuracies(pred, target, mode)
        acc.update(accuracy, batch_size)
        # end_acc.update(accuracies[1], batch_size)
        # exact_match.update(accuracies[2], batch_size)

        # If getting train accuracies, sample max 10k
        examples += batch_size
        if examples % 1000 == 0:
            print("%d examples completed" % examples)
        if mode == 'train' and examples >= 1e4:
            break
    # fout.close()
    logger.info('%s valid unofficial: Epoch = %d | accuracy = %.2f | ' %
                (mode, global_stats['epoch'], acc.avg) + 'examples = %d | ' %
                (examples) + 'valid time = %.2f (s)' % eval_time.time())

    if args.global_mode == "test":
        print("Number of examples attacked succesfully: %d" % attacked)
        print("Total number of correct adversarial examples: %d" % correct)
        print("Number of examples adversarial examples: %d" % non_adv)
        print("Number of correct examples attacked succesfully: %d" %
              attacked_correct)
        print("Number of examples: %d" % examples)
    return {'accuracy': acc.avg}