def _CalculateErrorRates(self, dec_outs_dict, input_batch): return {'stuff': dec_outs_dict.sparse_ids.values} gt_seq_lens = py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1) gt_transcripts = py_utils.RunOnTpuHost( self.input_generator.IdsToStrings, input_batch.tgt.labels, gt_seq_lens) # token error rate char_dist = tf.edit_distance(tf.string_split(dec_outs_dict.transcripts, sep=''), tf.string_split(gt_transcripts, sep=''), normalize=False) ref_chars = tf.strings.length(gt_transcripts) num_wrong_chars = tf.reduce_sum(char_dist) num_ref_chars = tf.cast(tf.reduce_sum(ref_chars), tf.float32) cer = num_wrong_chars / num_ref_chars # word error rate word_dist = decoder_utils.ComputeWer(dec_outs_dict.transcripts, gt_transcripts) # (B, 2) num_wrong_words = tf.reduce_sum(word_dist[:, 0]) num_ref_words = tf.reduce_sum(word_dist[:, 1]) wer = num_wrong_words / num_ref_words ret_dict = { 'target_ids': input_batch.tgt.ids, 'target_labels': input_batch.tgt.labels, 'target_weights': input_batch.tgt.weights, 'target_paddings': input_batch.tgt.paddings, 'target_transcripts': gt_transcripts, 'decoded_transcripts': dec_outs_dict.transcripts, 'wer': wer, 'cer': cer, 'num_wrong_words': num_wrong_words, 'num_ref_words': num_ref_words, 'num_wrong_chars': num_wrong_chars, 'num_ref_chars': num_ref_chars } if not py_utils.use_tpu(): ret_dict['utt_id'] = input_batch.sample_ids return ret_dict
def ComputeWer(hyps, refs): """Computes word errors in hypotheses relative to reference transcripts. Args: hyps: Hypotheses, represented as string tensors of shape [N]. refs: References, represented as string tensors of shape [N]. Returns: An int64 tensor, word_errs, of size [N, 2] where word_errs[i, 0] corresponds to the number of word errors in hyps[i] relative to refs[i]; word_errs[i, 1] corresponds to the number of words in refs[i]. """ def _NormalizeWhitespace(s): return tf.strings.regex_replace(tf.strings.strip(s), r'\s+', ' ') hyps = _NormalizeWhitespace(hyps) refs = _NormalizeWhitespace(refs) hyps = py_utils.HasRank(hyps, 1) refs = py_utils.HasRank(refs, 1) hyps = py_utils.HasShape(hyps, tf.shape(refs)) word_errors = tf.cast( tf.edit_distance(tf.string_split(hyps), tf.string_split(refs), normalize=False), tf.int64) # Count number of spaces in reference, and increment by 1 to get total number # of words. ref_words = tf.cast( tf.strings.length(tf.strings.regex_replace(refs, '[^ ]', '')) + 1, tf.int64) # Set number of words to 0 if the reference was empty. ref_words = tf.where(tf.equal(refs, ''), tf.zeros_like(ref_words, tf.int64), ref_words) return tf.concat( [tf.expand_dims(word_errors, -1), tf.expand_dims(ref_words, -1)], axis=1)