コード例 #1
0
ファイル: test_retrieve.py プロジェクト: soubkiou/platon
    def test_multinest(self):
        self.initialize(False)
        retriever = Retriever()
        result = retriever.run_multinest(self.wavelength_bins,
                                         self.depths,
                                         self.errors,
                                         self.fit_info,
                                         maxcall=200,
                                         include_condensation=False,
                                         plot_best=True)

        self.assertTrue(isinstance(result, nestle.Result))
        self.assertEqual(result.samples.shape[1],
                         len(self.fit_info.fit_param_names))

        retriever = Retriever()
        retriever.run_multinest(self.wavelength_bins,
                                self.depths,
                                self.errors,
                                self.fit_info,
                                maxiter=20,
                                include_condensation=True)
        self.assertTrue(isinstance(result, nestle.Result))
        self.assertEqual(result.samples.shape[1],
                         len(self.fit_info.fit_param_names))
コード例 #2
0
fit_info.add_gaussian_fit_param('Rs', 0.02*R_sun)
fit_info.add_gaussian_fit_param('Mp', 0.04*M_jup)

fit_info.add_uniform_fit_param('Rp', 0.9*R_guess, 1.1*R_guess)
fit_info.add_uniform_fit_param('T', 0.5*T_guess, 1.5*T_guess)
fit_info.add_uniform_fit_param("log_scatt_factor", 0, 1)
fit_info.add_uniform_fit_param("logZ", -1, 3)
fit_info.add_uniform_fit_param("log_cloudtop_P", -0.99, 5)
fit_info.add_uniform_fit_param("error_multiple", 0.5, 5)

#Use Nested Sampling to do the fitting
# with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
# with ProcessPoolExecutor() as executor:
with Pool(cpu_count()) as executor:
    result = retriever.run_multinest(wave_bins, depths, errors, fit_info, nestle_kwargs={'pool':executor, 'bootstrap':0})

time_stamp = datetime.utcnow().strftime("%Y%m%d%H%M%S")
result_dict = {'samples':result.samples, 'weights':result.weights, 'logl':result.logl}
joblib.dump(result_dict, 'multinest_results_{}.joblib.save'.format(time_stamp))


# Establish the Range in Wavelength to plot high resolution figures
wave_min = wave_bins.min()
wave_max = wave_bins.max()

n_theory_pts = 500
wavelengths_theory = np.linspace(wave_min, wave_max, n_theory_pts)
half_diff_lam = 0.5*np.median(np.diff(wavelengths_theory))

# Setup calculator to use the theoretical wavelengths
コード例 #3
0
#Add fitting parameters - this specifies which parameters you want to fit
#e.g. since we have not included cloudtop_P, it will be fixed at the value specified in the constructor

fit_info.add_gaussian_fit_param('Rs', 0.02 * R_sun)
fit_info.add_gaussian_fit_param('Mp', 0.04 * M_jup)

fit_info.add_uniform_fit_param('Rp', 0.9 * R_guess, 1.1 * R_guess)
fit_info.add_uniform_fit_param('T', 0.5 * T_guess, 1.5 * T_guess)
fit_info.add_uniform_fit_param("log_scatt_factor", 0, 1)
fit_info.add_uniform_fit_param("logZ", -1, 3)
fit_info.add_uniform_fit_param("log_cloudtop_P", -0.99, 5)
fit_info.add_uniform_fit_param("error_multiple", 0.5, 5)

#Use Nested Sampling to do the fitting
result = retriever.run_multinest(bins,
                                 depths,
                                 errors,
                                 fit_info,
                                 plot_best=True)
plt.savefig("best_fit.png")

np.save("samples.npy", result.samples)
np.save("weights.npy", result.weights)
np.save("logl.npy", result.logl)

fig = corner.corner(result.samples,
                    weights=result.weights,
                    range=[0.99] * result.samples.shape[1],
                    labels=fit_info.fit_param_names)
fig.savefig("multinest_corner.png")
コード例 #4
0
wavelengths, transit_depths = depth_calculator.compute_depths(
    Rp, temperature, logZ=logZ, CO_ratio=CO, cloudtop_pressure=1e3)
#wavelengths, depths2 = depth_calculator.compute_depths(71414515.1348402, P_prof

retriever = Retriever()

fit_info = retriever.get_default_fit_info(Rs,
                                          g,
                                          0.99 * Rp,
                                          0.9 * temperature,
                                          logZ=2,
                                          CO_ratio=1,
                                          add_fit_params=True)

errors = np.random.normal(scale=50e-6, size=len(transit_depths))
transit_depths += errors

result = retriever.run_multinest(wavelength_bins, transit_depths, errors,
                                 fit_info)
np.save("samples.npy", result.samples)
np.save("weights.npy", result.weights)
np.save("logl.npy", result.logl)

print fit_info.fit_param_names
fig = corner.corner(result.samples,
                    weights=result.weights,
                    range=[0.99] * result.samples.shape[1],
                    labels=fit_info.fit_param_names)
fig.savefig("multinest_corner.png")
コード例 #5
0
class PlatonWrapper():
    """Class object for running the platon atmospheric retrieval
    software."""
    def __init__(self):
        """Initialize the class object."""

        self.ec2_id = ''
        self.output_results = 'results.dat'
        self.output_plot = 'corner.png'
        self.retriever = Retriever()
        self.ssh_file = ''
        self.aws = False
        self._configure_logging()

    def _configure_logging(self):
        """Creates a log file that logs the execution of the script.

        Log files are written to a ``logs/`` subdirectory within the
        current working directory.

        Returns
        -------
        start_time : obj
            The start time of the script execution
        """

        # Define save location
        log_file = 'logs/{}.log'.format(
            datetime.datetime.now().strftime('%Y-%m-%d-%H-%M'))

        # Create the subdirectory if necessary
        if not os.path.exists('logs/'):
            os.mkdir('logs/')

        # Make sure no other root handlers exist before configuring the logger
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)

        # Create the log file
        logging.basicConfig(filename=log_file,
                            format='%(asctime)s %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S %p',
                            level=logging.INFO)
        print('Log file initialized to {}'.format(log_file))

        # Log environment information
        logging.info('User: '******'System: ' + socket.gethostname())
        logging.info('Python Version: ' + sys.version.replace('\n', ''))
        logging.info('Python Executable Path: ' + sys.executable)

        self.start_time = time.time()

    def make_plot(self):
        """Create a corner plot that shows the results of the retrieval."""

        print('Creating corner plot')
        logging.info('Creating corner plot')

        matplotlib.rcParams['text.usetex'] = False

        if self.method == 'emcee':
            fig = corner.corner(self.result.flatchain,
                                range=[0.99] * self.result.flatchain.shape[1],
                                labels=self.fit_info.fit_param_names)

        elif self.method == 'multinest':
            fig = corner.corner(self.result.samples,
                                weights=self.result.weights,
                                range=[0.99] * self.result.samples.shape[1],
                                labels=self.fit_info.fit_param_names)

        # Save the results
        self.output_plot = '{}_corner.png'.format(self.method)
        fig.savefig(self.output_plot)
        print('Corner plot saved to {}'.format(self.output_plot))
        logging.info('Corner plot saved to {}'.format(self.output_plot))

    def retrieve(self, method):
        """Perform the atmopsheric retrieval via the given method

        Parameters
        ----------
        method : str
            The method by which to perform atmospheric retrievals.  Can
            either be ``emcee`` or ``multinest``."""

        print('Performing atmopsheric retrievals via {}'.format(method))
        logging.info('Performing atmopsheric retrievals via {}'.format(method))

        # Ensure that the method parameter is valid
        assert method in ['multinest',
                          'emcee'], 'Unrecognized method: {}'.format(method)
        self.method = method

        # For processing on AWS
        if self.aws:

            # Start or create an EC2 instance
            instance, key, client = start_ec2(self.ssh_file, self.ec2_id)

            # Build the environment on EC2 instance if necessary
            if self.build_required:
                build_environment(instance, key, client)

            # Transfer object file to EC2
            transfer_to_ec2(instance, key, client, 'pw.obj')

            # Connect to the EC2 instance and run commands
            command = './exoctk/exoctk/atmospheric_retrievals/exoctk-env-init.sh python exoctk/exoctk/atmospheric_retrievals/platon_wrapper.py {}'.format(
                self.method)
            client.connect(hostname=instance.public_dns_name,
                           username='******',
                           pkey=key)
            stdin, stdout, stderr = client.exec_command(command)
            output = stdout.read()
            errors = stderr.read()
            log_output(output)
            log_output(errors)

            # Trasfer output products from EC2 to user
            if self.method == 'emcee':
                transfer_from_ec2(instance, key, client, 'emcee_results.obj')
                transfer_from_ec2(instance, key, client, 'emcee_corner.png')
            elif self.method == 'multinest':
                transfer_from_ec2(instance, key, client,
                                  'multinest_results.dat')
                transfer_from_ec2(instance, key, client,
                                  'multinest_corner.png')

            # Terminate or stop the EC2 instance
            stop_ec2(self.ec2_id, instance)

        # For processing locally
        else:
            if self.method == 'emcee':
                self.result = self.retriever.run_emcee(self.bins, self.depths,
                                                       self.errors,
                                                       self.fit_info)
            elif self.method == 'multinest':
                self.result = self.retriever.run_multinest(self.bins,
                                                           self.depths,
                                                           self.errors,
                                                           self.fit_info,
                                                           plot_best=False)

        _log_execution_time(self.start_time)

    def save_results(self):
        """Save the results of the retrieval to an output file."""

        print('Saving results')
        logging.info('Saving results')

        # Save the results
        if self.method == 'multinest':
            self.output_results = 'multinest_results.dat'
            with open(self.output_results, 'w') as f:
                f.write(str(self.result))
        elif self.method == 'emcee':
            self.output_results = 'emcee_results.obj'
            with open(self.output_results, 'wb') as f:
                pickle.dump(self.result, f)

        print('Results file saved to {}'.format(self.output_results))
        logging.info('Results file saved to {}'.format(self.output_results))

    def set_parameters(self, params):
        """Set necessary parameters to perform the retrieval.

        Required parameters include ``Rs``, ``Mp``, ``Rp``, and ``T``.
        Optional parameters include ``logZ``, ``CO_ratio``,
        ``log_cloudtop_P``, ``log_scatt_factor``, ``scatt_slope``,
        ``error_multiple``, and ``T_star``.

        Parameters
        ----------
        params : str or dict
            Either a path to a params file to use, or a dictionary of
            parameters and their values for running the software.
            See "Use" documentation for further details.
        """

        print('Setting parameters: {}'.format(params))
        logging.info('Setting parameters: {}'.format(params))

        _validate_parameters(params)
        _apply_factors(params)
        self.params = params
        self.fit_info = self.retriever.get_default_fit_info(**self.params)

    def use_aws(self, ssh_file, ec2_id):
        """Sets appropriate parameters in order to perform processing
        using an AWS EC2 instance.

        Parameters
        ----------
        ssh_file : str
            The path to a public SSH key used to connect to the EC2
            instance.
        ec2_id : str
            A template id that points to a pre-built EC2 instance.
        """

        print('Using AWS for processing')
        logging.info('Using AWS for processing')

        self.ssh_file = ssh_file
        self.ec2_id = ec2_id

        # If the ec2_id is a template ID, then building the instance is required
        if ec2_id.split('-')[0] == 'lt':
            self.build_required = True
        else:
            self.build_required = False

        # Write out object to file
        with open('pw.obj', 'wb') as f:
            pickle.dump(self, f)
        print('Saved PlatonWrapper object to pw.obj')
        logging.info('Saved PlatonWrapper object to pw.obj')

        self.aws = True
コード例 #6
0
class PlatonWrapper():
    """Class object for running the platon atmospheric retrieval
    software."""

    def __init__(self):
        """Initialize the class object."""

        self.retriever = Retriever()
        self.output_results = 'results.dat'
        self.output_plot = 'corner.png'

    def make_plot(self):
        """Create a corner plot that shows the results of the retrieval."""

        if self.method == 'emcee':
            fig = corner.corner(self.result.flatchain, range=[0.99] * self.result.flatchain.shape[1],
                        labels=self.fit_info.fit_param_names)

        elif self.method == 'multinest':
            fig = corner.corner(self.result.samples, weights=self.result.weights,
                                range=[0.99] * self.result.samples.shape[1],
                                labels=self.fit_info.fit_param_names)

        # Save the results
        self.output_plot = '{}_corner.png'.format(self.method)
        fig.savefig(self.output_plot)
        print('Corner plot saved to {}'.format(self.output_plot))

    def retrieve_emcee(self):
        """Perform the atmopsheric retrieval via emcee."""

        self.method = 'emcee'
        self.result = self.retriever.run_emcee(self.bins, self.depths, self.errors, self.fit_info)

    def retrieve_multinest(self):
        """Perform the atmopsheric retrieval via multinested sampling."""

        self.method = 'multinest'
        self.result = self.retriever.run_multinest(self.bins, self.depths, self.errors, self.fit_info, plot_best=False)

    def save_results(self):
        """Save the results of the retrieval to an output file."""

        # Save the results
        self.output_results = '{}_results.dat'.format(self.method)
        with open(self.output_results, 'w') as f:
            f.write(str(self.result))
        print('Results file saved to {}'.format(self.output_results))

    def set_parameters(self, params):
        """Set necessary parameters to perform the retrieval.

        Required parameters include ``Rs``, ``Mp``, ``Rp``, and ``T``.
        Optional parameters include ``logZ``, ``CO_ratio``,
        ``log_cloudtop_P``, ``log_scatt_factor``, ``scatt_slope``,
        ``error_multiple``, and ``T_star``.

        Parameters
        ----------
        params : dict
            A dictionary of parameters and their values for running the
            software.  See "Use" documentation for further details.
        """

        _validate_parameters(params)
        _apply_factors(params)
        self.params = params
        self.fit_info = self.retriever.get_default_fit_info(**self.params)