def test_sparse_variational_optimize_with_defaults() -> None: x_observed = np.linspace(0, 100, 100).reshape((-1, 1)) y_observed = _3x_plus_gaussian_noise(x_observed) data = x_observed, y_observed dataset = Dataset(*data) optimizer = create_optimizer(tf.optimizers.Adam(), dict(max_iter=20)) model = SparseVariational(svgp_model(x_observed, y_observed), optimizer=optimizer) loss = model.model.training_loss(data) model.optimize(dataset) assert model.model.training_loss(data) < loss
def test_sparse_variational_optimize(batcher: DatasetTransformer, compile: bool) -> None: x_observed = np.linspace(0, 100, 100).reshape((-1, 1)) y_observed = _3x_plus_gaussian_noise(x_observed) data = x_observed, y_observed dataset = Dataset(*data) optimizer = create_optimizer( tf.optimizers.Adam(), dict(max_iter=10, batch_size=10, dataset_builder=batcher, compile=compile), ) model = SparseVariational(svgp_model(x_observed, y_observed), optimizer=optimizer) loss = model.model.training_loss(data) model.optimize(dataset) assert model.model.training_loss(data) < loss
def test_sparse_variational_update_updates_num_data() -> None: model = SparseVariational( svgp_model(tf.zeros([1, 4]), tf.zeros([1, 1])), ) model.update(Dataset(tf.zeros([5, 4]), tf.zeros([5, 1]))) assert model.model.num_data == 5
def test_sparse_variational_update_raises_for_invalid_shapes(new_data: Dataset) -> None: model = SparseVariational( svgp_model(tf.zeros([1, 4]), tf.zeros([1, 1])), ) with pytest.raises(ValueError): model.update(new_data)
def test_sparse_variational_model_attribute() -> None: model = svgp_model(*mock_data()) sv = SparseVariational(model) assert sv.model is model