def test_dataset_getters() -> None: query_points, observations = tf.constant([[0.0]]), tf.constant([[1.0]]) dataset = Dataset(query_points, observations) assert dataset.query_points.dtype == query_points.dtype assert dataset.observations.dtype == observations.dtype assert shapes_equal(dataset.query_points, query_points) assert shapes_equal(dataset.observations, observations) assert tf.reduce_all(dataset.query_points == query_points) assert tf.reduce_all(dataset.observations == observations)
def assert_datasets_allclose(this: Dataset, that: Dataset) -> None: """ Check the :attr:`query_points` in ``this`` and ``that`` have the same shape and dtype, and all elements are approximately equal. Also check the same for :attr:`observations`. :param this: A dataset. :param that: A dataset. :raise AssertionError: If any of the following are true: - shapes are not equal - dtypes are not equal - elements are not approximately equal. """ assert shapes_equal(this.query_points, that.query_points) assert shapes_equal(this.observations, that.observations) assert this.query_points.dtype == that.query_points.dtype assert this.observations.dtype == that.observations.dtype npt.assert_allclose(this.query_points, that.query_points) npt.assert_allclose(this.observations, that.observations)