def test_torch_scheduler_on_epoch_no_monitor_oldstyle(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.OPTIMIZER: 'optimizer',
            torchbearer.METRICS: {}
        }
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor=None,
                                         step_on_batch=False)
        torch_scheduler._newstyle = False

        torch_scheduler.on_start(state)
        mock_scheduler.assert_called_once_with('optimizer', last_epoch=0)
        mock_scheduler.reset_mock()

        torch_scheduler.on_start_training(state)
        mock_scheduler.step.assert_called_once()
        mock_scheduler.reset_mock()

        torch_scheduler.on_sample(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_step_training(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_end_epoch(state)
        mock_scheduler.step.assert_called_once()
        mock_scheduler.reset_mock()
Exemple #2
0
    def test_torch_scheduler_on_batch_with_monitor_oldstyle(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.METRICS: {
                'test': 101
            },
            torchbearer.OPTIMIZER: 'optimizer',
            torchbearer.MODEL: Mock()
        }
        mock_scheduler = Mock()
        mock_scheduler.return_value = mock_scheduler

        torch_scheduler = TorchScheduler(mock_scheduler,
                                         monitor='test',
                                         step_on_batch=True)
        torch_scheduler._newstyle = False

        import warnings
        with warnings.catch_warnings(record=True) as w:
            torch_scheduler.on_start(state)
            self.assertTrue(len(w) == 1)
            self.assertTrue(issubclass(w[-1].category, UserWarning))
        mock_scheduler.assert_called_once_with('optimizer', last_epoch=0)
        mock_scheduler.reset_mock()

        torch_scheduler.on_start_training(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_sample(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()

        torch_scheduler.on_step_training(state)
        mock_scheduler.step.assert_called_once_with(101)
        mock_scheduler.reset_mock()

        torch_scheduler.on_end_epoch(state)
        mock_scheduler.step.assert_not_called()
        mock_scheduler.reset_mock()