def test_single_channel(self, mock_board, mock_grid): mock_board.return_value = Mock() mock_board.return_value.add_image = Mock() mock_grid.return_value = 10 state = {'x': torch.ones(18, 10, 10), torchbearer.EPOCH: 1, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))} tboard = TensorBoardImages(name='test', key='x', write_each_epoch=True, num_images=18, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1) tboard.on_start(state) tboard.on_step_validation(state) mock_grid.assert_called_once_with(ANY, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1) mock_board.return_value.add_image.assert_called_once_with('test', 10, 1) self.assertTrue(mock_grid.call_args[0][0].size() == torch.ones(18, 1, 10, 10).size())
def test_single_channel_visdom(self, mock_visdom, mock_writer, _, mock_grid): mock_writer.return_value = Mock() mock_writer.return_value.add_image = Mock() mock_grid.return_value = 10 state = {'x': torch.ones(18, 10, 10), torchbearer.EPOCH: 1, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))} tboard = TensorBoardImages(visdom=True, name='test', key='x', write_each_epoch=True, num_images=18, nrow=9, padding=3, normalize=True, norm_range='tmp', scale_each=True, pad_value=1) tboard.on_start(state) tboard.on_step_validation(state) mock_grid.assert_called_once_with(ANY, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1) mock_writer.return_value.add_image.assert_called_once_with('test1', 10, 1) self.assertTrue(mock_grid.call_args[0][0].size() == torch.ones(18, 1, 10, 10).size()) tboard.on_end({})
def test_multi_batch(self, mock_board, mock_grid): mock_board.return_value = Mock() mock_board.return_value.add_image = Mock() mock_grid.return_value = 10 state = { 'x': torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 1, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)) } tboard = TensorBoardImages(name='test', key='x', write_each_epoch=False, num_images=36, nrow=9, padding=3, normalize=True, norm_range='tmp', scale_each=True, pad_value=1) tboard.on_start(state) tboard.on_step_validation(state) tboard.on_step_validation(state) mock_grid.assert_called_once_with(ANY, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1) mock_board.return_value.add_image.assert_called_once_with( 'test', 10, 1) self.assertTrue(mock_grid.call_args[0][0].size() == torch.ones( 36, 3, 10, 10).size())