コード例 #1
0
    def test_summary_writer(self):
        with TemporaryDirectory() as tempdir:
            # generate the metric summary
            with contextlib.closing(tf.summary.FileWriter(tempdir)) as sw:
                logger = MetricLogger(sw,
                                      summary_skip_pattern=r'.*(time|timer)$',
                                      summary_commit_freqs={'every_two': 2})
                step = 0
                for epoch in range(1, 3):
                    for data in range(10):
                        step += 1
                        logger.collect_metrics({'acc': step * 100 + data},
                                               step)
                        logger.collect_metrics({'time': epoch}, step)
                        logger.collect_metrics({'every_two': step * 2}, step)

                    with self.test_session(use_gpu=False):
                        logger.collect_metrics({'valid_loss': -epoch},
                                               tf.constant(step))

            # read the metric summary
            acc_steps = []
            acc_values = []
            valid_loss_steps = []
            valid_loss_values = []
            every_two_steps = []
            every_two_values = []
            tags = set()

            event_file_path = os.path.join(tempdir, os.listdir(tempdir)[0])
            for e in tf.train.summary_iterator(event_file_path):
                for v in e.summary.value:
                    tags.add(v.tag)
                    if v.tag == 'acc':
                        acc_steps.append(e.step)
                        acc_values.append(v.simple_value)
                    elif v.tag == 'valid_loss':
                        valid_loss_steps.append(e.step)
                        valid_loss_values.append(v.simple_value)
                    elif v.tag == 'every_two':
                        every_two_steps.append(e.step)
                        every_two_values.append(v.simple_value)

            self.assertEqual(sorted(tags), ['acc', 'every_two', 'valid_loss'])
            np.testing.assert_equal(acc_steps, np.arange(1, 21))
            np.testing.assert_almost_equal(
                acc_values,
                np.arange(1, 21) * 100 +
                np.concatenate([np.arange(10), np.arange(10)]))
            np.testing.assert_equal(every_two_steps, np.arange(1, 21, 2))
            np.testing.assert_almost_equal(every_two_values,
                                           np.arange(1, 21, 2) * 2)
            np.testing.assert_equal(valid_loss_steps, [10, 20])
            np.testing.assert_almost_equal(valid_loss_values, [-1, -2])
コード例 #2
0
 def test_errors(self):
     logger = MetricLogger()
     with self.assertRaisesRegex(TypeError, '`metrics` should be a dict.'):
         logger.add_metrics(metrics=[])
コード例 #3
0
    def test_basic_logging(self):
        logger = MetricLogger()
        self.assertEqual(logger.format_logs(), '')

        logger.add_metrics(loss=1.)
        logger.add_metrics(loss=2., valid_loss=3., valid_timer=0.1)
        logger.add_metrics(loss=4., valid_acc=5., train_time=0.2)
        logger.add_metrics(loss=6., valid_acc=7., train_time=0.3)
        self.assertEqual(
            logger.format_logs(), 'train time: 0.25 sec (±0.05 sec); '
            'valid timer: 0.1 sec; '
            'loss: 3.25 (±1.92029); '
            'valid loss: 3; '
            'valid acc: 6 (±1)')

        logger.clear()
        self.assertEqual(logger.format_logs(), '')

        logger.add_metrics(metrics={'loss': 1.})
        self.assertEqual(logger.format_logs(), 'loss: 1')
コード例 #4
0
ファイル: test_logs.py プロジェクト: Feng37/tfsnippet
    def test_basic_logging(self):
        logger = MetricLogger()
        self.assertEqual(logger.format_logs(), '')

        logger.collect_metrics(dict(loss=SimpleDynamicValue(1.)))
        logger.collect_metrics(dict(loss=2., valid_loss=3., valid_timer=0.1))
        logger.collect_metrics(dict(loss=4., valid_acc=5., train_time=0.2))
        logger.collect_metrics(dict(loss=6., valid_acc=7., train_time=0.3))
        logger.collect_metrics(dict(other_metric=5.))
        self.assertEqual(
            logger.format_logs(), 'train time: 0.25 sec (±0.05 sec); '
            'valid timer: 0.1 sec; '
            'loss: 3.25 (±1.92029); '
            'valid loss: 3; '
            'valid acc: 6 (±1); '
            'other metric: 5')

        logger.clear()
        self.assertEqual(logger.format_logs(), '')

        logger.collect_metrics({'loss': 1.})
        self.assertEqual(logger.format_logs(), 'loss: 1')
コード例 #5
0
ファイル: test_logs.py プロジェクト: mengyuan404/tfsnippet
    def test_basic_logging(self):
        v = ScheduledVariable('v', 1.)

        logger = MetricLogger()
        self.assertEqual(logger.format_logs(), '')

        with self.test_session() as sess:
            ensure_variables_initialized()
            logger.collect_metrics(dict(loss=v))
        logger.collect_metrics(dict(loss=2., valid_loss=3., valid_timer=0.1))
        logger.collect_metrics(dict(loss=4., valid_acc=5., train_time=0.2))
        logger.collect_metrics(dict(loss=6., valid_acc=7., train_time=0.3))
        logger.collect_metrics(dict(other_metric=5.))
        self.assertEqual(
            logger.format_logs(),
            'train time: 0.25s (±0.05s); '
            'valid timer: 0.1s; '
            'other metric: 5; '
            'loss: 3.25 (±1.92029); '
            'valid loss: 3; '
            'valid acc: 6 (±1)'
        )

        logger.clear()
        self.assertEqual(logger.format_logs(), '')

        logger.collect_metrics({'loss': 1.})
        self.assertEqual(logger.format_logs(), 'loss: 1')