コード例 #1
0
    def test_sum(self):
        lengthscale1, lengthscale2 = 5., 10.
        X = np.random.normal(size=(10, 20))

        sk_sum = sklearn_RBF(lengthscale1) + sklearn_RBF(lengthscale2)
        ours_sum = RBF(lengthscale1) + RBF(lengthscale2)
        assert np.allclose(sk_sum(X), ours_sum(X))
コード例 #2
0
    def test_exponentiation(self):
        lengthscale = 5.
        exponent = 2.
        X = np.random.normal(size=(10, 20))

        sk = sklearn_RBF(lengthscale) ** exponent
        ours = RBF(lengthscale) ** exponent
        assert np.allclose(sk(X), ours(X))
コード例 #3
0
    def test_value(self, save_memory):
        config.SAVE_MEMORY = save_memory

        lengthscale = 15.
        X = np.random.normal(size=(10, 20))

        sk_rbf = sklearn_RBF(lengthscale)
        rbf = RBF(lengthscale)
        assert np.allclose(sk_rbf(X), rbf(X))
コード例 #4
0
    def test_gradient(self, save_memory):
        config.SAVE_MEMORY = save_memory

        lengthscale = 1.
        X = np.random.normal(size=(5, 2))

        sk_rbf = sklearn_RBF(lengthscale)
        _, sk_grad = sk_rbf(X, eval_gradient=True)
        rbf = RBF(lengthscale)
        _, grad = rbf(X, eval_gradient=True)
        assert np.allclose(sk_grad, grad)
コード例 #5
0
import jax.numpy as jnp
from sklearn import datasets
from sklearn.gaussian_process import GaussianProcessClassifier as sk_GPC
from sklearn.gaussian_process.kernels import RBF as sklearn_RBF
from sklearn_jax_kernels import RBF as jax_RBF
from sklearn_jax_kernels import GaussianProcessClassifier as jax_GPC

# import some data to play with
digits = datasets.load_digits()
X = digits.data
y = np.array(digits.target, dtype=int)

X_jax = jnp.asarray(X)
y_jax = jnp.asarray(y)

sk_kernel = 1.0 * sklearn_RBF([1.0])
jax_kernel = 1.0 * jax_RBF([1.0])

sk_clf = sk_GPC(kernel=sk_kernel, copy_X_train=False)
jax_clf = jax_GPC(kernel=jax_kernel, copy_X_train=False)

sk_clf.fit(X, y)
jax_clf.fit(X_jax, y_jax)


def fit_with_sklearn_kernel():
    sk_clf.fit(X, y)


def fit_with_jax_kernel():
    jax_clf.fit(X_jax, y_jax)