def _transform(self, X: Tensor) -> Tensor: r"""Warp the inputs through the Kumaraswamy CDF. Args: X: A `input_batch_shape x (batch_shape) x n x d`-dim tensor of inputs. batch_shape here can either be self.batch_shape or 1's such that it is broadcastable with self.batch_shape if self.batch_shape is set. Returns: A `input_batch_shape x (batch_shape) x n x d`-dim tensor of transformed inputs. """ X_tf = expand_and_copy_tensor(X=X, batch_shape=self.batch_shape) k = Kumaraswamy( concentration1=self.concentration1, concentration0=self.concentration0 ) # normalize to [eps, 1-eps] X_tf[..., self.indices] = k.cdf( torch.clamp( X_tf[..., self.indices] * self._X_range + self._X_min, self._X_min, 1.0 - self._X_min, ) ) return X_tf
def _untransform(self, X: Tensor) -> Tensor: r"""Warp the inputs through the Kumaraswamy inverse CDF. Args: X: A `input_batch_shape x batch_shape x n x d`-dim tensor of inputs. Returns: A `input_batch_shape x batch_shape x n x d`-dim tensor of transformed inputs. """ if len(self.batch_shape) > 0: if self.batch_shape != X.shape[-2 - len(self.batch_shape) : -2]: raise BotorchTensorDimensionError( "The right most batch dims of X must match self.batch_shape: " f"({self.batch_shape})." ) X_tf = X.clone() k = Kumaraswamy( concentration1=self.concentration1, concentration0=self.concentration0 ) # unnormalize from [eps, 1-eps] to [0,1] X_tf[..., self.indices] = ( (k.icdf(X_tf[..., self.indices]) - self._X_min) / self._X_range ).clamp(0.0, 1.0) return X_tf
def test_warp_transform(self): for dtype, batch_shape, warp_batch_shape in itertools.product( (torch.float, torch.double), (torch.Size(), torch.Size([3])), (torch.Size(), torch.Size([2])), ): tkwargs = {"device": self.device, "dtype": dtype} eps = 1e-6 if dtype == torch.double else 1e-5 # basic init indices = [0, 2] warp_tf = get_test_warp(indices, batch_shape=warp_batch_shape, eps=eps).to(**tkwargs) self.assertTrue(warp_tf.training) k = Kumaraswamy(warp_tf.concentration1, warp_tf.concentration0) self.assertEqual(warp_tf.indices.tolist(), indices) # We don't want these data points to end up all the way near zero, since # this would cause numerical issues and thus result in a flaky test. X = 0.025 + 0.95 * torch.rand(*batch_shape, 4, 3, **tkwargs) X = X.unsqueeze(-3) if len(warp_batch_shape) > 0 else X with torch.no_grad(): warp_tf = get_test_warp(indices=indices, batch_shape=warp_batch_shape, eps=eps).to(**tkwargs) X_tf = warp_tf(X) expected_X_tf = expand_and_copy_tensor( X, batch_shape=warp_tf.batch_shape) expected_X_tf[..., indices] = k.cdf(expected_X_tf[..., indices] * warp_tf._X_range + warp_tf._X_min) self.assertTrue(torch.equal(expected_X_tf, X_tf)) # test untransform untransformed_X = warp_tf.untransform(X_tf) self.assertTrue( torch.allclose( untransformed_X, expand_and_copy_tensor( X, batch_shape=warp_tf.batch_shape), rtol=1e-3, atol=1e-3 if self.device == torch.device("cpu") else 1e-2, )) if len(warp_batch_shape) > 0: with self.assertRaises(BotorchTensorDimensionError): warp_tf.untransform(X_tf.unsqueeze(-3)) # test no transform on eval warp_tf = get_test_warp( indices, transform_on_eval=False, batch_shape=warp_batch_shape, eps=eps, ).to(**tkwargs) X_tf = warp_tf(X) self.assertFalse(torch.equal(X, X_tf)) warp_tf.eval() X_tf = warp_tf(X) self.assertTrue(torch.equal(X, X_tf)) # test no transform on train warp_tf = get_test_warp( indices=indices, transform_on_train=False, batch_shape=warp_batch_shape, eps=eps, ).to(**tkwargs) X_tf = warp_tf(X) self.assertTrue(torch.equal(X, X_tf)) warp_tf.eval() X_tf = warp_tf(X) self.assertFalse(torch.equal(X, X_tf)) # test equals warp_tf2 = get_test_warp( indices=indices, transform_on_train=False, batch_shape=warp_batch_shape, eps=eps, ).to(**tkwargs) self.assertTrue(warp_tf.equals(warp_tf2)) # test different transform_on_train warp_tf2 = get_test_warp(indices=indices, batch_shape=warp_batch_shape, eps=eps) self.assertFalse(warp_tf.equals(warp_tf2)) # test different indices warp_tf2 = get_test_warp( indices=[0, 1], transform_on_train=False, batch_shape=warp_batch_shape, eps=eps, ).to(**tkwargs) self.assertFalse(warp_tf.equals(warp_tf2)) # test preprocess_transform warp_tf.transform_on_train = False self.assertTrue(torch.equal(warp_tf.preprocess_transform(X), X)) warp_tf.transform_on_train = True self.assertTrue( torch.equal(warp_tf.preprocess_transform(X), X_tf)) # test _set_concentration warp_tf._set_concentration(0, warp_tf.concentration0) warp_tf._set_concentration(1, warp_tf.concentration1) # test concentration prior prior0 = LogNormalPrior(0.0, 0.75).to(**tkwargs) prior1 = LogNormalPrior(0.0, 0.5).to(**tkwargs) warp_tf = get_test_warp( indices=[0, 1], concentration0_prior=prior0, concentration1_prior=prior1, batch_shape=warp_batch_shape, eps=eps, ) for i, (name, _, p, _, _) in enumerate(warp_tf.named_priors()): self.assertEqual(name, f"concentration{i}_prior") self.assertIsInstance(p, LogNormalPrior) self.assertEqual(p.base_dist.scale, 0.75 if i == 0 else 0.5) # test gradients X = 1 + 5 * torch.rand(*batch_shape, 4, 3, **tkwargs) X = X.unsqueeze(-3) if len(warp_batch_shape) > 0 else X warp_tf = get_test_warp(indices=indices, batch_shape=warp_batch_shape, eps=eps).to(**tkwargs) X_tf = warp_tf(X) X_tf.sum().backward() for grad in (warp_tf.concentration0.grad, warp_tf.concentration1.grad): self.assertIsNotNone(grad) self.assertFalse(torch.isnan(grad).any()) self.assertFalse(torch.isinf(grad).any()) self.assertFalse((grad == 0).all()) # test set with scalar warp_tf._set_concentration(i=0, value=2.0) self.assertTrue((warp_tf.concentration0 == 2.0).all()) warp_tf._set_concentration(i=1, value=3.0) self.assertTrue((warp_tf.concentration1 == 3.0).all())