def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=7) if sample_size != ntokens: metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=7) metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4)) else: metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg, round=4))
def reduce_metrics(cls, logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) # we divide by log(2) to convert the loss from base e to base 2 metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=7) metrics.log_scalar("nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=7) metrics.log_derived( "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4))
def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) zero = torch.scalar_tensor(0.0) num_char_errors = sum( log.get("_num_char_errors", zero) for log in logging_outputs) num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) num_word_errors = sum( log.get("_num_word_errors", zero) for log in logging_outputs) num_words = sum(log.get("_num_words", zero) for log in logging_outputs) num_pred_chars = sum( log.get("_num_pred_chars", zero) for log in logging_outputs) lm_score_sum = sum( log.get("_lm_score_sum", zero) for log in logging_outputs) vocab_seen = (sum( log.get("_vocab_seen", zero) for log in logging_outputs).bool().sum().item()) kaldi_score_sum = sum( log.get("_kaldi_score_sum", zero) for log in logging_outputs) word_lm_sum = sum( log.get("_word_lm_sum", zero) for log in logging_outputs) metrics.log_scalar_sum("_num_char_errors", num_char_errors) metrics.log_scalar_sum("_num_chars", num_chars) metrics.log_scalar_sum("_num_word_errors", num_word_errors) metrics.log_scalar_sum("_num_words", num_words) metrics.log_scalar_sum("lm_score_sum", lm_score_sum) metrics.log_scalar_sum("num_pred_chars", num_pred_chars) if self.cfg.word_kenlm_path is not None: metrics.log_scalar_sum("kaldi_score_sum", kaldi_score_sum) metrics.log_scalar_sum("word_lm_sum", word_lm_sum) if num_chars > 0: metrics.log_derived( "uer", lambda meters: meters["_num_char_errors"].sum * 100.0 / meters[ "_num_chars"].sum if meters["_num_chars"].sum > 0 else float("nan"), ) if lm_score_sum < 0 and vocab_seen > 0: metrics.log_scalar("vocab_seen_pct", vocab_seen / self.num_symbols) metrics.log_derived( "weighted_lm_ppl", lambda meters: math.pow( 10, -meters["lm_score_sum"].sum / (meters["num_pred_chars"].sum + meters["nsentences"]. sum), # account for </s> ) / meters["vocab_seen_pct"].avg**self.cfg. vocab_usage_power, ) metrics.log_derived( "lm_ppl", lambda meters: math.pow( 10, -meters["lm_score_sum"].sum / (meters["num_pred_chars"].sum + meters["nsentences"]. sum), # account for </s> ), ) else: metrics.log_derived("weighted_lm_ppl", lambda meters: float("inf")) if num_words > 0: if word_lm_sum != 0: metrics.log_derived( "word_lm_ppl", lambda meters: math.pow( 10, -meters["word_lm_sum"].sum / (meters["_num_words"].sum + meters["nsentences"].sum ), # account for </s> ), ) metrics.log_derived( "weighted_word_lm_ppl", lambda meters: math.pow( 10, -meters["word_lm_sum"].sum / (meters["_num_words"].sum + meters["nsentences"].sum ), # account for </s> ) / meters["vocab_seen_pct"].avg**self.cfg. vocab_usage_power, ) if self.cfg.word_kenlm_path is not None: metrics.log_derived( "kaldi_score", lambda meters: meters["kaldi_score_sum"].sum / meters[ "nsentences"].sum, )