Beispiel #1
0
def test_initialise(n_basis):
    key = jr.PRNGKey(123)
    kernel = SpectralRBF(num_basis=n_basis)
    params = initialise(key, kernel)
    assert list(params.keys()) == ["basis_fns", "lengthscale", "variance"]
    for v in params.values():
        assert v.dtype == jnp.float64
Beispiel #2
0
def test_call():
    key = jr.PRNGKey(123)
    kernel = SpectralRBF(num_basis=10)
    params = initialise(key, kernel)
    x, y = jnp.array([[1.]]), jnp.array([[0.5]])
    point_corr = kernel(x, y, params)
    assert isinstance(point_corr, jnp.DeviceArray)
    assert point_corr.shape == (1, 1)
Beispiel #3
0
def test_sample_frequencies(n_freqs):
    key = jr.PRNGKey(123)
    kernel = SpectralRBF(num_basis=n_freqs)
    sdensity = spectral_density(kernel)
    omega = sample_frequencies(key, kernel, n_freqs, 1)
    omegad = sample_frequencies(key, sdensity, n_freqs, 1)
    assert (omegad == omega).all()
    assert omegad.dtype == jnp.float64
    assert omega.dtype == jnp.float64
Beispiel #4
0
def test_spectral_density():
    kernel = SpectralRBF(num_basis=10)
    sdensity = spectral_density(kernel)
    assert isinstance(sdensity, tfd.Normal)