Ejemplo n.º 1
0
 def test_batch_decoupled_sampler(self):
     train_n = 6
     dim = 2
     transform = None
     num_fantasies = 3
     num_fant_X = 4
     model = SingleTaskGP(
         torch.rand(train_n, dim),
         torch.rand(train_n, 1),
         outcome_transform=transform,
     )
     fantasy_X = torch.rand(num_fant_X, 1, dim)
     fantasy_model = model.fantasize(fantasy_X,
                                     IIDNormalSampler(num_fantasies))
     sample_shape = [5]
     num_basis = 16
     sampler = decoupled_sampler(model=fantasy_model,
                                 sample_shape=sample_shape,
                                 num_basis=num_basis)
     num_test = 8
     test_X = torch.rand(num_test, dim)
     ds_samples = sampler(test_X)
     self.assertEqual(
         list(ds_samples.shape),
         sample_shape + [num_fantasies, num_fant_X, num_test, 1],
     )
Ejemplo n.º 2
0
    def test_gen_value_function_initial_conditions(self):
        num_fantasies = 2
        num_solutions = 3
        num_restarts = 4
        raw_samples = 5
        n_train = 6
        dim = 2
        dtype = torch.float
        # run a thorough test with dtype float
        train_X = torch.rand(n_train, dim, device=self.device, dtype=dtype)
        train_Y = torch.rand(n_train, 1, device=self.device, dtype=dtype)
        model = SingleTaskGP(train_X, train_Y)
        fant_X = torch.rand(num_solutions,
                            1,
                            dim,
                            device=self.device,
                            dtype=dtype)
        fantasy_model = model.fantasize(fant_X,
                                        IIDNormalSampler(num_fantasies))
        bounds = torch.tensor([[0, 0], [1, 1]],
                              device=self.device,
                              dtype=dtype)
        value_function = PosteriorMean(fantasy_model)
        # test option error
        with self.assertRaises(ValueError):
            gen_value_function_initial_conditions(
                acq_function=value_function,
                bounds=bounds,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                current_model=model,
                options={"frac_random": 2.0},
            )
        # test output shape
        ics = gen_value_function_initial_conditions(
            acq_function=value_function,
            bounds=bounds,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            current_model=model,
        )
        self.assertEqual(
            ics.shape,
            torch.Size([num_restarts, num_fantasies, num_solutions, 1, dim]))
        # test bounds
        self.assertTrue(torch.all(ics >= bounds[0]))
        self.assertTrue(torch.all(ics <= bounds[1]))
        # test dtype
        self.assertEqual(dtype, ics.dtype)

        # minimal test cases for when all raw samples are random, with dtype double
        dtype = torch.double
        n_train = 2
        dim = 1
        num_solutions = 1
        train_X = torch.rand(n_train, dim, device=self.device, dtype=dtype)
        train_Y = torch.rand(n_train, 1, device=self.device, dtype=dtype)
        model = SingleTaskGP(train_X, train_Y)
        fant_X = torch.rand(1, 1, dim, device=self.device, dtype=dtype)
        fantasy_model = model.fantasize(fant_X,
                                        IIDNormalSampler(num_fantasies))
        bounds = torch.tensor([[0], [1]], device=self.device, dtype=dtype)
        value_function = PosteriorMean(fantasy_model)
        ics = gen_value_function_initial_conditions(
            acq_function=value_function,
            bounds=bounds,
            num_restarts=1,
            raw_samples=1,
            current_model=model,
            options={"frac_random": 0.99},
        )
        self.assertEqual(ics.shape,
                         torch.Size([1, num_fantasies, num_solutions, 1, dim]))
        # test bounds
        self.assertTrue(torch.all(ics >= bounds[0]))
        self.assertTrue(torch.all(ics <= bounds[1]))
        # test dtype
        self.assertEqual(dtype, ics.dtype)