def test_single_dim_zero_init(self): lw = vis.LinSpaceWalker(dims_to_walk=[1], lin_steps=3, zero_init=True) lw.data = self.state[tb.X] lw.model = self.mock_model lw.dev = 'cpu' lw.vis(self.state) correct_lin_sample = torch.zeros(self.codes[0].shape).repeat(3, 1) correct_lin_sample[0, 1], correct_lin_sample[1, 1], correct_lin_sample[ 2, 1] = -1.0, 0.0, 1.0 self.assertTrue( (self.mock_model.decode.call_args[0][0][0] - correct_lin_sample < 1e-5).all())
def test_single_dim_alt(self): lw = vis.LinSpaceWalker(dims_to_walk=[0], lin_steps=3) lw.data = self.state[tb.X] lw.model = self.mock_model lw.dev = 'cpu' lw.vis(self.state) correct_lin_sample = self.codes[0].repeat(3, 1) correct_lin_sample[0, 0], correct_lin_sample[1, 0], correct_lin_sample[ 2, 0] = -1.0, 0.0, 1.0 self.assertTrue( (self.mock_model.decode.call_args[0][0][0] - correct_lin_sample[:3] < 1e-5).all())
def test_multi_dim(self): lw = vis.LinSpaceWalker(dims_to_walk=[0, 1], lin_steps=3) lw.data = self.state[tb.X] lw.model = self.mock_model lw.dev = 'cpu' lw.vis(self.state) correct_lin_sample = torch.cat( [self.codes[0].repeat(3, 1), self.codes[1].repeat(3, 1)]) correct_lin_sample[0, 0], correct_lin_sample[1, 0], correct_lin_sample[ 2, 0] = -1.0, 0.0, 1.0 correct_lin_sample[3, 1], correct_lin_sample[4, 1], correct_lin_sample[ 5, 1] = -1.0, 0.0, 1.0 self.assertTrue( (self.mock_model.decode.call_args[0][0][0] - correct_lin_sample < 1e-5).all())
def test_limits(self): lw = vis.LinSpaceWalker(dims_to_walk=[0, 1], lin_steps=3, lin_start=-2, same_image=True) lw.data = self.state[tb.X] lw.model = self.mock_model lw.dev = 'cpu' lw.vis(self.state) correct_lin_sample = self.codes[0].repeat(6, 1) correct_lin_sample[0, 0], correct_lin_sample[1, 0], correct_lin_sample[ 2, 0] = -2.0, -0.5, 1.0 correct_lin_sample[3, 1], correct_lin_sample[4, 1], correct_lin_sample[ 5, 1] = -2.0, -0.5, 1.0 self.assertTrue( (self.mock_model.decode.call_args[0][0][0] - correct_lin_sample < 1e-5).all())