示例#1
0
def train(args, data_loader, model, global_stats):
    """Run through one epoch of model training with the provided data loader."""
    # Initialize meters + timers
    ml_loss = AverageMeter()
    epoch_time = Timer()
    model.optimizer.param_groups[0]['lr'] = \
        model.optimizer.param_groups[0]['lr'] * args.lr_decay

    pbar = tqdm(data_loader)
    pbar.set_description("%s" % 'Epoch = %d [ml_loss = x.xx]' %
                         global_stats['epoch'])

    # Run one epoch
    for idx, ex in enumerate(pbar):
        bsz = ex['batch_size']
        net_loss = model.update(ex)
        ml_loss.update(net_loss.item(), bsz)

        log_info = 'Epoch = %d [ml_loss = %.2f]' % \
                   (global_stats['epoch'], ml_loss.avg)

        pbar.set_description("%s" % log_info)
        torch.cuda.empty_cache()

    logger.info('train: Epoch %d done. Time for epoch = %.2f (s)' %
                (global_stats['epoch'], epoch_time.time()))

    # Checkpoint
    if args.checkpoint:
        model.checkpoint(args.model_file + '.checkpoint',
                         global_stats['epoch'] + 1)
def eval_accuracies(hypotheses,
                    references,
                    copy_info,
                    sources=None,
                    filename=None,
                    print_copy_info=False):
    """An unofficial evalutation helper.
     Arguments:
        hypotheses: A mapping from instance id to predicted sequences.
        references: A mapping from instance id to ground truth sequences.
        copy_info: Map of id --> copy information.
        sources: Map of id --> input text sequence.
        filename:
        print_copy_info:
    """
    assert (sorted(references.keys()) == sorted(hypotheses.keys()))

    # Compute BLEU scores
    bleu_scorer = Bleu(n=4)
    bleu, ind_bleu = bleu_scorer.compute_score(references,
                                               hypotheses,
                                               verbose=0)

    # Compute ROUGE scores
    rouge_calculator = Rouge()
    rouge_l, ind_rouge = rouge_calculator.compute_score(references, hypotheses)

    f1 = AverageMeter()
    exact_match = AverageMeter()
    fw = open(filename, 'w') if filename else None
    for key in references.keys():
        exact_match.update(
            metric_max_over_ground_truths(exact_match_score,
                                          hypotheses[key][0], references[key]))
        f1.update(
            metric_max_over_ground_truths(f1_score, hypotheses[key][0],
                                          references[key]))
        if fw:
            if copy_info is not None and print_copy_info:
                prediction = hypotheses[key][0].split()
                pred_i = [
                    word + ' [' + str(copy_info[key][j]) + ']'
                    for j, word in enumerate(prediction)
                ]
                pred_i = [' '.join(pred_i)]
            else:
                pred_i = hypotheses[key]

            logobj = OrderedDict()
            logobj['session_id'] = key
            if sources is not None:
                logobj['previous_queries'] = sources[key]
            logobj['predictions'] = pred_i
            logobj['references'] = references[key][0] if args.print_one_target \
                else references[key]
            logobj['bleu'] = ind_bleu[key]
            fw.write(json.dumps(logobj) + '\n')

    if fw: fw.close()
    return bleu, rouge_l * 100, exact_match.avg * 100, f1.avg * 100
示例#3
0
def validate_official(args, data_loader, model, global_stats=None):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.
    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = Timer()
    # Run through examples
    examples = 0
    map = AverageMeter()
    mrr = AverageMeter()
    prec_1 = AverageMeter()
    prec_3 = AverageMeter()
    prec_5 = AverageMeter()
    sources, hypotheses, references = dict(), dict(), dict()
    with torch.no_grad():
        pbar = tqdm(data_loader)
        for ex in pbar:
            batch_size = ex['batch_size'] * ex['session_len']
            outputs = model.predict(ex)

            scores = outputs['click_scores'].view(batch_size, -1).contiguous()
            labels = ex['document_labels'].view(batch_size,
                                                -1).contiguous().numpy()
            predictions = np.argsort(
                -scores.cpu().numpy())  # sort in descending order

            map.update(MAP(predictions, labels))
            mrr.update(MRR(predictions, labels))
            prec_1.update(precision_at_k(predictions, labels, 1))
            prec_3.update(precision_at_k(predictions, labels, 3))
            prec_5.update(precision_at_k(predictions, labels, 5))

            ex_ids = outputs['ex_ids']
            predictions = outputs['predictions']
            targets = outputs['targets']
            src_sequences = outputs['src_sequences']
            examples += batch_size

            for key, src, pred, tgt in zip(ex_ids, src_sequences, predictions,
                                           targets):
                hypotheses[key] = [normalize_string(p) for p in pred] \
                    if isinstance(pred, list) else [normalize_string(pred)]
                references[key] = [normalize_string(t) for t in tgt]
                sources[key] = src

            if global_stats is not None:
                pbar.set_description("%s" % 'Epoch = %d [validating ... ]' %
                                     global_stats['epoch'])
            else:
                pbar.set_description("%s" % '[evaluating ... ]')

    bleu, rouge, exact_match, f1 = eval_accuracies(
        hypotheses,
        references,
        None,
        sources=sources,
        filename=args.pred_file,
        print_copy_info=args.print_copy_info)

    bleu = [b * 100 for b in bleu] \
        if isinstance(bleu, list) else bleu
    result = dict()
    result['rouge'] = rouge
    result['bleu'] = sum(bleu) / len(bleu) \
        if isinstance(bleu, list) else bleu
    result['em'] = exact_match
    result['f1'] = f1

    result['map'] = map.avg
    result['mrr'] = mrr.avg
    result['prec@1'] = prec_1.avg
    result['prec@3'] = prec_3.avg
    result['prec@5'] = prec_5.avg

    if global_stats is None:
        logger.info(
            'test results: MAP = %.2f | MRR = %.2f | Prec@1 = %.2f | ' %
            (result['map'], result['mrr'], result['prec@1']) +
            'Prec@3 = %.2f | Prec@5 = %.2f | ' %
            (result['prec@3'], result['prec@5']) +
            'rouge_l = %.2f | bleu = [%s] | ' %
            (rouge, ", ".join(format(b, ".2f") for b in bleu)) +
            'EM = %.2f | F1 = %.2f | examples = %d | ' %
            (exact_match, f1, examples) +
            'test time = %.2f (s)' % eval_time.time())
    else:
        logger.info('dev results: MAP = %.2f | MRR = %.2f | Prec@1 = %.2f | ' %
                    (result['map'], result['mrr'], result['prec@1']) +
                    'Prec@3 = %.2f | Prec@5 = %.2f | ' %
                    (result['prec@3'], result['prec@5']) +
                    'rouge_l = %.2f | bleu = [%s] | ' %
                    (rouge, ", ".join(format(b, ".2f") for b in bleu)) +
                    'EM = %.2f | F1 = %.2f | examples = %d | ' %
                    (exact_match, f1, examples) +
                    'valid time = %.2f (s)' % eval_time.time())

    return result
示例#4
0
def validate_official(args, data_loader, model, global_stats=None):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.
    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = Timer()
    # Run through examples
    examples = 0
    map = AverageMeter()
    mrr = AverageMeter()
    prec_1 = AverageMeter()
    prec_3 = AverageMeter()
    prec_5 = AverageMeter()
    with torch.no_grad():
        pbar = tqdm(data_loader)
        for ex in pbar:
            ids, batch_size = ex['ids'], ex['batch_size']
            scores = model.predict(ex)
            predictions = np.argsort(
                -scores.cpu().numpy())  # sort in descending order
            labels = ex['label'].numpy()

            map.update(MAP(predictions, labels))
            mrr.update(MRR(predictions, labels))
            prec_1.update(precision_at_k(predictions, labels, 1))
            prec_3.update(precision_at_k(predictions, labels, 3))
            prec_5.update(precision_at_k(predictions, labels, 5))

            if global_stats is None:
                pbar.set_description('[testing ... ]')
            else:
                pbar.set_description("%s" % 'Epoch = %d [validating... ]' %
                                     global_stats['epoch'])

            examples += batch_size

    result = dict()
    result['map'] = map.avg
    result['mrr'] = mrr.avg
    result['prec@1'] = prec_1.avg
    result['prec@3'] = prec_3.avg
    result['prec@5'] = prec_5.avg

    if global_stats is None:
        logger.info(
            'test results: MAP = %.2f | MRR = %.2f | Prec@1 = %.2f | ' %
            (result['map'], result['mrr'], result['prec@1']) +
            'Prec@3 = %.2f | Prec@5 = %.2f | examples = %d | ' %
            (result['prec@3'], result['prec@5'], examples) +
            'time elapsed = %.2f (s)' % (eval_time.time()))
    else:
        logger.info('valid official: Epoch = %d | MAP = %.2f | ' %
                    (global_stats['epoch'], result['map']) +
                    'MRR = %.2f | Prec@1 = %.2f | Prec@3 = %.2f | ' %
                    (result['mrr'], result['prec@1'], result['prec@3']) +
                    'Prec@5 = %.2f | examples = %d | valid time = %.2f (s)' %
                    (result['prec@5'], examples, eval_time.time()))

    return result