コード例 #1
0
def test_shape(dim):
    x = jnp.linspace(-1.0, 1.0, num=10).reshape(-1, 1)
    if dim > 1:
        x = jnp.hstack([x] * dim)
    meanf = Zero()
    mu = meanf(x)
    assert mu.shape[0] == x.shape[0]
    assert mu.shape[1] == 1
コード例 #2
0
ファイル: model.py プロジェクト: aidanscannell/GPJax
 def __init__(
     self,
     kernel: Kernel,
     likelihood: Likelihood,
     mean_function: Optional[MeanFunction] = None,
     num_latent_gps: int = None,
     jitter=1e-6,
 ):
     assert (num_latent_gps
             is not None), "GP requires specification of num_latent_gps"
     self.num_latent_gps = num_latent_gps
     self.kernel = kernel
     self.likelihood = likelihood
     if mean_function is None:
         mean_function = Zero()
     self.mean_function = mean_function
     self.jitter = jitter
コード例 #3
0
def test_initialisers():
    params = initialise(Zero())
    assert not params
コード例 #4
0
def test_hyperparametr_initialise():
    params = _initialise_hyperparams(RBF(), Zero())
    assert list(params.keys()) == sorted(["lengthscale", "variance"])
コード例 #5
0

class Datum:
    input_dim, output_dim = 3, 2
    N, Ntest, M = 20, 30, 10


# Constant(c=jax.random.normal(key, shape=(Datum.output_dim,))),


class Data:
    x1 = jnp.linspace(0, 10, 20).reshape(10, 2)


_mean_functions = [
    Zero(),
    # Linear(
    #     A=rng.randn(Datum.input_dim, Datum.output_dim),
    #     b=rng.randn(Datum.output_dim, 1).reshape(-1),
    # ),
    Constant(c=jax.random.normal(key, shape=(Datum.output_dim,))),
]


@pytest.mark.parametrize("mean_function_1", _mean_functions)
@pytest.mark.parametrize("mean_function_2", _mean_functions)
@pytest.mark.parametrize("operation", ["+", "*"])
def test_mean_functions_output_shape(
    mean_function_1, mean_function_2, operation
):
    """