Esempio n. 1
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        match = sum(log.get('match', 0) for log in logging_outputs)
        total = sum(log.get('total', 0) for log in logging_outputs)

        metrics.log_scalar('match', value=match, round=0)
        metrics.log_scalar('total', value=total, round=0)
        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        if sample_size != ntokens:
            metrics.log_scalar('nll_loss',
                               loss_sum / ntokens / math.log(2),
                               ntokens,
                               round=3)
            metrics.log_derived(
                'ppl',
                lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
        else:
            metrics.log_derived(
                'ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
Esempio n. 2
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get('loss', 0) for log in logging_outputs))
        nll_loss_sum = utils.item(
            sum(log.get('nll_loss', 0) for log in logging_outputs))
        alignment_loss_sum = utils.item(
            sum(log.get('alignment_loss', 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get('ntokens', 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get('sample_size', 0) for log in logging_outputs))

        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar('nll_loss',
                           nll_loss_sum / ntokens / math.log(2),
                           ntokens,
                           round=3)
        metrics.log_scalar('alignment_loss',
                           alignment_loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_derived(
            'ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
Esempio n. 3
0
 def reduce_metrics(logging_outputs) -> None:
     """Aggregate logging outputs from data parallel training."""
     loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
     sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
     # metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
     metrics.log_scalar('loss',
                        loss_sum / sample_size,
                        sample_size,
                        round=3)
     metrics.log_derived(
         'ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
Esempio n. 4
0
def get_valid_stats(args, trainer, stats):
    if 'nll_loss' in stats and 'ppl' not in stats:
        stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(checkpoint_utils.save_checkpoint, 'best'):
        key = 'best_{0}'.format(args['checkpoint']['best_checkpoint_metric'])
        best_function = max if args['checkpoint']['maximize_best_checkpoint_metric'] else min
        stats[key] = best_function(
            checkpoint_utils.save_checkpoint.best,
            stats[args['checkpoint']['best_checkpoint_metric']],
        )
    return stats
Esempio n. 5
0
def get_training_stats(stats):
    if 'nll_loss' in stats and 'ppl' not in stats:
        stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
    stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0)
    return stats
Esempio n. 6
0
def get_valid_stats(args, trainer, stats):
    if 'nll_loss' in stats and 'ppl' not in stats:
        stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
    stats['num_updates'] = trainer.get_num_updates()
    return stats