Exemple #1
0
    def _CalculateErrorRates(self, dec_outs_dict, input_batch):
        return {'stuff': dec_outs_dict.sparse_ids.values}
        gt_seq_lens = py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1)
        gt_transcripts = py_utils.RunOnTpuHost(
            self.input_generator.IdsToStrings, input_batch.tgt.labels,
            gt_seq_lens)

        # token error rate
        char_dist = tf.edit_distance(tf.string_split(dec_outs_dict.transcripts,
                                                     sep=''),
                                     tf.string_split(gt_transcripts, sep=''),
                                     normalize=False)

        ref_chars = tf.strings.length(gt_transcripts)
        num_wrong_chars = tf.reduce_sum(char_dist)
        num_ref_chars = tf.cast(tf.reduce_sum(ref_chars), tf.float32)
        cer = num_wrong_chars / num_ref_chars

        # word error rate
        word_dist = decoder_utils.ComputeWer(dec_outs_dict.transcripts,
                                             gt_transcripts)  # (B, 2)
        num_wrong_words = tf.reduce_sum(word_dist[:, 0])
        num_ref_words = tf.reduce_sum(word_dist[:, 1])
        wer = num_wrong_words / num_ref_words

        ret_dict = {
            'target_ids': input_batch.tgt.ids,
            'target_labels': input_batch.tgt.labels,
            'target_weights': input_batch.tgt.weights,
            'target_paddings': input_batch.tgt.paddings,
            'target_transcripts': gt_transcripts,
            'decoded_transcripts': dec_outs_dict.transcripts,
            'wer': wer,
            'cer': cer,
            'num_wrong_words': num_wrong_words,
            'num_ref_words': num_ref_words,
            'num_wrong_chars': num_wrong_chars,
            'num_ref_chars': num_ref_chars
        }
        if not py_utils.use_tpu():
            ret_dict['utt_id'] = input_batch.sample_ids

        return ret_dict
Exemple #2
0
def ComputeWer(hyps, refs):
    """Computes word errors in hypotheses relative to reference transcripts.

  Args:
    hyps: Hypotheses, represented as string tensors of shape [N].
    refs: References, represented as string tensors of shape [N].

  Returns:
    An int64 tensor, word_errs, of size [N, 2] where word_errs[i, 0] corresponds
    to the number of word errors in hyps[i] relative to refs[i]; word_errs[i, 1]
    corresponds to the number of words in refs[i].
  """
    def _NormalizeWhitespace(s):
        return tf.strings.regex_replace(tf.strings.strip(s), r'\s+', ' ')

    hyps = _NormalizeWhitespace(hyps)
    refs = _NormalizeWhitespace(refs)

    hyps = py_utils.HasRank(hyps, 1)
    refs = py_utils.HasRank(refs, 1)
    hyps = py_utils.HasShape(hyps, tf.shape(refs))

    word_errors = tf.cast(
        tf.edit_distance(tf.string_split(hyps),
                         tf.string_split(refs),
                         normalize=False), tf.int64)

    # Count number of spaces in reference, and increment by 1 to get total number
    # of words.
    ref_words = tf.cast(
        tf.strings.length(tf.strings.regex_replace(refs, '[^ ]', '')) + 1,
        tf.int64)
    # Set number of words to 0 if the reference was empty.
    ref_words = tf.where(tf.equal(refs, ''),
                         tf.zeros_like(ref_words, tf.int64), ref_words)

    return tf.concat(
        [tf.expand_dims(word_errors, -1),
         tf.expand_dims(ref_words, -1)],
        axis=1)