コード例 #1
0
ファイル: theory.py プロジェクト: nikfilippas/LotssCross
    def sample(self, verbosity=0, use_mpi=False):
        import emcee
        if use_mpi:
            from schwimmbad import MPIPool
            pool = MPIPool()
            print("Using MPI")
            pool_use = pool
        else:
            pool = DumPool()
            print("Not using MPI")
            pool_use = None

        if not pool.is_master():
            pool.wait()
            sys.exit(0)

        fname_chain = self.prefix_out + "chain"
        found_file = os.path.isfile(fname_chain + '.txt')

        if (not found_file) or self.rerun:
            pos_ini = (np.array(self.p0)[None, :] +
                       0.001 * np.random.randn(self.nwalkers, self.ndim))
            nsteps_use = self.nsteps
        else:
            print("Restarting from previous run")
            old_chain = np.loadtxt(fname_chain + '.txt')
            if np.ndim(old_chain) == 1:
                old_chain = np.atleast_2d(old_chain).T
            pos_ini = old_chain[-self.nwalkers:, :]
            nsteps_use = max(self.nsteps - len(old_chain) // self.nwalkers, 0)
            print(self.nsteps - len(old_chain) // self.nwalkers)

        chain_file = SampleFileUtil(self.prefix_out + "chain",
                                    rerun=self.rerun)
        sampler = emcee.EnsembleSampler(self.nwalkers,
                                        self.ndim,
                                        self.lnprob,
                                        pool=pool_use)
        counter = 1
        for pos, prob, _ in sampler.sample(pos_ini, iterations=nsteps_use):
            if pool.is_master():
                chain_file.persistSamplingValues(pos, prob)

                if counter % 10 == 0:
                    print(f"Finished sample {counter}")
            counter += 1

        pool.close()

        return sampler
コード例 #2
0
def postprocess_run(jobdir,
                    savename,
                    exp_type,
                    fields,
                    save_beta=False,
                    comm=None,
                    return_dframe=True):

    # Distribute postprocessing across ranks, if desired
    if comm is not None:
        # Collect all .h5 files
        if comm.rank == 0:
            data_files = grab_files(jobdir, '*.dat', exp_type)
            data_files = [(d) for d in data_files]
            print(len(data_files))
        else:
            data_files = None

        rank = comm.rank
        size = comm.size
        # print('Rank %d' % rank)
        master = StreamWorker(savename)
        worker = PostprocessWorker(jobdir, fields, rank, size)
        pool = MPIPool(comm)
        pool.map(worker,
                 data_files,
                 callback=master.stream,
                 track_results=False)
        if not pool.is_master():
            pool.wait()
            sys.exit(0)
        pool.close()

        if rank == 0:
            master.close()

    else:
        worker = PostprocessWorker(jobdir, fields, savename)
        for i, data_file in enumerate(data_files):
            t0 = time.time()
            result = worker(data_file)
            worker.extend(result)
            print(time.time() - t0)
        dframe = worker.save(save_beta)

    if return_dframe:
        return dframe
コード例 #3
0
ファイル: example.py プロジェクト: aphriksee/eb-mcmc
def main_fit():
    pool = MPIPool()
    if not pool.is_master():
        pool.wait()
        sys.exit(0)

    # Create an initial point
    psize = pos_lim_max - pos_lim_min
    p0 = [pos_lim_min + psize * np.random.rand(ndim) for i in range(nwalkers)]

    # Set a sample
    sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, pool=pool)

    # Run MCMC
    sampler.run_mcmc(p0, burnin)

    # Flat sample and discard 10% sample
    flat_samples = sampler.get_chain(discard=int(0.1 * burnin), flat=True)

    pool.close()
    return flat_samples
コード例 #4
0
ファイル: fitter.py プロジェクト: Hoptune/MOSFiT
    def fit_events(self,
                   events=[],
                   models=[],
                   max_time='',
                   time_list=[],
                   time_unit=None,
                   band_list=[],
                   band_systems=[],
                   band_instruments=[],
                   band_bandsets=[],
                   band_sampling_points=17,
                   iterations=10000,
                   num_walkers=None,
                   num_temps=1,
                   parameter_paths=['parameters.json'],
                   fracking=True,
                   frack_step=50,
                   burn=None,
                   post_burn=None,
                   gibbs=False,
                   smooth_times=-1,
                   extrapolate_time=0.0,
                   limit_fitting_mjds=False,
                   exclude_bands=[],
                   exclude_instruments=[],
                   exclude_systems=[],
                   exclude_sources=[],
                   exclude_kinds=[],
                   output_path='',
                   suffix='',
                   upload=False,
                   write=False,
                   upload_token='',
                   check_upload_quality=False,
                   variance_for_each=[],
                   user_fixed_parameters=[],
                   user_released_parameters=[],
                   convergence_type=None,
                   convergence_criteria=None,
                   save_full_chain=False,
                   draw_above_likelihood=False,
                   maximum_walltime=False,
                   start_time=False,
                   print_trees=False,
                   maximum_memory=np.inf,
                   speak=False,
                   return_fits=True,
                   extra_outputs=None,
                   walker_paths=[],
                   catalogs=[],
                   exit_on_prompt=False,
                   download_recommended_data=False,
                   local_data_only=False,
                   method=None,
                   seed=None,
                   **kwargs):
        """Fit a list of events with a list of models."""
        global model
        if start_time is False:
            start_time = time.time()

        self._seed = seed
        if seed is not None:
            np.random.seed(seed)

        self._start_time = start_time
        self._maximum_walltime = maximum_walltime
        self._maximum_memory = maximum_memory
        self._debug = False
        self._speak = speak
        self._download_recommended_data = download_recommended_data
        self._local_data_only = local_data_only

        self._draw_above_likelihood = draw_above_likelihood

        prt = self._printer

        event_list = listify(events)
        model_list = listify(models)

        if len(model_list) and not len(event_list):
            event_list = ['']

        # Exclude catalogs not included in catalog list.
        self._fetcher.add_excluded_catalogs(catalogs)

        if not len(event_list) and not len(model_list):
            prt.message('no_events_models', warning=True)

        # If the input is not a JSON file, assume it is either a list of
        # transients or that it is the data from a single transient in tabular
        # form. Try to guess the format first, and if that fails ask the user.
        self._converter = Converter(prt, require_source=upload)
        event_list = self._converter.generate_event_list(event_list)

        event_list = [x.replace('‑', '-') for x in event_list]

        entries = [[] for x in range(len(event_list))]
        ps = [[] for x in range(len(event_list))]
        lnprobs = [[] for x in range(len(event_list))]

        # Load walker data if provided a list of walker paths.
        walker_data = []

        if len(walker_paths):
            try:
                pool = MPIPool()
            except (ImportError, ValueError):
                pool = SerialPool()
            if pool.is_master():
                prt.message('walker_file')
                wfi = 0
                for walker_path in walker_paths:
                    if os.path.exists(walker_path):
                        prt.prt('  {}'.format(walker_path))
                        with codecs.open(walker_path, 'r',
                                         encoding='utf-8') as f:
                            all_walker_data = json.load(
                                f, object_pairs_hook=OrderedDict)

                        # Support both the format where all data stored in a
                        # single-item dictionary (the OAC format) and the older
                        # MOSFiT format where the data was stored in the
                        # top-level dictionary.
                        if ENTRY.NAME not in all_walker_data:
                            all_walker_data = all_walker_data[list(
                                all_walker_data.keys())[0]]

                        models = all_walker_data.get(ENTRY.MODELS, [])
                        choice = None
                        if len(models) > 1:
                            model_opts = [
                                '{}-{}-{}'.format(x['code'], x['name'],
                                                  x['date']) for x in models
                            ]
                            choice = prt.prompt('select_model_walkers',
                                                kind='select',
                                                message=True,
                                                options=model_opts)
                            choice = model_opts.index(choice)
                        elif len(models) == 1:
                            choice = 0

                        if choice is not None:
                            walker_data.extend([[
                                wfi, x[REALIZATION.PARAMETERS],
                                x.get(REALIZATION.WEIGHT)
                            ] for x in models[choice][MODEL.REALIZATIONS]])

                        for i in range(len(walker_data)):
                            if walker_data[i][2] is not None:
                                walker_data[i][2] = float(walker_data[i][2])

                        if not len(walker_data):
                            prt.message('no_walker_data')
                    else:
                        prt.message('no_walker_data')
                        if self._offline:
                            prt.message('omit_offline')
                        raise RuntimeError
                    wfi = wfi + 1

                for rank in range(1, pool.size + 1):
                    pool.comm.send(walker_data, dest=rank, tag=3)
            else:
                walker_data = pool.comm.recv(source=0, tag=3)
                pool.wait()

            if pool.is_master():
                pool.close()

        self._event_name = 'Batch'
        self._event_path = ''
        self._event_data = {}

        try:
            pool = MPIPool()
        except (ImportError, ValueError):
            pool = SerialPool()
        if pool.is_master():
            fetched_events = self._fetcher.fetch(
                event_list,
                offline=self._offline,
                prefer_cache=self._prefer_cache)

            for rank in range(1, pool.size + 1):
                pool.comm.send(fetched_events, dest=rank, tag=0)
            pool.close()
        else:
            fetched_events = pool.comm.recv(source=0, tag=0)
            pool.wait()

        for ei, event in enumerate(fetched_events):
            if event is not None:
                self._event_name = event.get('name', 'Batch')
                self._event_path = event.get('path', '')
                if not self._event_path:
                    continue
                self._event_data = self._fetcher.load_data(event)
                if not self._event_data:
                    continue

            if model_list:
                lmodel_list = model_list
            else:
                lmodel_list = ['']

            entries[ei] = [None for y in range(len(lmodel_list))]
            ps[ei] = [None for y in range(len(lmodel_list))]
            lnprobs[ei] = [None for y in range(len(lmodel_list))]

            if (event is not None and
                (not self._event_data or ENTRY.PHOTOMETRY
                 not in self._event_data[list(self._event_data.keys())[0]])):
                prt.message('no_photometry', [self._event_name])
                continue

            for mi, mod_name in enumerate(lmodel_list):
                for parameter_path in parameter_paths:
                    try:
                        pool = MPIPool()
                    except (ImportError, ValueError):
                        pool = SerialPool()
                    self._model = Model(model=mod_name,
                                        data=self._event_data,
                                        parameter_path=parameter_path,
                                        output_path=output_path,
                                        wrap_length=self._wrap_length,
                                        test=self._test,
                                        printer=prt,
                                        fitter=self,
                                        pool=pool,
                                        print_trees=print_trees)

                    if not self._model._model_name:
                        prt.message('no_models_avail', [self._event_name],
                                    warning=True)
                        continue

                    if not event:
                        prt.message('gen_dummy')
                        self._event_name = mod_name
                        gen_args = {
                            'name': mod_name,
                            'max_time': max_time,
                            'time_list': time_list,
                            'band_list': band_list,
                            'band_systems': band_systems,
                            'band_instruments': band_instruments,
                            'band_bandsets': band_bandsets
                        }
                        self._event_data = self.generate_dummy_data(**gen_args)

                    success = False
                    alt_name = None
                    while not success:
                        self._model.reset_unset_recommended_keys()
                        success = self._model.load_data(
                            self._event_data,
                            event_name=self._event_name,
                            smooth_times=smooth_times,
                            extrapolate_time=extrapolate_time,
                            limit_fitting_mjds=limit_fitting_mjds,
                            exclude_bands=exclude_bands,
                            exclude_instruments=exclude_instruments,
                            exclude_systems=exclude_systems,
                            exclude_sources=exclude_sources,
                            exclude_kinds=exclude_kinds,
                            time_list=time_list,
                            time_unit=time_unit,
                            band_list=band_list,
                            band_systems=band_systems,
                            band_instruments=band_instruments,
                            band_bandsets=band_bandsets,
                            band_sampling_points=band_sampling_points,
                            variance_for_each=variance_for_each,
                            user_fixed_parameters=user_fixed_parameters,
                            user_released_parameters=user_released_parameters,
                            pool=pool)

                        if not success:
                            break

                        if self._local_data_only:
                            break

                        # If our data is missing recommended keys, offer the
                        # user option to pull the missing data from online and
                        # merge with existing data.
                        urk = self._model.get_unset_recommended_keys()
                        ptxt = prt.text('acquire_recommended',
                                        [', '.join(list(urk))])
                        while event and len(urk) and (
                                alt_name or self._download_recommended_data
                                or prt.prompt(ptxt, [', '.join(urk)],
                                              kind='bool')):
                            try:
                                pool = MPIPool()
                            except (ImportError, ValueError):
                                pool = SerialPool()
                            if pool.is_master():
                                en = (alt_name
                                      if alt_name else self._event_name)
                                extra_event = self._fetcher.fetch(
                                    en,
                                    offline=self._offline,
                                    prefer_cache=self._prefer_cache)[0]
                                extra_data = self._fetcher.load_data(
                                    extra_event)

                                for rank in range(1, pool.size + 1):
                                    pool.comm.send(extra_data,
                                                   dest=rank,
                                                   tag=4)
                                pool.close()
                            else:
                                extra_data = pool.comm.recv(source=0, tag=4)
                                pool.wait()

                            if extra_data is not None:
                                extra_data = extra_data[list(
                                    extra_data.keys())[0]]

                                for key in urk:
                                    new_val = extra_data.get(key)
                                    self._event_data[list(
                                        self._event_data.keys())
                                                     [0]][key] = new_val
                                    if new_val is not None and len(new_val):
                                        prt.message('extra_value', [
                                            key,
                                            str(new_val[0].get(QUANTITY.VALUE))
                                        ])
                                success = False
                                prt.message('reloading_merged')
                                break
                            else:
                                text = prt.text('extra_not_found',
                                                [self._event_name])
                                alt_name = prt.prompt(text, kind='string')
                                if not alt_name:
                                    break

                    if success:
                        self._walker_data = walker_data

                        entry, p, lnprob = self.fit_data(
                            event_name=self._event_name,
                            method=method,
                            iterations=iterations,
                            num_walkers=num_walkers,
                            num_temps=num_temps,
                            burn=burn,
                            post_burn=post_burn,
                            fracking=fracking,
                            frack_step=frack_step,
                            gibbs=gibbs,
                            pool=pool,
                            output_path=output_path,
                            suffix=suffix,
                            write=write,
                            upload=upload,
                            upload_token=upload_token,
                            check_upload_quality=check_upload_quality,
                            convergence_type=convergence_type,
                            convergence_criteria=convergence_criteria,
                            save_full_chain=save_full_chain,
                            extra_outputs=extra_outputs)
                        if return_fits:
                            entries[ei][mi] = deepcopy(entry)
                            ps[ei][mi] = deepcopy(p)
                            lnprobs[ei][mi] = deepcopy(lnprob)

                    if pool.is_master():
                        pool.close()

                    # Remove global model variable and garbage collect.
                    try:
                        model
                    except NameError:
                        pass
                    else:
                        del (model)
                    del (self._model)
                    gc.collect()

        return (entries, ps, lnprobs)
コード例 #5
0
    cat = fits.open('catalogs/derived/%s_colored.fits' % sample_name)[1].data

    if stack_maps:
        for j in range(int(np.max(cat['colorbin']))):
            colorcat = cat[np.where(cat['colorbin'] == (j + 1))]
            stack_mp(planck_map,
                     colorcat['RA'],
                     colorcat['DEC'],
                     pool,
                     weighting=colorcat['weight'],
                     prob_weights=colorcat['PQSO'],
                     outname=(outname + '%s' % j),
                     imsize=imsize,
                     reso=reso)
        if not stack_noise:
            pool.close()

    if stack_noise:
        noisemaplist = glob.glob('noisemaps/maps/*.fits')
        colorcat = cat[np.where(cat['colorbin'] == 5)]
        for j in range(len(noisemaplist)):
            noisemap = hp.read_map('noisemaps/maps/%s.fits' % j,
                                   dtype=np.single)
            stack_mp(noisemap,
                     colorcat['RA'],
                     colorcat['DEC'],
                     pool,
                     weighting=colorcat['weight'],
                     prob_weights=colorcat['PQSO'],
                     outname='noise_stacks/map%s' % (j),
                     imsize=imsize,
コード例 #6
0
def main(path2config, time_likelihood):

    # Read params from yaml
    config = yaml.load(open(path2config))
    mode = config['modus_operandi']
    default_params = config['default_params']
    power_params = config['power_params']
    template_params = config['template_params']
    cl_params = config['cl_params']
    ch_config_params = config['ch_config_params']
    fit_params = config['fit_params']

    # parameters to fit
    nparams = len(fit_params.keys())
    param_mapping = {}
    params = np.zeros((nparams, 4))
    for key in fit_params.keys():
        param_mapping[key] = fit_params[key][0]
        params[fit_params[key][0], :] = fit_params[key][1:]

    # Cosmology
    if set(COSMO_PARAM_KEYS) == set(default_params.keys()):
        print("We are NOT varying the cosmology")
        #cosmo_dict = {}
        #for key in COSMO_PARAM_KEYS:
        #cosmo_dict[key] = default_params[key]
        cosmo = ccl.Cosmology(**default_params)
    else:
        print("We ARE varying the cosmology")
        cosmo = None

    if power_params['lmax'] == 'kmax':
        lmax = kmax2lmax(power_params['kmax'], power_params['z'], cosmo)
        power_params['lmax'] = lmax
    if power_params['lmin'] == 'kmax':
        lmin = 0.
        power_params['lmin'] = lmin

    # read data parameters
    Data = PowerData(mode, power_params)
    Data.setup()

    # read theory parameters
    Theory = PowerTheory(mode, Data.x, Data.z, template_params, cl_params,
                         power_params, default_params, param_mapping)
    Theory.setup()

    # initialize the Cl templates
    if mode == 'Cl' and cosmo != None:
        Theory.init_cl_ij(cosmo)

    # Make path to output
    if not os.path.isdir(os.path.expanduser(ch_config_params['path2output'])):
        try:
            os.makedirs(os.path.expanduser(ch_config_params['path2output']))
        except:
            pass

    # MPI option
    if ch_config_params['use_mpi']:
        from schwimmbad import MPIPool
        pool = MPIPool()
        print("Using MPI")
        pool_use = pool
    else:
        pool = DumPool()
        print("Not using MPI")
        pool_use = None

    if not pool.is_master():
        pool.wait()
        sys.exit(0)

    # just time the likelihood calculation
    if time_likelihood:
        time_lnprob(params, Data, Theory)
        return

    # emcee parameters
    nwalkers = nparams * ch_config_params['walkersRatio']
    nsteps = ch_config_params['burninIterations'] + ch_config_params[
        'sampleIterations']

    # where to record
    prefix_chain = os.path.join(
        os.path.expanduser(ch_config_params['path2output']),
        ch_config_params['chainsPrefix'])

    # fix initial conditions
    found_file = os.path.isfile(prefix_chain + '.txt')
    if (not found_file) or (not ch_config_params['rerun']):
        p_initial = params[:, 0] + np.random.normal(
            size=(nwalkers, nparams)) * params[:, 3][None, :]
        nsteps_use = nsteps
    else:
        print("Restarting from a previous run")
        old_chain = np.loadtxt(prefix_chain + '.txt')
        p_initial = old_chain[-nwalkers:, :]
        nsteps_use = max(nsteps - len(old_chain) // nwalkers, 0)

    # initializing sampler
    chain_file = SampleFileUtil(prefix_chain,
                                carry_on=ch_config_params['rerun'])
    sampler = emcee.EnsembleSampler(nwalkers,
                                    nparams,
                                    lnprob,
                                    args=(params, Data, Theory),
                                    pool=pool_use)
    start = time.time()
    print("Running %d samples" % nsteps_use)

    # record every iteration
    counter = 1
    for pos, prob, _ in sampler.sample(p_initial, iterations=nsteps_use):
        if pool.is_master():
            print('Iteration done. Persisting.')
            chain_file.persistSamplingValues(pos, prob)

            if counter % 10:
                print(f"Finished sample {counter}")
        counter += 1

    pool.close()
    end = time.time()
    print("Took ", (end - start), " seconds")
コード例 #7
0
def main(path2config, time_likelihood):

    # load the yaml parameters
    config = yaml.load(open(path2config))
    sim_params = config['sim_params']
    HOD_params = config['HOD_params']
    clustering_params = config['clustering_params']
    data_params = config['data_params']
    ch_config_params = config['ch_config_params']
    fit_params = config['fit_params']

    # create a new abacushod object and load the subsamples
    newBall = AbacusHOD(sim_params, HOD_params, clustering_params)

    # read data parameters
    newData = PowerData(data_params, HOD_params)

    # parameters to fit
    nparams = len(fit_params.keys())
    param_mapping = {}
    param_tracer = {}
    params = np.zeros((nparams, 4))
    for key in fit_params.keys():
        mapping_idx = fit_params[key][0]
        tracer_type = fit_params[key][-1]
        param_mapping[key] = mapping_idx
        param_tracer[key] = tracer_type
        params[mapping_idx, :] = fit_params[key][1:-1]

    # Make path to output
    if not os.path.isdir(os.path.expanduser(ch_config_params['path2output'])):
        try:
            os.makedirs(os.path.expanduser(ch_config_params['path2output']))
        except:
            pass

    # MPI option
    if ch_config_params['use_mpi']:
        from schwimmbad import MPIPool
        pool = MPIPool()
        print("Using MPI")
        pool_use = pool
    else:
        pool = DumPool()
        print("Not using MPI")
        pool_use = None

    if not pool.is_master():
        pool.wait()
        sys.exit(0)

    # just time the likelihood calculation
    if time_likelihood:
        time_lnprob(params, param_mapping, param_tracer, newData, newBall)
        return

    # emcee parameters
    nwalkers = nparams * ch_config_params['walkersRatio']
    nsteps = ch_config_params['burninIterations'] + ch_config_params[
        'sampleIterations']

    # where to record
    prefix_chain = os.path.join(
        os.path.expanduser(ch_config_params['path2output']),
        ch_config_params['chainsPrefix'])

    # fix initial conditions
    found_file = os.path.isfile(prefix_chain + '.txt')
    if (not found_file) or (not ch_config_params['rerun']):
        p_initial = params[:, 0] + np.random.normal(
            size=(nwalkers, nparams)) * params[:, 3][None, :]
        nsteps_use = nsteps
    else:
        print("Restarting from a previous run")
        old_chain = np.loadtxt(prefix_chain + '.txt')
        p_initial = old_chain[-nwalkers:, :]
        nsteps_use = max(nsteps - len(old_chain) // nwalkers, 0)

    # initializing sampler
    chain_file = SampleFileUtil(prefix_chain,
                                carry_on=ch_config_params['rerun'])
    sampler = emcee.EnsembleSampler(nwalkers,
                                    nparams,
                                    lnprob,
                                    args=(params, param_mapping, param_tracer,
                                          newData, newBall),
                                    pool=pool_use)
    start = time.time()
    print("Running %d samples" % nsteps_use)

    # record every iteration
    counter = 1
    for pos, prob, _ in sampler.sample(p_initial, iterations=nsteps_use):
        if pool.is_master():
            print('Iteration done. Persisting.')
            chain_file.persistSamplingValues(pos, prob)

            if counter % 10:
                print(f"Finished sample {counter}")
        counter += 1

    pool.close()
    end = time.time()
    print("Took ", (end - start), " seconds")
コード例 #8
0
	def MCMC(self, nwalkers=50, nburn=200, nMCMC=1000, use_MPI=False, chain_file='chain.dat', fig_name='./MCMC_corner.png', plot_corner=False, **kwargs):
		# The function to carry out MCMC. 
		# parameters:
		# 	nwalkers: int, optional
		# 		the number of walkers in MCMC, which must be even. 
		# 		default: 50
		# 	nburn: int, optional
		# 		the number of burn-in steps in MCMC.
		# 		default: 200
		# 	nMCMC: int, optional
		# 		the number of final MCMC steps in MCMC.
		# 		default: 1000
		# 	use_MPI: Boolean, optional
		# 		whether to use MPI. 
		# 		default: False
		# returns:
		# 	p_best: array_like
		# 		best fitting parameter set.
		
		# Initialize the walkers with a set of initial points, p0.
		E0 = np.random.normal(0.5, 0.3, size=nwalkers)
		T0 = np.random.normal(0.5, 0.3, size=nwalkers)
		a0 = np.random.normal(0, 0.7, size=nwalkers)
		covEE = truncnorm.rvs(0, 1, loc=0.3, scale=0.1, size=nwalkers)
		covTT = truncnorm.rvs(0 ,1, loc=0.3, scale=0.1, size=nwalkers)
		covaa = truncnorm.rvs(0, 1, loc=0.1, scale=0.1, size=nwalkers)
		covEa = truncnorm.rvs(0, 1, loc=0.1, scale=0.1, size=nwalkers)
		p0 = [[E0[i], T0[i], a0[i], covEE[i], covTT[i], covaa[i], covEa[i]] for i in range(nwalkers)]
		print 'start MCMC.'


		if not use_MPI:
			sampler = EnsembleSampler(nwalkers, self.ndim, lnprob, **kwargs)	
			# sampler = EnsembleSampler(nwalkers, self.ndim, lnprob, \
			# 		  args=(self.E_grid, self.T_grid, self.a_grid, self.ba_lgSMA_bins, self.bin_obs, self.num_obs), **kwargs)	

		# When using MPI, we differentiate between different processes.
		else:
			pool = MPIPool()
			if not pool.is_master():
				pool.wait()
				sys.exit(0)
			# sampler = EnsembleSampler(nwalkers, self.ndim, lnprob, \
			# 		  args=(self.E_grid, self.T_grid, self.a_grid, self.ba_lgSMA_bins, self.bin_obs, self.num_obs), pool=pool, **kwargs)
			sampler = EnsembleSampler(nwalkers, self.ndim, lnprob,  pool=pool, **kwargs)

		# burn-in phase
		pos, prob, state = sampler.run_mcmc(p0, nburn, chain_file=chain_file)
		sampler.reset()

		# MCMC phase
		sampler.run_mcmc(pos, nMCMC, chain_file=chain_file)

		if use_MPI:
			pool.close()
		
		# If we want to make classic corner plots...
		if plot_corner:
			samples = sampler.chain[:, nMCMC / 2:, :].reshape((-1, self.ndim))
			fig = corner.corner(samples, labels=['E', 'T', 'a', 'covEE', 'covTT', 'covaa', 'covEa'])
			fig.savefig(fig_name)

		# Get the best fitting parameters. We take the median parameter value for the ensemble
		# of steps with log-probabilities within the largest 30% among the whole ensemble as the
		# best parameters.
		samples = sampler.flatchain
		lnp = sampler.flatlnprobability
		crit_lnp = np.percentile(lnp, 70)
		good = np.where(lnp > crit_lnp)
		p_best = [np.median(samples[good, i]) for i in range(self.ndim)]

		return np.array(p_best)
コード例 #9
0
ファイル: sampler.py プロジェクト: nikfilippas/yxgxk
    def sample(self, carry_on=False, verbosity=0, use_mpi=False):
        """
        Sample the posterior distribution

        Args:
            carry_on (bool): if True, the sampler will restart from
                its last iteration.
            verbosity (int): if >0, progress will be reported.
            use_mpi (bool): set to True to parallelize with MPI

        Returns:
            :obj:`emcee.EnsembleSampler`: sampler with chain.
        """
        import emcee
        if use_mpi:
            from schwimmbad import MPIPool
            pool = MPIPool()
            print("Using MPI")
            pool_use = pool
        else:
            pool = DumPool()
            print("Not using MPI")
            pool_use = None

        if not pool.is_master():
            pool.wait()
            sys.exit(0)

        fname_chain = self.prefix_out+"chain"
        found_file = os.path.isfile(fname_chain+'.txt')

        counter = 1
        if (not found_file) or (not carry_on):
            pos_ini = (np.array(self.p0)[None, :] +
                       0.001 * np.random.randn(self.nwalkers, self.ndim))
            nsteps_use = self.nsteps
        else:
            print("Restarting from previous run")
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                old_chain = np.loadtxt(fname_chain+'.txt')
            if old_chain.size != 0:
                pos_ini = old_chain[-self.nwalkers:, :]
                nsteps_use = max(self.nsteps-len(old_chain) // self.nwalkers, 0)
                counter = len(old_chain) // self.nwalkers
                # print(self.nsteps - len(old_chain) // self.nwalkers)
            else:
                pos_ini = (np.array(self.p0)[None, :] +
                           0.001 * np.random.randn(self.nwalkers, self.ndim))
                nsteps_use = self.nsteps

        chain_file = SampleFileUtil(self.prefix_out+"chain", carry_on=carry_on)
        sampler = emcee.EnsembleSampler(self.nwalkers,
                                        self.ndim,
                                        self.lnprob,
                                        pool=pool_use)

        for pos, prob, _ in sampler.sample(pos_ini, iterations=nsteps_use):
            if pool.is_master():
                if verbosity > 0:
                    print('Iteration done. Persisting.')
                    chain_file.persistSamplingValues(pos, prob)

                    if (counter % 10) == 0:
                        print(f"Finished sample {counter}")
            counter += 1

        pool.close()

        return sampler
コード例 #10
0
    #Read in and zip up dataframes
    
    sCM20_df = pd.read_hdf('models.h5','sCM20')
    lCM20_df = pd.read_hdf('models.h5','lCM20')
    aSilM5_df = pd.read_hdf('models.h5','aSilM5')
    wavelength_df = pd.read_hdf('models.h5','wavelength')
    
    pandas_dfs = [sCM20_df,lCM20_df,
                  aSilM5_df,wavelength_df]
    
    if args.mpi:
        
        mpi_pool = MPIPool()
        
        if not mpi_pool.is_master():
            mpi_pool.wait()
            sys.exit(0)
        
        mpi_pool.map( main,
                      range(len(flux_df)) )
        
        mpi_pool.close()
        
    else:
        
        for gal_row in range(len(flux_df)):
             
            main(gal_row)
    
    print('Code complete, took %.2fm' % ( (time.time() - start_time)/60 ))
コード例 #11
0
    def megafit_emcee(self):
        nsamples = 300

        atm = []
        for j in range(len(self.names)):
            prod_name = os.path.join(parameters.PROD_DIRECTORY, parameters.PROD_NAME)
            atmgrid = AtmosphereGrid(
                filename=(prod_name + '/' + self.names[j].split('/')[-1]).replace('sim', 'reduc').replace(
                    'spectrum.txt', 'atmsim.fits'))
            atm.append(atmgrid)

        def matrice_data():
            y = self.data - self.order2
            nb_spectre = len(self.names)
            nb_bin = len(self.data)
            D = np.zeros(nb_bin * nb_spectre)
            for j in range(nb_spectre):
                D[j * nb_bin: (j + 1) * nb_bin] = y[:, j]
            return D

        def Atm(atm, ozone, eau, aerosols):
            nb_spectre = len(self.names)
            nb_bin = len(self.data)
            M = np.zeros((nb_spectre, nb_bin, nb_bin))
            M_p = np.zeros((nb_spectre * nb_bin, nb_bin))
            for j in range(nb_spectre):
                Atmo = np.zeros(len(self.new_lambda))
                Lambdas = np.arange(self.Bin[0],self.Bin[-1],0.2)
                Atm = atm[j].simulate(ozone, eau, aerosols)(Lambdas)
                for i in range(len(self.new_lambda)):
                    Atmo[i] = np.mean(Atm[i*int(self.binwidths/0.2):(i+1)*int(self.binwidths/0.2)])
                a = np.diagflat(Atmo)
                M[j, :, :] = a
                M_p[nb_bin * j:nb_bin * (j+1),:] = a
            return M, M_p

        def log_likelihood(params_fit, atm):
            nb_spectre = len(self.names)
            nb_bin = len(self.data)
            ozone, eau, aerosols = params_fit[-3], params_fit[-2], params_fit[-1]
            D = matrice_data()
            M, M_p = Atm(atm, ozone, eau, aerosols)
            prod = np.zeros((nb_bin, nb_spectre * nb_bin))
            for spec in range(nb_spectre):
                prod[:,spec * nb_bin : (spec+1) * nb_bin] = M[spec] @ self.INVCOV[spec]
            COV = inv(prod @ M_p)
            A = COV @ prod @ D

            chi2 = 0
            for spec in range(nb_spectre):
                mat = D[spec * nb_bin : (spec+1) * nb_bin] - M[spec] @ A
                chi2 += mat @ self.INVCOV[spec] @ mat

            n = np.random.randint(0, 100)
            if n > 97:
                print(chi2 / (nb_spectre * nb_bin))
                print(ozone, eau, aerosols)
            return -0.5 * chi2

        def log_prior(params_fit):
            ozone, eau, aerosols = params_fit[-3], params_fit[-2], params_fit[-1]
            if 100 < ozone < 700 and 0 < eau < 10 and 0 < aerosols < 0.1:
                return 0
            else:
                return -np.inf

        def log_probability(params_fit, atm):
            lp = log_prior(params_fit)
            if not np.isfinite(lp):
                return -np.inf
            return lp + log_likelihood(params_fit, atm)

        if self.sim:
            filename = "sps/" + self.disperseur + "_"+ "sim_"+ parameters.PROD_NUM + "_emcee.h5"
        else:
            filename = "sps/" + self.disperseur + "_" + "reduc_" + parameters.PROD_NUM + "_emcee.h5"

        p_ozone = 300
        p_eau = 5
        p_aerosols = 0.03
        p0 = np.array([p_ozone, p_eau, p_aerosols])
        walker = 10

        init_ozone = p0[0] + p0[0] / 5 * np.random.randn(walker)
        init_eau = p0[1] + p0[1] / 5 * np.random.randn(walker)
        init_aerosols = p0[2] + p0[2] / 5 * np.random.randn(walker)

        p0 = np.array([[init_ozone[i], init_eau[i], init_aerosols[i]] for i in range(walker)])
        nwalkers, ndim = p0.shape

        backend = emcee.backends.HDFBackend(filename)
        try:
            pool = MPIPool()
            if not pool.is_master():
                pool.wait()
                sys.exit(0)
            sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability,
                                            args=(atm,), pool=pool, backend=backend)
            if backend.iteration > 0:
                p0 = backend.get_last_sample()

            if nsamples - backend.iteration > 0:
                sampler.run_mcmc(p0, nsteps=max(0, nsamples - backend.iteration), progress=True)
            pool.close()
        except ValueError:
            sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability,
                                            args=(atm,),
                                            threads=multiprocessing.cpu_count(), backend=backend)
            if backend.iteration > 0:
                p0 = sampler.get_last_sample()
            for _ in sampler.sample(p0, iterations=max(0, nsamples - backend.iteration), progress=True, store=True):
                continue

        flat_samples = sampler.get_chain(discard=100, thin=1, flat=True)

        ozone, d_ozone = np.mean(flat_samples[:, -3]), np.std(flat_samples[:, -3])
        eau, d_eau = np.mean(flat_samples[:, -2]), np.std(flat_samples[:, -2])
        aerosols, d_aerosols = np.mean(flat_samples[:, -1]), np.std(flat_samples[:, -1])
        print(ozone, d_ozone)
        print(eau, d_eau)
        print(aerosols, d_aerosols)
        self.params_atmo = np.array([ozone, eau, aerosols])
        self.err_params_atmo = np.array([d_ozone, d_eau, d_aerosols])

        nb_spectre = len(self.names)
        nb_bin = len(self.data)
        M, M_p = Atm(atm, ozone, eau, aerosols)
        prod = np.zeros((nb_bin, nb_spectre * nb_bin))
        chi2 = 0
        for spec in range(nb_spectre):
            prod[:, spec * nb_bin: (spec + 1) * nb_bin] = M[spec] @ self.INVCOV[spec]

        COV = inv(prod @ M_p)
        D = matrice_data()
        Tinst = COV @ prod @ D
        Tinst_err = np.array([np.sqrt(COV[i,i]) for i in range(len(Tinst))])

        if self.disperseur == 'HoloAmAg' and self.sim == False:
            a, b = np.argmin(abs(self.new_lambda - 537.5)), np.argmin(abs(self.new_lambda - 542.5))
            Tinst_err[a], Tinst_err[b] = Tinst_err[a-1], Tinst_err[b+1]

        for spec in range(nb_spectre):
            mat = D[spec * nb_bin: (spec + 1) * nb_bin] - M[spec] @ Tinst
            chi2 += mat @ self.INVCOV[spec] @ mat
        print(chi2 / (nb_spectre * nb_bin))

        err = np.zeros_like(D)
        for j in range(len(self.names)):
            for i in range(len(self.data)):
                if self.disperseur == 'HoloAmAg' and self.sim == False:
                    if self.new_lambda[i] == 537.5 or self.new_lambda[i] == 542.5:
                        err[j * len(self.data) + i] = 1
                    else:
                        err[j * len(self.data) + i] = np.sqrt(self.cov[j][i, i])
                else:
                    err[j * len(self.data) + i] = np.sqrt(self.cov[j][i, i])

        model = M_p @ Tinst
        Err = (D - model) / err

        if parameters.plot_residuals:
            self.Plot_residuals(model, D, Err, COV)

        return Tinst, Tinst_err
コード例 #12
0
    print('Running {0} iterations with {1} triangles'.format(
        nsteps_torun, early_iters_num_triangles))
    sampler.run_mcmc(None, nsteps_torun, progress=True)

# If early iterations completed, switch to higher number of triangles and run remaining trials    
nsteps_completed = backend.iteration
if nsteps_completed >= early_iters_cutoff:
    ## Switch to higher number of triangles
    mcmc_fit_obj.set_model_numTriangles(final_iters_num_triangles)
    
    print('Running {0} iterations with {1} triangles'.format(
        trial_samp_len - nsteps_completed, final_iters_num_triangles))
    sampler.run_mcmc(None, trial_samp_len - nsteps_completed, progress=True)

# Close multiprocessing pool
mp_pool.close()

# Print burnin and thin lengths
tau = backend.get_autocorr_time()
burnin = int(2 * np.max(tau))
thin = int(0.5 * np.min(tau))

print("Samples burn-in: {0}".format(burnin))
print("Samples thin: {0}".format(thin))


print(backend.get_chain())
print(backend.get_log_prob())
print(backend.get_log_prior())

コード例 #13
0
ファイル: merged.py プロジェクト: faircloth-lab/itero
def main(args, parser, mpi=False):
    if mpi:
        from schwimmbad import MPIPool
        # open up a pool of MPI processes using schwimmbad
        mpi_pool = MPIPool()
        # add this line for MPI compatibility
        if not mpi_pool.is_master():
            mpi_pool.wait()
            sys.exit(0)
    else:
        import multiprocessing
    start_time = time.time()
    # setup logging
    log, my_name = setup_logging(args)
    if mpi:
        # UNIQUE TO MPI CODE - so that processes will die if output directory
        # exists. So, make the output directory or die
        if os.path.exists(args.output):
            log.critical("THE OUTPUT DIRECTORY EXISTS.  QUITTING.")
            mpi_pool.close()
            sys.exit(1)
        else:
            # create the new directory
            os.makedirs(args.output)
    # get seeds from config file
    conf = ConfigParser.ConfigParser(allow_no_value=True)
    conf.optionxform = str
    conf.read(args.config)
    # get the seed file info
    seeds = common.get_seed_file(conf, args.config)
    # get name of all loci in seeds file - only need to do this once
    seed_names = common.get_seed_names(seeds)
    # get the input data
    log.info("Getting input filenames and creating output directories")
    individuals = common.get_input_data(log, args, conf)
    for individual in individuals:
        sample, dir = individual
        # pretty print taxon status
        text = " Processing {} ".format(sample)
        log.info(text.center(65, "-"))
        # make a directory for sample-specific assemblies
        sample_dir = os.path.join(args.output, sample)
        os.makedirs(sample_dir)
        # determine how many files we're dealing with
        try:
            fastq = raw_reads.get_input_files(dir, args.subfolder, log)
        except IOError, err:
            log.critical(
                "THERE WAS A PROBLEM WITH THE FASTQ FILES.  QUITTING.")
            if mpi:
                mpi_pool.close()
            traceback.print_exc()
            sys.exit(1)
        iterations = list(xrange(args.iterations)) + ['final']
        next_to_last_iter = iterations[-2]
        for iteration in iterations:
            text = " Iteration {} ".format(iteration)
            log.info(text.center(45, "-"))
            # One the last few iterations, set some things up differently to deal w/ dupe contigs.
            # First, we'll allow multiple contigs during all but the last few rounds of contig assembly.
            # This is because we could be assembling different parts of a locus that simply have not
            # merged in the middle yet (but will).  We'll turn option to remove multiple contigs
            # back on for last three rounds
            if iteration in iterations[-3:]:
                if args.allow_multiple_contigs is True:
                    allow_multiple_contigs = True
                else:
                    allow_multiple_contigs = False
            else:
                allow_multiple_contigs = True
            sample_dir_iter = os.path.join(sample_dir,
                                           "iter-{}".format(iteration))
            os.makedirs(sample_dir_iter)
            # change to sample_dir_iter
            os.chdir(sample_dir_iter)
            # copy seeds file
            if iteration == 0 and os.path.dirname(seeds) != os.getcwd():
                shutil.copy(seeds, os.getcwd())
                seeds = os.path.join(os.getcwd(), os.path.basename(seeds))
            elif iteration >= 1:
                shutil.copy(new_seeds, os.getcwd())
                seeds = os.path.join(os.getcwd(), os.path.basename(new_seeds))
            # if we are finished with it, cleanup the previous iteration
            if not args.do_not_zip and iteration >= 1:
                # after assembling all loci, zip the iter-#/loci directory; this will be slow if --clean is not turned on.
                prev_iter = common.get_previous_iter(log, sample_dir_iter,
                                                     iterations, iteration)
                zipped = common.zip_assembly_dir(log, sample_dir_iter,
                                                 args.clean, prev_iter)
            #index the seed file
            bwa.bwa_index_seeds(seeds, log)
            # map initial reads to seeds
            bam = bwa.bwa_mem_pe_align(log, sample, sample_dir_iter, seeds,
                                       args.local_cores, fastq.r1, fastq.r2,
                                       iteration)
            # reduce bam to mapping reads
            reduced_bam = samtools.samtools_reduce(log,
                                                   sample,
                                                   sample_dir_iter,
                                                   bam,
                                                   iteration=iteration)
            # remove the un-reduced BAM
            os.remove(bam)
            # if we are not on our last iteration, assembly as usual
            if iteration is not 'final':
                log.info("Splitting BAM by locus to SAM")
                header = samtools.get_bam_header(log, reduced_bam, iteration)
                sample_dir_iter_locus_temp = samtools.faster_split_bam(
                    log, reduced_bam, sample_dir_iter, iteration)
                if args.only_single_locus:
                    locus_names = ['locus-1']
                else:
                    # get list of loci in sorted bam
                    locus_names = samtools.samtools_get_locus_names_from_bam(
                        log, reduced_bam, iteration)
                log.info("Reheadering split SAMs")
                samtools.reheader_split_sams(log, sample_dir_iter,
                                             sample_dir_iter_locus_temp,
                                             header, locus_names)
                log.info("Removing temporary SAM files")
                shutil.rmtree(sample_dir_iter_locus_temp)
                log.info("Assembling")
                # MPI-specific bits
                tasks = [(iteration, sample, sample_dir_iter, locus_name,
                          args.clean, args.only_single_locus)
                         for locus_name in locus_names]
                if mpi:
                    results = mpi_pool.map(common.initial_assembly, tasks)
                # multiprocessing specific bits
                else:
                    if not args.only_single_locus and args.local_cores > 1:
                        assert args.local_cores <= multiprocessing.cpu_count(
                        ), "You've specified more cores than you have"
                        pool = multiprocessing.Pool(args.local_cores)
                        pool.map(common.initial_assembly, tasks)
                    elif args.only_single_locus:
                        map(common.initial_assembly, tasks)
                    else:
                        map(common.initial_assembly, tasks)
                # after assembling all loci, get them into a single file
                new_seeds = common.get_fasta(log,
                                             sample,
                                             sample_dir_iter,
                                             locus_names,
                                             allow_multiple_contigs,
                                             iteration=iteration)
                # after assembling all loci, report on deltas of the assembly length
                if iteration is not 0:
                    assembly_delta = common.get_deltas(log,
                                                       sample,
                                                       sample_dir_iter,
                                                       iterations,
                                                       iteration=iteration)
            elif iteration is 'final':
                log.info(
                    "Final assemblies and a BAM file with alignments to those assemblies are in {}/iter-{}"
                    .format(os.path.join(args.output, individual[0]),
                            iteration))
                # reset the seeds file to the starting value
                seeds = common.get_seed_file(conf, args.config)
コード例 #14
0
    def megafit_emcee(self):
        nsamples = 300

        atm = []
        for j in range(len(self.names)):
            prod_name = os.path.join(parameters.PROD_DIRECTORY,
                                     parameters.PROD_NAME)
            atmgrid = AtmosphereGrid(filename=(
                prod_name + '/' + self.names[j].split('/')[-1]
            ).replace('sim', 'reduc').replace('spectrum.txt', 'atmsim.fits'))
            atm.append(atmgrid)

        #D = self.matrice_data()
        """
        def f_tinst_atm(Tinst, ozone, eau, aerosols, atm):
            model = np.zeros((len(self.data_mag), len(self.names)))
            for j in range(len(self.names)):
                a = atm[j].simulate(ozone, eau, aerosols)
                model[:, j] = Tinst * a(self.new_lambda)
            return model
        """
        def Atm(atm, ozone, eau, aerosols):
            nb_spectre = len(self.names)
            nb_bin = len(self.data_mag)
            M = np.zeros((nb_spectre, nb_bin, nb_bin))
            M_p = np.zeros((nb_spectre * nb_bin, nb_bin))
            for j in range(nb_spectre):
                Atmo = np.zeros(len(self.new_lambda))
                Lambdas = np.arange(self.Bin[0], self.Bin[-1], 0.2)
                Atm = atm[j].simulate(ozone, eau, aerosols)(Lambdas)
                for i in range(len(self.new_lambda)):
                    #step = int(self.binwidths * 10)
                    #X = np.linspace(self.Bin[i], self.Bin[i + 1] + step, step)
                    Atmo[i] = np.mean(
                        Atm[i * int(self.binwidths / 0.2):(i + 1) *
                            int(self.binwidths / 0.2)])
                a = np.diagflat(Atmo)
                M[j, :, :] = a
                M_p[nb_bin * j:nb_bin * (j + 1), :] = a
            return M, M_p

        def log_likelihood(params_fit, atm):
            nb_spectre = len(self.names)
            nb_bin = len(self.data_mag)
            ozone, eau, aerosols = params_fit[-3], params_fit[-2], params_fit[
                -1]
            A2 = 0
            D = self.matrice_data(A2)
            M, M_p = Atm(atm, ozone, eau, aerosols)
            #A = np.zeros(nb_bin)
            #COV = np.zeros((nb_bin, nb_bin))
            prod = np.zeros((nb_bin, nb_spectre * nb_bin))
            for spec in range(nb_spectre):
                prod[:, spec * nb_bin:(spec + 1) *
                     nb_bin] = M[spec] @ self.INVCOV[spec]
            COV = inv(prod @ M_p)
            A = COV @ prod @ D

            chi2 = 0
            for spec in range(nb_spectre):
                mat = D[spec * nb_bin:(spec + 1) * nb_bin] - M[spec] @ A
                chi2 += mat @ self.INVCOV[spec] @ mat

            n = np.random.randint(0, 100)
            if n > 97:
                print(chi2 / (nb_spectre * nb_bin))
                print(A2, ozone, eau, aerosols)

            return -0.5 * chi2

        def log_prior(params_fit):
            ozone, eau, aerosols = params_fit[-3], params_fit[-2], params_fit[
                -1]
            if 100 < ozone < 700 and 0 < eau < 10 and 0 < aerosols < 0.1:  # and 0<A2<5:
                return 0
            else:
                return -np.inf

        def log_probability(params_fit, atm):
            lp = log_prior(params_fit)
            if not np.isfinite(lp):
                return -np.inf
            return lp + log_likelihood(params_fit, atm)

        if self.sim:
            filename = "sps/" + self.disperseur + "_" + "sim_new_" + parameters.PROD_NUM + "_emcee.h5"  #+ "sim_new_"
        else:
            filename = "sps/" + self.disperseur + "_" + "reduc_new_" + parameters.PROD_NUM + "_emcee.h5"

        if os.path.exists(filename):
            slope, ord, err_slope, err_ord = self.bouguer_line()
        else:
            slope, ord, err_slope, err_ord, A2, A2_err = self.bouguer_line_order2(
            )
        p_ozone = 300
        p_eau = 5
        p_aerosols = 0.03
        #p_A2 = 1
        p0 = np.array([p_ozone, p_eau, p_aerosols])
        walker = 10

        #init_A2 = p0[0] + p0[0] / 5 * np.random.randn(walker)
        init_ozone = p0[0] + p0[0] / 5 * np.random.randn(walker)
        init_eau = p0[1] + p0[1] / 5 * np.random.randn(walker)
        init_aerosols = p0[2] + p0[2] / 5 * np.random.randn(walker)

        p0 = np.array([[init_ozone[i], init_eau[i], init_aerosols[i]]
                       for i in range(walker)])
        nwalkers, ndim = p0.shape
        """
        plt.errorbar(self.new_lambda, (np.exp(self.data_mag) - self.order2)[:,10], yerr=(self.err_mag * np.exp(self.data_mag))[:,10])
        plt.show()
        """
        backend = emcee.backends.HDFBackend(filename)
        try:
            pool = MPIPool()
            if not pool.is_master():
                pool.wait()
                sys.exit(0)
            sampler = emcee.EnsembleSampler(nwalkers,
                                            ndim,
                                            log_probability,
                                            args=(atm, ),
                                            pool=pool,
                                            backend=backend)
            if backend.iteration > 0:
                p0 = backend.get_last_sample()

            if nsamples - backend.iteration > 0:
                sampler.run_mcmc(p0,
                                 nsteps=max(0, nsamples - backend.iteration),
                                 progress=True)
            pool.close()
        except ValueError:
            sampler = emcee.EnsembleSampler(
                nwalkers,
                ndim,
                log_probability,
                args=(atm, ),
                threads=multiprocessing.cpu_count(),
                backend=backend)
            if backend.iteration > 0:
                p0 = sampler.get_last_sample()
            for _ in sampler.sample(p0,
                                    iterations=max(
                                        0, nsamples - backend.iteration),
                                    progress=True,
                                    store=True):
                continue

        flat_samples = sampler.get_chain(discard=100, thin=1, flat=True)

        #A2, A2_err = np.mean(flat_samples[:, -4]), np.std(flat_samples[:, -4])
        ozone, d_ozone = np.mean(flat_samples[:, -3]), np.std(flat_samples[:,
                                                                           -3])
        eau, d_eau = np.mean(flat_samples[:, -2]), np.std(flat_samples[:, -2])
        aerosols, d_aerosols = np.mean(flat_samples[:, -1]), np.std(
            flat_samples[:, -1])
        print(ozone, d_ozone)
        print(eau, d_eau)
        print(aerosols, d_aerosols)
        #print(A2, A2_err)
        nb_spectre = len(self.names)
        nb_bin = len(self.data_mag)
        M, M_p = Atm(atm, ozone, eau, aerosols)
        prod = np.zeros((nb_bin, nb_spectre * nb_bin))
        chi2 = 0
        for spec in range(nb_spectre):
            prod[:, spec * nb_bin:(spec + 1) *
                 nb_bin] = M[spec] @ self.INVCOV[spec]

        COV = inv(prod @ M_p)
        A2 = 0
        D = self.matrice_data(A2)
        Tinst = COV @ prod @ D
        Tinst_err = np.array([np.sqrt(COV[i, i]) for i in range(len(Tinst))])
        a, b = np.argmin(abs(self.new_lambda - 537.5)), np.argmin(
            abs(self.new_lambda - 542.5))
        Tinst_err[a], Tinst_err[b] = 1e-16, 1e-16
        A = COV @ prod @ D
        for spec in range(nb_spectre):
            mat = D[spec * nb_bin:(spec + 1) * nb_bin] - M[spec] @ A
            chi2 += mat @ self.INVCOV[spec] @ mat
        print(chi2 / (nb_spectre * nb_bin))

        def compute_correlation_matrix(cov):
            rho = np.zeros_like(cov)
            for i in range(cov.shape[0]):
                for j in range(cov.shape[1]):
                    rho[i, j] = cov[i, j] / np.sqrt(cov[i, i] * cov[j, j])
            return rho

        def plot_correlation_matrix_simple(ax, rho, axis_names, ipar=None):
            if ipar is None:
                ipar = np.arange(rho.shape[0]).astype(int)
            im = plt.imshow(rho[ipar[:, None], ipar],
                            interpolation="nearest",
                            cmap='bwr',
                            vmin=-1,
                            vmax=1)
            ax.set_title("Correlation matrix")
            names = [axis_names[ip] for ip in ipar]
            plt.xticks(np.arange(ipar.size),
                       names,
                       rotation='vertical',
                       fontsize=11)
            plt.yticks(np.arange(ipar.size), names, fontsize=11)
            cbar = plt.colorbar(im)
            cbar.ax.tick_params(labelsize=9)
            plt.gcf().tight_layout()

        def plot_err(err, ipar=None):
            rho = np.zeros((len(self.names), len(self.data_mag)))
            test = [
                int(self.names[i][-16:-13]) for i in range(len(self.names))
            ]
            print(test)
            test2 = test.copy()
            for Test in test:
                C = 0
                for Test2 in test:
                    if Test > Test2:
                        C += 1
                test2[C] = Test
            print(test2)
            axis_names_vert = []
            for i in range(rho.shape[0]):
                k = np.argmin(abs(np.array(test) - test2[i]))
                axis_names_vert.append(str(test[k]))
                for j in range(rho.shape[1]):
                    rho[i, j] = err[k * len(self.data_mag) + j]
            if ipar is None:
                vert = np.arange(rho.shape[0]).astype(int)
                hor = np.arange(rho.shape[1]).astype(int)
            """
            gs_kw = dict(height_ratios=[1], width_ratios=[5,1])
            fig, ax = plt.subplots(1, 2, figsize=[15, 15], constrained_layout=True, gridspec_kw=gs_kw)
            ax1, ax2 = ax[0], ax[1]
            """
            fig = plt.figure(figsize=[15, 7])
            ax = fig.add_subplot(111)
            axis_names_hor = [
                str(self.new_lambda[i]) for i in range(len(self.new_lambda))
            ]
            norm = matplotlib.colors.SymLogNorm(vmin=-np.max(abs(rho)),
                                                vmax=np.max(abs(rho)),
                                                linthresh=10)
            im = plt.imshow(rho[vert[:, None], hor],
                            interpolation="nearest",
                            cmap='bwr',
                            vmin=-5,
                            vmax=5)
            if self.sim:
                #ax.set_title("Résidus: (data - model) / err sur des simulations du "+self.disperseur+" version "+parameters.PROD_NUM)
                ax.set_title(self.disperseur, fontsize=21)
            else:
                ax.set_title(self.disperseur, fontsize=21)
            print(np.mean(rho))
            print(np.std(rho))
            names_vert = [axis_names_vert[ip] for ip in vert]
            names_hor = [axis_names_hor[ip] for ip in hor]
            plt.xticks(np.arange(0, hor.size, 3),
                       names_hor[::3],
                       rotation='vertical',
                       fontsize=14)
            plt.yticks(np.arange(0, vert.size, 3),
                       names_vert[::3],
                       fontsize=14)
            plt.xlabel('$\lambda$ [nm]', fontsize=17)
            plt.ylabel('Spectrum index', fontsize=17)
            cbar = plt.colorbar(im, orientation='horizontal')
            cbar.set_label('Residuals in #$\sigma$', fontsize=20)
            cbar.ax.tick_params(labelsize=13)
            # cbar.ax.set_yticklabels(['{:.0f}'.format(x) for x in np.linspace(np.min(rho), np.max(rho), 10)])                                    #fontsize=16, weight='bold')
            plt.gcf().tight_layout()
            fig.tight_layout()
            if self.sim and 1 == 1:
                plt.savefig(parameters.OUTPUTS_THROUGHPUT_SIM +
                            'throughput_simb, ' + self.disperseur +
                            ',résidus, version_' + parameters.PROD_NUM +
                            '.pdf')
            elif 1 == 1:
                plt.savefig(parameters.OUTPUTS_THROUGHPUT_REDUC +
                            'throughput_reduc, ' + self.disperseur +
                            ',résidus, version_' + parameters.PROD_NUM +
                            '.pdf')

        fig = plt.figure(figsize=[15, 10])
        ax = fig.add_subplot(111)
        axis_names = [str(i) for i in range(len(COV))]
        plot_correlation_matrix_simple(ax,
                                       compute_correlation_matrix(COV),
                                       axis_names,
                                       ipar=None)

        err = np.zeros_like(D)
        for j in range(len(self.names)):
            for i in range(len(self.data_mag)):
                #if self.new_lambda[i]==537.5 or self.new_lambda[i]==542.5:
                #   err[j * len(self.data_mag) + i] = 1
                #else:
                err[j * len(self.data_mag) + i] = np.sqrt(self.cov[j][i, i])
        model = M_p @ Tinst
        Err = (D - model) / err
        plot_err(Err)
        fig = plt.figure(figsize=[15, 15])
        ax = fig.add_subplot(111)
        ax.hist(Err, bins=np.arange(-10, 10, 1))
        ax.set_title(
            "Histogramme des résidus: (data - model) / err sur des données du HoloAmAg version "
            + parameters.PROD_NUM)
        plt.xlabel('Ecart aux données', fontsize=14)
        plt.grid(True)
        test = [int(self.names[i][-16:-13]) for i in range(len(self.names))]
        Kplot = [181, 186, 191, 196]
        for i in range(len(Kplot)):
            k = np.argmin(abs(np.array(test) - Kplot[i]))
            fig = plt.figure(figsize=[15, 15])
            ax = fig.add_subplot(111)
            plt.plot(self.new_lambda,
                     model[k * len(self.new_lambda):(k + 1) *
                           len(self.new_lambda)],
                     c='red',
                     label='model')
            plt.plot(self.new_lambda,
                     D[k * len(self.new_lambda):(k + 1) *
                       len(self.new_lambda)],
                     c='blue',
                     label='data')
            plt.title("spectrum :" + str(Kplot[i]))
            plt.grid(True)
            plt.legend()

        plt.show()
        return Tinst, Tinst_err