Ejemplo n.º 1
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get('loss', 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get('ntokens', 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get('nsentences', 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get('sample_size', 0) for log in logging_outputs))

        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar('ntokens', ntokens)
        metrics.log_scalar('nsentences', nsentences)

        correct = sum(log.get("correct", 0) for log in logging_outputs)
        metrics.log_scalar("_correct", correct)

        total = sum(log.get("count", 0) for log in logging_outputs)
        metrics.log_scalar("_total", total)

        correct = sum(log.get("correct_ali", 0) for log in logging_outputs)
        metrics.log_scalar("_correct_ali", correct)

        total_ali = sum(log.get("count_ali", 0) for log in logging_outputs)
        metrics.log_scalar("_total_ali", total_ali)

        if total > 0:
            metrics.log_derived(
                "accuracy",
                lambda meters: safe_round(
                    meters["_correct"].sum * 1.0 / meters["_total"].sum, 5)
                if meters["_total"].sum > 0 else float("nan"),
            )

        if total_ali > 0:
            metrics.log_derived(
                "accuracy_ali",
                lambda meters: safe_round(
                    meters["_correct_ali"].sum * 1.0 / meters["_total_ali"].
                    sum, 5) if meters["_total_ali"].sum > 0 else float("nan"),
            )

        builtin_keys = {
            'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count',
            'correct_ali', 'count_ali'
        }
        for k in logging_outputs[0]:
            if k not in builtin_keys:
                val = sum(log.get(k, 0)
                          for log in logging_outputs) / len(logging_outputs)
                if k.startswith('loss'):
                    metrics.log_scalar(k, val / sample_size / math.log(2),
                                       sample_size)
                else:
                    metrics.log_scalar(k, val, round=3)
Ejemplo n.º 2
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""

        loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
        ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs)
        )
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs)
        )

        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar("ntokens", ntokens)
        metrics.log_scalar("nsentences", nsentences)
        if sample_size != ntokens:
            metrics.log_scalar(
                "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
            )

        c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log.get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        w_total = sum(log.get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(
                    meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                )
                if meters["_c_total"].sum > 0
                else float("nan"),
            )
        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(
                    meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(
                    meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
Ejemplo n.º 3
0
def get_perplexity(loss, round=2, base=2):
    if loss is None:
        return 0.0
    try:
        return safe_round(base**loss, round)
    except OverflowError:
        return float("inf")
Ejemplo n.º 4
0
def get_perplexity(loss, round=2, base=2):
    from fairseq.logging.meters import safe_round

    if loss is None:
        return 0.0
    try:
        return safe_round(base**loss, round)
    except OverflowError:
        return float("inf")
Ejemplo n.º 5
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get("loss", 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get("ntokens", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs))

        metrics.log_scalar("loss",
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar("ntokens", ntokens)
        metrics.log_scalar("nsentences", nsentences)

        correct = sum(log.get("correct", 0) for log in logging_outputs)
        metrics.log_scalar("_correct", correct)

        total = sum(log.get("count", 0) for log in logging_outputs)
        metrics.log_scalar("_total", total)

        if total > 0:
            metrics.log_derived(
                "accuracy",
                lambda meters: safe_round(
                    meters["_correct"].sum / meters["_total"].sum, 5)
                if meters["_total"].sum > 0 else float("nan"),
            )

        builtin_keys = {
            "loss",
            "ntokens",
            "nsentences",
            "sample_size",
            "correct",
            "count",
        }

        handled_keys = reduce_probe_metrics(logging_outputs, metrics)
        builtin_keys.update(handled_keys)

        for k in logging_outputs[0]:
            if k not in builtin_keys:
                val = sum(log.get(k, 0) for log in logging_outputs)
                if k.startswith("loss"):
                    metrics.log_scalar(k,
                                       val / sample_size / math.log(2),
                                       sample_size,
                                       round=3)
                else:
                    metrics.log_scalar(k, val / len(logging_outputs), round=3)
Ejemplo n.º 6
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get("loss", 0) for log in logging_outputs))
        ctc_loss = utils.item(
            sum(log.get("ctc_loss", 0) for log in logging_outputs))
        ce_loss = utils.item(
            sum(log.get("cif_loss", 0) for log in logging_outputs))
        qua_loss = utils.item(
            sum(log.get("qua_loss", 0) for log in logging_outputs))
        emb_loss = utils.item(
            sum(log.get("emb_loss", 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get("ntokens", 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs))

        if sample_size > 0:  # training
            metrics.log_scalar("loss",
                               loss_sum / sample_size / math.log(2),
                               sample_size,
                               round=3)
            metrics.log_scalar("ctc_loss",
                               ctc_loss / ntokens / math.log(2),
                               ntokens,
                               round=3)
            metrics.log_scalar("cif_loss",
                               ce_loss / ntokens / math.log(2),
                               ntokens,
                               round=3)
            metrics.log_scalar("qua_loss",
                               qua_loss / nsentences / math.log(2),
                               nsentences,
                               round=3)
            metrics.log_scalar("emb_loss",
                               emb_loss / ntokens / math.log(2),
                               ntokens,
                               round=3)

        else:
            ctc_c_errors = sum(log['ctc'].get("c_errors", 0)
                               for log in logging_outputs)
            metrics.log_scalar("ctc_c_errors", ctc_c_errors)
            ctc_c_total = sum(log['ctc'].get("c_total", 0)
                              for log in logging_outputs)
            metrics.log_scalar("ctc_c_total", ctc_c_total)
            if ctc_c_total > 0:
                metrics.log_derived(
                    "ctc_uer",
                    lambda meters: safe_round(
                        meters["ctc_c_errors"].sum * 100.0 / meters[
                            "ctc_c_total"].sum, 3)
                    if meters["ctc_c_total"].sum > 0 else float("nan"),
                )

            c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
            metrics.log_scalar("_c_errors", c_errors)
            c_total = sum(log.get("c_total", 0) for log in logging_outputs)
            metrics.log_scalar("_c_total", c_total)
            w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
            metrics.log_scalar("_w_errors", w_errors)
            w_total = sum(log.get("w_total", 0) for log in logging_outputs)
            metrics.log_scalar("_w_total", w_total)
            if c_total > 0:
                metrics.log_derived(
                    "uer",
                    lambda meters: safe_round(
                        meters["_c_errors"].sum * 100.0 / meters["_c_total"].
                        sum, 3)
                    if meters["_c_total"].sum > 0 else float("nan"),
                )
            if w_total > 0:
                metrics.log_derived(
                    "wer",
                    lambda meters: safe_round(
                        meters["_w_errors"].sum * 100.0 / meters["_w_total"].
                        sum, 3)
                    if meters["_w_total"].sum > 0 else float("nan"),
                )
Ejemplo n.º 7
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get("loss", 0) for log in logging_outputs))

        ctc_loss_sum = utils.item(
            sum(log['ctc'].get('loss', 0) for log in logging_outputs))
        ctc_sample_size = utils.item(
            sum(log['ctc'].get('sample_size', 0) for log in logging_outputs))
        ctc_ntokens = utils.item(
            sum(log['ctc'].get('ntokens', 0) for log in logging_outputs))
        ctc_nsentences = utils.item(
            sum(log['ctc'].get('nsentences', 0) for log in logging_outputs))

        ctras_loss_sum = utils.item(
            sum(log['infonce'].get('loss', 0) for log in logging_outputs))
        ctras_sample_size = utils.item(
            sum(log['infonce'].get('sample_size', 0)
                for log in logging_outputs))
        ctras_ntokens = utils.item(
            sum(log['infonce'].get('ntokens', 0) for log in logging_outputs))
        ctras_nsentences = utils.item(
            sum(log['infonce'].get('nsentences', 0)
                for log in logging_outputs))

        metrics.log_scalar("loss", loss_sum, 1, round=3)
        metrics.log_scalar("contrastive_loss",
                           ctras_loss_sum / ctras_sample_size / math.log(2),
                           ctras_sample_size,
                           round=3)

        if ctc_sample_size == 0:
            metrics.log_scalar("ctc_loss", 0, ctc_sample_size, round=3)
        else:
            metrics.log_scalar("ctc_loss",
                               ctc_loss_sum / ctc_sample_size / math.log(2),
                               ctc_sample_size,
                               round=3)

            if ctc_sample_size != ctc_ntokens:
                metrics.log_scalar("nll_loss",
                                   ctc_loss_sum / ctc_ntokens / math.log(2),
                                   ctc_ntokens,
                                   round=3)
        c_errors = sum(log['ctc'].get("c_errors", 0)
                       for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log['ctc'].get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        w_errors = sum(log['ctc'].get("w_errors", 0)
                       for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log['ctc'].get("wv_errors", 0)
                        for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        w_total = sum(log['ctc'].get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(
                    meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                ) if meters["_c_total"].sum > 0 else float("nan"),
            )
        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(
                    meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                ) if meters["_w_total"].sum > 0 else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(
                    meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum,
                    3) if meters["_w_total"].sum > 0 else float("nan"),
            )

        metrics.log_scalar("nsentences", ctras_nsentences)
        metrics.log_scalar("ctc_sample_size", ctc_sample_size)
        metrics.log_scalar("contrastive_sample_size", ctras_sample_size)

        correct = sum(log['infonce'].get("correct", 0)
                      for log in logging_outputs)
        metrics.log_scalar("_correct", correct)

        total = sum(log['infonce'].get("count", 0) for log in logging_outputs)
        metrics.log_scalar("_total", total)

        if total > 0:
            metrics.log_derived(
                "accuracy",
                lambda meters: safe_round(
                    meters["_correct"].sum / meters["_total"].sum, 5)
                if meters["_total"].sum > 0 else float("nan"),
            )

        builtin_keys = {
            'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'
        }
        for k in logging_outputs[0]['infonce']:
            if k not in builtin_keys:
                val = sum(log['infonce'].get(k, 0)
                          for log in logging_outputs) / len(logging_outputs)
                if k.startswith('loss'):
                    metrics.log_scalar(k,
                                       val / ctras_sample_size / math.log(2),
                                       ctras_sample_size)
                else:
                    metrics.log_scalar(k, val, round=3)
Ejemplo n.º 8
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        def _get_mode(logging_outputs):
            mds = [
                SpeechTextPreTrainCompoundCriterion.value2mode(log["mode"])
                for log in logging_outputs
            ]
            if sum([1 if l != mds[0] else 0 for l in mds]) > 0:
                raise ValueError(
                    "mode in one mini-batch is expected to be the same!")
            return mds[0]

        log_mode = _get_mode(logging_outputs)
        if log_mode == "xent":
            return SpeechTextPreTrainCrossEntCriterion.reduce_metrics(
                logging_outputs)

        # ctc loss
        loss_sum = utils.item(
            sum(log.get("loss", 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get("ntokens", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs))

        metrics.log_scalar("ctc_loss",
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar("ctc_ntokens", ntokens)
        metrics.log_scalar("ctc_nsentences", nsentences)
        if sample_size != ntokens:
            metrics.log_scalar("ctc_nll_loss",
                               loss_sum / ntokens / math.log(2),
                               ntokens,
                               round=3)

        c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log.get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        w_total = sum(log.get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(
                    meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                ) if meters["_c_total"].sum > 0 else float("nan"),
            )
        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(
                    meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                ) if meters["_w_total"].sum > 0 else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(
                    meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum,
                    3) if meters["_w_total"].sum > 0 else float("nan"),
            )
Ejemplo n.º 9
0
def get_perplexity(loss, round=2, base=2):
    if loss is None:
        return 0.
    return safe_round(base**loss, round)
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        if self.args.eval_bleu:

            def sum_logs(key):
                return sum(log.get(key, 0) for log in logging_outputs)

            counts, totals = [], []
            for i in range(EVAL_BLEU_ORDER):
                counts.append(sum_logs("_bleu_counts_" + str(i)))
                totals.append(sum_logs("_bleu_totals_" + str(i)))

            if max(totals) > 0:
                # log counts as numpy arrays -- log_scalar will sum them correctly
                metrics.log_scalar("_bleu_counts", np.array(counts))
                metrics.log_scalar("_bleu_totals", np.array(totals))
                metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
                metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))

                def compute_bleu(meters):
                    import inspect
                    import sacrebleu

                    fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
                    if "smooth_method" in fn_sig:
                        smooth = {"smooth_method": "exp"}
                    else:
                        smooth = {"smooth": "exp"}
                    bleu = sacrebleu.compute_bleu(
                        correct=meters["_bleu_counts"].sum,
                        total=meters["_bleu_totals"].sum,
                        sys_len=meters["_bleu_sys_len"].sum,
                        ref_len=meters["_bleu_ref_len"].sum,
                        **smooth)
                    return round(bleu.score, 2)

                metrics.log_derived("bleu", compute_bleu)

            # wer
            w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
            metrics.log_scalar("_w_errors", w_errors)
            wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
            metrics.log_scalar("_wv_errors", wv_errors)
            w_total = sum(log.get("w_total", 0) for log in logging_outputs)
            metrics.log_scalar("_w_total", w_total)

            if w_total > 0:
                metrics.log_derived(
                    "wer",
                    lambda meters: safe_round(
                        meters["_w_errors"].sum * 100.0 / meters["_w_total"].
                        sum, 3)
                    if meters["_w_total"].sum > 0 else float("nan"),
                )
                metrics.log_derived(
                    "raw_wer",
                    lambda meters: safe_round(
                        meters["_wv_errors"].sum * 100.0 / meters["_w_total"].
                        sum, 3)
                    if meters["_w_total"].sum > 0 else float("nan"),
                )
Ejemplo n.º 11
0
 def smoothed_value(self) -> float:
     val = self.avg
     if self.round is not None and val is not None:
         val = safe_round(val, self.round)
     return val