def _update_pareto_Y(self) -> bool: r"""Update the non-dominated front. Returns: A boolean indicating whether the Pareto frontier has changed. """ # is_non_dominated assumes maximization if self._neg_Y.shape[-2] == 0: pareto_Y = self._neg_Y else: # assumes maximization pareto_Y = -_pad_batch_pareto_frontier( Y=self.Y, ref_point=_expand_ref_point( ref_point=self.ref_point, batch_shape=self.batch_shape ), ) if self.sort: # sort by first objective if len(self.batch_shape) > 0: pareto_Y = pareto_Y.gather( index=torch.argsort(pareto_Y[..., :1], dim=-2).expand( pareto_Y.shape ), dim=-2, ) else: pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])] if not hasattr(self, "_neg_pareto_Y") or not torch.equal( pareto_Y, self._neg_pareto_Y ): self.register_buffer("_neg_pareto_Y", pareto_Y) return True return False
def test_pad_batch_pareto_frontier(self): for dtype in (torch.float, torch.double): Y1 = torch.tensor( [ [1.0, 5.0], [10.0, 3.0], [4.0, 5.0], [4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0], [8.0, 1.0], ], dtype=dtype, device=self.device, ) Y2 = torch.tensor( [ [1.0, 9.0], [10.0, 3.0], [4.0, 5.0], [4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.5], [8.5, 3.0], [9.0, 5.0], [9.0, 4.0], ], dtype=dtype, device=self.device, ) Y = torch.stack([Y1, Y2], dim=0) ref_point = torch.full((2, 2), 2.0, dtype=dtype, device=self.device) padded_pareto = _pad_batch_pareto_frontier(Y=Y, ref_point=ref_point, is_pareto=False) expected_nondom_Y1 = torch.tensor( [[10.0, 3.0], [5.0, 5.0], [8.5, 3.5]], dtype=dtype, device=self.device, ) expected_padded_nondom_Y2 = torch.tensor( [ [10.0, 3.0], [9.0, 5.0], [9.0, 5.0], ], dtype=dtype, device=self.device, ) expected_padded_pareto = torch.stack( [expected_nondom_Y1, expected_padded_nondom_Y2], dim=0) self.assertTrue(torch.equal(padded_pareto, expected_padded_pareto)) # test feasibility mask feas = (Y >= 9.0).any(dim=-1) expected_nondom_Y1 = torch.tensor( [[10.0, 3.0], [10.0, 3.0]], dtype=dtype, device=self.device, ) expected_padded_nondom_Y2 = torch.tensor( [[10.0, 3.0], [9.0, 5.0]], dtype=dtype, device=self.device, ) expected_padded_pareto = torch.stack( [expected_nondom_Y1, expected_padded_nondom_Y2], dim=0) padded_pareto = _pad_batch_pareto_frontier(Y=Y, ref_point=ref_point, feasibility_mask=feas, is_pareto=False) self.assertTrue(torch.equal(padded_pareto, expected_padded_pareto)) # test is_pareto=True # one row of Y2 should be dropped because it is not better than the # reference point Y1 = torch.tensor( [[10.0, 3.0], [5.0, 5.0], [8.5, 3.5]], dtype=dtype, device=self.device, ) Y2 = torch.tensor( [ [1.0, 9.0], [10.0, 3.0], [9.0, 5.0], ], dtype=dtype, device=self.device, ) Y = torch.stack([Y1, Y2], dim=0) expected_padded_pareto = torch.stack( [ Y1, torch.cat([Y2[1:], Y2[-1:]], dim=0), ], dim=0, ) padded_pareto = _pad_batch_pareto_frontier(Y=Y, ref_point=ref_point, is_pareto=True) self.assertTrue(torch.equal(padded_pareto, expected_padded_pareto)) # test multiple batch dims with self.assertRaises(UnsupportedError): _pad_batch_pareto_frontier(Y=Y.unsqueeze(0), ref_point=ref_point, is_pareto=False)
def _set_cell_bounds(self, num_new_points: int) -> None: r"""Compute the box decomposition under each posterior sample. Args: num_new_points: The number of new points (beyond the points in X_baseline) that were used in the previous box decomposition. In the first box decomposition, this should be the number of points in X_baseline. """ feas = None if self.X_baseline.shape[0] > 0: with torch.no_grad(): posterior = self.model.posterior(self.X_baseline) # Reset sampler, accounting for possible one-to-many transform. self.q_in = -1 n_w = posterior.event_shape[-2] // self.X_baseline.shape[-2] self._set_sampler(q_in=num_new_points * n_w, posterior=posterior) # set base_sampler self.base_sampler.register_buffer( "base_samples", self.sampler.base_samples.detach().clone()) samples = self.base_sampler(posterior) # cache posterior if self._cache_root: self._cache_root_decomposition(posterior=posterior) obj = self.objective(samples, X=self.X_baseline) if self.constraints is not None: feas = torch.stack([c(samples) <= 0 for c in self.constraints], dim=0).all(dim=0) else: obj = torch.empty( *self.sampler._sample_shape, 0, self.ref_point.shape[-1], dtype=self.ref_point.dtype, device=self.ref_point.device, ) self._batch_sample_shape = obj.shape[:-2] # collapse batch dimensions # use numel() rather than view(-1) to handle case of no baseline points new_batch_shape = self._batch_sample_shape.numel() obj = obj.view(new_batch_shape, *obj.shape[-2:]) if self.constraints is not None and feas is not None: feas = feas.view(new_batch_shape, *feas.shape[-1:]) if self.partitioning is None and not self.incremental_nehvi: self._compute_initial_hvs(obj=obj, feas=feas) if self.ref_point.shape[-1] > 2: # the partitioning algorithms run faster on the CPU # due to advanced indexing ref_point_cpu = self.ref_point.cpu() obj_cpu = obj.cpu() if self.constraints is not None and feas is not None: feas_cpu = feas.cpu() obj_cpu = [ obj_cpu[i][feas_cpu[i]] for i in range(obj.shape[0]) ] partitionings = [] for sample in obj_cpu: partitioning = self.p_class(ref_point=ref_point_cpu, Y=sample, **self.p_kwargs) partitionings.append(partitioning) self.partitioning = BoxDecompositionList(*partitionings) else: # use batched partitioning obj = _pad_batch_pareto_frontier( Y=obj, ref_point=self.ref_point.unsqueeze(0).expand( obj.shape[0], self.ref_point.shape[-1]), feasibility_mask=feas, ) self.partitioning = self.p_class(ref_point=self.ref_point, Y=obj, **self.p_kwargs) cell_bounds = self.partitioning.get_hypercell_bounds().to( self.ref_point) cell_bounds = cell_bounds.view(2, *self._batch_sample_shape, *cell_bounds.shape[-2:]) self.register_buffer("cell_lower_bounds", cell_bounds[0]) self.register_buffer("cell_upper_bounds", cell_bounds[1])
def test_compute_hypercell_bounds_2d(self): ref_point_raw = torch.zeros(2, device=self.device) arange = torch.arange(3, 9, device=self.device) pareto_Y_raw = torch.stack([arange, 11 - arange], dim=-1) inf = float("inf") for method in ( compute_non_dominated_hypercell_bounds_2d, compute_dominated_hypercell_bounds_2d, ): if method == compute_non_dominated_hypercell_bounds_2d: expected_cell_bounds_raw = torch.tensor( [ [ [0.0, 8.0], [3.0, 7.0], [4.0, 6.0], [5.0, 5.0], [6.0, 4.0], [7.0, 3.0], [8.0, 0.0], ], [ [3.0, inf], [4.0, inf], [5.0, inf], [6.0, inf], [7.0, inf], [8.0, inf], [inf, inf], ], ], device=self.device, ) else: expected_cell_bounds_raw = torch.tensor( [ [ [0.0, 0.0], [3.0, 0.0], [4.0, 0.0], [5.0, 0.0], [6.0, 0.0], [7.0, 0.0], ], [ [3.0, 8.0], [4.0, 7.0], [5.0, 6.0], [6.0, 5.0], [7.0, 4.0], [8.0, 3.0], ], ], device=self.device, ) for dtype in (torch.float, torch.double): pareto_Y = pareto_Y_raw.to(dtype=dtype) ref_point = ref_point_raw.to(dtype=dtype) expected_cell_bounds = expected_cell_bounds_raw.to(dtype=dtype) # test non-batch cell_bounds = method( pareto_Y_sorted=pareto_Y, ref_point=ref_point, ) self.assertTrue(torch.equal(cell_bounds, expected_cell_bounds)) # test batch pareto_Y_batch = torch.stack( [pareto_Y, pareto_Y + pareto_Y.max(dim=-2).values], dim=0) # filter out points that are not better than ref_point ref_point = pareto_Y.max(dim=-2).values pareto_Y_batch = _pad_batch_pareto_frontier( Y=pareto_Y_batch, ref_point=ref_point, is_pareto=True) # sort pareto_Y_batch pareto_Y_batch = pareto_Y_batch.gather( index=torch.argsort(pareto_Y_batch[..., :1], dim=-2).expand(pareto_Y_batch.shape), dim=-2, ) cell_bounds = method( ref_point=ref_point, pareto_Y_sorted=pareto_Y_batch, ) # check hypervolume max_vals = (pareto_Y + pareto_Y).max(dim=-2).values if method == compute_non_dominated_hypercell_bounds_2d: clamped_cell_bounds = torch.min(cell_bounds, max_vals) total_hv = (max_vals - ref_point).prod() nondom_hv = ((clamped_cell_bounds[1] - clamped_cell_bounds[0]).prod(dim=-1).sum( dim=-1)) hv = total_hv - nondom_hv else: hv = (cell_bounds[1] - cell_bounds[0]).prod(dim=-1).sum(dim=-1) self.assertEqual(hv[0].item(), 0.0) self.assertEqual(hv[1].item(), 49.0)