def test_monitor_found(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.OPTIMIZER: 'optimizer',
            torchbearer.METRICS: {
                'test': 1.
            }
        }
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor='test',
                                         step_on_batch=False)
        torch_scheduler.on_start(state)
        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_start_training(state)
            self.assertTrue(len(w) == 0)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_start_validation(state)
            self.assertTrue(len(w) == 0)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_end_epoch(state)
            self.assertTrue(len(w) == 0)
Exemplo n.º 2
0
    def test_monitor_not_found(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.OPTIMIZER: 'optimizer',
            torchbearer.METRICS: {
                'not_test': 1.
            },
            torchbearer.MODEL: Mock()
        }
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor='test',
                                         step_on_batch=False)
        torch_scheduler.on_start(state)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_start_validation(state)
            self.assertTrue(len(w) == 0)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_end_epoch(state)
            self.assertTrue(
                'Failed to retrieve key `test`' in str(w[0].message))
Exemplo n.º 3
0
    def test_torch_scheduler_on_epoch_no_monitor(self):
        state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {}}
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor=None, step_on_batch=False)

        torch_scheduler.on_start(state)
        mock_scheduler.assert_called_once_with('optimizer')
        mock_scheduler.reset_mock()

        torch_scheduler.on_sample(state)
        mock_scheduler.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_step_training(state)
        mock_scheduler.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_end_epoch(state)
        mock_scheduler.assert_not_called()
        mock_scheduler.reset_mock()
    def test_batch_monitor_not_found(self):
        state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'not_test': 1.}}
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=True)
        torch_scheduler.on_start(state)

        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_step_training(state)
            self.assertTrue('Failed to retrieve key `test`' in str(w[0].message))
    def test_torch_scheduler_on_epoch_with_monitor(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.METRICS: {
                'test': 101
            },
            torchbearer.OPTIMIZER: 'optimizer',
            torchbearer.DATA: None
        }
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor='test',
                                         step_on_batch=False)
        torch_scheduler._newstyle = True

        torch_scheduler.on_start(state)
        mock_scheduler.assert_called_once_with('optimizer', last_epoch=0)
        mock_scheduler.reset_mock()

        torch_scheduler.on_start_training(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_sample(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_step_training(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_end_epoch(state)
        mock_scheduler.step.assert_called_once_with(101)
        mock_scheduler.reset_mock()
    def test_torch_scheduler_on_batch_with_monitor_oldstyle(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.METRICS: {
                'test': 101
            },
            torchbearer.OPTIMIZER: 'optimizer'
        }
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor='test',
                                         step_on_batch=True)
        torch_scheduler._newstyle = False

        import warnings
        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_start(state)
            self.assertTrue(len(w) == 1)
            self.assertTrue(issubclass(w[-1].category, UserWarning))
        mock_scheduler.assert_called_once_with('optimizer', last_epoch=0)
        mock_scheduler.reset_mock()

        torch_scheduler.on_start_training(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_sample(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_step_training(state)
        mock_scheduler.step.assert_called_once_with(101)
        mock_scheduler.reset_mock()

        torch_scheduler.on_end_epoch(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()
    def test_torch_scheduler_on_batch_no_monitor_oldstyle(self):
        state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer'}
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor=None,
                                         step_on_batch=True)
        torch_scheduler._newstyle = False

        torch_scheduler.on_start(state)
        mock_scheduler.assert_called_once_with('optimizer', last_epoch=0)
        mock_scheduler.reset_mock()

        torch_scheduler.on_start_training(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_sample(state)
        mock_scheduler.step.assert_called_once()
        mock_scheduler.reset_mock()

        torch_scheduler.on_step_training(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_end_epoch(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()