def test_simple_methods(self, mock_to_file, mock_to_pyplot, mock_to_tensorboard, mock_to_visdom): callback = imaging.ImagingCallback() self.assertRaises(NotImplementedError, lambda: callback.on_batch('test')) callback = imaging.ImagingCallback().to_file('test') self.assertTrue(mock_to_file.call_count == 1) callback = imaging.ImagingCallback().to_pyplot() self.assertTrue(mock_to_pyplot.call_count == 1) callback = imaging.ImagingCallback().to_state('test') state = {} callback._handlers[0][0]('image', 0, state) self.assertTrue('test' in state) self.assertTrue(state['test'] is 'image') callback = imaging.ImagingCallback().to_state(0) state = {} callback._handlers[0][0]('image', 0, state) self.assertTrue(0 in state) self.assertTrue(state[0] is 'image') callback = imaging.ImagingCallback().to_tensorboard() self.assertTrue(mock_to_tensorboard.call_count == 1) callback = imaging.ImagingCallback().to_visdom() self.assertTrue(mock_to_visdom.call_count == 1)
def test_on_train(self): callback = imaging.ImagingCallback() mock = MagicMock() callback.on_step_training = mock callback.process = MagicMock() callback = callback.on_train() callback.on_step_training('state') mock.assert_called_once_with('state') callback.process.assert_called_once_with('state')
def test_dims(self, mock_to_file): callback = imaging.ImagingCallback().to_file('test') image = torch.rand(1, 1, 1) callback.on_batch = lambda _: image callback.process('state') mock_to_file().assert_called_once_with(ANY, 0, 'state') self.assertTrue( (mock_to_file().call_args_list[0][0][0] == image).all()) self.assertTrue(mock_to_file().call_args_list[0][0][0].dim() == 3)
def test_process(self): callback = imaging.ImagingCallback() callback.on_batch = lambda _: 'test' handler = MagicMock() callback = callback.with_handler(handler) callback.transform = MagicMock(return_value='test') state = 'state' callback.process(state) handler.assert_called_once_with('test', 'state') self.assertTrue(callback.transform.call_count == 1)
def test_process(self): callback = imaging.ImagingCallback() callback.on_batch = MagicMock() handler = MagicMock() callback = callback.with_handler(handler) callback.transform = MagicMock() state = 'state' callback.process(state) self.assertTrue(callback.transform.call_count == 1) handler.assert_called_once_with(callback.transform()[None], 0, 'state') self.assertTrue(callback.transform().dim.call_count == 1)
def test_index(self, mock_to_file): callback = imaging.ImagingCallback().to_file('test', index=10) image = torch.zeros(11, 1, 1, 1) callback.on_batch = lambda _: image callback.process('state') mock_to_file().assert_called_once_with(ANY, 10, 'state') self.assertTrue( (mock_to_file().call_args_list[0][0][0] == image[10]).all()) mock_to_file().reset_mock() callback = imaging.ImagingCallback().to_file('test', index=[2, 5, 10]) image = torch.rand(11, 1, 1, 1) callback.on_batch = lambda _: image callback.process('state') self.assertTrue(mock_to_file().call_count == 3) self.assertTrue( (mock_to_file().call_args_list[0][0][0] == image[2]).all()) self.assertTrue( (mock_to_file().call_args_list[1][0][0] == image[5]).all()) self.assertTrue( (mock_to_file().call_args_list[2][0][0] == image[10]).all())
def test_make_grid(self, mock_grid): callback = imaging.ImagingCallback() callback.on_batch = lambda _: 10 callback.make_grid(1, 2, True, 4, True, 6) callback.on_batch({}) mock_grid.assert_called_once_with(10, nrow=1, padding=2, normalize=True, range=4, scale_each=True, pad_value=6)
def test_on_test(self): callback = imaging.ImagingCallback() mock = MagicMock() callback.on_step_validation = mock callback.process = MagicMock() callback = callback.on_test() state = {torchbearer.DATA: torchbearer.VALIDATION_DATA} callback.on_step_validation(state) mock.assert_called_once_with(state) self.assertTrue(callback.process.call_count == 0) mock.reset_mock() callback.process.reset_mock() state = {torchbearer.DATA: torchbearer.TEST_DATA} callback.on_step_validation(state) mock.assert_called_once_with(state) callback.process.assert_called_once_with(state)
def test_transform(self): callback = imaging.ImagingCallback() self.assertTrue(callback.transform('test') is 'test') callback = imaging.ImagingCallback(transform=lambda _: 'test') self.assertTrue(callback.transform('something else') is 'test')
def test_cache(self, mock_cache_images): callback = imaging.ImagingCallback() callback.cache(10) mock_cache_images.assert_called_once_with(10)