Ejemplo n.º 1
0
    def reduce_metrics(self, logging_outputs, criterion):
        """Aggregate logging outputs from data parallel training."""
        # backward compatibility for tasks that override aggregate_logging_outputs
        base_func = FairseqTask.aggregate_logging_outputs
        self_func = getattr(self, "aggregate_logging_outputs").__func__
        if self_func is not base_func:
            utils.deprecation_warning(
                "Tasks should implement the reduce_metrics API. "
                "Falling back to deprecated aggregate_logging_outputs API.")
            agg_logging_outputs = self.aggregate_logging_outputs(
                logging_outputs, criterion)
            for k, v in agg_logging_outputs.items():
                metrics.log_scalar(k, v)
            return

        if not any("ntokens" in log for log in logging_outputs):
            warnings.warn(
                "ntokens not found in Criterion logging outputs, cannot log wpb or wps"
            )
        else:
            ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
            metrics.log_scalar("wpb", ntokens, priority=180, round=1)
            metrics.log_speed("wps", ntokens, priority=90, round=1)

        if not any("nsentences" in log for log in logging_outputs):
            warnings.warn(
                "nsentences not found in Criterion logging outputs, cannot log bsz"
            )
        else:
            nsentences = sum(
                log.get("nsentences", 0) for log in logging_outputs)
            metrics.log_scalar("bsz", nsentences, priority=190, round=1)

        criterion.__class__.reduce_metrics(logging_outputs)
Ejemplo n.º 2
0
 def aggregate_logging_outputs(
     logging_outputs: List[Dict[str, Any]], ) -> Dict[str, Any]:
     """Aggregate logging outputs from data parallel training."""
     utils.deprecation_warning(
         'The aggregate_logging_outputs API is deprecated. '
         'Please use the reduce_metrics API instead.')
     raise NotImplementedError
Ejemplo n.º 3
0
 def aggregate_logging_outputs(self, logging_outputs, criterion):
     """[deprecated] Aggregate logging outputs from data parallel training."""
     utils.deprecation_warning(
         "The aggregate_logging_outputs API is deprecated. "
         "Please use the reduce_metrics API instead.")
     with metrics.aggregate() as agg:
         self.reduce_metrics(logging_outputs, criterion)
         return agg.get_smoothed_values()
Ejemplo n.º 4
0
    def __init__(self, args, task):
        super().__init__(task=task)
        self.args = args

        utils.deprecation_warning(
            'Criterions should take explicit arguments instead of an '
            'argparse.Namespace object, please update your criterion by '
            'extending FairseqCriterion instead of LegacyFairseqCriterion.')
Ejemplo n.º 5
0
 def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
     """Aggregate logging outputs from data parallel training."""
     utils.deprecation_warning(
         'Criterions should implement the reduce_metrics API. '
         'Falling back to deprecated aggregate_logging_outputs API.')
     agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
     for k, v in agg_logging_outputs.items():
         if k in {'nsentences', 'ntokens', 'sample_size'}:
             continue
         metrics.log_scalar(k, v)
Ejemplo n.º 6
0
def build_scorer(args, tgt_dict):
    from tools import utils

    if args.sacrebleu:
        utils.deprecation_warning(
            "--sacrebleu is deprecated. Please use --scoring sacrebleu instead."
        )
        args.scoring = "sacrebleu"
    if args.scoring == "bleu":
        from scoring import bleu
        return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    else:
        return _build_scoring(args)
Ejemplo n.º 7
0
    def get_meter(self, name):
        """[deprecated] Get a specific meter by name."""
        from loggings import meters

        if 'get_meter' not in self._warn_once:
            self._warn_once.add('get_meter')
            utils.deprecation_warning(
                'Trainer.get_meter is deprecated. Please use fairseq.metrics instead.'
            )

        train_meters = metrics.get_meters("train")
        if train_meters is None:
            train_meters = {}

        if name == "train_loss" and "loss" in train_meters:
            return train_meters["loss"]
        elif name == "train_nll_loss":
            # support for legacy train.py, which assumed this meter is
            # always initialized
            m = train_meters.get("nll_loss", None)
            return m or meters.AverageMeter()
        elif name == "wall":
            # support for legacy train.py, which assumed this meter is
            # always initialized
            m = metrics.get_meter("default", "wall")
            return m or meters.TimeMeter()
        elif name == "wps":
            m = metrics.get_meter("train", "wps")
            return m or meters.TimeMeter()
        elif name in {"valid_loss", "valid_nll_loss"}:
            # support for legacy train.py, which assumed these meters
            # are always initialized
            k = name[len("valid_"):]
            m = metrics.get_meter("valid", k)
            return m or meters.AverageMeter()
        elif name == "oom":
            return meters.AverageMeter()
        elif name in train_meters:
            return train_meters[name]
        return None