Ejemplo n.º 1
0
    def test_emcee(self):
        self.initialize(True)
        nwalkers = 20
        nsteps = 20

        retriever = Retriever()
        result = retriever.run_emcee(self.wavelength_bins,
                                     self.depths,
                                     self.errors,
                                     self.fit_info,
                                     nsteps=nsteps,
                                     nwalkers=nwalkers,
                                     include_condensation=False)
        self.assertTrue(isinstance(result, emcee.ensemble.EnsembleSampler))
        self.assertTrue(result.chain.shape,
                        (nwalkers, nsteps, len(self.fit_info.fit_param_names)))
        self.assertTrue(result.lnprobability.shape, (nwalkers, nsteps))

        retriever = Retriever()
        result = retriever.run_emcee(self.wavelength_bins,
                                     self.depths,
                                     self.errors,
                                     self.fit_info,
                                     nsteps=nsteps,
                                     nwalkers=nwalkers,
                                     include_condensation=True,
                                     plot_best=True)
        self.assertTrue(isinstance(result, emcee.ensemble.EnsembleSampler))
        self.assertEqual(
            result.chain.shape,
            (nwalkers, nsteps, len(self.fit_info.fit_param_names)))
        self.assertEqual(result.lnprobability.shape, (nwalkers, nsteps))
fit_info.add_uniform_fit_param("log_cloudtop_P", -0.99, 5)
fit_info.add_uniform_fit_param("error_multiple", 0.5, 5)

#Use Nested Sampling to do the fitting
# with ThreadPoolExecutor() as executor:
# with ProcessPoolExecutor() as executor:
time_stamp = datetime.utcnow().strftime("%Y%m%d%H%M%S")

if bayesian_model == 'multinest':
    with Pool(cpu_count()) as executor:
        result = retriever.run_multinest(wave_bins, depths, errors, fit_info, nestle_kwargs={'pool':executor})#, 'bootstrap':0 # bootstrap for `dynesty`
    
    result_dict = {'samples':result.samples, 'weights':result.weights, 'logl':result.logl}
    joblib.dump(result_dict, 'multinest_results_{}.joblib.save'.format(time_stamp))
elif bayesian_model == 'emcee':
    result = retriever.run_emcee(wave_bins, depths, errors, fit_info, nwalkers=nwalkers, nsteps=nsteps)
    
    result_dict = {'flatchain':result.flatchain, 'flatlnprob':result.flatlnprobability, 'chain':result.chain, 'lnprob':result.lnprobability}
    joblib.dump(result_dict, 'emcee_results_{}walkers_{}steps_{}.joblib.save'.format(nwalkers, nsteps, time_stamp))
else:
    raise ValueError("Options for `bayesian_model` (-bm, --bayesianmodel) must be either 'multinest' or 'emcee'")

# Establish the Range in Wavelength to plot high resolution figures
wave_min = wave_bins.min()
wave_max = wave_bins.max()

n_theory_pts = 500
wavelengths_theory = np.linspace(wave_min, wave_max, n_theory_pts)
half_diff_lam = 0.5*np.median(np.diff(wavelengths_theory))

# Setup calculator to use the theoretical wavelengths
Ejemplo n.º 3
0
    log_scatt_factor=0, scatt_slope=4, error_multiple=1, T_star=6091)

#Add fitting parameters - this specifies which parameters you want to fit
#e.g. since we have not included cloudtop_P, it will be fixed at the value specified in the constructor

fit_info.add_gaussian_fit_param('Rs', 0.02*R_sun)
fit_info.add_gaussian_fit_param('Mp', 0.04*M_jup)

# Here, emcee is initialized with walkers where R is between 0.9*R_guess and
# 1.1*R_guess.  However, the hard limit on R is from 0 to infinity.
fit_info.add_uniform_fit_param('R', 0, np.inf, 0.9*R_guess, 1.1*R_guess)

fit_info.add_uniform_fit_param('T', 300, 3000, 0.5*T_guess, 1.5*T_guess)
fit_info.add_uniform_fit_param("log_scatt_factor", 0, 5, 0, 1)
fit_info.add_uniform_fit_param("logZ", -1, 3)
fit_info.add_uniform_fit_param("log_cloudtop_P", -0.99, 5)
fit_info.add_uniform_fit_param("error_multiple", 0, np.inf, 0.5, 5)

#Use Nested Sampling to do the fitting
result = retriever.run_emcee(bins, depths, errors, fit_info, plot_best=True)
plt.savefig("best_fit.png")

np.save("chain.npy", result.chain)
np.save("logl.npy", result.lnprobability)

fig = corner.corner(result.flatchain,
                    range=[0.99] * result.flatchain.shape[1],
                    labels=fit_info.fit_param_names)
fig.savefig("emcee_corner.png")

Ejemplo n.º 4
0
class PlatonWrapper():
    """Class object for running the platon atmospheric retrieval
    software."""
    def __init__(self):
        """Initialize the class object."""

        self.ec2_id = ''
        self.output_results = 'results.dat'
        self.output_plot = 'corner.png'
        self.retriever = Retriever()
        self.ssh_file = ''
        self.aws = False
        self._configure_logging()

    def _configure_logging(self):
        """Creates a log file that logs the execution of the script.

        Log files are written to a ``logs/`` subdirectory within the
        current working directory.

        Returns
        -------
        start_time : obj
            The start time of the script execution
        """

        # Define save location
        log_file = 'logs/{}.log'.format(
            datetime.datetime.now().strftime('%Y-%m-%d-%H-%M'))

        # Create the subdirectory if necessary
        if not os.path.exists('logs/'):
            os.mkdir('logs/')

        # Make sure no other root handlers exist before configuring the logger
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)

        # Create the log file
        logging.basicConfig(filename=log_file,
                            format='%(asctime)s %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S %p',
                            level=logging.INFO)
        print('Log file initialized to {}'.format(log_file))

        # Log environment information
        logging.info('User: '******'System: ' + socket.gethostname())
        logging.info('Python Version: ' + sys.version.replace('\n', ''))
        logging.info('Python Executable Path: ' + sys.executable)

        self.start_time = time.time()

    def make_plot(self):
        """Create a corner plot that shows the results of the retrieval."""

        print('Creating corner plot')
        logging.info('Creating corner plot')

        matplotlib.rcParams['text.usetex'] = False

        if self.method == 'emcee':
            fig = corner.corner(self.result.flatchain,
                                range=[0.99] * self.result.flatchain.shape[1],
                                labels=self.fit_info.fit_param_names)

        elif self.method == 'multinest':
            fig = corner.corner(self.result.samples,
                                weights=self.result.weights,
                                range=[0.99] * self.result.samples.shape[1],
                                labels=self.fit_info.fit_param_names)

        # Save the results
        self.output_plot = '{}_corner.png'.format(self.method)
        fig.savefig(self.output_plot)
        print('Corner plot saved to {}'.format(self.output_plot))
        logging.info('Corner plot saved to {}'.format(self.output_plot))

    def retrieve(self, method):
        """Perform the atmopsheric retrieval via the given method

        Parameters
        ----------
        method : str
            The method by which to perform atmospheric retrievals.  Can
            either be ``emcee`` or ``multinest``."""

        print('Performing atmopsheric retrievals via {}'.format(method))
        logging.info('Performing atmopsheric retrievals via {}'.format(method))

        # Ensure that the method parameter is valid
        assert method in ['multinest',
                          'emcee'], 'Unrecognized method: {}'.format(method)
        self.method = method

        # For processing on AWS
        if self.aws:

            # Start or create an EC2 instance
            instance, key, client = start_ec2(self.ssh_file, self.ec2_id)

            # Build the environment on EC2 instance if necessary
            if self.build_required:
                build_environment(instance, key, client)

            # Transfer object file to EC2
            transfer_to_ec2(instance, key, client, 'pw.obj')

            # Connect to the EC2 instance and run commands
            command = './exoctk/exoctk/atmospheric_retrievals/exoctk-env-init.sh python exoctk/exoctk/atmospheric_retrievals/platon_wrapper.py {}'.format(
                self.method)
            client.connect(hostname=instance.public_dns_name,
                           username='******',
                           pkey=key)
            stdin, stdout, stderr = client.exec_command(command)
            output = stdout.read()
            errors = stderr.read()
            log_output(output)
            log_output(errors)

            # Trasfer output products from EC2 to user
            if self.method == 'emcee':
                transfer_from_ec2(instance, key, client, 'emcee_results.obj')
                transfer_from_ec2(instance, key, client, 'emcee_corner.png')
            elif self.method == 'multinest':
                transfer_from_ec2(instance, key, client,
                                  'multinest_results.dat')
                transfer_from_ec2(instance, key, client,
                                  'multinest_corner.png')

            # Terminate or stop the EC2 instance
            stop_ec2(self.ec2_id, instance)

        # For processing locally
        else:
            if self.method == 'emcee':
                self.result = self.retriever.run_emcee(self.bins, self.depths,
                                                       self.errors,
                                                       self.fit_info)
            elif self.method == 'multinest':
                self.result = self.retriever.run_multinest(self.bins,
                                                           self.depths,
                                                           self.errors,
                                                           self.fit_info,
                                                           plot_best=False)

        _log_execution_time(self.start_time)

    def save_results(self):
        """Save the results of the retrieval to an output file."""

        print('Saving results')
        logging.info('Saving results')

        # Save the results
        if self.method == 'multinest':
            self.output_results = 'multinest_results.dat'
            with open(self.output_results, 'w') as f:
                f.write(str(self.result))
        elif self.method == 'emcee':
            self.output_results = 'emcee_results.obj'
            with open(self.output_results, 'wb') as f:
                pickle.dump(self.result, f)

        print('Results file saved to {}'.format(self.output_results))
        logging.info('Results file saved to {}'.format(self.output_results))

    def set_parameters(self, params):
        """Set necessary parameters to perform the retrieval.

        Required parameters include ``Rs``, ``Mp``, ``Rp``, and ``T``.
        Optional parameters include ``logZ``, ``CO_ratio``,
        ``log_cloudtop_P``, ``log_scatt_factor``, ``scatt_slope``,
        ``error_multiple``, and ``T_star``.

        Parameters
        ----------
        params : str or dict
            Either a path to a params file to use, or a dictionary of
            parameters and their values for running the software.
            See "Use" documentation for further details.
        """

        print('Setting parameters: {}'.format(params))
        logging.info('Setting parameters: {}'.format(params))

        _validate_parameters(params)
        _apply_factors(params)
        self.params = params
        self.fit_info = self.retriever.get_default_fit_info(**self.params)

    def use_aws(self, ssh_file, ec2_id):
        """Sets appropriate parameters in order to perform processing
        using an AWS EC2 instance.

        Parameters
        ----------
        ssh_file : str
            The path to a public SSH key used to connect to the EC2
            instance.
        ec2_id : str
            A template id that points to a pre-built EC2 instance.
        """

        print('Using AWS for processing')
        logging.info('Using AWS for processing')

        self.ssh_file = ssh_file
        self.ec2_id = ec2_id

        # If the ec2_id is a template ID, then building the instance is required
        if ec2_id.split('-')[0] == 'lt':
            self.build_required = True
        else:
            self.build_required = False

        # Write out object to file
        with open('pw.obj', 'wb') as f:
            pickle.dump(self, f)
        print('Saved PlatonWrapper object to pw.obj')
        logging.info('Saved PlatonWrapper object to pw.obj')

        self.aws = True
Ejemplo n.º 5
0
class PlatonWrapper():
    """Class object for running the platon atmospheric retrieval
    software."""

    def __init__(self):
        """Initialize the class object."""

        self.retriever = Retriever()
        self.output_results = 'results.dat'
        self.output_plot = 'corner.png'

    def make_plot(self):
        """Create a corner plot that shows the results of the retrieval."""

        if self.method == 'emcee':
            fig = corner.corner(self.result.flatchain, range=[0.99] * self.result.flatchain.shape[1],
                        labels=self.fit_info.fit_param_names)

        elif self.method == 'multinest':
            fig = corner.corner(self.result.samples, weights=self.result.weights,
                                range=[0.99] * self.result.samples.shape[1],
                                labels=self.fit_info.fit_param_names)

        # Save the results
        self.output_plot = '{}_corner.png'.format(self.method)
        fig.savefig(self.output_plot)
        print('Corner plot saved to {}'.format(self.output_plot))

    def retrieve_emcee(self):
        """Perform the atmopsheric retrieval via emcee."""

        self.method = 'emcee'
        self.result = self.retriever.run_emcee(self.bins, self.depths, self.errors, self.fit_info)

    def retrieve_multinest(self):
        """Perform the atmopsheric retrieval via multinested sampling."""

        self.method = 'multinest'
        self.result = self.retriever.run_multinest(self.bins, self.depths, self.errors, self.fit_info, plot_best=False)

    def save_results(self):
        """Save the results of the retrieval to an output file."""

        # Save the results
        self.output_results = '{}_results.dat'.format(self.method)
        with open(self.output_results, 'w') as f:
            f.write(str(self.result))
        print('Results file saved to {}'.format(self.output_results))

    def set_parameters(self, params):
        """Set necessary parameters to perform the retrieval.

        Required parameters include ``Rs``, ``Mp``, ``Rp``, and ``T``.
        Optional parameters include ``logZ``, ``CO_ratio``,
        ``log_cloudtop_P``, ``log_scatt_factor``, ``scatt_slope``,
        ``error_multiple``, and ``T_star``.

        Parameters
        ----------
        params : dict
            A dictionary of parameters and their values for running the
            software.  See "Use" documentation for further details.
        """

        _validate_parameters(params)
        _apply_factors(params)
        self.params = params
        self.fit_info = self.retriever.get_default_fit_info(**self.params)