def pool(request): multimode = 'None' # multimode = 'Serial' # multimode = 'Multi' # multimode = 'MPI' # setup code pool = None if multimode == 'Serial': from schwimmbad import SerialPool pool = SerialPool() elif multimode == 'Multi': from schwimmbad import MultiPool pool = MultiPool() elif multimode == 'MPI': from schwimmbad import MPIPool pool = MPIPool() if not pool.is_master(): pool.wait() import sys sys.exit(0) # inject class variables request.cls.pool = pool yield # tear down if multimode == 'Multi' or multimode == 'MPI': pool.close()
def test_init(case): prior, _ = get_prior(case) # Try various initializations TheJoker(prior) with pytest.raises(TypeError): TheJoker('jsdfkj') # Pools: with SerialPool() as pool: TheJoker(prior, pool=pool) # fail when pool is invalid: with pytest.raises(TypeError): TheJoker(prior, pool='sdfks') # Random state: rnd = np.random.default_rng(42) TheJoker(prior, random_state=rnd) # fail when random state is invalid: with pytest.raises(TypeError): TheJoker(prior, random_state='sdfks') with pytest.warns(DeprecationWarning): rnd = np.random.RandomState(42) TheJoker(prior, random_state=rnd) # tempfile location: joker = TheJoker(prior, tempfile_path='/tmp/joker') assert os.path.exists(joker.tempfile_path)
def __init__(self, cuda=False, exit_on_prompt=False, language='en', limiting_magnitude=None, prefer_fluxes=False, offline=False, prefer_cache=False, open_in_browser=False, pool=None, quiet=False, test=False, wrap_length=100, **kwargs): """Initialize `Fitter` class.""" self._pool = SerialPool() if pool is None else pool self._printer = Printer(pool=self._pool, wrap_length=wrap_length, quiet=quiet, fitter=self, language=language, exit_on_prompt=exit_on_prompt) self._fetcher = Fetcher(test=test, open_in_browser=open_in_browser, printer=self._printer) self._cuda = cuda self._limiting_magnitude = limiting_magnitude self._prefer_fluxes = prefer_fluxes self._offline = offline self._prefer_cache = prefer_cache self._open_in_browser = open_in_browser self._quiet = quiet self._test = test self._wrap_length = wrap_length if self._cuda: try: import pycuda.autoinit # noqa: F401 import skcuda.linalg as linalg linalg.init() except ImportError: pass
def sample(self, n_samples): if self.pool is None or _GPU_ENABLED: pool = SerialPool() else: if isinstance(self.pool, int): pool = MultiPool(self.pool) elif isinstance(self.pool, (SerialPool, MultiPool)): pool = self.pool else: raise TypeError( "Does not understand the given multiprocessing pool.") drawn_samples = list( pool.map(self.draw_one_joint_posterior_sample_map, range(n_samples))) pool.close() drawn_zs = [drawn_samples[i][0] for i in range(n_samples)] drawn_inference_posteriors = [ drawn_samples[i][1] for i in range(n_samples) ] drawn_joint_posterior_samples = pd.DataFrame( drawn_inference_posteriors) drawn_joint_posterior_samples["redshift"] = drawn_zs return drawn_joint_posterior_samples
def compute_mean_selection_function(selection_function, N_avg, pool=None): if pool is None: pool = SerialPool() elif isinstance(pool, int): pool = MultiPool(pool) elif isinstance(pool, (SerialPool, MultiPool)): pool = pool else: raise TypeError("Does not understand the given multiprocessing pool.") out = list( pool.starmap(selection_function.evaluate, [() for _ in range(N_avg)])) avg = np.average(out) pool.close() return avg
class Fitter(object): """Fit transient events with the provided model.""" _DEFAULT_SOURCE = {SOURCE.BIBCODE: '2017arXiv171002145G'} def __init__(self, cuda=False, exit_on_prompt=False, language='en', limiting_magnitude=None, prefer_fluxes=False, offline=False, prefer_cache=False, open_in_browser=False, pool=None, quiet=False, test=False, wrap_length=100, **kwargs): """Initialize `Fitter` class.""" self._pool = SerialPool() if pool is None else pool self._printer = Printer(pool=self._pool, wrap_length=wrap_length, quiet=quiet, fitter=self, language=language, exit_on_prompt=exit_on_prompt) self._fetcher = Fetcher(test=test, open_in_browser=open_in_browser, printer=self._printer) self._cuda = cuda self._limiting_magnitude = limiting_magnitude self._prefer_fluxes = prefer_fluxes self._offline = offline self._prefer_cache = prefer_cache self._open_in_browser = open_in_browser self._quiet = quiet self._test = test self._wrap_length = wrap_length if self._cuda: try: import pycuda.autoinit # noqa: F401 import skcuda.linalg as linalg linalg.init() except ImportError: pass def fit_events(self, events=[], models=[], max_time='', time_list=[], time_unit=None, band_list=[], band_systems=[], band_instruments=[], band_bandsets=[], band_sampling_points=17, iterations=10000, num_walkers=None, num_temps=1, parameter_paths=['parameters.json'], fracking=True, frack_step=50, burn=None, post_burn=None, gibbs=False, smooth_times=-1, extrapolate_time=0.0, limit_fitting_mjds=False, exclude_bands=[], exclude_instruments=[], exclude_systems=[], exclude_sources=[], exclude_kinds=[], output_path='', suffix='', upload=False, write=False, upload_token='', check_upload_quality=False, variance_for_each=[], user_fixed_parameters=[], user_released_parameters=[], convergence_type=None, convergence_criteria=None, save_full_chain=False, draw_above_likelihood=False, maximum_walltime=False, start_time=False, print_trees=False, maximum_memory=np.inf, speak=False, return_fits=True, extra_outputs=None, walker_paths=[], catalogs=[], exit_on_prompt=False, download_recommended_data=False, local_data_only=False, method=None, seed=None, **kwargs): """Fit a list of events with a list of models.""" global model if start_time is False: start_time = time.time() self._seed = seed if seed is not None: np.random.seed(seed) self._start_time = start_time self._maximum_walltime = maximum_walltime self._maximum_memory = maximum_memory self._debug = False self._speak = speak self._download_recommended_data = download_recommended_data self._local_data_only = local_data_only self._draw_above_likelihood = draw_above_likelihood prt = self._printer event_list = listify(events) model_list = listify(models) if len(model_list) and not len(event_list): event_list = [''] # Exclude catalogs not included in catalog list. self._fetcher.add_excluded_catalogs(catalogs) if not len(event_list) and not len(model_list): prt.message('no_events_models', warning=True) # If the input is not a JSON file, assume it is either a list of # transients or that it is the data from a single transient in tabular # form. Try to guess the format first, and if that fails ask the user. self._converter = Converter(prt, require_source=upload) event_list = self._converter.generate_event_list(event_list) event_list = [x.replace('‑', '-') for x in event_list] entries = [[] for x in range(len(event_list))] ps = [[] for x in range(len(event_list))] lnprobs = [[] for x in range(len(event_list))] # Load walker data if provided a list of walker paths. walker_data = [] if len(walker_paths): try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): prt.message('walker_file') wfi = 0 for walker_path in walker_paths: if os.path.exists(walker_path): prt.prt(' {}'.format(walker_path)) with codecs.open(walker_path, 'r', encoding='utf-8') as f: all_walker_data = json.load( f, object_pairs_hook=OrderedDict) # Support both the format where all data stored in a # single-item dictionary (the OAC format) and the older # MOSFiT format where the data was stored in the # top-level dictionary. if ENTRY.NAME not in all_walker_data: all_walker_data = all_walker_data[list( all_walker_data.keys())[0]] models = all_walker_data.get(ENTRY.MODELS, []) choice = None if len(models) > 1: model_opts = [ '{}-{}-{}'.format(x['code'], x['name'], x['date']) for x in models ] choice = prt.prompt('select_model_walkers', kind='select', message=True, options=model_opts) choice = model_opts.index(choice) elif len(models) == 1: choice = 0 if choice is not None: walker_data.extend([[ wfi, x[REALIZATION.PARAMETERS], x.get(REALIZATION.WEIGHT) ] for x in models[choice][MODEL.REALIZATIONS]]) for i in range(len(walker_data)): if walker_data[i][2] is not None: walker_data[i][2] = float(walker_data[i][2]) if not len(walker_data): prt.message('no_walker_data') else: prt.message('no_walker_data') if self._offline: prt.message('omit_offline') raise RuntimeError wfi = wfi + 1 for rank in range(1, pool.size + 1): pool.comm.send(walker_data, dest=rank, tag=3) else: walker_data = pool.comm.recv(source=0, tag=3) pool.wait() if pool.is_master(): pool.close() self._event_name = 'Batch' self._event_path = '' self._event_data = {} try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): fetched_events = self._fetcher.fetch( event_list, offline=self._offline, prefer_cache=self._prefer_cache) for rank in range(1, pool.size + 1): pool.comm.send(fetched_events, dest=rank, tag=0) pool.close() else: fetched_events = pool.comm.recv(source=0, tag=0) pool.wait() for ei, event in enumerate(fetched_events): if event is not None: self._event_name = event.get('name', 'Batch') self._event_path = event.get('path', '') if not self._event_path: continue self._event_data = self._fetcher.load_data(event) if not self._event_data: continue if model_list: lmodel_list = model_list else: lmodel_list = [''] entries[ei] = [None for y in range(len(lmodel_list))] ps[ei] = [None for y in range(len(lmodel_list))] lnprobs[ei] = [None for y in range(len(lmodel_list))] if (event is not None and (not self._event_data or ENTRY.PHOTOMETRY not in self._event_data[list(self._event_data.keys())[0]])): prt.message('no_photometry', [self._event_name]) continue for mi, mod_name in enumerate(lmodel_list): for parameter_path in parameter_paths: try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() self._model = Model(model=mod_name, data=self._event_data, parameter_path=parameter_path, output_path=output_path, wrap_length=self._wrap_length, test=self._test, printer=prt, fitter=self, pool=pool, print_trees=print_trees) if not self._model._model_name: prt.message('no_models_avail', [self._event_name], warning=True) continue if not event: prt.message('gen_dummy') self._event_name = mod_name gen_args = { 'name': mod_name, 'max_time': max_time, 'time_list': time_list, 'band_list': band_list, 'band_systems': band_systems, 'band_instruments': band_instruments, 'band_bandsets': band_bandsets } self._event_data = self.generate_dummy_data(**gen_args) success = False alt_name = None while not success: self._model.reset_unset_recommended_keys() success = self._model.load_data( self._event_data, event_name=self._event_name, smooth_times=smooth_times, extrapolate_time=extrapolate_time, limit_fitting_mjds=limit_fitting_mjds, exclude_bands=exclude_bands, exclude_instruments=exclude_instruments, exclude_systems=exclude_systems, exclude_sources=exclude_sources, exclude_kinds=exclude_kinds, time_list=time_list, time_unit=time_unit, band_list=band_list, band_systems=band_systems, band_instruments=band_instruments, band_bandsets=band_bandsets, band_sampling_points=band_sampling_points, variance_for_each=variance_for_each, user_fixed_parameters=user_fixed_parameters, user_released_parameters=user_released_parameters, pool=pool) if not success: break if self._local_data_only: break # If our data is missing recommended keys, offer the # user option to pull the missing data from online and # merge with existing data. urk = self._model.get_unset_recommended_keys() ptxt = prt.text('acquire_recommended', [', '.join(list(urk))]) while event and len(urk) and ( alt_name or self._download_recommended_data or prt.prompt(ptxt, [', '.join(urk)], kind='bool')): try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): en = (alt_name if alt_name else self._event_name) extra_event = self._fetcher.fetch( en, offline=self._offline, prefer_cache=self._prefer_cache)[0] extra_data = self._fetcher.load_data( extra_event) for rank in range(1, pool.size + 1): pool.comm.send(extra_data, dest=rank, tag=4) pool.close() else: extra_data = pool.comm.recv(source=0, tag=4) pool.wait() if extra_data is not None: extra_data = extra_data[list( extra_data.keys())[0]] for key in urk: new_val = extra_data.get(key) self._event_data[list( self._event_data.keys()) [0]][key] = new_val if new_val is not None and len(new_val): prt.message('extra_value', [ key, str(new_val[0].get(QUANTITY.VALUE)) ]) success = False prt.message('reloading_merged') break else: text = prt.text('extra_not_found', [self._event_name]) alt_name = prt.prompt(text, kind='string') if not alt_name: break if success: self._walker_data = walker_data entry, p, lnprob = self.fit_data( event_name=self._event_name, method=method, iterations=iterations, num_walkers=num_walkers, num_temps=num_temps, burn=burn, post_burn=post_burn, fracking=fracking, frack_step=frack_step, gibbs=gibbs, pool=pool, output_path=output_path, suffix=suffix, write=write, upload=upload, upload_token=upload_token, check_upload_quality=check_upload_quality, convergence_type=convergence_type, convergence_criteria=convergence_criteria, save_full_chain=save_full_chain, extra_outputs=extra_outputs) if return_fits: entries[ei][mi] = deepcopy(entry) ps[ei][mi] = deepcopy(p) lnprobs[ei][mi] = deepcopy(lnprob) if pool.is_master(): pool.close() # Remove global model variable and garbage collect. try: model except NameError: pass else: del (model) del (self._model) gc.collect() return (entries, ps, lnprobs) def fit_data(self, event_name='', method=None, iterations=None, frack_step=20, num_walkers=None, num_temps=1, burn=None, post_burn=None, fracking=True, gibbs=False, pool=None, output_path='', suffix='', write=False, upload=False, upload_token='', check_upload_quality=True, convergence_type=None, convergence_criteria=None, save_full_chain=False, extra_outputs=None): """Fit the data for a given event. Fitting performed using a combination of emcee and fracking. """ if self._speak: speak('Fitting ' + event_name, self._speak) from mosfit.__init__ import __version__ global model model = self._model prt = self._printer upload_model = upload and iterations > 0 if pool is not None: self._pool = pool if upload: try: import dropbox except ImportError: if self._test: pass else: prt.message('install_db', error=True) raise if not self._pool.is_master(): try: self._pool.wait() except (KeyboardInterrupt, SystemExit): pass return (None, None, None) self._method = method if self._method == 'nester': self._sampler = Nester(self, model, iterations, burn, post_burn, num_walkers, convergence_criteria, convergence_type, gibbs, fracking, frack_step) else: self._sampler = Ensembler(self, model, iterations, burn, post_burn, num_temps, num_walkers, convergence_criteria, convergence_type, gibbs, fracking, frack_step) self._sampler.run(self._walker_data) prt.message('constructing') if write: if self._speak: speak(prt._strings['saving_output'], self._speak) if self._event_path: entry = Entry.init_from_file(catalog=None, name=self._event_name, path=self._event_path, merge=False, pop_schema=False, ignore_keys=[ENTRY.MODELS], compare_to_existing=False) new_photometry = [] for photo in entry.get(ENTRY.PHOTOMETRY, []): if PHOTOMETRY.REALIZATION not in photo: new_photometry.append(photo) if len(new_photometry): entry[ENTRY.PHOTOMETRY] = new_photometry else: entry = Entry(name=self._event_name) uentry = Entry(name=self._event_name) data_keys = set() for task in model._call_stack: if model._call_stack[task]['kind'] == 'data': data_keys.update( list(model._call_stack[task].get('keys', {}).keys())) entryhash = entry.get_hash(keys=list(sorted(list(data_keys)))) # Accumulate all the sources and add them to each entry. sources = [] for root in model._references: for ref in model._references[root]: sources.append(entry.add_source(**ref)) sources.append(entry.add_source(**self._DEFAULT_SOURCE)) source = ','.join(sources) usources = [] for root in model._references: for ref in model._references[root]: usources.append(uentry.add_source(**ref)) usources.append(uentry.add_source(**self._DEFAULT_SOURCE)) usource = ','.join(usources) model_setup = OrderedDict() for ti, task in enumerate(model._call_stack): task_copy = deepcopy(model._call_stack[task]) if (task_copy['kind'] == 'parameter' and task in model._parameter_json): task_copy.update(model._parameter_json[task]) model_setup[task] = task_copy modeldict = OrderedDict([(MODEL.NAME, model._model_name), (MODEL.SETUP, model_setup), (MODEL.CODE, 'MOSFiT'), (MODEL.DATE, time.strftime("%Y/%m/%d")), (MODEL.VERSION, __version__), (MODEL.SOURCE, source)]) self._sampler.prepare_output(check_upload_quality, upload) self._sampler.append_output(modeldict) umodeldict = deepcopy(modeldict) umodeldict[MODEL.SOURCE] = usource modelhash = get_model_hash(umodeldict, ignore_keys=[MODEL.DATE, MODEL.SOURCE]) umodelnum = uentry.add_model(**umodeldict) if self._sampler._upload_model is not None: upload_model = self._sampler._upload_model modelnum = entry.add_model(**modeldict) samples, probs, weights = self._sampler.get_samples() extras = OrderedDict() samples_to_plot = self._sampler._nwalkers if isinstance(self._sampler, Nester): icdf = np.cumsum(np.concatenate(([0.0], weights))) draws = np.random.rand(samples_to_plot) indices = np.searchsorted(icdf, draws) - 1 else: indices = list(range(samples_to_plot)) ri = 0 selected_extra = False for xi, x in enumerate(samples): ri = ri + 1 prt.message('outputting_walker', [ri, len(samples)], inline=True, min_time=0.2) if xi in indices: output = model.run_stack(x, root='output') if extra_outputs is not None: if not extra_outputs and not selected_extra: extra_options = list(output.keys()) prt.message('available_keys') for opt in extra_options: prt.prt('- {}'.format(opt)) selected_extra = True for key in extra_outputs: new_val = output.get(key, []) new_val = all_to_list(new_val) extras.setdefault(key, []).append(new_val) for i in range(len(output['times'])): if not np.isfinite(output['model_observations'][i]): continue photodict = { PHOTOMETRY.TIME: output['times'][i] + output['min_times'], PHOTOMETRY.MODEL: modelnum, PHOTOMETRY.SOURCE: source, PHOTOMETRY.REALIZATION: str(ri) } if output['observation_types'][i] == 'magnitude': photodict[PHOTOMETRY.BAND] = output['bands'][i] photodict[PHOTOMETRY. MAGNITUDE] = output['model_observations'][i] photodict[PHOTOMETRY. E_MAGNITUDE] = output['model_variances'][i] elif output['observation_types'][i] == 'magcount': if output['model_observations'][i] == 0.0: continue photodict[PHOTOMETRY.BAND] = output['bands'][i] photodict[PHOTOMETRY. COUNT_RATE] = output['model_observations'][i] photodict[PHOTOMETRY. E_COUNT_RATE] = output['model_variances'][i] photodict[PHOTOMETRY.MAGNITUDE] = -2.5 * np.log10( output['model_observations'] [i]) + output['all_zeropoints'][i] photodict[PHOTOMETRY.E_UPPER_MAGNITUDE] = 2.5 * ( np.log10(output['model_observations'][i] + output['model_variances'][i]) - np.log10(output['model_observations'][i])) if (output['model_variances'][i] > output['model_observations'][i]): photodict[PHOTOMETRY.UPPER_LIMIT] = True else: photodict[PHOTOMETRY.E_LOWER_MAGNITUDE] = 2.5 * ( np.log10(output['model_observations'][i]) - np.log10(output['model_observations'][i] - output['model_variances'][i])) elif output['observation_types'][i] == 'fluxdensity': photodict[PHOTOMETRY.FREQUENCY] = output[ 'frequencies'][i] * frequency_unit('GHz') photodict[PHOTOMETRY.FLUX_DENSITY] = output[ 'model_observations'][i] * flux_density_unit('µJy') photodict[PHOTOMETRY.E_LOWER_FLUX_DENSITY] = ( photodict[PHOTOMETRY.FLUX_DENSITY] - (10.0** (np.log10(photodict[PHOTOMETRY.FLUX_DENSITY]) - output['model_variances'][i] / 2.5)) * flux_density_unit('µJy')) photodict[PHOTOMETRY.E_UPPER_FLUX_DENSITY] = ( 10.0**(np.log10(photodict[PHOTOMETRY.FLUX_DENSITY]) + output['model_variances'][i] / 2.5) * flux_density_unit('µJy') - photodict[PHOTOMETRY.FLUX_DENSITY]) photodict[PHOTOMETRY.U_FREQUENCY] = 'GHz' photodict[PHOTOMETRY.U_FLUX_DENSITY] = 'µJy' elif output['observation_types'][i] == 'countrate': photodict[PHOTOMETRY. COUNT_RATE] = output['model_observations'][i] photodict[PHOTOMETRY.E_LOWER_COUNT_RATE] = ( photodict[PHOTOMETRY.COUNT_RATE] - (10.0**(np.log10(photodict[PHOTOMETRY.COUNT_RATE]) - output['model_variances'][i] / 2.5))) photodict[PHOTOMETRY.E_UPPER_COUNT_RATE] = ( 10.0**(np.log10(photodict[PHOTOMETRY.COUNT_RATE]) + output['model_variances'][i] / 2.5) - photodict[PHOTOMETRY.COUNT_RATE]) photodict[PHOTOMETRY.U_COUNT_RATE] = 's^-1' if ('model_upper_limits' in output and output['model_upper_limits'][i]): photodict[PHOTOMETRY.UPPER_LIMIT] = bool( output['model_upper_limits'][i]) if self._limiting_magnitude is not None: photodict[PHOTOMETRY.SIMULATED] = True if 'telescopes' in output and output['telescopes'][i]: photodict[ PHOTOMETRY.TELESCOPE] = output['telescopes'][i] if 'systems' in output and output['systems'][i]: photodict[PHOTOMETRY.SYSTEM] = output['systems'][i] if 'bandsets' in output and output['bandsets'][i]: photodict[PHOTOMETRY.BAND_SET] = output['bandsets'][i] if 'instruments' in output and output['instruments'][i]: photodict[ PHOTOMETRY.INSTRUMENT] = output['instruments'][i] if 'modes' in output and output['modes'][i]: photodict[PHOTOMETRY.MODE] = output['modes'][i] entry.add_photometry(compare_to_existing=False, check_for_dupes=False, **photodict) uphotodict = deepcopy(photodict) uphotodict[PHOTOMETRY.SOURCE] = umodelnum uentry.add_photometry(compare_to_existing=False, check_for_dupes=False, **uphotodict) else: output = model.run_stack(x, root='objective') parameters = OrderedDict() derived_keys = set() pi = 0 for ti, task in enumerate(model._call_stack): # if task not in model._free_parameters: # continue if model._call_stack[task]['kind'] != 'parameter': continue paramdict = OrderedDict( (('latex', model._modules[task].latex()), ('log', model._modules[task].is_log()))) if task in model._free_parameters: poutput = model._modules[task].process( **{'fraction': x[pi]}) value = list(poutput.values())[0] paramdict['value'] = value paramdict['fraction'] = x[pi] pi = pi + 1 else: if output.get(task, None) is not None: paramdict['value'] = output[task] parameters.update({model._modules[task].name(): paramdict}) # Dump out any derived parameter keys derived_keys.update(model._modules[task].get_derived_keys()) for key in list(sorted(list(derived_keys))): if (output.get(key, None) is not None and key not in parameters): parameters.update({key: {'value': output[key]}}) realdict = {REALIZATION.PARAMETERS: parameters} if probs is not None: realdict[REALIZATION.SCORE] = str(probs[xi]) else: realdict[REALIZATION.SCORE] = str( ln_likelihood(x) + ln_prior(x)) realdict[REALIZATION.ALIAS] = str(ri) realdict[REALIZATION.WEIGHT] = str(weights[xi]) entry[ENTRY.MODELS][0].add_realization(check_for_dupes=False, **realdict) urealdict = deepcopy(realdict) uentry[ENTRY.MODELS][0].add_realization(check_for_dupes=False, **urealdict) prt.message('all_walkers_written', inline=True) entry.sanitize() oentry = {self._event_name: entry._ordered(entry)} uentry.sanitize() ouentry = {self._event_name: uentry._ordered(uentry)} uname = '_'.join([self._event_name, entryhash, modelhash]) if output_path and not os.path.exists(output_path): os.makedirs(output_path) if not os.path.exists(model.get_products_path()): os.makedirs(model.get_products_path()) if write: prt.message('writing_complete') with open_atomic( os.path.join(model.get_products_path(), 'walkers.json'), 'w') as flast, open_atomic( os.path.join( model.get_products_path(), self._event_name + (('_' + suffix) if suffix else '') + '.json'), 'w') as feven: entabbed_json_dump(oentry, flast, separators=(',', ':')) entabbed_json_dump(oentry, feven, separators=(',', ':')) if save_full_chain: prt.message('writing_full_chain') with open_atomic( os.path.join(model.get_products_path(), 'chain.json'), 'w') as flast, open_atomic( os.path.join( model.get_products_path(), self._event_name + '_chain' + (('_' + suffix) if suffix else '') + '.json'), 'w') as feven: entabbed_json_dump(self._sampler._all_chain.tolist(), flast, separators=(',', ':')) entabbed_json_dump(self._sampler._all_chain.tolist(), feven, separators=(',', ':')) if extra_outputs is not None: prt.message('writing_extras') with open_atomic( os.path.join(model.get_products_path(), 'extras.json'), 'w') as flast, open_atomic( os.path.join( model.get_products_path(), self._event_name + '_extras' + (('_' + suffix) if suffix else '') + '.json'), 'w') as feven: entabbed_json_dump(extras, flast, separators=(',', ':')) entabbed_json_dump(extras, feven, separators=(',', ':')) prt.message('writing_model') with open_atomic( os.path.join(model.get_products_path(), 'upload.json'), 'w') as flast, open_atomic( os.path.join( model.get_products_path(), uname + (('_' + suffix) if suffix else '') + '.json'), 'w') as feven: entabbed_json_dump(ouentry, flast, separators=(',', ':')) entabbed_json_dump(ouentry, feven, separators=(',', ':')) if upload_model: prt.message('ul_fit', [entryhash, self._sampler._modelhash]) upayload = entabbed_json_dumps(ouentry, separators=(',', ':')) try: dbx = dropbox.Dropbox(upload_token) dbx.files_upload(upayload.encode(), '/' + uname + '.json', mode=dropbox.files.WriteMode.overwrite) prt.message('ul_complete') except Exception: if self._test: pass else: raise if upload: for ce in self._converter.get_converted(): dentry = Entry.init_from_file(catalog=None, name=ce[0], path=ce[1], merge=False, pop_schema=False, ignore_keys=[ENTRY.MODELS], compare_to_existing=False) dentry.sanitize() odentry = {ce[0]: uentry._ordered(dentry)} dpayload = entabbed_json_dumps(odentry, separators=(',', ':')) text = prt.message('ul_devent', [ce[0]], prt=False) ul_devent = prt.prompt(text, kind='bool', message=False) if ul_devent: dpath = '/' + slugify( ce[0] + '_' + dentry[ENTRY.SOURCES][0].get( SOURCE.BIBCODE, dentry[ENTRY.SOURCES][0].get( SOURCE.NAME, 'NOSOURCE'))) + '.json' try: dbx = dropbox.Dropbox(upload_token) dbx.files_upload( dpayload.encode(), dpath, mode=dropbox.files.WriteMode.overwrite) prt.message('ul_complete') except Exception: if self._test: pass else: raise return (entry, samples, probs) def nester(self): """Use nested sampling to determine posteriors.""" pass def generate_dummy_data(self, name, max_time=1000., time_list=[], band_list=[], band_systems=[], band_instruments=[], band_bandsets=[]): """Generate simulated data based on priors.""" # Just need 2 plot points for beginning and end. plot_points = 2 times = list( sorted( set(list(np.linspace(0.0, max_time, plot_points)) + time_list))) band_list_all = ['V'] if len(band_list) == 0 else band_list times = np.repeat(times, len(band_list_all)) # Create lists of systems/instruments if not provided. if isinstance(band_systems, string_types): band_systems = [band_systems for x in range(len(band_list_all))] if isinstance(band_instruments, string_types): band_instruments = [ band_instruments for x in range(len(band_list_all)) ] if isinstance(band_bandsets, string_types): band_bandsets = [band_bandsets for x in range(len(band_list_all))] if len(band_systems) < len(band_list_all): rep_val = '' if len(band_systems) == 0 else band_systems[-1] band_systems = band_systems + [ rep_val for x in range(len(band_list_all) - len(band_systems)) ] if len(band_instruments) < len(band_list_all): rep_val = '' if len( band_instruments) == 0 else band_instruments[-1] band_instruments = band_instruments + [ rep_val for x in range(len(band_list_all) - len(band_instruments)) ] if len(band_bandsets) < len(band_list_all): rep_val = '' if len(band_bandsets) == 0 else band_bandsets[-1] band_bandsets = band_bandsets + [ rep_val for x in range(len(band_list_all) - len(band_bandsets)) ] bands = [i for s in [band_list_all for x in times] for i in s] systs = [i for s in [band_systems for x in times] for i in s] insts = [i for s in [band_instruments for x in times] for i in s] bsets = [i for s in [band_bandsets for x in times] for i in s] data = {name: {'photometry': []}} for ti, tim in enumerate(times): band = bands[ti] if isinstance(band, dict): band = band['name'] photodict = { 'time': tim, 'band': band, 'magnitude': 0.0, 'e_magnitude': 0.0 } if systs[ti]: photodict['system'] = systs[ti] if insts[ti]: photodict['instrument'] = insts[ti] if bsets[ti]: photodict['bandset'] = bsets[ti] data[name]['photometry'].append(photodict) return data
def fit_events(self, events=[], models=[], max_time='', time_list=[], time_unit=None, band_list=[], band_systems=[], band_instruments=[], band_bandsets=[], band_sampling_points=17, iterations=10000, num_walkers=None, num_temps=1, parameter_paths=['parameters.json'], fracking=True, frack_step=50, burn=None, post_burn=None, gibbs=False, smooth_times=-1, extrapolate_time=0.0, limit_fitting_mjds=False, exclude_bands=[], exclude_instruments=[], exclude_systems=[], exclude_sources=[], exclude_kinds=[], output_path='', suffix='', upload=False, write=False, upload_token='', check_upload_quality=False, variance_for_each=[], user_fixed_parameters=[], user_released_parameters=[], convergence_type=None, convergence_criteria=None, save_full_chain=False, draw_above_likelihood=False, maximum_walltime=False, start_time=False, print_trees=False, maximum_memory=np.inf, speak=False, return_fits=True, extra_outputs=None, walker_paths=[], catalogs=[], exit_on_prompt=False, download_recommended_data=False, local_data_only=False, method=None, seed=None, **kwargs): """Fit a list of events with a list of models.""" global model if start_time is False: start_time = time.time() self._seed = seed if seed is not None: np.random.seed(seed) self._start_time = start_time self._maximum_walltime = maximum_walltime self._maximum_memory = maximum_memory self._debug = False self._speak = speak self._download_recommended_data = download_recommended_data self._local_data_only = local_data_only self._draw_above_likelihood = draw_above_likelihood prt = self._printer event_list = listify(events) model_list = listify(models) if len(model_list) and not len(event_list): event_list = [''] # Exclude catalogs not included in catalog list. self._fetcher.add_excluded_catalogs(catalogs) if not len(event_list) and not len(model_list): prt.message('no_events_models', warning=True) # If the input is not a JSON file, assume it is either a list of # transients or that it is the data from a single transient in tabular # form. Try to guess the format first, and if that fails ask the user. self._converter = Converter(prt, require_source=upload) event_list = self._converter.generate_event_list(event_list) event_list = [x.replace('‑', '-') for x in event_list] entries = [[] for x in range(len(event_list))] ps = [[] for x in range(len(event_list))] lnprobs = [[] for x in range(len(event_list))] # Load walker data if provided a list of walker paths. walker_data = [] if len(walker_paths): try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): prt.message('walker_file') wfi = 0 for walker_path in walker_paths: if os.path.exists(walker_path): prt.prt(' {}'.format(walker_path)) with codecs.open(walker_path, 'r', encoding='utf-8') as f: all_walker_data = json.load( f, object_pairs_hook=OrderedDict) # Support both the format where all data stored in a # single-item dictionary (the OAC format) and the older # MOSFiT format where the data was stored in the # top-level dictionary. if ENTRY.NAME not in all_walker_data: all_walker_data = all_walker_data[list( all_walker_data.keys())[0]] models = all_walker_data.get(ENTRY.MODELS, []) choice = None if len(models) > 1: model_opts = [ '{}-{}-{}'.format(x['code'], x['name'], x['date']) for x in models ] choice = prt.prompt('select_model_walkers', kind='select', message=True, options=model_opts) choice = model_opts.index(choice) elif len(models) == 1: choice = 0 if choice is not None: walker_data.extend([[ wfi, x[REALIZATION.PARAMETERS], x.get(REALIZATION.WEIGHT) ] for x in models[choice][MODEL.REALIZATIONS]]) for i in range(len(walker_data)): if walker_data[i][2] is not None: walker_data[i][2] = float(walker_data[i][2]) if not len(walker_data): prt.message('no_walker_data') else: prt.message('no_walker_data') if self._offline: prt.message('omit_offline') raise RuntimeError wfi = wfi + 1 for rank in range(1, pool.size + 1): pool.comm.send(walker_data, dest=rank, tag=3) else: walker_data = pool.comm.recv(source=0, tag=3) pool.wait() if pool.is_master(): pool.close() self._event_name = 'Batch' self._event_path = '' self._event_data = {} try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): fetched_events = self._fetcher.fetch( event_list, offline=self._offline, prefer_cache=self._prefer_cache) for rank in range(1, pool.size + 1): pool.comm.send(fetched_events, dest=rank, tag=0) pool.close() else: fetched_events = pool.comm.recv(source=0, tag=0) pool.wait() for ei, event in enumerate(fetched_events): if event is not None: self._event_name = event.get('name', 'Batch') self._event_path = event.get('path', '') if not self._event_path: continue self._event_data = self._fetcher.load_data(event) if not self._event_data: continue if model_list: lmodel_list = model_list else: lmodel_list = [''] entries[ei] = [None for y in range(len(lmodel_list))] ps[ei] = [None for y in range(len(lmodel_list))] lnprobs[ei] = [None for y in range(len(lmodel_list))] if (event is not None and (not self._event_data or ENTRY.PHOTOMETRY not in self._event_data[list(self._event_data.keys())[0]])): prt.message('no_photometry', [self._event_name]) continue for mi, mod_name in enumerate(lmodel_list): for parameter_path in parameter_paths: try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() self._model = Model(model=mod_name, data=self._event_data, parameter_path=parameter_path, output_path=output_path, wrap_length=self._wrap_length, test=self._test, printer=prt, fitter=self, pool=pool, print_trees=print_trees) if not self._model._model_name: prt.message('no_models_avail', [self._event_name], warning=True) continue if not event: prt.message('gen_dummy') self._event_name = mod_name gen_args = { 'name': mod_name, 'max_time': max_time, 'time_list': time_list, 'band_list': band_list, 'band_systems': band_systems, 'band_instruments': band_instruments, 'band_bandsets': band_bandsets } self._event_data = self.generate_dummy_data(**gen_args) success = False alt_name = None while not success: self._model.reset_unset_recommended_keys() success = self._model.load_data( self._event_data, event_name=self._event_name, smooth_times=smooth_times, extrapolate_time=extrapolate_time, limit_fitting_mjds=limit_fitting_mjds, exclude_bands=exclude_bands, exclude_instruments=exclude_instruments, exclude_systems=exclude_systems, exclude_sources=exclude_sources, exclude_kinds=exclude_kinds, time_list=time_list, time_unit=time_unit, band_list=band_list, band_systems=band_systems, band_instruments=band_instruments, band_bandsets=band_bandsets, band_sampling_points=band_sampling_points, variance_for_each=variance_for_each, user_fixed_parameters=user_fixed_parameters, user_released_parameters=user_released_parameters, pool=pool) if not success: break if self._local_data_only: break # If our data is missing recommended keys, offer the # user option to pull the missing data from online and # merge with existing data. urk = self._model.get_unset_recommended_keys() ptxt = prt.text('acquire_recommended', [', '.join(list(urk))]) while event and len(urk) and ( alt_name or self._download_recommended_data or prt.prompt(ptxt, [', '.join(urk)], kind='bool')): try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): en = (alt_name if alt_name else self._event_name) extra_event = self._fetcher.fetch( en, offline=self._offline, prefer_cache=self._prefer_cache)[0] extra_data = self._fetcher.load_data( extra_event) for rank in range(1, pool.size + 1): pool.comm.send(extra_data, dest=rank, tag=4) pool.close() else: extra_data = pool.comm.recv(source=0, tag=4) pool.wait() if extra_data is not None: extra_data = extra_data[list( extra_data.keys())[0]] for key in urk: new_val = extra_data.get(key) self._event_data[list( self._event_data.keys()) [0]][key] = new_val if new_val is not None and len(new_val): prt.message('extra_value', [ key, str(new_val[0].get(QUANTITY.VALUE)) ]) success = False prt.message('reloading_merged') break else: text = prt.text('extra_not_found', [self._event_name]) alt_name = prt.prompt(text, kind='string') if not alt_name: break if success: self._walker_data = walker_data entry, p, lnprob = self.fit_data( event_name=self._event_name, method=method, iterations=iterations, num_walkers=num_walkers, num_temps=num_temps, burn=burn, post_burn=post_burn, fracking=fracking, frack_step=frack_step, gibbs=gibbs, pool=pool, output_path=output_path, suffix=suffix, write=write, upload=upload, upload_token=upload_token, check_upload_quality=check_upload_quality, convergence_type=convergence_type, convergence_criteria=convergence_criteria, save_full_chain=save_full_chain, extra_outputs=extra_outputs) if return_fits: entries[ei][mi] = deepcopy(entry) ps[ei][mi] = deepcopy(p) lnprobs[ei][mi] = deepcopy(lnprob) if pool.is_master(): pool.close() # Remove global model variable and garbage collect. try: model except NameError: pass else: del (model) del (self._model) gc.collect() return (entries, ps, lnprobs)
def __init__(self, parameter_path='parameters.json', model='', data={}, wrap_length=100, output_path='', pool=None, test=False, printer=None, fitter=None, print_trees=False): """Initialize `Model` object.""" from mosfit.fitter import Fitter self._model_name = model self._parameter_path = parameter_path self._output_path = output_path self._pool = SerialPool() if pool is None else pool self._is_master = pool.is_master() if pool else False self._wrap_length = wrap_length self._print_trees = print_trees self._inflect = inflect.engine() self._test = test self._inflections = {} self._references = OrderedDict() self._free_parameters = [] self._user_fixed_parameters = [] self._user_released_parameters = [] self._kinds_needed = set() self._kinds_supported = set() self._draw_limit_reached = False self._fitter = Fitter() if not fitter else fitter self._printer = self._fitter._printer if not printer else printer prt = self._printer self._dir_path = os.path.dirname(os.path.realpath(__file__)) # Load suggested model associations for transient types. if os.path.isfile(os.path.join('models', 'types.json')): types_path = os.path.join('models', 'types.json') else: types_path = os.path.join(self._dir_path, 'models', 'types.json') with open(types_path, 'r') as f: model_types = json.load(f, object_pairs_hook=OrderedDict) # Create list of all available models. all_models = set() if os.path.isdir('models'): all_models |= set(next(os.walk('models'))[1]) models_path = os.path.join(self._dir_path, 'models') if os.path.isdir(models_path): all_models |= set(next(os.walk(models_path))[1]) all_models = list(sorted(list(all_models))) if not self._model_name: claimed_type = None try: claimed_type = list( data.values())[0]['claimedtype'][0][QUANTITY.VALUE] except Exception: prt.message('no_model_type', warning=True) all_models_txt = prt.text('all_models') suggested_models_txt = prt.text('suggested_models', [claimed_type]) another_model_txt = prt.text('another_model') type_options = model_types.get(claimed_type, []) if claimed_type else [] if not type_options: type_options = all_models model_prompt_txt = all_models_txt else: type_options.append(another_model_txt) model_prompt_txt = suggested_models_txt if not type_options: prt.message('no_model_for_type', warning=True) else: while not self._model_name: if self._test: self._model_name = type_options[0] else: sel = self._printer.prompt( model_prompt_txt, kind='option', options=type_options, message=False, default='n', none_string=prt.text('none_above_models')) if sel is not None: self._model_name = type_options[int(sel) - 1] if not self._model_name: break if self._model_name == another_model_txt: type_options = all_models model_prompt_txt = all_models_txt self._model_name = None if not self._model_name: return # Load the basic model file. if os.path.isfile(os.path.join('models', 'model.json')): basic_model_path = os.path.join('models', 'model.json') else: basic_model_path = os.path.join(self._dir_path, 'models', 'model.json') with open(basic_model_path, 'r') as f: self._model = json.load(f, object_pairs_hook=OrderedDict) # Load the model file. model = self._model_name model_dir = self._model_name if '.json' in self._model_name: model_dir = self._model_name.split('.json')[0] else: model = self._model_name + '.json' if os.path.isfile(model): model_path = model else: # Look in local hierarchy first if os.path.isfile(os.path.join('models', model_dir, model)): model_path = os.path.join('models', model_dir, model) else: model_path = os.path.join(self._dir_path, 'models', model_dir, model) with open(model_path, 'r') as f: self._model.update(json.load(f, object_pairs_hook=OrderedDict)) # Find @ tags, store them, and prune them from `_model`. for tag in list(self._model.keys()): if tag.startswith('@'): if tag == '@references': self._references.setdefault('base', []).extend(self._model[tag]) del self._model[tag] # with open(os.path.join( # self.get_products_path(), # self._model_name + '.json'), 'w') as f: # json.dump(self._model, f) # Load model parameter file. model_pp = os.path.join(self._dir_path, 'models', model_dir, 'parameters.json') pp = '' local_pp = (self._parameter_path if '/' in self._parameter_path else os.path.join('models', model_dir, self._parameter_path)) if os.path.isfile(local_pp): selected_pp = local_pp else: selected_pp = os.path.join(self._dir_path, 'models', model_dir, self._parameter_path) # First try user-specified path if self._parameter_path and os.path.isfile(self._parameter_path): pp = self._parameter_path # Then try directory we are running from elif os.path.isfile('parameters.json'): pp = 'parameters.json' # Then try the model directory, with the user-specified name elif os.path.isfile(selected_pp): pp = selected_pp # Finally try model folder elif os.path.isfile(model_pp): pp = model_pp else: raise ValueError(prt.text('no_parameter_file')) if self._is_master: prt.message('files', [basic_model_path, model_path, pp], wrapped=False) with open(pp, 'r') as f: self._parameter_json = json.load(f, object_pairs_hook=OrderedDict) self._modules = OrderedDict() self._bands = [] self._instruments = [] self._telescopes = [] # Load the call tree for the model. Work our way in reverse from the # observables, first constructing a tree for each observable and then # combining trees. root_kinds = ['output', 'objective'] self._trees = OrderedDict() self._simple_trees = OrderedDict() self.construct_trees(self._model, self._trees, self._simple_trees, kinds=root_kinds) if self._print_trees: self._printer.prt('Dependency trees:\n', wrapped=True) self._printer.tree(self._simple_trees) unsorted_call_stack = OrderedDict() self._max_depth_all = -1 for tag in self._model: model_tag = self._model[tag] roots = [] if model_tag['kind'] in root_kinds: max_depth = 0 roots = [model_tag['kind']] else: max_depth = -1 for tag2 in self._trees: if self.in_tree(tag, self._trees[tag2]): roots.extend(self._trees[tag2]['roots']) depth = self.get_max_depth(tag, self._trees[tag2], max_depth) if depth > max_depth: max_depth = depth if depth > self._max_depth_all: self._max_depth_all = depth roots = list(sorted(set(roots))) new_entry = deepcopy(model_tag) new_entry['roots'] = roots if 'children' in new_entry: del new_entry['children'] new_entry['depth'] = max_depth unsorted_call_stack[tag] = new_entry # print(unsorted_call_stack) # Currently just have one call stack for all products, can be wasteful # if only using some products. self._call_stack = OrderedDict() for depth in range(self._max_depth_all, -1, -1): for task in unsorted_call_stack: if unsorted_call_stack[task]['depth'] == depth: self._call_stack[task] = unsorted_call_stack[task] # with open(os.path.join( # self.get_products_path(), # self._model_name + '-stack.json'), 'w') as f: # json.dump(self._call_stack, f) for task in self._call_stack: cur_task = self._call_stack[task] mod_name = cur_task.get('class', task) if cur_task['kind'] == 'parameter' and task in self._parameter_json: cur_task.update(self._parameter_json[task]) self._modules[task] = self._load_task_module(task) if mod_name == 'photometry': self._telescopes = self._modules[task].telescopes() self._instruments = self._modules[task].instruments() self._bands = self._modules[task].bands() self._modules[task].set_attributes(cur_task) # Look forward to see which modules want dense arrays. for task in self._call_stack: for ftask in self._call_stack: if (task != ftask and self._call_stack[ftask]['depth'] < self._call_stack[task]['depth'] and self._modules[ftask]._wants_dense): self._modules[ftask]._provide_dense = True # Count free parameters. self.determine_free_parameters()
class Model(object): """Define a semi-analytical model to fit transients with.""" MODEL_PRODUCTS_DIR = 'products' MIN_WAVE_FRAC_DIFF = 0.1 DRAW_LIMIT = 10 # class outClass(object): # pass def __init__(self, parameter_path='parameters.json', model='', data={}, wrap_length=100, output_path='', pool=None, test=False, printer=None, fitter=None, print_trees=False): """Initialize `Model` object.""" from mosfit.fitter import Fitter self._model_name = model self._parameter_path = parameter_path self._output_path = output_path self._pool = SerialPool() if pool is None else pool self._is_master = pool.is_master() if pool else False self._wrap_length = wrap_length self._print_trees = print_trees self._inflect = inflect.engine() self._test = test self._inflections = {} self._references = OrderedDict() self._free_parameters = [] self._user_fixed_parameters = [] self._user_released_parameters = [] self._kinds_needed = set() self._kinds_supported = set() self._draw_limit_reached = False self._fitter = Fitter() if not fitter else fitter self._printer = self._fitter._printer if not printer else printer prt = self._printer self._dir_path = os.path.dirname(os.path.realpath(__file__)) # Load suggested model associations for transient types. if os.path.isfile(os.path.join('models', 'types.json')): types_path = os.path.join('models', 'types.json') else: types_path = os.path.join(self._dir_path, 'models', 'types.json') with open(types_path, 'r') as f: model_types = json.load(f, object_pairs_hook=OrderedDict) # Create list of all available models. all_models = set() if os.path.isdir('models'): all_models |= set(next(os.walk('models'))[1]) models_path = os.path.join(self._dir_path, 'models') if os.path.isdir(models_path): all_models |= set(next(os.walk(models_path))[1]) all_models = list(sorted(list(all_models))) if not self._model_name: claimed_type = None try: claimed_type = list( data.values())[0]['claimedtype'][0][QUANTITY.VALUE] except Exception: prt.message('no_model_type', warning=True) all_models_txt = prt.text('all_models') suggested_models_txt = prt.text('suggested_models', [claimed_type]) another_model_txt = prt.text('another_model') type_options = model_types.get(claimed_type, []) if claimed_type else [] if not type_options: type_options = all_models model_prompt_txt = all_models_txt else: type_options.append(another_model_txt) model_prompt_txt = suggested_models_txt if not type_options: prt.message('no_model_for_type', warning=True) else: while not self._model_name: if self._test: self._model_name = type_options[0] else: sel = self._printer.prompt( model_prompt_txt, kind='option', options=type_options, message=False, default='n', none_string=prt.text('none_above_models')) if sel is not None: self._model_name = type_options[int(sel) - 1] if not self._model_name: break if self._model_name == another_model_txt: type_options = all_models model_prompt_txt = all_models_txt self._model_name = None if not self._model_name: return # Load the basic model file. if os.path.isfile(os.path.join('models', 'model.json')): basic_model_path = os.path.join('models', 'model.json') else: basic_model_path = os.path.join(self._dir_path, 'models', 'model.json') with open(basic_model_path, 'r') as f: self._model = json.load(f, object_pairs_hook=OrderedDict) # Load the model file. model = self._model_name model_dir = self._model_name if '.json' in self._model_name: model_dir = self._model_name.split('.json')[0] else: model = self._model_name + '.json' if os.path.isfile(model): model_path = model else: # Look in local hierarchy first if os.path.isfile(os.path.join('models', model_dir, model)): model_path = os.path.join('models', model_dir, model) else: model_path = os.path.join(self._dir_path, 'models', model_dir, model) with open(model_path, 'r') as f: self._model.update(json.load(f, object_pairs_hook=OrderedDict)) # Find @ tags, store them, and prune them from `_model`. for tag in list(self._model.keys()): if tag.startswith('@'): if tag == '@references': self._references.setdefault('base', []).extend(self._model[tag]) del self._model[tag] # with open(os.path.join( # self.get_products_path(), # self._model_name + '.json'), 'w') as f: # json.dump(self._model, f) # Load model parameter file. model_pp = os.path.join(self._dir_path, 'models', model_dir, 'parameters.json') pp = '' local_pp = (self._parameter_path if '/' in self._parameter_path else os.path.join('models', model_dir, self._parameter_path)) if os.path.isfile(local_pp): selected_pp = local_pp else: selected_pp = os.path.join(self._dir_path, 'models', model_dir, self._parameter_path) # First try user-specified path if self._parameter_path and os.path.isfile(self._parameter_path): pp = self._parameter_path # Then try directory we are running from elif os.path.isfile('parameters.json'): pp = 'parameters.json' # Then try the model directory, with the user-specified name elif os.path.isfile(selected_pp): pp = selected_pp # Finally try model folder elif os.path.isfile(model_pp): pp = model_pp else: raise ValueError(prt.text('no_parameter_file')) if self._is_master: prt.message('files', [basic_model_path, model_path, pp], wrapped=False) with open(pp, 'r') as f: self._parameter_json = json.load(f, object_pairs_hook=OrderedDict) self._modules = OrderedDict() self._bands = [] self._instruments = [] self._telescopes = [] # Load the call tree for the model. Work our way in reverse from the # observables, first constructing a tree for each observable and then # combining trees. root_kinds = ['output', 'objective'] self._trees = OrderedDict() self._simple_trees = OrderedDict() self.construct_trees(self._model, self._trees, self._simple_trees, kinds=root_kinds) if self._print_trees: self._printer.prt('Dependency trees:\n', wrapped=True) self._printer.tree(self._simple_trees) unsorted_call_stack = OrderedDict() self._max_depth_all = -1 for tag in self._model: model_tag = self._model[tag] roots = [] if model_tag['kind'] in root_kinds: max_depth = 0 roots = [model_tag['kind']] else: max_depth = -1 for tag2 in self._trees: if self.in_tree(tag, self._trees[tag2]): roots.extend(self._trees[tag2]['roots']) depth = self.get_max_depth(tag, self._trees[tag2], max_depth) if depth > max_depth: max_depth = depth if depth > self._max_depth_all: self._max_depth_all = depth roots = list(sorted(set(roots))) new_entry = deepcopy(model_tag) new_entry['roots'] = roots if 'children' in new_entry: del new_entry['children'] new_entry['depth'] = max_depth unsorted_call_stack[tag] = new_entry # print(unsorted_call_stack) # Currently just have one call stack for all products, can be wasteful # if only using some products. self._call_stack = OrderedDict() for depth in range(self._max_depth_all, -1, -1): for task in unsorted_call_stack: if unsorted_call_stack[task]['depth'] == depth: self._call_stack[task] = unsorted_call_stack[task] # with open(os.path.join( # self.get_products_path(), # self._model_name + '-stack.json'), 'w') as f: # json.dump(self._call_stack, f) for task in self._call_stack: cur_task = self._call_stack[task] mod_name = cur_task.get('class', task) if cur_task['kind'] == 'parameter' and task in self._parameter_json: cur_task.update(self._parameter_json[task]) self._modules[task] = self._load_task_module(task) if mod_name == 'photometry': self._telescopes = self._modules[task].telescopes() self._instruments = self._modules[task].instruments() self._bands = self._modules[task].bands() self._modules[task].set_attributes(cur_task) # Look forward to see which modules want dense arrays. for task in self._call_stack: for ftask in self._call_stack: if (task != ftask and self._call_stack[ftask]['depth'] < self._call_stack[task]['depth'] and self._modules[ftask]._wants_dense): self._modules[ftask]._provide_dense = True # Count free parameters. self.determine_free_parameters() def get_products_path(self): """Get path to products.""" return os.path.join(self._output_path, self.MODEL_PRODUCTS_DIR) def _load_task_module(self, task, call_stack=None): if not call_stack: call_stack = self._call_stack cur_task = call_stack[task] kinds = self._inflect.plural(cur_task['kind']) mod_name = cur_task.get('class', task).lower() mod_path = os.path.join('modules', kinds, mod_name + '.py') if not os.path.isfile(mod_path): mod_path = os.path.join(self._dir_path, 'modules', kinds, mod_name + '.py') mod_name = 'mosfit.modules.' + kinds + mod_name try: mod = importlib.machinery.SourceFileLoader(mod_name, mod_path).load_module() except AttributeError: import imp mod = imp.load_source(mod_name, mod_path) class_name = [ x[0] for x in inspect.getmembers(mod, inspect.isclass) if issubclass(x[1], Module) and x[1].__module__ == mod.__name__ ][0] mod_class = getattr(mod, class_name) return mod_class(name=task, model=self, fitter=self._fitter, **cur_task) def load_data(self, data, event_name='', smooth_times=-1, extrapolate_time=0.0, limit_fitting_mjds=False, exclude_bands=[], exclude_instruments=[], exclude_systems=[], exclude_sources=[], exclude_kinds=[], time_unit=None, time_list=[], band_list=[], band_systems=[], band_instruments=[], band_bandsets=[], band_sampling_points=25, variance_for_each=[], user_fixed_parameters=[], user_released_parameters=[], pool=None): """Load the data for the specified event.""" if pool is not None: self._pool = pool self._printer._pool = pool prt = self._printer prt.message('loading_data', inline=True) # Fix user-specified parameters. fixed_parameters = [] released_parameters = [] for task in self._call_stack: for fi, param in enumerate(user_fixed_parameters): if (task == param or self._call_stack[task].get('class', '') == param): fixed_parameters.append(task) if fi < len(user_fixed_parameters) - 1 and is_number( user_fixed_parameters[fi + 1]): value = float(user_fixed_parameters[fi + 1]) if value not in self._call_stack: self._call_stack[task]['value'] = value if 'min_value' in self._call_stack[task]: del self._call_stack[task]['min_value'] if 'max_value' in self._call_stack[task]: del self._call_stack[task]['max_value'] self._modules[task].fix_value( self._call_stack[task]['value']) for fi, param in enumerate(user_released_parameters): if (task == param or self._call_stack[task].get('class', '') == param): released_parameters.append(task) self.determine_free_parameters(fixed_parameters, released_parameters) for ti, task in enumerate(self._call_stack): cur_task = self._call_stack[task] self._modules[task].set_event_name(event_name) new_per = np.round(100.0 * float(ti) / len(self._call_stack)) prt.message('loading_task', [task, new_per], inline=True) self._kinds_supported |= set(cur_task.get('supports', [])) if cur_task['kind'] == 'data': success = self._modules[task].set_data( data, req_key_values=OrderedDict( (('band', self._bands), ('instrument', self._instruments), ('telescope', self._telescopes))), subtract_minimum_keys=['times'], smooth_times=smooth_times, extrapolate_time=extrapolate_time, limit_fitting_mjds=limit_fitting_mjds, exclude_bands=exclude_bands, exclude_instruments=exclude_instruments, exclude_systems=exclude_systems, exclude_sources=exclude_sources, exclude_kinds=exclude_kinds, time_unit=time_unit, time_list=time_list, band_list=band_list, band_systems=band_systems, band_instruments=band_instruments, band_bandsets=band_bandsets) if not success: return False fixed_parameters.extend( self._modules[task].get_data_determined_parameters()) elif cur_task['kind'] == 'sed': self._modules[task].set_data(band_sampling_points) self._kinds_needed |= self._modules[task]._kinds_needed # Find unsupported wavebands and report to user. unsupported_kinds = self._kinds_needed - self._kinds_supported if unsupported_kinds: prt.message('using_unsupported_kinds' if 'none' in exclude_kinds else 'ignoring_unsupported_kinds', [', '.join(sorted(unsupported_kinds))], warning=True) # Determine free parameters again as setting data may have fixed some # more. self.determine_free_parameters(fixed_parameters, released_parameters) self.exchange_requests() prt.message('finding_bands', inline=True) # Run through once to set all inits. for root in ['output', 'objective']: outputs = self.run_stack( [0.0 for x in range(self._num_free_parameters)], root=root) # Create any data-dependent free parameters. self.adjust_fixed_parameters(variance_for_each, outputs) # Determine free parameters again as above may have changed them. self.determine_free_parameters(fixed_parameters, released_parameters) self.determine_number_of_measurements() self.exchange_requests() # Reset modules for task in self._call_stack: self._modules[task].reset_preprocessed(['photometry']) # Run through inits once more. for root in ['output', 'objective']: outputs = self.run_stack( [0.0 for x in range(self._num_free_parameters)], root=root) # Collect observed band info if self._pool.is_master() and 'photometry' in self._modules: prt.message('bands_used') bis = list( filter(lambda a: a != -1, sorted(set(outputs['all_band_indices'])))) ois = [] for bi in bis: ois.append( any([ y for x, y in zip(outputs['all_band_indices'], outputs['observed']) if x == bi ])) band_len = max([ len(self._modules['photometry']._unique_bands[bi]['origin']) for bi in bis ]) filts = self._modules['photometry'] ubs = filts._unique_bands filterarr = [ (ubs[bis[i]]['systems'], ubs[bis[i]]['bandsets'], filts._average_wavelengths[bis[i]], filts._band_offsets[bis[i]], filts._band_kinds[bis[i]], filts._band_names[bis[i]], ois[i], bis[i]) for i in range(len(bis)) ] filterrows = [ (' ' + (' ' if s[-2] else '*') + ubs[s[-1]]['origin'].ljust(band_len) + ' [' + ', '.join( list( filter(None, ('Bandset: ' + s[1] if s[1] else '', 'System: ' + s[0] if s[0] else '', 'AB offset: ' + pretty_num(s[3]) if (s[4] == 'magnitude' and s[0] != 'AB') else '')))) + ']').replace(' []', '') for s in list(sorted(filterarr)) ] if not all(ois): filterrows.append(prt.text('not_observed')) prt.prt('\n'.join(filterrows)) single_freq_inst = list( sorted( set( np.array(outputs['instruments'])[np.array( outputs['all_band_indices']) == -1]))) if len(single_freq_inst): prt.message('single_freq') for inst in single_freq_inst: prt.prt(' {}'.format(inst)) if ('unmatched_bands' in outputs and 'unmatched_instruments' in outputs): prt.message('unmatched_obs', warning=True) prt.prt(', '.join([ '{} [{}]'.format(x[0], x[1]) if x[0] and x[1] else x[0] if not x[1] else x[1] for x in list( set( zip(outputs['unmatched_bands'], outputs['unmatched_instruments']))) ]), warning=True, prefix=False, wrapped=True) return True def adjust_fixed_parameters(self, variance_for_each=[], output={}): """Create free parameters that depend on loaded data.""" unique_band_indices = list( sorted(set(output.get('all_band_indices', [])))) needs_general_variance = any( np.array(output.get('all_band_indices', [])) < 0) new_call_stack = OrderedDict() for task in self._call_stack: cur_task = self._call_stack[task] vfe = listify(variance_for_each) if task == 'variance' and 'band' in vfe: vfi = vfe.index('band') + 1 mwfd = float(vfe[vfi]) if (vfi < len(vfe) and is_number( vfe[vfi])) else self.MIN_WAVE_FRAC_DIFF # Find photometry in call stack. ptask = None for ptask in self._call_stack: if ptask == 'photometry': awaves = self._modules[ptask].average_wavelengths( unique_band_indices) abands = self._modules[ptask].bands( unique_band_indices) band_pairs = list(sorted(zip(awaves, abands))) break owav = 0.0 variance_bands = [] for (awav, band) in band_pairs: wave_frac_diff = abs(awav - owav) / (awav + owav) if wave_frac_diff < mwfd: continue new_task_name = '-'.join([task, 'band', band]) if new_task_name in self._call_stack: continue new_task = deepcopy(cur_task) new_call_stack[new_task_name] = new_task if 'latex' in new_task: new_task['latex'] += '_{\\rm ' + band + '}' new_call_stack[new_task_name] = new_task self._modules[new_task_name] = self._load_task_module( new_task_name, call_stack=new_call_stack) owav = awav variance_bands.append([awav, band]) if needs_general_variance: new_call_stack[task] = deepcopy(cur_task) if self._pool.is_master(): self._printer.message( 'anchoring_variances', [', '.join([x[1] for x in variance_bands])], wrapped=True) self._modules[ptask].set_variance_bands(variance_bands) else: new_call_stack[task] = deepcopy(cur_task) # Fixed any variables to be fixed if any conditional inputs are # fixed by the data. # if any([listify(x)[-1] == 'conditional' # for x in cur_task.get('inputs', [])]): self._call_stack = new_call_stack for task in reversed(self._call_stack): cur_task = self._call_stack[task] for inp in cur_task.get('inputs', []): other = listify(inp)[0] if (cur_task['kind'] == 'parameter' and output.get(other, None) is not None): if (not self._modules[other]._fixed or self._modules[other]._fixed_by_user): self._modules[task]._fixed = True self._modules[task]._derived_keys = list( set(self._modules[task]._derived_keys + [task])) def determine_number_of_measurements(self): """Estimate the number of measurements.""" self._num_measurements = 0 for task in self._call_stack: cur_task = self._call_stack[task] if cur_task['kind'] == 'data': self._num_measurements += len( self._modules[task]._data['times']) def determine_free_parameters(self, extra_fixed_parameters=[], extra_released_parameters=[]): """Generate `_free_parameters` and `_num_free_parameters`.""" self._free_parameters = [] self._user_fixed_parameters = [] self._num_variances = 0 for task in self._call_stack: cur_task = self._call_stack[task] if (task in extra_released_parameters or (task not in extra_fixed_parameters and cur_task['kind'] == 'parameter' and 'min_value' in cur_task and 'max_value' in cur_task and cur_task['min_value'] != cur_task['max_value'] and not self._modules[task]._fixed)): self._free_parameters.append(task) if cur_task.get('class', '') == 'variance': self._num_variances += 1 elif (cur_task['kind'] == 'parameter' and task in extra_fixed_parameters): self._user_fixed_parameters.append(task) self._num_free_parameters = len(self._free_parameters) def is_parameter_fixed_by_user(self, parameter): """Return whether a parameter is fixed by the user.""" return parameter in self._user_fixed_parameters def get_num_free_parameters(self): """Return number of free parameters.""" return self._num_free_parameters def exchange_requests(self): """Exchange requests between modules.""" for task in reversed(self._call_stack): cur_task = self._call_stack[task] if 'requests' in cur_task: requests = OrderedDict() reqs = cur_task['requests'] for req in reqs: if reqs[req] not in self._modules: raise RuntimeError( 'Request cannot be satisfied because module ' '`{}` could not be found.'.format(reqs[req])) requests[req] = self._modules[reqs[req]].send_request(req) self._modules[task].receive_requests(**requests) def frack(self, arg): """Perform fracking upon a single walker. Uses a randomly-selected global or local minimization method. """ x = np.array(arg[0]) step = 1.0 seed = arg[1] np.random.seed(seed) my_choice = np.random.choice(range(3)) # my_choice = 0 my_method = ['L-BFGS-B', 'TNC', 'SLSQP'][my_choice] opt_dict = {'disp': False, 'approx_grad': True} if my_method in ['TNC', 'SLSQP']: opt_dict['maxiter'] = 200 elif my_method == 'L-BFGS-B': opt_dict['maxfun'] = 5000 opt_dict['maxls'] = 50 # bounds = [(0.0, 1.0) for y in range(self._num_free_parameters)] bounds = list( zip(np.clip(x - step, 0.0, 1.0), np.clip(x + step, 0.0, 1.0))) bh = minimize(self.fprob, x, method=my_method, bounds=bounds, options=opt_dict) # bounds = list( # zip(np.clip(x - step, 0.0, 1.0), np.clip(x + step, 0.0, 1.0))) # # bh = differential_evolution( # self.fprob, bounds, disp=True, polish=False) # bh = basinhopping( # self.fprob, # x, # disp=True, # niter=10, # minimizer_kwargs={'method': "L-BFGS-B", # 'bounds': bounds}) # bo = BayesianOptimization(self.boprob, dict( # [('x' + str(i), # (np.clip(x[i] - step, 0.0, 1.0), # np.clip(x[i] + step, 0.0, 1.0))) # for i in range(len(x))])) # # bo.explore(dict([('x' + str(i), [x[i]]) for i in range(len(x))])) # # bo.maximize(init_points=0, n_iter=20, acq='ei') # # bh = self.outClass() # bh.x = [x[1] for x in sorted(bo.res['max']['max_params'].items())] # bh.fun = -bo.res['max']['max_val'] # m = Minuit(self.fprob) # m.migrad() return bh def construct_trees(self, d, trees, simple, kinds=[], name='', roots=[], depth=0): """Construct call trees for each root.""" leaf = kinds if len(kinds) else name if depth > 100: raise RuntimeError( 'Error: Tree depth greater than 100, suggests a recursive ' 'input loop in `{}`.'.format(leaf)) for tag in d: entry = deepcopy(d[tag]) new_roots = list(roots) if entry['kind'] in kinds or tag == name: entry['depth'] = depth if entry['kind'] in kinds: new_roots.append(entry['kind']) entry['roots'] = list(sorted(set(new_roots))) trees[tag] = entry simple[tag] = OrderedDict() inputs = listify(entry.get('inputs', [])) for inps in inputs: conditional = False if isinstance(inps, list) and not isinstance( inps, string_types) and inps[-1] == "conditional": inp = inps[0] conditional = True else: inp = inps if inp not in d: suggests = get_close_matches(inp, d, n=1, cutoff=0.8) warn_str = ('Module `{}` for input to `{}` ' 'not found!'.format(inp, leaf)) if len(suggests): warn_str += (' Did you perhaps mean `{}`?'.format( suggests[0])) raise RuntimeError(warn_str) # Conditional inputs don't propagate down the tree. if conditional: continue children = OrderedDict() simple_children = OrderedDict() self.construct_trees(d, children, simple_children, name=inp, roots=new_roots, depth=depth + 1) trees[tag].setdefault('children', OrderedDict()) trees[tag]['children'].update(children) simple[tag].update(simple_children) def draw_from_icdf(self, draw): """Draw parameters into unit interval using parameter inverse CDFs.""" return [ self._modules[self._free_parameters[i]].prior_icdf(x) for i, x in enumerate(draw) ] def draw_walker(self, test=True, walkers_pool=[], replace=False, weights=None): """Draw a walker randomly. Draw a walker randomly from the full range of all parameters, reject walkers that return invalid scores. """ p = None chosen_one = None draw_cnt = 0 while p is None: draw_cnt += 1 draw = np.random.uniform(low=0.0, high=1.0, size=self._num_free_parameters) draw = self.draw_from_icdf(draw) if walkers_pool: if not replace: chosen_one = 0 else: chosen_one = np.random.choice(range(len(walkers_pool)), p=weights) for e, elem in enumerate(walkers_pool[chosen_one]): if elem is not None: draw[e] = elem if not test: p = draw score = None break score = self.ln_likelihood(draw) if draw_cnt >= self.DRAW_LIMIT and not self._draw_limit_reached: self._printer.message('draw_limit_reached', warning=True) self._draw_limit_reached = True if ((not isnan(score) and np.isfinite(score) and (not isinstance(self._fitter._draw_above_likelihood, float) or score > self._fitter._draw_above_likelihood)) or draw_cnt >= self.DRAW_LIMIT): p = draw if not replace and chosen_one is not None: del walkers_pool[chosen_one] if weights is not None: del weights[chosen_one] if weights and None not in weights: totw = np.sum(weights) weights = [x / totw for x in weights] return (p, score) def get_max_depth(self, tag, parent, max_depth): """Return the maximum depth a given task is found in a tree.""" for child in parent.get('children', []): if child == tag: new_max = parent['children'][child]['depth'] if new_max > max_depth: max_depth = new_max else: new_max = self.get_max_depth(tag, parent['children'][child], max_depth) if new_max > max_depth: max_depth = new_max return max_depth def in_tree(self, tag, parent): """Return the maximum depth a given task is found in a tree.""" for child in parent.get('children', []): if child == tag: return True else: if self.in_tree(tag, parent['children'][child]): return True return False def pool(self): """Return processing pool.""" return self._pool def run(self, x, root='output'): """Run stack with the given root.""" outputs = self.run_stack(x, root=root) return outputs def printer(self): """Return printer.""" return self._printer def likelihood(self, x): """Return score related to maximum likelihood.""" return np.exp(self.ln_likelihood(x)) def ln_likelihood(self, x): """Return ln(likelihood).""" outputs = self.run_stack(x, root='objective') return outputs['value'] def ln_likelihood_floored(self, x): """Return ln(likelihood), floored to a finite value.""" outputs = self.run_stack(x, root='objective') return max(LOCAL_LIKELIHOOD_FLOOR, outputs['value']) def free_parameter_names(self, x): """Return list of free parameter names.""" return self._free_parameters def prior(self, x): """Return score related to paramater priors.""" return np.exp(self.ln_prior(x)) def ln_prior(self, x): """Return ln(prior).""" prior = 0.0 for pi, par in enumerate(self._free_parameters): lprior = self._modules[par].lnprior_pdf(x[pi]) prior = prior + lprior return prior def boprob(self, **kwargs): """Score for `BayesianOptimization`.""" x = [] for key in sorted(kwargs): x.append(kwargs[key]) li = self.ln_likelihood(x) + self.ln_prior(x) if not np.isfinite(li): return LOCAL_LIKELIHOOD_FLOOR return li def fprob(self, x): """Return score for fracking.""" li = -(self.ln_likelihood(x) + self.ln_prior(x)) if not np.isfinite(li): return -LOCAL_LIKELIHOOD_FLOOR return li def plural(self, x): """Pluralize and cache model-related keys.""" if x not in self._inflections: plural = self._inflect.plural(x) if plural == x: plural = x + 's' self._inflections[x] = plural else: plural = self._inflections[x] return plural def reset_unset_recommended_keys(self): """Null the list of unset recommended keys across all modules.""" for module in self._modules.values(): module.reset_unset_recommended_keys() def get_unset_recommended_keys(self): """Collect list of unset recommended keys across all modules.""" unset_keys = set() for module in self._modules.values(): unset_keys.update(module.get_unset_recommended_keys()) return unset_keys def run_stack(self, x, root='objective'): """Run module stack. Run a stack of modules as defined in the model definition file. Only run functions that match the specified root. """ inputs = OrderedDict() outputs = OrderedDict() pos = 0 cur_depth = self._max_depth_all # If this is the first time running this stack, build the ref arrays. build_refs = root not in self._references if build_refs: self._references[root] = [] for task in self._call_stack: cur_task = self._call_stack[task] if root not in cur_task['roots']: continue if cur_task['depth'] != cur_depth: inputs = outputs inputs.update(OrderedDict([('root', root)])) cur_depth = cur_task['depth'] if task in self._free_parameters: inputs.update(OrderedDict([('fraction', x[pos])])) inputs.setdefault('fractions', []).append(x[pos]) pos = pos + 1 try: new_outs = self._modules[task].process(**inputs) if not isinstance(new_outs, OrderedDict): new_outs = OrderedDict(sorted(new_outs.items())) except Exception: self._printer.prt( "Failed to execute module `{}`\'s process().".format(task), wrapped=True) raise outputs.update(new_outs) # Append module references if build_refs: self._references[root].extend(self._modules[task]._REFERENCES) if '_delete_keys' in outputs: for key in list(outputs['_delete_keys'].keys()): del outputs[key] del outputs['_delete_keys'] if build_refs: # Make sure references are unique. self._references[root] = list( map( dict, set( tuple(sorted(d.items())) for d in self._references[root]))) return outputs
label='data: trail') axes[2].set_xlim(0, 17) axes[2].set_ylim(-1, 1) axes[2].set_xlabel(r'$\Delta \phi_1$ [deg]') axes[2].set_ylabel('$\Delta \phi_2$ [deg]') axes[2].legend(loc='best', fontsize=15) fig.tight_layout() fig.savefig( path.join( plot_path, 'BarModels_RL{:d}_Mb{:.0e}_Om{:.1f}.png'.format( release_every, m_b.value, omega.value))) def worker(task): omega, = task width_track(omega * u.km / u.s / u.kpc, m_b=1e10 * u.Msun, release_every=1, n_steps=6000) tasks = [(om, ) for om in np.arange(28.0, 60 + 1e-3, 0.5)] with SerialPool() as pool: #with MultiPool() as pool: print(pool.size) for r in pool.map(worker, tasks): pass