def reduce_metrics(self, logging_outputs, criterion): """Aggregate logging outputs from data parallel training.""" # backward compatibility for tasks that override aggregate_logging_outputs base_func = FairseqTask.aggregate_logging_outputs self_func = getattr(self, "aggregate_logging_outputs").__func__ if self_func is not base_func: utils.deprecation_warning( "Tasks should implement the reduce_metrics API. " "Falling back to deprecated aggregate_logging_outputs API.") agg_logging_outputs = self.aggregate_logging_outputs( logging_outputs, criterion) for k, v in agg_logging_outputs.items(): metrics.log_scalar(k, v) return if not any("ntokens" in log for log in logging_outputs): warnings.warn( "ntokens not found in Criterion logging outputs, cannot log wpb or wps" ) else: ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) metrics.log_scalar("wpb", ntokens, priority=180, round=1) metrics.log_speed("wps", ntokens, priority=90, round=1) if not any("nsentences" in log for log in logging_outputs): warnings.warn( "nsentences not found in Criterion logging outputs, cannot log bsz" ) else: nsentences = sum( log.get("nsentences", 0) for log in logging_outputs) metrics.log_scalar("bsz", nsentences, priority=190, round=1) criterion.__class__.reduce_metrics(logging_outputs)
def aggregate_logging_outputs( logging_outputs: List[Dict[str, Any]], ) -> Dict[str, Any]: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( 'The aggregate_logging_outputs API is deprecated. ' 'Please use the reduce_metrics API instead.') raise NotImplementedError
def aggregate_logging_outputs(self, logging_outputs, criterion): """[deprecated] Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( "The aggregate_logging_outputs API is deprecated. " "Please use the reduce_metrics API instead.") with metrics.aggregate() as agg: self.reduce_metrics(logging_outputs, criterion) return agg.get_smoothed_values()
def __init__(self, args, task): super().__init__(task=task) self.args = args utils.deprecation_warning( 'Criterions should take explicit arguments instead of an ' 'argparse.Namespace object, please update your criterion by ' 'extending FairseqCriterion instead of LegacyFairseqCriterion.')
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( 'Criterions should implement the reduce_metrics API. ' 'Falling back to deprecated aggregate_logging_outputs API.') agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) for k, v in agg_logging_outputs.items(): if k in {'nsentences', 'ntokens', 'sample_size'}: continue metrics.log_scalar(k, v)
def build_scorer(args, tgt_dict): from tools import utils if args.sacrebleu: utils.deprecation_warning( "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." ) args.scoring = "sacrebleu" if args.scoring == "bleu": from scoring import bleu return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) else: return _build_scoring(args)
def get_meter(self, name): """[deprecated] Get a specific meter by name.""" from loggings import meters if 'get_meter' not in self._warn_once: self._warn_once.add('get_meter') utils.deprecation_warning( 'Trainer.get_meter is deprecated. Please use fairseq.metrics instead.' ) train_meters = metrics.get_meters("train") if train_meters is None: train_meters = {} if name == "train_loss" and "loss" in train_meters: return train_meters["loss"] elif name == "train_nll_loss": # support for legacy train.py, which assumed this meter is # always initialized m = train_meters.get("nll_loss", None) return m or meters.AverageMeter() elif name == "wall": # support for legacy train.py, which assumed this meter is # always initialized m = metrics.get_meter("default", "wall") return m or meters.TimeMeter() elif name == "wps": m = metrics.get_meter("train", "wps") return m or meters.TimeMeter() elif name in {"valid_loss", "valid_nll_loss"}: # support for legacy train.py, which assumed these meters # are always initialized k = name[len("valid_"):] m = metrics.get_meter("valid", k) return m or meters.AverageMeter() elif name == "oom": return meters.AverageMeter() elif name in train_meters: return train_meters[name] return None