Example #1
0
 def test_flip_sub_unique(self):
     for dtype in (torch.float, torch.double):
         tkwargs = {"device": self.device, "dtype": dtype}
         x = torch.tensor([0.69, 0.75, 0.69, 0.21, 0.86, 0.21], **tkwargs)
         y = _flip_sub_unique(x=x, k=1)
         y_exp = torch.tensor([0.21], **tkwargs)
         self.assertTrue(torch.allclose(y, y_exp))
         y = _flip_sub_unique(x=x, k=3)
         y_exp = torch.tensor([0.21, 0.86, 0.69], **tkwargs)
         self.assertTrue(torch.allclose(y, y_exp))
         y = _flip_sub_unique(x=x, k=10)
         y_exp = torch.tensor([0.21, 0.86, 0.69, 0.75], **tkwargs)
         self.assertTrue(torch.allclose(y, y_exp))
     # long dtype
     tkwargs["dtype"] = torch.long
     x = torch.tensor([1, 6, 4, 3, 6, 3], **tkwargs)
     y = _flip_sub_unique(x=x, k=1)
     y_exp = torch.tensor([3], **tkwargs)
     self.assertTrue(torch.allclose(y, y_exp))
     y = _flip_sub_unique(x=x, k=3)
     y_exp = torch.tensor([3, 6, 4], **tkwargs)
     self.assertTrue(torch.allclose(y, y_exp))
     y = _flip_sub_unique(x=x, k=4)
     y_exp = torch.tensor([3, 6, 4, 1], **tkwargs)
     self.assertTrue(torch.allclose(y, y_exp))
     y = _flip_sub_unique(x=x, k=10)
     self.assertTrue(torch.allclose(y, y_exp))
Example #2
0
    def forward(self,
                X: Tensor,
                num_samples: int = 1,
                observation_noise: bool = False) -> Tensor:
        r"""Sample from the model posterior.

        Args:
            X: A `batch_shape x N x d`-dim Tensor from which to sample (in the `N`
                dimension) according to the maximum posterior value under the objective.
            num_samples: The number of samples to draw.
            observation_noise: If True, sample with observation noise.

        Returns:
            A `batch_shape x num_samples x d`-dim Tensor of samples from `X`, where
            `X[..., i, :]` is the `i`-th sample.
        """
        posterior = self.model.posterior(X,
                                         observation_noise=observation_noise)
        if isinstance(self.objective, ScalarizedObjective):
            posterior = self.objective(posterior)

        # num_samples x batch_shape x N x m
        samples = posterior.rsample(sample_shape=torch.Size([num_samples]))
        if isinstance(self.objective, ScalarizedObjective):
            obj = samples.squeeze(-1)  # num_samples x batch_shape x N
        else:
            obj = self.objective(samples, X=X)  # num_samples x batch_shape x N
        if self.replacement:
            # if we allow replacement then things are simple(r)
            idcs = torch.argmax(obj, dim=-1)
        else:
            # if we need to deduplicate we have to do some tensor acrobatics
            # first we get the indices associated w/ the num_samples top samples
            _, idcs_full = torch.topk(obj, num_samples, dim=-1)
            # generate some indices to smartly index into the lower triangle of
            # idcs_full (broadcasting across batch dimensions)
            ridx, cindx = torch.tril_indices(num_samples, num_samples)
            # pick the unique indices in order - since we look at the lower triangle
            # of the index matrix and we don't sort, this achieves deduplication
            sub_idcs = idcs_full[ridx, ..., cindx]
            if sub_idcs.ndim == 1:
                idcs = _flip_sub_unique(sub_idcs, num_samples)
            elif sub_idcs.ndim == 2:
                # TODO: Find a better way to do this
                n_b = sub_idcs.size(-1)
                idcs = torch.stack(
                    [
                        _flip_sub_unique(sub_idcs[:, i], num_samples)
                        for i in range(n_b)
                    ],
                    dim=-1,
                )
            else:
                # TODO: Find a general way to do this efficiently.
                raise NotImplementedError(
                    "MaxPosteriorSampling without replacement for more than a single "
                    "batch dimension is not yet implemented.")
        # idcs is num_samples x batch_shape, to index into X we need to permute for it
        # to have shape batch_shape x num_samples
        if idcs.ndim > 1:
            idcs = idcs.permute(*range(1, idcs.ndim), 0)
        # in order to use gather, we need to repeat the index tensor d times
        idcs = idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1))
        # now if the model is batched batch_shape will not necessarily be the
        # batch_shape of X, so we expand X to the proper shape
        Xe = X.expand(*obj.shape[1:], X.size(-1))
        # finally we can gather along the N dimension
        return torch.gather(Xe, -2, idcs)