Ejemplo n.º 1
0
def run_sampling(pm_model,
                 output_dir,
                 ncores=1,
                 nchains=2,
                 max_attempts=2,
                 filename="trace"):
    # Log file output
    logging.basicConfig(
        filename=output_dir + "/sampling.log",
        filemode="w",
        format="%(name)s - %(levelname)s - %(message)s",
    )

    # Sample the model
    divperc = 20

    with pm_model:
        # Run initial chain
        try:
            trace = pm.sample(
                tune=1000,
                draws=4000,
                cores=ncores,
                chains=nchains,
                step=xo.get_dense_nuts_step(),
            )

        except pm.exceptions.SamplingError:
            logging.error("Sampling failed, model misspecified")
            return None

        # Check for divergences, restart sampling if necessary
        divergent = trace["diverging"]
        divperc = divergent.nonzero()[0].size / len(trace) * 100

        n_attempts = 1
        while divperc > 15.0 and n_attempts <= max_attempts:
            # Run sampling
            trace = pm.sample(
                tune=2000,
                draws=n_attempts * 10000,
                cores=ncores,
                chains=nchains,
                step=xo.get_dense_nuts_step(target_accept=0.9),
            )

            n_attempts += 1

        if divperc > 15:
            logging.warning(f"{divperc} of samples are diverging.")
            df = pm.trace_to_dataframe(trace, include_transformed=True)
            df.to_csv(output_dir + f"/{filename}.csv")
            return None

        else:
            df = pm.trace_to_dataframe(trace, include_transformed=True)
            df.to_csv(output_dir + f"/{filename}.csv")

    return trace
Ejemplo n.º 2
0
def fit_model(samples, samples_logp, output_dir):
    # Convert samples to theano.tensor
    samples_tensor = T.as_tensor_variable(samples)  # for performance reasons
    samples_logp_tensor = T.as_tensor_variable(samples_logp)

    model = HierarchicalModel(samples_tensor, samples_logp_tensor)

    with model:
        # Print initial logps
        print(model.test_point)

        # Run sampling
        trace = pm.sample(tune=100,
                          draws=1000,
                          cores=4,
                          step=xo.get_dense_nuts_step())

    # Save the samples to disk
    samples_hyper = np.stack((
        trace["lam_ln_A0"],
        trace["mu_ln_delta_t0"],
        trace["sig_ln_delta_t0"],
        trace["mu_ln_tE"],
        trace["sig_ln_tE"],
        trace["alpha_f"],
        trace["beta_f"],
    )).T

    np.save(output_dir + "samples_hyper.npy", samples_hyper)
Ejemplo n.º 3
0
def run_pymc3_model(pos, pos_err, proper, proper_err, mean, cov):

    M = get_tangent_basis(pos[0] * 2 * np.pi / 360, pos[1] * 2 * np.pi / 360)
    # mean, cov = get_prior()

    with pm.Model() as model:

        vxyzD = pm.MvNormal("vxyzD", mu=mean, cov=cov, shape=4)
        vxyz = pm.Deterministic("vxyz", vxyzD[:3])
        log_D = pm.Deterministic("log_D", vxyzD[3])
        D = pm.Deterministic("D", tt.exp(log_D))

        xyz = pm.Deterministic("xyz", tt_eqtogal(pos[0], pos[1], D)[:, 0])

        pm_from_v, rv_from_v = tt_get_icrs_from_galactocentric(
            xyz, vxyz, pos[0], pos[1], D, M)

        pm.Normal("proper_motion",
                  mu=pm_from_v,
                  sigma=np.array(proper_err),
                  observed=np.array(proper))
        pm.Normal("parallax", mu=1. / D, sigma=pos_err[2], observed=pos[2])

        map_soln = xo.optimize()
        trace = pm.sample(tune=1500,
                          draws=1000,
                          start=map_soln,
                          step=xo.get_dense_nuts_step(target_accept=0.9))

    return trace
Ejemplo n.º 4
0
    def predict(self, model_type=None):
        """
        Predict the period of stellar variability via Gaussian Process fitting
        
        Returns all samples of paramters after mcmc fitting
        """
        if model_type == None:
            if self.lctype == "rotation":
                model, map_soln = self.rotation_model()
            elif self.lctype == "granulation":
                model, map_soln = self.granulation_model()
            elif self.lctype == "hybrid":
                model, map_soln = self.hybrid_model()
        else:
            if model_type == "rotation":
                model, map_soln = self.rotation_model()
            elif model_type == "granulation":
                model, map_soln = self.granulation_model()
            elif model_type == "hybrid":
                model, map_soln = self.hybrid_model()

        np.random.seed(42)
        with model:
            trace = pm.sample(tune=2000,
                              draws=2000,
                              start=map_soln,
                              step=xo.get_dense_nuts_step(target_accept=0.99),
                              progressbar=True)

        self._trace = trace
        return trace
Ejemplo n.º 5
0
    def fit_ttv(self, n, run_MCMC=False, ttv_start=None, verbose=True):
        """
        Fit a single transit with a transit timing variation, using the shape given by best-fit orbital parameters.
        """
        if verbose: print("Fitting ttv for transit number", n)

        # Get the transit lightcurve
        transit = self.lightcurve.get_transit(n, self.p_ref, self.t0_ref)
        t = transit.time * self.p_ref
        y = transit.flux
        sd = transit.flux_err

        if ttv_start is None:
            ttv_start = np.median(self.pars['ttvs'])

        with pm.Model() as model:
            ttv = pm.Normal("ttv", mu=ttv_start, sd=0.025)  # sd = 36 minutes

            orbit = xo.orbits.KeplerianOrbit(period=self.p_ref,
                                             t0=ttv,
                                             b=self.pars['b'])

            light_curves = xo.LimbDarkLightCurve(
                self.pars['u']).get_light_curve(orbit=orbit,
                                                r=self.pars['r'],
                                                t=t)
            light_curve = pm.math.sum(light_curves, axis=-1) + 1
            pm.Deterministic("transit_" + str(n), light_curve)

            pm.Normal("obs", mu=light_curve, sd=sd, observed=y)

            map_soln = xo.optimize(start=model.test_point,
                                   verbose=False,
                                   progress_bar=False)

        self.pars['ttvs'][n] = float(map_soln['ttv'])
        if verbose: print(f"\t ttv {n} = {self.pars['ttvs'][n]}")

        if run_MCMC:
            np.random.seed(42)
            with model:
                trace = pm.sample(
                    tune=500,
                    draws=500,
                    start=map_soln,
                    cores=1,
                    chains=2,
                    step=xo.get_dense_nuts_step(target_accept=0.9),
                )

            self.pars['ttvs'][n] = np.median(trace['ttv'])
            self.pars['e_ttvs'][n] = self.pars['ttvs'][n] - np.percentile(
                trace['ttv'], 16, axis=0)
            self.pars['E_ttvs'][n] = -self.pars['ttvs'][n] + np.percentile(
                trace['ttv'], 84, axis=0)

            if verbose:
                print(
                    f"\t ttv {n} = {self.pars['ttvs'][n]} /+ {self.pars['E_ttvs'][n]} /- {self.pars['e_ttvs'][n]}"
                )
Ejemplo n.º 6
0
def sample_from_model(model,
                      map_soln,
                      tune=500,
                      draws=200,
                      chains=5,
                      cores=None,
                      step=None):
    """
    Sample from the transit light curve model.

    Parameters
    ----------
    model : `~pymc3.model`
        A model object.

    map_soln : dict
        A dictionary with the maximum a posteriori estimates of the variables.

    tune : int, optional
        The number of iterations to tune.

    draws : int, optional
        The number of samples to draw.

    chains : int, optional
        The number of chains to sample.

    cores : int, optional
        The number of cores to run in parallel.

    step : function, optional
        A step function.

    Returns
    -------
    trace : `~pymc3.backends.base.MultiTrace`
        A ``MultiTrace`` object that contains the samples.
    """
    # Use 1 CPU thread per chain, unless specified otherwise
    if cores is None:
        cores = min(chains, mp.cpu_count())

    # Ignore FutureWarnings
    warnings.simplefilter('ignore', FutureWarning)

    with model:
        if step is None:
            step = xo.get_dense_nuts_step(target_accept=0.95)

        trace = pm.sample(tune=tune,
                          draws=draws,
                          start=map_soln,
                          chains=chains,
                          cores=cores,
                          step=step)

    # Reset warnings
    warnings.resetwarnings()

    return trace
Ejemplo n.º 7
0
def run_mcmc_1d(t, data, logS0_init, 
                logw0_init, logQ_init, 
                logsig_init, t0_init, 
                r_init, d_init, tin_init):
    
    with pm.Model() as model:
        #logsig = pm.Uniform("logsig", lower=-20.0, upper=0.0, testval=logsig_init)

        # The parameters of the SHOTerm kernel
        #logS0 = pm.Uniform("logS0", lower=-50.0, upper=0.0, testval=logS0_init)
        #logQ = pm.Uniform("logQ", lower=-50.0, upper=20.0, testval=logQ_init)
        #logw0 = pm.Uniform("logw0", lower=-50.0, upper=20.0, testval=logw0_init)
        
        # The parameters for the transit mean function
        t0 = pm.Uniform("t0", lower=t[0], upper=t[-1], testval=t0_init)
        r = pm.Uniform("r", lower=0.0, upper=1.0, testval=r_init)
        d = pm.Uniform("d", lower=0.0, upper=10.0, testval=d_init)
        tin = pm.Uniform("tin", lower=0.0, upper=10.0, testval=tin_init)
            
        # Deterministics
        # mean = pm.Deterministic("mean", utils.transit(t, t0, r, d, tin))
        transit = utils.theano_transit(t, t0, r, d, tin)

        # Set up the Gaussian Process model
        kernel = xo.gp.terms.SHOTerm(
            log_S0 = logS0_init,
            log_w0 = logw0_init,
            log_Q=logQ_init
        )
    
        diag = np.exp(2*logsig_init)*tt.ones((1, len(t)))
        gp = GP(kernel, t, diag, J=2)

        # Compute the Gaussian Process likelihood and add it into the
        # the PyMC3 model as a "potential"
        pm.Potential("loglike", gp.log_likelihood(data - transit))

        # Compute the mean model prediction for plotting purposes
        #pm.Deterministic("mu", gp.predict())
        map_soln = xo.optimize(start=model.test_point, verbose=False)
        
    with model:
        map_soln = xo.optimize(start=model.test_point)
        
    with model:
        trace = pm.sample(
            tune=500,
            draws=500,
            start=map_soln,
            cores=2,
            chains=2,
            step=xo.get_dense_nuts_step(target_accept=0.9),
        )
    return trace
Ejemplo n.º 8
0
 def fit(self, map_soln, model):
     with model:
         trace = pm.sample(
             tune=2000,
             draws=2000,
             start=map_soln,
             chains=4,
             step=xo.get_dense_nuts_step(target_accept=0.9),
         )
     trace_summary = pm.summary(
         trace, round_to='none'
     )  # Not a typo. PyMC3 wants 'none' as a string here.
     epoch = round(
         trace_summary['mean']['Transit epoch (BTJD)'],
         3)  # Round the epoch differently, as BTJD needs more digits.
     trace_summary['mean'] = self_.round_series_to_significant_figures(
         trace_summary['mean'], 5)
     trace_summary['mean']['Transit epoch (BTJD)'] = epoch
     self.bokeh_document.add_next_tick_callback(
         partial(self.update_parameters_table, trace_summary))
     with pd.option_context('display.max_columns', None,
                            'display.max_rows', None):
         print(trace_summary)
         print(f'Star radius: {self.star_radius}')
Ejemplo n.º 9
0
def build_model(mask=None, start=None):

with pm.Model() as model:

	# The baseline flux
	mean = pm.Normal("mean", mu=0.0, sd=0.00001)

	# The time of a reference transit for each planet
	t0 = pm.Normal("t0", mu=t0s, sd=1.0, shape=1)

	# The log period; also tracking the period itself
	logP = pm.Normal("logP", mu=np.log(periods), sd=0.01, shape=1)

	rho_star = pm.Normal("rho_star", mu=0.14, sd=0.01, shape=1)
	r_star = pm.Normal("r_star", mu=2.7, sd=0.01, shape=1)

	period = pm.Deterministic("period", pm.math.exp(logP))

	# The Kipping (2013) parameterization for quadratic limb darkening paramters
	u = xo.distributions.QuadLimbDark("u", testval=np.array([0.3, 0.2]))

	r = pm.Uniform(
		"r", lower=0.01, upper=0.3, shape=1, testval=0.15)
	
	b = xo.distributions.ImpactParameter(
		"b", ror=r, shape=1, testval=0.5)
	
	# Transit jitter & GP parameters
	logs2 = pm.Normal("logs2", mu=np.log(np.var(y)), sd=10)
	logw0 = pm.Normal("logw0", mu=0, sd=10)
	logSw4 = pm.Normal("logSw4", mu=np.log(np.var(y)), sd=10)

	# Set up a Keplerian orbit for the planets
	orbit = xo.orbits.KeplerianOrbit(period=period, t0=t0, b=b, rho_star=rho_star,r_star=r_star)
	
	# Compute the model light curve using starry
	light_curves = xo.LimbDarkLightCurve(u).get_light_curve(
		orbit=orbit, r=r, t=t
	)
	light_curve = pm.math.sum(light_curves, axis=-1) + mean

	# Here we track the value of the model light curve for plotting
	# purposes
	pm.Deterministic("light_curves", light_curves)

	kernel = xo.gp.terms.SHOTerm(log_Sw4=logSw4, log_w0=logw0, Q=1 / np.sqrt(2))
	gp = xo.gp.GP(kernel, t, tt.exp(logs2) + tt.zeros(len(t)), mean=light_curve)
	gp.marginal("gp", observed=y)
	pm.Deterministic("gp_pred", gp.predict())

	# The likelihood function assuming known Gaussian uncertainty
	pm.Normal("obs", mu=light_curve, sd=yerr, observed=y)

	# Fit for the maximum a posteriori parameters given the simuated
	# dataset
	map_soln = xo.optimize(start=model.test_point)
	
	return model, map_soln
	
model, map_soln = build_model()

gp_mod = map_soln["gp_pred"] + map_soln["mean"]
plt.clf()
plt.plot(t, y, ".k", ms=4, label="data")
plt.plot(t, gp_mod, lw=1,label="gp model")
plt.plot(t, map_soln["light_curves"], lw=1,label="transit model")
plt.xlim(t.min(), t.max())
plt.ylabel("relative flux")
plt.xlabel("time [days]")
plt.legend(fontsize=10)
_ = plt.title("map model")

np.random.seed(42)
with model:
    trace = pm.sample(
        tune=3000,
        draws=3000,
        start=map_soln,
        cores=2,
        chains=2,
        step=xo.get_dense_nuts_step(target_accept=0.9),
    )
    
    
pm.summary(trace, varnames=["period", "t0", "r", "b", "u", "mean", "rho_star","logw0","logSw4","logs2"])


import corner

samples = pm.trace_to_dataframe(trace, varnames=["period", "r"])
truth = np.concatenate(
    xo.eval_in_model([period, r], model.test_point, model=model)
)
_ = corner.corner(
    samples,
    truths=truth,
    labels=["period 1", "radius 1"],
)


# Compute the GP prediction
gp_mod = np.median(trace["gp_pred"] + trace["mean"][:, None], axis=0)

# Get the posterior median orbital parameters
p = np.median(trace["period"])
t0 = np.median(trace["t0"])

# Plot the folded data
x_fold = (t - t0 + 0.5 * p) % p - 0.5 * p
plt.plot(x_fold, y - gp_mod, ".k", label="data", zorder=-1000)

# Overplot the phase binned light curve
bins = np.linspace(-0.41, 0.41, 50)
denom, _ = np.histogram(x_fold, bins)
num, _ = np.histogram(x_fold, bins, weights=y)
denom[num == 0] = 1.0
plt.plot(0.5 * (bins[1:] + bins[:-1]), num / denom, "o", color="C2", label="binned")

# Plot the folded model
inds = np.argsort(x_fold)
inds = inds[np.abs(x_fold)[inds] < 0.3]
pred = trace["light_curves"][:, inds, 0]
pred = np.percentile(pred, [16, 50, 84], axis=0)
plt.plot(x_fold[inds], pred[1], color="C1", label="model")
art = plt.fill_between(
    x_fold[inds], pred[0], pred[2], color="C1", alpha=0.5, zorder=1000
)
art.set_edgecolor("none")

# Annotate the plot with the planet's period
txt = "period = {0:.5f} +/- {1:.5f} d".format(
    np.mean(trace["period"]), np.std(trace["period"])
)
plt.annotate(
    txt,
    (0, 0),
    xycoords="axes fraction",
    xytext=(5, 5),
    textcoords="offset points",
    ha="left",
    va="bottom",
    fontsize=12,
)

plt.legend(fontsize=10, loc=4)
plt.xlim(-0.5 * p, 0.5 * p)
plt.xlabel("time since transit [days]")
plt.ylabel("de-trended flux")
plt.xlim(-0.3, 0.3);
Ejemplo n.º 10
0
    def run_onetransit_inference(self, prior_d, pklpath, make_threadsafe=True):
        """
        Similar to "run_transit_inference", but with more restrictive priors on
        ephemeris. Also, it simultaneously fits for quadratic trend.
        """

        # if the model has already been run, pull the result from the
        # pickle. otherwise, run it.
        if os.path.exists(pklpath):
            d = pickle.load(open(pklpath, 'rb'))
            self.model = d['model']
            self.trace = d['trace']
            self.map_estimate = d['map_estimate']
            return 1

        with pm.Model() as model:

            assert len(self.data.keys()) == 1

            name = list(self.data.keys())[0]
            x_obs = list(self.data.values())[0][0]
            y_obs = list(self.data.values())[0][1]
            y_err = list(self.data.values())[0][2]
            t_exp = list(self.data.values())[0][3]

            # Fixed data errors.
            sigma = y_err

            # Define priors and PyMC3 random variables to sample over.

            # Stellar parameters. (Following tess.world notebooks).
            logg_star = pm.Normal("logg_star", mu=LOGG, sd=LOGG_STDEV)
            r_star = pm.Bound(pm.Normal, lower=0.0)("r_star",
                                                    mu=RSTAR,
                                                    sd=RSTAR_STDEV)
            rho_star = pm.Deterministic("rho_star",
                                        factor * 10**logg_star / r_star)

            # Transit parameters.
            t0 = pm.Normal("t0",
                           mu=prior_d['t0'],
                           sd=1e-3,
                           testval=prior_d['t0'])
            period = pm.Normal('period',
                               mu=prior_d['period'],
                               sd=3e-4,
                               testval=prior_d['period'])

            # NOTE: might want to implement kwarg for flexibility
            # u = xo.distributions.QuadLimbDark(
            #     "u", testval=prior_d['u']
            # )

            u0 = pm.Uniform('u[0]',
                            lower=prior_d['u[0]'] - 0.15,
                            upper=prior_d['u[0]'] + 0.15,
                            testval=prior_d['u[0]'])
            u1 = pm.Uniform('u[1]',
                            lower=prior_d['u[1]'] - 0.15,
                            upper=prior_d['u[1]'] + 0.15,
                            testval=prior_d['u[1]'])
            u = [u0, u1]

            # # The Espinoza (2018) parameterization for the joint radius ratio and
            # # impact parameter distribution
            # r, b = xo.distributions.get_joint_radius_impact(
            #     min_radius=0.001, max_radius=1.0,
            #     testval_r=prior_d['r'],
            #     testval_b=prior_d['b']
            # )
            # # NOTE: apparently, it's been deprecated. DFM's manuscript notes
            # that it leads to Rp/Rs values biased high

            log_r = pm.Uniform('log_r',
                               lower=np.log(1e-2),
                               upper=np.log(1),
                               testval=prior_d['log_r'])
            r = pm.Deterministic('r', tt.exp(log_r))

            b = xo.distributions.ImpactParameter("b",
                                                 ror=r,
                                                 testval=prior_d['b'])

            # the transit
            orbit = xo.orbits.KeplerianOrbit(period=period,
                                             t0=t0,
                                             b=b,
                                             rho_star=rho_star)

            transit_lc = pm.Deterministic(
                'transit_lc',
                xo.LimbDarkLightCurve(u).get_light_curve(
                    orbit=orbit, r=r, t=x_obs, texp=t_exp).T.flatten())

            # quadratic trend parameters
            mean = pm.Normal(f"{name}_mean",
                             mu=prior_d[f'{name}_mean'],
                             sd=1e-2,
                             testval=prior_d[f'{name}_mean'])
            a1 = pm.Normal(f"{name}_a1",
                           mu=prior_d[f'{name}_a1'],
                           sd=1,
                           testval=prior_d[f'{name}_a1'])
            a2 = pm.Normal(f"{name}_a2",
                           mu=prior_d[f'{name}_a2'],
                           sd=1,
                           testval=prior_d[f'{name}_a2'])

            _tmid = np.nanmedian(x_obs)
            lc_model = pm.Deterministic(
                'mu_transit', mean + a1 * (x_obs - _tmid) + a2 *
                (x_obs - _tmid)**2 + transit_lc)

            roughdepth = pm.Deterministic(f'roughdepth',
                                          pm.math.abs_(transit_lc).max())

            #
            # Derived parameters
            #

            # planet radius in jupiter radii
            r_planet = pm.Deterministic(
                "r_planet",
                (r * r_star) * (1 * units.Rsun / (1 * units.Rjup)).cgs.value)

            #
            # eq 30 of winn+2010, ignoring planet density.
            #
            a_Rs = pm.Deterministic("a_Rs", (rho_star * period**2)**(1 / 3) *
                                    (((1 * units.gram / (1 * units.cm)**3) *
                                      (1 * units.day**2) * const.G /
                                      (3 * np.pi))**(1 / 3)).cgs.value)

            #
            # cosi. assumes e=0 (e.g., Winn+2010 eq 7)
            #
            cosi = pm.Deterministic("cosi", b / a_Rs)

            # safer than tt.arccos(cosi)
            sini = pm.Deterministic("sini", pm.math.sqrt(1 - cosi**2))

            #
            # transit durations (T_14, T_13) for circular orbits. Winn+2010 Eq 14, 15.
            # units: hours.
            #
            T_14 = pm.Deterministic('T_14', (period / np.pi) * tt.arcsin(
                (1 / a_Rs) * pm.math.sqrt((1 + r)**2 - b**2) * (1 / sini)) *
                                    24)

            T_13 = pm.Deterministic('T_13', (period / np.pi) * tt.arcsin(
                (1 / a_Rs) * pm.math.sqrt((1 - r)**2 - b**2) * (1 / sini)) *
                                    24)

            #
            # mean model and likelihood
            #

            # mean_model = mu_transit + mean
            # mu_model = pm.Deterministic('mu_model', lc_model)

            likelihood = pm.Normal('obs',
                                   mu=lc_model,
                                   sigma=sigma,
                                   observed=y_obs)

            # Optimizing
            map_estimate = pm.find_MAP(model=model)

            # start = model.test_point
            # if 'transit' in self.modelcomponents:
            #     map_estimate = xo.optimize(start=start,
            #                                vars=[r, b, period, t0])
            # map_estimate = xo.optimize(start=map_estimate)

            if make_threadsafe:
                pass
            else:
                # as described in
                # https://github.com/matplotlib/matplotlib/issues/15410
                # matplotlib is not threadsafe. so do not make plots before
                # sampling, because some child processes tries to close a
                # cached file, and crashes the sampler.
                print(map_estimate)

            # sample from the posterior defined by this model.
            trace = pm.sample(
                tune=self.N_samples,
                draws=self.N_samples,
                start=map_estimate,
                cores=self.N_cores,
                chains=self.N_chains,
                step=xo.get_dense_nuts_step(target_accept=0.8),
            )

        with open(pklpath, 'wb') as buff:
            pickle.dump(
                {
                    'model': model,
                    'trace': trace,
                    'map_estimate': map_estimate
                }, buff)

        self.model = model
        self.trace = trace
        self.map_estimate = map_estimate
Ejemplo n.º 11
0
    def run_alltransit_inference(self, prior_d, pklpath, make_threadsafe=True):

        # if the model has already been run, pull the result from the
        # pickle. otherwise, run it.
        if os.path.exists(pklpath):
            d = pickle.load(open(pklpath, 'rb'))
            self.model = d['model']
            self.trace = d['trace']
            self.map_estimate = d['map_estimate']
            return 1

        with pm.Model() as model:

            # Shared parameters

            # Stellar parameters. (Following tess.world notebooks).
            logg_star = pm.Normal("logg_star", mu=LOGG, sd=LOGG_STDEV)
            r_star = pm.Bound(pm.Normal, lower=0.0)("r_star",
                                                    mu=RSTAR,
                                                    sd=RSTAR_STDEV)
            rho_star = pm.Deterministic("rho_star",
                                        factor * 10**logg_star / r_star)

            # fix Rp/Rs across bandpasses, b/c you're assuming it's a planet
            if 'quaddepthvar' not in self.modelid:
                log_r = pm.Uniform('log_r',
                                   lower=np.log(1e-2),
                                   upper=np.log(1),
                                   testval=prior_d['log_r'])
                r = pm.Deterministic('r', tt.exp(log_r))
            else:

                log_r_Tband = pm.Uniform('log_r_Tband',
                                         lower=np.log(1e-2),
                                         upper=np.log(1),
                                         testval=prior_d['log_r_Tband'])
                r_Tband = pm.Deterministic('r_Tband', tt.exp(log_r_Tband))

                log_r_Rband = pm.Uniform('log_r_Rband',
                                         lower=np.log(1e-2),
                                         upper=np.log(1),
                                         testval=prior_d['log_r_Rband'])
                r_Rband = pm.Deterministic('r_Rband', tt.exp(log_r_Rband))

                log_r_Bband = pm.Uniform('log_r_Bband',
                                         lower=np.log(1e-2),
                                         upper=np.log(1),
                                         testval=prior_d['log_r_Bband'])
                r_Bband = pm.Deterministic('r_Bband', tt.exp(log_r_Bband))

                r = r_Tband

            # Some orbital parameters
            t0 = pm.Normal("t0",
                           mu=prior_d['t0'],
                           sd=5e-3,
                           testval=prior_d['t0'])
            period = pm.Normal('period',
                               mu=prior_d['period'],
                               sd=5e-3,
                               testval=prior_d['period'])
            b = xo.distributions.ImpactParameter("b",
                                                 ror=r,
                                                 testval=prior_d['b'])
            orbit = xo.orbits.KeplerianOrbit(period=period,
                                             t0=t0,
                                             b=b,
                                             rho_star=rho_star)

            # NOTE: limb-darkening should be bandpass specific, but we don't
            # have the SNR to justify that, so go with TESS-dominated
            u0 = pm.Uniform('u[0]',
                            lower=prior_d['u[0]'] - 0.15,
                            upper=prior_d['u[0]'] + 0.15,
                            testval=prior_d['u[0]'])
            u1 = pm.Uniform('u[1]',
                            lower=prior_d['u[1]'] - 0.15,
                            upper=prior_d['u[1]'] + 0.15,
                            testval=prior_d['u[1]'])
            u = [u0, u1]

            star = xo.LimbDarkLightCurve(u)

            # Loop over "instruments" (TESS, then each ground-based lightcurve)
            parameters = dict()
            lc_models = dict()
            roughdepths = dict()

            for n, (name, (x, y, yerr, texp)) in enumerate(self.data.items()):

                # Define per-instrument parameters in a submodel, to not need
                # to prefix the names. Yields e.g., "TESS_mean",
                # "elsauce_0_mean", "elsauce_2_a2"
                with pm.Model(name=name, model=model):

                    # Transit parameters.
                    mean = pm.Normal("mean",
                                     mu=prior_d[f'{name}_mean'],
                                     sd=1e-2,
                                     testval=prior_d[f'{name}_mean'])

                    if 'quad' in self.modelid:

                        if name != 'tess':

                            # units: rel flux per day.
                            a1 = pm.Normal("a1",
                                           mu=prior_d[f'{name}_a1'],
                                           sd=1,
                                           testval=prior_d[f'{name}_a1'])
                            # units: rel flux per day^2.
                            a2 = pm.Normal("a2",
                                           mu=prior_d[f'{name}_a2'],
                                           sd=1,
                                           testval=prior_d[f'{name}_a2'])

                if self.modelid == 'alltransit':
                    lc_models[name] = pm.Deterministic(
                        f'{name}_mu_transit', mean + star.get_light_curve(
                            orbit=orbit, r=r, t=x, texp=texp).T.flatten())

                elif self.modelid == 'alltransit_quad':

                    if name != 'tess':
                        # midpoint for this definition of the quadratic trend
                        _tmid = np.nanmedian(x)

                        lc_models[name] = pm.Deterministic(
                            f'{name}_mu_transit', mean + a1 * (x - _tmid) +
                            a2 * (x - _tmid)**2 + star.get_light_curve(
                                orbit=orbit, r=r, t=x, texp=texp).T.flatten())
                    elif name == 'tess':

                        lc_models[name] = pm.Deterministic(
                            f'{name}_mu_transit', mean + star.get_light_curve(
                                orbit=orbit, r=r, t=x, texp=texp).T.flatten())

                elif self.modelid == 'alltransit_quaddepthvar':

                    if name != 'tess':
                        # midpoint for this definition of the quadratic trend
                        _tmid = np.nanmedian(x)

                        # do custom depth-to-
                        if (name == 'elsauce_20200401'
                                or name == 'elsauce_20200426'):
                            r = r_Rband
                        elif name == 'elsauce_20200521':
                            r = r_Tband
                        elif name == 'elsauce_20200614':
                            r = r_Bband

                        transit_lc = star.get_light_curve(
                            orbit=orbit, r=r, t=x, texp=texp).T.flatten()

                        lc_models[name] = pm.Deterministic(
                            f'{name}_mu_transit', mean + a1 * (x - _tmid) +
                            a2 * (x - _tmid)**2 + transit_lc)

                        roughdepths[name] = pm.Deterministic(
                            f'{name}_roughdepth',
                            pm.math.abs_(transit_lc).max())

                    elif name == 'tess':

                        r = r_Tband

                        transit_lc = star.get_light_curve(
                            orbit=orbit, r=r, t=x, texp=texp).T.flatten()

                        lc_models[name] = pm.Deterministic(
                            f'{name}_mu_transit', mean + transit_lc)

                        roughdepths[name] = pm.Deterministic(
                            f'{name}_roughdepth',
                            pm.math.abs_(transit_lc).max())

                # TODO: add error bar fudge
                likelihood = pm.Normal(f'{name}_obs',
                                       mu=lc_models[name],
                                       sigma=yerr,
                                       observed=y)

            #
            # Derived parameters
            #
            if self.modelid == 'alltransit_quaddepthvar':
                r = r_Tband

            # planet radius in jupiter radii
            r_planet = pm.Deterministic(
                "r_planet",
                (r * r_star) * (1 * units.Rsun / (1 * units.Rjup)).cgs.value)

            #
            # eq 30 of winn+2010, ignoring planet density.
            #
            a_Rs = pm.Deterministic("a_Rs", (rho_star * period**2)**(1 / 3) *
                                    (((1 * units.gram / (1 * units.cm)**3) *
                                      (1 * units.day**2) * const.G /
                                      (3 * np.pi))**(1 / 3)).cgs.value)

            #
            # cosi. assumes e=0 (e.g., Winn+2010 eq 7)
            #
            cosi = pm.Deterministic("cosi", b / a_Rs)

            # probably safer than tt.arccos(cosi)
            sini = pm.Deterministic("sini", pm.math.sqrt(1 - cosi**2))

            #
            # transit durations (T_14, T_13) for circular orbits. Winn+2010 Eq 14, 15.
            # units: hours.
            #
            T_14 = pm.Deterministic('T_14', (period / np.pi) * tt.arcsin(
                (1 / a_Rs) * pm.math.sqrt((1 + r)**2 - b**2) * (1 / sini)) *
                                    24)

            T_13 = pm.Deterministic('T_13', (period / np.pi) * tt.arcsin(
                (1 / a_Rs) * pm.math.sqrt((1 - r)**2 - b**2) * (1 / sini)) *
                                    24)

            # Optimizing
            map_estimate = pm.find_MAP(model=model)

            # start = model.test_point
            # if 'transit' in self.modelcomponents:
            #     map_estimate = xo.optimize(start=start,
            #                                vars=[r, b, period, t0])
            # map_estimate = xo.optimize(start=map_estimate)

            if make_threadsafe:
                pass
            else:
                # NOTE: would usually plot MAP estimate here, but really
                # there's not a huge need.
                print(map_estimate)
                pass

            # sample from the posterior defined by this model.
            trace = pm.sample(
                tune=self.N_samples,
                draws=self.N_samples,
                start=map_estimate,
                cores=self.N_cores,
                chains=self.N_chains,
                step=xo.get_dense_nuts_step(target_accept=0.8),
            )

        with open(pklpath, 'wb') as buff:
            pickle.dump(
                {
                    'model': model,
                    'trace': trace,
                    'map_estimate': map_estimate
                }, buff)

        self.model = model
        self.trace = trace
        self.map_estimate = map_estimate
Ejemplo n.º 12
0
ax[3].set_xlim(t_fine[0], t_fine[-1])
_ = ax[0].set_title("map orbit")

# %% [markdown]
# Now let's sample the posterior.

# %%
np.random.seed(1234)
with model:
    trace = pm.sample(
        tune=5000,
        draws=4000,
        start=map_soln,
        cores=2,
        chains=2,
        step=xo.get_dense_nuts_step(target_accept=0.9, adaptation_window=201),
    )

# %% [markdown]
# First we can check the convergence for some of the key parameters.

# %%
pm.summary(
    trace, varnames=["P", "tperi", "a_ang", "omega", "Omega", "incl", "ecc"]
)

# %% [markdown]
# That looks pretty good.
# Now here's a corner plot showing the covariances between parameters.

# %%
Ejemplo n.º 13
0
    plt.xlabel("t")
    plt.ylabel("y")
    plt.xlim(0, 10)
    _ = plt.ylim(-2.5, 2.5)
    plt.savefig('../results/test_results/gp_model/prediction_uncert.png',
                dpi=200)
    plt.close('all')

    with model:
        trace = pm.sample(
            tune=2000,
            draws=2000,
            start=map_soln,
            cores=2,
            chains=2,
            step=xo.get_dense_nuts_step(target_accept=0.9),
        )

    print(pm.summary(trace))

    # truth = np.concatenate(
    #     xo.eval_in_model([period, r], model.test_point, model=model)
    # )
    fig = corner.corner(
        samples,
        # truths=truth,
        labels=["S1", "S2", "w1", "w2", "log_Q"],
    )
    fig.savefig('../results/test_results/gp_model/test_gp_corner.png')
    plt.close('all')
Ejemplo n.º 14
0
def worker(task):
    (i1, i2), data, model_kw, basename = task

    g = GaiaData(data)

    cache_filename = os.path.abspath(f'../cache/tmp-{basename}_{i1}-{i2}.fits')
    if os.path.exists(cache_filename):
        print(f"({pid}) cache filename exists for index range: "
              f"{cache_filename}")
        return cache_filename

    print(f"({pid}) setting up model")
    helper = ComovingHelper(g)

    niter = 0
    while niter < 10:
        try:
            model = helper.get_model(**model_kw)
            break
        except OSError:
            print(f"{pid} failed to compile - trying again in 2sec...")
            time.sleep(5)
            niter += 1
            continue
    else:
        print(f"{pid} never successfully compiled. aborting")
        import socket
        print(socket.gethostname(), socket.getfqdn(),
              os.path.exists("/cm/shared/sw/pkg/devel/gcc/7.4.0/bin/g++"))
        return ''

    print(f"({pid}) done init model - running {len(g)} stars")

    probs = np.full(helper.N, np.nan)
    for n in range(helper.N):
        with model:
            pm.set_data({
                'y': helper.ys[n],
                'Cinv': helper.Cinvs[n],
                'M': helper.Ms[n]
            })

            test_pt = {
                'vxyz': helper.test_vxyz[n],
                'r': helper.test_r[n],
                'w': np.array([0.5, 0.5])
            }
            try:
                print("starting optimize")
                res = xo.optimize(start=test_pt,
                                  progress_bar=False,
                                  verbose=False)

                print("done optimize - starting sample")
                trace = pm.sample(
                    start=res,
                    tune=2000,
                    draws=1000,
                    cores=1,
                    chains=1,
                    step=xo.get_dense_nuts_step(target_accept=0.95),
                    progressbar=False)
            except Exception as e:
                print(str(e))
                continue

            # print("done sample - computing prob")
            ll_fg = trace.get_values(model.group_logp)
            ll_bg = trace.get_values(model.field_logp)
            post_prob = np.exp(ll_fg - np.logaddexp(ll_fg, ll_bg))
            probs[n] = post_prob.sum() / len(post_prob)

    # write probs to cache filename
    tbl = at.Table()
    tbl['source_id'] = g.source_id
    tbl['prob'] = probs
    tbl.write(cache_filename)

    return cache_filename
Ejemplo n.º 15
0
    def fit_shape(self,
                  run_MCMC=False,
                  r_start=None,
                  b_start=None,
                  verbose=True):
        """
        Fit the orbital parameters to the shape of the folded lightcurve data.
        
        Parameters:
            run_MCMC -- Boolean to run Monte Carlo Markov Chain methods. This will take longer but yields better results 
                        and gives uncertainties.
            r_start -- Starting estimate for the relative radius.
            b_start -- Starting estimate for the impact parameter.
        """

        if verbose: print("Optimising the shape of the orbital model:")

        folded_lc = self.lightcurve.fold(self.p_ref,
                                         self.t0_ref,
                                         ttvs=self.pars['ttvs'])
        t = folded_lc.time * self.p_ref
        y = folded_lc.flux
        sd = folded_lc.flux_err

        if r_start is None: r_start = 0.055
        if b_start is None: b_start = 0.5

        with pm.Model() as model:
            mean = pm.Normal("mean", mu=1.0, sd=0.1)  # Baseline flux
            t0 = pm.Normal("t0", mu=0, sd=0.025)

            u = xo.distributions.QuadLimbDark(
                "u")  # Quadratic limb-darkening parameters
            r = pm.Uniform("r", lower=0.01, upper=0.1,
                           testval=r_start)  # radius ratio
            b = xo.distributions.ImpactParameter(
                "b", ror=r, testval=b_start)  # Impact parameter

            orbit = xo.orbits.KeplerianOrbit(period=self.p_ref, t0=t0, b=b)

            # Compute the model light curve
            lc = xo.LimbDarkLightCurve(u).get_light_curve(orbit=orbit,
                                                          r=r,
                                                          t=t)
            light_curve = pm.math.sum(lc, axis=-1) + mean

            pm.Deterministic(
                "light_curve", light_curve
            )  # track the value of the model light curve for plotting purposes

            # The likelihood function
            pm.Normal("obs", mu=light_curve, sd=sd, observed=y)

            map_soln = xo.optimize(start=model.test_point,
                                   verbose=verbose,
                                   progress_bar=False)

        for k in ['mean', 't0', 'u', 'r', 'b']:
            self.pars[k] = map_soln[k]
            if verbose: print('\t', k, '=', self.pars[k])

        if run_MCMC:
            np.random.seed(42)
            with model:
                trace = pm.sample(
                    tune=3000,
                    draws=3000,
                    start=map_soln,
                    cores=1,
                    chains=2,
                    step=xo.get_dense_nuts_step(target_accept=0.9),
                )

            for k in ['mean', 't0', 'u', 'r', 'b']:
                self.pars[k] = np.median(trace[k], axis=0)
                self.pars['e_' + k] = self.pars[k] - np.percentile(
                    trace[k], 16, axis=0)
                self.pars['E_' + k] = -self.pars[k] + np.percentile(
                    trace[k], 84, axis=0)

                if verbose:
                    print(
                        f"\t{k} = {self.pars[k]} /+ {self.pars['E_'+k]} /- {self.pars['e_'+k]}"
                    )
Ejemplo n.º 16
0
def main(c, prior, metadata_row, overwrite=False):
    mcmc_cache_path = os.path.join(c.run_path, 'mcmc')
    os.makedirs(mcmc_cache_path, exist_ok=True)

    apogee_id = metadata_row['APOGEE_ID']

    this_cache_path = os.path.join(mcmc_cache_path, apogee_id)
    if os.path.exists(this_cache_path) and not overwrite:
        logger.info(f"{apogee_id} already done!")
        # Assume it's already done
        return

    # Set up The Joker:
    joker = tj.TheJoker(prior)

    # Load the data:
    logger.debug(f"{apogee_id}: Loading all data")
    allstar, allvisit = c.load_alldata()
    allstar = allstar[np.isin(allstar['APOGEE_ID'].astype(str), apogee_id)]
    allvisit = allvisit[np.isin(allvisit['APOGEE_ID'].astype(str),
                                allstar['APOGEE_ID'].astype(str))]
    visits = allvisit[allvisit['APOGEE_ID'] == apogee_id]
    data = get_rvdata(visits)

    t0 = time.time()

    # Read MAP sample:
    MAP_sample = extract_MAP_sample(metadata_row)
    logger.log(1, f"{apogee_id}: MAP sample loaded")

    # Run MCMC:
    with joker.prior.model as model:
        logger.log(1, f"{apogee_id}: Setting up MCMC...")
        mcmc_init = joker.setup_mcmc(data, MAP_sample)
        logger.log(1, f"{apogee_id}: ...setup complete")

        if 'ln_prior' not in model.named_vars:
            ln_prior_var = None
            for k in joker.prior._nonlinear_equiv_units:
                var = model.named_vars[k]
                try:
                    if ln_prior_var is None:
                        ln_prior_var = var.distribution.logp(var)
                    else:
                        ln_prior_var = ln_prior_var + var.distribution.logp(
                            var)
                except Exception as e:
                    logger.warning("Cannot auto-compute log-prior value for "
                                   f"parameter {var}.")
                    print(e)
                    continue

            pm.Deterministic('ln_prior', ln_prior_var)
            logger.log(1, f"{apogee_id}: setting up ln_prior in pymc3 model")

        if 'logp' not in model.named_vars:
            pm.Deterministic('logp', model.logpt)
            logger.log(1, f"{apogee_id}: setting up logp in pymc3 model")

        logger.debug(f"{apogee_id}: Starting MCMC sampling")
        trace = pm.sample(start=mcmc_init,
                          chains=4,
                          cores=1,
                          step=xo.get_dense_nuts_step(target_accept=0.95),
                          tune=c.tune,
                          draws=c.draws)

    pm.save_trace(trace, directory=this_cache_path, overwrite=True)
    logger.debug(
        "{apogee_id}: Finished MCMC sampling ({time:.2f} seconds)".format(
            apogee_id=apogee_id, time=time.time() - t0))
Ejemplo n.º 17
0
        def run_fitting():
            with pm.Model() as model:
                # Stellar parameters
                mean = pm.Normal("mean", mu=0.0, sigma=10.0 * 1e-3)
                u = xo.distributions.QuadLimbDark("u")
                star_params = [mean, u]

                # Gaussian process noise model
                sigma = pm.InverseGamma("sigma",
                                        alpha=3.0,
                                        beta=2 *
                                        np.median(self_.relative_flux_errors))
                log_Sw4 = pm.Normal("log_Sw4", mu=0.0, sigma=10.0)
                log_w0 = pm.Normal("log_w0",
                                   mu=np.log(2 * np.pi / 10.0),
                                   sigma=10.0)
                kernel = xo.gp.terms.SHOTerm(log_Sw4=log_Sw4,
                                             log_w0=log_w0,
                                             Q=1.0 / 3)
                noise_params = [sigma, log_Sw4, log_w0]

                # Planet parameters
                log_ror = pm.Normal("log_ror",
                                    mu=0.5 * np.log(self_.depth),
                                    sigma=10.0 * 1e-3)
                ror = pm.Deterministic("ror", tt.exp(log_ror))
                depth = pm.Deterministic("depth", tt.square(ror))

                # Orbital parameters
                log_period = pm.Normal("log_period",
                                       mu=np.log(self_.period),
                                       sigma=1.0)
                t0 = pm.Normal("t0", mu=self_.transit_epoch, sigma=1.0)
                log_dur = pm.Normal("log_dur", mu=np.log(0.1), sigma=10.0)
                b = xo.distributions.ImpactParameter("b", ror=ror)

                period = pm.Deterministic("period", tt.exp(log_period))
                dur = pm.Deterministic("dur", tt.exp(log_dur))

                # Set up the orbit
                orbit = xo.orbits.KeplerianOrbit(period=period,
                                                 duration=dur,
                                                 t0=t0,
                                                 b=b,
                                                 r_star=self.star_radius)

                # We're going to track the implied density for reasons that will become clear later
                pm.Deterministic("rho_circ", orbit.rho_star)

                # Set up the mean transit model
                star = xo.LimbDarkLightCurve(u)

                def lc_model(t):
                    return mean + tt.sum(star.get_light_curve(
                        orbit=orbit, r=ror * self.star_radius, t=t),
                                         axis=-1)

                # Finally the GP observation model
                gp = xo.gp.GP(kernel,
                              self_.times,
                              (self_.relative_flux_errors**2) + (sigma**2),
                              mean=lc_model)
                gp.marginal("obs", observed=self_.relative_fluxes)

                # Double check that everything looks good - we shouldn't see any NaNs!
                print(model.check_test_point())

                # Optimize the model
                map_soln = model.test_point
                map_soln = xo.optimize(map_soln, [sigma])
                map_soln = xo.optimize(map_soln, [log_ror, b, log_dur])
                map_soln = xo.optimize(map_soln, noise_params)
                map_soln = xo.optimize(map_soln, star_params)
                map_soln = xo.optimize(map_soln)

            with model:
                gp_pred, lc_pred = xo.eval_in_model(
                    [gp.predict(), lc_model(self_.times)], map_soln)

            x_fold = (self_.times - map_soln["t0"] + 0.5 * map_soln["period"]
                      ) % map_soln["period"] - 0.5 * map_soln["period"]
            inds = np.argsort(x_fold)
            initial_fit_data_source.data['Folded time (days)'] = x_fold
            initial_fit_data_source.data[
                'Relative flux'] = self_.relative_fluxes - gp_pred - map_soln[
                    "mean"]
            initial_fit_data_source.data[
                'Fit'] = lc_pred[inds] - map_soln["mean"]
            initial_fit_data_source.data['Fit time'] = x_fold[
                inds]  # TODO: This is terrible, you should be able to line them up *afterward* to not make a duplicate time column

            with model:
                trace = pm.sample(
                    tune=2000,
                    draws=2000,
                    start=map_soln,
                    chains=4,
                    step=xo.get_dense_nuts_step(target_accept=0.9),
                )

            trace_summary = pm.summary(
                trace, round_to='none'
            )  # Not a typo. PyMC3 wants 'none' as a string here.
            epoch = round(
                trace_summary['mean']['t0'],
                3)  # Round the epoch differently, as BTJD needs more digits.
            trace_summary['mean'] = self_.round_series_to_significant_figures(
                trace_summary['mean'], 5)
            trace_summary['mean']['t0'] = epoch
            parameters_table_data_source.data = trace_summary
            parameters_table_data_source.data[
                'parameter'] = trace_summary.index
            with pd.option_context('display.max_columns', None,
                                   'display.max_rows', None):
                print(trace_summary)
                print(f'Star radius: {self.star_radius}')
Ejemplo n.º 18
0
cov = np.dot(L, L.T)

# %% [markdown]
# And then we can sample this using PyMC3 and :func:`exoplanet.get_dense_nuts_step`:

# %%
import pymc3 as pm
import exoplanet as xo

with pm.Model() as model:
    pm.MvNormal("x", mu=np.zeros(ndim), chol=L, shape=(ndim, ))
    trace = pm.sample(tune=2000,
                      draws=2000,
                      chains=2,
                      cores=2,
                      step=xo.get_dense_nuts_step())

# %% [markdown]
# This is a little more verbose than the standard use of PyMC3, but the performance is several orders of magnitude better than you would get without the mass matrix tuning.
# As you can see from the `pymc3.summary`, the autocorrelation time of this chain is about 1 as we would expect for a simple problem like this.

# %%
pm.summary(trace)

# %% [markdown]
# ## Evaluating model components for specific samples
#
# I find that when I'm debugging a PyMC3 model, I often want to inspect the value of some part of the model for a given set of parameters.
# As far as I can tell, there isn't a simple way to do this in PyMC3, so *exoplanet* comes with a hack for doing this: :func:`exoplanet.eval_in_model`.
# This function handles the mapping between named PyMC3 variables and the input required by the Theano function that can evaluate the requested variable or tensor.
#
Ejemplo n.º 19
0
    def run_inference(self, prior_d, pklpath, make_threadsafe=True):

        # if the model has already been run, pull the result from the
        # pickle. otherwise, run it.
        if os.path.exists(pklpath):
            d = pickle.load(open(pklpath, 'rb'))
            self.model = d['model']
            self.trace = d['trace']
            self.map_estimate = d['map_estimate']
            return 1

        with pm.Model() as model:

            # Fixed data errors.
            sigma = self.y_err

            # Define priors and PyMC3 random variables to sample over.
            A_d, B_d, omega_d, phi_d = {}, {}, {}, {}
            _A_d, _B_d = {}, {}
            for modelcomponent in self.modelcomponents:

                if 'transit' in modelcomponent:

                    BoundedNormal = pm.Bound(pm.Normal, lower=0, upper=3)
                    m_star = BoundedNormal("m_star",
                                           mu=MSTAR_VANEYKEN,
                                           sd=MSTAR_STDEV)
                    r_star = BoundedNormal("r_star",
                                           mu=RSTAR_VANEYKEN,
                                           sd=RSTAR_STDEV)

                    # mean = pm.Normal(
                    #     "mean", mu=prior_d['mean'], sd=0.02, testval=prior_d['mean']
                    # )
                    mean = pm.Uniform("mean",
                                      lower=prior_d['mean'] - 1e-2,
                                      upper=prior_d['mean'] + 1e-2,
                                      testval=prior_d['mean'])

                    t0 = pm.Normal("t0",
                                   mu=prior_d['t0'],
                                   sd=0.002,
                                   testval=prior_d['t0'])

                    # logP = pm.Normal(
                    #     "logP", mu=np.log(prior_d['period']),
                    #     sd=0.001*np.abs(np.log(prior_d['period'])),
                    #     testval=np.log(prior_d['period'])
                    # )
                    # period = pm.Deterministic("period", pm.math.exp(logP))
                    period = pm.Normal('period',
                                       mu=prior_d['period'],
                                       sd=1e-3,
                                       testval=prior_d['period'])

                    u = xo.distributions.QuadLimbDark("u",
                                                      testval=prior_d['u'])

                    r = pm.Normal("r",
                                  mu=prior_d['r'],
                                  sd=0.10 * prior_d['r'],
                                  testval=prior_d['r'])
                    # r = pm.Uniform(
                    #     "r", lower=prior_d['r']-1e-2,
                    #     upper=prior_d['r']+1e-2, testval=prior_d['r']
                    # )

                    b = xo.distributions.ImpactParameter("b",
                                                         ror=r,
                                                         testval=prior_d['b'])

                    orbit = xo.orbits.KeplerianOrbit(period=period,
                                                     t0=t0,
                                                     b=b,
                                                     mstar=m_star,
                                                     rstar=r_star)
                    light_curve = (
                        mean + xo.LimbDarkLightCurve(u).get_light_curve(
                            orbit=orbit, r=r, t=self.x_obs, texp=self.t_exp))

                    #
                    # derived quantities
                    #
                    # stellar density in cgs
                    rhostar = pm.Deterministic(
                        "rhostar",
                        ((m_star / ((4 * np.pi / 3) * r_star**3)) *
                         (1 * units.Msun / ((1 * units.Rsun)**3)).cgs.value))

                    # planet radius in jupiter radii
                    r_planet = pm.Deterministic("r_planet", (r * r_star) *
                                                (1 * units.Rsun /
                                                 (1 * units.Rjup)).cgs.value)

                    #
                    # eq 30 of winn+2010, ignoring planet density.
                    #
                    a_Rs = pm.Deterministic(
                        "a_Rs", (rhostar * period**2)**(1 / 3) *
                        (((1 * units.gram /
                           (1 * units.cm)**3) * (1 * units.day**2) * const.G /
                          (3 * np.pi))**(1 / 3)).cgs.value)

                if 'sincos' in modelcomponent:
                    if 'Porb' in modelcomponent:
                        k = 'orb'
                    elif 'Prot' in modelcomponent:
                        k = 'rot'
                    else:
                        msg = 'expected Porb or Prot for freq specification'
                        raise NotImplementedError(msg)

                    omegakey = 'omega{}'.format(k)
                    if k == 'rot':
                        omega_d[omegakey] = pm.Normal(
                            omegakey,
                            mu=prior_d[omegakey],
                            sd=0.01 * prior_d[omegakey],
                            testval=prior_d[omegakey])
                        P_rot = pm.Deterministic(
                            'P_rot',
                            pm.math.dot(1 / omega_d[omegakey], 2 * np.pi))

                        #omega_d[omegakey] = pm.Uniform(omegakey,
                        #                               lower=prior_d[omegakey]-1e-2,
                        #                               upper=prior_d[omegakey]+1e-2,
                        #                               testval=prior_d[omegakey])
                    elif k == 'orb':
                        # For orbital frequency, no need to declare new
                        # random variable!
                        omega_d[omegakey] = pm.Deterministic(
                            omegakey, pm.math.dot(1 / period, 2 * np.pi))

                    # sin and cosine terms are highly degenerate...
                    phikey = 'phi{}'.format(k)
                    if k == 'rot':
                        phi_d[phikey] = pm.Uniform(
                            phikey,
                            lower=prior_d[phikey] - np.pi / 8,
                            upper=prior_d[phikey] + np.pi / 8,
                            testval=prior_d[phikey])
                        #phi_d[phikey] = pm.Uniform(phikey,
                        #                           lower=0,
                        #                           upper=np.pi,
                        #                           testval=prior_d[phikey])
                    elif k == 'orb':
                        # For orbital phase, no need to declare new
                        # random variable!
                        phi_d[phikey] = pm.Deterministic(
                            phikey, pm.math.dot(t0 / period, 2 * np.pi))

                    N_harmonics = int(modelcomponent[0])
                    for ix in range(N_harmonics):

                        if LINEAR_AMPLITUDES:
                            Akey = 'A{}{}'.format(k, ix)
                            Bkey = 'B{}{}'.format(k, ix)

                            A_d[Akey] = pm.Uniform(
                                Akey,
                                lower=-2 * np.abs(prior_d[Akey]),
                                upper=2 * np.abs(prior_d[Akey]),
                                testval=np.abs(prior_d[Akey]))

                            B_d[Bkey] = pm.Uniform(
                                Bkey,
                                lower=-2 * np.abs(prior_d[Bkey]),
                                upper=2 * np.abs(prior_d[Bkey]),
                                testval=np.abs(prior_d[Bkey]))

                        if LOG_AMPLITUDES:
                            Akey = 'A{}{}'.format(k, ix)
                            Bkey = 'B{}{}'.format(k, ix)
                            logAkey = 'logA{}{}'.format(k, ix)
                            logBkey = 'logB{}{}'.format(k, ix)

                            if k == 'rot':
                                mfact = 3
                            elif k == 'orb':
                                mfact = 10
                            _A_d[logAkey] = pm.Uniform(
                                logAkey,
                                lower=np.log(prior_d[Akey] / mfact),
                                upper=np.log(mfact * prior_d[Akey]),
                                testval=np.log(prior_d[Akey]))
                            A_d[Akey] = pm.Deterministic(
                                Akey, pm.math.exp(_A_d[logAkey]))

                            _B_d[logBkey] = pm.Uniform(
                                logBkey,
                                lower=np.log(prior_d[Bkey] / mfact),
                                upper=np.log(mfact * prior_d[Bkey]),
                                testval=np.log(prior_d[Bkey]))
                            B_d[Bkey] = pm.Deterministic(
                                Bkey, pm.math.exp(_B_d[logBkey]))

            harmonic_d = {**A_d, **B_d, **omega_d, **phi_d}

            # Build the likelihood

            if 'transit' not in self.modelcomponents:
                # NOTE: hacky implementation detail: I didn't now how else to
                # initialize an "empty" pymc3 random variable, so I assumed
                # here that "transit" would be a modelcomponent, and the
                # likelihood variable is initialized using it.
                msg = 'Expected transit to be a model component.'
                raise NotImplementedError(msg)

            for modelcomponent in self.modelcomponents:

                if 'transit' in modelcomponent:
                    mu_model = light_curve.flatten()
                    pm.Deterministic("mu_transit", light_curve.flatten())

                if 'sincos' in modelcomponent:
                    if 'Porb' in modelcomponent:
                        k = 'orb'
                    elif 'Prot' in modelcomponent:
                        k = 'rot'

                    N_harmonics = int(modelcomponent[0])
                    for ix in range(N_harmonics):

                        spnames = [
                            'A{}{}'.format(k, ix), 'omega{}'.format(k),
                            'phi{}'.format(k)
                        ]
                        cpnames = [
                            'B{}{}'.format(k, ix), 'omega{}'.format(k),
                            'phi{}'.format(k)
                        ]
                        sin_params = [harmonic_d[k] for k in spnames]
                        cos_params = [harmonic_d[k] for k in cpnames]

                        # harmonic multiplier
                        mult = ix + 1
                        sin_params[1] = pm.math.dot(sin_params[1], mult)
                        cos_params[1] = pm.math.dot(cos_params[1], mult)

                        s_mod = sin_model(sin_params, self.x_obs)
                        c_mod = cos_model(cos_params, self.x_obs)

                        mu_model += s_mod
                        mu_model += c_mod

                        # save model components (rot and orb) for plotting
                        pm.Deterministic("mu_{}sin{}".format(k, ix), s_mod)
                        pm.Deterministic("mu_{}cos{}".format(k, ix), c_mod)

            # track the total model to plot it
            pm.Deterministic("mu_model", mu_model)

            likelihood = pm.Normal('obs',
                                   mu=mu_model,
                                   sigma=sigma,
                                   observed=self.y_obs)

            # Get MAP estimate from model.
            map_estimate = pm.find_MAP(model=model)

            # Plot the simulated data and the maximum a posteriori model to
            # make sure that our initialization looks ok.
            self.y_MAP = map_estimate['mu_model'].flatten()

            if make_threadsafe:
                pass
            else:
                # as described in
                # https://github.com/matplotlib/matplotlib/issues/15410
                # matplotlib is not threadsafe. so do not make plots before
                # sampling, because some child processes tries to close a
                # cached file, and crashes the sampler.
                if self.PLOTDIR is None:
                    raise NotImplementedError
                outpath = os.path.join(self.PLOTDIR,
                                       'test_{}_MAP.png'.format(self.modelid))
                plot_MAP_data(self.x_obs, self.y_obs, self.y_MAP, outpath)

            # sample from the posterior defined by this model.
            trace = pm.sample(
                tune=self.N_samples,
                draws=self.N_samples,
                start=map_estimate,
                cores=self.N_cores,
                chains=self.N_chains,
                step=xo.get_dense_nuts_step(target_accept=0.9),
            )

        with open(pklpath, 'wb') as buff:
            pickle.dump(
                {
                    'model': model,
                    'trace': trace,
                    'map_estimate': map_estimate
                }, buff)

        self.model = model
        self.trace = trace
        self.map_estimate = map_estimate
Ejemplo n.º 20
0
    def run_allindivtransit_inference(self,
                                      prior_d,
                                      pklpath,
                                      make_threadsafe=True,
                                      target_accept=0.8):

        # if the model has already been run, pull the result from the
        # pickle. otherwise, run it.
        if os.path.exists(pklpath):
            d = pickle.load(open(pklpath, 'rb'))
            self.model = d['model']
            self.trace = d['trace']
            self.map_estimate = d['map_estimate']
            return 1

        with pm.Model() as model:

            # Shared parameters

            # Stellar parameters. (Following tess.world notebooks).
            logg_star = pm.Normal("logg_star", mu=LOGG, sd=LOGG_STDEV)
            r_star = pm.Bound(pm.Normal, lower=0.0)("r_star",
                                                    mu=RSTAR,
                                                    sd=RSTAR_STDEV)
            rho_star = pm.Deterministic("rho_star",
                                        factor * 10**logg_star / r_star)

            # fix Rp/Rs across bandpasses, b/c you're assuming it's a planet
            log_r = pm.Uniform('log_r',
                               lower=np.log(1e-2),
                               upper=np.log(1),
                               testval=prior_d['log_r'])

            r = pm.Deterministic('r', tt.exp(log_r))

            # Some orbital parameters
            t0 = pm.Normal("t0",
                           mu=prior_d['t0'],
                           sd=1e-1,
                           testval=prior_d['t0'])
            period = pm.Normal('period',
                               mu=prior_d['period'],
                               sd=1e-1,
                               testval=prior_d['period'])

            b = xo.distributions.ImpactParameter("b",
                                                 ror=r,
                                                 testval=prior_d['b'])
            orbit = xo.orbits.KeplerianOrbit(period=period,
                                             t0=t0,
                                             b=b,
                                             rho_star=rho_star)

            # NOTE: limb-darkening should be bandpass specific, but we don't
            # have the SNR to justify that, so go with TESS-dominated
            # u = xo.QuadLimbDark("u")

            # NOTE: deprecated(?)
            delta_u = 0.15
            u0 = pm.Uniform('u[0]',
                            lower=prior_d['u[0]'] - delta_u,
                            upper=prior_d['u[0]'] + delta_u,
                            testval=prior_d['u[0]'])
            u1 = pm.Uniform('u[1]',
                            lower=prior_d['u[1]'] - delta_u,
                            upper=prior_d['u[1]'] + delta_u,
                            testval=prior_d['u[1]'])
            u = [u0, u1]

            star = xo.LimbDarkLightCurve(u)

            # Loop over "instruments" (TESS, then each ground-based lightcurve)
            parameters = dict()
            lc_models = dict()
            roughdepths = dict()

            for n, (name, (x, y, yerr, texp)) in enumerate(self.data.items()):

                if 'tess' in name:
                    delta_trend = 0.100
                else:
                    delta_trend = 0.050

                # Define per-instrument parameters in a submodel, to not need
                # to prefix the names. Yields e.g., "TESS_0_mean",
                # "elsauce_0_mean", "elsauce_2_a2"
                mean = pm.Normal(f'{name}_mean',
                                 mu=prior_d[f'{name}_mean'],
                                 sd=1e-2,
                                 testval=prior_d[f'{name}_mean'])
                a1 = pm.Uniform(f'{name}_a1',
                                lower=-delta_trend,
                                upper=delta_trend,
                                testval=prior_d[f'{name}_a1'])
                a2 = pm.Uniform(f'{name}_a2',
                                lower=-delta_trend,
                                upper=delta_trend,
                                testval=prior_d[f'{name}_a2'])

                # midpoint for this definition of the quadratic trend
                _tmid = np.nanmedian(x)

                transit_lc = star.get_light_curve(orbit=orbit,
                                                  r=r,
                                                  t=x,
                                                  texp=texp).T.flatten()

                lc_models[name] = pm.Deterministic(
                    f'{name}_mu_transit',
                    mean + a1 * (x - _tmid) + a2 * (x - _tmid)**2 + transit_lc)

                roughdepths[name] = pm.Deterministic(
                    f'{name}_roughdepth',
                    pm.math.abs_(transit_lc).max())

                # NOTE: might want error bar fudge.
                likelihood = pm.Normal(f'{name}_obs',
                                       mu=lc_models[name],
                                       sigma=yerr,
                                       observed=y)

            #
            # Derived parameters
            #

            # planet radius in jupiter radii
            r_planet = pm.Deterministic(
                "r_planet",
                (r * r_star) * (1 * units.Rsun / (1 * units.Rjup)).cgs.value)

            #
            # eq 30 of winn+2010, ignoring planet density.
            #
            a_Rs = pm.Deterministic("a_Rs", (rho_star * period**2)**(1 / 3) *
                                    (((1 * units.gram / (1 * units.cm)**3) *
                                      (1 * units.day**2) * const.G /
                                      (3 * np.pi))**(1 / 3)).cgs.value)

            #
            # cosi. assumes e=0 (e.g., Winn+2010 eq 7)
            #
            cosi = pm.Deterministic("cosi", b / a_Rs)

            # probably safer than tt.arccos(cosi)
            sini = pm.Deterministic("sini", pm.math.sqrt(1 - cosi**2))

            #
            # transit durations (T_14, T_13) for circular orbits. Winn+2010 Eq 14, 15.
            # units: hours.
            #
            T_14 = pm.Deterministic('T_14', (period / np.pi) * tt.arcsin(
                (1 / a_Rs) * pm.math.sqrt((1 + r)**2 - b**2) * (1 / sini)) *
                                    24)

            T_13 = pm.Deterministic('T_13', (period / np.pi) * tt.arcsin(
                (1 / a_Rs) * pm.math.sqrt((1 - r)**2 - b**2) * (1 / sini)) *
                                    24)

            map_estimate = pm.find_MAP(model=model)

            # if make_threadsafe:
            #     pass
            # else:
            #     # NOTE: would usually plot MAP estimate here, but really
            #     # there's not a huge need.
            #     print(map_estimate)
            #     for k,v in map_estimate.items():
            #         if 'transit' not in k:
            #             print(k, v)

            # NOTE: could start at map_estimate, which currently is not being
            # used for anything.
            start = model.test_point

            trace = pm.sample(
                tune=self.N_samples,
                draws=self.N_samples,
                start=start,
                cores=self.N_cores,
                chains=self.N_chains,
                step=xo.get_dense_nuts_step(target_accept=target_accept),
            )

        with open(pklpath, 'wb') as buff:
            pickle.dump(
                {
                    'model': model,
                    'trace': trace,
                    'map_estimate': map_estimate
                }, buff)

        self.model = model
        self.trace = trace
        self.map_estimate = map_estimate
Ejemplo n.º 21
0
    def run_rv_inference(self, prior_d, pklpath, make_threadsafe=True):

        # if the model has already been run, pull the result from the
        # pickle. otherwise, run it.
        if os.path.exists(pklpath):
            d = pickle.load(open(pklpath, 'rb'))
            self.model = d['model']
            self.trace = d['trace']
            self.map_estimate = d['map_estimate']
            return 1

        with pm.Model() as model:

            # Fixed data errors.
            sigma = self.y_err

            # Define priors and PyMC3 random variables to sample over.

            # Stellar parameters. (Following tess.world notebooks).
            logg_star = pm.Normal("logg_star",
                                  mu=prior_d['logg_star'][0],
                                  sd=prior_d['logg_star'][1])

            r_star = pm.Bound(pm.Normal, lower=0.0)("r_star",
                                                    mu=prior_d['r_star'][0],
                                                    sd=prior_d['r_star'][1])
            rho_star = pm.Deterministic("rho_star",
                                        factor * 10**logg_star / r_star)

            # RV parameters.

            # Chen & Kipping predicted M: 49.631 Mearth, based on Rp of 8Re. It
            # could be bigger, e.g., 94m/s if 1 Mjup.
            # Predicted K: 14.26 m/s

            #K = pm.Lognormal("K", mu=np.log(prior_d['K'][0]),
            #                 sigma=prior_d['K'][1])
            log_K = pm.Uniform('log_K',
                               lower=prior_d['log_K'][0],
                               upper=prior_d['log_K'][1])
            K = pm.Deterministic('K', tt.exp(log_K))

            period = pm.Normal("period",
                               mu=prior_d['period'][0],
                               sigma=prior_d['period'][1])

            ecs = xo.UnitDisk("ecs", testval=np.array([0.7, -0.3]))
            ecc = pm.Deterministic("ecc", tt.sum(ecs**2))

            omega = pm.Deterministic("omega", tt.arctan2(ecs[1], ecs[0]))

            phase = xo.UnitUniform("phase")

            # use time of transit, rather than time of periastron. we do, after
            # all, know it.
            t0 = pm.Normal("t0",
                           mu=prior_d['t0'][0],
                           sd=prior_d['t0'][1],
                           testval=prior_d['t0'][0])

            orbit = xo.orbits.KeplerianOrbit(period=period,
                                             t0=t0,
                                             rho_star=rho_star,
                                             ecc=ecc,
                                             omega=omega)

            #FIXME edit these
            # noise model parameters: FIXME what are these?
            S_tot = pm.Lognormal("S_tot",
                                 mu=np.log(prior_d['S_tot'][0]),
                                 sigma=prior_d['S_tot'][1])
            ell = pm.Lognormal("ell",
                               mu=np.log(prior_d['ell'][0]),
                               sigma=prior_d['ell'][1])

            # per instrument parameters
            means = pm.Normal(
                "means",
                mu=np.array([
                    np.median(self.y_obs[self.telvec == u])
                    for u in self.uniqueinstrs
                ]),
                sigma=500,
                shape=self.num_inst,
            )

            # different instruments have different intrinsic jitters. assign
            # those based on the reported error bars. (NOTE: might inflate or
            # overwrite these, for say, CHIRON)
            sigmas = pm.HalfNormal("sigmas",
                                   sigma=np.array([
                                       np.median(self.y_err[self.telvec == u])
                                       for u in self.uniqueinstrs
                                   ]),
                                   shape=self.num_inst)

            # Compute the RV offset and jitter for each data point depending on
            # its instrument
            mean = tt.zeros(len(self.x_obs))
            diag = tt.zeros(len(self.x_obs))
            for i, u in enumerate(self.uniqueinstrs):
                mean += means[i] * (self.telvec == u)
                diag += (self.y_err**2 + sigmas[i]**2) * (self.telvec == u)
            pm.Deterministic("mean", mean)
            pm.Deterministic("diag", diag)

            # NOTE: local function definition is jank
            def rv_model(x):
                return orbit.get_radial_velocity(x, K=K)

            kernel = xo.gp.terms.SHOTerm(S_tot=S_tot,
                                         w0=2 * np.pi / ell,
                                         Q=1.0 / 3)
            # NOTE temp
            gp = xo.gp.GP(kernel, self.x_obs, diag, mean=rv_model)
            # gp = xo.gp.GP(kernel, self.x_obs, diag,
            #               mean=orbit.get_radial_velocity(self.x_obs, K=K))
            # the actual "conditioning" step, i.e. the likelihood definition
            gp.marginal("obs", observed=self.y_obs - mean)
            pm.Deterministic("gp_pred", gp.predict())

            map_estimate = model.test_point
            map_estimate = xo.optimize(map_estimate, [means])
            map_estimate = xo.optimize(map_estimate, [means, phase])
            map_estimate = xo.optimize(map_estimate, [means, phase, log_K])
            map_estimate = xo.optimize(map_estimate,
                                       [means, t0, log_K, period, ecs])
            map_estimate = xo.optimize(map_estimate, [sigmas, S_tot, ell])
            map_estimate = xo.optimize(map_estimate)

            #
            # Derived parameters
            #

            #TODO
            # # planet radius in jupiter radii
            # r_planet = pm.Deterministic(
            #     "r_planet", (r*r_star)*( 1*units.Rsun/(1*units.Rjup) ).cgs.value
            # )

        # Plot the simulated data and the maximum a posteriori model to
        # make sure that our initialization looks ok.

        # i.e., "detrended". the "rv data" are y_obs - mean. The "trend" model
        # is a GP. FIXME: AFAIK, it doesn't do much as-implemented.
        self.y_MAP = (self.y_obs - map_estimate["mean"] -
                      map_estimate["gp_pred"])

        t_pred = np.linspace(self.x_obs.min() - 10,
                             self.x_obs.max() + 10, 10000)

        with model:
            # NOTE temp
            y_pred_MAP = xo.eval_in_model(rv_model(t_pred), map_estimate)
            # # NOTE temp
            # y_pred_MAP = xo.eval_in_model(
            #     orbit.get_radial_velocity(t_pred, K=K), map_estimate
            # )

        self.x_pred = t_pred
        self.y_pred_MAP = y_pred_MAP

        if make_threadsafe:
            pass
        else:
            # as described in
            # https://github.com/matplotlib/matplotlib/issues/15410
            # matplotlib is not threadsafe. so do not make plots before
            # sampling, because some child processes tries to close a
            # cached file, and crashes the sampler.

            print(map_estimate)

            if self.PLOTDIR is None:
                raise NotImplementedError
            outpath = os.path.join(self.PLOTDIR,
                                   'test_{}_MAP.png'.format(self.modelid))

            plot_MAP_rv(self.x_obs, self.y_obs, self.y_MAP, self.y_err,
                        self.telcolors, self.x_pred, self.y_pred_MAP,
                        map_estimate, outpath)

        with model:
            # sample from the posterior defined by this model.
            trace = pm.sample(
                tune=self.N_samples,
                draws=self.N_samples,
                start=map_estimate,
                cores=self.N_cores,
                chains=self.N_chains,
                step=xo.get_dense_nuts_step(target_accept=0.8),
            )

        # with open(pklpath, 'wb') as buff:
        #     pickle.dump({'model': model, 'trace': trace,
        #                  'map_estimate': map_estimate}, buff)

        self.model = model
        self.trace = trace
        self.map_estimate = map_estimate
Ejemplo n.º 22
0
    def run_inference(self, prior_d, pklpath, make_threadsafe=True):

        # if the model has already been run, pull the result from the
        # pickle. otherwise, run it.
        if os.path.exists(pklpath):
            d = pickle.load(open(pklpath, 'rb'))
            self.model = d['model']
            self.trace = d['trace']
            self.map_estimate = d['map_estimate']
            return 1

        with pm.Model() as model:

            # Fixed data errors.
            sigma = self.y_err

            # Define priors and PyMC3 random variables to sample over.
            # Start with the transit parameters.
            mean = pm.Normal("mean",
                             mu=prior_d['mean'],
                             sd=1e-2,
                             testval=prior_d['mean'])

            t0 = pm.Normal("t0",
                           mu=prior_d['t0'],
                           sd=5e-3,
                           testval=prior_d['t0'])

            period = pm.Normal('period',
                               mu=prior_d['period'],
                               sd=5e-3,
                               testval=prior_d['period'])

            u = xo.distributions.QuadLimbDark("u", testval=prior_d['u'])

            r = pm.Normal("r",
                          mu=prior_d['r'],
                          sd=0.20 * prior_d['r'],
                          testval=prior_d['r'])

            b = xo.distributions.ImpactParameter("b",
                                                 ror=r,
                                                 testval=prior_d['b'])

            orbit = xo.orbits.KeplerianOrbit(period=period,
                                             t0=t0,
                                             b=b,
                                             mstar=self.mstar,
                                             rstar=self.rstar)

            mu_transit = pm.Deterministic(
                'mu_transit',
                xo.LimbDarkLightCurve(u).get_light_curve(
                    orbit=orbit, r=r, t=self.x_obs,
                    texp=self.t_exp).T.flatten())

            mean_model = mu_transit + mean

            if self.modelcomponents == ['transit']:

                mu_model = pm.Deterministic('mu_model', mean_model)

                likelihood = pm.Normal('obs',
                                       mu=mu_model,
                                       sigma=sigma,
                                       observed=self.y_obs)

            if 'gprot' in self.modelcomponents:

                # Instantiate the GP parameters.

                P_rot = pm.Normal("P_rot",
                                  mu=prior_d['P_rot'],
                                  sigma=1.0,
                                  testval=prior_d['P_rot'])

                amp = pm.Uniform('amp', lower=5, upper=40, testval=10)

                mix = xo.distributions.UnitUniform("mix")

                log_Q0 = pm.Normal("log_Q0",
                                   mu=1.0,
                                   sd=10.0,
                                   testval=prior_d['log_Q0'])

                log_deltaQ = pm.Normal("log_deltaQ",
                                       mu=2.0,
                                       sd=10.0,
                                       testval=prior_d['log_deltaQ'])

                kernel = terms.RotationTerm(
                    period=P_rot,
                    amp=amp,
                    mix=mix,
                    log_Q0=log_Q0,
                    log_deltaQ=log_deltaQ,
                )

                gp = GP(kernel, self.x_obs, sigma**2, mean=mean_model)

                # Condition the GP on the observations and add the marginal likelihood
                # to the model. Needed before calling "gp.predict()".
                # NOTE: This formally is the definition of the likelihood?  It
                # would be good to figure out how this works under the hood...
                gp.marginal("transit_obs", observed=self.y_obs)

                # Compute the mean model prediction for plotting purposes
                mu_gprot = pm.Deterministic("mu_gprot", gp.predict())
                mu_model = pm.Deterministic("mu_model", mu_gprot + mean_model)

            # Optimizing
            start = model.test_point
            if 'transit' in self.modelcomponents:
                map_estimate = xo.optimize(start=start,
                                           vars=[r, b, period, t0])
            if 'gprot' in self.modelcomponents:
                map_estimate = xo.optimize(
                    start=map_estimate,
                    vars=[P_rot, amp, mix, log_Q0, log_deltaQ])
            map_estimate = xo.optimize(start=map_estimate)
            # map_estimate = pm.find_MAP(model=model)

            # Plot the simulated data and the maximum a posteriori model to
            # make sure that our initialization looks ok.
            self.y_MAP = (map_estimate['mean'] + map_estimate['mu_transit'])
            if 'gprot' in self.modelcomponents:
                self.y_MAP += map_estimate['mu_gprot']

            if make_threadsafe:
                pass
            else:
                # as described in
                # https://github.com/matplotlib/matplotlib/issues/15410
                # matplotlib is not threadsafe. so do not make plots before
                # sampling, because some child processes tries to close a
                # cached file, and crashes the sampler.

                print(map_estimate)

                if self.PLOTDIR is None:
                    raise NotImplementedError
                outpath = os.path.join(self.PLOTDIR,
                                       'test_{}_MAP.png'.format(self.modelid))
                plot_MAP_data(self.x_obs, self.y_obs, self.y_MAP, outpath)

            # sample from the posterior defined by this model.
            trace = pm.sample(
                tune=self.N_samples,
                draws=self.N_samples,
                start=map_estimate,
                cores=self.N_cores,
                chains=self.N_chains,
                step=xo.get_dense_nuts_step(target_accept=0.9),
            )

        with open(pklpath, 'wb') as buff:
            pickle.dump(
                {
                    'model': model,
                    'trace': trace,
                    'map_estimate': map_estimate
                }, buff)

        self.model = model
        self.trace = trace
        self.map_estimate = map_estimate