Beispiel #1
0
    def test_simple_methods(self, mock_to_file, mock_to_pyplot,
                            mock_to_tensorboard, mock_to_visdom):
        callback = imaging.ImagingCallback()
        self.assertRaises(NotImplementedError,
                          lambda: callback.on_batch('test'))

        callback = imaging.ImagingCallback().to_file('test')
        self.assertTrue(mock_to_file.call_count == 1)

        callback = imaging.ImagingCallback().to_pyplot()
        self.assertTrue(mock_to_pyplot.call_count == 1)

        callback = imaging.ImagingCallback().to_state('test')
        state = {}
        callback._handlers[0][0]('image', 0, state)
        self.assertTrue('test' in state)
        self.assertTrue(state['test'] is 'image')

        callback = imaging.ImagingCallback().to_state(0)
        state = {}
        callback._handlers[0][0]('image', 0, state)
        self.assertTrue(0 in state)
        self.assertTrue(state[0] is 'image')

        callback = imaging.ImagingCallback().to_tensorboard()
        self.assertTrue(mock_to_tensorboard.call_count == 1)

        callback = imaging.ImagingCallback().to_visdom()
        self.assertTrue(mock_to_visdom.call_count == 1)
Beispiel #2
0
 def test_on_train(self):
     callback = imaging.ImagingCallback()
     mock = MagicMock()
     callback.on_step_training = mock
     callback.process = MagicMock()
     callback = callback.on_train()
     callback.on_step_training('state')
     mock.assert_called_once_with('state')
     callback.process.assert_called_once_with('state')
Beispiel #3
0
 def test_dims(self, mock_to_file):
     callback = imaging.ImagingCallback().to_file('test')
     image = torch.rand(1, 1, 1)
     callback.on_batch = lambda _: image
     callback.process('state')
     mock_to_file().assert_called_once_with(ANY, 0, 'state')
     self.assertTrue(
         (mock_to_file().call_args_list[0][0][0] == image).all())
     self.assertTrue(mock_to_file().call_args_list[0][0][0].dim() == 3)
Beispiel #4
0
 def test_process(self):
     callback = imaging.ImagingCallback()
     callback.on_batch = lambda _: 'test'
     handler = MagicMock()
     callback = callback.with_handler(handler)
     callback.transform = MagicMock(return_value='test')
     state = 'state'
     callback.process(state)
     handler.assert_called_once_with('test', 'state')
     self.assertTrue(callback.transform.call_count == 1)
Beispiel #5
0
 def test_process(self):
     callback = imaging.ImagingCallback()
     callback.on_batch = MagicMock()
     handler = MagicMock()
     callback = callback.with_handler(handler)
     callback.transform = MagicMock()
     state = 'state'
     callback.process(state)
     self.assertTrue(callback.transform.call_count == 1)
     handler.assert_called_once_with(callback.transform()[None], 0, 'state')
     self.assertTrue(callback.transform().dim.call_count == 1)
Beispiel #6
0
    def test_index(self, mock_to_file):
        callback = imaging.ImagingCallback().to_file('test', index=10)
        image = torch.zeros(11, 1, 1, 1)
        callback.on_batch = lambda _: image
        callback.process('state')
        mock_to_file().assert_called_once_with(ANY, 10, 'state')
        self.assertTrue(
            (mock_to_file().call_args_list[0][0][0] == image[10]).all())
        mock_to_file().reset_mock()

        callback = imaging.ImagingCallback().to_file('test', index=[2, 5, 10])
        image = torch.rand(11, 1, 1, 1)
        callback.on_batch = lambda _: image
        callback.process('state')
        self.assertTrue(mock_to_file().call_count == 3)
        self.assertTrue(
            (mock_to_file().call_args_list[0][0][0] == image[2]).all())
        self.assertTrue(
            (mock_to_file().call_args_list[1][0][0] == image[5]).all())
        self.assertTrue(
            (mock_to_file().call_args_list[2][0][0] == image[10]).all())
Beispiel #7
0
 def test_make_grid(self, mock_grid):
     callback = imaging.ImagingCallback()
     callback.on_batch = lambda _: 10
     callback.make_grid(1, 2, True, 4, True, 6)
     callback.on_batch({})
     mock_grid.assert_called_once_with(10,
                                       nrow=1,
                                       padding=2,
                                       normalize=True,
                                       range=4,
                                       scale_each=True,
                                       pad_value=6)
Beispiel #8
0
    def test_on_test(self):
        callback = imaging.ImagingCallback()
        mock = MagicMock()
        callback.on_step_validation = mock
        callback.process = MagicMock()
        callback = callback.on_test()
        state = {torchbearer.DATA: torchbearer.VALIDATION_DATA}
        callback.on_step_validation(state)
        mock.assert_called_once_with(state)
        self.assertTrue(callback.process.call_count == 0)

        mock.reset_mock()
        callback.process.reset_mock()

        state = {torchbearer.DATA: torchbearer.TEST_DATA}
        callback.on_step_validation(state)
        mock.assert_called_once_with(state)
        callback.process.assert_called_once_with(state)
Beispiel #9
0
    def test_transform(self):
        callback = imaging.ImagingCallback()
        self.assertTrue(callback.transform('test') is 'test')

        callback = imaging.ImagingCallback(transform=lambda _: 'test')
        self.assertTrue(callback.transform('something else') is 'test')
Beispiel #10
0
 def test_cache(self, mock_cache_images):
     callback = imaging.ImagingCallback()
     callback.cache(10)
     mock_cache_images.assert_called_once_with(10)