示例#1
0
    def test_code_shape(self):
        rw = vis.RandomWalker()
        rw.data = self.state[tb.X]
        rw.model = self.mock_model
        rw.dev = 'cpu'
        rw.vis(self.state)

        self.assertTrue(
            list(self.mock_model.decode.call_args[0][0][0].shape) == [32, 2])
示例#2
0
    def test_uniform(self, mock_rand, mock_randn):
        mock_rand.return_value = torch.ones(2, 2)
        mock_randn.return_value = torch.zeros(2, 2)

        rw = vis.RandomWalker(num_images=10)
        rw.data = self.state[tb.X]
        rw.model = self.mock_model
        rw.dev = 'cpu'
        rw.vis(self.state)

        self.assertTrue(mock_rand.call_count == 0)
        self.assertTrue(mock_randn.call_count == 1)

        rw = vis.RandomWalker(num_images=10, uniform=True)
        rw.data = self.state[tb.X]
        rw.dev = 'cpu'
        rw.model = self.mock_model
        rw.vis(self.state)
        self.assertTrue(mock_rand.call_count == 1)
        self.assertTrue(mock_randn.call_count == 1)
示例#3
0
    def test_code_shape_alt(self):
        self.mock_model.decode.return_value = torch.rand(10, 1, 2, 2)

        rw = vis.RandomWalker(num_images=10)
        rw.data = self.state[tb.X]
        rw.model = self.mock_model
        rw.dev = 'cpu'
        rw.vis(self.state)

        self.assertTrue(
            list(self.mock_model.decode.call_args[0][0][0].shape) == [10, 2])