def test_shape(dim): x = jnp.linspace(-1.0, 1.0, num=10).reshape(-1, 1) if dim > 1: x = jnp.hstack([x] * dim) meanf = Zero() mu = meanf(x) assert mu.shape[0] == x.shape[0] assert mu.shape[1] == 1
def __init__( self, kernel: Kernel, likelihood: Likelihood, mean_function: Optional[MeanFunction] = None, num_latent_gps: int = None, jitter=1e-6, ): assert (num_latent_gps is not None), "GP requires specification of num_latent_gps" self.num_latent_gps = num_latent_gps self.kernel = kernel self.likelihood = likelihood if mean_function is None: mean_function = Zero() self.mean_function = mean_function self.jitter = jitter
def test_initialisers(): params = initialise(Zero()) assert not params
def test_hyperparametr_initialise(): params = _initialise_hyperparams(RBF(), Zero()) assert list(params.keys()) == sorted(["lengthscale", "variance"])
class Datum: input_dim, output_dim = 3, 2 N, Ntest, M = 20, 30, 10 # Constant(c=jax.random.normal(key, shape=(Datum.output_dim,))), class Data: x1 = jnp.linspace(0, 10, 20).reshape(10, 2) _mean_functions = [ Zero(), # Linear( # A=rng.randn(Datum.input_dim, Datum.output_dim), # b=rng.randn(Datum.output_dim, 1).reshape(-1), # ), Constant(c=jax.random.normal(key, shape=(Datum.output_dim,))), ] @pytest.mark.parametrize("mean_function_1", _mean_functions) @pytest.mark.parametrize("mean_function_2", _mean_functions) @pytest.mark.parametrize("operation", ["+", "*"]) def test_mean_functions_output_shape( mean_function_1, mean_function_2, operation ): """