def test_basic_logging(self): logger = MetricLogger() self.assertEqual(logger.format_logs(), '') logger.add_metrics(loss=1.) logger.add_metrics(loss=2., valid_loss=3., valid_timer=0.1) logger.add_metrics(loss=4., valid_acc=5., train_time=0.2) logger.add_metrics(loss=6., valid_acc=7., train_time=0.3) self.assertEqual( logger.format_logs(), 'train time: 0.25 sec (±0.05 sec); ' 'valid timer: 0.1 sec; ' 'loss: 3.25 (±1.92029); ' 'valid loss: 3; ' 'valid acc: 6 (±1)') logger.clear() self.assertEqual(logger.format_logs(), '') logger.add_metrics(metrics={'loss': 1.}) self.assertEqual(logger.format_logs(), 'loss: 1')
def test_errors(self): logger = MetricLogger() with self.assertRaisesRegex(TypeError, '`metrics` should be a dict.'): logger.add_metrics(metrics=[])