Example #1
0
def test_cython_vs_theano(f, c11):
    from ..theano.theano import thermal_phase_curve
    planet = Planet.from_name('HD 189733')
    filt = Filter.from_name('IRAC 1')
    C_ml = [[0], [0, c11, 0]]
    m = Model(-0.8, 0.575, 4.5, 0, C_ml, 1, planet=planet, filt=filt)
    xi = np.linspace(-np.pi, np.pi, 100)

    # Set resolution of grid points on sphere:
    n_phi = 100
    n_theta = 10
    phi = np.linspace(-2 * np.pi, 2 * np.pi, n_phi)
    theta = np.linspace(0, np.pi, n_theta)
    theta2d, phi2d = np.meshgrid(theta, phi)

    cython_phase_curve = m.thermal_phase_curve(xi, f=f).flux
    cython_temp_map, _, _ = m.temperature_map(n_theta, n_phi, f=f)

    with pm.Model():
        thermal_pc, T = thermal_phase_curve(xi, -0.8, 4.5, 0.575, c11,
                                            planet.T_s, planet.a, planet.rp_a,
                                            0, theta2d, phi2d,
                                            filt.wavelength.to(u.m).value,
                                            filt.transmittance, f)

        theano_phase_curve = 1e6 * pmx.eval_in_model(thermal_pc)
        theano_map = pmx.eval_in_model(T)[..., 0, 0].T

    np.testing.assert_allclose(cython_phase_curve, theano_phase_curve, atol=5)

    np.testing.assert_allclose(cython_temp_map, theano_map, atol=10)
Example #2
0
        def get_val(x):

            try:

                # Try to directly evaluate it

                return x.eval()

            except MissingInputError as e:

                # That didn't work. Perhaps we are in a pymc3 model
                # context, but the user didn't provide a point?

                import pymc3 as pm
                import pymc3_ext as pmx

                try:
                    model = kwargs_model
                    if model is None:
                        model = pm.Model.get_context()
                except TypeError:
                    raise ValueError(
                        "Missing input for variable {}, and no pymc3 model found.".format(
                            x
                        )
                    )

                # Warn the user that we're using the test point
                warnings.warn(
                    "Detected pymc3 model context, but no point provided. Evaluating at test_point."
                )

                return pmx.eval_in_model(
                    x, model=model, point=model.test_point
                )
Example #3
0
def test_jax_vs_theano_thermal():
    # These parameters have been chi-by-eye "fit" to the Spitzer/3.6 um PC
    f = 0.68
    planet = Planet.from_name('HD 189733')
    filt = Filter.from_name('IRAC 1')
    filt.bin_down(5)

    xi = np.linspace(-np.pi, np.pi, 100)
    # Set resolution of grid points on sphere:
    n_phi = 100
    n_theta = 10
    phi = np.linspace(-2 * np.pi, 2 * np.pi, n_phi)
    theta = np.linspace(0, np.pi, n_theta)
    theta2d, phi2d = np.meshgrid(theta, phi)

    thermal_pc_jax, T = therm_jax(xi, -0.8, 4.5, 0.575, 0.18, planet.T_s,
                                  planet.a, planet.rp_a, 0, theta2d, phi2d,
                                  filt.wavelength.to(u.m).value,
                                  filt.transmittance, f)

    with pm.Model():
        therm, T = therm_theano(xi, -0.8, 4.5, 0.575, 0.18, planet.T_s,
                                planet.a, planet.rp_a, 0, theta2d, phi2d,
                                filt.wavelength.to(u.m).value,
                                filt.transmittance, f)

        thermal_pc_theano = pmx.eval_in_model(therm)

    np.testing.assert_allclose(thermal_pc_jax, thermal_pc_theano, atol=1e-6)
Example #4
0
def test_albedo():
    from ..theano.theano import thermal_phase_curve

    f = 2**-0.5
    p = Planet.from_name('HD 189733')
    filt = Filter.from_name('IRAC 1')

    # Set resolution of grid points on sphere:
    n_phi = 100
    n_theta = 10
    phi = np.linspace(-2 * np.pi, 2 * np.pi, n_phi)
    theta = np.linspace(0, np.pi, n_theta)
    theta2d, phi2d = np.meshgrid(theta, phi)
    xi = np.linspace(-np.pi, np.pi, 100)

    with pm.Model():
        thermal_pc, T = thermal_phase_curve(xi, -0.8, 4.5, 0.575, 0, p.T_s,
                                            p.a, p.rp_a, 0, theta2d, phi2d,
                                            filt.wavelength.to(u.m).value,
                                            filt.transmittance, f)

        theano_map = pmx.eval_in_model(T)[..., 0, 0]

    A_B = 1 - p.a**2 * (trapz2d(
        theano_map[..., None]**4 * np.sin(theta2d[..., None]) *
        (phi2d[..., None] > 0), phi, theta) / (np.pi * p.T_s**4))[0]

    assert abs(A_B - 0) < 5e-2
Example #5
0
def test_jax_vs_theano_reflected(omega, delta_phi):
    args = (np.linspace(0, 1, 100), omega, delta_phi, 100)
    jax_pc = refl_jax(*args)[0]

    with pm.Model():
        theano_pc = pmx.eval_in_model(refl_theano(*args)[0])

    np.testing.assert_allclose(jax_pc, theano_pc)
Example #6
0
def evaluator(**kwargs):
    """
    Return a function to evaluate theano tensors.

    Works inside a `pymc3` model if a `point` is provided.
    Lazily imports `pymc3` to minimize overhead.

    """
    # Store the kwargs
    kwargs_point = kwargs.pop("point", None)
    kwargs_model = kwargs.pop("model", None)

    if kwargs_point is not None:

        # User provided a point

        import pymc3 as pm
        import pymc3_ext as pmx

        point = kwargs_point
        model = kwargs_model
        if model is None:
            model = pm.Model.get_context()
        get_val = lambda x: pmx.eval_in_model(x, model=model, point=point)

    else:

        # No point provided

        def get_val(x):

            try:

                # Try to directly evaluate it

                return x.eval()

            except MissingInputError as e:

                # That didn't work. Perhaps we are in a pymc3 model
                # context, but the user didn't provide a point?

                import pymc3 as pm
                import pymc3_ext as pmx

                try:
                    model = kwargs_model
                    if model is None:
                        model = pm.Model.get_context()
                except TypeError:
                    raise ValueError(
                        "Missing input for variable {}, and no pymc3 model found.".format(
                            x
                        )
                    )

                # Warn the user that we're using the test point
                warnings.warn(
                    "Detected pymc3 model context, but no point provided. Evaluating at test_point."
                )

                return pmx.eval_in_model(
                    x, model=model, point=model.test_point
                )

    return get_val
Example #7
0
 def get_val(x):
     if is_tensor(x):
         return pmx.eval_in_model(x, model=model, point=point)
     else:
         return x