Exemplo n.º 1
0
def test_vgp_raises_for_invalid_init() -> None:
    x_np = np.arange(5, dtype=np.float64).reshape(-1, 1)
    x = tf.convert_to_tensor(x_np, x_np.dtype)
    y = fnc_3x_plus_10(x)

    with pytest.raises(ValueError):
        VariationalGaussianProcess(vgp_model(x, y), natgrad_gamma=1)

    with pytest.raises(ValueError):
        optimizer = Optimizer(gpflow.optimizers.Scipy())
        VariationalGaussianProcess(vgp_model(x, y), optimizer=optimizer, use_natgrads=True)
Exemplo n.º 2
0
def test_vgp_update_updates_num_data() -> None:
    x_np = np.arange(5, dtype=np.float64).reshape(-1, 1)
    x = tf.convert_to_tensor(x_np, x_np.dtype)
    y = fnc_3x_plus_10(x)
    m = VariationalGaussianProcess(vgp_model(x, y))
    num_data = m.model.num_data

    x_new = tf.concat([x, [[10.0], [11.0]]], 0)
    y_new = fnc_3x_plus_10(x_new)
    m.update(Dataset(x_new, y_new))
    new_num_data = m.model.num_data
    assert new_num_data - num_data == 2
Exemplo n.º 3
0
def test_vgp_update_q_mu_sqrt_unchanged() -> None:
    x_observed = tf.constant(np.arange(10).reshape((-1, 1)), dtype=gpflow.default_float())
    y_observed = fnc_2sin_x_over_3(x_observed)
    model = VariationalGaussianProcess(vgp_matern_model(x_observed, y_observed))

    old_q_mu = model.model.q_mu.numpy()
    old_q_sqrt = model.model.q_sqrt.numpy()
    data = Dataset(x_observed, y_observed)
    model.update(data)

    new_q_mu = model.model.q_mu.numpy()
    new_q_sqrt = model.model.q_sqrt.numpy()

    npt.assert_allclose(old_q_mu, new_q_mu, atol=1e-5)
    npt.assert_allclose(old_q_sqrt, new_q_sqrt, atol=1e-5)
Exemplo n.º 4
0
def test_vgp_optimize_with_and_without_natgrads(
    batcher: DatasetTransformer, compile: bool, use_natgrads: bool
) -> None:
    x_observed = np.linspace(0, 100, 100).reshape((-1, 1))
    y_observed = _3x_plus_gaussian_noise(x_observed)
    data = x_observed, y_observed
    dataset = Dataset(*data)

    optimizer = create_optimizer(
        tf.optimizers.Adam(),
        dict(max_iter=10, batch_size=10, dataset_builder=batcher, compile=compile),
    )
    model = VariationalGaussianProcess(
        vgp_model(x_observed[:10], y_observed[:10]), optimizer=optimizer, use_natgrads=use_natgrads
    )
    loss = model.model.training_loss()
    model.optimize(dataset)
    assert model.model.training_loss() < loss
Exemplo n.º 5
0
def test_vgp_update() -> None:
    x = tf.constant(np.arange(5).reshape(-1, 1), dtype=gpflow.default_float())

    data = Dataset(x, fnc_3x_plus_10(x))
    m = VariationalGaussianProcess(vgp_model(data.query_points, data.observations))

    reference_model = vgp_model(data.query_points, data.observations)

    npt.assert_allclose(m.model.q_mu, reference_model.q_mu, atol=1e-5)
    npt.assert_allclose(m.model.q_sqrt, reference_model.q_sqrt, atol=1e-5)

    x_new = tf.concat([x, tf.constant([[10.0], [11.0]], dtype=gpflow.default_float())], 0)
    new_data = Dataset(x_new, fnc_3x_plus_10(x_new))

    m.update(new_data)
    reference_model_new = vgp_model(new_data.query_points, new_data.observations)

    npt.assert_allclose(m.model.q_mu, reference_model_new.q_mu, atol=1e-5)
    npt.assert_allclose(m.model.q_sqrt, reference_model_new.q_sqrt, atol=1e-5)
Exemplo n.º 6
0
def test_variational_gaussian_process_predict() -> None:
    x_observed = tf.constant(np.arange(100).reshape((-1, 1)),
                             dtype=gpflow.default_float())
    y_observed = _3x_plus_gaussian_noise(x_observed)
    model = VariationalGaussianProcess(vgp_model(x_observed, y_observed))
    internal_model = model.model

    gpflow.optimizers.Scipy().minimize(
        internal_model.training_loss_closure(),
        internal_model.trainable_variables,
    )
    x_predict = tf.constant([[50.5]], gpflow.default_float())
    mean, variance = model.predict(x_predict)
    mean_y, variance_y = model.predict_y(x_predict)

    reference_model = vgp_model(x_observed, y_observed)

    reference_model.data = (
        tf.Variable(
            reference_model.data[0],
            trainable=False,
            shape=[None, *reference_model.data[0].shape[1:]],
        ),
        tf.Variable(
            reference_model.data[1],
            trainable=False,
            shape=[None, *reference_model.data[1].shape[1:]],
        ),
    )

    gpflow.optimizers.Scipy().minimize(
        reference_model.training_loss_closure(),
        reference_model.trainable_variables,
    )
    reference_mean, reference_variance = reference_model.predict_f(x_predict)

    npt.assert_allclose(mean, reference_mean)
    npt.assert_allclose(variance, reference_variance, atol=1e-3)
    npt.assert_allclose(variance_y - model.get_observation_noise(),
                        variance,
                        atol=5e-5)
Exemplo n.º 7
0
def test_vgp_optimize_natgrads_only_updates_variational_params(compile: bool) -> None:
    x_observed = np.linspace(0, 100, 10).reshape((-1, 1))
    y_observed = _3x_plus_gaussian_noise(x_observed)
    data = x_observed, y_observed
    dataset = Dataset(*data)

    class DummyBatchOptimizer(BatchOptimizer):
        def optimize(self, model: tf.Module, dataset: Dataset) -> None:
            pass

    optimizer = DummyBatchOptimizer(tf.optimizers.Adam(), compile=compile, max_iter=10)

    model = VariationalGaussianProcess(
        vgp_matern_model(x_observed[:10], y_observed[:10]), optimizer=optimizer, use_natgrads=True
    )

    old_num_trainable_params = len(model.trainable_variables)
    old_kernel_params = model.get_kernel().parameters[0].numpy()
    old_q_mu = model.model.q_mu.numpy()
    old_q_sqrt = model.model.q_sqrt.numpy()

    model.optimize(dataset)

    new_num_trainable_params = len(model.trainable_variables)
    new_kernel_params = model.get_kernel().parameters[0].numpy()
    new_q_mu = model.model.q_mu.numpy()
    new_q_sqrt = model.model.q_sqrt.numpy()

    npt.assert_allclose(old_kernel_params, new_kernel_params, atol=1e-3)
    npt.assert_equal(old_num_trainable_params, new_num_trainable_params)
    npt.assert_raises(AssertionError, npt.assert_allclose, old_q_mu, new_q_mu)
    npt.assert_raises(AssertionError, npt.assert_allclose, old_q_sqrt, new_q_sqrt)