def test_code_shape(self): rw = vis.RandomWalker() rw.data = self.state[tb.X] rw.model = self.mock_model rw.dev = 'cpu' rw.vis(self.state) self.assertTrue( list(self.mock_model.decode.call_args[0][0][0].shape) == [32, 2])
def test_uniform(self, mock_rand, mock_randn): mock_rand.return_value = torch.ones(2, 2) mock_randn.return_value = torch.zeros(2, 2) rw = vis.RandomWalker(num_images=10) rw.data = self.state[tb.X] rw.model = self.mock_model rw.dev = 'cpu' rw.vis(self.state) self.assertTrue(mock_rand.call_count == 0) self.assertTrue(mock_randn.call_count == 1) rw = vis.RandomWalker(num_images=10, uniform=True) rw.data = self.state[tb.X] rw.dev = 'cpu' rw.model = self.mock_model rw.vis(self.state) self.assertTrue(mock_rand.call_count == 1) self.assertTrue(mock_randn.call_count == 1)
def test_code_shape_alt(self): self.mock_model.decode.return_value = torch.rand(10, 1, 2, 2) rw = vis.RandomWalker(num_images=10) rw.data = self.state[tb.X] rw.model = self.mock_model rw.dev = 'cpu' rw.vis(self.state) self.assertTrue( list(self.mock_model.decode.call_args[0][0][0].shape) == [10, 2])