def test_torch_scheduler_on_epoch_no_monitor(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {}} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor=None, step_on_batch=False) torch_scheduler.on_start(state) mock_scheduler.assert_called_once_with('optimizer') mock_scheduler.reset_mock() torch_scheduler.on_start_training(state) mock_scheduler.step.assert_called_once_with(epoch=1) mock_scheduler.reset_mock() torch_scheduler.on_sample(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_step_training(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_end_epoch(state) mock_scheduler.assert_not_called() mock_scheduler.reset_mock()
def test_torch_scheduler_on_batch_no_monitor_oldstyle(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer'} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(mock_scheduler, monitor=None, step_on_batch=True) torch_scheduler._newstyle = False torch_scheduler.on_start(state) mock_scheduler.assert_called_once_with('optimizer', last_epoch=0) mock_scheduler.reset_mock() torch_scheduler.on_start_training(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_sample(state) mock_scheduler.step.assert_called_once() mock_scheduler.reset_mock() torch_scheduler.on_step_training(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_end_epoch(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock()
def test_batch_monitor_found(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'test': 1.}} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=True) torch_scheduler.on_start(state) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_step_training(state) self.assertTrue(len(w) == 0)
def test_batch_monitor_not_found(self): state = {torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: {'not_test': 1.}} mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(lambda opt: mock_scheduler(opt), monitor='test', step_on_batch=True) torch_scheduler.on_start(state) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_step_training(state) self.assertTrue('Failed to retrieve key `test`' in str(w[0].message))
def test_torch_scheduler_on_batch_with_monitor_oldstyle(self): state = { torchbearer.EPOCH: 1, torchbearer.METRICS: { 'test': 101 }, torchbearer.OPTIMIZER: 'optimizer', torchbearer.MODEL: Mock() } mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(mock_scheduler, monitor='test', step_on_batch=True) torch_scheduler._newstyle = False import warnings with warnings.catch_warnings(record=True) as w: torch_scheduler.on_start(state) self.assertTrue(len(w) == 1) self.assertTrue(issubclass(w[-1].category, UserWarning)) mock_scheduler.assert_called_once_with('optimizer', last_epoch=0) mock_scheduler.reset_mock() torch_scheduler.on_start_training(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_sample(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_step_training(state) mock_scheduler.step.assert_called_once_with(101) mock_scheduler.reset_mock() torch_scheduler.on_end_epoch(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock()
def test_torch_scheduler_on_epoch_with_monitor_oldstyle(self): state = { torchbearer.EPOCH: 1, torchbearer.METRICS: { 'test': 101 }, torchbearer.OPTIMIZER: 'optimizer', torchbearer.DATA: None, torchbearer.MODEL: Mock() } mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(mock_scheduler, monitor='test', step_on_batch=False) torch_scheduler._newstyle = False torch_scheduler.on_start(state) mock_scheduler.assert_called_once_with('optimizer', last_epoch=0) mock_scheduler.reset_mock() torch_scheduler.on_start_training(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_sample(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_step_training(state) mock_scheduler.step.assert_not_called() mock_scheduler.reset_mock() torch_scheduler.on_end_epoch(state) mock_scheduler.step.assert_called_once_with(101, epoch=1) mock_scheduler.reset_mock()