def test_quadrature(white, mean):
    with session_context() as session:
        c = DataQuadrature
        d = c.tensors(white, mean)
        quad_args = d.Xmu, d.Xvar, c.H, c.D_in, (c.D_out, )
        mean_quad = mvnquad(d.mean_fn, *quad_args)
        var_quad = mvnquad(d.var_fn, *quad_args)
        mean_sq_quad = mvnquad(d.mean_sq_fn, *quad_args)
        mean_analytic, var_analytic = uncertain_conditional(
            d.Xmu,
            d.Xvar,
            d.feat,
            d.kern,
            d.q_mu,
            d.q_sqrt,
            mean_function=d.mean_function,
            full_cov_output=False,
            white=white)

        mean_quad, var_quad, mean_sq_quad = session.run(
            [mean_quad, var_quad, mean_sq_quad], feed_dict=d.feed_dict)
        var_quad = var_quad + (mean_sq_quad - mean_quad**2)
        mean_analytic, var_analytic = session.run(
            [mean_analytic, var_analytic], feed_dict=d.feed_dict)

        assert_almost_equal(mean_quad, mean_analytic, decimal=6)
        assert_almost_equal(var_quad, var_analytic, decimal=6)
示例#2
0
def test_quadrature(white, mean):
    kernel = gpflow.kernels.SquaredExponential()
    inducing_variable = gpflow.inducing_variables.InducingPoints(DataQuad.Z)
    mean_function = mean_function_factory(mean, DataQuad.D_in, DataQuad.D_out)

    effective_mean = mean_function or (lambda X: 0.0)

    def conditional_fn(X):
        return conditional(
            X,
            inducing_variable,
            kernel,
            DataQuad.q_mu,
            q_sqrt=DataQuad.q_sqrt,
            white=white,
        )

    def mean_fn(X):
        return conditional_fn(X)[0] + effective_mean(X)

    def var_fn(X):
        return conditional_fn(X)[1]

    quad_args = (
        DataQuad.Xmu,
        DataQuad.Xvar,
        DataQuad.H,
        DataQuad.D_in,
        (DataQuad.D_out, ),
    )
    mean_quad = mvnquad(mean_fn, *quad_args)
    var_quad = mvnquad(var_fn, *quad_args)

    def mean_sq_fn(X):
        return mean_fn(X)**2

    mean_sq_quad = mvnquad(mean_sq_fn, *quad_args)
    var_quad = var_quad + (mean_sq_quad - mean_quad**2)

    mean_analytic, var_analytic = uncertain_conditional(
        DataQuad.Xmu,
        DataQuad.Xvar,
        inducing_variable,
        kernel,
        DataQuad.q_mu,
        DataQuad.q_sqrt,
        mean_function=mean_function,
        full_output_cov=False,
        white=white,
    )

    assert_allclose(mean_quad, mean_analytic, rtol=1e-6)
    assert_allclose(var_quad, var_analytic, rtol=1e-6)
def test_quadrature(white, mean):
    with session_context() as session:
        c = DataQuadrature
        d = c.tensors(white, mean)
        quad_args = d.Xmu, d.Xvar, c.H, c.D_in, (c.D_out,)
        mean_quad = mvnquad(d.mean_fn, *quad_args)
        var_quad = mvnquad(d.var_fn, *quad_args)
        mean_sq_quad = mvnquad(d.mean_sq_fn, *quad_args)
        mean_analytic, var_analytic = uncertain_conditional(
            d.Xmu, d.Xvar, d.feat, d.kern,
            d.q_mu, d.q_sqrt,
            mean_function=d.mean_function,
            full_cov_output=False,
            white=white)

        mean_quad, var_quad, mean_sq_quad = session.run(
            [mean_quad, var_quad, mean_sq_quad], feed_dict=d.feed_dict)
        var_quad = var_quad + (mean_sq_quad - mean_quad**2)
        mean_analytic, var_analytic = session.run(
            [mean_analytic, var_analytic], feed_dict=d.feed_dict)

        assert_almost_equal(mean_quad, mean_analytic, decimal=6)
        assert_almost_equal(var_quad, var_analytic, decimal=6)