def test_monitor_found(self): state = { torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: { 'test': 1. } } mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(mock_scheduler, monitor='test', step_on_batch=False) torch_scheduler.on_start(state) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_start_training(state) self.assertTrue(len(w) == 0) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_start_validation(state) self.assertTrue(len(w) == 0) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_end_epoch(state) self.assertTrue(len(w) == 0)
def test_monitor_not_found(self): state = { torchbearer.EPOCH: 1, torchbearer.OPTIMIZER: 'optimizer', torchbearer.METRICS: { 'not_test': 1. }, torchbearer.MODEL: Mock() } mock_scheduler = Mock() mock_scheduler.return_value = mock_scheduler torch_scheduler = TorchScheduler(mock_scheduler, monitor='test', step_on_batch=False) torch_scheduler.on_start(state) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_start_validation(state) self.assertTrue(len(w) == 0) with warnings.catch_warnings(record=True) as w: torch_scheduler.on_end_epoch(state) self.assertTrue( 'Failed to retrieve key `test`' in str(w[0].message))