コード例 #1
0
def test_prior_sample(n, n_sample):
    key = jr.PRNGKey(123)
    f = Prior(kernel=RBF())
    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)
    params = initialise(RBF())
    samples = sample(key, f, params, sample_points, n_samples=n_sample)
    assert samples.shape == (n_sample, sample_points.shape[0])
コード例 #2
0
ファイル: test_sampling.py プロジェクト: thomaspinder/GPJax
def test_prior_random_variable(n):
    f = Prior(kernel=RBF())
    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)
    D = Dataset(X=sample_points)
    params = initialise(RBF())
    rv = random_variable(f, params, D)
    assert isinstance(rv, tfd.MultivariateNormalFullCovariance)
コード例 #3
0
ファイル: test_base.py プロジェクト: thomaspinder/GPJax
def test_call():
    kernel = RBF()
    params = initialise(kernel)
    x, y = jnp.array([[1.]]), jnp.array([[0.5]])
    point_corr = kernel(x, y, params)
    assert isinstance(point_corr, jnp.DeviceArray)
    assert point_corr.shape == ()
コード例 #4
0
def test_prior_mll():
    """
    Test that the MLL evaluation works with priors attached to the parameter values.
    """
    key = jr.PRNGKey(123)
    x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)),
                 axis=0)
    f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x)
    y = f(x) + jr.normal(key, shape=x.shape) * 0.1
    posterior = Prior(kernel=RBF()) * Gaussian()

    params = initialise(posterior)
    config = get_defaults()
    constrainer, unconstrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    print(params)

    mll = marginal_ll(posterior, transform=constrainer)

    priors = {
        "lengthscale": tfd.Gamma(1.0, 1.0),
        "variance": tfd.Gamma(2.0, 2.0),
        "obs_noise": tfd.Gamma(2.0, 2.0),
    }
    mll_eval = mll(params, x, y)
    mll_eval_priors = mll(params, x, y, priors)

    assert pytest.approx(mll_eval) == jnp.array(-103.28180663)
    assert pytest.approx(mll_eval_priors) == jnp.array(-105.509218857)
コード例 #5
0
def _get_conjugate_posterior_params() -> dict:
    kernel = RBF()
    prior = Prior(kernel=kernel)
    lik = Gaussian()
    posterior = prior * lik
    params = initialise(posterior)
    return params, posterior
コード例 #6
0
def test_spectral():
    key = jr.PRNGKey(123)
    kernel = to_spectral(RBF(), 10)
    posterior = Prior(kernel=kernel) * Gaussian()
    params = initialise(key, posterior)
    assert list(params.keys()) == sorted(
        ["basis_fns", "obs_noise", "lengthscale", "variance"])
    assert params["basis_fns"].shape == (10, 1)
コード例 #7
0
def test_posterior_random_variable(n):
    f = Prior(kernel=RBF()) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1)
    y = jnp.sin(x)
    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)
    params = initialise(f)
    rv = random_variable(f, params, sample_points, x, y)
    assert isinstance(rv, tfd.MultivariateNormalFullCovariance)
コード例 #8
0
ファイル: test_transforms.py プロジェクト: thomaspinder/GPJax
def test_constrain(likelihood):
    posterior = Prior(kernel=RBF()) * likelihood()
    params = initialise(posterior, 10)
    config = get_defaults()
    transform_map = build_constrain(params.keys(), config)
    transformed_params = transform_map(params)
    assert transformed_params.keys() == params.keys()
    for u, v in zip(transformed_params.values(), params.values()):
        assert u.dtype == v.dtype
コード例 #9
0
ファイル: test_base.py プロジェクト: thomaspinder/GPJax
def test_gram(dim):
    x = jnp.linspace(-1.0, 1.0, num=10).reshape(-1, 1)
    if dim > 1:
        x = jnp.hstack([x] * dim)
    kern = RBF()
    params = initialise(kern)
    gram_matrix = gram(kern, x, params)
    assert gram_matrix.shape[0] == x.shape[0]
    assert gram_matrix.shape[0] == gram_matrix.shape[1]
コード例 #10
0
def test_check_needless():
    complete_prior = {
        "lengthscale": tfd.Gamma(1.0, 1.0),
        "variance": tfd.Gamma(2.0, 2.0),
        "obs_noise": tfd.Gamma(3.0, 3.0),
        "latent": tfd.Normal(loc=0.0, scale=1.0),
    }
    posterior = Prior(kernel=RBF()) * Bernoulli()
    priors = prior_checks(posterior, complete_prior)
    assert priors == complete_prior
コード例 #11
0
def test_posterior_sample(n, n_sample):
    key = jr.PRNGKey(123)
    f = Prior(kernel=RBF()) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1)
    y = jnp.sin(x)
    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)
    params = initialise(f)
    rv = random_variable(f, params, sample_points, x, y)
    samples = sample(key, rv, n_samples=n_sample)
    assert samples.shape == (n_sample, sample_points.shape[0])
コード例 #12
0
ファイル: test_sampling.py プロジェクト: thomaspinder/GPJax
def test_posterior_random_variable(n):
    f = Prior(kernel=RBF()) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)
    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)
    params = initialise(f)
    rv = random_variable(f, params, D)
    assert isinstance(rv, Callable)
    fstar = rv(sample_points)
    assert isinstance(fstar, tfd.MultivariateNormalFullCovariance)
コード例 #13
0
def test_conjugate_variance():
    key = jr.PRNGKey(123)
    x = jr.uniform(key, shape=(20, 1), minval=-3.0, maxval=3.0)
    y = jnp.sin(x)

    posterior = Prior(kernel=RBF()) * Gaussian()
    params = initialise(posterior)

    xtest = jnp.linspace(-3.0, 3.0, 30).reshape(-1, 1)
    sigma = variance(posterior, params, xtest, x, y)
    assert sigma.shape == (xtest.shape[0], xtest.shape[0])
コード例 #14
0
def test_non_conjugate_mean():
    key = jr.PRNGKey(123)
    x = jnp.sort(jr.uniform(key, shape=(10, 1), minval=-1.0, maxval=1.0),
                 axis=0)
    y = 0.5 * jnp.sign(
        jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5
    xtest = jnp.linspace(-1.05, 1.05, 50).reshape(-1, 1)

    posterior = Prior(kernel=RBF()) * Bernoulli()
    params = initialise(posterior, x.shape[0])

    mu = mean(posterior, params, xtest, x, y)
    assert mu.shape == (xtest.shape[0], )
コード例 #15
0
ファイル: test_predict.py プロジェクト: thomaspinder/GPJax
def test_conjugate_mean():
    key = jr.PRNGKey(123)
    x = jr.uniform(key, shape=(20, 1), minval=-3.0, maxval=3.0)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)

    posterior = Prior(kernel=RBF()) * Gaussian()
    params = initialise(posterior)

    xtest = jnp.linspace(-3.0, 3.0, 30).reshape(-1, 1)
    meanf = mean(posterior, params, D)
    mu = meanf(xtest)
    assert mu.shape == (xtest.shape[0], y.shape[1])
コード例 #16
0
ファイル: test_predict.py プロジェクト: thomaspinder/GPJax
def test_non_conjugate_variance():
    key = jr.PRNGKey(123)
    x = jnp.sort(jr.uniform(key, shape=(10, 1), minval=-1.0, maxval=1.0), axis=0)
    y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5
    D = Dataset(X=x, y=y)
    xtest = jnp.linspace(-1.05, 1.05, 50).reshape(-1, 1)

    posterior = Prior(kernel=RBF()) * Bernoulli()
    params = initialise(posterior, x.shape[0])

    varf = variance(posterior, params, D)
    sigma = varf(xtest)
    assert sigma.shape == (xtest.shape[0],)
コード例 #17
0
ファイル: test_base.py プロジェクト: thomaspinder/GPJax
def test_pos_def(dim, ell, sigma):
    n = 30
    x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1)
    if dim > 1:
        x = jnp.hstack((x) * dim)
    kern = RBF()
    params = {"lengthscale": jnp.array([ell]), "variance": jnp.array(sigma)}

    gram_matrix = gram(kern, x, params)
    jitter_matrix = I(n) * 1e-6
    gram_matrix += jitter_matrix
    min_eig = jnp.linalg.eigvals(gram_matrix).min()
    assert min_eig > 0
コード例 #18
0
def test_non_conjugate():
    posterior = Prior(kernel=RBF()) * Bernoulli()
    n = 20
    x = jnp.linspace(-1.0, 1.0, n).reshape(-1, 1)
    y = jnp.sin(x)
    params = initialise(posterior, 20)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
コード例 #19
0
ファイル: test_mlls.py プロジェクト: thomaspinder/GPJax
def test_conjugate():
    posterior = Prior(kernel=RBF()) * Gaussian()

    x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)
    params = initialise(posterior)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, D) == jnp.array(-1.0) * mll(params, D)
コード例 #20
0
ファイル: test_sampling.py プロジェクト: thomaspinder/GPJax
def test_non_conjugate_rv(n):
    key = jr.PRNGKey(123)
    f = posterior = Prior(kernel=RBF()) * Bernoulli()
    x = jnp.sort(jr.uniform(key, shape=(n, 1), minval=-1.0, maxval=1.0), axis=0)
    y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5
    D = Dataset(X=x, y=y)

    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)

    hyperparams = {"lengthscale": jnp.array([1.0]), "variance": jnp.array([1.0])}
    params = complete(hyperparams, posterior, x.shape[0])
    rv = random_variable(f, params, D)
    assert isinstance(rv, Callable)
    fstar = rv(sample_points)
    assert isinstance(fstar, tfd.ProbitBernoulli)
コード例 #21
0
ファイル: test_sampling.py プロジェクト: thomaspinder/GPJax
def test_spectral_sample():
    key = jr.PRNGKey(123)
    M = 10
    x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)
    sample_points = jnp.linspace(-1.0, 1.0, num=50).reshape(-1, 1)
    kernel = to_spectral(RBF(), M)
    post = Prior(kernel=kernel) * Gaussian()
    params = initialise(key, post)
    sparams = {"basis_fns": params["basis_fns"]}
    del params["basis_fns"]
    posterior_rv = random_variable(post, params, D, static_params=sparams)(sample_points)
    assert isinstance(posterior_rv, tfd.Distribution)
    assert isinstance(posterior_rv, tfd.MultivariateNormalFullCovariance)
コード例 #22
0
ファイル: test_spectral.py プロジェクト: jejjohnson/GPJax-1
def test_conjugate():
    key = jr.PRNGKey(123)
    kern = to_spectral(RBF(), 10)
    posterior = Prior(kernel=kern) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
    y = jnp.sin(x)
    params = initialise(key, posterior)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
    nmll = neg_mll(params, x, y)
    assert nmll.shape == ()
コード例 #23
0
ファイル: test_transforms.py プロジェクト: thomaspinder/GPJax
def test_build_all_transforms(likelihood):
    posterior = Prior(kernel=RBF()) * likelihood()
    params = initialise(posterior, 10)
    config = get_defaults()
    t1, t2 = build_all_transforms(params.keys(), config)
    constrainer = build_constrain(params.keys(), config)
    constrained = t1(params)
    constrained2 = constrainer(params)
    assert constrained2.keys() == constrained2.keys()
    for u, v in zip(constrained.values(), constrained2.values()):
        assert_array_equal(u, v)
        assert u.dtype == v.dtype
    unconstrained = t2(params)
    unconstrainer = build_unconstrain(params.keys(), config)
    unconstrained2 = unconstrainer(params)
    for u, v in zip(unconstrained.values(), unconstrained2.values()):
        assert_array_equal(u, v)
        assert u.dtype == v.dtype
コード例 #24
0
def mogpe_checkpoint_to_numpy(config_file, ckpt_dir, data_file, expert_num=0):
    # load data set
    data = np.load(data_file)
    X = data["x"]

    # configure mogpe model from checkpoint
    model = load_model_from_config_and_checkpoint(config_file, ckpt_dir, X=X)

    # select the gating function to use
    gating_func = model.gating_network.gating_function_list[expert_num]
    mean_function = 0.0  # mogpe gating functions have zero mean function
    whiten = gating_func.whiten

    # sparse GP parameters
    q_mu = gating_func.q_mu.numpy()
    q_sqrt = gating_func.q_sqrt.numpy()
    inducing_variable = (
        gating_func.inducing_variable.inducing_variable.Z.numpy())

    # kerenl parameters
    variance = gating_func.kernel.kernels[0].variance.numpy()
    lengthscales = gating_func.kernel.kernels[0].lengthscales.numpy()
    kernel = RBF(variance=variance, lengthscales=lengthscales)
    return kernel, inducing_variable, mean_function, q_mu, q_sqrt, whiten
コード例 #25
0
ファイル: test_gp.py プロジェクト: jejjohnson/GPJax-1
def test_spectral():
    kernel = to_spectral(RBF(), 10)
    posterior = Prior(kernel=kernel) * Gaussian()
    assert isinstance(posterior, SpectralPosterior)
コード例 #26
0
ファイル: test_transforms.py プロジェクト: thomaspinder/GPJax
def test_output(transformation, likelihood):
    posterior = Prior(kernel=RBF()) * likelihood()
    params = initialise(posterior, 10)
    config = get_defaults()
    transform_map = transformation(params.keys(), config)
    assert isinstance(transform_map, Callable)
コード例 #27
0
def test_checks():
    incomplete_priors = {"lengthscale": jnp.array([1.0])}
    posterior = Prior(kernel=RBF()) * Bernoulli()
    priors = prior_checks(posterior, incomplete_priors)
    assert "latent" in priors.keys()
    assert "variance" not in priors.keys()
コード例 #28
0
ファイル: test_gp.py プロジェクト: jejjohnson/GPJax-1
def test_conjugate_posterior():
    p = Prior(kernel=RBF())
    lik = Gaussian()
    post = p * lik
    assert isinstance(post, ConjugatePosterior)
コード例 #29
0
ファイル: test_gp.py プロジェクト: jejjohnson/GPJax-1
def test_non_conjugate_poster(likelihood):
    posterior = Prior(kernel=RBF()) * likelihood()
    assert isinstance(posterior, NonConjugatePosterior)
コード例 #30
0
ファイル: test_spectral.py プロジェクト: jejjohnson/GPJax-1
def test_to_spectral(n_basis):
    base_kern = RBF()
    spectral = to_spectral(base_kern, n_basis)
    assert isinstance(spectral, SpectralRBF)
    assert spectral.num_basis == n_basis
    assert spectral.stationary