def test_flip_sub_unique(self): for dtype in (torch.float, torch.double): tkwargs = {"device": self.device, "dtype": dtype} x = torch.tensor([0.69, 0.75, 0.69, 0.21, 0.86, 0.21], **tkwargs) y = _flip_sub_unique(x=x, k=1) y_exp = torch.tensor([0.21], **tkwargs) self.assertTrue(torch.allclose(y, y_exp)) y = _flip_sub_unique(x=x, k=3) y_exp = torch.tensor([0.21, 0.86, 0.69], **tkwargs) self.assertTrue(torch.allclose(y, y_exp)) y = _flip_sub_unique(x=x, k=10) y_exp = torch.tensor([0.21, 0.86, 0.69, 0.75], **tkwargs) self.assertTrue(torch.allclose(y, y_exp)) # long dtype tkwargs["dtype"] = torch.long x = torch.tensor([1, 6, 4, 3, 6, 3], **tkwargs) y = _flip_sub_unique(x=x, k=1) y_exp = torch.tensor([3], **tkwargs) self.assertTrue(torch.allclose(y, y_exp)) y = _flip_sub_unique(x=x, k=3) y_exp = torch.tensor([3, 6, 4], **tkwargs) self.assertTrue(torch.allclose(y, y_exp)) y = _flip_sub_unique(x=x, k=4) y_exp = torch.tensor([3, 6, 4, 1], **tkwargs) self.assertTrue(torch.allclose(y, y_exp)) y = _flip_sub_unique(x=x, k=10) self.assertTrue(torch.allclose(y, y_exp))
def forward(self, X: Tensor, num_samples: int = 1, observation_noise: bool = False) -> Tensor: r"""Sample from the model posterior. Args: X: A `batch_shape x N x d`-dim Tensor from which to sample (in the `N` dimension) according to the maximum posterior value under the objective. num_samples: The number of samples to draw. observation_noise: If True, sample with observation noise. Returns: A `batch_shape x num_samples x d`-dim Tensor of samples from `X`, where `X[..., i, :]` is the `i`-th sample. """ posterior = self.model.posterior(X, observation_noise=observation_noise) if isinstance(self.objective, ScalarizedObjective): posterior = self.objective(posterior) # num_samples x batch_shape x N x m samples = posterior.rsample(sample_shape=torch.Size([num_samples])) if isinstance(self.objective, ScalarizedObjective): obj = samples.squeeze(-1) # num_samples x batch_shape x N else: obj = self.objective(samples, X=X) # num_samples x batch_shape x N if self.replacement: # if we allow replacement then things are simple(r) idcs = torch.argmax(obj, dim=-1) else: # if we need to deduplicate we have to do some tensor acrobatics # first we get the indices associated w/ the num_samples top samples _, idcs_full = torch.topk(obj, num_samples, dim=-1) # generate some indices to smartly index into the lower triangle of # idcs_full (broadcasting across batch dimensions) ridx, cindx = torch.tril_indices(num_samples, num_samples) # pick the unique indices in order - since we look at the lower triangle # of the index matrix and we don't sort, this achieves deduplication sub_idcs = idcs_full[ridx, ..., cindx] if sub_idcs.ndim == 1: idcs = _flip_sub_unique(sub_idcs, num_samples) elif sub_idcs.ndim == 2: # TODO: Find a better way to do this n_b = sub_idcs.size(-1) idcs = torch.stack( [ _flip_sub_unique(sub_idcs[:, i], num_samples) for i in range(n_b) ], dim=-1, ) else: # TODO: Find a general way to do this efficiently. raise NotImplementedError( "MaxPosteriorSampling without replacement for more than a single " "batch dimension is not yet implemented.") # idcs is num_samples x batch_shape, to index into X we need to permute for it # to have shape batch_shape x num_samples if idcs.ndim > 1: idcs = idcs.permute(*range(1, idcs.ndim), 0) # in order to use gather, we need to repeat the index tensor d times idcs = idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1)) # now if the model is batched batch_shape will not necessarily be the # batch_shape of X, so we expand X to the proper shape Xe = X.expand(*obj.shape[1:], X.size(-1)) # finally we can gather along the N dimension return torch.gather(Xe, -2, idcs)