Esempio n. 1
0
    def test_csv_closed(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log', write_header=False)
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.return_value.close.called)
Esempio n. 2
0
    def test_write_on_epoch(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log')
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertEqual(mock_write.call_count, 1)
Esempio n. 3
0
    def test_batch_granularity(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log', batch_granularity=True)
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_write.call_count == 3)
Esempio n. 4
0
    def test_write_header(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log')
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        header = mock_open.mock_calls[1][1][0]
        self.assertTrue('epoch' in header)
        self.assertTrue('test_metric_1' in header)
        self.assertTrue('test_metric_2' in header)
Esempio n. 5
0
    def test_append(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log', append=True)
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        import sys
        if sys.version_info[0] < 3:
            self.assertTrue(mock_open.call_args[0][1] == 'ab')
        else:
            self.assertTrue(mock_open.call_args[0][1] == 'a')