def test_sparse_variational_update_raises_for_invalid_shapes(new_data: Dataset) -> None: model = SparseVariational( svgp_model(tf.zeros([1, 4]), tf.zeros([1, 1])), ) with pytest.raises(ValueError): model.update(new_data)
def test_sparse_variational_update_updates_num_data() -> None: model = SparseVariational( svgp_model(tf.zeros([1, 4]), tf.zeros([1, 1])), ) model.update(Dataset(tf.zeros([5, 4]), tf.zeros([5, 1]))) assert model.model.num_data == 5