Example #1
0
def pool(request):
    multimode = 'None'
    # multimode = 'Serial'
    # multimode = 'Multi'
    # multimode = 'MPI'

    # setup code
    pool = None
    if multimode == 'Serial':
        from schwimmbad import SerialPool
        pool = SerialPool()
    elif multimode == 'Multi':
        from schwimmbad import MultiPool
        pool = MultiPool()
    elif multimode == 'MPI':
        from schwimmbad import MPIPool
        pool = MPIPool()
        if not pool.is_master():
            pool.wait()
            import sys
            sys.exit(0)

    # inject class variables
    request.cls.pool = pool
    yield

    # tear down
    if multimode == 'Multi' or multimode == 'MPI':
        pool.close()
Example #2
0
def test_init(case):
    prior, _ = get_prior(case)

    # Try various initializations
    TheJoker(prior)

    with pytest.raises(TypeError):
        TheJoker('jsdfkj')

    # Pools:
    with SerialPool() as pool:
        TheJoker(prior, pool=pool)

    # fail when pool is invalid:
    with pytest.raises(TypeError):
        TheJoker(prior, pool='sdfks')

    # Random state:
    rnd = np.random.default_rng(42)
    TheJoker(prior, random_state=rnd)

    # fail when random state is invalid:
    with pytest.raises(TypeError):
        TheJoker(prior, random_state='sdfks')

    with pytest.warns(DeprecationWarning):
        rnd = np.random.RandomState(42)
        TheJoker(prior, random_state=rnd)

    # tempfile location:
    joker = TheJoker(prior, tempfile_path='/tmp/joker')
    assert os.path.exists(joker.tempfile_path)
Example #3
0
    def __init__(self,
                 cuda=False,
                 exit_on_prompt=False,
                 language='en',
                 limiting_magnitude=None,
                 prefer_fluxes=False,
                 offline=False,
                 prefer_cache=False,
                 open_in_browser=False,
                 pool=None,
                 quiet=False,
                 test=False,
                 wrap_length=100,
                 **kwargs):
        """Initialize `Fitter` class."""
        self._pool = SerialPool() if pool is None else pool
        self._printer = Printer(pool=self._pool,
                                wrap_length=wrap_length,
                                quiet=quiet,
                                fitter=self,
                                language=language,
                                exit_on_prompt=exit_on_prompt)
        self._fetcher = Fetcher(test=test,
                                open_in_browser=open_in_browser,
                                printer=self._printer)

        self._cuda = cuda
        self._limiting_magnitude = limiting_magnitude
        self._prefer_fluxes = prefer_fluxes
        self._offline = offline
        self._prefer_cache = prefer_cache
        self._open_in_browser = open_in_browser
        self._quiet = quiet
        self._test = test
        self._wrap_length = wrap_length

        if self._cuda:
            try:
                import pycuda.autoinit  # noqa: F401
                import skcuda.linalg as linalg
                linalg.init()
            except ImportError:
                pass
Example #4
0
    def sample(self, n_samples):
        if self.pool is None or _GPU_ENABLED:
            pool = SerialPool()
        else:
            if isinstance(self.pool, int):
                pool = MultiPool(self.pool)
            elif isinstance(self.pool, (SerialPool, MultiPool)):
                pool = self.pool
            else:
                raise TypeError(
                    "Does not understand the given multiprocessing pool.")

        drawn_samples = list(
            pool.map(self.draw_one_joint_posterior_sample_map,
                     range(n_samples)))
        pool.close()

        drawn_zs = [drawn_samples[i][0] for i in range(n_samples)]
        drawn_inference_posteriors = [
            drawn_samples[i][1] for i in range(n_samples)
        ]

        drawn_joint_posterior_samples = pd.DataFrame(
            drawn_inference_posteriors)
        drawn_joint_posterior_samples["redshift"] = drawn_zs

        return drawn_joint_posterior_samples
Example #5
0
def compute_mean_selection_function(selection_function, N_avg, pool=None):
    if pool is None:
        pool = SerialPool()
    elif isinstance(pool, int):
        pool = MultiPool(pool)
    elif isinstance(pool, (SerialPool, MultiPool)):
        pool = pool
    else:
        raise TypeError("Does not understand the given multiprocessing pool.")

    out = list(
        pool.starmap(selection_function.evaluate, [() for _ in range(N_avg)]))
    avg = np.average(out)
    pool.close()
    return avg
Example #6
0
class Fitter(object):
    """Fit transient events with the provided model."""

    _DEFAULT_SOURCE = {SOURCE.BIBCODE: '2017arXiv171002145G'}

    def __init__(self,
                 cuda=False,
                 exit_on_prompt=False,
                 language='en',
                 limiting_magnitude=None,
                 prefer_fluxes=False,
                 offline=False,
                 prefer_cache=False,
                 open_in_browser=False,
                 pool=None,
                 quiet=False,
                 test=False,
                 wrap_length=100,
                 **kwargs):
        """Initialize `Fitter` class."""
        self._pool = SerialPool() if pool is None else pool
        self._printer = Printer(pool=self._pool,
                                wrap_length=wrap_length,
                                quiet=quiet,
                                fitter=self,
                                language=language,
                                exit_on_prompt=exit_on_prompt)
        self._fetcher = Fetcher(test=test,
                                open_in_browser=open_in_browser,
                                printer=self._printer)

        self._cuda = cuda
        self._limiting_magnitude = limiting_magnitude
        self._prefer_fluxes = prefer_fluxes
        self._offline = offline
        self._prefer_cache = prefer_cache
        self._open_in_browser = open_in_browser
        self._quiet = quiet
        self._test = test
        self._wrap_length = wrap_length

        if self._cuda:
            try:
                import pycuda.autoinit  # noqa: F401
                import skcuda.linalg as linalg
                linalg.init()
            except ImportError:
                pass

    def fit_events(self,
                   events=[],
                   models=[],
                   max_time='',
                   time_list=[],
                   time_unit=None,
                   band_list=[],
                   band_systems=[],
                   band_instruments=[],
                   band_bandsets=[],
                   band_sampling_points=17,
                   iterations=10000,
                   num_walkers=None,
                   num_temps=1,
                   parameter_paths=['parameters.json'],
                   fracking=True,
                   frack_step=50,
                   burn=None,
                   post_burn=None,
                   gibbs=False,
                   smooth_times=-1,
                   extrapolate_time=0.0,
                   limit_fitting_mjds=False,
                   exclude_bands=[],
                   exclude_instruments=[],
                   exclude_systems=[],
                   exclude_sources=[],
                   exclude_kinds=[],
                   output_path='',
                   suffix='',
                   upload=False,
                   write=False,
                   upload_token='',
                   check_upload_quality=False,
                   variance_for_each=[],
                   user_fixed_parameters=[],
                   user_released_parameters=[],
                   convergence_type=None,
                   convergence_criteria=None,
                   save_full_chain=False,
                   draw_above_likelihood=False,
                   maximum_walltime=False,
                   start_time=False,
                   print_trees=False,
                   maximum_memory=np.inf,
                   speak=False,
                   return_fits=True,
                   extra_outputs=None,
                   walker_paths=[],
                   catalogs=[],
                   exit_on_prompt=False,
                   download_recommended_data=False,
                   local_data_only=False,
                   method=None,
                   seed=None,
                   **kwargs):
        """Fit a list of events with a list of models."""
        global model
        if start_time is False:
            start_time = time.time()

        self._seed = seed
        if seed is not None:
            np.random.seed(seed)

        self._start_time = start_time
        self._maximum_walltime = maximum_walltime
        self._maximum_memory = maximum_memory
        self._debug = False
        self._speak = speak
        self._download_recommended_data = download_recommended_data
        self._local_data_only = local_data_only

        self._draw_above_likelihood = draw_above_likelihood

        prt = self._printer

        event_list = listify(events)
        model_list = listify(models)

        if len(model_list) and not len(event_list):
            event_list = ['']

        # Exclude catalogs not included in catalog list.
        self._fetcher.add_excluded_catalogs(catalogs)

        if not len(event_list) and not len(model_list):
            prt.message('no_events_models', warning=True)

        # If the input is not a JSON file, assume it is either a list of
        # transients or that it is the data from a single transient in tabular
        # form. Try to guess the format first, and if that fails ask the user.
        self._converter = Converter(prt, require_source=upload)
        event_list = self._converter.generate_event_list(event_list)

        event_list = [x.replace('‑', '-') for x in event_list]

        entries = [[] for x in range(len(event_list))]
        ps = [[] for x in range(len(event_list))]
        lnprobs = [[] for x in range(len(event_list))]

        # Load walker data if provided a list of walker paths.
        walker_data = []

        if len(walker_paths):
            try:
                pool = MPIPool()
            except (ImportError, ValueError):
                pool = SerialPool()
            if pool.is_master():
                prt.message('walker_file')
                wfi = 0
                for walker_path in walker_paths:
                    if os.path.exists(walker_path):
                        prt.prt('  {}'.format(walker_path))
                        with codecs.open(walker_path, 'r',
                                         encoding='utf-8') as f:
                            all_walker_data = json.load(
                                f, object_pairs_hook=OrderedDict)

                        # Support both the format where all data stored in a
                        # single-item dictionary (the OAC format) and the older
                        # MOSFiT format where the data was stored in the
                        # top-level dictionary.
                        if ENTRY.NAME not in all_walker_data:
                            all_walker_data = all_walker_data[list(
                                all_walker_data.keys())[0]]

                        models = all_walker_data.get(ENTRY.MODELS, [])
                        choice = None
                        if len(models) > 1:
                            model_opts = [
                                '{}-{}-{}'.format(x['code'], x['name'],
                                                  x['date']) for x in models
                            ]
                            choice = prt.prompt('select_model_walkers',
                                                kind='select',
                                                message=True,
                                                options=model_opts)
                            choice = model_opts.index(choice)
                        elif len(models) == 1:
                            choice = 0

                        if choice is not None:
                            walker_data.extend([[
                                wfi, x[REALIZATION.PARAMETERS],
                                x.get(REALIZATION.WEIGHT)
                            ] for x in models[choice][MODEL.REALIZATIONS]])

                        for i in range(len(walker_data)):
                            if walker_data[i][2] is not None:
                                walker_data[i][2] = float(walker_data[i][2])

                        if not len(walker_data):
                            prt.message('no_walker_data')
                    else:
                        prt.message('no_walker_data')
                        if self._offline:
                            prt.message('omit_offline')
                        raise RuntimeError
                    wfi = wfi + 1

                for rank in range(1, pool.size + 1):
                    pool.comm.send(walker_data, dest=rank, tag=3)
            else:
                walker_data = pool.comm.recv(source=0, tag=3)
                pool.wait()

            if pool.is_master():
                pool.close()

        self._event_name = 'Batch'
        self._event_path = ''
        self._event_data = {}

        try:
            pool = MPIPool()
        except (ImportError, ValueError):
            pool = SerialPool()
        if pool.is_master():
            fetched_events = self._fetcher.fetch(
                event_list,
                offline=self._offline,
                prefer_cache=self._prefer_cache)

            for rank in range(1, pool.size + 1):
                pool.comm.send(fetched_events, dest=rank, tag=0)
            pool.close()
        else:
            fetched_events = pool.comm.recv(source=0, tag=0)
            pool.wait()

        for ei, event in enumerate(fetched_events):
            if event is not None:
                self._event_name = event.get('name', 'Batch')
                self._event_path = event.get('path', '')
                if not self._event_path:
                    continue
                self._event_data = self._fetcher.load_data(event)
                if not self._event_data:
                    continue

            if model_list:
                lmodel_list = model_list
            else:
                lmodel_list = ['']

            entries[ei] = [None for y in range(len(lmodel_list))]
            ps[ei] = [None for y in range(len(lmodel_list))]
            lnprobs[ei] = [None for y in range(len(lmodel_list))]

            if (event is not None and
                (not self._event_data or ENTRY.PHOTOMETRY
                 not in self._event_data[list(self._event_data.keys())[0]])):
                prt.message('no_photometry', [self._event_name])
                continue

            for mi, mod_name in enumerate(lmodel_list):
                for parameter_path in parameter_paths:
                    try:
                        pool = MPIPool()
                    except (ImportError, ValueError):
                        pool = SerialPool()
                    self._model = Model(model=mod_name,
                                        data=self._event_data,
                                        parameter_path=parameter_path,
                                        output_path=output_path,
                                        wrap_length=self._wrap_length,
                                        test=self._test,
                                        printer=prt,
                                        fitter=self,
                                        pool=pool,
                                        print_trees=print_trees)

                    if not self._model._model_name:
                        prt.message('no_models_avail', [self._event_name],
                                    warning=True)
                        continue

                    if not event:
                        prt.message('gen_dummy')
                        self._event_name = mod_name
                        gen_args = {
                            'name': mod_name,
                            'max_time': max_time,
                            'time_list': time_list,
                            'band_list': band_list,
                            'band_systems': band_systems,
                            'band_instruments': band_instruments,
                            'band_bandsets': band_bandsets
                        }
                        self._event_data = self.generate_dummy_data(**gen_args)

                    success = False
                    alt_name = None
                    while not success:
                        self._model.reset_unset_recommended_keys()
                        success = self._model.load_data(
                            self._event_data,
                            event_name=self._event_name,
                            smooth_times=smooth_times,
                            extrapolate_time=extrapolate_time,
                            limit_fitting_mjds=limit_fitting_mjds,
                            exclude_bands=exclude_bands,
                            exclude_instruments=exclude_instruments,
                            exclude_systems=exclude_systems,
                            exclude_sources=exclude_sources,
                            exclude_kinds=exclude_kinds,
                            time_list=time_list,
                            time_unit=time_unit,
                            band_list=band_list,
                            band_systems=band_systems,
                            band_instruments=band_instruments,
                            band_bandsets=band_bandsets,
                            band_sampling_points=band_sampling_points,
                            variance_for_each=variance_for_each,
                            user_fixed_parameters=user_fixed_parameters,
                            user_released_parameters=user_released_parameters,
                            pool=pool)

                        if not success:
                            break

                        if self._local_data_only:
                            break

                        # If our data is missing recommended keys, offer the
                        # user option to pull the missing data from online and
                        # merge with existing data.
                        urk = self._model.get_unset_recommended_keys()
                        ptxt = prt.text('acquire_recommended',
                                        [', '.join(list(urk))])
                        while event and len(urk) and (
                                alt_name or self._download_recommended_data
                                or prt.prompt(ptxt, [', '.join(urk)],
                                              kind='bool')):
                            try:
                                pool = MPIPool()
                            except (ImportError, ValueError):
                                pool = SerialPool()
                            if pool.is_master():
                                en = (alt_name
                                      if alt_name else self._event_name)
                                extra_event = self._fetcher.fetch(
                                    en,
                                    offline=self._offline,
                                    prefer_cache=self._prefer_cache)[0]
                                extra_data = self._fetcher.load_data(
                                    extra_event)

                                for rank in range(1, pool.size + 1):
                                    pool.comm.send(extra_data,
                                                   dest=rank,
                                                   tag=4)
                                pool.close()
                            else:
                                extra_data = pool.comm.recv(source=0, tag=4)
                                pool.wait()

                            if extra_data is not None:
                                extra_data = extra_data[list(
                                    extra_data.keys())[0]]

                                for key in urk:
                                    new_val = extra_data.get(key)
                                    self._event_data[list(
                                        self._event_data.keys())
                                                     [0]][key] = new_val
                                    if new_val is not None and len(new_val):
                                        prt.message('extra_value', [
                                            key,
                                            str(new_val[0].get(QUANTITY.VALUE))
                                        ])
                                success = False
                                prt.message('reloading_merged')
                                break
                            else:
                                text = prt.text('extra_not_found',
                                                [self._event_name])
                                alt_name = prt.prompt(text, kind='string')
                                if not alt_name:
                                    break

                    if success:
                        self._walker_data = walker_data

                        entry, p, lnprob = self.fit_data(
                            event_name=self._event_name,
                            method=method,
                            iterations=iterations,
                            num_walkers=num_walkers,
                            num_temps=num_temps,
                            burn=burn,
                            post_burn=post_burn,
                            fracking=fracking,
                            frack_step=frack_step,
                            gibbs=gibbs,
                            pool=pool,
                            output_path=output_path,
                            suffix=suffix,
                            write=write,
                            upload=upload,
                            upload_token=upload_token,
                            check_upload_quality=check_upload_quality,
                            convergence_type=convergence_type,
                            convergence_criteria=convergence_criteria,
                            save_full_chain=save_full_chain,
                            extra_outputs=extra_outputs)
                        if return_fits:
                            entries[ei][mi] = deepcopy(entry)
                            ps[ei][mi] = deepcopy(p)
                            lnprobs[ei][mi] = deepcopy(lnprob)

                    if pool.is_master():
                        pool.close()

                    # Remove global model variable and garbage collect.
                    try:
                        model
                    except NameError:
                        pass
                    else:
                        del (model)
                    del (self._model)
                    gc.collect()

        return (entries, ps, lnprobs)

    def fit_data(self,
                 event_name='',
                 method=None,
                 iterations=None,
                 frack_step=20,
                 num_walkers=None,
                 num_temps=1,
                 burn=None,
                 post_burn=None,
                 fracking=True,
                 gibbs=False,
                 pool=None,
                 output_path='',
                 suffix='',
                 write=False,
                 upload=False,
                 upload_token='',
                 check_upload_quality=True,
                 convergence_type=None,
                 convergence_criteria=None,
                 save_full_chain=False,
                 extra_outputs=None):
        """Fit the data for a given event.

        Fitting performed using a combination of emcee and fracking.
        """
        if self._speak:
            speak('Fitting ' + event_name, self._speak)
        from mosfit.__init__ import __version__
        global model
        model = self._model
        prt = self._printer

        upload_model = upload and iterations > 0

        if pool is not None:
            self._pool = pool

        if upload:
            try:
                import dropbox
            except ImportError:
                if self._test:
                    pass
                else:
                    prt.message('install_db', error=True)
                    raise

        if not self._pool.is_master():
            try:
                self._pool.wait()
            except (KeyboardInterrupt, SystemExit):
                pass
            return (None, None, None)

        self._method = method

        if self._method == 'nester':
            self._sampler = Nester(self, model, iterations, burn, post_burn,
                                   num_walkers, convergence_criteria,
                                   convergence_type, gibbs, fracking,
                                   frack_step)
        else:
            self._sampler = Ensembler(self, model, iterations, burn, post_burn,
                                      num_temps, num_walkers,
                                      convergence_criteria, convergence_type,
                                      gibbs, fracking, frack_step)

        self._sampler.run(self._walker_data)

        prt.message('constructing')

        if write:
            if self._speak:
                speak(prt._strings['saving_output'], self._speak)

        if self._event_path:
            entry = Entry.init_from_file(catalog=None,
                                         name=self._event_name,
                                         path=self._event_path,
                                         merge=False,
                                         pop_schema=False,
                                         ignore_keys=[ENTRY.MODELS],
                                         compare_to_existing=False)
            new_photometry = []
            for photo in entry.get(ENTRY.PHOTOMETRY, []):
                if PHOTOMETRY.REALIZATION not in photo:
                    new_photometry.append(photo)
            if len(new_photometry):
                entry[ENTRY.PHOTOMETRY] = new_photometry
        else:
            entry = Entry(name=self._event_name)

        uentry = Entry(name=self._event_name)
        data_keys = set()
        for task in model._call_stack:
            if model._call_stack[task]['kind'] == 'data':
                data_keys.update(
                    list(model._call_stack[task].get('keys', {}).keys()))
        entryhash = entry.get_hash(keys=list(sorted(list(data_keys))))

        # Accumulate all the sources and add them to each entry.
        sources = []
        for root in model._references:
            for ref in model._references[root]:
                sources.append(entry.add_source(**ref))
        sources.append(entry.add_source(**self._DEFAULT_SOURCE))
        source = ','.join(sources)

        usources = []
        for root in model._references:
            for ref in model._references[root]:
                usources.append(uentry.add_source(**ref))
        usources.append(uentry.add_source(**self._DEFAULT_SOURCE))
        usource = ','.join(usources)

        model_setup = OrderedDict()
        for ti, task in enumerate(model._call_stack):
            task_copy = deepcopy(model._call_stack[task])
            if (task_copy['kind'] == 'parameter'
                    and task in model._parameter_json):
                task_copy.update(model._parameter_json[task])
            model_setup[task] = task_copy
        modeldict = OrderedDict([(MODEL.NAME, model._model_name),
                                 (MODEL.SETUP, model_setup),
                                 (MODEL.CODE, 'MOSFiT'),
                                 (MODEL.DATE, time.strftime("%Y/%m/%d")),
                                 (MODEL.VERSION, __version__),
                                 (MODEL.SOURCE, source)])

        self._sampler.prepare_output(check_upload_quality, upload)

        self._sampler.append_output(modeldict)

        umodeldict = deepcopy(modeldict)
        umodeldict[MODEL.SOURCE] = usource
        modelhash = get_model_hash(umodeldict,
                                   ignore_keys=[MODEL.DATE, MODEL.SOURCE])
        umodelnum = uentry.add_model(**umodeldict)

        if self._sampler._upload_model is not None:
            upload_model = self._sampler._upload_model

        modelnum = entry.add_model(**modeldict)

        samples, probs, weights = self._sampler.get_samples()

        extras = OrderedDict()
        samples_to_plot = self._sampler._nwalkers

        if isinstance(self._sampler, Nester):
            icdf = np.cumsum(np.concatenate(([0.0], weights)))
            draws = np.random.rand(samples_to_plot)
            indices = np.searchsorted(icdf, draws) - 1
        else:
            indices = list(range(samples_to_plot))

        ri = 0
        selected_extra = False
        for xi, x in enumerate(samples):
            ri = ri + 1
            prt.message('outputting_walker', [ri, len(samples)],
                        inline=True,
                        min_time=0.2)
            if xi in indices:
                output = model.run_stack(x, root='output')
                if extra_outputs is not None:
                    if not extra_outputs and not selected_extra:
                        extra_options = list(output.keys())
                        prt.message('available_keys')
                        for opt in extra_options:
                            prt.prt('- {}'.format(opt))
                        selected_extra = True
                    for key in extra_outputs:
                        new_val = output.get(key, [])
                        new_val = all_to_list(new_val)
                        extras.setdefault(key, []).append(new_val)
                for i in range(len(output['times'])):
                    if not np.isfinite(output['model_observations'][i]):
                        continue
                    photodict = {
                        PHOTOMETRY.TIME:
                        output['times'][i] + output['min_times'],
                        PHOTOMETRY.MODEL: modelnum,
                        PHOTOMETRY.SOURCE: source,
                        PHOTOMETRY.REALIZATION: str(ri)
                    }
                    if output['observation_types'][i] == 'magnitude':
                        photodict[PHOTOMETRY.BAND] = output['bands'][i]
                        photodict[PHOTOMETRY.
                                  MAGNITUDE] = output['model_observations'][i]
                        photodict[PHOTOMETRY.
                                  E_MAGNITUDE] = output['model_variances'][i]
                    elif output['observation_types'][i] == 'magcount':
                        if output['model_observations'][i] == 0.0:
                            continue
                        photodict[PHOTOMETRY.BAND] = output['bands'][i]
                        photodict[PHOTOMETRY.
                                  COUNT_RATE] = output['model_observations'][i]
                        photodict[PHOTOMETRY.
                                  E_COUNT_RATE] = output['model_variances'][i]
                        photodict[PHOTOMETRY.MAGNITUDE] = -2.5 * np.log10(
                            output['model_observations']
                            [i]) + output['all_zeropoints'][i]
                        photodict[PHOTOMETRY.E_UPPER_MAGNITUDE] = 2.5 * (
                            np.log10(output['model_observations'][i] +
                                     output['model_variances'][i]) -
                            np.log10(output['model_observations'][i]))
                        if (output['model_variances'][i] >
                                output['model_observations'][i]):
                            photodict[PHOTOMETRY.UPPER_LIMIT] = True
                        else:
                            photodict[PHOTOMETRY.E_LOWER_MAGNITUDE] = 2.5 * (
                                np.log10(output['model_observations'][i]) -
                                np.log10(output['model_observations'][i] -
                                         output['model_variances'][i]))
                    elif output['observation_types'][i] == 'fluxdensity':
                        photodict[PHOTOMETRY.FREQUENCY] = output[
                            'frequencies'][i] * frequency_unit('GHz')
                        photodict[PHOTOMETRY.FLUX_DENSITY] = output[
                            'model_observations'][i] * flux_density_unit('µJy')
                        photodict[PHOTOMETRY.E_LOWER_FLUX_DENSITY] = (
                            photodict[PHOTOMETRY.FLUX_DENSITY] -
                            (10.0**
                             (np.log10(photodict[PHOTOMETRY.FLUX_DENSITY]) -
                              output['model_variances'][i] / 2.5)) *
                            flux_density_unit('µJy'))
                        photodict[PHOTOMETRY.E_UPPER_FLUX_DENSITY] = (
                            10.0**(np.log10(photodict[PHOTOMETRY.FLUX_DENSITY])
                                   + output['model_variances'][i] / 2.5) *
                            flux_density_unit('µJy') -
                            photodict[PHOTOMETRY.FLUX_DENSITY])
                        photodict[PHOTOMETRY.U_FREQUENCY] = 'GHz'
                        photodict[PHOTOMETRY.U_FLUX_DENSITY] = 'µJy'
                    elif output['observation_types'][i] == 'countrate':
                        photodict[PHOTOMETRY.
                                  COUNT_RATE] = output['model_observations'][i]
                        photodict[PHOTOMETRY.E_LOWER_COUNT_RATE] = (
                            photodict[PHOTOMETRY.COUNT_RATE] -
                            (10.0**(np.log10(photodict[PHOTOMETRY.COUNT_RATE])
                                    - output['model_variances'][i] / 2.5)))
                        photodict[PHOTOMETRY.E_UPPER_COUNT_RATE] = (
                            10.0**(np.log10(photodict[PHOTOMETRY.COUNT_RATE]) +
                                   output['model_variances'][i] / 2.5) -
                            photodict[PHOTOMETRY.COUNT_RATE])
                        photodict[PHOTOMETRY.U_COUNT_RATE] = 's^-1'
                    if ('model_upper_limits' in output
                            and output['model_upper_limits'][i]):
                        photodict[PHOTOMETRY.UPPER_LIMIT] = bool(
                            output['model_upper_limits'][i])
                    if self._limiting_magnitude is not None:
                        photodict[PHOTOMETRY.SIMULATED] = True
                    if 'telescopes' in output and output['telescopes'][i]:
                        photodict[
                            PHOTOMETRY.TELESCOPE] = output['telescopes'][i]
                    if 'systems' in output and output['systems'][i]:
                        photodict[PHOTOMETRY.SYSTEM] = output['systems'][i]
                    if 'bandsets' in output and output['bandsets'][i]:
                        photodict[PHOTOMETRY.BAND_SET] = output['bandsets'][i]
                    if 'instruments' in output and output['instruments'][i]:
                        photodict[
                            PHOTOMETRY.INSTRUMENT] = output['instruments'][i]
                    if 'modes' in output and output['modes'][i]:
                        photodict[PHOTOMETRY.MODE] = output['modes'][i]
                    entry.add_photometry(compare_to_existing=False,
                                         check_for_dupes=False,
                                         **photodict)

                    uphotodict = deepcopy(photodict)
                    uphotodict[PHOTOMETRY.SOURCE] = umodelnum
                    uentry.add_photometry(compare_to_existing=False,
                                          check_for_dupes=False,
                                          **uphotodict)
            else:
                output = model.run_stack(x, root='objective')

            parameters = OrderedDict()
            derived_keys = set()
            pi = 0
            for ti, task in enumerate(model._call_stack):
                # if task not in model._free_parameters:
                #     continue
                if model._call_stack[task]['kind'] != 'parameter':
                    continue
                paramdict = OrderedDict(
                    (('latex', model._modules[task].latex()),
                     ('log', model._modules[task].is_log())))
                if task in model._free_parameters:
                    poutput = model._modules[task].process(
                        **{'fraction': x[pi]})
                    value = list(poutput.values())[0]
                    paramdict['value'] = value
                    paramdict['fraction'] = x[pi]
                    pi = pi + 1
                else:
                    if output.get(task, None) is not None:
                        paramdict['value'] = output[task]
                parameters.update({model._modules[task].name(): paramdict})
                # Dump out any derived parameter keys
                derived_keys.update(model._modules[task].get_derived_keys())

            for key in list(sorted(list(derived_keys))):
                if (output.get(key, None) is not None
                        and key not in parameters):
                    parameters.update({key: {'value': output[key]}})

            realdict = {REALIZATION.PARAMETERS: parameters}
            if probs is not None:
                realdict[REALIZATION.SCORE] = str(probs[xi])
            else:
                realdict[REALIZATION.SCORE] = str(
                    ln_likelihood(x) + ln_prior(x))
            realdict[REALIZATION.ALIAS] = str(ri)
            realdict[REALIZATION.WEIGHT] = str(weights[xi])
            entry[ENTRY.MODELS][0].add_realization(check_for_dupes=False,
                                                   **realdict)
            urealdict = deepcopy(realdict)
            uentry[ENTRY.MODELS][0].add_realization(check_for_dupes=False,
                                                    **urealdict)
        prt.message('all_walkers_written', inline=True)

        entry.sanitize()
        oentry = {self._event_name: entry._ordered(entry)}
        uentry.sanitize()
        ouentry = {self._event_name: uentry._ordered(uentry)}

        uname = '_'.join([self._event_name, entryhash, modelhash])

        if output_path and not os.path.exists(output_path):
            os.makedirs(output_path)

        if not os.path.exists(model.get_products_path()):
            os.makedirs(model.get_products_path())

        if write:
            prt.message('writing_complete')
            with open_atomic(
                    os.path.join(model.get_products_path(), 'walkers.json'),
                    'w') as flast, open_atomic(
                        os.path.join(
                            model.get_products_path(), self._event_name +
                            (('_' + suffix) if suffix else '') + '.json'),
                        'w') as feven:
                entabbed_json_dump(oentry, flast, separators=(',', ':'))
                entabbed_json_dump(oentry, feven, separators=(',', ':'))

            if save_full_chain:
                prt.message('writing_full_chain')
                with open_atomic(
                        os.path.join(model.get_products_path(), 'chain.json'),
                        'w') as flast, open_atomic(
                            os.path.join(
                                model.get_products_path(),
                                self._event_name + '_chain' +
                                (('_' + suffix) if suffix else '') + '.json'),
                            'w') as feven:
                    entabbed_json_dump(self._sampler._all_chain.tolist(),
                                       flast,
                                       separators=(',', ':'))
                    entabbed_json_dump(self._sampler._all_chain.tolist(),
                                       feven,
                                       separators=(',', ':'))

            if extra_outputs is not None:
                prt.message('writing_extras')
                with open_atomic(
                        os.path.join(model.get_products_path(), 'extras.json'),
                        'w') as flast, open_atomic(
                            os.path.join(
                                model.get_products_path(),
                                self._event_name + '_extras' +
                                (('_' + suffix) if suffix else '') + '.json'),
                            'w') as feven:
                    entabbed_json_dump(extras, flast, separators=(',', ':'))
                    entabbed_json_dump(extras, feven, separators=(',', ':'))

            prt.message('writing_model')
            with open_atomic(
                    os.path.join(model.get_products_path(), 'upload.json'),
                    'w') as flast, open_atomic(
                        os.path.join(
                            model.get_products_path(), uname +
                            (('_' + suffix) if suffix else '') + '.json'),
                        'w') as feven:
                entabbed_json_dump(ouentry, flast, separators=(',', ':'))
                entabbed_json_dump(ouentry, feven, separators=(',', ':'))

        if upload_model:
            prt.message('ul_fit', [entryhash, self._sampler._modelhash])
            upayload = entabbed_json_dumps(ouentry, separators=(',', ':'))
            try:
                dbx = dropbox.Dropbox(upload_token)
                dbx.files_upload(upayload.encode(),
                                 '/' + uname + '.json',
                                 mode=dropbox.files.WriteMode.overwrite)
                prt.message('ul_complete')
            except Exception:
                if self._test:
                    pass
                else:
                    raise

        if upload:
            for ce in self._converter.get_converted():
                dentry = Entry.init_from_file(catalog=None,
                                              name=ce[0],
                                              path=ce[1],
                                              merge=False,
                                              pop_schema=False,
                                              ignore_keys=[ENTRY.MODELS],
                                              compare_to_existing=False)

                dentry.sanitize()
                odentry = {ce[0]: uentry._ordered(dentry)}
                dpayload = entabbed_json_dumps(odentry, separators=(',', ':'))
                text = prt.message('ul_devent', [ce[0]], prt=False)
                ul_devent = prt.prompt(text, kind='bool', message=False)
                if ul_devent:
                    dpath = '/' + slugify(
                        ce[0] + '_' + dentry[ENTRY.SOURCES][0].get(
                            SOURCE.BIBCODE, dentry[ENTRY.SOURCES][0].get(
                                SOURCE.NAME, 'NOSOURCE'))) + '.json'
                    try:
                        dbx = dropbox.Dropbox(upload_token)
                        dbx.files_upload(
                            dpayload.encode(),
                            dpath,
                            mode=dropbox.files.WriteMode.overwrite)
                        prt.message('ul_complete')
                    except Exception:
                        if self._test:
                            pass
                        else:
                            raise

        return (entry, samples, probs)

    def nester(self):
        """Use nested sampling to determine posteriors."""
        pass

    def generate_dummy_data(self,
                            name,
                            max_time=1000.,
                            time_list=[],
                            band_list=[],
                            band_systems=[],
                            band_instruments=[],
                            band_bandsets=[]):
        """Generate simulated data based on priors."""
        # Just need 2 plot points for beginning and end.
        plot_points = 2

        times = list(
            sorted(
                set(list(np.linspace(0.0, max_time, plot_points)) +
                    time_list)))
        band_list_all = ['V'] if len(band_list) == 0 else band_list
        times = np.repeat(times, len(band_list_all))

        # Create lists of systems/instruments if not provided.
        if isinstance(band_systems, string_types):
            band_systems = [band_systems for x in range(len(band_list_all))]
        if isinstance(band_instruments, string_types):
            band_instruments = [
                band_instruments for x in range(len(band_list_all))
            ]
        if isinstance(band_bandsets, string_types):
            band_bandsets = [band_bandsets for x in range(len(band_list_all))]
        if len(band_systems) < len(band_list_all):
            rep_val = '' if len(band_systems) == 0 else band_systems[-1]
            band_systems = band_systems + [
                rep_val for x in range(len(band_list_all) - len(band_systems))
            ]
        if len(band_instruments) < len(band_list_all):
            rep_val = '' if len(
                band_instruments) == 0 else band_instruments[-1]
            band_instruments = band_instruments + [
                rep_val
                for x in range(len(band_list_all) - len(band_instruments))
            ]
        if len(band_bandsets) < len(band_list_all):
            rep_val = '' if len(band_bandsets) == 0 else band_bandsets[-1]
            band_bandsets = band_bandsets + [
                rep_val
                for x in range(len(band_list_all) - len(band_bandsets))
            ]

        bands = [i for s in [band_list_all for x in times] for i in s]
        systs = [i for s in [band_systems for x in times] for i in s]
        insts = [i for s in [band_instruments for x in times] for i in s]
        bsets = [i for s in [band_bandsets for x in times] for i in s]

        data = {name: {'photometry': []}}
        for ti, tim in enumerate(times):
            band = bands[ti]
            if isinstance(band, dict):
                band = band['name']

            photodict = {
                'time': tim,
                'band': band,
                'magnitude': 0.0,
                'e_magnitude': 0.0
            }
            if systs[ti]:
                photodict['system'] = systs[ti]
            if insts[ti]:
                photodict['instrument'] = insts[ti]
            if bsets[ti]:
                photodict['bandset'] = bsets[ti]
            data[name]['photometry'].append(photodict)

        return data
Example #7
0
    def fit_events(self,
                   events=[],
                   models=[],
                   max_time='',
                   time_list=[],
                   time_unit=None,
                   band_list=[],
                   band_systems=[],
                   band_instruments=[],
                   band_bandsets=[],
                   band_sampling_points=17,
                   iterations=10000,
                   num_walkers=None,
                   num_temps=1,
                   parameter_paths=['parameters.json'],
                   fracking=True,
                   frack_step=50,
                   burn=None,
                   post_burn=None,
                   gibbs=False,
                   smooth_times=-1,
                   extrapolate_time=0.0,
                   limit_fitting_mjds=False,
                   exclude_bands=[],
                   exclude_instruments=[],
                   exclude_systems=[],
                   exclude_sources=[],
                   exclude_kinds=[],
                   output_path='',
                   suffix='',
                   upload=False,
                   write=False,
                   upload_token='',
                   check_upload_quality=False,
                   variance_for_each=[],
                   user_fixed_parameters=[],
                   user_released_parameters=[],
                   convergence_type=None,
                   convergence_criteria=None,
                   save_full_chain=False,
                   draw_above_likelihood=False,
                   maximum_walltime=False,
                   start_time=False,
                   print_trees=False,
                   maximum_memory=np.inf,
                   speak=False,
                   return_fits=True,
                   extra_outputs=None,
                   walker_paths=[],
                   catalogs=[],
                   exit_on_prompt=False,
                   download_recommended_data=False,
                   local_data_only=False,
                   method=None,
                   seed=None,
                   **kwargs):
        """Fit a list of events with a list of models."""
        global model
        if start_time is False:
            start_time = time.time()

        self._seed = seed
        if seed is not None:
            np.random.seed(seed)

        self._start_time = start_time
        self._maximum_walltime = maximum_walltime
        self._maximum_memory = maximum_memory
        self._debug = False
        self._speak = speak
        self._download_recommended_data = download_recommended_data
        self._local_data_only = local_data_only

        self._draw_above_likelihood = draw_above_likelihood

        prt = self._printer

        event_list = listify(events)
        model_list = listify(models)

        if len(model_list) and not len(event_list):
            event_list = ['']

        # Exclude catalogs not included in catalog list.
        self._fetcher.add_excluded_catalogs(catalogs)

        if not len(event_list) and not len(model_list):
            prt.message('no_events_models', warning=True)

        # If the input is not a JSON file, assume it is either a list of
        # transients or that it is the data from a single transient in tabular
        # form. Try to guess the format first, and if that fails ask the user.
        self._converter = Converter(prt, require_source=upload)
        event_list = self._converter.generate_event_list(event_list)

        event_list = [x.replace('‑', '-') for x in event_list]

        entries = [[] for x in range(len(event_list))]
        ps = [[] for x in range(len(event_list))]
        lnprobs = [[] for x in range(len(event_list))]

        # Load walker data if provided a list of walker paths.
        walker_data = []

        if len(walker_paths):
            try:
                pool = MPIPool()
            except (ImportError, ValueError):
                pool = SerialPool()
            if pool.is_master():
                prt.message('walker_file')
                wfi = 0
                for walker_path in walker_paths:
                    if os.path.exists(walker_path):
                        prt.prt('  {}'.format(walker_path))
                        with codecs.open(walker_path, 'r',
                                         encoding='utf-8') as f:
                            all_walker_data = json.load(
                                f, object_pairs_hook=OrderedDict)

                        # Support both the format where all data stored in a
                        # single-item dictionary (the OAC format) and the older
                        # MOSFiT format where the data was stored in the
                        # top-level dictionary.
                        if ENTRY.NAME not in all_walker_data:
                            all_walker_data = all_walker_data[list(
                                all_walker_data.keys())[0]]

                        models = all_walker_data.get(ENTRY.MODELS, [])
                        choice = None
                        if len(models) > 1:
                            model_opts = [
                                '{}-{}-{}'.format(x['code'], x['name'],
                                                  x['date']) for x in models
                            ]
                            choice = prt.prompt('select_model_walkers',
                                                kind='select',
                                                message=True,
                                                options=model_opts)
                            choice = model_opts.index(choice)
                        elif len(models) == 1:
                            choice = 0

                        if choice is not None:
                            walker_data.extend([[
                                wfi, x[REALIZATION.PARAMETERS],
                                x.get(REALIZATION.WEIGHT)
                            ] for x in models[choice][MODEL.REALIZATIONS]])

                        for i in range(len(walker_data)):
                            if walker_data[i][2] is not None:
                                walker_data[i][2] = float(walker_data[i][2])

                        if not len(walker_data):
                            prt.message('no_walker_data')
                    else:
                        prt.message('no_walker_data')
                        if self._offline:
                            prt.message('omit_offline')
                        raise RuntimeError
                    wfi = wfi + 1

                for rank in range(1, pool.size + 1):
                    pool.comm.send(walker_data, dest=rank, tag=3)
            else:
                walker_data = pool.comm.recv(source=0, tag=3)
                pool.wait()

            if pool.is_master():
                pool.close()

        self._event_name = 'Batch'
        self._event_path = ''
        self._event_data = {}

        try:
            pool = MPIPool()
        except (ImportError, ValueError):
            pool = SerialPool()
        if pool.is_master():
            fetched_events = self._fetcher.fetch(
                event_list,
                offline=self._offline,
                prefer_cache=self._prefer_cache)

            for rank in range(1, pool.size + 1):
                pool.comm.send(fetched_events, dest=rank, tag=0)
            pool.close()
        else:
            fetched_events = pool.comm.recv(source=0, tag=0)
            pool.wait()

        for ei, event in enumerate(fetched_events):
            if event is not None:
                self._event_name = event.get('name', 'Batch')
                self._event_path = event.get('path', '')
                if not self._event_path:
                    continue
                self._event_data = self._fetcher.load_data(event)
                if not self._event_data:
                    continue

            if model_list:
                lmodel_list = model_list
            else:
                lmodel_list = ['']

            entries[ei] = [None for y in range(len(lmodel_list))]
            ps[ei] = [None for y in range(len(lmodel_list))]
            lnprobs[ei] = [None for y in range(len(lmodel_list))]

            if (event is not None and
                (not self._event_data or ENTRY.PHOTOMETRY
                 not in self._event_data[list(self._event_data.keys())[0]])):
                prt.message('no_photometry', [self._event_name])
                continue

            for mi, mod_name in enumerate(lmodel_list):
                for parameter_path in parameter_paths:
                    try:
                        pool = MPIPool()
                    except (ImportError, ValueError):
                        pool = SerialPool()
                    self._model = Model(model=mod_name,
                                        data=self._event_data,
                                        parameter_path=parameter_path,
                                        output_path=output_path,
                                        wrap_length=self._wrap_length,
                                        test=self._test,
                                        printer=prt,
                                        fitter=self,
                                        pool=pool,
                                        print_trees=print_trees)

                    if not self._model._model_name:
                        prt.message('no_models_avail', [self._event_name],
                                    warning=True)
                        continue

                    if not event:
                        prt.message('gen_dummy')
                        self._event_name = mod_name
                        gen_args = {
                            'name': mod_name,
                            'max_time': max_time,
                            'time_list': time_list,
                            'band_list': band_list,
                            'band_systems': band_systems,
                            'band_instruments': band_instruments,
                            'band_bandsets': band_bandsets
                        }
                        self._event_data = self.generate_dummy_data(**gen_args)

                    success = False
                    alt_name = None
                    while not success:
                        self._model.reset_unset_recommended_keys()
                        success = self._model.load_data(
                            self._event_data,
                            event_name=self._event_name,
                            smooth_times=smooth_times,
                            extrapolate_time=extrapolate_time,
                            limit_fitting_mjds=limit_fitting_mjds,
                            exclude_bands=exclude_bands,
                            exclude_instruments=exclude_instruments,
                            exclude_systems=exclude_systems,
                            exclude_sources=exclude_sources,
                            exclude_kinds=exclude_kinds,
                            time_list=time_list,
                            time_unit=time_unit,
                            band_list=band_list,
                            band_systems=band_systems,
                            band_instruments=band_instruments,
                            band_bandsets=band_bandsets,
                            band_sampling_points=band_sampling_points,
                            variance_for_each=variance_for_each,
                            user_fixed_parameters=user_fixed_parameters,
                            user_released_parameters=user_released_parameters,
                            pool=pool)

                        if not success:
                            break

                        if self._local_data_only:
                            break

                        # If our data is missing recommended keys, offer the
                        # user option to pull the missing data from online and
                        # merge with existing data.
                        urk = self._model.get_unset_recommended_keys()
                        ptxt = prt.text('acquire_recommended',
                                        [', '.join(list(urk))])
                        while event and len(urk) and (
                                alt_name or self._download_recommended_data
                                or prt.prompt(ptxt, [', '.join(urk)],
                                              kind='bool')):
                            try:
                                pool = MPIPool()
                            except (ImportError, ValueError):
                                pool = SerialPool()
                            if pool.is_master():
                                en = (alt_name
                                      if alt_name else self._event_name)
                                extra_event = self._fetcher.fetch(
                                    en,
                                    offline=self._offline,
                                    prefer_cache=self._prefer_cache)[0]
                                extra_data = self._fetcher.load_data(
                                    extra_event)

                                for rank in range(1, pool.size + 1):
                                    pool.comm.send(extra_data,
                                                   dest=rank,
                                                   tag=4)
                                pool.close()
                            else:
                                extra_data = pool.comm.recv(source=0, tag=4)
                                pool.wait()

                            if extra_data is not None:
                                extra_data = extra_data[list(
                                    extra_data.keys())[0]]

                                for key in urk:
                                    new_val = extra_data.get(key)
                                    self._event_data[list(
                                        self._event_data.keys())
                                                     [0]][key] = new_val
                                    if new_val is not None and len(new_val):
                                        prt.message('extra_value', [
                                            key,
                                            str(new_val[0].get(QUANTITY.VALUE))
                                        ])
                                success = False
                                prt.message('reloading_merged')
                                break
                            else:
                                text = prt.text('extra_not_found',
                                                [self._event_name])
                                alt_name = prt.prompt(text, kind='string')
                                if not alt_name:
                                    break

                    if success:
                        self._walker_data = walker_data

                        entry, p, lnprob = self.fit_data(
                            event_name=self._event_name,
                            method=method,
                            iterations=iterations,
                            num_walkers=num_walkers,
                            num_temps=num_temps,
                            burn=burn,
                            post_burn=post_burn,
                            fracking=fracking,
                            frack_step=frack_step,
                            gibbs=gibbs,
                            pool=pool,
                            output_path=output_path,
                            suffix=suffix,
                            write=write,
                            upload=upload,
                            upload_token=upload_token,
                            check_upload_quality=check_upload_quality,
                            convergence_type=convergence_type,
                            convergence_criteria=convergence_criteria,
                            save_full_chain=save_full_chain,
                            extra_outputs=extra_outputs)
                        if return_fits:
                            entries[ei][mi] = deepcopy(entry)
                            ps[ei][mi] = deepcopy(p)
                            lnprobs[ei][mi] = deepcopy(lnprob)

                    if pool.is_master():
                        pool.close()

                    # Remove global model variable and garbage collect.
                    try:
                        model
                    except NameError:
                        pass
                    else:
                        del (model)
                    del (self._model)
                    gc.collect()

        return (entries, ps, lnprobs)
Example #8
0
    def __init__(self,
                 parameter_path='parameters.json',
                 model='',
                 data={},
                 wrap_length=100,
                 output_path='',
                 pool=None,
                 test=False,
                 printer=None,
                 fitter=None,
                 print_trees=False):
        """Initialize `Model` object."""
        from mosfit.fitter import Fitter

        self._model_name = model
        self._parameter_path = parameter_path
        self._output_path = output_path
        self._pool = SerialPool() if pool is None else pool
        self._is_master = pool.is_master() if pool else False
        self._wrap_length = wrap_length
        self._print_trees = print_trees
        self._inflect = inflect.engine()
        self._test = test
        self._inflections = {}
        self._references = OrderedDict()
        self._free_parameters = []
        self._user_fixed_parameters = []
        self._user_released_parameters = []
        self._kinds_needed = set()
        self._kinds_supported = set()

        self._draw_limit_reached = False

        self._fitter = Fitter() if not fitter else fitter
        self._printer = self._fitter._printer if not printer else printer

        prt = self._printer

        self._dir_path = os.path.dirname(os.path.realpath(__file__))

        # Load suggested model associations for transient types.
        if os.path.isfile(os.path.join('models', 'types.json')):
            types_path = os.path.join('models', 'types.json')
        else:
            types_path = os.path.join(self._dir_path, 'models', 'types.json')
        with open(types_path, 'r') as f:
            model_types = json.load(f, object_pairs_hook=OrderedDict)

        # Create list of all available models.
        all_models = set()
        if os.path.isdir('models'):
            all_models |= set(next(os.walk('models'))[1])
        models_path = os.path.join(self._dir_path, 'models')
        if os.path.isdir(models_path):
            all_models |= set(next(os.walk(models_path))[1])
        all_models = list(sorted(list(all_models)))

        if not self._model_name:
            claimed_type = None
            try:
                claimed_type = list(
                    data.values())[0]['claimedtype'][0][QUANTITY.VALUE]
            except Exception:
                prt.message('no_model_type', warning=True)

            all_models_txt = prt.text('all_models')
            suggested_models_txt = prt.text('suggested_models', [claimed_type])
            another_model_txt = prt.text('another_model')

            type_options = model_types.get(claimed_type,
                                           []) if claimed_type else []
            if not type_options:
                type_options = all_models
                model_prompt_txt = all_models_txt
            else:
                type_options.append(another_model_txt)
                model_prompt_txt = suggested_models_txt
            if not type_options:
                prt.message('no_model_for_type', warning=True)
            else:
                while not self._model_name:
                    if self._test:
                        self._model_name = type_options[0]
                    else:
                        sel = self._printer.prompt(
                            model_prompt_txt,
                            kind='option',
                            options=type_options,
                            message=False,
                            default='n',
                            none_string=prt.text('none_above_models'))
                        if sel is not None:
                            self._model_name = type_options[int(sel) - 1]
                    if not self._model_name:
                        break
                    if self._model_name == another_model_txt:
                        type_options = all_models
                        model_prompt_txt = all_models_txt
                        self._model_name = None

        if not self._model_name:
            return

        # Load the basic model file.
        if os.path.isfile(os.path.join('models', 'model.json')):
            basic_model_path = os.path.join('models', 'model.json')
        else:
            basic_model_path = os.path.join(self._dir_path, 'models',
                                            'model.json')

        with open(basic_model_path, 'r') as f:
            self._model = json.load(f, object_pairs_hook=OrderedDict)

        # Load the model file.
        model = self._model_name
        model_dir = self._model_name

        if '.json' in self._model_name:
            model_dir = self._model_name.split('.json')[0]
        else:
            model = self._model_name + '.json'

        if os.path.isfile(model):
            model_path = model
        else:
            # Look in local hierarchy first
            if os.path.isfile(os.path.join('models', model_dir, model)):
                model_path = os.path.join('models', model_dir, model)
            else:
                model_path = os.path.join(self._dir_path, 'models', model_dir,
                                          model)

        with open(model_path, 'r') as f:
            self._model.update(json.load(f, object_pairs_hook=OrderedDict))

        # Find @ tags, store them, and prune them from `_model`.
        for tag in list(self._model.keys()):
            if tag.startswith('@'):
                if tag == '@references':
                    self._references.setdefault('base',
                                                []).extend(self._model[tag])
                del self._model[tag]

        # with open(os.path.join(
        #         self.get_products_path(),
        #         self._model_name + '.json'), 'w') as f:
        #     json.dump(self._model, f)

        # Load model parameter file.
        model_pp = os.path.join(self._dir_path, 'models', model_dir,
                                'parameters.json')

        pp = ''

        local_pp = (self._parameter_path if '/' in self._parameter_path else
                    os.path.join('models', model_dir, self._parameter_path))

        if os.path.isfile(local_pp):
            selected_pp = local_pp
        else:
            selected_pp = os.path.join(self._dir_path, 'models', model_dir,
                                       self._parameter_path)

        # First try user-specified path
        if self._parameter_path and os.path.isfile(self._parameter_path):
            pp = self._parameter_path
        # Then try directory we are running from
        elif os.path.isfile('parameters.json'):
            pp = 'parameters.json'
        # Then try the model directory, with the user-specified name
        elif os.path.isfile(selected_pp):
            pp = selected_pp
        # Finally try model folder
        elif os.path.isfile(model_pp):
            pp = model_pp
        else:
            raise ValueError(prt.text('no_parameter_file'))

        if self._is_master:
            prt.message('files', [basic_model_path, model_path, pp],
                        wrapped=False)

        with open(pp, 'r') as f:
            self._parameter_json = json.load(f, object_pairs_hook=OrderedDict)
        self._modules = OrderedDict()
        self._bands = []
        self._instruments = []
        self._telescopes = []

        # Load the call tree for the model. Work our way in reverse from the
        # observables, first constructing a tree for each observable and then
        # combining trees.
        root_kinds = ['output', 'objective']

        self._trees = OrderedDict()
        self._simple_trees = OrderedDict()
        self.construct_trees(self._model,
                             self._trees,
                             self._simple_trees,
                             kinds=root_kinds)

        if self._print_trees:
            self._printer.prt('Dependency trees:\n', wrapped=True)
            self._printer.tree(self._simple_trees)

        unsorted_call_stack = OrderedDict()
        self._max_depth_all = -1
        for tag in self._model:
            model_tag = self._model[tag]
            roots = []
            if model_tag['kind'] in root_kinds:
                max_depth = 0
                roots = [model_tag['kind']]
            else:
                max_depth = -1
                for tag2 in self._trees:
                    if self.in_tree(tag, self._trees[tag2]):
                        roots.extend(self._trees[tag2]['roots'])
                    depth = self.get_max_depth(tag, self._trees[tag2],
                                               max_depth)
                    if depth > max_depth:
                        max_depth = depth
                    if depth > self._max_depth_all:
                        self._max_depth_all = depth
            roots = list(sorted(set(roots)))
            new_entry = deepcopy(model_tag)
            new_entry['roots'] = roots
            if 'children' in new_entry:
                del new_entry['children']
            new_entry['depth'] = max_depth
            unsorted_call_stack[tag] = new_entry
        # print(unsorted_call_stack)

        # Currently just have one call stack for all products, can be wasteful
        # if only using some products.
        self._call_stack = OrderedDict()
        for depth in range(self._max_depth_all, -1, -1):
            for task in unsorted_call_stack:
                if unsorted_call_stack[task]['depth'] == depth:
                    self._call_stack[task] = unsorted_call_stack[task]

        # with open(os.path.join(
        #         self.get_products_path(),
        #         self._model_name + '-stack.json'), 'w') as f:
        #     json.dump(self._call_stack, f)

        for task in self._call_stack:
            cur_task = self._call_stack[task]
            mod_name = cur_task.get('class', task)
            if cur_task['kind'] == 'parameter' and task in self._parameter_json:
                cur_task.update(self._parameter_json[task])
            self._modules[task] = self._load_task_module(task)
            if mod_name == 'photometry':
                self._telescopes = self._modules[task].telescopes()
                self._instruments = self._modules[task].instruments()
                self._bands = self._modules[task].bands()
            self._modules[task].set_attributes(cur_task)

        # Look forward to see which modules want dense arrays.
        for task in self._call_stack:
            for ftask in self._call_stack:
                if (task != ftask and self._call_stack[ftask]['depth'] <
                        self._call_stack[task]['depth']
                        and self._modules[ftask]._wants_dense):
                    self._modules[ftask]._provide_dense = True

        # Count free parameters.
        self.determine_free_parameters()
Example #9
0
class Model(object):
    """Define a semi-analytical model to fit transients with."""

    MODEL_PRODUCTS_DIR = 'products'
    MIN_WAVE_FRAC_DIFF = 0.1
    DRAW_LIMIT = 10

    # class outClass(object):
    #     pass

    def __init__(self,
                 parameter_path='parameters.json',
                 model='',
                 data={},
                 wrap_length=100,
                 output_path='',
                 pool=None,
                 test=False,
                 printer=None,
                 fitter=None,
                 print_trees=False):
        """Initialize `Model` object."""
        from mosfit.fitter import Fitter

        self._model_name = model
        self._parameter_path = parameter_path
        self._output_path = output_path
        self._pool = SerialPool() if pool is None else pool
        self._is_master = pool.is_master() if pool else False
        self._wrap_length = wrap_length
        self._print_trees = print_trees
        self._inflect = inflect.engine()
        self._test = test
        self._inflections = {}
        self._references = OrderedDict()
        self._free_parameters = []
        self._user_fixed_parameters = []
        self._user_released_parameters = []
        self._kinds_needed = set()
        self._kinds_supported = set()

        self._draw_limit_reached = False

        self._fitter = Fitter() if not fitter else fitter
        self._printer = self._fitter._printer if not printer else printer

        prt = self._printer

        self._dir_path = os.path.dirname(os.path.realpath(__file__))

        # Load suggested model associations for transient types.
        if os.path.isfile(os.path.join('models', 'types.json')):
            types_path = os.path.join('models', 'types.json')
        else:
            types_path = os.path.join(self._dir_path, 'models', 'types.json')
        with open(types_path, 'r') as f:
            model_types = json.load(f, object_pairs_hook=OrderedDict)

        # Create list of all available models.
        all_models = set()
        if os.path.isdir('models'):
            all_models |= set(next(os.walk('models'))[1])
        models_path = os.path.join(self._dir_path, 'models')
        if os.path.isdir(models_path):
            all_models |= set(next(os.walk(models_path))[1])
        all_models = list(sorted(list(all_models)))

        if not self._model_name:
            claimed_type = None
            try:
                claimed_type = list(
                    data.values())[0]['claimedtype'][0][QUANTITY.VALUE]
            except Exception:
                prt.message('no_model_type', warning=True)

            all_models_txt = prt.text('all_models')
            suggested_models_txt = prt.text('suggested_models', [claimed_type])
            another_model_txt = prt.text('another_model')

            type_options = model_types.get(claimed_type,
                                           []) if claimed_type else []
            if not type_options:
                type_options = all_models
                model_prompt_txt = all_models_txt
            else:
                type_options.append(another_model_txt)
                model_prompt_txt = suggested_models_txt
            if not type_options:
                prt.message('no_model_for_type', warning=True)
            else:
                while not self._model_name:
                    if self._test:
                        self._model_name = type_options[0]
                    else:
                        sel = self._printer.prompt(
                            model_prompt_txt,
                            kind='option',
                            options=type_options,
                            message=False,
                            default='n',
                            none_string=prt.text('none_above_models'))
                        if sel is not None:
                            self._model_name = type_options[int(sel) - 1]
                    if not self._model_name:
                        break
                    if self._model_name == another_model_txt:
                        type_options = all_models
                        model_prompt_txt = all_models_txt
                        self._model_name = None

        if not self._model_name:
            return

        # Load the basic model file.
        if os.path.isfile(os.path.join('models', 'model.json')):
            basic_model_path = os.path.join('models', 'model.json')
        else:
            basic_model_path = os.path.join(self._dir_path, 'models',
                                            'model.json')

        with open(basic_model_path, 'r') as f:
            self._model = json.load(f, object_pairs_hook=OrderedDict)

        # Load the model file.
        model = self._model_name
        model_dir = self._model_name

        if '.json' in self._model_name:
            model_dir = self._model_name.split('.json')[0]
        else:
            model = self._model_name + '.json'

        if os.path.isfile(model):
            model_path = model
        else:
            # Look in local hierarchy first
            if os.path.isfile(os.path.join('models', model_dir, model)):
                model_path = os.path.join('models', model_dir, model)
            else:
                model_path = os.path.join(self._dir_path, 'models', model_dir,
                                          model)

        with open(model_path, 'r') as f:
            self._model.update(json.load(f, object_pairs_hook=OrderedDict))

        # Find @ tags, store them, and prune them from `_model`.
        for tag in list(self._model.keys()):
            if tag.startswith('@'):
                if tag == '@references':
                    self._references.setdefault('base',
                                                []).extend(self._model[tag])
                del self._model[tag]

        # with open(os.path.join(
        #         self.get_products_path(),
        #         self._model_name + '.json'), 'w') as f:
        #     json.dump(self._model, f)

        # Load model parameter file.
        model_pp = os.path.join(self._dir_path, 'models', model_dir,
                                'parameters.json')

        pp = ''

        local_pp = (self._parameter_path if '/' in self._parameter_path else
                    os.path.join('models', model_dir, self._parameter_path))

        if os.path.isfile(local_pp):
            selected_pp = local_pp
        else:
            selected_pp = os.path.join(self._dir_path, 'models', model_dir,
                                       self._parameter_path)

        # First try user-specified path
        if self._parameter_path and os.path.isfile(self._parameter_path):
            pp = self._parameter_path
        # Then try directory we are running from
        elif os.path.isfile('parameters.json'):
            pp = 'parameters.json'
        # Then try the model directory, with the user-specified name
        elif os.path.isfile(selected_pp):
            pp = selected_pp
        # Finally try model folder
        elif os.path.isfile(model_pp):
            pp = model_pp
        else:
            raise ValueError(prt.text('no_parameter_file'))

        if self._is_master:
            prt.message('files', [basic_model_path, model_path, pp],
                        wrapped=False)

        with open(pp, 'r') as f:
            self._parameter_json = json.load(f, object_pairs_hook=OrderedDict)
        self._modules = OrderedDict()
        self._bands = []
        self._instruments = []
        self._telescopes = []

        # Load the call tree for the model. Work our way in reverse from the
        # observables, first constructing a tree for each observable and then
        # combining trees.
        root_kinds = ['output', 'objective']

        self._trees = OrderedDict()
        self._simple_trees = OrderedDict()
        self.construct_trees(self._model,
                             self._trees,
                             self._simple_trees,
                             kinds=root_kinds)

        if self._print_trees:
            self._printer.prt('Dependency trees:\n', wrapped=True)
            self._printer.tree(self._simple_trees)

        unsorted_call_stack = OrderedDict()
        self._max_depth_all = -1
        for tag in self._model:
            model_tag = self._model[tag]
            roots = []
            if model_tag['kind'] in root_kinds:
                max_depth = 0
                roots = [model_tag['kind']]
            else:
                max_depth = -1
                for tag2 in self._trees:
                    if self.in_tree(tag, self._trees[tag2]):
                        roots.extend(self._trees[tag2]['roots'])
                    depth = self.get_max_depth(tag, self._trees[tag2],
                                               max_depth)
                    if depth > max_depth:
                        max_depth = depth
                    if depth > self._max_depth_all:
                        self._max_depth_all = depth
            roots = list(sorted(set(roots)))
            new_entry = deepcopy(model_tag)
            new_entry['roots'] = roots
            if 'children' in new_entry:
                del new_entry['children']
            new_entry['depth'] = max_depth
            unsorted_call_stack[tag] = new_entry
        # print(unsorted_call_stack)

        # Currently just have one call stack for all products, can be wasteful
        # if only using some products.
        self._call_stack = OrderedDict()
        for depth in range(self._max_depth_all, -1, -1):
            for task in unsorted_call_stack:
                if unsorted_call_stack[task]['depth'] == depth:
                    self._call_stack[task] = unsorted_call_stack[task]

        # with open(os.path.join(
        #         self.get_products_path(),
        #         self._model_name + '-stack.json'), 'w') as f:
        #     json.dump(self._call_stack, f)

        for task in self._call_stack:
            cur_task = self._call_stack[task]
            mod_name = cur_task.get('class', task)
            if cur_task['kind'] == 'parameter' and task in self._parameter_json:
                cur_task.update(self._parameter_json[task])
            self._modules[task] = self._load_task_module(task)
            if mod_name == 'photometry':
                self._telescopes = self._modules[task].telescopes()
                self._instruments = self._modules[task].instruments()
                self._bands = self._modules[task].bands()
            self._modules[task].set_attributes(cur_task)

        # Look forward to see which modules want dense arrays.
        for task in self._call_stack:
            for ftask in self._call_stack:
                if (task != ftask and self._call_stack[ftask]['depth'] <
                        self._call_stack[task]['depth']
                        and self._modules[ftask]._wants_dense):
                    self._modules[ftask]._provide_dense = True

        # Count free parameters.
        self.determine_free_parameters()

    def get_products_path(self):
        """Get path to products."""
        return os.path.join(self._output_path, self.MODEL_PRODUCTS_DIR)

    def _load_task_module(self, task, call_stack=None):
        if not call_stack:
            call_stack = self._call_stack
        cur_task = call_stack[task]
        kinds = self._inflect.plural(cur_task['kind'])
        mod_name = cur_task.get('class', task).lower()
        mod_path = os.path.join('modules', kinds, mod_name + '.py')
        if not os.path.isfile(mod_path):
            mod_path = os.path.join(self._dir_path, 'modules', kinds,
                                    mod_name + '.py')
        mod_name = 'mosfit.modules.' + kinds + mod_name
        try:
            mod = importlib.machinery.SourceFileLoader(mod_name,
                                                       mod_path).load_module()
        except AttributeError:
            import imp
            mod = imp.load_source(mod_name, mod_path)

        class_name = [
            x[0] for x in inspect.getmembers(mod, inspect.isclass)
            if issubclass(x[1], Module) and x[1].__module__ == mod.__name__
        ][0]
        mod_class = getattr(mod, class_name)
        return mod_class(name=task,
                         model=self,
                         fitter=self._fitter,
                         **cur_task)

    def load_data(self,
                  data,
                  event_name='',
                  smooth_times=-1,
                  extrapolate_time=0.0,
                  limit_fitting_mjds=False,
                  exclude_bands=[],
                  exclude_instruments=[],
                  exclude_systems=[],
                  exclude_sources=[],
                  exclude_kinds=[],
                  time_unit=None,
                  time_list=[],
                  band_list=[],
                  band_systems=[],
                  band_instruments=[],
                  band_bandsets=[],
                  band_sampling_points=25,
                  variance_for_each=[],
                  user_fixed_parameters=[],
                  user_released_parameters=[],
                  pool=None):
        """Load the data for the specified event."""
        if pool is not None:
            self._pool = pool
            self._printer._pool = pool

        prt = self._printer

        prt.message('loading_data', inline=True)

        # Fix user-specified parameters.
        fixed_parameters = []
        released_parameters = []
        for task in self._call_stack:
            for fi, param in enumerate(user_fixed_parameters):
                if (task == param
                        or self._call_stack[task].get('class', '') == param):
                    fixed_parameters.append(task)
                    if fi < len(user_fixed_parameters) - 1 and is_number(
                            user_fixed_parameters[fi + 1]):
                        value = float(user_fixed_parameters[fi + 1])
                        if value not in self._call_stack:
                            self._call_stack[task]['value'] = value
                    if 'min_value' in self._call_stack[task]:
                        del self._call_stack[task]['min_value']
                    if 'max_value' in self._call_stack[task]:
                        del self._call_stack[task]['max_value']
                    self._modules[task].fix_value(
                        self._call_stack[task]['value'])
            for fi, param in enumerate(user_released_parameters):
                if (task == param
                        or self._call_stack[task].get('class', '') == param):
                    released_parameters.append(task)

        self.determine_free_parameters(fixed_parameters, released_parameters)

        for ti, task in enumerate(self._call_stack):
            cur_task = self._call_stack[task]
            self._modules[task].set_event_name(event_name)
            new_per = np.round(100.0 * float(ti) / len(self._call_stack))
            prt.message('loading_task', [task, new_per], inline=True)
            self._kinds_supported |= set(cur_task.get('supports', []))
            if cur_task['kind'] == 'data':
                success = self._modules[task].set_data(
                    data,
                    req_key_values=OrderedDict(
                        (('band', self._bands), ('instrument',
                                                 self._instruments),
                         ('telescope', self._telescopes))),
                    subtract_minimum_keys=['times'],
                    smooth_times=smooth_times,
                    extrapolate_time=extrapolate_time,
                    limit_fitting_mjds=limit_fitting_mjds,
                    exclude_bands=exclude_bands,
                    exclude_instruments=exclude_instruments,
                    exclude_systems=exclude_systems,
                    exclude_sources=exclude_sources,
                    exclude_kinds=exclude_kinds,
                    time_unit=time_unit,
                    time_list=time_list,
                    band_list=band_list,
                    band_systems=band_systems,
                    band_instruments=band_instruments,
                    band_bandsets=band_bandsets)
                if not success:
                    return False
                fixed_parameters.extend(
                    self._modules[task].get_data_determined_parameters())
            elif cur_task['kind'] == 'sed':
                self._modules[task].set_data(band_sampling_points)
            self._kinds_needed |= self._modules[task]._kinds_needed

        # Find unsupported wavebands and report to user.
        unsupported_kinds = self._kinds_needed - self._kinds_supported
        if unsupported_kinds:
            prt.message('using_unsupported_kinds' if 'none' in exclude_kinds
                        else 'ignoring_unsupported_kinds',
                        [', '.join(sorted(unsupported_kinds))],
                        warning=True)

        # Determine free parameters again as setting data may have fixed some
        # more.
        self.determine_free_parameters(fixed_parameters, released_parameters)

        self.exchange_requests()

        prt.message('finding_bands', inline=True)

        # Run through once to set all inits.
        for root in ['output', 'objective']:
            outputs = self.run_stack(
                [0.0 for x in range(self._num_free_parameters)], root=root)

        # Create any data-dependent free parameters.
        self.adjust_fixed_parameters(variance_for_each, outputs)

        # Determine free parameters again as above may have changed them.
        self.determine_free_parameters(fixed_parameters, released_parameters)

        self.determine_number_of_measurements()

        self.exchange_requests()

        # Reset modules
        for task in self._call_stack:
            self._modules[task].reset_preprocessed(['photometry'])

        # Run through inits once more.
        for root in ['output', 'objective']:
            outputs = self.run_stack(
                [0.0 for x in range(self._num_free_parameters)], root=root)

        # Collect observed band info
        if self._pool.is_master() and 'photometry' in self._modules:
            prt.message('bands_used')
            bis = list(
                filter(lambda a: a != -1,
                       sorted(set(outputs['all_band_indices']))))
            ois = []
            for bi in bis:
                ois.append(
                    any([
                        y for x, y in zip(outputs['all_band_indices'],
                                          outputs['observed']) if x == bi
                    ]))
            band_len = max([
                len(self._modules['photometry']._unique_bands[bi]['origin'])
                for bi in bis
            ])
            filts = self._modules['photometry']
            ubs = filts._unique_bands
            filterarr = [
                (ubs[bis[i]]['systems'], ubs[bis[i]]['bandsets'],
                 filts._average_wavelengths[bis[i]],
                 filts._band_offsets[bis[i]], filts._band_kinds[bis[i]],
                 filts._band_names[bis[i]], ois[i], bis[i])
                for i in range(len(bis))
            ]
            filterrows = [
                (' ' + (' ' if s[-2] else '*') +
                 ubs[s[-1]]['origin'].ljust(band_len) + ' [' + ', '.join(
                     list(
                         filter(None, ('Bandset: ' + s[1] if s[1] else '',
                                       'System: ' + s[0] if s[0] else '',
                                       'AB offset: ' + pretty_num(s[3]) if
                                       (s[4] == 'magnitude' and s[0] != 'AB')
                                       else '')))) + ']').replace(' []', '')
                for s in list(sorted(filterarr))
            ]
            if not all(ois):
                filterrows.append(prt.text('not_observed'))
            prt.prt('\n'.join(filterrows))

            single_freq_inst = list(
                sorted(
                    set(
                        np.array(outputs['instruments'])[np.array(
                            outputs['all_band_indices']) == -1])))

            if len(single_freq_inst):
                prt.message('single_freq')
            for inst in single_freq_inst:
                prt.prt('  {}'.format(inst))

            if ('unmatched_bands' in outputs
                    and 'unmatched_instruments' in outputs):
                prt.message('unmatched_obs', warning=True)
                prt.prt(', '.join([
                    '{} [{}]'.format(x[0], x[1])
                    if x[0] and x[1] else x[0] if not x[1] else x[1]
                    for x in list(
                        set(
                            zip(outputs['unmatched_bands'],
                                outputs['unmatched_instruments'])))
                ]),
                        warning=True,
                        prefix=False,
                        wrapped=True)

        return True

    def adjust_fixed_parameters(self, variance_for_each=[], output={}):
        """Create free parameters that depend on loaded data."""
        unique_band_indices = list(
            sorted(set(output.get('all_band_indices', []))))
        needs_general_variance = any(
            np.array(output.get('all_band_indices', [])) < 0)

        new_call_stack = OrderedDict()
        for task in self._call_stack:
            cur_task = self._call_stack[task]
            vfe = listify(variance_for_each)
            if task == 'variance' and 'band' in vfe:
                vfi = vfe.index('band') + 1
                mwfd = float(vfe[vfi]) if (vfi < len(vfe) and is_number(
                    vfe[vfi])) else self.MIN_WAVE_FRAC_DIFF
                # Find photometry in call stack.
                ptask = None
                for ptask in self._call_stack:
                    if ptask == 'photometry':
                        awaves = self._modules[ptask].average_wavelengths(
                            unique_band_indices)
                        abands = self._modules[ptask].bands(
                            unique_band_indices)
                        band_pairs = list(sorted(zip(awaves, abands)))
                        break
                owav = 0.0
                variance_bands = []
                for (awav, band) in band_pairs:
                    wave_frac_diff = abs(awav - owav) / (awav + owav)
                    if wave_frac_diff < mwfd:
                        continue
                    new_task_name = '-'.join([task, 'band', band])
                    if new_task_name in self._call_stack:
                        continue
                    new_task = deepcopy(cur_task)
                    new_call_stack[new_task_name] = new_task
                    if 'latex' in new_task:
                        new_task['latex'] += '_{\\rm ' + band + '}'
                    new_call_stack[new_task_name] = new_task
                    self._modules[new_task_name] = self._load_task_module(
                        new_task_name, call_stack=new_call_stack)
                    owav = awav
                    variance_bands.append([awav, band])
                if needs_general_variance:
                    new_call_stack[task] = deepcopy(cur_task)
                if self._pool.is_master():
                    self._printer.message(
                        'anchoring_variances',
                        [', '.join([x[1] for x in variance_bands])],
                        wrapped=True)
                self._modules[ptask].set_variance_bands(variance_bands)
            else:
                new_call_stack[task] = deepcopy(cur_task)
            # Fixed any variables to be fixed if any conditional inputs are
            # fixed by the data.
            # if any([listify(x)[-1] == 'conditional'
            #         for x in cur_task.get('inputs', [])]):
        self._call_stack = new_call_stack

        for task in reversed(self._call_stack):
            cur_task = self._call_stack[task]
            for inp in cur_task.get('inputs', []):
                other = listify(inp)[0]
                if (cur_task['kind'] == 'parameter'
                        and output.get(other, None) is not None):
                    if (not self._modules[other]._fixed
                            or self._modules[other]._fixed_by_user):
                        self._modules[task]._fixed = True
                    self._modules[task]._derived_keys = list(
                        set(self._modules[task]._derived_keys + [task]))

    def determine_number_of_measurements(self):
        """Estimate the number of measurements."""
        self._num_measurements = 0
        for task in self._call_stack:
            cur_task = self._call_stack[task]
            if cur_task['kind'] == 'data':
                self._num_measurements += len(
                    self._modules[task]._data['times'])

    def determine_free_parameters(self,
                                  extra_fixed_parameters=[],
                                  extra_released_parameters=[]):
        """Generate `_free_parameters` and `_num_free_parameters`."""
        self._free_parameters = []
        self._user_fixed_parameters = []
        self._num_variances = 0
        for task in self._call_stack:
            cur_task = self._call_stack[task]
            if (task in extra_released_parameters
                    or (task not in extra_fixed_parameters
                        and cur_task['kind'] == 'parameter'
                        and 'min_value' in cur_task and 'max_value' in cur_task
                        and cur_task['min_value'] != cur_task['max_value']
                        and not self._modules[task]._fixed)):
                self._free_parameters.append(task)
                if cur_task.get('class', '') == 'variance':
                    self._num_variances += 1
            elif (cur_task['kind'] == 'parameter'
                  and task in extra_fixed_parameters):
                self._user_fixed_parameters.append(task)
        self._num_free_parameters = len(self._free_parameters)

    def is_parameter_fixed_by_user(self, parameter):
        """Return whether a parameter is fixed by the user."""
        return parameter in self._user_fixed_parameters

    def get_num_free_parameters(self):
        """Return number of free parameters."""
        return self._num_free_parameters

    def exchange_requests(self):
        """Exchange requests between modules."""
        for task in reversed(self._call_stack):
            cur_task = self._call_stack[task]
            if 'requests' in cur_task:
                requests = OrderedDict()
                reqs = cur_task['requests']
                for req in reqs:
                    if reqs[req] not in self._modules:
                        raise RuntimeError(
                            'Request cannot be satisfied because module '
                            '`{}` could not be found.'.format(reqs[req]))
                    requests[req] = self._modules[reqs[req]].send_request(req)
                self._modules[task].receive_requests(**requests)

    def frack(self, arg):
        """Perform fracking upon a single walker.

        Uses a randomly-selected global or local minimization method.
        """
        x = np.array(arg[0])
        step = 1.0
        seed = arg[1]
        np.random.seed(seed)
        my_choice = np.random.choice(range(3))
        # my_choice = 0
        my_method = ['L-BFGS-B', 'TNC', 'SLSQP'][my_choice]
        opt_dict = {'disp': False, 'approx_grad': True}
        if my_method in ['TNC', 'SLSQP']:
            opt_dict['maxiter'] = 200
        elif my_method == 'L-BFGS-B':
            opt_dict['maxfun'] = 5000
            opt_dict['maxls'] = 50
        # bounds = [(0.0, 1.0) for y in range(self._num_free_parameters)]
        bounds = list(
            zip(np.clip(x - step, 0.0, 1.0), np.clip(x + step, 0.0, 1.0)))

        bh = minimize(self.fprob,
                      x,
                      method=my_method,
                      bounds=bounds,
                      options=opt_dict)

        # bounds = list(
        #     zip(np.clip(x - step, 0.0, 1.0), np.clip(x + step, 0.0, 1.0)))
        #
        # bh = differential_evolution(
        #     self.fprob, bounds, disp=True, polish=False)

        # bh = basinhopping(
        #     self.fprob,
        #     x,
        #     disp=True,
        #     niter=10,
        #     minimizer_kwargs={'method': "L-BFGS-B",
        #                       'bounds': bounds})

        # bo = BayesianOptimization(self.boprob, dict(
        #     [('x' + str(i),
        #       (np.clip(x[i] - step, 0.0, 1.0),
        #        np.clip(x[i] + step, 0.0, 1.0)))
        #      for i in range(len(x))]))
        #
        # bo.explore(dict([('x' + str(i), [x[i]]) for i in range(len(x))]))
        #
        # bo.maximize(init_points=0, n_iter=20, acq='ei')
        #
        # bh = self.outClass()
        # bh.x = [x[1] for x in sorted(bo.res['max']['max_params'].items())]
        # bh.fun = -bo.res['max']['max_val']

        # m = Minuit(self.fprob)
        # m.migrad()
        return bh

    def construct_trees(self,
                        d,
                        trees,
                        simple,
                        kinds=[],
                        name='',
                        roots=[],
                        depth=0):
        """Construct call trees for each root."""
        leaf = kinds if len(kinds) else name
        if depth > 100:
            raise RuntimeError(
                'Error: Tree depth greater than 100, suggests a recursive '
                'input loop in `{}`.'.format(leaf))
        for tag in d:
            entry = deepcopy(d[tag])
            new_roots = list(roots)
            if entry['kind'] in kinds or tag == name:
                entry['depth'] = depth
                if entry['kind'] in kinds:
                    new_roots.append(entry['kind'])
                entry['roots'] = list(sorted(set(new_roots)))
                trees[tag] = entry
                simple[tag] = OrderedDict()
                inputs = listify(entry.get('inputs', []))
                for inps in inputs:
                    conditional = False
                    if isinstance(inps, list) and not isinstance(
                            inps, string_types) and inps[-1] == "conditional":
                        inp = inps[0]
                        conditional = True
                    else:
                        inp = inps
                    if inp not in d:
                        suggests = get_close_matches(inp, d, n=1, cutoff=0.8)
                        warn_str = ('Module `{}` for input to `{}` '
                                    'not found!'.format(inp, leaf))
                        if len(suggests):
                            warn_str += (' Did you perhaps mean `{}`?'.format(
                                suggests[0]))
                        raise RuntimeError(warn_str)
                    # Conditional inputs don't propagate down the tree.
                    if conditional:
                        continue
                    children = OrderedDict()
                    simple_children = OrderedDict()
                    self.construct_trees(d,
                                         children,
                                         simple_children,
                                         name=inp,
                                         roots=new_roots,
                                         depth=depth + 1)
                    trees[tag].setdefault('children', OrderedDict())
                    trees[tag]['children'].update(children)
                    simple[tag].update(simple_children)

    def draw_from_icdf(self, draw):
        """Draw parameters into unit interval using parameter inverse CDFs."""
        return [
            self._modules[self._free_parameters[i]].prior_icdf(x)
            for i, x in enumerate(draw)
        ]

    def draw_walker(self,
                    test=True,
                    walkers_pool=[],
                    replace=False,
                    weights=None):
        """Draw a walker randomly.

        Draw a walker randomly from the full range of all parameters, reject
        walkers that return invalid scores.
        """
        p = None
        chosen_one = None
        draw_cnt = 0
        while p is None:
            draw_cnt += 1
            draw = np.random.uniform(low=0.0,
                                     high=1.0,
                                     size=self._num_free_parameters)
            draw = self.draw_from_icdf(draw)
            if walkers_pool:
                if not replace:
                    chosen_one = 0
                else:
                    chosen_one = np.random.choice(range(len(walkers_pool)),
                                                  p=weights)
                for e, elem in enumerate(walkers_pool[chosen_one]):
                    if elem is not None:
                        draw[e] = elem
            if not test:
                p = draw
                score = None
                break
            score = self.ln_likelihood(draw)
            if draw_cnt >= self.DRAW_LIMIT and not self._draw_limit_reached:
                self._printer.message('draw_limit_reached', warning=True)
                self._draw_limit_reached = True
            if ((not isnan(score) and np.isfinite(score) and
                 (not isinstance(self._fitter._draw_above_likelihood, float)
                  or score > self._fitter._draw_above_likelihood))
                    or draw_cnt >= self.DRAW_LIMIT):
                p = draw

        if not replace and chosen_one is not None:
            del walkers_pool[chosen_one]
            if weights is not None:
                del weights[chosen_one]
                if weights and None not in weights:
                    totw = np.sum(weights)
                    weights = [x / totw for x in weights]
        return (p, score)

    def get_max_depth(self, tag, parent, max_depth):
        """Return the maximum depth a given task is found in a tree."""
        for child in parent.get('children', []):
            if child == tag:
                new_max = parent['children'][child]['depth']
                if new_max > max_depth:
                    max_depth = new_max
            else:
                new_max = self.get_max_depth(tag, parent['children'][child],
                                             max_depth)
                if new_max > max_depth:
                    max_depth = new_max
        return max_depth

    def in_tree(self, tag, parent):
        """Return the maximum depth a given task is found in a tree."""
        for child in parent.get('children', []):
            if child == tag:
                return True
            else:
                if self.in_tree(tag, parent['children'][child]):
                    return True
        return False

    def pool(self):
        """Return processing pool."""
        return self._pool

    def run(self, x, root='output'):
        """Run stack with the given root."""
        outputs = self.run_stack(x, root=root)
        return outputs

    def printer(self):
        """Return printer."""
        return self._printer

    def likelihood(self, x):
        """Return score related to maximum likelihood."""
        return np.exp(self.ln_likelihood(x))

    def ln_likelihood(self, x):
        """Return ln(likelihood)."""
        outputs = self.run_stack(x, root='objective')
        return outputs['value']

    def ln_likelihood_floored(self, x):
        """Return ln(likelihood), floored to a finite value."""
        outputs = self.run_stack(x, root='objective')
        return max(LOCAL_LIKELIHOOD_FLOOR, outputs['value'])

    def free_parameter_names(self, x):
        """Return list of free parameter names."""
        return self._free_parameters

    def prior(self, x):
        """Return score related to paramater priors."""
        return np.exp(self.ln_prior(x))

    def ln_prior(self, x):
        """Return ln(prior)."""
        prior = 0.0
        for pi, par in enumerate(self._free_parameters):
            lprior = self._modules[par].lnprior_pdf(x[pi])
            prior = prior + lprior
        return prior

    def boprob(self, **kwargs):
        """Score for `BayesianOptimization`."""
        x = []
        for key in sorted(kwargs):
            x.append(kwargs[key])

        li = self.ln_likelihood(x) + self.ln_prior(x)
        if not np.isfinite(li):
            return LOCAL_LIKELIHOOD_FLOOR
        return li

    def fprob(self, x):
        """Return score for fracking."""
        li = -(self.ln_likelihood(x) + self.ln_prior(x))
        if not np.isfinite(li):
            return -LOCAL_LIKELIHOOD_FLOOR
        return li

    def plural(self, x):
        """Pluralize and cache model-related keys."""
        if x not in self._inflections:
            plural = self._inflect.plural(x)
            if plural == x:
                plural = x + 's'
            self._inflections[x] = plural
        else:
            plural = self._inflections[x]
        return plural

    def reset_unset_recommended_keys(self):
        """Null the list of unset recommended keys across all modules."""
        for module in self._modules.values():
            module.reset_unset_recommended_keys()

    def get_unset_recommended_keys(self):
        """Collect list of unset recommended keys across all modules."""
        unset_keys = set()
        for module in self._modules.values():
            unset_keys.update(module.get_unset_recommended_keys())
        return unset_keys

    def run_stack(self, x, root='objective'):
        """Run module stack.

        Run a stack of modules as defined in the model definition file. Only
        run functions that match the specified root.
        """
        inputs = OrderedDict()
        outputs = OrderedDict()
        pos = 0
        cur_depth = self._max_depth_all

        # If this is the first time running this stack, build the ref arrays.
        build_refs = root not in self._references
        if build_refs:
            self._references[root] = []

        for task in self._call_stack:
            cur_task = self._call_stack[task]
            if root not in cur_task['roots']:
                continue
            if cur_task['depth'] != cur_depth:
                inputs = outputs
            inputs.update(OrderedDict([('root', root)]))
            cur_depth = cur_task['depth']
            if task in self._free_parameters:
                inputs.update(OrderedDict([('fraction', x[pos])]))
                inputs.setdefault('fractions', []).append(x[pos])
                pos = pos + 1
            try:
                new_outs = self._modules[task].process(**inputs)
                if not isinstance(new_outs, OrderedDict):
                    new_outs = OrderedDict(sorted(new_outs.items()))
            except Exception:
                self._printer.prt(
                    "Failed to execute module `{}`\'s process().".format(task),
                    wrapped=True)
                raise

            outputs.update(new_outs)

            # Append module references
            if build_refs:
                self._references[root].extend(self._modules[task]._REFERENCES)

            if '_delete_keys' in outputs:
                for key in list(outputs['_delete_keys'].keys()):
                    del outputs[key]
                del outputs['_delete_keys']

        if build_refs:
            # Make sure references are unique.
            self._references[root] = list(
                map(
                    dict,
                    set(
                        tuple(sorted(d.items()))
                        for d in self._references[root])))

        return outputs
Example #10
0
                     label='data: trail')

    axes[2].set_xlim(0, 17)
    axes[2].set_ylim(-1, 1)
    axes[2].set_xlabel(r'$\Delta \phi_1$ [deg]')
    axes[2].set_ylabel('$\Delta \phi_2$ [deg]')
    axes[2].legend(loc='best', fontsize=15)

    fig.tight_layout()
    fig.savefig(
        path.join(
            plot_path, 'BarModels_RL{:d}_Mb{:.0e}_Om{:.1f}.png'.format(
                release_every, m_b.value, omega.value)))


def worker(task):
    omega, = task
    width_track(omega * u.km / u.s / u.kpc,
                m_b=1e10 * u.Msun,
                release_every=1,
                n_steps=6000)


tasks = [(om, ) for om in np.arange(28.0, 60 + 1e-3, 0.5)]

with SerialPool() as pool:
    #with MultiPool() as pool:
    print(pool.size)
    for r in pool.map(worker, tasks):
        pass