Esempio n. 1
0
    def test_single_dim_zero_init(self):
        lw = vis.LinSpaceWalker(dims_to_walk=[1], lin_steps=3, zero_init=True)
        lw.data = self.state[tb.X]
        lw.model = self.mock_model
        lw.dev = 'cpu'
        lw.vis(self.state)

        correct_lin_sample = torch.zeros(self.codes[0].shape).repeat(3, 1)
        correct_lin_sample[0, 1], correct_lin_sample[1, 1], correct_lin_sample[
            2, 1] = -1.0, 0.0, 1.0
        self.assertTrue(
            (self.mock_model.decode.call_args[0][0][0] - correct_lin_sample <
             1e-5).all())
Esempio n. 2
0
    def test_single_dim_alt(self):
        lw = vis.LinSpaceWalker(dims_to_walk=[0], lin_steps=3)
        lw.data = self.state[tb.X]
        lw.model = self.mock_model
        lw.dev = 'cpu'
        lw.vis(self.state)

        correct_lin_sample = self.codes[0].repeat(3, 1)
        correct_lin_sample[0, 0], correct_lin_sample[1, 0], correct_lin_sample[
            2, 0] = -1.0, 0.0, 1.0
        self.assertTrue(
            (self.mock_model.decode.call_args[0][0][0] - correct_lin_sample[:3]
             < 1e-5).all())
Esempio n. 3
0
    def test_multi_dim(self):
        lw = vis.LinSpaceWalker(dims_to_walk=[0, 1], lin_steps=3)
        lw.data = self.state[tb.X]
        lw.model = self.mock_model
        lw.dev = 'cpu'
        lw.vis(self.state)

        correct_lin_sample = torch.cat(
            [self.codes[0].repeat(3, 1), self.codes[1].repeat(3, 1)])
        correct_lin_sample[0, 0], correct_lin_sample[1, 0], correct_lin_sample[
            2, 0] = -1.0, 0.0, 1.0
        correct_lin_sample[3, 1], correct_lin_sample[4, 1], correct_lin_sample[
            5, 1] = -1.0, 0.0, 1.0

        self.assertTrue(
            (self.mock_model.decode.call_args[0][0][0] - correct_lin_sample <
             1e-5).all())
Esempio n. 4
0
    def test_limits(self):
        lw = vis.LinSpaceWalker(dims_to_walk=[0, 1],
                                lin_steps=3,
                                lin_start=-2,
                                same_image=True)
        lw.data = self.state[tb.X]
        lw.model = self.mock_model
        lw.dev = 'cpu'
        lw.vis(self.state)

        correct_lin_sample = self.codes[0].repeat(6, 1)
        correct_lin_sample[0, 0], correct_lin_sample[1, 0], correct_lin_sample[
            2, 0] = -2.0, -0.5, 1.0
        correct_lin_sample[3, 1], correct_lin_sample[4, 1], correct_lin_sample[
            5, 1] = -2.0, -0.5, 1.0

        self.assertTrue(
            (self.mock_model.decode.call_args[0][0][0] - correct_lin_sample <
             1e-5).all())