Example #1
0
 def test_construct(self):
     m = counters.NumBatchesCounter()
     self.assertEqual(m.name, 'num_batches')
     self.assertTrue(m.stateful)
     self.assertEqual(m.dtype, tf.int64)
     self.assertLen(m.variables, 1)
     self.assertEqual(m.total, 0)
     m = counters.NumBatchesCounter('num_batches2')
     self.assertEqual(m.name, 'num_batches2')
Example #2
0
 def test_update_with_sample_weight(self, batch1, batch2):
     m = counters.NumBatchesCounter()
     self.assertEqual(m(batch1, batch1), 1)
     self.assertEqual(m.total, 1)
     self.assertEqual(m.update_state(batch2, batch2), 2)
     self.assertEqual(m.total, 2)
Example #3
0
 def test_reset_to_zero(self):
     m = counters.NumBatchesCounter()
     self.assertGreater(m(tf.zeros([10, 1]), tf.zeros([10])), 0)
     self.assertGreater(m.total, 0)
     m.reset_state()
     self.assertEqual(m.total, 0)
 def metrics_fn():
     return [
         counters.NumExamplesCounter(),
         counters.NumBatchesCounter(),
         tf.keras.metrics.SparseCategoricalAccuracy()
     ]