def randomize_objective_weights(objective_weights: Tensor, **acquisition_function_kwargs: Any) -> Tensor: """Generate a random weighting based on acquisition function settings. Args: objective_weights: Base weights to multiply by random values.. **acquisition_function_kwargs: Kwargs containing weight generation algorithm options. Returns: A normalized list of indices such that each index is between `0` and `d-1`. """ # Set distribution and sample weights. distribution = acquisition_function_kwargs.get( "random_scalarization_distribution", SIMPLEX) dtype = objective_weights.dtype device = objective_weights.device if distribution == SIMPLEX: random_weights = sample_simplex(len(objective_weights), dtype=dtype, device=device).squeeze() elif distribution == HYPERSPHERE: random_weights = torch.abs( sample_hypersphere(len(objective_weights), dtype=dtype, device=device).squeeze()) # pyre-fixme[61]: `random_weights` may not be initialized here. objective_weights = torch.mul(objective_weights, random_weights) return objective_weights
def test_sample_hypersphere(self): for d, n, qmc, seed, dtype in itertools.product( (1, 2, 3), (2, 5), (False, True), (None, 1234), (torch.float, torch.double) ): samples = sample_hypersphere( d=d, n=n, qmc=qmc, seed=seed, device=self.device, dtype=dtype ) self.assertEqual(samples.shape, torch.Size([n, d])) self.assertTrue(torch.max((samples.pow(2).sum(dim=-1) - 1).abs()) < 1e-5) self.assertEqual(samples.device.type, self.device.type) self.assertEqual(samples.dtype, dtype)
def gen_pareto_front(self, n: int) -> Tensor: r"""Generate `n` pareto optimal points. The pareto points are randomly sampled from the hypersphere's positive section. """ f_X = sample_hypersphere( n=n, d=self.num_objectives, dtype=self.ref_point.dtype, device=self.ref_point.device, qmc=True, ).abs() if self.negate: f_X *= -1 return f_X
def sample_hypersphere_positive_quadrant(dim: int) -> Tensor: """Sample uniformly from the positive quadrant of a dim-sphere.""" return torch.abs( botorch_sampling.sample_hypersphere(dim, dtype=torch.double).squeeze())