def test_sum(self): lengthscale1, lengthscale2 = 5., 10. X = np.random.normal(size=(10, 20)) sk_sum = sklearn_RBF(lengthscale1) + sklearn_RBF(lengthscale2) ours_sum = RBF(lengthscale1) + RBF(lengthscale2) assert np.allclose(sk_sum(X), ours_sum(X))
def test_exponentiation(self): lengthscale = 5. exponent = 2. X = np.random.normal(size=(10, 20)) sk = sklearn_RBF(lengthscale) ** exponent ours = RBF(lengthscale) ** exponent assert np.allclose(sk(X), ours(X))
def test_value(self, save_memory): config.SAVE_MEMORY = save_memory lengthscale = 15. X = np.random.normal(size=(10, 20)) sk_rbf = sklearn_RBF(lengthscale) rbf = RBF(lengthscale) assert np.allclose(sk_rbf(X), rbf(X))
def test_gradient(self, save_memory): config.SAVE_MEMORY = save_memory lengthscale = 1. X = np.random.normal(size=(5, 2)) sk_rbf = sklearn_RBF(lengthscale) _, sk_grad = sk_rbf(X, eval_gradient=True) rbf = RBF(lengthscale) _, grad = rbf(X, eval_gradient=True) assert np.allclose(sk_grad, grad)
import jax.numpy as jnp from sklearn import datasets from sklearn.gaussian_process import GaussianProcessClassifier as sk_GPC from sklearn.gaussian_process.kernels import RBF as sklearn_RBF from sklearn_jax_kernels import RBF as jax_RBF from sklearn_jax_kernels import GaussianProcessClassifier as jax_GPC # import some data to play with digits = datasets.load_digits() X = digits.data y = np.array(digits.target, dtype=int) X_jax = jnp.asarray(X) y_jax = jnp.asarray(y) sk_kernel = 1.0 * sklearn_RBF([1.0]) jax_kernel = 1.0 * jax_RBF([1.0]) sk_clf = sk_GPC(kernel=sk_kernel, copy_X_train=False) jax_clf = jax_GPC(kernel=jax_kernel, copy_X_train=False) sk_clf.fit(X, y) jax_clf.fit(X_jax, y_jax) def fit_with_sklearn_kernel(): sk_clf.fit(X, y) def fit_with_jax_kernel(): jax_clf.fit(X_jax, y_jax)