예제 #1
0
 def testBasic(self):
     with self.session():
         self.assertAllEqual(
             decoder_utils.ComputeWer(hyps=["one"], refs=["one"]).eval(),
             [[0, 1]])
         self.assertAllEqual(
             decoder_utils.ComputeWer(hyps=["one two"],
                                      refs=["one two"]).eval(), [[0, 2]])
예제 #2
0
 def testConsecutiveWhiteSpace(self):
     with self.session():
         wer = decoder_utils.ComputeWer(
             hyps=["one    two", "one two", "two     pigs"],
             refs=["one two", "one     two ", "three pink pigs"])
         self.assertAllEqual(wer.shape, [3, 2])
         self.assertAllEqual(wer.eval(), [[0, 2], [0, 2], [2, 3]])
예제 #3
0
    def _ComputeNormalizedWER(self, hyps, refs):
        # Filter out all '<epsilon>' tokens for norm_wer computation.
        hyps_no_epsilon = tf.regex_replace(hyps, '(<epsilon>)+', ' ')
        # norm_wer is size [num_transcripts * hyps_per_beam, 2]
        norm_wer = decoder_utils.ComputeWer(hyps_no_epsilon, refs)
        # Split into two tensors of size [num_transcripts * hyps_per_beam, 1]
        norm_wer_errors, norm_wer_words = tf.split(norm_wer, [1, 1], 1)
        shape = [-1, self.params.decoder.beam_search.num_hyps_per_beam]
        norm_wer_errors = tf.reshape(norm_wer_errors, shape)
        norm_wer_words = tf.reshape(norm_wer_words, shape)

        return norm_wer_errors, norm_wer_words
예제 #4
0
    def _CalculateErrorRates(self, dec_outs_dict, input_batch):
        return {'stuff': dec_outs_dict.sparse_ids.values}
        gt_seq_lens = py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1)
        gt_transcripts = py_utils.RunOnTpuHost(
            self.input_generator.IdsToStrings, input_batch.tgt.labels,
            gt_seq_lens)

        # token error rate
        char_dist = tf.edit_distance(tf.string_split(dec_outs_dict.transcripts,
                                                     sep=''),
                                     tf.string_split(gt_transcripts, sep=''),
                                     normalize=False)

        ref_chars = tf.strings.length(gt_transcripts)
        num_wrong_chars = tf.reduce_sum(char_dist)
        num_ref_chars = tf.cast(tf.reduce_sum(ref_chars), tf.float32)
        cer = num_wrong_chars / num_ref_chars

        # word error rate
        word_dist = decoder_utils.ComputeWer(dec_outs_dict.transcripts,
                                             gt_transcripts)  # (B, 2)
        num_wrong_words = tf.reduce_sum(word_dist[:, 0])
        num_ref_words = tf.reduce_sum(word_dist[:, 1])
        wer = num_wrong_words / num_ref_words

        ret_dict = {
            'target_ids': input_batch.tgt.ids,
            'target_labels': input_batch.tgt.labels,
            'target_weights': input_batch.tgt.weights,
            'target_paddings': input_batch.tgt.paddings,
            'target_transcripts': gt_transcripts,
            'decoded_transcripts': dec_outs_dict.transcripts,
            'wer': wer,
            'cer': cer,
            'num_wrong_words': num_wrong_words,
            'num_ref_words': num_ref_words,
            'num_wrong_chars': num_wrong_chars,
            'num_ref_chars': num_ref_chars
        }
        if not py_utils.use_tpu():
            ret_dict['utt_id'] = input_batch.sample_ids

        return ret_dict
예제 #5
0
 def testDifferencesInCaseAreCountedAsErrors(self):
     with self.session():
         wer = decoder_utils.ComputeWer(hyps=["ONE two", "one two"],
                                        refs=["one two", "ONE two"])
         self.assertAllEqual(wer.shape, [2, 2])
         self.assertAllEqual(wer.eval(), [[1, 2], [1, 2]])
예제 #6
0
 def testEmptyRefsAndHyps(self):
     with self.session():
         wer = decoder_utils.ComputeWer(hyps=["", "one two", ""],
                                        refs=["", "", "three four five"])
         self.assertAllEqual(wer.shape, [3, 2])
         self.assertAllEqual(wer.eval(), [[0, 0], [2, 0], [3, 3]])
예제 #7
0
 def testMultiples(self):
     with self.session():
         wer = decoder_utils.ComputeWer(hyps=["one", "two pigs"],
                                        refs=["one", "three pink pigs"])
         self.assertAllEqual(wer.shape, [2, 2])
         self.assertAllEqual(wer.eval(), [[0, 1], [2, 3]])
예제 #8
0
 def testInvalidInputsWrongRank(self):
     with self.session():
         with self.assertRaises(Exception):
             decoder_utils.ComputeWer(hyps=[["one"], ["two"]],
                                      refs=[["one"], ["two"]]).eval()
예제 #9
0
 def testInvalidInputsExtraRefs(self):
     with self.session():
         with self.assertRaises(Exception):
             decoder_utils.ComputeWer(hyps=["one"], refs=["one",
                                                          "two"]).eval()