Exemplo n.º 1
0
    def testEditDistanceSkipsEmptyTokens(self):
        ref = "a b c d e   f g h"
        hyp = "a b c d e f g h"
        self.assertEqual((0, 0, 0, 0), decoder_utils.EditDistance(ref, hyp))

        ref = "a b c d e f g h"
        hyp = "a b c d e   f g h"
        self.assertEqual((0, 0, 0, 0), decoder_utils.EditDistance(ref, hyp))
Exemplo n.º 2
0
    def testEditDistance1(self):
        ref = "a b c d e f g h"
        hyp = "a b c d e f g h"
        self.assertEqual((0, 0, 0, 0), decoder_utils.EditDistance(ref, hyp))

        ref = "a b c d e f g h"
        hyp = "a b d e f g h"
        self.assertEqual((0, 0, 1, 1), decoder_utils.EditDistance(ref, hyp))

        ref = "a b c d e f g h"
        hyp = "a b c i d e f g h"
        self.assertEqual((1, 0, 0, 1), decoder_utils.EditDistance(ref, hyp))

        ref = "a b c d e f g h"
        hyp = "a b c i e f g h"
        self.assertEqual((0, 1, 0, 1), decoder_utils.EditDistance(ref, hyp))

        ref = "a b c d e f g j h"
        hyp = "a b c i d e f g h"
        self.assertEqual((1, 0, 1, 2), decoder_utils.EditDistance(ref, hyp))

        ref = "a b c d e f g j h"
        hyp = "a b c i e f g h k"
        self.assertEqual((1, 1, 1, 3), decoder_utils.EditDistance(ref, hyp))

        ref = ""
        hyp = ""
        self.assertEqual((0, 0, 0, 0), decoder_utils.EditDistance(ref, hyp))
        ref = ""
        hyp = "a b c"
        self.assertEqual((3, 0, 0, 3), decoder_utils.EditDistance(ref, hyp))

        ref = "a b c d"
        hyp = ""
        self.assertEqual((0, 0, 4, 4), decoder_utils.EditDistance(ref, hyp))
Exemplo n.º 3
0
    def PostProcessDecodeOut(self, dec_out_dict, dec_metrics_dict):
        p = self.params
        topk_scores = dec_out_dict['topk_scores']
        topk_decoded = dec_out_dict['topk_decoded']
        transcripts = dec_out_dict['transcripts']
        utt_id = dec_out_dict['utt_id']
        norm_wer_errors = dec_out_dict['norm_wer_errors']
        norm_wer_words = dec_out_dict['norm_wer_words']
        target_labels = dec_out_dict['target_labels']
        target_paddings = dec_out_dict['target_paddings']
        topk_ids = dec_out_dict['topk_ids']
        topk_lens = dec_out_dict['topk_lens']
        assert len(transcripts) == len(target_labels)
        assert len(transcripts) == len(target_paddings)
        assert len(transcripts) == len(topk_decoded)
        assert len(utt_id) == len(transcripts)
        assert (len(topk_ids) == p.decoder.beam_search.num_hyps_per_beam *
                len(transcripts))
        assert len(norm_wer_errors) == len(transcripts)
        assert len(norm_wer_words) == len(transcripts)

        dec_metrics_dict['num_samples_in_batch'].Update(len(transcripts))

        def GetRefIds(ref_ids, ref_paddinds):
            assert len(ref_ids) == len(ref_paddinds)
            return_ids = []
            for i in range(len(ref_ids)):
                if ref_paddinds[i] == 0:
                    return_ids.append(ref_ids[i])
            return return_ids

        total_errs = 0
        total_oracle_errs = 0
        total_ref_words = 0
        total_token_errs = 0
        total_ref_tokens = 0
        total_norm_wer_errs = 0
        total_norm_wer_words = 0
        total_accurate_sentences = 0
        key_value_pairs = []
        for i in range(len(transcripts)):
            ref_str = transcripts[i]
            tf.logging.info('utt_id: %s', utt_id[i])
            tf.logging.info('  ref_str: %s', ref_str)
            hyps = topk_decoded[i]
            ref_ids = GetRefIds(target_labels[i], target_paddings[i])
            hyp_index = i * p.decoder.beam_search.num_hyps_per_beam
            top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]]
            total_ref_tokens += len(ref_ids)
            _, _, _, token_errs = decoder_utils.EditDistanceInIds(
                ref_ids, top_hyp_ids)
            total_token_errs += token_errs

            assert p.decoder.beam_search.num_hyps_per_beam == len(hyps)
            filtered_ref = decoder_utils.FilterNoise(ref_str)
            filtered_ref = decoder_utils.FilterEpsilon(filtered_ref)
            oracle_errs = norm_wer_errors[i][0]
            for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)):
                tf.logging.info('  %f: %s', score, hyp_str)
                filtered_hyp = decoder_utils.FilterNoise(hyp_str)
                filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp)
                ins, subs, dels, errs = decoder_utils.EditDistance(
                    filtered_ref, filtered_hyp)
                # Note that these numbers are not consistent with what is used to
                # compute normalized WER.  In particular, these numbers will be inflated
                # when the transcript contains punctuation.
                tf.logging.info('  ins: %d, subs: %d, del: %d, total: %d', ins,
                                subs, dels, errs)
                hyp_norm_wer_errors = norm_wer_errors[i][n]
                hyp_norm_wer_words = norm_wer_words[i][n]
                # Only aggregate scores of the top hypothesis.
                if n == 0:
                    total_errs += errs
                    total_ref_words += len(
                        decoder_utils.Tokenize(filtered_ref))
                    total_norm_wer_errs += hyp_norm_wer_errors
                    if hyp_norm_wer_errors == 0:
                        total_accurate_sentences += 1
                    total_norm_wer_words += hyp_norm_wer_words
                    dec_metrics_dict['corpus_bleu'].Update(
                        filtered_ref, filtered_hyp)
                if hyp_norm_wer_errors < oracle_errs:
                    oracle_errs = hyp_norm_wer_errors
            total_oracle_errs += oracle_errs

        dec_metrics_dict['wer'].Update(total_errs / total_ref_words,
                                       total_ref_words)
        dec_metrics_dict['oracle_norm_wer'].Update(
            total_oracle_errs / total_ref_words, total_ref_words)
        dec_metrics_dict['sacc'].Update(
            total_accurate_sentences / len(transcripts), len(transcripts))
        dec_metrics_dict['norm_wer'].Update(
            total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words)
        dec_metrics_dict['ter'].Update(total_token_errs / total_ref_tokens,
                                       total_ref_tokens)

        # Update any additional metrics.
        dec_metrics_dict = self.UpdateAdditionalMetrics(
            dec_out_dict, dec_metrics_dict)
        return key_value_pairs
Exemplo n.º 4
0
    def PostProcess(self, dec_out_dict, dec_metrics_dict):
        p = self.params
        assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys())
        topk_scores = dec_out_dict['topk_scores']
        topk_decoded = dec_out_dict['topk_decoded']
        transcripts = dec_out_dict['transcripts']
        if not py_utils.use_tpu():
            utt_id = dec_out_dict['utt_id']
            assert len(utt_id) == len(transcripts)
        norm_wer_errors = dec_out_dict['norm_wer_errors']
        norm_wer_words = dec_out_dict['norm_wer_words']
        target_labels = dec_out_dict['target_labels']
        target_paddings = dec_out_dict['target_paddings']
        topk_ids = dec_out_dict['topk_ids']
        topk_lens = dec_out_dict['topk_lens']
        assert len(transcripts) == len(target_labels)
        assert len(transcripts) == len(target_paddings)
        assert len(transcripts) == len(topk_decoded)
        assert len(norm_wer_errors) == len(transcripts)
        assert len(norm_wer_words) == len(transcripts)

        num_samples_in_batch = len(transcripts)
        dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch)

        def GetRefIds(ref_ids, ref_paddinds):
            assert len(ref_ids) == len(ref_paddinds)
            return_ids = []
            for i in range(len(ref_ids)):
                if ref_paddinds[i] == 0:
                    return_ids.append(ref_ids[i])
            return return_ids

        total_norm_wer_errs = norm_wer_errors[:, 0].sum()
        total_norm_wer_words = norm_wer_words[:, 0].sum()

        dec_metrics_dict['norm_wer'].Update(
            total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words)

        for ref_str, hyps in zip(transcripts, topk_decoded):
            filtered_ref = decoder_utils.FilterNoise(ref_str)
            filtered_ref = decoder_utils.FilterEpsilon(filtered_ref)
            filtered_hyp = decoder_utils.FilterNoise(hyps[0])
            filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp)
            dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp)

        total_errs = 0
        total_oracle_errs = 0
        total_ref_words = 0
        total_token_errs = 0
        total_ref_tokens = 0
        total_accurate_sentences = 0
        key_value_pairs = []

        if p.include_auxiliary_metrics:
            for i in range(len(transcripts)):
                ref_str = transcripts[i]
                if not py_utils.use_tpu():
                    tf.logging.info('utt_id: %s', utt_id[i])
                if self.cluster.add_summary:
                    tf.logging.info(
                        '  ref_str: %s',
                        ref_str.decode('utf-8') if p.log_utf8 else ref_str)
                hyps = topk_decoded[i]
                num_hyps_per_beam = len(hyps)
                ref_ids = GetRefIds(target_labels[i], target_paddings[i])
                hyp_index = i * num_hyps_per_beam
                top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]]
                if self.cluster.add_summary:
                    tf.logging.info('  ref_ids: %s', ref_ids)
                    tf.logging.info('  top_hyp_ids: %s', top_hyp_ids)
                total_ref_tokens += len(ref_ids)
                _, _, _, token_errs = decoder_utils.EditDistanceInIds(
                    ref_ids, top_hyp_ids)
                total_token_errs += token_errs

                filtered_ref = decoder_utils.FilterNoise(ref_str)
                filtered_ref = decoder_utils.FilterEpsilon(filtered_ref)
                oracle_errs = norm_wer_errors[i][0]
                for n, (score, hyp_str) in enumerate(zip(topk_scores[i],
                                                         hyps)):
                    if self.cluster.add_summary:
                        tf.logging.info(
                            '  %f: %s', score,
                            hyp_str.decode('utf-8') if p.log_utf8 else hyp_str)
                    filtered_hyp = decoder_utils.FilterNoise(hyp_str)
                    filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp)
                    ins, subs, dels, errs = decoder_utils.EditDistance(
                        filtered_ref, filtered_hyp)
                    # Note that these numbers are not consistent with what is used to
                    # compute normalized WER.  In particular, these numbers will be
                    # inflated when the transcript contains punctuation.
                    tf.logging.info('  ins: %d, subs: %d, del: %d, total: %d',
                                    ins, subs, dels, errs)
                    # Only aggregate scores of the top hypothesis.
                    if n == 0:
                        total_errs += errs
                        total_ref_words += len(
                            decoder_utils.Tokenize(filtered_ref))
                        if norm_wer_errors[i, n] == 0:
                            total_accurate_sentences += 1
                    oracle_errs = min(oracle_errs, norm_wer_errors[i, n])
                total_oracle_errs += oracle_errs

            dec_metrics_dict['wer'].Update(
                total_errs / max(1., total_ref_words), total_ref_words)
            dec_metrics_dict['oracle_norm_wer'].Update(
                total_oracle_errs / max(1., total_ref_words), total_ref_words)
            dec_metrics_dict['sacc'].Update(
                total_accurate_sentences / len(transcripts), len(transcripts))
            dec_metrics_dict['ter'].Update(
                total_token_errs / max(1., total_ref_tokens), total_ref_tokens)

        return key_value_pairs
Exemplo n.º 5
0
    def PostProcess(self, dec_out_dict, dec_metrics_dict):
        p = self.params
        assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys())
        topk_scores = dec_out_dict['topk_scores']
        topk_decoded = dec_out_dict['topk_decoded']
        transcripts = dec_out_dict['transcripts']
        if not py_utils.use_tpu():
            utt_id = dec_out_dict['utt_id']
            assert len(utt_id) == len(transcripts)
        norm_wer_errors = dec_out_dict['norm_wer_errors']
        norm_wer_words = dec_out_dict['norm_wer_words']
        target_labels = dec_out_dict['target_labels']
        target_paddings = dec_out_dict['target_paddings']
        topk_ids = dec_out_dict['topk_ids']
        topk_lens = dec_out_dict['topk_lens']
        if 'example_weights' in dec_out_dict:
            example_weights = dec_out_dict['example_weights']
        else:
            example_weights = np.ones([len(transcripts)], np.float32)
        assert len(transcripts) == len(target_labels)
        assert len(transcripts) == len(target_paddings)
        assert len(transcripts) == len(topk_decoded)
        assert len(norm_wer_errors) == len(transcripts)
        assert len(norm_wer_words) == len(transcripts)

        num_samples_in_batch = example_weights.sum()
        dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch)

        def GetRefIds(ref_ids, ref_paddinds):
            assert len(ref_ids) == len(ref_paddinds)
            return_ids = []
            for i in range(len(ref_ids)):
                if ref_paddinds[i] == 0:
                    return_ids.append(ref_ids[i])
            return return_ids

        total_norm_wer_errs = (norm_wer_errors[:, 0] * example_weights).sum()
        total_norm_wer_words = (norm_wer_words[:, 0] * example_weights).sum()

        dec_metrics_dict['norm_wer'].Update(
            total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words)

        filtered_transcripts = []
        filtered_top_hyps = []
        for ref_str, hyps in zip(transcripts, topk_decoded):
            filtered_ref = decoder_utils.FilterNoise(ref_str)
            filtered_ref = decoder_utils.FilterEpsilon(filtered_ref)
            filtered_transcripts.append(filtered_ref)
            filtered_hyp = decoder_utils.FilterNoise(hyps[0])
            filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp)
            filtered_top_hyps.append(filtered_hyp)
            dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp)

        total_errs = 0
        total_oracle_errs = 0
        total_ref_words = 0
        total_token_errs = 0
        total_ref_tokens = 0
        total_accurate_sentences = 0
        key_value_pairs = []

        if p.include_auxiliary_metrics:
            for i in range(len(transcripts)):
                ref_str = transcripts[i]
                if not py_utils.use_tpu():
                    tf.logging.info('utt_id: %s', utt_id[i])
                if self.cluster.add_summary:
                    tf.logging.info(
                        '  ref_str: %s',
                        ref_str.decode('utf-8') if p.log_utf8 else ref_str)
                hyps = topk_decoded[i]
                num_hyps_per_beam = len(hyps)
                ref_ids = GetRefIds(target_labels[i], target_paddings[i])
                hyp_index = i * num_hyps_per_beam
                top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]]
                if self.cluster.add_summary:
                    tf.logging.info('  ref_ids: %s', ref_ids)
                    tf.logging.info('  top_hyp_ids: %s', top_hyp_ids)
                total_ref_tokens += len(ref_ids)
                _, _, _, token_errs = decoder_utils.EditDistanceInIds(
                    ref_ids, top_hyp_ids)
                total_token_errs += token_errs

                filtered_ref = filtered_transcripts[i]
                oracle_errs = norm_wer_errors[i][0]
                for n, (score, hyp_str) in enumerate(zip(topk_scores[i],
                                                         hyps)):
                    oracle_errs = min(oracle_errs, norm_wer_errors[i, n])
                    if self.cluster.add_summary:
                        tf.logging.info(
                            '  %f: %s', score,
                            hyp_str.decode('utf-8') if p.log_utf8 else hyp_str)
                    # Only aggregate scores of the top hypothesis.
                    if n != 0:
                        continue
                    filtered_hyp = filtered_top_hyps[i]
                    _, _, _, errs = decoder_utils.EditDistance(
                        filtered_ref, filtered_hyp)
                    total_errs += errs
                    total_ref_words += len(
                        decoder_utils.Tokenize(filtered_ref))
                    if norm_wer_errors[i, n] == 0:
                        total_accurate_sentences += 1

                total_oracle_errs += oracle_errs

            dec_metrics_dict['wer'].Update(
                total_errs / max(1., total_ref_words), total_ref_words)
            dec_metrics_dict['oracle_norm_wer'].Update(
                total_oracle_errs / max(1., total_ref_words), total_ref_words)
            dec_metrics_dict['sacc'].Update(
                total_accurate_sentences / len(transcripts), len(transcripts))
            dec_metrics_dict['ter'].Update(
                total_token_errs / max(1., total_ref_tokens), total_ref_tokens)

        return key_value_pairs
Exemplo n.º 6
0
def CalculateMetrics(
    postprocess_inputs: PostProcessInputs,
    dec_metrics_dict: Dict[str, Any],
    add_summary: bool,
    use_tpu: bool,
    log_utf8: bool,
):
    """Calculate and update metrics.

  Args:
    postprocess_inputs: namedtuple of Postprocess input objects/tensors.
    dec_metrics_dict: A dictionary of metric names to metrics.
    add_summary: Whether to add detailed summary logging for processing each
      utterance.
    use_tpu: Whether TPU is used (for decoding).
    log_utf8: DecoderMetrics param. If True, decode reference and hypotheses
      bytes to UTF-8 for logging.
  """
    (transcripts, topk_decoded, filtered_transcripts, filtered_top_hyps,
     topk_scores, utt_id, norm_wer_errors, target_labels, target_paddings,
     topk_ids, topk_lens) = postprocess_inputs

    # Case sensitive WERs.
    total_ins, total_subs, total_dels, total_errs = 0, 0, 0, 0
    # Case insensitive WERs.
    ci_total_ins, ci_total_subs, ci_total_dels, ci_total_errs = 0, 0, 0, 0
    total_oracle_errs = 0
    total_ref_words = 0
    total_token_errs = 0
    total_ref_tokens = 0
    total_accurate_sentences = 0

    for i in range(len(transcripts)):
        ref_str = transcripts[i]
        if not use_tpu:
            tf.logging.info('utt_id: %s', utt_id[i])
        if add_summary:
            tf.logging.info('  ref_str: %s',
                            ref_str.decode('utf-8') if log_utf8 else ref_str)
        hyps = topk_decoded[i]
        num_hyps_per_beam = len(hyps)
        ref_ids = GetRefIds(target_labels[i], target_paddings[i])
        hyp_index = i * num_hyps_per_beam
        top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]]
        if add_summary:
            tf.logging.info('  ref_ids: %s', ref_ids)
            tf.logging.info('  top_hyp_ids: %s', top_hyp_ids)
        total_ref_tokens += len(ref_ids)
        _, _, _, token_errs = decoder_utils.EditDistanceInIds(
            ref_ids, top_hyp_ids)
        total_token_errs += token_errs

        filtered_ref = filtered_transcripts[i]
        oracle_errs = norm_wer_errors[i][0]
        for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)):
            oracle_errs = min(oracle_errs, norm_wer_errors[i, n])
            if add_summary:
                tf.logging.info(
                    '  %f: %s', score,
                    hyp_str.decode('utf-8') if log_utf8 else hyp_str)
            # Only aggregate scores of the top hypothesis.
            if n != 0:
                continue
            filtered_hyp = filtered_top_hyps[i]
            ins, subs, dels, errs = decoder_utils.EditDistance(
                filtered_ref, filtered_hyp)

            total_ins += ins
            total_subs += subs
            total_dels += dels
            total_errs += errs

            # Calculating case_insensitive WERs
            ci_ins, ci_subs, ci_dels, ci_errs = decoder_utils.EditDistance(
                filtered_ref.lower(), filtered_hyp.lower())

            ci_total_ins += ci_ins
            ci_total_subs += ci_subs
            ci_total_dels += ci_dels
            ci_total_errs += ci_errs

            ref_words = len(decoder_utils.Tokenize(filtered_ref))
            total_ref_words += ref_words
            if norm_wer_errors[i, n] == 0:
                total_accurate_sentences += 1
            tf.logging.info(
                '  ins: %d, subs: %d, del: %d, total: %d, ref_words: %d, wer: %f',
                ins, subs, dels, errs, ref_words, errs / max(1, ref_words))

            tf.logging.info(
                '  ci_ins: %d, ci_subs: %d, ci_del: %d, ci_total: %d, '
                'ref_words: %d, ci_wer: %f', ci_ins, ci_subs, ci_dels, ci_errs,
                ref_words, ci_errs / max(1, ref_words))

        total_oracle_errs += oracle_errs

    non_zero_total_ref_words = max(1., total_ref_words)
    dec_metrics_dict['wer'].Update(total_errs / non_zero_total_ref_words,
                                   total_ref_words)
    dec_metrics_dict['error_rates/ins'].Update(
        total_ins / non_zero_total_ref_words, total_ref_words)
    dec_metrics_dict['error_rates/sub'].Update(
        total_subs / non_zero_total_ref_words, total_ref_words)
    dec_metrics_dict['error_rates/del'].Update(
        total_dels / non_zero_total_ref_words, total_ref_words)
    dec_metrics_dict['error_rates/wer'].Update(
        total_errs / non_zero_total_ref_words, total_ref_words)

    dec_metrics_dict['case_insensitive_error_rates/ins'].Update(
        ci_total_ins / non_zero_total_ref_words, total_ref_words)
    dec_metrics_dict['case_insensitive_error_rates/sub'].Update(
        ci_total_subs / non_zero_total_ref_words, total_ref_words)
    dec_metrics_dict['case_insensitive_error_rates/del'].Update(
        ci_total_dels / non_zero_total_ref_words, total_ref_words)
    dec_metrics_dict['case_insensitive_error_rates/wer'].Update(
        ci_total_errs / non_zero_total_ref_words, total_ref_words)

    dec_metrics_dict['oracle_norm_wer'].Update(
        total_oracle_errs / non_zero_total_ref_words, total_ref_words)
    dec_metrics_dict['sacc'].Update(
        total_accurate_sentences / len(transcripts), len(transcripts))
    dec_metrics_dict['ter'].Update(
        total_token_errs / max(1., total_ref_tokens), total_ref_tokens)