コード例 #1
0
ファイル: input.py プロジェクト: jduerholt/botorch
    def _transform(self, X: Tensor) -> Tensor:
        r"""Warp the inputs through the Kumaraswamy CDF.

        Args:
            X: A `input_batch_shape x (batch_shape) x n x d`-dim tensor of inputs.
                batch_shape here can either be self.batch_shape or 1's such that
                it is broadcastable with self.batch_shape if self.batch_shape is set.

        Returns:
            A `input_batch_shape x (batch_shape) x n x d`-dim tensor of transformed
                inputs.
        """
        X_tf = expand_and_copy_tensor(X=X, batch_shape=self.batch_shape)
        k = Kumaraswamy(
            concentration1=self.concentration1, concentration0=self.concentration0
        )
        # normalize to [eps, 1-eps]
        X_tf[..., self.indices] = k.cdf(
            torch.clamp(
                X_tf[..., self.indices] * self._X_range + self._X_min,
                self._X_min,
                1.0 - self._X_min,
            )
        )
        return X_tf
コード例 #2
0
ファイル: input.py プロジェクト: jduerholt/botorch
    def _untransform(self, X: Tensor) -> Tensor:
        r"""Warp the inputs through the Kumaraswamy inverse CDF.

        Args:
            X: A `input_batch_shape x batch_shape x n x d`-dim tensor of inputs.

        Returns:
            A `input_batch_shape x batch_shape x n x d`-dim tensor of transformed
                inputs.
        """
        if len(self.batch_shape) > 0:
            if self.batch_shape != X.shape[-2 - len(self.batch_shape) : -2]:
                raise BotorchTensorDimensionError(
                    "The right most batch dims of X must match self.batch_shape: "
                    f"({self.batch_shape})."
                )
        X_tf = X.clone()
        k = Kumaraswamy(
            concentration1=self.concentration1, concentration0=self.concentration0
        )
        # unnormalize from [eps, 1-eps] to [0,1]
        X_tf[..., self.indices] = (
            (k.icdf(X_tf[..., self.indices]) - self._X_min) / self._X_range
        ).clamp(0.0, 1.0)
        return X_tf
コード例 #3
0
    def test_warp_transform(self):
        for dtype, batch_shape, warp_batch_shape in itertools.product(
            (torch.float, torch.double),
            (torch.Size(), torch.Size([3])),
            (torch.Size(), torch.Size([2])),
        ):
            tkwargs = {"device": self.device, "dtype": dtype}
            eps = 1e-6 if dtype == torch.double else 1e-5

            # basic init
            indices = [0, 2]
            warp_tf = get_test_warp(indices,
                                    batch_shape=warp_batch_shape,
                                    eps=eps).to(**tkwargs)
            self.assertTrue(warp_tf.training)

            k = Kumaraswamy(warp_tf.concentration1, warp_tf.concentration0)

            self.assertEqual(warp_tf.indices.tolist(), indices)

            # We don't want these data points to end up all the way near zero, since
            # this would cause numerical issues and thus result in a flaky test.
            X = 0.025 + 0.95 * torch.rand(*batch_shape, 4, 3, **tkwargs)
            X = X.unsqueeze(-3) if len(warp_batch_shape) > 0 else X
            with torch.no_grad():
                warp_tf = get_test_warp(indices=indices,
                                        batch_shape=warp_batch_shape,
                                        eps=eps).to(**tkwargs)
                X_tf = warp_tf(X)
                expected_X_tf = expand_and_copy_tensor(
                    X, batch_shape=warp_tf.batch_shape)
                expected_X_tf[...,
                              indices] = k.cdf(expected_X_tf[..., indices] *
                                               warp_tf._X_range +
                                               warp_tf._X_min)

                self.assertTrue(torch.equal(expected_X_tf, X_tf))

                # test untransform
                untransformed_X = warp_tf.untransform(X_tf)
                self.assertTrue(
                    torch.allclose(
                        untransformed_X,
                        expand_and_copy_tensor(
                            X, batch_shape=warp_tf.batch_shape),
                        rtol=1e-3,
                        atol=1e-3
                        if self.device == torch.device("cpu") else 1e-2,
                    ))
                if len(warp_batch_shape) > 0:
                    with self.assertRaises(BotorchTensorDimensionError):
                        warp_tf.untransform(X_tf.unsqueeze(-3))

                # test no transform on eval
                warp_tf = get_test_warp(
                    indices,
                    transform_on_eval=False,
                    batch_shape=warp_batch_shape,
                    eps=eps,
                ).to(**tkwargs)
                X_tf = warp_tf(X)
                self.assertFalse(torch.equal(X, X_tf))
                warp_tf.eval()
                X_tf = warp_tf(X)
                self.assertTrue(torch.equal(X, X_tf))

                # test no transform on train
                warp_tf = get_test_warp(
                    indices=indices,
                    transform_on_train=False,
                    batch_shape=warp_batch_shape,
                    eps=eps,
                ).to(**tkwargs)
                X_tf = warp_tf(X)
                self.assertTrue(torch.equal(X, X_tf))
                warp_tf.eval()
                X_tf = warp_tf(X)
                self.assertFalse(torch.equal(X, X_tf))

                # test equals
                warp_tf2 = get_test_warp(
                    indices=indices,
                    transform_on_train=False,
                    batch_shape=warp_batch_shape,
                    eps=eps,
                ).to(**tkwargs)
                self.assertTrue(warp_tf.equals(warp_tf2))
                # test different transform_on_train
                warp_tf2 = get_test_warp(indices=indices,
                                         batch_shape=warp_batch_shape,
                                         eps=eps)
                self.assertFalse(warp_tf.equals(warp_tf2))
                # test different indices
                warp_tf2 = get_test_warp(
                    indices=[0, 1],
                    transform_on_train=False,
                    batch_shape=warp_batch_shape,
                    eps=eps,
                ).to(**tkwargs)
                self.assertFalse(warp_tf.equals(warp_tf2))

                # test preprocess_transform
                warp_tf.transform_on_train = False
                self.assertTrue(torch.equal(warp_tf.preprocess_transform(X),
                                            X))
                warp_tf.transform_on_train = True
                self.assertTrue(
                    torch.equal(warp_tf.preprocess_transform(X), X_tf))

                # test _set_concentration
                warp_tf._set_concentration(0, warp_tf.concentration0)
                warp_tf._set_concentration(1, warp_tf.concentration1)

                # test concentration prior
                prior0 = LogNormalPrior(0.0, 0.75).to(**tkwargs)
                prior1 = LogNormalPrior(0.0, 0.5).to(**tkwargs)
                warp_tf = get_test_warp(
                    indices=[0, 1],
                    concentration0_prior=prior0,
                    concentration1_prior=prior1,
                    batch_shape=warp_batch_shape,
                    eps=eps,
                )
                for i, (name, _, p, _, _) in enumerate(warp_tf.named_priors()):
                    self.assertEqual(name, f"concentration{i}_prior")
                    self.assertIsInstance(p, LogNormalPrior)
                    self.assertEqual(p.base_dist.scale,
                                     0.75 if i == 0 else 0.5)

            # test gradients
            X = 1 + 5 * torch.rand(*batch_shape, 4, 3, **tkwargs)
            X = X.unsqueeze(-3) if len(warp_batch_shape) > 0 else X
            warp_tf = get_test_warp(indices=indices,
                                    batch_shape=warp_batch_shape,
                                    eps=eps).to(**tkwargs)
            X_tf = warp_tf(X)
            X_tf.sum().backward()
            for grad in (warp_tf.concentration0.grad,
                         warp_tf.concentration1.grad):
                self.assertIsNotNone(grad)
                self.assertFalse(torch.isnan(grad).any())
                self.assertFalse(torch.isinf(grad).any())
                self.assertFalse((grad == 0).all())

            # test set with scalar
            warp_tf._set_concentration(i=0, value=2.0)
            self.assertTrue((warp_tf.concentration0 == 2.0).all())
            warp_tf._set_concentration(i=1, value=3.0)
            self.assertTrue((warp_tf.concentration1 == 3.0).all())