def sample(self, n_samples): if self.pool is None or _GPU_ENABLED: pool = SerialPool() else: if isinstance(self.pool, int): pool = MultiPool(self.pool) elif isinstance(self.pool, (SerialPool, MultiPool)): pool = self.pool else: raise TypeError( "Does not understand the given multiprocessing pool.") drawn_samples = list( pool.map(self.draw_one_joint_posterior_sample_map, range(n_samples))) pool.close() drawn_zs = [drawn_samples[i][0] for i in range(n_samples)] drawn_inference_posteriors = [ drawn_samples[i][1] for i in range(n_samples) ] drawn_joint_posterior_samples = pd.DataFrame( drawn_inference_posteriors) drawn_joint_posterior_samples["redshift"] = drawn_zs return drawn_joint_posterior_samples
def pool(request): multimode = 'None' # multimode = 'Serial' # multimode = 'Multi' # multimode = 'MPI' # setup code pool = None if multimode == 'Serial': from schwimmbad import SerialPool pool = SerialPool() elif multimode == 'Multi': from schwimmbad import MultiPool pool = MultiPool() elif multimode == 'MPI': from schwimmbad import MPIPool pool = MPIPool() if not pool.is_master(): pool.wait() import sys sys.exit(0) # inject class variables request.cls.pool = pool yield # tear down if multimode == 'Multi' or multimode == 'MPI': pool.close()
def test_init(case): prior, _ = get_prior(case) # Try various initializations TheJoker(prior) with pytest.raises(TypeError): TheJoker('jsdfkj') # Pools: with SerialPool() as pool: TheJoker(prior, pool=pool) # fail when pool is invalid: with pytest.raises(TypeError): TheJoker(prior, pool='sdfks') # Random state: rnd = np.random.default_rng(42) TheJoker(prior, random_state=rnd) # fail when random state is invalid: with pytest.raises(TypeError): TheJoker(prior, random_state='sdfks') with pytest.warns(DeprecationWarning): rnd = np.random.RandomState(42) TheJoker(prior, random_state=rnd) # tempfile location: joker = TheJoker(prior, tempfile_path='/tmp/joker') assert os.path.exists(joker.tempfile_path)
def compute_mean_selection_function(selection_function, N_avg, pool=None): if pool is None: pool = SerialPool() elif isinstance(pool, int): pool = MultiPool(pool) elif isinstance(pool, (SerialPool, MultiPool)): pool = pool else: raise TypeError("Does not understand the given multiprocessing pool.") out = list( pool.starmap(selection_function.evaluate, [() for _ in range(N_avg)])) avg = np.average(out) pool.close() return avg
def __init__(self, cuda=False, exit_on_prompt=False, language='en', limiting_magnitude=None, prefer_fluxes=False, offline=False, prefer_cache=False, open_in_browser=False, pool=None, quiet=False, test=False, wrap_length=100, **kwargs): """Initialize `Fitter` class.""" self._pool = SerialPool() if pool is None else pool self._printer = Printer(pool=self._pool, wrap_length=wrap_length, quiet=quiet, fitter=self, language=language, exit_on_prompt=exit_on_prompt) self._fetcher = Fetcher(test=test, open_in_browser=open_in_browser, printer=self._printer) self._cuda = cuda self._limiting_magnitude = limiting_magnitude self._prefer_fluxes = prefer_fluxes self._offline = offline self._prefer_cache = prefer_cache self._open_in_browser = open_in_browser self._quiet = quiet self._test = test self._wrap_length = wrap_length if self._cuda: try: import pycuda.autoinit # noqa: F401 import skcuda.linalg as linalg linalg.init() except ImportError: pass
def fit_events(self, events=[], models=[], max_time='', time_list=[], time_unit=None, band_list=[], band_systems=[], band_instruments=[], band_bandsets=[], band_sampling_points=17, iterations=10000, num_walkers=None, num_temps=1, parameter_paths=['parameters.json'], fracking=True, frack_step=50, burn=None, post_burn=None, gibbs=False, smooth_times=-1, extrapolate_time=0.0, limit_fitting_mjds=False, exclude_bands=[], exclude_instruments=[], exclude_systems=[], exclude_sources=[], exclude_kinds=[], output_path='', suffix='', upload=False, write=False, upload_token='', check_upload_quality=False, variance_for_each=[], user_fixed_parameters=[], user_released_parameters=[], convergence_type=None, convergence_criteria=None, save_full_chain=False, draw_above_likelihood=False, maximum_walltime=False, start_time=False, print_trees=False, maximum_memory=np.inf, speak=False, return_fits=True, extra_outputs=None, walker_paths=[], catalogs=[], exit_on_prompt=False, download_recommended_data=False, local_data_only=False, method=None, seed=None, **kwargs): """Fit a list of events with a list of models.""" global model if start_time is False: start_time = time.time() self._seed = seed if seed is not None: np.random.seed(seed) self._start_time = start_time self._maximum_walltime = maximum_walltime self._maximum_memory = maximum_memory self._debug = False self._speak = speak self._download_recommended_data = download_recommended_data self._local_data_only = local_data_only self._draw_above_likelihood = draw_above_likelihood prt = self._printer event_list = listify(events) model_list = listify(models) if len(model_list) and not len(event_list): event_list = [''] # Exclude catalogs not included in catalog list. self._fetcher.add_excluded_catalogs(catalogs) if not len(event_list) and not len(model_list): prt.message('no_events_models', warning=True) # If the input is not a JSON file, assume it is either a list of # transients or that it is the data from a single transient in tabular # form. Try to guess the format first, and if that fails ask the user. self._converter = Converter(prt, require_source=upload) event_list = self._converter.generate_event_list(event_list) event_list = [x.replace('‑', '-') for x in event_list] entries = [[] for x in range(len(event_list))] ps = [[] for x in range(len(event_list))] lnprobs = [[] for x in range(len(event_list))] # Load walker data if provided a list of walker paths. walker_data = [] if len(walker_paths): try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): prt.message('walker_file') wfi = 0 for walker_path in walker_paths: if os.path.exists(walker_path): prt.prt(' {}'.format(walker_path)) with codecs.open(walker_path, 'r', encoding='utf-8') as f: all_walker_data = json.load( f, object_pairs_hook=OrderedDict) # Support both the format where all data stored in a # single-item dictionary (the OAC format) and the older # MOSFiT format where the data was stored in the # top-level dictionary. if ENTRY.NAME not in all_walker_data: all_walker_data = all_walker_data[list( all_walker_data.keys())[0]] models = all_walker_data.get(ENTRY.MODELS, []) choice = None if len(models) > 1: model_opts = [ '{}-{}-{}'.format(x['code'], x['name'], x['date']) for x in models ] choice = prt.prompt('select_model_walkers', kind='select', message=True, options=model_opts) choice = model_opts.index(choice) elif len(models) == 1: choice = 0 if choice is not None: walker_data.extend([[ wfi, x[REALIZATION.PARAMETERS], x.get(REALIZATION.WEIGHT) ] for x in models[choice][MODEL.REALIZATIONS]]) for i in range(len(walker_data)): if walker_data[i][2] is not None: walker_data[i][2] = float(walker_data[i][2]) if not len(walker_data): prt.message('no_walker_data') else: prt.message('no_walker_data') if self._offline: prt.message('omit_offline') raise RuntimeError wfi = wfi + 1 for rank in range(1, pool.size + 1): pool.comm.send(walker_data, dest=rank, tag=3) else: walker_data = pool.comm.recv(source=0, tag=3) pool.wait() if pool.is_master(): pool.close() self._event_name = 'Batch' self._event_path = '' self._event_data = {} try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): fetched_events = self._fetcher.fetch( event_list, offline=self._offline, prefer_cache=self._prefer_cache) for rank in range(1, pool.size + 1): pool.comm.send(fetched_events, dest=rank, tag=0) pool.close() else: fetched_events = pool.comm.recv(source=0, tag=0) pool.wait() for ei, event in enumerate(fetched_events): if event is not None: self._event_name = event.get('name', 'Batch') self._event_path = event.get('path', '') if not self._event_path: continue self._event_data = self._fetcher.load_data(event) if not self._event_data: continue if model_list: lmodel_list = model_list else: lmodel_list = [''] entries[ei] = [None for y in range(len(lmodel_list))] ps[ei] = [None for y in range(len(lmodel_list))] lnprobs[ei] = [None for y in range(len(lmodel_list))] if (event is not None and (not self._event_data or ENTRY.PHOTOMETRY not in self._event_data[list(self._event_data.keys())[0]])): prt.message('no_photometry', [self._event_name]) continue for mi, mod_name in enumerate(lmodel_list): for parameter_path in parameter_paths: try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() self._model = Model(model=mod_name, data=self._event_data, parameter_path=parameter_path, output_path=output_path, wrap_length=self._wrap_length, test=self._test, printer=prt, fitter=self, pool=pool, print_trees=print_trees) if not self._model._model_name: prt.message('no_models_avail', [self._event_name], warning=True) continue if not event: prt.message('gen_dummy') self._event_name = mod_name gen_args = { 'name': mod_name, 'max_time': max_time, 'time_list': time_list, 'band_list': band_list, 'band_systems': band_systems, 'band_instruments': band_instruments, 'band_bandsets': band_bandsets } self._event_data = self.generate_dummy_data(**gen_args) success = False alt_name = None while not success: self._model.reset_unset_recommended_keys() success = self._model.load_data( self._event_data, event_name=self._event_name, smooth_times=smooth_times, extrapolate_time=extrapolate_time, limit_fitting_mjds=limit_fitting_mjds, exclude_bands=exclude_bands, exclude_instruments=exclude_instruments, exclude_systems=exclude_systems, exclude_sources=exclude_sources, exclude_kinds=exclude_kinds, time_list=time_list, time_unit=time_unit, band_list=band_list, band_systems=band_systems, band_instruments=band_instruments, band_bandsets=band_bandsets, band_sampling_points=band_sampling_points, variance_for_each=variance_for_each, user_fixed_parameters=user_fixed_parameters, user_released_parameters=user_released_parameters, pool=pool) if not success: break if self._local_data_only: break # If our data is missing recommended keys, offer the # user option to pull the missing data from online and # merge with existing data. urk = self._model.get_unset_recommended_keys() ptxt = prt.text('acquire_recommended', [', '.join(list(urk))]) while event and len(urk) and ( alt_name or self._download_recommended_data or prt.prompt(ptxt, [', '.join(urk)], kind='bool')): try: pool = MPIPool() except (ImportError, ValueError): pool = SerialPool() if pool.is_master(): en = (alt_name if alt_name else self._event_name) extra_event = self._fetcher.fetch( en, offline=self._offline, prefer_cache=self._prefer_cache)[0] extra_data = self._fetcher.load_data( extra_event) for rank in range(1, pool.size + 1): pool.comm.send(extra_data, dest=rank, tag=4) pool.close() else: extra_data = pool.comm.recv(source=0, tag=4) pool.wait() if extra_data is not None: extra_data = extra_data[list( extra_data.keys())[0]] for key in urk: new_val = extra_data.get(key) self._event_data[list( self._event_data.keys()) [0]][key] = new_val if new_val is not None and len(new_val): prt.message('extra_value', [ key, str(new_val[0].get(QUANTITY.VALUE)) ]) success = False prt.message('reloading_merged') break else: text = prt.text('extra_not_found', [self._event_name]) alt_name = prt.prompt(text, kind='string') if not alt_name: break if success: self._walker_data = walker_data entry, p, lnprob = self.fit_data( event_name=self._event_name, method=method, iterations=iterations, num_walkers=num_walkers, num_temps=num_temps, burn=burn, post_burn=post_burn, fracking=fracking, frack_step=frack_step, gibbs=gibbs, pool=pool, output_path=output_path, suffix=suffix, write=write, upload=upload, upload_token=upload_token, check_upload_quality=check_upload_quality, convergence_type=convergence_type, convergence_criteria=convergence_criteria, save_full_chain=save_full_chain, extra_outputs=extra_outputs) if return_fits: entries[ei][mi] = deepcopy(entry) ps[ei][mi] = deepcopy(p) lnprobs[ei][mi] = deepcopy(lnprob) if pool.is_master(): pool.close() # Remove global model variable and garbage collect. try: model except NameError: pass else: del (model) del (self._model) gc.collect() return (entries, ps, lnprobs)
def __init__(self, parameter_path='parameters.json', model='', data={}, wrap_length=100, output_path='', pool=None, test=False, printer=None, fitter=None, print_trees=False): """Initialize `Model` object.""" from mosfit.fitter import Fitter self._model_name = model self._parameter_path = parameter_path self._output_path = output_path self._pool = SerialPool() if pool is None else pool self._is_master = pool.is_master() if pool else False self._wrap_length = wrap_length self._print_trees = print_trees self._inflect = inflect.engine() self._test = test self._inflections = {} self._references = OrderedDict() self._free_parameters = [] self._user_fixed_parameters = [] self._user_released_parameters = [] self._kinds_needed = set() self._kinds_supported = set() self._draw_limit_reached = False self._fitter = Fitter() if not fitter else fitter self._printer = self._fitter._printer if not printer else printer prt = self._printer self._dir_path = os.path.dirname(os.path.realpath(__file__)) # Load suggested model associations for transient types. if os.path.isfile(os.path.join('models', 'types.json')): types_path = os.path.join('models', 'types.json') else: types_path = os.path.join(self._dir_path, 'models', 'types.json') with open(types_path, 'r') as f: model_types = json.load(f, object_pairs_hook=OrderedDict) # Create list of all available models. all_models = set() if os.path.isdir('models'): all_models |= set(next(os.walk('models'))[1]) models_path = os.path.join(self._dir_path, 'models') if os.path.isdir(models_path): all_models |= set(next(os.walk(models_path))[1]) all_models = list(sorted(list(all_models))) if not self._model_name: claimed_type = None try: claimed_type = list( data.values())[0]['claimedtype'][0][QUANTITY.VALUE] except Exception: prt.message('no_model_type', warning=True) all_models_txt = prt.text('all_models') suggested_models_txt = prt.text('suggested_models', [claimed_type]) another_model_txt = prt.text('another_model') type_options = model_types.get(claimed_type, []) if claimed_type else [] if not type_options: type_options = all_models model_prompt_txt = all_models_txt else: type_options.append(another_model_txt) model_prompt_txt = suggested_models_txt if not type_options: prt.message('no_model_for_type', warning=True) else: while not self._model_name: if self._test: self._model_name = type_options[0] else: sel = self._printer.prompt( model_prompt_txt, kind='option', options=type_options, message=False, default='n', none_string=prt.text('none_above_models')) if sel is not None: self._model_name = type_options[int(sel) - 1] if not self._model_name: break if self._model_name == another_model_txt: type_options = all_models model_prompt_txt = all_models_txt self._model_name = None if not self._model_name: return # Load the basic model file. if os.path.isfile(os.path.join('models', 'model.json')): basic_model_path = os.path.join('models', 'model.json') else: basic_model_path = os.path.join(self._dir_path, 'models', 'model.json') with open(basic_model_path, 'r') as f: self._model = json.load(f, object_pairs_hook=OrderedDict) # Load the model file. model = self._model_name model_dir = self._model_name if '.json' in self._model_name: model_dir = self._model_name.split('.json')[0] else: model = self._model_name + '.json' if os.path.isfile(model): model_path = model else: # Look in local hierarchy first if os.path.isfile(os.path.join('models', model_dir, model)): model_path = os.path.join('models', model_dir, model) else: model_path = os.path.join(self._dir_path, 'models', model_dir, model) with open(model_path, 'r') as f: self._model.update(json.load(f, object_pairs_hook=OrderedDict)) # Find @ tags, store them, and prune them from `_model`. for tag in list(self._model.keys()): if tag.startswith('@'): if tag == '@references': self._references.setdefault('base', []).extend(self._model[tag]) del self._model[tag] # with open(os.path.join( # self.get_products_path(), # self._model_name + '.json'), 'w') as f: # json.dump(self._model, f) # Load model parameter file. model_pp = os.path.join(self._dir_path, 'models', model_dir, 'parameters.json') pp = '' local_pp = (self._parameter_path if '/' in self._parameter_path else os.path.join('models', model_dir, self._parameter_path)) if os.path.isfile(local_pp): selected_pp = local_pp else: selected_pp = os.path.join(self._dir_path, 'models', model_dir, self._parameter_path) # First try user-specified path if self._parameter_path and os.path.isfile(self._parameter_path): pp = self._parameter_path # Then try directory we are running from elif os.path.isfile('parameters.json'): pp = 'parameters.json' # Then try the model directory, with the user-specified name elif os.path.isfile(selected_pp): pp = selected_pp # Finally try model folder elif os.path.isfile(model_pp): pp = model_pp else: raise ValueError(prt.text('no_parameter_file')) if self._is_master: prt.message('files', [basic_model_path, model_path, pp], wrapped=False) with open(pp, 'r') as f: self._parameter_json = json.load(f, object_pairs_hook=OrderedDict) self._modules = OrderedDict() self._bands = [] self._instruments = [] self._telescopes = [] # Load the call tree for the model. Work our way in reverse from the # observables, first constructing a tree for each observable and then # combining trees. root_kinds = ['output', 'objective'] self._trees = OrderedDict() self._simple_trees = OrderedDict() self.construct_trees(self._model, self._trees, self._simple_trees, kinds=root_kinds) if self._print_trees: self._printer.prt('Dependency trees:\n', wrapped=True) self._printer.tree(self._simple_trees) unsorted_call_stack = OrderedDict() self._max_depth_all = -1 for tag in self._model: model_tag = self._model[tag] roots = [] if model_tag['kind'] in root_kinds: max_depth = 0 roots = [model_tag['kind']] else: max_depth = -1 for tag2 in self._trees: if self.in_tree(tag, self._trees[tag2]): roots.extend(self._trees[tag2]['roots']) depth = self.get_max_depth(tag, self._trees[tag2], max_depth) if depth > max_depth: max_depth = depth if depth > self._max_depth_all: self._max_depth_all = depth roots = list(sorted(set(roots))) new_entry = deepcopy(model_tag) new_entry['roots'] = roots if 'children' in new_entry: del new_entry['children'] new_entry['depth'] = max_depth unsorted_call_stack[tag] = new_entry # print(unsorted_call_stack) # Currently just have one call stack for all products, can be wasteful # if only using some products. self._call_stack = OrderedDict() for depth in range(self._max_depth_all, -1, -1): for task in unsorted_call_stack: if unsorted_call_stack[task]['depth'] == depth: self._call_stack[task] = unsorted_call_stack[task] # with open(os.path.join( # self.get_products_path(), # self._model_name + '-stack.json'), 'w') as f: # json.dump(self._call_stack, f) for task in self._call_stack: cur_task = self._call_stack[task] mod_name = cur_task.get('class', task) if cur_task['kind'] == 'parameter' and task in self._parameter_json: cur_task.update(self._parameter_json[task]) self._modules[task] = self._load_task_module(task) if mod_name == 'photometry': self._telescopes = self._modules[task].telescopes() self._instruments = self._modules[task].instruments() self._bands = self._modules[task].bands() self._modules[task].set_attributes(cur_task) # Look forward to see which modules want dense arrays. for task in self._call_stack: for ftask in self._call_stack: if (task != ftask and self._call_stack[ftask]['depth'] < self._call_stack[task]['depth'] and self._modules[ftask]._wants_dense): self._modules[ftask]._provide_dense = True # Count free parameters. self.determine_free_parameters()
label='data: trail') axes[2].set_xlim(0, 17) axes[2].set_ylim(-1, 1) axes[2].set_xlabel(r'$\Delta \phi_1$ [deg]') axes[2].set_ylabel('$\Delta \phi_2$ [deg]') axes[2].legend(loc='best', fontsize=15) fig.tight_layout() fig.savefig( path.join( plot_path, 'BarModels_RL{:d}_Mb{:.0e}_Om{:.1f}.png'.format( release_every, m_b.value, omega.value))) def worker(task): omega, = task width_track(omega * u.km / u.s / u.kpc, m_b=1e10 * u.Msun, release_every=1, n_steps=6000) tasks = [(om, ) for om in np.arange(28.0, 60 + 1e-3, 0.5)] with SerialPool() as pool: #with MultiPool() as pool: print(pool.size) for r in pool.map(worker, tasks): pass