Exemplo n.º 1
0
 def test_counts_total_examples_without_zero_mask_no_sample_weight(self):
     metric = keras_metrics.NumTokensCounter()
     metric.update_state(
         y_true=[[1, 2, 3, 4], [0, 0, 0, 0]],
         y_pred=[
             0
             # y_pred is thrown away
         ])
     self.assertEqual(self.evaluate(metric.result()), 8)
Exemplo n.º 2
0
 def model_fn() -> model.Model:
   return keras_utils.from_keras_model(
       keras_model=char_prediction_models.create_recurrent_model(
           vocab_size=VOCAB_LENGTH, sequence_length=sequence_length),
       loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       input_spec=task_datasets.element_type_structure,
       metrics=[
           keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
           keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token])
       ])
Exemplo n.º 3
0
 def metrics_builder():
     return [
         keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
         keras_metrics.MaskedCategoricalAccuracy(name='accuracy',
                                                 masked_tokens=[pad_token]),
         keras_metrics.MaskedCategoricalAccuracy(
             name='accuracy_without_out_of_vocab',
             masked_tokens=[pad_token] + oov_tokens),
         # Notice that the beginning of sentence token never appears in the
         # ground truth label.
         keras_metrics.MaskedCategoricalAccuracy(
             name='accuracy_without_out_of_vocab_or_end_of_sentence',
             masked_tokens=[pad_token, eos_token] + oov_tokens),
     ]
Exemplo n.º 4
0
 def test_counts_total_examples_with_zero_mask_with_sample_weight(self):
     metric = keras_metrics.NumTokensCounter(masked_tokens=[0])
     metric.update_state(y_true=[[1, 2, 3, 0], [1, 0, 0, 0]],
                         y_pred=[0],
                         sample_weight=[[1, 2, 3, 4], [1, 1, 1, 1]])
     self.assertEqual(self.evaluate(metric.result()), 7)
Exemplo n.º 5
0
 def test_constructor_no_masked_token(self):
     metric_name = 'my_test_metric'
     metric = keras_metrics.NumTokensCounter(name=metric_name)
     self.assertIsInstance(metric, tf.keras.metrics.Metric)
     self.assertEqual(metric.name, metric_name)
     self.assertEqual(self.evaluate(metric.result()), 0)