def testEditDistanceSkipsEmptyTokens(self): ref = "a b c d e f g h" hyp = "a b c d e f g h" self.assertEqual((0, 0, 0, 0), decoder_utils.EditDistance(ref, hyp)) ref = "a b c d e f g h" hyp = "a b c d e f g h" self.assertEqual((0, 0, 0, 0), decoder_utils.EditDistance(ref, hyp))
def testEditDistance1(self): ref = "a b c d e f g h" hyp = "a b c d e f g h" self.assertEqual((0, 0, 0, 0), decoder_utils.EditDistance(ref, hyp)) ref = "a b c d e f g h" hyp = "a b d e f g h" self.assertEqual((0, 0, 1, 1), decoder_utils.EditDistance(ref, hyp)) ref = "a b c d e f g h" hyp = "a b c i d e f g h" self.assertEqual((1, 0, 0, 1), decoder_utils.EditDistance(ref, hyp)) ref = "a b c d e f g h" hyp = "a b c i e f g h" self.assertEqual((0, 1, 0, 1), decoder_utils.EditDistance(ref, hyp)) ref = "a b c d e f g j h" hyp = "a b c i d e f g h" self.assertEqual((1, 0, 1, 2), decoder_utils.EditDistance(ref, hyp)) ref = "a b c d e f g j h" hyp = "a b c i e f g h k" self.assertEqual((1, 1, 1, 3), decoder_utils.EditDistance(ref, hyp)) ref = "" hyp = "" self.assertEqual((0, 0, 0, 0), decoder_utils.EditDistance(ref, hyp)) ref = "" hyp = "a b c" self.assertEqual((3, 0, 0, 3), decoder_utils.EditDistance(ref, hyp)) ref = "a b c d" hyp = "" self.assertEqual((0, 0, 4, 4), decoder_utils.EditDistance(ref, hyp))
def PostProcessDecodeOut(self, dec_out_dict, dec_metrics_dict): p = self.params topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] utt_id = dec_out_dict['utt_id'] norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(utt_id) == len(transcripts) assert (len(topk_ids) == p.decoder.beam_search.num_hyps_per_beam * len(transcripts)) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) dec_metrics_dict['num_samples_in_batch'].Update(len(transcripts)) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_norm_wer_errs = 0 total_norm_wer_words = 0 total_accurate_sentences = 0 key_value_pairs = [] for i in range(len(transcripts)): ref_str = transcripts[i] tf.logging.info('utt_id: %s', utt_id[i]) tf.logging.info(' ref_str: %s', ref_str) hyps = topk_decoded[i] ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * p.decoder.beam_search.num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs assert p.decoder.beam_search.num_hyps_per_beam == len(hyps) filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): tf.logging.info(' %f: %s', score, hyp_str) filtered_hyp = decoder_utils.FilterNoise(hyp_str) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) ins, subs, dels, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) # Note that these numbers are not consistent with what is used to # compute normalized WER. In particular, these numbers will be inflated # when the transcript contains punctuation. tf.logging.info(' ins: %d, subs: %d, del: %d, total: %d', ins, subs, dels, errs) hyp_norm_wer_errors = norm_wer_errors[i][n] hyp_norm_wer_words = norm_wer_words[i][n] # Only aggregate scores of the top hypothesis. if n == 0: total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) total_norm_wer_errs += hyp_norm_wer_errors if hyp_norm_wer_errors == 0: total_accurate_sentences += 1 total_norm_wer_words += hyp_norm_wer_words dec_metrics_dict['corpus_bleu'].Update( filtered_ref, filtered_hyp) if hyp_norm_wer_errors < oracle_errs: oracle_errs = hyp_norm_wer_errors total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update(total_errs / total_ref_words, total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / total_ref_words, total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) dec_metrics_dict['ter'].Update(total_token_errs / total_ref_tokens, total_ref_tokens) # Update any additional metrics. dec_metrics_dict = self.UpdateAdditionalMetrics( dec_out_dict, dec_metrics_dict) return key_value_pairs
def PostProcess(self, dec_out_dict, dec_metrics_dict): p = self.params assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys()) topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] if not py_utils.use_tpu(): utt_id = dec_out_dict['utt_id'] assert len(utt_id) == len(transcripts) norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) num_samples_in_batch = len(transcripts) dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_norm_wer_errs = norm_wer_errors[:, 0].sum() total_norm_wer_words = norm_wer_words[:, 0].sum() dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) for ref_str, hyps in zip(transcripts, topk_decoded): filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) filtered_hyp = decoder_utils.FilterNoise(hyps[0]) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp) total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_accurate_sentences = 0 key_value_pairs = [] if p.include_auxiliary_metrics: for i in range(len(transcripts)): ref_str = transcripts[i] if not py_utils.use_tpu(): tf.logging.info('utt_id: %s', utt_id[i]) if self.cluster.add_summary: tf.logging.info( ' ref_str: %s', ref_str.decode('utf-8') if p.log_utf8 else ref_str) hyps = topk_decoded[i] num_hyps_per_beam = len(hyps) ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] if self.cluster.add_summary: tf.logging.info(' ref_ids: %s', ref_ids) tf.logging.info(' top_hyp_ids: %s', top_hyp_ids) total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): if self.cluster.add_summary: tf.logging.info( ' %f: %s', score, hyp_str.decode('utf-8') if p.log_utf8 else hyp_str) filtered_hyp = decoder_utils.FilterNoise(hyp_str) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) ins, subs, dels, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) # Note that these numbers are not consistent with what is used to # compute normalized WER. In particular, these numbers will be # inflated when the transcript contains punctuation. tf.logging.info(' ins: %d, subs: %d, del: %d, total: %d', ins, subs, dels, errs) # Only aggregate scores of the top hypothesis. if n == 0: total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) if norm_wer_errors[i, n] == 0: total_accurate_sentences += 1 oracle_errs = min(oracle_errs, norm_wer_errors[i, n]) total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update( total_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['ter'].Update( total_token_errs / max(1., total_ref_tokens), total_ref_tokens) return key_value_pairs
def PostProcess(self, dec_out_dict, dec_metrics_dict): p = self.params assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys()) topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] if not py_utils.use_tpu(): utt_id = dec_out_dict['utt_id'] assert len(utt_id) == len(transcripts) norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] if 'example_weights' in dec_out_dict: example_weights = dec_out_dict['example_weights'] else: example_weights = np.ones([len(transcripts)], np.float32) assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) num_samples_in_batch = example_weights.sum() dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_norm_wer_errs = (norm_wer_errors[:, 0] * example_weights).sum() total_norm_wer_words = (norm_wer_words[:, 0] * example_weights).sum() dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) filtered_transcripts = [] filtered_top_hyps = [] for ref_str, hyps in zip(transcripts, topk_decoded): filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) filtered_transcripts.append(filtered_ref) filtered_hyp = decoder_utils.FilterNoise(hyps[0]) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) filtered_top_hyps.append(filtered_hyp) dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp) total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_accurate_sentences = 0 key_value_pairs = [] if p.include_auxiliary_metrics: for i in range(len(transcripts)): ref_str = transcripts[i] if not py_utils.use_tpu(): tf.logging.info('utt_id: %s', utt_id[i]) if self.cluster.add_summary: tf.logging.info( ' ref_str: %s', ref_str.decode('utf-8') if p.log_utf8 else ref_str) hyps = topk_decoded[i] num_hyps_per_beam = len(hyps) ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] if self.cluster.add_summary: tf.logging.info(' ref_ids: %s', ref_ids) tf.logging.info(' top_hyp_ids: %s', top_hyp_ids) total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs filtered_ref = filtered_transcripts[i] oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): oracle_errs = min(oracle_errs, norm_wer_errors[i, n]) if self.cluster.add_summary: tf.logging.info( ' %f: %s', score, hyp_str.decode('utf-8') if p.log_utf8 else hyp_str) # Only aggregate scores of the top hypothesis. if n != 0: continue filtered_hyp = filtered_top_hyps[i] _, _, _, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) if norm_wer_errors[i, n] == 0: total_accurate_sentences += 1 total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update( total_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['ter'].Update( total_token_errs / max(1., total_ref_tokens), total_ref_tokens) return key_value_pairs
def CalculateMetrics( postprocess_inputs: PostProcessInputs, dec_metrics_dict: Dict[str, Any], add_summary: bool, use_tpu: bool, log_utf8: bool, ): """Calculate and update metrics. Args: postprocess_inputs: namedtuple of Postprocess input objects/tensors. dec_metrics_dict: A dictionary of metric names to metrics. add_summary: Whether to add detailed summary logging for processing each utterance. use_tpu: Whether TPU is used (for decoding). log_utf8: DecoderMetrics param. If True, decode reference and hypotheses bytes to UTF-8 for logging. """ (transcripts, topk_decoded, filtered_transcripts, filtered_top_hyps, topk_scores, utt_id, norm_wer_errors, target_labels, target_paddings, topk_ids, topk_lens) = postprocess_inputs # Case sensitive WERs. total_ins, total_subs, total_dels, total_errs = 0, 0, 0, 0 # Case insensitive WERs. ci_total_ins, ci_total_subs, ci_total_dels, ci_total_errs = 0, 0, 0, 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_accurate_sentences = 0 for i in range(len(transcripts)): ref_str = transcripts[i] if not use_tpu: tf.logging.info('utt_id: %s', utt_id[i]) if add_summary: tf.logging.info(' ref_str: %s', ref_str.decode('utf-8') if log_utf8 else ref_str) hyps = topk_decoded[i] num_hyps_per_beam = len(hyps) ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] if add_summary: tf.logging.info(' ref_ids: %s', ref_ids) tf.logging.info(' top_hyp_ids: %s', top_hyp_ids) total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs filtered_ref = filtered_transcripts[i] oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): oracle_errs = min(oracle_errs, norm_wer_errors[i, n]) if add_summary: tf.logging.info( ' %f: %s', score, hyp_str.decode('utf-8') if log_utf8 else hyp_str) # Only aggregate scores of the top hypothesis. if n != 0: continue filtered_hyp = filtered_top_hyps[i] ins, subs, dels, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) total_ins += ins total_subs += subs total_dels += dels total_errs += errs # Calculating case_insensitive WERs ci_ins, ci_subs, ci_dels, ci_errs = decoder_utils.EditDistance( filtered_ref.lower(), filtered_hyp.lower()) ci_total_ins += ci_ins ci_total_subs += ci_subs ci_total_dels += ci_dels ci_total_errs += ci_errs ref_words = len(decoder_utils.Tokenize(filtered_ref)) total_ref_words += ref_words if norm_wer_errors[i, n] == 0: total_accurate_sentences += 1 tf.logging.info( ' ins: %d, subs: %d, del: %d, total: %d, ref_words: %d, wer: %f', ins, subs, dels, errs, ref_words, errs / max(1, ref_words)) tf.logging.info( ' ci_ins: %d, ci_subs: %d, ci_del: %d, ci_total: %d, ' 'ref_words: %d, ci_wer: %f', ci_ins, ci_subs, ci_dels, ci_errs, ref_words, ci_errs / max(1, ref_words)) total_oracle_errs += oracle_errs non_zero_total_ref_words = max(1., total_ref_words) dec_metrics_dict['wer'].Update(total_errs / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['error_rates/ins'].Update( total_ins / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['error_rates/sub'].Update( total_subs / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['error_rates/del'].Update( total_dels / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['error_rates/wer'].Update( total_errs / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['case_insensitive_error_rates/ins'].Update( ci_total_ins / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['case_insensitive_error_rates/sub'].Update( ci_total_subs / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['case_insensitive_error_rates/del'].Update( ci_total_dels / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['case_insensitive_error_rates/wer'].Update( ci_total_errs / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / non_zero_total_ref_words, total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['ter'].Update( total_token_errs / max(1., total_ref_tokens), total_ref_tokens)