def test_cython_vs_theano(f, c11): from ..theano.theano import thermal_phase_curve planet = Planet.from_name('HD 189733') filt = Filter.from_name('IRAC 1') C_ml = [[0], [0, c11, 0]] m = Model(-0.8, 0.575, 4.5, 0, C_ml, 1, planet=planet, filt=filt) xi = np.linspace(-np.pi, np.pi, 100) # Set resolution of grid points on sphere: n_phi = 100 n_theta = 10 phi = np.linspace(-2 * np.pi, 2 * np.pi, n_phi) theta = np.linspace(0, np.pi, n_theta) theta2d, phi2d = np.meshgrid(theta, phi) cython_phase_curve = m.thermal_phase_curve(xi, f=f).flux cython_temp_map, _, _ = m.temperature_map(n_theta, n_phi, f=f) with pm.Model(): thermal_pc, T = thermal_phase_curve(xi, -0.8, 4.5, 0.575, c11, planet.T_s, planet.a, planet.rp_a, 0, theta2d, phi2d, filt.wavelength.to(u.m).value, filt.transmittance, f) theano_phase_curve = 1e6 * pmx.eval_in_model(thermal_pc) theano_map = pmx.eval_in_model(T)[..., 0, 0].T np.testing.assert_allclose(cython_phase_curve, theano_phase_curve, atol=5) np.testing.assert_allclose(cython_temp_map, theano_map, atol=10)
def get_val(x): try: # Try to directly evaluate it return x.eval() except MissingInputError as e: # That didn't work. Perhaps we are in a pymc3 model # context, but the user didn't provide a point? import pymc3 as pm import pymc3_ext as pmx try: model = kwargs_model if model is None: model = pm.Model.get_context() except TypeError: raise ValueError( "Missing input for variable {}, and no pymc3 model found.".format( x ) ) # Warn the user that we're using the test point warnings.warn( "Detected pymc3 model context, but no point provided. Evaluating at test_point." ) return pmx.eval_in_model( x, model=model, point=model.test_point )
def test_jax_vs_theano_thermal(): # These parameters have been chi-by-eye "fit" to the Spitzer/3.6 um PC f = 0.68 planet = Planet.from_name('HD 189733') filt = Filter.from_name('IRAC 1') filt.bin_down(5) xi = np.linspace(-np.pi, np.pi, 100) # Set resolution of grid points on sphere: n_phi = 100 n_theta = 10 phi = np.linspace(-2 * np.pi, 2 * np.pi, n_phi) theta = np.linspace(0, np.pi, n_theta) theta2d, phi2d = np.meshgrid(theta, phi) thermal_pc_jax, T = therm_jax(xi, -0.8, 4.5, 0.575, 0.18, planet.T_s, planet.a, planet.rp_a, 0, theta2d, phi2d, filt.wavelength.to(u.m).value, filt.transmittance, f) with pm.Model(): therm, T = therm_theano(xi, -0.8, 4.5, 0.575, 0.18, planet.T_s, planet.a, planet.rp_a, 0, theta2d, phi2d, filt.wavelength.to(u.m).value, filt.transmittance, f) thermal_pc_theano = pmx.eval_in_model(therm) np.testing.assert_allclose(thermal_pc_jax, thermal_pc_theano, atol=1e-6)
def test_albedo(): from ..theano.theano import thermal_phase_curve f = 2**-0.5 p = Planet.from_name('HD 189733') filt = Filter.from_name('IRAC 1') # Set resolution of grid points on sphere: n_phi = 100 n_theta = 10 phi = np.linspace(-2 * np.pi, 2 * np.pi, n_phi) theta = np.linspace(0, np.pi, n_theta) theta2d, phi2d = np.meshgrid(theta, phi) xi = np.linspace(-np.pi, np.pi, 100) with pm.Model(): thermal_pc, T = thermal_phase_curve(xi, -0.8, 4.5, 0.575, 0, p.T_s, p.a, p.rp_a, 0, theta2d, phi2d, filt.wavelength.to(u.m).value, filt.transmittance, f) theano_map = pmx.eval_in_model(T)[..., 0, 0] A_B = 1 - p.a**2 * (trapz2d( theano_map[..., None]**4 * np.sin(theta2d[..., None]) * (phi2d[..., None] > 0), phi, theta) / (np.pi * p.T_s**4))[0] assert abs(A_B - 0) < 5e-2
def test_jax_vs_theano_reflected(omega, delta_phi): args = (np.linspace(0, 1, 100), omega, delta_phi, 100) jax_pc = refl_jax(*args)[0] with pm.Model(): theano_pc = pmx.eval_in_model(refl_theano(*args)[0]) np.testing.assert_allclose(jax_pc, theano_pc)
def evaluator(**kwargs): """ Return a function to evaluate theano tensors. Works inside a `pymc3` model if a `point` is provided. Lazily imports `pymc3` to minimize overhead. """ # Store the kwargs kwargs_point = kwargs.pop("point", None) kwargs_model = kwargs.pop("model", None) if kwargs_point is not None: # User provided a point import pymc3 as pm import pymc3_ext as pmx point = kwargs_point model = kwargs_model if model is None: model = pm.Model.get_context() get_val = lambda x: pmx.eval_in_model(x, model=model, point=point) else: # No point provided def get_val(x): try: # Try to directly evaluate it return x.eval() except MissingInputError as e: # That didn't work. Perhaps we are in a pymc3 model # context, but the user didn't provide a point? import pymc3 as pm import pymc3_ext as pmx try: model = kwargs_model if model is None: model = pm.Model.get_context() except TypeError: raise ValueError( "Missing input for variable {}, and no pymc3 model found.".format( x ) ) # Warn the user that we're using the test point warnings.warn( "Detected pymc3 model context, but no point provided. Evaluating at test_point." ) return pmx.eval_in_model( x, model=model, point=model.test_point ) return get_val
def get_val(x): if is_tensor(x): return pmx.eval_in_model(x, model=model, point=point) else: return x