コード例 #1
0
ファイル: keras_metrics_test.py プロジェクト: sls33/federated
 def test_counts_total_examples_with_zero_mask_with_sample_weight(self):
   metric = keras_metrics.NumTokensCounter(masked_tokens=[0])
   metric.update_state(
       y_true=[[1, 2, 3, 0], [1, 0, 0, 0]],
       y_pred=[0],
       sample_weight=[[1, 2, 3, 4], [1, 1, 1, 1]])
   self.assertEqual(self.evaluate(metric.result()), 7)
コード例 #2
0
ファイル: keras_metrics_test.py プロジェクト: sls33/federated
 def test_counts_total_examples_without_zero_mask_no_sample_weight(self):
   metric = keras_metrics.NumTokensCounter()
   metric.update_state(
       y_true=[[1, 2, 3, 4], [0, 0, 0, 0]],
       y_pred=[
           0
           # y_pred is thrown away
       ])
   self.assertEqual(self.evaluate(metric.result()), 8)
コード例 #3
0
def metrics_builder():
    """Returns a `list` of `tf.keras.metric.Metric` objects."""
    pad_token, _, _, _ = shakespeare_dataset.get_special_tokens()

    return [
        keras_metrics.NumBatchesCounter(),
        keras_metrics.NumExamplesCounter(),
        keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token]),
    ]
コード例 #4
0
ファイル: run_federated.py プロジェクト: sls33/federated
 def metrics_builder():
     return [
         keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                 masked_tokens=[pad_token]),
         keras_metrics.MaskedCategoricalAccuracy(
             name='accuracy_no_oov', masked_tokens=[pad_token, oov_token]),
         # Notice BOS never appears in ground truth.
         keras_metrics.MaskedCategoricalAccuracy(
             name='accuracy_no_oov_or_eos',
             masked_tokens=[pad_token, oov_token, eos_token]),
         keras_metrics.NumBatchesCounter(),
         keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
     ]
コード例 #5
0
ファイル: keras_metrics_test.py プロジェクト: sls33/federated
 def test_constructor_no_masked_token(self):
   metric_name = 'my_test_metric'
   metric = keras_metrics.NumTokensCounter(name=metric_name)
   self.assertIsInstance(metric, tf.keras.metrics.Metric)
   self.assertEqual(metric.name, metric_name)
   self.assertEqual(self.evaluate(metric.result()), 0)