def testFilterNoise(self): s = "no noise" self.assertEqual(s, decoder_utils.FilterNoise(s)) s = "<noise> noise tokens are <noise> removed <noise>" self.assertEqual("noise tokens are removed", decoder_utils.FilterNoise(s))
def PreparePostProcess(self, dec_out_dict, dec_metrics_dict) -> PostProcessInputs: """Prepare the objects for PostProcess metrics calculations.""" assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys()) # Filter out examples that is not real (dummy batch paddings). if 'is_real' in dec_out_dict: self.FilterRealExamples(dec_out_dict) topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] utt_id = None if not py_utils.use_tpu(): utt_id = dec_out_dict['utt_id'] assert len(utt_id) == len(transcripts) norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] if 'example_weights' in dec_out_dict: example_weights = dec_out_dict['example_weights'] else: example_weights = np.ones([len(transcripts)], np.float32) assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) num_samples_in_batch = example_weights.sum() dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch) total_norm_wer_errs = (norm_wer_errors[:, 0] * example_weights).sum() total_norm_wer_words = (norm_wer_words[:, 0] * example_weights).sum() dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) filtered_transcripts = [] filtered_top_hyps = [] for ref_str, hyps in zip(transcripts, topk_decoded): filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) filtered_transcripts.append(filtered_ref) filtered_hyp = decoder_utils.FilterNoise(hyps[0]) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) filtered_top_hyps.append(filtered_hyp) dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp) return PostProcessInputs(transcripts, topk_decoded, filtered_transcripts, filtered_top_hyps, topk_scores, utt_id, norm_wer_errors, target_labels, target_paddings, topk_ids, topk_lens)
def PostProcessDecodeOut(self, dec_out_dict, dec_metrics_dict): p = self.params topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] utt_id = dec_out_dict['utt_id'] norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(utt_id) == len(transcripts) assert (len(topk_ids) == p.decoder.beam_search.num_hyps_per_beam * len(transcripts)) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) dec_metrics_dict['num_samples_in_batch'].Update(len(transcripts)) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_norm_wer_errs = 0 total_norm_wer_words = 0 total_accurate_sentences = 0 key_value_pairs = [] for i in range(len(transcripts)): ref_str = transcripts[i] tf.logging.info('utt_id: %s', utt_id[i]) tf.logging.info(' ref_str: %s', ref_str) hyps = topk_decoded[i] ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * p.decoder.beam_search.num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs assert p.decoder.beam_search.num_hyps_per_beam == len(hyps) filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): tf.logging.info(' %f: %s', score, hyp_str) filtered_hyp = decoder_utils.FilterNoise(hyp_str) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) ins, subs, dels, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) # Note that these numbers are not consistent with what is used to # compute normalized WER. In particular, these numbers will be inflated # when the transcript contains punctuation. tf.logging.info(' ins: %d, subs: %d, del: %d, total: %d', ins, subs, dels, errs) hyp_norm_wer_errors = norm_wer_errors[i][n] hyp_norm_wer_words = norm_wer_words[i][n] # Only aggregate scores of the top hypothesis. if n == 0: total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) total_norm_wer_errs += hyp_norm_wer_errors if hyp_norm_wer_errors == 0: total_accurate_sentences += 1 total_norm_wer_words += hyp_norm_wer_words dec_metrics_dict['corpus_bleu'].Update( filtered_ref, filtered_hyp) if hyp_norm_wer_errors < oracle_errs: oracle_errs = hyp_norm_wer_errors total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update(total_errs / total_ref_words, total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / total_ref_words, total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) dec_metrics_dict['ter'].Update(total_token_errs / total_ref_tokens, total_ref_tokens) # Update any additional metrics. dec_metrics_dict = self.UpdateAdditionalMetrics( dec_out_dict, dec_metrics_dict) return key_value_pairs
def PostProcess(self, dec_out_dict, dec_metrics_dict): p = self.params assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys()) topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] if not py_utils.use_tpu(): utt_id = dec_out_dict['utt_id'] assert len(utt_id) == len(transcripts) norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) num_samples_in_batch = len(transcripts) dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_norm_wer_errs = norm_wer_errors[:, 0].sum() total_norm_wer_words = norm_wer_words[:, 0].sum() dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) for ref_str, hyps in zip(transcripts, topk_decoded): filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) filtered_hyp = decoder_utils.FilterNoise(hyps[0]) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp) total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_accurate_sentences = 0 key_value_pairs = [] if p.include_auxiliary_metrics: for i in range(len(transcripts)): ref_str = transcripts[i] if not py_utils.use_tpu(): tf.logging.info('utt_id: %s', utt_id[i]) if self.cluster.add_summary: tf.logging.info( ' ref_str: %s', ref_str.decode('utf-8') if p.log_utf8 else ref_str) hyps = topk_decoded[i] num_hyps_per_beam = len(hyps) ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] if self.cluster.add_summary: tf.logging.info(' ref_ids: %s', ref_ids) tf.logging.info(' top_hyp_ids: %s', top_hyp_ids) total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): if self.cluster.add_summary: tf.logging.info( ' %f: %s', score, hyp_str.decode('utf-8') if p.log_utf8 else hyp_str) filtered_hyp = decoder_utils.FilterNoise(hyp_str) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) ins, subs, dels, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) # Note that these numbers are not consistent with what is used to # compute normalized WER. In particular, these numbers will be # inflated when the transcript contains punctuation. tf.logging.info(' ins: %d, subs: %d, del: %d, total: %d', ins, subs, dels, errs) # Only aggregate scores of the top hypothesis. if n == 0: total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) if norm_wer_errors[i, n] == 0: total_accurate_sentences += 1 oracle_errs = min(oracle_errs, norm_wer_errors[i, n]) total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update( total_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['ter'].Update( total_token_errs / max(1., total_ref_tokens), total_ref_tokens) return key_value_pairs
def PostProcess(self, dec_out_dict, dec_metrics_dict): p = self.params assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys()) topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] if not py_utils.use_tpu(): utt_id = dec_out_dict['utt_id'] assert len(utt_id) == len(transcripts) norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] if 'example_weights' in dec_out_dict: example_weights = dec_out_dict['example_weights'] else: example_weights = np.ones([len(transcripts)], np.float32) assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) num_samples_in_batch = example_weights.sum() dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_norm_wer_errs = (norm_wer_errors[:, 0] * example_weights).sum() total_norm_wer_words = (norm_wer_words[:, 0] * example_weights).sum() dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) filtered_transcripts = [] filtered_top_hyps = [] for ref_str, hyps in zip(transcripts, topk_decoded): filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) filtered_transcripts.append(filtered_ref) filtered_hyp = decoder_utils.FilterNoise(hyps[0]) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) filtered_top_hyps.append(filtered_hyp) dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp) total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_accurate_sentences = 0 key_value_pairs = [] if p.include_auxiliary_metrics: for i in range(len(transcripts)): ref_str = transcripts[i] if not py_utils.use_tpu(): tf.logging.info('utt_id: %s', utt_id[i]) if self.cluster.add_summary: tf.logging.info( ' ref_str: %s', ref_str.decode('utf-8') if p.log_utf8 else ref_str) hyps = topk_decoded[i] num_hyps_per_beam = len(hyps) ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] if self.cluster.add_summary: tf.logging.info(' ref_ids: %s', ref_ids) tf.logging.info(' top_hyp_ids: %s', top_hyp_ids) total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs filtered_ref = filtered_transcripts[i] oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): oracle_errs = min(oracle_errs, norm_wer_errors[i, n]) if self.cluster.add_summary: tf.logging.info( ' %f: %s', score, hyp_str.decode('utf-8') if p.log_utf8 else hyp_str) # Only aggregate scores of the top hypothesis. if n != 0: continue filtered_hyp = filtered_top_hyps[i] _, _, _, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) if norm_wer_errors[i, n] == 0: total_accurate_sentences += 1 total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update( total_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['ter'].Update( total_token_errs / max(1., total_ref_tokens), total_ref_tokens) return key_value_pairs