def test_construct(self): m = counters.NumBatchesCounter() self.assertEqual(m.name, 'num_batches') self.assertTrue(m.stateful) self.assertEqual(m.dtype, tf.int64) self.assertLen(m.variables, 1) self.assertEqual(m.total, 0) m = counters.NumBatchesCounter('num_batches2') self.assertEqual(m.name, 'num_batches2')
def test_update_with_sample_weight(self, batch1, batch2): m = counters.NumBatchesCounter() self.assertEqual(m(batch1, batch1), 1) self.assertEqual(m.total, 1) self.assertEqual(m.update_state(batch2, batch2), 2) self.assertEqual(m.total, 2)
def test_reset_to_zero(self): m = counters.NumBatchesCounter() self.assertGreater(m(tf.zeros([10, 1]), tf.zeros([10])), 0) self.assertGreater(m.total, 0) m.reset_state() self.assertEqual(m.total, 0)
def metrics_fn(): return [ counters.NumExamplesCounter(), counters.NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ]