예제 #1
0
def test_filter_finite() -> None:
    ok_query_points = [[-1., 0.], [1., 0.], [0., 2.], [1., 3.]]
    query_points = tf.constant([[0., 0.]] + ok_query_points)
    finite_values = filter_finite(query_points, _nan_at_origin(query_points))

    npt.assert_array_almost_equal(finite_values.query_points, ok_query_points)
    npt.assert_array_almost_equal(finite_values.observations,
                                  [[-1.], [1.], [2.], [4.]])
예제 #2
0
def test_filter_finite_raises_for_invalid_shapes(qp_shape: ShapeLike,
                                                 obs_shape: ShapeLike) -> None:
    with pytest.raises(ValueError):
        filter_finite(tf.ones(qp_shape), tf.ones(obs_shape))
예제 #3
0
def test_filter_finite(query_points: tf.Tensor, expected: Dataset) -> None:
    observations = _sum_with_nan_at_origin(query_points)
    assert_datasets_allclose(filter_finite(query_points, observations),
                             expected)