def __init__(self, params, pool=None, random_state=None, n_batches=None): # set the processing pool if pool is None: import schwimmbad pool = schwimmbad.SerialPool() elif not hasattr(pool, 'map') or not hasattr(pool, 'close'): raise TypeError("Input pool object must have .map() and .close() " "methods. We recommend using `schwimmbad` pools.") self.pool = pool # Set the parent random state - child processes get different states # based on the parent if random_state is None: self._rnd_passed = False random_state = np.random.RandomState() elif not isinstance(random_state, np.random.RandomState): raise TypeError("Random state object must be a numpy RandomState " "instance, not '{0}'".format(type(random_state))) else: self._rnd_passed = True self.random_state = random_state # check if a JokerParams instance was passed in to specify the state if not isinstance(params, JokerParams): raise TypeError("Parameter specification must be a JokerParams " "instance, not a '{0}'".format(type(params))) self.params = params self.n_batches = n_batches
def read_array(ftype, directory, tag, dataset, numThreads=1, noH=False, physicalUnits=False, CGS=False, verbose=True): """ Args: ftype (str) directory (str) tag (str) dataset (str) numThreads (int) noH (bool) physicalUnits (bool) """ start = timeit.default_timer() files = get_files(ftype, directory, tag) if numThreads == 1: pool = schwimmbad.SerialPool() elif numThreads == -1: pool = schwimmbad.MultiPool() else: pool = schwimmbad.MultiPool(processes=numThreads) lg = partial(read_hdf5, dataset=dataset) dat = np.concatenate(list(pool.map(lg, files)), axis=0) pool.close() stop = timeit.default_timer() print("Reading in '{}' for z = {} using {} thread(s) took {}s".format( dataset, np.round(read_header(ftype, directory, tag, dataset='Redshift'), 3), numThreads, np.round(stop - start, 6))) if noH: dat = apply_hfreeUnits_conversion(files[0], dataset, dat, verbose=verbose) if physicalUnits: dat = apply_physicalUnits_conversion(files[0], dataset, dat, verbose=verbose) if CGS: dat = apply_CGSUnits_conversion(files[0], dataset, dat, verbose=verbose) return dat
def main(db_path, run_name, overwrite=False, pool=None): if pool is None: pool = schwimmbad.SerialPool() # connect to the database engine = db_connect(db_path) # engine.echo = True logger.debug("Connected to database at '{}'".format(db_path)) # create a new session for interacting with the database session = Session() root_path, _ = path.split(db_path) plot_path = path.join(root_path, 'plots', run_name) if not path.exists(plot_path): os.makedirs(plot_path, exist_ok=True) # get object to correct the observed RV's rv_corr = RVCorrector(session, run_name) observations = session.query(Observation).join(Run)\ .filter(Run.name == run_name).all() for obs in observations: q = session.query(RVMeasurement).join(Observation)\ .filter(Observation.id == obs.id) if q.count() > 0 and not overwrite: logger.debug('RV measurement already complete for object ' '{0} in file {1}'.format(obs.object, obs.filename_raw)) continue elif q.count() > 1: raise RuntimeError( 'Multiple RV measurements found for object {0}'.format(obs)) elif len(obs.measurements) == 0: logger.debug( 'Observation {0} has no line measurements.'.format(obs)) continue corrected_rv, err, flag = rv_corr.get_corrected_rv(obs) # remove previous RV measurements if q.count() > 0: session.delete(q.one()) session.commit() rv_meas = RVMeasurement(rv=corrected_rv, err=err, flag=flag) rv_meas.observation = obs session.add(rv_meas) session.commit() pool.close()
def get_age(arr, z, numThreads=4): if numThreads == 1: pool = schwimmbad.SerialPool() elif numThreads == -1: pool = schwimmbad.MultiPool() else: pool = schwimmbad.MultiPool(processes=numThreads) calc = partial(get_SFT, redshift=z) Age = np.array(list(pool.map(calc, arr))) return Age
def get_age(self, arr, z, numThreads=4): if numThreads == 1: pool = schwimmbad.SerialPool() elif numThreads == -1: pool = schwimmbad.MultiPool() else: pool = schwimmbad.MultiPool(processes=numThreads) Age = self.cosmo.age(z).value - np.array( list(pool.map(self.get_star_formation_time, arr))) pool.close() return Age
def run(self, batchsize=1, batches=1, threads=1): if threads == 1: pool = schwimmbad.SerialPool() else: pool = multiprocessing.Pool() for i in range(batches): par_list = self.get_parameters(size=batchsize) indices = [uuid.uuid4().hex for i in range(batchsize)] df = pd.DataFrame(par_list, index=indices) self.store.store_df('parameters', df, append=True) for key in self.funcdic: worker = partial(_run_single, scan=self, key=key) results = list(pool.map(worker, par_list)) df = pd.DataFrame(results, index=indices) self.store.store_df(key, df, append=True) pool.close()
def __init__(self, prior, pool=None, random_state=None, tempfile_path=None): # set the processing pool if pool is None: import schwimmbad pool = schwimmbad.SerialPool() elif not hasattr(pool, 'map') or not hasattr(pool, 'close'): raise TypeError("Input pool object must have .map() and .close() " "methods. We recommend using `schwimmbad` pools.") self.pool = pool # Set the parent random state - child processes get different states # based on the parent if random_state is None: random_state = np.random.default_rng() elif isinstance(random_state, np.random.RandomState): warnings.warn( "With thejoker>=v1.2, use numpy.random.Generator " "objects instead of RandomState objects to control " "random numbers.", DeprecationWarning) tmp = np.random.Generator(np.random.MT19937()) tmp.bit_generator.state = random_state.get_state() random_state = tmp elif not isinstance(random_state, np.random.Generator): raise TypeError("Random state object must be a " "numpy.random.Generator instance, not " f"'{type(random_state)}'") self.random_state = random_state # check if a JokerParams instance was passed in to specify the state if not isinstance(prior, JokerPrior): raise TypeError("The input prior must be a JokerPrior instance.") self.prior = prior if tempfile_path is None: self._tempfile_path = os.path.expanduser('~/.thejoker/') else: self._tempfile_path = os.path.abspath( os.path.expanduser(tempfile_path))
def test_multiproc_helpers(self, tmpdir): prior_samples_file = str(tmpdir.join('prior-samples.h5')) pool = schwimmbad.SerialPool() data = self.data['circ_binary'] joker_params = self.joker_params['circ_binary'] truth = self.truths['circ_binary'] nlp = self.truths_to_nlp(truth) # write some nonsense out to the prior file n = 8192 P = np.random.uniform(nlp[0]-2., nlp[0]+2., n) M0 = np.random.uniform(0, 2*np.pi, n) ecc = np.zeros(n) omega = np.zeros(n) jitter = np.zeros(n) samples = np.vstack((P,M0,ecc,omega,jitter)).T # TODO: use save_prior_samples here with h5py.File(prior_samples_file) as f: f['samples'] = samples lls = compute_likelihoods(n, prior_samples_file, 0, data, joker_params, pool) idx = get_good_sample_indices(lls) assert len(idx) >= 1 lls = compute_likelihoods(n, prior_samples_file, 0, data, joker_params, pool, n_batches=13) idx = get_good_sample_indices(lls) assert len(idx) >= 1 full_samples = sample_indices_to_full_samples(idx, prior_samples_file, data, joker_params, pool) print(full_samples)
def main(db_path, run_name, data_root_path=None, filename=None, overwrite=False, pool=None): if pool is None: pool = schwimmbad.SerialPool() # connect to the database engine = db_connect(db_path) # engine.echo = True logger.debug("Connected to database at '{}'".format(db_path)) # create a new session for interacting with the database session = Session() root_path, _ = path.split(db_path) if data_root_path is None: data_root_path = root_path plot_path = path.join(root_path, 'plots', run_name) if not path.exists(plot_path): os.makedirs(plot_path, exist_ok=True) # TODO: there might be some bugs here... n_lines = session.query(SpectralLineInfo).count() Halpha = session.query(SpectralLineInfo)\ .filter(SpectralLineInfo.name == 'Halpha').one() OI_lines = session.query(SpectralLineInfo)\ .filter(SpectralLineInfo.name.contains('[OI]')).all() if filename is None: # grab all unfinished sources observations = session.query(Observation).join(Run)\ .filter(Run.name == run_name).all() else: # only process the observation corresponding to this filename observations = session.query(Observation).join(Run)\ .filter(Run.name == run_name)\ .filter(Observation.filename_raw == filename).all() for obs in observations: measurements = session.query(SpectralLineMeasurement)\ .join(Observation)\ .filter(Observation.id == obs.id).all() if len(measurements) == n_lines and not overwrite: logger.debug('All line measurements already complete for object ' '{0} in file {1}'.format(obs.object, obs.filename_raw)) continue # Read the spectrum data and get wavelength solution filebase, _ = path.splitext(obs.filename_1d) filename_1d = obs.path_1d(data_root_path) spec = Table.read(filename_1d) logger.debug('Loaded 1D spectrum for object {0} from file {1}'.format( obs.object, filename_1d)) # Extract region around Halpha x, (flux, ivar) = extract_region( spec['wavelength'], center=Halpha.wavelength.value, width=100, arrs=[spec['source_flux'], spec['source_ivar']]) # We start by doing maximum likelihood estimation to fit the line, then # use the best-fit parameters to initialize an MCMC run. # TODO: need to figure out if it's emission or absorption...for now just # assume absorption absorp_emiss = -1. lf = VoigtLineFitter(x, flux, ivar, absorp_emiss=absorp_emiss) lf.fit() fit_pars = lf.get_gp_mean_pars() if (not lf.success or abs(fit_pars['x0'] - Halpha.wavelength.value) > 16. or # 16 Å = ~700 km/s abs(fit_pars['amp']) < 10): # minimum amplitude - MAGIC NUMBER # TODO: should try again with emission line logger.error('absorption line has tiny amplitude! did ' 'auto-determination of absorption/emission fail?') # TODO: what now? continue fig = lf.plot_fit() fig.savefig(path.join(plot_path, '{}_maxlike.png'.format(filebase)), dpi=256) plt.close(fig) # ---------------------------------------------------------------------- # Run `emcee` instead to sample over GP model parameters: if fit_pars['std_G'] < 1E-2: lf.gp.freeze_parameter('mean:ln_std_G') initial = np.array(lf.gp.get_parameter_vector()) if initial[4] < -10: # TODO: ??? initial[4] = -8. if initial[5] < -10: # TODO: ??? initial[5] = -8. ndim, nwalkers = len(initial), 64 sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability, pool=pool, args=(lf.gp, flux)) logger.debug("Running burn-in...") p0 = initial + 1e-6 * np.random.randn(nwalkers, ndim) p0, lp, _ = sampler.run_mcmc(p0, 128) logger.debug("Running 2nd burn-in...") sampler.reset() p0 = p0[lp.argmax()] + 1e-3 * np.random.randn(nwalkers, ndim) p0, lp, _ = sampler.run_mcmc(p0, 512) logger.debug("Running production...") sampler.reset() pos, lp, _ = sampler.run_mcmc(p0, 1024) fit_kw = dict() for i, par_name in enumerate(lf.gp.get_parameter_names()): if 'kernel' in par_name: continue # remove 'mean:' par_name = par_name[5:] # skip bg if par_name.startswith('bg'): continue samples = sampler.flatchain[:, i] if par_name.startswith('ln_'): par_name = par_name[3:] samples = np.exp(samples) MAD = np.median(np.abs(samples - np.median(samples))) fit_kw[par_name] = np.median(samples) fit_kw[par_name + '_error'] = 1.5 * MAD # convert to ~stddev # remove all previous line measurements q = session.query(SpectralLineMeasurement).join(Observation)\ .filter(Observation.id == obs.id) if q.count() > 0: for meas in q.all(): session.delete(meas) session.commit() slm = SpectralLineMeasurement(**fit_kw) slm.info = Halpha slm.observation = obs session.add(slm) session.commit() # -------------------------------------------------------------------- # plot MCMC traces fig, axes = plt.subplots(2, 4, figsize=(18, 6)) for i in range(sampler.dim): for walker in sampler.chain[..., i]: axes.flat[i].plot(walker, marker='', drawstyle='steps-mid', alpha=0.2) axes.flat[i].set_title(lf.gp.get_parameter_names()[i], fontsize=12) fig.tight_layout() fig.savefig(path.join(plot_path, '{}_mcmc_trace.png'.format(filebase)), dpi=256) plt.close(fig) # -------------------------------------------------------------------- # -------------------------------------------------------------------- # plot samples fig, axes = plt.subplots(3, 1, figsize=(10, 10), sharex=True) samples = sampler.flatchain for s in samples[np.random.randint(len(samples), size=32)]: lf.gp.set_parameter_vector(s) lf.plot_fit(axes=axes, fit_alpha=0.2) fig.tight_layout() fig.savefig(path.join(plot_path, '{}_mcmc_fits.png'.format(filebase)), dpi=256) plt.close(fig) # -------------------------------------------------------------------- # -------------------------------------------------------------------- # corner plot fig = corner.corner( sampler.flatchain[::10, :], labels=[x.split(':')[1] for x in lf.gp.get_parameter_names()]) fig.savefig(path.join(plot_path, '{}_corner.png'.format(filebase)), dpi=256) plt.close(fig) # -------------------------------------------------------------------- # compute centroids for sky lines sky_centroids = [] for j, sky_line in enumerate(OI_lines): wvln = sky_line.wavelength.value x, (flux, ivar) = extract_region( spec['wavelength'], center=wvln, width=32., # angstroms arrs=[spec['background_flux'], spec['background_ivar']]) lf = GaussianLineFitter(x, flux, ivar, absorp_emiss=1.) # all emission lines try: lf.fit() fit_pars = lf.get_gp_mean_pars() except Exception as e: logger.warn("Failed to fit sky line {0}:\n{1}".format( sky_line, e)) lf.success = False fit_pars = lf.get_init() # OMG this is the biggest effing hack fit_pars['amp'] = 0. fit_pars['bg_coef'] = None fit_pars['x0'] = 0. # HACK: hackish signal-to-noise max_ = fit_pars['amp'] / np.sqrt(2 * np.pi * fit_pars['std']**2) SNR = max_ / np.median(1 / np.sqrt(ivar)) if (not lf.success or abs(fit_pars['x0'] - wvln) > 4 or fit_pars['amp'] < 10 or fit_pars['std'] > 4 or SNR < 2.5): # failed x0 = np.nan * u.angstrom title = 'f****d' fit_pars['amp'] = 0. else: x0 = fit_pars['x0'] * u.angstrom title = '{:.2f}'.format(fit_pars['amp']) if lf.success: fig = lf.plot_fit() fig.suptitle(title, y=0.95) fig.subplots_adjust(top=0.8) fig.savefig(path.join( plot_path, '{}_maxlike_sky_{:.0f}.png'.format(filebase, wvln)), dpi=256) plt.close(fig) # store the sky line measurements fit_pars['std_G'] = fit_pars.pop('std') # HACK fit_pars.pop('bg_coef') # HACK slm = SpectralLineMeasurement(**fit_pars) slm.info = sky_line slm.observation = obs session.add(slm) session.commit() sky_centroids.append(x0) sky_centroids = u.Quantity(sky_centroids) logger.info('{} [{}]: x0={x0:.3f} σ={err:.3f}\n--------'.format( obs.object, filebase, x0=fit_kw['x0'], err=fit_kw['x0_error'])) session.commit() pool.close()
parser.add_argument('--sim', action='store_true', default=False, dest='simulated_data') parser.add_argument('--name', required=True, dest='name', help='Name of the data - can be "apw" or "rave"') args = parser.parse_args() if args.mpi: pool = MPIPool() if not pool.is_master(): pool.wait() sys.exit(0) else: pool = schwimmbad.SerialPool() if args.simulated_data: print("Loading simulated data") # Load simulated data _tbl1 = fits.getdata('../notebooks/data1.fits') data1 = TGASData(_tbl1, rv=_tbl1['RV']*u.km/u.s, rv_err=_tbl1['RV_err']*u.km/u.s) _tbl2 = fits.getdata('../notebooks/data2.fits') data2 = TGASData(_tbl2, rv=_tbl2['RV']*u.km/u.s, rv_err=_tbl2['RV_err']*u.km/u.s) else: print("Loading real data")
import schwimmbad import numpy as np def func(i): ''' A useless function ''' print(str(i + 1)) return i # Use multipool - same as multiprocessing with schwimmbad.MultiPool() as pool: inputs = [i for i in np.arange(0, 10, 2)] out1 = list(pool.map(func, inputs)) # Use serial pool with schwimmbad.SerialPool() as pool: inputs = [i for i in np.arange(10, 20, 2)] out2 = list(pool.map(func, inputs)) print(out1, out2)
def main(argv=None): args = get_options(argv=argv) np.random.seed(seed=42) # setup time-ranges ligo_run_start = Time('2022-06-01T00:00:00.0') ligo_run_end = Time('2023-06-01T00:00:00.0') hst_cyc_start = Time('2021-10-01T00:00:00.0') hst_cyc_end = Time('2023-09-30T00:00:00.0') #hst_cyc_end = Time('2023-09-30T00:00:00.0') eng_time = 2.*u.week Range = namedtuple('Range', ['start', 'end']) ligo_run = Range(start=ligo_run_start, end=ligo_run_end) hst_cycle = Range(start=hst_cyc_start, end=hst_cyc_end) latest_start = max(ligo_run.start, hst_cycle.start) earliest_end = min(ligo_run.end, hst_cycle.end) td = (earliest_end - latest_start) + eng_time fractional_duration = (td/(1.*u.year)).decompose().value box_size = args.box_size volume = box_size**3 # create the mass distribution of the merging neutron star mass_distrib = args.mass_distrib # the truncated normal distribution looks to be from: # https://arxiv.org/pdf/1309.6635.pdf mean_mass = args.masskey1 sig_mass = args.masskey2 min_mass = args.masskey1 max_mass = args.masskey2 # the two ligo detectors ahve strongly correlated duty cycles # they are both not very correlated with Virgo lvc_cor_matrix = np.array([[1., 0.8, 0.5, 0.2], [0.8, 1., 0.5, 0.2], [0.5, 0.5, 1., 0.2], [0.2, 0.2, 0.2, 1.]]) upper_chol = cholesky(lvc_cor_matrix) # setup duty cycles h_duty = args.hdutycycle l_duty = args.ldutycycle v_duty = args.vdutycycle k_duty = args.kdutycycle # setup event rates mean_lograte = args.mean_lograte sig_lograte = args.sig_lograte n_try = args.ntry temp = at.Table.read('kilonova_phottable_40Mpc.txt', format='ascii') phase = temp['ofphase'] temphmag = temp['f160w'] tempf200w = temp['f218w'] temprmag = temp['f625w'] # define ranges ligo_range = get_range('ligo') virgo_range = get_range('virgo') kagra_range = get_range('kagra') def dotry(n): rate = 10.**(np.random.normal(mean_lograte, sig_lograte)) n_events = np.around(rate*volume*fractional_duration).astype('int') if n_events == 0: return tuple(0 for _ in range(15)) # FIXME: fix to prevent unpacking error print(f"### Num trial = {n}; Num events = {n_events}") if mass_distrib == 'mw': mass1 = spstat.truncnorm.rvs(0, np.inf, 1.4, 0.09, n_events) # FIXME: Unbound local error mass2 = spstat.truncnorm.rvs(0, np.inf, 1.4, 0.09, n_events) elif mass_distrib == 'msp': print("MSP population chosen, overriding mean_mass and sig_mass if supplied.") # numbers from https://arxiv.org/pdf/1605.01665.pdf # two modes, choose a random one each time mean_mass, sig_mass = random.choice([(1.393, 0.064), (1.807, 0.177)]) mass1 = spstat.truncnorm.rvs(0, np.inf, mean_mass, sig_mass, n_events) mass2 = spstat.truncnorm.rvs(0, np.inf, mean_mass, sig_mass, n_events) else: print("Flat population chosen.") mass1 = np.random.uniform(min_mass, max_mass, n_events) mass2 = np.random.uniform(min_mass, max_mass, n_events) bns_range_ligo = np.array( [ligo_range(m1=m1, m2=m2) for m1, m2 in zip(mass1, mass2)] ) * u.Mpc bns_range_virgo = np.array( [virgo_range(m1=m1, m2=m2) for m1, m2 in zip(mass1, mass2)] ) * u.Mpc bns_range_kagra = np.array( [kagra_range(m1=m1, m2=m2) for m1, m2 in zip(mass1, mass2)] ) * u.Mpc tot_mass = mass1 + mass2 delay = np.random.uniform(0, 365.25, n_events) delay[delay > 90] = 0 av = np.random.exponential(1, n_events)*0.4 ah = av/6.1 ar = av/1.33 # ref: table 2 of https://arxiv.org/abs/astro-ph/9809387 sss17a = -16.9 #H-band sss17a_r = -15.8 #Rband minmag = -14.7 maxmag = sss17a - 2. rmag = temprmag - min(temprmag) rmag[phase < 2.5] = 0 magindex = [(phase - x).argmin() for x in delay] magindex = np.array(magindex) default_value= [0,] if n_events == 0: return default_value, default_value, default_value, default_value, default_value, default_value, 0, 0 absm = np.random.uniform(0, 1, n_events)*abs(maxmag-minmag) + sss17a_r + rmag[magindex] + ar absm = np.array(absm) # simulate coordinates x = np.random.uniform(-box_size/2., box_size/2., n_events)*u.megaparsec y = np.random.uniform(-box_size/2., box_size/2., n_events)*u.megaparsec z = np.random.uniform(-box_size/2., box_size/2., n_events)*u.megaparsec dist = (x**2. + y**2. + z**2. + (0.05*u.megaparsec)**2.)**0.5 h_on, l_on, v_on, k_on = get_sim_dutycycles(n_events, upper_chol, h_duty, l_duty, v_duty, k_duty) n_detectors_on = np.array( [sum(_) for _ in np.vstack((h_on, l_on, v_on, k_on)).T] ) # which detectors observed dist_ligo_bool = dist <= bns_range_ligo dist_virgo_bool = dist <= bns_range_virgo dist_kagra_bool = dist <= bns_range_kagra h_on_and_observed = h_on * dist_ligo_bool l_on_and_observed = l_on * dist_ligo_bool v_on_and_observed = v_on * dist_virgo_bool k_on_and_observed = k_on * dist_kagra_bool n_detectors_on_and_obs = np.sum(np.vstack( (h_on_and_observed, l_on_and_observed, v_on_and_observed, k_on_and_observed)).T, axis=1 ) two_det_obs = n_detectors_on_and_obs == 2 three_det_obs = n_detectors_on_and_obs == 3 four_det_obs = n_detectors_on_and_obs == 4 # decide whether there is a kilnova based on remnant matter has_ejecta_bool = [ has_ejecta_mass(m1, m2) for m1, m2 in zip(mass1, mass2) ] distmod = Distance(dist) obsmag = absm + distmod.distmod.value em_bool = obsmag < 22. # whether this event was not affected by then sun detected_events = np.where(em_bool) sun_bool = np.random.random(len(detected_events[0])) >= args.sun_loss em_bool[detected_events] = sun_bool n2_gw_only = np.where(two_det_obs)[0] n2_gw = len(n2_gw_only) n2_good = np.where(two_det_obs & em_bool & has_ejecta_bool)[0] n2 = len(n2_good) # sanity check assert n2_gw >= n2, "GW events ({}) less than EM follow events ({})".format(n2_gw, n2) n3_gw_only = np.where(three_det_obs)[0] n3_gw = len(n3_gw_only) n3_good = np.where(three_det_obs & em_bool & has_ejecta_bool)[0] n3 = len(n3_good) # sanity check assert n3_gw >= n3, "GW events ({}) less than EM follow events ({})".format(n3_gw, n3) n4_gw_only = np.where(four_det_obs)[0] n4_gw = len(n4_gw_only) n4_good = np.where(four_det_obs & em_bool & has_ejecta_bool)[0] n4 = len(n4_good) # sanity check assert n4_gw >= n4, "GW events ({}) less than EM follow events ({})".format(n4_gw, n4) return dist[n2_good].value.tolist(), tot_mass[n2_good].tolist(),\ dist[n3_good].value.tolist(), tot_mass[n3_good].tolist(),\ dist[n4_good].value.tolist(), tot_mass[n4_good].tolist(),\ obsmag[n2_good].tolist(), obsmag[n3_good].tolist(),\ obsmag[n3_good].tolist(), n2, n3, n4 with schwimmbad.SerialPool() as pool: values = list(pool.map(dotry, range(n_try))) print("Finshed computation, plotting...") data_dump = dict() n_detect2 = [] n_detect3 = [] n_detect4 = [] dist_detect2 = [] mass_detect2 = [] dist_detect3 = [] mass_detect3 = [] dist_detect4 = [] mass_detect4 = [] rmah_detect2 = [] rmah_detect3 = [] rmah_detect4 = [] for idx, val in enumerate(values): d2, m2, d3, m3, d4, m4, h2, h3, h4, n2, n3, n4, *_ = val if n2 >= 0: n_detect2.append(n2) if n3>0: dist_detect2 += d2 mass_detect2 += m2 rmah_detect2 += h2 if n3>=0: n_detect3.append(n3) if n3 > 0: dist_detect3 += d3 mass_detect3 += m3 rmah_detect3 += h3 if n4>=0: n_detect4.append(n4) if n4 > 0: dist_detect4 += d4 mass_detect4 += m4 rmah_detect4 += h4 data_dump[f"{idx}"] = {"d2": d2, "m2": m2, "d3": d3, "m3": m3, "d4": d4, "m4": m4, "h2": h2, "h3": h3, "h4": h4, "n2": n2, "n3": n3, "n4": n4} with open(f"hst/data-dump-hst-29-30-31-{args.mass_distrib}.pickle", "wb") as f: pickle.dump(data_dump, f) n_detect2 = np.array(n_detect2) n_detect3 = np.array(n_detect3) n_detect4 = np.array(n_detect4) #print(f"2 det: {n_detect2};\n3 det: {n_detect3};\n4 det: {n_detect4}") #print(f"2 det mean: {np.mean(n_detect2)};\n3 det mean: {np.mean(n_detect3)};\n4 det mean: {np.mean(n_detect4)}") fig_kw = {'figsize':(9.5/0.7, 3.5)} fig, axes = plt.subplots(nrows=1, ncols=3, **fig_kw) #ebins = np.logspace(0, 1.53, 10) #ebins = np.insert(ebins, 0, 0) ebins = np.arange(32) norm = np.sum(n_detect3)/np.sum(n_detect2) vals, _, _ = axes[0].hist(n_detect2, histtype='stepfilled', \ bins=ebins, color='C0', alpha=0.3, density=True, zorder=0) axes[0].hist(n_detect2, histtype='step', \ bins=ebins, color='C0', lw=3, density=True, zorder=3) bin_centers = (ebins[0:-1] + ebins[1:])/2. mean_nevents = np.mean(n_detect2) five_percent, ninetyfive_percent = np.percentile(n_detect2, 5), np.percentile(n_detect2, 95) axes[0].axvline(round(mean_nevents), color='C0', linestyle='--', lw=2, label=r'$\langle N\rangle = %d ;~ N_{95} = %d$' % (round(mean_nevents), ninetyfive_percent)) axes[0].axvline(ninetyfive_percent, color='C0', linestyle='dotted', lw=1) #vals, bins = np.histogram(n_detect3, bins=ebins, density=True) mean_nevents = np.mean(n_detect3) #vals*=norm #test = dict(zip(ebins, vals)) #print(ebins, vals) #print("Test") #print(test) axes[0].hist(n_detect3, density=True, histtype='stepfilled', color='C1', alpha=0.5, bins=ebins, zorder=1) axes[0].hist(n_detect3, density=True, histtype='step', color='C1', lw=3, bins=ebins, zorder=2) #axes[0].hist(list(test.keys()), weights=list(test.values()), histtype='stepfilled', color='C1', alpha=0.5, bins=ebins, zorder=1) #axes[0].hist(list(test.keys()), weights=list(test.values()), histtype='step', color='C1', lw=3, bins=ebins, zorder=2) five_percent, ninetyfive_percent = np.percentile(n_detect3, 5), np.percentile(n_detect3, 95) axes[0].axvline(round(mean_nevents), color='C1', linestyle='--', lw=2, label=r'$\langle N\rangle = %d ;~ N_{95} = %d$' % (round(mean_nevents), ninetyfive_percent)) axes[0].axvline(ninetyfive_percent, color='C1', linestyle='dotted', lw=1) #vals, bins = np.histogram(n_detect4, bins=ebins, density=True) mean_nevents = np.mean(n_detect4) #vals*=norm #test = dict(zip(ebins, vals)) axes[0].hist(n_detect4, density=True, histtype='stepfilled', color='C2', alpha=0.5, bins=ebins, zorder=1) axes[0].hist(n_detect4, density=True, histtype='step', color='C2', lw=3, bins=ebins, zorder=2) five_percent, ninetyfive_percent = np.percentile(n_detect4, 5), np.percentile(n_detect4, 95) axes[0].axvline(round(mean_nevents), color='C2', linestyle='--', lw=2, label=r'$\langle N \rangle = %d ;~ N_{95} = %d$' % (round(mean_nevents), ninetyfive_percent)) axes[0].axvline(ninetyfive_percent, color='C2', linestyle='dotted', lw=1) axes[0].legend(frameon=False, fontsize='small', loc='upper right') #axes[0].set_xscale('log') axes[0].set_yscale('log') axes[0].set_xlim((0., 31)) #axes[0].set_ylim((1e-2, 1)) ####################################################### ### print out probabilities of greater than 1 event ### ####################################################### print("P(N > 1 event detected)") print("For two detector", np.sum(n_detect2 > 1)/len(n_detect2)) print("For three detector", np.sum(n_detect3 > 1)/len(n_detect2)) print("For four detector", np.sum(n_detect4 > 1)/len(n_detect2)) # save number of detections with open(f'hst/n-events-hst-29-30-31-{args.mass_distrib}.pickle', 'wb') as f: res = dict(n_detect2=n_detect2, n_detect3=n_detect3, n_detect4=n_detect4, dist_detect2=dist_detect2, dist_detect3=dist_detect3, dist_detect4=dist_detect4, mass_detect2=mass_detect2, mass_detect3=mass_detect3, mass_detect4=mass_detect4, rmah_detect2=rmah_detect2, rmah_detect3=rmah_detect3, rmah_detect4=rmah_detect4) pickle.dump(res, f) dist_range = np.arange(0, 400., 0.1) patches = list() legend_text = list() try: kde = spstat.gaussian_kde(dist_detect2, bw_method='scott') pdist = kde(dist_range) axes[1].plot(dist_range, pdist, color='C0', linestyle='-', lw=3, zorder=4) patch1 = axes[1].fill_between(dist_range, np.zeros(len(dist_range)), pdist, color='C0', alpha=0.3, zorder=0) patches.append(patch1) legend_text.append('2 Detector Events') mean_dist = np.mean(dist_detect2) axes[1].axvline(mean_dist, color='C0', linestyle='--', lw=1.5, zorder=6, label=r'$\langle D \rangle = {:.0f}$ Mpc'.format(mean_dist)) ind0_40 = dist_range <= 40. ind40_80 = (dist_range <= 100.) & (dist_range > 40.) ind80_160 = (dist_range <= 160.) & (dist_range > 100.) p0_40 = scinteg.trapz(pdist[ind0_40], dist_range[ind0_40]) p40_80 = scinteg.trapz(pdist[ind40_80], dist_range[ind40_80]) p80_160 = scinteg.trapz(pdist[ind80_160], dist_range[ind80_160]) print(p0_40*5, p40_80*5, p80_160*5) except ValueError: print("Could not create KDE since no 2-det detection") try: kde = spstat.gaussian_kde(dist_detect3, bw_method='scott') pdist = kde(dist_range) axes[1].plot(dist_range, pdist, color='C1', linestyle='-', lw=3, zorder=2) patch2 = axes[1].fill_between(dist_range, np.zeros(len(dist_range)), pdist, color='C1', alpha=0.5, zorder=1) patches.append(patch2) legend_text.append('3 Detector Events') mean_dist = np.mean(dist_detect3) axes[1].axvline(mean_dist, color='C1', linestyle='--', lw=1.5, zorder=6, label=r'$\langle D \rangle = {:.0f}$ Mpc'.format(mean_dist)) axes[1].legend(frameon=False, fontsize='small') except ValueError: print("Could not create KDE since no 3-det detection") try: kde = spstat.gaussian_kde(dist_detect4, bw_method='scott') pdist = kde(dist_range) mean_dist = np.mean(dist_detect4) axes[1].plot(dist_range, pdist, color='C2', linestyle='-', lw=3, zorder=2) axes[1].axvline(mean_dist, color='C2', linestyle='--', lw=1.5, zorder=6, label=r'$\langle D \rangle = {:.0f}$ Mpc'.format(mean_dist)) patch3 = axes[1].fill_between(dist_range, np.zeros(len(dist_range)), pdist, color='C2', alpha=0.5, zorder=1) patches.append(patch3) legend_text.append('4 Detector Events') axes[1].legend(frameon=False, fontsize='small') except ValueError: print("Could not create KDE since no 4-det detection") h_range = np.arange(15, 23, 0.1) kde = spstat.gaussian_kde(rmah_detect2, bw_method='scott') ph = kde(h_range) axes[2].plot(h_range, ph, color='C0', linestyle='-', lw=3, zorder=4) axes[2].fill_between(h_range, np.zeros(len(h_range)), ph, color='C0', alpha=0.3, zorder=0) mean_h = np.mean(rmah_detect2) axes[2].axvline(mean_h, color='C0', linestyle='--', lw=1.5, zorder=6, label=r'$\langle H \rangle = {:.1f}$ mag'.format(mean_h)) kde = spstat.gaussian_kde(rmah_detect3, bw_method='scott') ph = kde(h_range) axes[2].plot(h_range, ph, color='C1', linestyle='-', lw=3, zorder=2) axes[2].fill_between(h_range, np.zeros(len(h_range)), ph, color='C1', alpha=0.5, zorder=1) mean_h = np.mean(rmah_detect3) axes[2].axvline(mean_h, color='C1', linestyle='--', lw=1.5, zorder=6, label=r'$\langle H \rangle = {:.1f}$ mag'.format(mean_h)) axes[2].legend(frameon=False, fontsize='small') try: kde = spstat.gaussian_kde(rmah_detect4, bw_method='scott') ph = kde(h_range) axes[2].plot(h_range, ph, color='C2', linestyle='-', lw=3, zorder=2) axes[2].fill_between(h_range, np.zeros(len(h_range)), ph, color='C1', alpha=0.5, zorder=1) mean_h = np.mean(rmah_detect4) axes[2].axvline(mean_h, color='C2', linestyle='--', lw=1.5, zorder=6, label=r'$\langle H \rangle = {:.1f}$ mag'.format(mean_h)) axes[2].legend(frameon=False, fontsize='small') except ValueError: print("Could not create KDE for h-mag since no 4 detector events found") axes[1].set_xlabel('Distance ($D$, Mpc)', fontsize='large') axes[1].set_ylabel('$P(D)$', fontsize='large') axes[0].set_xlabel('Number of Events ($N$)', fontsize='large') axes[0].set_ylabel('$P(N)$', fontsize='large') axes[2].set_xlabel('Apparent F475W ($g$, AB mag)', fontsize='large') axes[2].set_ylabel('$P(H)$', fontsize='large') axes[0].set_xlim(0, ebins.max()) ymin, ymax = axes[1].get_ylim() axes[1].set_ylim(0, ymax) ymin, ymax = axes[2].get_ylim() axes[2].set_ylim(0, ymax) fig.legend(patches, legend_text, 'upper center', frameon=False, ncol=3, fontsize='medium') fig.tight_layout(rect=[0, 0, 1, 0.97], pad=1.05) fig.savefig(f'hst/hst_gw_detect_hst_29_30_31_{args.mass_distrib}.pdf') plt.show()