def _ComputeDecoderMetrics(self, decoder_outs, input_batch): """Computes metrics on output from decoder. Args: decoder_outs: A `BeamSearchDecodeOutput`, a namedtuple containing the decode results. input_batch: A `NestedMap` of tensors representing the source, target, and other components of the input batch. Returns: A dict of Tensors containing decoder output and metrics. """ p = self.params topk = self._GetTopK(decoder_outs) tgt = self._GetTargetForDecoderMetrics(input_batch) transcripts = self.input_generator.IdsToStrings( tgt.labels, tf.cast(tf.round(tf.reduce_sum(1.0 - tgt.paddings, 1) - 1.0), tf.int32)) # Filter out all isolated '<noise>' tokens. noise_pattern = ' <noise> |^<noise> | <noise>$|^<noise>$' filtered_refs = tf.regex_replace(transcripts, noise_pattern, ' ') filtered_hyps = tf.regex_replace(topk.decoded, noise_pattern, ' ') # Compute translation quality scores for all hyps. filtered_refs = tf.tile(tf.reshape(filtered_refs, [-1, 1]), [1, p.decoder.beam_search.num_hyps_per_beam]) filtered_hyps = tf.reshape(filtered_hyps, [-1]) filtered_refs = tf.reshape(filtered_refs, [-1]) norm_wer_errors, norm_wer_words = self._ComputeNormalizedWER( filtered_hyps, filtered_refs) ret_dict = { 'target_ids': tgt.ids, 'target_labels': tgt.labels, 'target_weights': tgt.weights, 'target_paddings': tgt.paddings, 'transcripts': transcripts, 'topk_decoded': topk.decoded, 'topk_ids': topk.ids, 'topk_lens': topk.lens, 'topk_scores': topk.scores, 'norm_wer_errors': norm_wer_errors, 'norm_wer_words': norm_wer_words, } if not py_utils.use_tpu(): ret_dict['utt_id'] = input_batch.sample_ids ret_dict.update( self.AddAdditionalDecoderMetricsToGraph(topk, filtered_hyps, filtered_refs, input_batch, decoder_outs)) return ret_dict
def _ComputeNormalizedWER(self, hyps, refs): # Filter out all '<epsilon>' tokens for norm_wer computation. hyps_no_epsilon = tf.regex_replace(hyps, '(<epsilon>)+', ' ') # norm_wer is size [num_transcripts * hyps_per_beam, 2] norm_wer = decoder_utils.ComputeWer(hyps_no_epsilon, refs) # Split into two tensors of size [num_transcripts * hyps_per_beam, 1] norm_wer_errors, norm_wer_words = tf.split(norm_wer, [1, 1], 1) shape = [-1, self.params.decoder.beam_search.num_hyps_per_beam] norm_wer_errors = tf.reshape(norm_wer_errors, shape) norm_wer_words = tf.reshape(norm_wer_words, shape) return norm_wer_errors, norm_wer_words
def ExtractRunIds(run_segments): """Extract the RunIds from the run_segments feature field. Args: run_segments: a string Tensor of shape [batch, 1] containing a text proto. See `SummaryTest.testExtractRunIds` for an example. Returns: A string Tensor of shape [batch], containing the extracted run id. """ run_segments = tf.convert_to_tensor(run_segments)[:, 0] return tf.regex_replace(run_segments, r'[^:]+: "(.+)"\n[^:]+: (\d+)(.|\n)*', r'\1_\2')
def ComputeWer(hyps, refs): """Computes word errors in hypotheses relative to reference transcripts. Args: hyps: Hypotheses, represented as string tensors of shape [N]. refs: References, represented as string tensors of shape [N]. Returns: An int64 tensor, word_errs, of size [N, 2] where word_errs[i, 0] corresponds to the number of word errors in hyps[i] relative to refs[i]; word_errs[i, 1] corresponds to the number of words in refs[i]. """ def _NormalizeWhitespace(s): return tf.regex_replace(tf.strings.strip(s), r'\s+', ' ') hyps = _NormalizeWhitespace(hyps) refs = _NormalizeWhitespace(refs) hyps = py_utils.HasRank(hyps, 1) refs = py_utils.HasRank(refs, 1) hyps = py_utils.HasShape(hyps, tf.shape(refs)) word_errors = tf.to_int64( tf.edit_distance(tf.string_split(hyps), tf.string_split(refs), normalize=False)) # Count number of spaces in reference, and increment by 1 to get total number # of words. ref_words = tf.to_int64( tf.strings.length(tf.regex_replace(refs, '[^ ]', '')) + 1) # Set number of words to 0 if the reference was empty. ref_words = tf.where(tf.equal(refs, ''), tf.zeros_like(ref_words, tf.int64), ref_words) return tf.concat( [tf.expand_dims(word_errors, -1), tf.expand_dims(ref_words, -1)], axis=1)
def _NormalizeWhitespace(s): return tf.regex_replace(tf.strings.strip(s), r'\s+', ' ')