示例#1
0
    def _ComputeDecoderMetrics(self, decoder_outs, input_batch):
        """Computes metrics on output from decoder.

    Args:
      decoder_outs: A `BeamSearchDecodeOutput`, a namedtuple containing the
        decode results.
      input_batch:  A `NestedMap` of tensors representing the source, target,
        and other components of the input batch.

    Returns:
      A dict of Tensors containing decoder output and metrics.
    """
        p = self.params
        topk = self._GetTopK(decoder_outs)
        tgt = self._GetTargetForDecoderMetrics(input_batch)
        transcripts = self.input_generator.IdsToStrings(
            tgt.labels,
            tf.cast(tf.round(tf.reduce_sum(1.0 - tgt.paddings, 1) - 1.0),
                    tf.int32))

        # Filter out all isolated '<noise>' tokens.
        noise_pattern = ' <noise> |^<noise> | <noise>$|^<noise>$'
        filtered_refs = tf.regex_replace(transcripts, noise_pattern, ' ')
        filtered_hyps = tf.regex_replace(topk.decoded, noise_pattern, ' ')
        # Compute translation quality scores for all hyps.
        filtered_refs = tf.tile(tf.reshape(filtered_refs, [-1, 1]),
                                [1, p.decoder.beam_search.num_hyps_per_beam])
        filtered_hyps = tf.reshape(filtered_hyps, [-1])
        filtered_refs = tf.reshape(filtered_refs, [-1])
        norm_wer_errors, norm_wer_words = self._ComputeNormalizedWER(
            filtered_hyps, filtered_refs)

        ret_dict = {
            'target_ids': tgt.ids,
            'target_labels': tgt.labels,
            'target_weights': tgt.weights,
            'target_paddings': tgt.paddings,
            'transcripts': transcripts,
            'topk_decoded': topk.decoded,
            'topk_ids': topk.ids,
            'topk_lens': topk.lens,
            'topk_scores': topk.scores,
            'norm_wer_errors': norm_wer_errors,
            'norm_wer_words': norm_wer_words,
        }

        if not py_utils.use_tpu():
            ret_dict['utt_id'] = input_batch.sample_ids

        ret_dict.update(
            self.AddAdditionalDecoderMetricsToGraph(topk, filtered_hyps,
                                                    filtered_refs, input_batch,
                                                    decoder_outs))
        return ret_dict
示例#2
0
    def _ComputeNormalizedWER(self, hyps, refs):
        # Filter out all '<epsilon>' tokens for norm_wer computation.
        hyps_no_epsilon = tf.regex_replace(hyps, '(<epsilon>)+', ' ')
        # norm_wer is size [num_transcripts * hyps_per_beam, 2]
        norm_wer = decoder_utils.ComputeWer(hyps_no_epsilon, refs)
        # Split into two tensors of size [num_transcripts * hyps_per_beam, 1]
        norm_wer_errors, norm_wer_words = tf.split(norm_wer, [1, 1], 1)
        shape = [-1, self.params.decoder.beam_search.num_hyps_per_beam]
        norm_wer_errors = tf.reshape(norm_wer_errors, shape)
        norm_wer_words = tf.reshape(norm_wer_words, shape)

        return norm_wer_errors, norm_wer_words
示例#3
0
def ExtractRunIds(run_segments):
    """Extract the RunIds from the run_segments feature field.

  Args:
    run_segments: a string Tensor of shape [batch, 1] containing a text proto.

      See `SummaryTest.testExtractRunIds` for an example.

  Returns:
    A string Tensor of shape [batch], containing the extracted run id.
  """
    run_segments = tf.convert_to_tensor(run_segments)[:, 0]
    return tf.regex_replace(run_segments,
                            r'[^:]+: "(.+)"\n[^:]+: (\d+)(.|\n)*', r'\1_\2')
示例#4
0
def ComputeWer(hyps, refs):
    """Computes word errors in hypotheses relative to reference transcripts.

  Args:
    hyps: Hypotheses, represented as string tensors of shape [N].
    refs: References, represented as string tensors of shape [N].

  Returns:
    An int64 tensor, word_errs, of size [N, 2] where word_errs[i, 0] corresponds
    to the number of word errors in hyps[i] relative to refs[i]; word_errs[i, 1]
    corresponds to the number of words in refs[i].
  """
    def _NormalizeWhitespace(s):
        return tf.regex_replace(tf.strings.strip(s), r'\s+', ' ')

    hyps = _NormalizeWhitespace(hyps)
    refs = _NormalizeWhitespace(refs)

    hyps = py_utils.HasRank(hyps, 1)
    refs = py_utils.HasRank(refs, 1)
    hyps = py_utils.HasShape(hyps, tf.shape(refs))

    word_errors = tf.to_int64(
        tf.edit_distance(tf.string_split(hyps),
                         tf.string_split(refs),
                         normalize=False))

    # Count number of spaces in reference, and increment by 1 to get total number
    # of words.
    ref_words = tf.to_int64(
        tf.strings.length(tf.regex_replace(refs, '[^ ]', '')) + 1)
    # Set number of words to 0 if the reference was empty.
    ref_words = tf.where(tf.equal(refs, ''),
                         tf.zeros_like(ref_words, tf.int64), ref_words)

    return tf.concat(
        [tf.expand_dims(word_errors, -1),
         tf.expand_dims(ref_words, -1)],
        axis=1)
示例#5
0
 def _NormalizeWhitespace(s):
     return tf.regex_replace(tf.strings.strip(s), r'\s+', ' ')