def testBasic(self): with self.session(): self.assertAllEqual( decoder_utils.ComputeWer(hyps=["one"], refs=["one"]).eval(), [[0, 1]]) self.assertAllEqual( decoder_utils.ComputeWer(hyps=["one two"], refs=["one two"]).eval(), [[0, 2]])
def testConsecutiveWhiteSpace(self): with self.session(): wer = decoder_utils.ComputeWer( hyps=["one two", "one two", "two pigs"], refs=["one two", "one two ", "three pink pigs"]) self.assertAllEqual(wer.shape, [3, 2]) self.assertAllEqual(wer.eval(), [[0, 2], [0, 2], [2, 3]])
def _ComputeNormalizedWER(self, hyps, refs): # Filter out all '<epsilon>' tokens for norm_wer computation. hyps_no_epsilon = tf.regex_replace(hyps, '(<epsilon>)+', ' ') # norm_wer is size [num_transcripts * hyps_per_beam, 2] norm_wer = decoder_utils.ComputeWer(hyps_no_epsilon, refs) # Split into two tensors of size [num_transcripts * hyps_per_beam, 1] norm_wer_errors, norm_wer_words = tf.split(norm_wer, [1, 1], 1) shape = [-1, self.params.decoder.beam_search.num_hyps_per_beam] norm_wer_errors = tf.reshape(norm_wer_errors, shape) norm_wer_words = tf.reshape(norm_wer_words, shape) return norm_wer_errors, norm_wer_words
def _CalculateErrorRates(self, dec_outs_dict, input_batch): return {'stuff': dec_outs_dict.sparse_ids.values} gt_seq_lens = py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1) gt_transcripts = py_utils.RunOnTpuHost( self.input_generator.IdsToStrings, input_batch.tgt.labels, gt_seq_lens) # token error rate char_dist = tf.edit_distance(tf.string_split(dec_outs_dict.transcripts, sep=''), tf.string_split(gt_transcripts, sep=''), normalize=False) ref_chars = tf.strings.length(gt_transcripts) num_wrong_chars = tf.reduce_sum(char_dist) num_ref_chars = tf.cast(tf.reduce_sum(ref_chars), tf.float32) cer = num_wrong_chars / num_ref_chars # word error rate word_dist = decoder_utils.ComputeWer(dec_outs_dict.transcripts, gt_transcripts) # (B, 2) num_wrong_words = tf.reduce_sum(word_dist[:, 0]) num_ref_words = tf.reduce_sum(word_dist[:, 1]) wer = num_wrong_words / num_ref_words ret_dict = { 'target_ids': input_batch.tgt.ids, 'target_labels': input_batch.tgt.labels, 'target_weights': input_batch.tgt.weights, 'target_paddings': input_batch.tgt.paddings, 'target_transcripts': gt_transcripts, 'decoded_transcripts': dec_outs_dict.transcripts, 'wer': wer, 'cer': cer, 'num_wrong_words': num_wrong_words, 'num_ref_words': num_ref_words, 'num_wrong_chars': num_wrong_chars, 'num_ref_chars': num_ref_chars } if not py_utils.use_tpu(): ret_dict['utt_id'] = input_batch.sample_ids return ret_dict
def testDifferencesInCaseAreCountedAsErrors(self): with self.session(): wer = decoder_utils.ComputeWer(hyps=["ONE two", "one two"], refs=["one two", "ONE two"]) self.assertAllEqual(wer.shape, [2, 2]) self.assertAllEqual(wer.eval(), [[1, 2], [1, 2]])
def testEmptyRefsAndHyps(self): with self.session(): wer = decoder_utils.ComputeWer(hyps=["", "one two", ""], refs=["", "", "three four five"]) self.assertAllEqual(wer.shape, [3, 2]) self.assertAllEqual(wer.eval(), [[0, 0], [2, 0], [3, 3]])
def testMultiples(self): with self.session(): wer = decoder_utils.ComputeWer(hyps=["one", "two pigs"], refs=["one", "three pink pigs"]) self.assertAllEqual(wer.shape, [2, 2]) self.assertAllEqual(wer.eval(), [[0, 1], [2, 3]])
def testInvalidInputsWrongRank(self): with self.session(): with self.assertRaises(Exception): decoder_utils.ComputeWer(hyps=[["one"], ["two"]], refs=[["one"], ["two"]]).eval()
def testInvalidInputsExtraRefs(self): with self.session(): with self.assertRaises(Exception): decoder_utils.ComputeWer(hyps=["one"], refs=["one", "two"]).eval()