def test_monitor_found(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.OPTIMIZER: 'optimizer',
            torchbearer.METRICS: {
                'test': 1.
            }
        }
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor='test',
                                         step_on_batch=False)
        torch_scheduler.on_start(state)
        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_start_training(state)
            self.assertTrue(len(w) == 0)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_start_validation(state)
            self.assertTrue(len(w) == 0)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_end_epoch(state)
            self.assertTrue(len(w) == 0)
예제 #2
0
    def test_monitor_not_found(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.OPTIMIZER: 'optimizer',
            torchbearer.METRICS: {
                'not_test': 1.
            },
            torchbearer.MODEL: Mock()
        }
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor='test',
                                         step_on_batch=False)
        torch_scheduler.on_start(state)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_start_validation(state)
            self.assertTrue(len(w) == 0)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_end_epoch(state)
            self.assertTrue(
                'Failed to retrieve key `test`' in str(w[0].message))