예제 #1
0
    def test_write_on_epoch(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log')

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        mock_write.assert_called_once()
예제 #2
0
    def test_csv_closed(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', write_header=False)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.return_value.close.called)
예제 #3
0
    def test_append(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', append=True)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.call_args[0][1] == 'a+')
예제 #4
0
    def test_append(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', append=True)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.call_args[0][1] == 'a+')
예제 #5
0
    def test_write_no_header(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', write_header=False)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertFalse(mock_open.mock_calls[1][1][0] == 'epoch,test_metric_1,test_metric_2\r\n')
예제 #6
0
    def test_csv_closed(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', write_header=False)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.return_value.close.called)
예제 #7
0
    def test_write_on_epoch(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log')

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertEqual(mock_write.call_count, 1)
예제 #8
0
    def test_batch_granularity(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', batch_granularity=True)

        logger.on_step_training(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_write.call_count == 3)
예제 #9
0
    def test_batch_granularity(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', batch_granularity=True)

        logger.on_step_training(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_write.call_count == 3)
예제 #10
0
    def test_write_no_header(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', write_header=False)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        header = mock_open.mock_calls[1][1][0]
        self.assertTrue('epoch' not in header)
        self.assertTrue('test_metric_1' not in header)
        self.assertTrue('test_metric_2' not in header)
예제 #11
0
    def test_write_header(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log')

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.mock_calls[1][1][0] ==
                        'epoch,test_metric_1,test_metric_2\r\n')
예제 #12
0
    def test_append(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log', append=True)
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        import sys
        if sys.version_info[0] < 3:
            self.assertTrue(mock_open.call_args[0][1] == 'ab')
        else:
            self.assertTrue(mock_open.call_args[0][1] == 'a')