示例#1
0
    def test_optimize_acqf_joint(self, mock_gen_candidates,
                                 mock_gen_batch_initial_conditions):
        q = 3
        num_restarts = 2
        raw_samples = 10
        options = {}
        mock_acq_function = MockAcquisitionFunction()
        cnt = 1
        for dtype in (torch.float, torch.double):
            mock_gen_batch_initial_conditions.return_value = torch.zeros(
                num_restarts, q, 3, device=self.device, dtype=dtype)
            base_cand = torch.ones(1, q, 3, device=self.device, dtype=dtype)
            mock_candidates = torch.cat(
                [i * base_cand for i in range(num_restarts)], dim=0)
            mock_acq_values = num_restarts - torch.arange(
                num_restarts, device=self.device, dtype=dtype)
            mock_gen_candidates.return_value = (mock_candidates,
                                                mock_acq_values)
            bounds = torch.stack([
                torch.zeros(3, device=self.device, dtype=dtype),
                4 * torch.ones(3, device=self.device, dtype=dtype),
            ])
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
            )
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))

            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                return_best_only=False,
                batch_initial_conditions=torch.zeros(num_restarts,
                                                     q,
                                                     3,
                                                     device=self.device,
                                                     dtype=dtype),
            )
            self.assertTrue(torch.equal(candidates, mock_candidates))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values))
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)
            cnt += 1

        # test OneShotAcquisitionFunction
        mock_acq_function = MockOneShotAcquisitionFunction()
        candidates, acq_vals = optimize_acqf(
            acq_function=mock_acq_function,
            bounds=bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            options=options,
        )
        self.assertTrue(
            torch.equal(
                candidates,
                mock_acq_function.extract_candidates(mock_candidates[0])))
        self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
示例#2
0
    def test_optimize_acqf_joint(
        self, mock_gen_candidates, mock_gen_batch_initial_conditions
    ):
        q = 3
        num_restarts = 2
        raw_samples = 10
        options = {}
        mock_acq_function = MockAcquisitionFunction()
        cnt = 0
        for dtype in (torch.float, torch.double):
            mock_gen_batch_initial_conditions.return_value = torch.zeros(
                num_restarts, q, 3, device=self.device, dtype=dtype
            )
            base_cand = torch.arange(3, device=self.device, dtype=dtype).expand(1, q, 3)
            mock_candidates = torch.cat(
                [i * base_cand for i in range(num_restarts)], dim=0
            )
            mock_acq_values = num_restarts - torch.arange(
                num_restarts, device=self.device, dtype=dtype
            )
            mock_gen_candidates.return_value = (mock_candidates, mock_acq_values)
            bounds = torch.stack(
                [
                    torch.zeros(3, device=self.device, dtype=dtype),
                    4 * torch.ones(3, device=self.device, dtype=dtype),
                ]
            )
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
            )
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
            cnt += 1
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test generation with provided initial conditions
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                return_best_only=False,
                batch_initial_conditions=torch.zeros(
                    num_restarts, q, 3, device=self.device, dtype=dtype
                ),
            )
            self.assertTrue(torch.equal(candidates, mock_candidates))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values))
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test fixed features
            fixed_features = {0: 0.1}
            mock_candidates[:, 0] = 0.1
            mock_gen_candidates.return_value = (mock_candidates, mock_acq_values)
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                fixed_features=fixed_features,
            )
            self.assertEqual(
                mock_gen_candidates.call_args[1]["fixed_features"], fixed_features
            )
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            cnt += 1
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test trivial case when all features are fixed
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                fixed_features={0: 0.1, 1: 0.2, 2: 0.3},
            )
            self.assertTrue(
                torch.equal(
                    candidates,
                    torch.tensor(
                        [0.1, 0.2, 0.3], device=self.device, dtype=dtype
                    ).expand(3, 3),
                )
            )
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

        # test OneShotAcquisitionFunction
        mock_acq_function = MockOneShotAcquisitionFunction()
        candidates, acq_vals = optimize_acqf(
            acq_function=mock_acq_function,
            bounds=bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            options=options,
        )
        self.assertTrue(
            torch.equal(
                candidates, mock_acq_function.extract_candidates(mock_candidates[0])
            )
        )
        self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))