Ejemplo n.º 1
0
    def _update_pareto_Y(self) -> bool:
        r"""Update the non-dominated front.

        Returns:
            A boolean indicating whether the Pareto frontier has changed.
        """
        # is_non_dominated assumes maximization
        if self._neg_Y.shape[-2] == 0:
            pareto_Y = self._neg_Y
        else:
            # assumes maximization
            pareto_Y = -_pad_batch_pareto_frontier(
                Y=self.Y,
                ref_point=_expand_ref_point(
                    ref_point=self.ref_point, batch_shape=self.batch_shape
                ),
            )
            if self.sort:
                # sort by first objective
                if len(self.batch_shape) > 0:
                    pareto_Y = pareto_Y.gather(
                        index=torch.argsort(pareto_Y[..., :1], dim=-2).expand(
                            pareto_Y.shape
                        ),
                        dim=-2,
                    )
                else:
                    pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])]

        if not hasattr(self, "_neg_pareto_Y") or not torch.equal(
            pareto_Y, self._neg_pareto_Y
        ):
            self.register_buffer("_neg_pareto_Y", pareto_Y)
            return True
        return False
Ejemplo n.º 2
0
    def test_pad_batch_pareto_frontier(self):
        for dtype in (torch.float, torch.double):
            Y1 = torch.tensor(
                [
                    [1.0, 5.0],
                    [10.0, 3.0],
                    [4.0, 5.0],
                    [4.0, 5.0],
                    [5.0, 5.0],
                    [8.5, 3.5],
                    [8.5, 3.5],
                    [8.5, 3.0],
                    [9.0, 1.0],
                    [8.0, 1.0],
                ],
                dtype=dtype,
                device=self.device,
            )

            Y2 = torch.tensor(
                [
                    [1.0, 9.0],
                    [10.0, 3.0],
                    [4.0, 5.0],
                    [4.0, 5.0],
                    [5.0, 5.0],
                    [8.5, 3.5],
                    [8.5, 3.5],
                    [8.5, 3.0],
                    [9.0, 5.0],
                    [9.0, 4.0],
                ],
                dtype=dtype,
                device=self.device,
            )
            Y = torch.stack([Y1, Y2], dim=0)
            ref_point = torch.full((2, 2),
                                   2.0,
                                   dtype=dtype,
                                   device=self.device)
            padded_pareto = _pad_batch_pareto_frontier(Y=Y,
                                                       ref_point=ref_point,
                                                       is_pareto=False)
            expected_nondom_Y1 = torch.tensor(
                [[10.0, 3.0], [5.0, 5.0], [8.5, 3.5]],
                dtype=dtype,
                device=self.device,
            )
            expected_padded_nondom_Y2 = torch.tensor(
                [
                    [10.0, 3.0],
                    [9.0, 5.0],
                    [9.0, 5.0],
                ],
                dtype=dtype,
                device=self.device,
            )
            expected_padded_pareto = torch.stack(
                [expected_nondom_Y1, expected_padded_nondom_Y2], dim=0)
            self.assertTrue(torch.equal(padded_pareto, expected_padded_pareto))

            # test feasibility mask
            feas = (Y >= 9.0).any(dim=-1)
            expected_nondom_Y1 = torch.tensor(
                [[10.0, 3.0], [10.0, 3.0]],
                dtype=dtype,
                device=self.device,
            )
            expected_padded_nondom_Y2 = torch.tensor(
                [[10.0, 3.0], [9.0, 5.0]],
                dtype=dtype,
                device=self.device,
            )
            expected_padded_pareto = torch.stack(
                [expected_nondom_Y1, expected_padded_nondom_Y2], dim=0)
            padded_pareto = _pad_batch_pareto_frontier(Y=Y,
                                                       ref_point=ref_point,
                                                       feasibility_mask=feas,
                                                       is_pareto=False)
            self.assertTrue(torch.equal(padded_pareto, expected_padded_pareto))

            # test is_pareto=True
            # one row of Y2 should be dropped because it is not better than the
            # reference point
            Y1 = torch.tensor(
                [[10.0, 3.0], [5.0, 5.0], [8.5, 3.5]],
                dtype=dtype,
                device=self.device,
            )
            Y2 = torch.tensor(
                [
                    [1.0, 9.0],
                    [10.0, 3.0],
                    [9.0, 5.0],
                ],
                dtype=dtype,
                device=self.device,
            )
            Y = torch.stack([Y1, Y2], dim=0)
            expected_padded_pareto = torch.stack(
                [
                    Y1,
                    torch.cat([Y2[1:], Y2[-1:]], dim=0),
                ],
                dim=0,
            )
            padded_pareto = _pad_batch_pareto_frontier(Y=Y,
                                                       ref_point=ref_point,
                                                       is_pareto=True)
            self.assertTrue(torch.equal(padded_pareto, expected_padded_pareto))

        # test multiple batch dims
        with self.assertRaises(UnsupportedError):
            _pad_batch_pareto_frontier(Y=Y.unsqueeze(0),
                                       ref_point=ref_point,
                                       is_pareto=False)
Ejemplo n.º 3
0
    def _set_cell_bounds(self, num_new_points: int) -> None:
        r"""Compute the box decomposition under each posterior sample.

        Args:
            num_new_points: The number of new points (beyond the points
                in X_baseline) that were used in the previous box decomposition.
                In the first box decomposition, this should be the number of points
                in X_baseline.
        """
        feas = None
        if self.X_baseline.shape[0] > 0:
            with torch.no_grad():
                posterior = self.model.posterior(self.X_baseline)
            # Reset sampler, accounting for possible one-to-many transform.
            self.q_in = -1
            n_w = posterior.event_shape[-2] // self.X_baseline.shape[-2]
            self._set_sampler(q_in=num_new_points * n_w, posterior=posterior)
            # set base_sampler
            self.base_sampler.register_buffer(
                "base_samples",
                self.sampler.base_samples.detach().clone())

            samples = self.base_sampler(posterior)
            # cache posterior
            if self._cache_root:
                self._cache_root_decomposition(posterior=posterior)
            obj = self.objective(samples, X=self.X_baseline)
            if self.constraints is not None:
                feas = torch.stack([c(samples) <= 0 for c in self.constraints],
                                   dim=0).all(dim=0)
        else:
            obj = torch.empty(
                *self.sampler._sample_shape,
                0,
                self.ref_point.shape[-1],
                dtype=self.ref_point.dtype,
                device=self.ref_point.device,
            )
        self._batch_sample_shape = obj.shape[:-2]
        # collapse batch dimensions
        # use numel() rather than view(-1) to handle case of no baseline points
        new_batch_shape = self._batch_sample_shape.numel()
        obj = obj.view(new_batch_shape, *obj.shape[-2:])
        if self.constraints is not None and feas is not None:
            feas = feas.view(new_batch_shape, *feas.shape[-1:])

        if self.partitioning is None and not self.incremental_nehvi:
            self._compute_initial_hvs(obj=obj, feas=feas)
        if self.ref_point.shape[-1] > 2:
            # the partitioning algorithms run faster on the CPU
            # due to advanced indexing
            ref_point_cpu = self.ref_point.cpu()
            obj_cpu = obj.cpu()
            if self.constraints is not None and feas is not None:
                feas_cpu = feas.cpu()
                obj_cpu = [
                    obj_cpu[i][feas_cpu[i]] for i in range(obj.shape[0])
                ]
            partitionings = []
            for sample in obj_cpu:
                partitioning = self.p_class(ref_point=ref_point_cpu,
                                            Y=sample,
                                            **self.p_kwargs)
                partitionings.append(partitioning)
            self.partitioning = BoxDecompositionList(*partitionings)
        else:
            # use batched partitioning
            obj = _pad_batch_pareto_frontier(
                Y=obj,
                ref_point=self.ref_point.unsqueeze(0).expand(
                    obj.shape[0], self.ref_point.shape[-1]),
                feasibility_mask=feas,
            )
            self.partitioning = self.p_class(ref_point=self.ref_point,
                                             Y=obj,
                                             **self.p_kwargs)
        cell_bounds = self.partitioning.get_hypercell_bounds().to(
            self.ref_point)
        cell_bounds = cell_bounds.view(2, *self._batch_sample_shape,
                                       *cell_bounds.shape[-2:])
        self.register_buffer("cell_lower_bounds", cell_bounds[0])
        self.register_buffer("cell_upper_bounds", cell_bounds[1])
Ejemplo n.º 4
0
 def test_compute_hypercell_bounds_2d(self):
     ref_point_raw = torch.zeros(2, device=self.device)
     arange = torch.arange(3, 9, device=self.device)
     pareto_Y_raw = torch.stack([arange, 11 - arange], dim=-1)
     inf = float("inf")
     for method in (
             compute_non_dominated_hypercell_bounds_2d,
             compute_dominated_hypercell_bounds_2d,
     ):
         if method == compute_non_dominated_hypercell_bounds_2d:
             expected_cell_bounds_raw = torch.tensor(
                 [
                     [
                         [0.0, 8.0],
                         [3.0, 7.0],
                         [4.0, 6.0],
                         [5.0, 5.0],
                         [6.0, 4.0],
                         [7.0, 3.0],
                         [8.0, 0.0],
                     ],
                     [
                         [3.0, inf],
                         [4.0, inf],
                         [5.0, inf],
                         [6.0, inf],
                         [7.0, inf],
                         [8.0, inf],
                         [inf, inf],
                     ],
                 ],
                 device=self.device,
             )
         else:
             expected_cell_bounds_raw = torch.tensor(
                 [
                     [
                         [0.0, 0.0],
                         [3.0, 0.0],
                         [4.0, 0.0],
                         [5.0, 0.0],
                         [6.0, 0.0],
                         [7.0, 0.0],
                     ],
                     [
                         [3.0, 8.0],
                         [4.0, 7.0],
                         [5.0, 6.0],
                         [6.0, 5.0],
                         [7.0, 4.0],
                         [8.0, 3.0],
                     ],
                 ],
                 device=self.device,
             )
         for dtype in (torch.float, torch.double):
             pareto_Y = pareto_Y_raw.to(dtype=dtype)
             ref_point = ref_point_raw.to(dtype=dtype)
             expected_cell_bounds = expected_cell_bounds_raw.to(dtype=dtype)
             # test non-batch
             cell_bounds = method(
                 pareto_Y_sorted=pareto_Y,
                 ref_point=ref_point,
             )
             self.assertTrue(torch.equal(cell_bounds, expected_cell_bounds))
             # test batch
             pareto_Y_batch = torch.stack(
                 [pareto_Y, pareto_Y + pareto_Y.max(dim=-2).values], dim=0)
             # filter out points that are not better than ref_point
             ref_point = pareto_Y.max(dim=-2).values
             pareto_Y_batch = _pad_batch_pareto_frontier(
                 Y=pareto_Y_batch, ref_point=ref_point, is_pareto=True)
             # sort pareto_Y_batch
             pareto_Y_batch = pareto_Y_batch.gather(
                 index=torch.argsort(pareto_Y_batch[..., :1],
                                     dim=-2).expand(pareto_Y_batch.shape),
                 dim=-2,
             )
             cell_bounds = method(
                 ref_point=ref_point,
                 pareto_Y_sorted=pareto_Y_batch,
             )
             # check hypervolume
             max_vals = (pareto_Y + pareto_Y).max(dim=-2).values
             if method == compute_non_dominated_hypercell_bounds_2d:
                 clamped_cell_bounds = torch.min(cell_bounds, max_vals)
                 total_hv = (max_vals - ref_point).prod()
                 nondom_hv = ((clamped_cell_bounds[1] -
                               clamped_cell_bounds[0]).prod(dim=-1).sum(
                                   dim=-1))
                 hv = total_hv - nondom_hv
             else:
                 hv = (cell_bounds[1] -
                       cell_bounds[0]).prod(dim=-1).sum(dim=-1)
             self.assertEqual(hv[0].item(), 0.0)
             self.assertEqual(hv[1].item(), 49.0)