예제 #1
0
  def testCorpusBleuMetric(self):
    m = metrics.CorpusBleuMetric()
    m.Update('a b c d', 'a b c d')
    m.Update('a b c', 'a b c')

    self.assertEqual(1.0, m.value)

    name = 'corpus_bleu'
    self.assertEqual(
        tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=1.0)]),
        m.Summary(name))
예제 #2
0
    def CreateDecoderMetrics(self):
        base_metrics = {
            'num_samples_in_batch': metrics.AverageMetric(),
            'wer': metrics.AverageMetric(),  # Word error rate.
            'norm_wer': metrics.AverageMetric(),  # Normalized word error rate.
            'sacc': metrics.AverageMetric(),  # Sentence accuracy.
            'ter': metrics.AverageMetric(),  # Token error rate.
            'corpus_bleu': metrics.CorpusBleuMetric(),
            'oracle_norm_wer': metrics.AverageMetric(),
        }

        # Add any additional metrics that should be computed.
        base_metrics.update(self.CreateAdditionalDecoderMetrics())
        return base_metrics
예제 #3
0
    def CreateMetrics(self):
        base_metrics = {
            'num_samples_in_batch': metrics.AverageMetric(),
            'norm_wer': metrics.AverageMetric(),  # Normalized word error rate.
            'corpus_bleu': metrics.CorpusBleuMetric(),
        }

        if self.params.include_auxiliary_metrics:
            base_metrics.update({
                'wer': metrics.AverageMetric(),  # Word error rate.
                'sacc': metrics.AverageMetric(),  # Sentence accuracy.
                'ter': metrics.AverageMetric(),  # Token error rate.
                'oracle_norm_wer': metrics.AverageMetric(),
            })

        return base_metrics
예제 #4
0
    def CreateMetrics(self):
        base_metrics = {
            'num_samples_in_batch': metrics.AverageMetric(),
            'norm_wer': metrics.AverageMetric(),  # Normalized word error rate.
            'corpus_bleu': metrics.CorpusBleuMetric(),
        }

        if self.params.include_auxiliary_metrics:
            base_metrics.update({
                # TODO(xingwu): fully replace 'wer' with 'error_rates/wer'.
                'wer':
                metrics.AverageMetric(),  # Word error rate.
                'error_rates/ins':
                metrics.AverageMetric(),  # Insert error rate.
                'error_rates/sub':
                metrics.AverageMetric(),  # Substitute error rate.
                'error_rates/del':
                metrics.AverageMetric(),  # Deletion error rate.
                'error_rates/wer':
                metrics.AverageMetric(),  # Word error rate.
                'case_insensitive_error_rates/ins':
                metrics.AverageMetric(),  # Insert case-insensitive error rate.
                'case_insensitive_error_rates/sub':
                metrics.AverageMetric(
                ),  # Substitute case-insensitive error rate.
                'case_insensitive_error_rates/del':
                metrics.AverageMetric(
                ),  # Deletion case-insensitive error rate.
                'case_insensitive_error_rates/wer':
                metrics.AverageMetric(),  # Case-insensitive Word error rate.
                'sacc':
                metrics.AverageMetric(),  # Sentence accuracy.
                'ter':
                metrics.AverageMetric(),  # Token error rate.
                'oracle_norm_wer':
                metrics.AverageMetric(),
                'oracle/ins':
                metrics.AverageMetric(),
                'oracle/sub':
                metrics.AverageMetric(),
                'oracle/del':
                metrics.AverageMetric(),
            })

        return base_metrics
예제 #5
0
파일: model.py 프로젝트: wgfi110/lingvo
 def CreateDecoderMetrics(self):
   decoder_metrics = {
       'num_samples_in_batch': metrics.AverageMetric(),
       'corpus_bleu': metrics.CorpusBleuMetric(separator_type='wpm'),
   }
   return decoder_metrics