def testCorpusBleuMetric(self): m = metrics.CorpusBleuMetric() m.Update('a b c d', 'a b c d') m.Update('a b c', 'a b c') self.assertEqual(1.0, m.value) name = 'corpus_bleu' self.assertEqual( tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=1.0)]), m.Summary(name))
def CreateDecoderMetrics(self): base_metrics = { 'num_samples_in_batch': metrics.AverageMetric(), 'wer': metrics.AverageMetric(), # Word error rate. 'norm_wer': metrics.AverageMetric(), # Normalized word error rate. 'sacc': metrics.AverageMetric(), # Sentence accuracy. 'ter': metrics.AverageMetric(), # Token error rate. 'corpus_bleu': metrics.CorpusBleuMetric(), 'oracle_norm_wer': metrics.AverageMetric(), } # Add any additional metrics that should be computed. base_metrics.update(self.CreateAdditionalDecoderMetrics()) return base_metrics
def CreateMetrics(self): base_metrics = { 'num_samples_in_batch': metrics.AverageMetric(), 'norm_wer': metrics.AverageMetric(), # Normalized word error rate. 'corpus_bleu': metrics.CorpusBleuMetric(), } if self.params.include_auxiliary_metrics: base_metrics.update({ 'wer': metrics.AverageMetric(), # Word error rate. 'sacc': metrics.AverageMetric(), # Sentence accuracy. 'ter': metrics.AverageMetric(), # Token error rate. 'oracle_norm_wer': metrics.AverageMetric(), }) return base_metrics
def CreateMetrics(self): base_metrics = { 'num_samples_in_batch': metrics.AverageMetric(), 'norm_wer': metrics.AverageMetric(), # Normalized word error rate. 'corpus_bleu': metrics.CorpusBleuMetric(), } if self.params.include_auxiliary_metrics: base_metrics.update({ # TODO(xingwu): fully replace 'wer' with 'error_rates/wer'. 'wer': metrics.AverageMetric(), # Word error rate. 'error_rates/ins': metrics.AverageMetric(), # Insert error rate. 'error_rates/sub': metrics.AverageMetric(), # Substitute error rate. 'error_rates/del': metrics.AverageMetric(), # Deletion error rate. 'error_rates/wer': metrics.AverageMetric(), # Word error rate. 'case_insensitive_error_rates/ins': metrics.AverageMetric(), # Insert case-insensitive error rate. 'case_insensitive_error_rates/sub': metrics.AverageMetric( ), # Substitute case-insensitive error rate. 'case_insensitive_error_rates/del': metrics.AverageMetric( ), # Deletion case-insensitive error rate. 'case_insensitive_error_rates/wer': metrics.AverageMetric(), # Case-insensitive Word error rate. 'sacc': metrics.AverageMetric(), # Sentence accuracy. 'ter': metrics.AverageMetric(), # Token error rate. 'oracle_norm_wer': metrics.AverageMetric(), 'oracle/ins': metrics.AverageMetric(), 'oracle/sub': metrics.AverageMetric(), 'oracle/del': metrics.AverageMetric(), }) return base_metrics
def CreateDecoderMetrics(self): decoder_metrics = { 'num_samples_in_batch': metrics.AverageMetric(), 'corpus_bleu': metrics.CorpusBleuMetric(separator_type='wpm'), } return decoder_metrics