Пример #1
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)

        metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=7)
        if sample_size != ntokens:
            metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=7)
            metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4))
        else:
            metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg, round=4))
Пример #2
0
    def reduce_metrics(cls, logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)

        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar("loss",
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=7)
        metrics.log_scalar("nll_loss",
                           nll_loss_sum / ntokens / math.log(2),
                           ntokens,
                           round=7)
        metrics.log_derived(
            "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg,
                                                       round=4))
Пример #3
0
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)

        zero = torch.scalar_tensor(0.0)
        num_char_errors = sum(
            log.get("_num_char_errors", zero) for log in logging_outputs)
        num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs)
        num_word_errors = sum(
            log.get("_num_word_errors", zero) for log in logging_outputs)
        num_words = sum(log.get("_num_words", zero) for log in logging_outputs)
        num_pred_chars = sum(
            log.get("_num_pred_chars", zero) for log in logging_outputs)

        lm_score_sum = sum(
            log.get("_lm_score_sum", zero) for log in logging_outputs)
        vocab_seen = (sum(
            log.get("_vocab_seen", zero)
            for log in logging_outputs).bool().sum().item())
        kaldi_score_sum = sum(
            log.get("_kaldi_score_sum", zero) for log in logging_outputs)
        word_lm_sum = sum(
            log.get("_word_lm_sum", zero) for log in logging_outputs)

        metrics.log_scalar_sum("_num_char_errors", num_char_errors)
        metrics.log_scalar_sum("_num_chars", num_chars)
        metrics.log_scalar_sum("_num_word_errors", num_word_errors)
        metrics.log_scalar_sum("_num_words", num_words)

        metrics.log_scalar_sum("lm_score_sum", lm_score_sum)
        metrics.log_scalar_sum("num_pred_chars", num_pred_chars)

        if self.cfg.word_kenlm_path is not None:
            metrics.log_scalar_sum("kaldi_score_sum", kaldi_score_sum)
            metrics.log_scalar_sum("word_lm_sum", word_lm_sum)

        if num_chars > 0:
            metrics.log_derived(
                "uer",
                lambda meters: meters["_num_char_errors"].sum * 100.0 / meters[
                    "_num_chars"].sum
                if meters["_num_chars"].sum > 0 else float("nan"),
            )

            if lm_score_sum < 0 and vocab_seen > 0:
                metrics.log_scalar("vocab_seen_pct",
                                   vocab_seen / self.num_symbols)

                metrics.log_derived(
                    "weighted_lm_ppl",
                    lambda meters: math.pow(
                        10,
                        -meters["lm_score_sum"].sum /
                        (meters["num_pred_chars"].sum + meters["nsentences"].
                         sum),  # account for </s>
                    ) / meters["vocab_seen_pct"].avg**self.cfg.
                    vocab_usage_power,
                )

                metrics.log_derived(
                    "lm_ppl",
                    lambda meters: math.pow(
                        10,
                        -meters["lm_score_sum"].sum /
                        (meters["num_pred_chars"].sum + meters["nsentences"].
                         sum),  # account for </s>
                    ),
                )
            else:
                metrics.log_derived("weighted_lm_ppl",
                                    lambda meters: float("inf"))

        if num_words > 0:
            if word_lm_sum != 0:
                metrics.log_derived(
                    "word_lm_ppl",
                    lambda meters: math.pow(
                        10,
                        -meters["word_lm_sum"].sum /
                        (meters["_num_words"].sum + meters["nsentences"].sum
                         ),  # account for </s>
                    ),
                )
                metrics.log_derived(
                    "weighted_word_lm_ppl",
                    lambda meters: math.pow(
                        10,
                        -meters["word_lm_sum"].sum /
                        (meters["_num_words"].sum + meters["nsentences"].sum
                         ),  # account for </s>
                    ) / meters["vocab_seen_pct"].avg**self.cfg.
                    vocab_usage_power,
                )

            if self.cfg.word_kenlm_path is not None:
                metrics.log_derived(
                    "kaldi_score",
                    lambda meters: meters["kaldi_score_sum"].sum / meters[
                        "nsentences"].sum,
                )