Пример #1
0
    def __init__(self, params, pool=None, random_state=None, n_batches=None):

        # set the processing pool
        if pool is None:
            import schwimmbad
            pool = schwimmbad.SerialPool()

        elif not hasattr(pool, 'map') or not hasattr(pool, 'close'):
            raise TypeError("Input pool object must have .map() and .close() "
                            "methods. We recommend using `schwimmbad` pools.")

        self.pool = pool

        # Set the parent random state - child processes get different states
        # based on the parent
        if random_state is None:
            self._rnd_passed = False
            random_state = np.random.RandomState()

        elif not isinstance(random_state, np.random.RandomState):
            raise TypeError("Random state object must be a numpy RandomState "
                            "instance, not '{0}'".format(type(random_state)))

        else:
            self._rnd_passed = True

        self.random_state = random_state

        # check if a JokerParams instance was passed in to specify the state
        if not isinstance(params, JokerParams):
            raise TypeError("Parameter specification must be a JokerParams "
                            "instance, not a '{0}'".format(type(params)))
        self.params = params

        self.n_batches = n_batches
Пример #2
0
def read_array(ftype,
               directory,
               tag,
               dataset,
               numThreads=1,
               noH=False,
               physicalUnits=False,
               CGS=False,
               verbose=True):
    """
   
    Args:
        ftype (str)
        directory (str)
        tag (str)
        dataset (str)
        numThreads (int)
        noH (bool)
        physicalUnits (bool)
    """
    start = timeit.default_timer()

    files = get_files(ftype, directory, tag)

    if numThreads == 1:
        pool = schwimmbad.SerialPool()
    elif numThreads == -1:
        pool = schwimmbad.MultiPool()
    else:
        pool = schwimmbad.MultiPool(processes=numThreads)

    lg = partial(read_hdf5, dataset=dataset)
    dat = np.concatenate(list(pool.map(lg, files)), axis=0)
    pool.close()

    stop = timeit.default_timer()

    print("Reading in '{}' for z = {} using {} thread(s) took {}s".format(
        dataset,
        np.round(read_header(ftype, directory, tag, dataset='Redshift'), 3),
        numThreads, np.round(stop - start, 6)))

    if noH:
        dat = apply_hfreeUnits_conversion(files[0],
                                          dataset,
                                          dat,
                                          verbose=verbose)

    if physicalUnits:
        dat = apply_physicalUnits_conversion(files[0],
                                             dataset,
                                             dat,
                                             verbose=verbose)
    if CGS:
        dat = apply_CGSUnits_conversion(files[0],
                                        dataset,
                                        dat,
                                        verbose=verbose)
    return dat
Пример #3
0
def main(db_path, run_name, overwrite=False, pool=None):

    if pool is None:
        pool = schwimmbad.SerialPool()

    # connect to the database
    engine = db_connect(db_path)
    # engine.echo = True
    logger.debug("Connected to database at '{}'".format(db_path))

    # create a new session for interacting with the database
    session = Session()

    root_path, _ = path.split(db_path)
    plot_path = path.join(root_path, 'plots', run_name)
    if not path.exists(plot_path):
        os.makedirs(plot_path, exist_ok=True)

    # get object to correct the observed RV's
    rv_corr = RVCorrector(session, run_name)

    observations = session.query(Observation).join(Run)\
                          .filter(Run.name == run_name).all()

    for obs in observations:
        q = session.query(RVMeasurement).join(Observation)\
                   .filter(Observation.id == obs.id)

        if q.count() > 0 and not overwrite:
            logger.debug('RV measurement already complete for object '
                         '{0} in file {1}'.format(obs.object,
                                                  obs.filename_raw))
            continue

        elif q.count() > 1:
            raise RuntimeError(
                'Multiple RV measurements found for object {0}'.format(obs))

        elif len(obs.measurements) == 0:
            logger.debug(
                'Observation {0} has no line measurements.'.format(obs))
            continue

        corrected_rv, err, flag = rv_corr.get_corrected_rv(obs)

        # remove previous RV measurements
        if q.count() > 0:
            session.delete(q.one())
            session.commit()

        rv_meas = RVMeasurement(rv=corrected_rv, err=err, flag=flag)
        rv_meas.observation = obs
        session.add(rv_meas)
        session.commit()

    pool.close()
Пример #4
0
def get_age(arr, z, numThreads=4):

    if numThreads == 1:
        pool = schwimmbad.SerialPool()
    elif numThreads == -1:
        pool = schwimmbad.MultiPool()
    else:
        pool = schwimmbad.MultiPool(processes=numThreads)

    calc = partial(get_SFT, redshift=z)
    Age = np.array(list(pool.map(calc, arr)))

    return Age
Пример #5
0
    def get_age(self, arr, z, numThreads=4):

        if numThreads == 1:
            pool = schwimmbad.SerialPool()
        elif numThreads == -1:
            pool = schwimmbad.MultiPool()
        else:
            pool = schwimmbad.MultiPool(processes=numThreads)

        Age = self.cosmo.age(z).value - np.array(
            list(pool.map(self.get_star_formation_time, arr)))
        pool.close()

        return Age
Пример #6
0
 def run(self, batchsize=1, batches=1, threads=1):
     if threads == 1:
         pool = schwimmbad.SerialPool()
     else:
         pool = multiprocessing.Pool()
     for i in range(batches):
         par_list = self.get_parameters(size=batchsize)
         indices = [uuid.uuid4().hex for i in range(batchsize)]
         df = pd.DataFrame(par_list, index=indices)
         self.store.store_df('parameters', df, append=True)
         for key in self.funcdic:
             worker = partial(_run_single, scan=self, key=key)
             results = list(pool.map(worker, par_list))
             df = pd.DataFrame(results, index=indices)
             self.store.store_df(key, df, append=True)
     pool.close()
Пример #7
0
    def __init__(self,
                 prior,
                 pool=None,
                 random_state=None,
                 tempfile_path=None):

        # set the processing pool
        if pool is None:
            import schwimmbad
            pool = schwimmbad.SerialPool()
        elif not hasattr(pool, 'map') or not hasattr(pool, 'close'):
            raise TypeError("Input pool object must have .map() and .close() "
                            "methods. We recommend using `schwimmbad` pools.")
        self.pool = pool

        # Set the parent random state - child processes get different states
        # based on the parent
        if random_state is None:
            random_state = np.random.default_rng()
        elif isinstance(random_state, np.random.RandomState):
            warnings.warn(
                "With thejoker>=v1.2, use numpy.random.Generator "
                "objects instead of RandomState objects to control "
                "random numbers.", DeprecationWarning)
            tmp = np.random.Generator(np.random.MT19937())
            tmp.bit_generator.state = random_state.get_state()
            random_state = tmp
        elif not isinstance(random_state, np.random.Generator):
            raise TypeError("Random state object must be a "
                            "numpy.random.Generator instance, not "
                            f"'{type(random_state)}'")
        self.random_state = random_state

        # check if a JokerParams instance was passed in to specify the state
        if not isinstance(prior, JokerPrior):
            raise TypeError("The input prior must be a JokerPrior instance.")
        self.prior = prior

        if tempfile_path is None:
            self._tempfile_path = os.path.expanduser('~/.thejoker/')
        else:
            self._tempfile_path = os.path.abspath(
                os.path.expanduser(tempfile_path))
Пример #8
0
    def test_multiproc_helpers(self, tmpdir):
        prior_samples_file = str(tmpdir.join('prior-samples.h5'))
        pool = schwimmbad.SerialPool()

        data = self.data['circ_binary']
        joker_params = self.joker_params['circ_binary']
        truth = self.truths['circ_binary']
        nlp = self.truths_to_nlp(truth)

        # write some nonsense out to the prior file
        n = 8192
        P = np.random.uniform(nlp[0]-2., nlp[0]+2., n)
        M0 = np.random.uniform(0, 2*np.pi, n)
        ecc = np.zeros(n)
        omega = np.zeros(n)
        jitter = np.zeros(n)
        samples = np.vstack((P,M0,ecc,omega,jitter)).T

        # TODO: use save_prior_samples here

        with h5py.File(prior_samples_file) as f:
            f['samples'] = samples

        lls = compute_likelihoods(n, prior_samples_file, 0, data,
                                  joker_params, pool)
        idx = get_good_sample_indices(lls)
        assert len(idx) >= 1

        lls = compute_likelihoods(n, prior_samples_file, 0, data,
                                  joker_params, pool, n_batches=13)
        idx = get_good_sample_indices(lls)
        assert len(idx) >= 1

        full_samples = sample_indices_to_full_samples(idx, prior_samples_file,
                                                      data, joker_params, pool)
        print(full_samples)
Пример #9
0
def main(db_path,
         run_name,
         data_root_path=None,
         filename=None,
         overwrite=False,
         pool=None):

    if pool is None:
        pool = schwimmbad.SerialPool()

    # connect to the database
    engine = db_connect(db_path)
    # engine.echo = True
    logger.debug("Connected to database at '{}'".format(db_path))

    # create a new session for interacting with the database
    session = Session()

    root_path, _ = path.split(db_path)
    if data_root_path is None:
        data_root_path = root_path

    plot_path = path.join(root_path, 'plots', run_name)
    if not path.exists(plot_path):
        os.makedirs(plot_path, exist_ok=True)

    # TODO: there might be some bugs here...
    n_lines = session.query(SpectralLineInfo).count()
    Halpha = session.query(SpectralLineInfo)\
                    .filter(SpectralLineInfo.name == 'Halpha').one()
    OI_lines = session.query(SpectralLineInfo)\
                      .filter(SpectralLineInfo.name.contains('[OI]')).all()

    if filename is None:  # grab all unfinished sources
        observations = session.query(Observation).join(Run)\
                              .filter(Run.name == run_name).all()

    else:  # only process the observation corresponding to this filename
        observations = session.query(Observation).join(Run)\
                              .filter(Run.name == run_name)\
                              .filter(Observation.filename_raw == filename).all()

    for obs in observations:
        measurements = session.query(SpectralLineMeasurement)\
                              .join(Observation)\
                              .filter(Observation.id == obs.id).all()

        if len(measurements) == n_lines and not overwrite:
            logger.debug('All line measurements already complete for object '
                         '{0} in file {1}'.format(obs.object,
                                                  obs.filename_raw))
            continue

        # Read the spectrum data and get wavelength solution
        filebase, _ = path.splitext(obs.filename_1d)
        filename_1d = obs.path_1d(data_root_path)
        spec = Table.read(filename_1d)
        logger.debug('Loaded 1D spectrum for object {0} from file {1}'.format(
            obs.object, filename_1d))

        # Extract region around Halpha
        x, (flux, ivar) = extract_region(
            spec['wavelength'],
            center=Halpha.wavelength.value,
            width=100,
            arrs=[spec['source_flux'], spec['source_ivar']])

        # We start by doing maximum likelihood estimation to fit the line, then
        # use the best-fit parameters to initialize an MCMC run.
        # TODO: need to figure out if it's emission or absorption...for now just
        #   assume absorption
        absorp_emiss = -1.
        lf = VoigtLineFitter(x, flux, ivar, absorp_emiss=absorp_emiss)
        lf.fit()
        fit_pars = lf.get_gp_mean_pars()

        if (not lf.success
                or abs(fit_pars['x0'] - Halpha.wavelength.value) > 16.
                or  # 16 Å = ~700 km/s
                abs(fit_pars['amp']) < 10):  # minimum amplitude - MAGIC NUMBER
            # TODO: should try again with emission line
            logger.error('absorption line has tiny amplitude! did '
                         'auto-determination of absorption/emission fail?')
            # TODO: what now?
            continue

        fig = lf.plot_fit()
        fig.savefig(path.join(plot_path, '{}_maxlike.png'.format(filebase)),
                    dpi=256)
        plt.close(fig)

        # ----------------------------------------------------------------------

        # Run `emcee` instead to sample over GP model parameters:
        if fit_pars['std_G'] < 1E-2:
            lf.gp.freeze_parameter('mean:ln_std_G')

        initial = np.array(lf.gp.get_parameter_vector())
        if initial[4] < -10:  # TODO: ???
            initial[4] = -8.
        if initial[5] < -10:  # TODO: ???
            initial[5] = -8.
        ndim, nwalkers = len(initial), 64

        sampler = emcee.EnsembleSampler(nwalkers,
                                        ndim,
                                        log_probability,
                                        pool=pool,
                                        args=(lf.gp, flux))

        logger.debug("Running burn-in...")
        p0 = initial + 1e-6 * np.random.randn(nwalkers, ndim)
        p0, lp, _ = sampler.run_mcmc(p0, 128)

        logger.debug("Running 2nd burn-in...")
        sampler.reset()
        p0 = p0[lp.argmax()] + 1e-3 * np.random.randn(nwalkers, ndim)
        p0, lp, _ = sampler.run_mcmc(p0, 512)

        logger.debug("Running production...")
        sampler.reset()
        pos, lp, _ = sampler.run_mcmc(p0, 1024)

        fit_kw = dict()
        for i, par_name in enumerate(lf.gp.get_parameter_names()):
            if 'kernel' in par_name: continue

            # remove 'mean:'
            par_name = par_name[5:]

            # skip bg
            if par_name.startswith('bg'): continue

            samples = sampler.flatchain[:, i]

            if par_name.startswith('ln_'):
                par_name = par_name[3:]
                samples = np.exp(samples)

            MAD = np.median(np.abs(samples - np.median(samples)))
            fit_kw[par_name] = np.median(samples)
            fit_kw[par_name + '_error'] = 1.5 * MAD  # convert to ~stddev

        # remove all previous line measurements
        q = session.query(SpectralLineMeasurement).join(Observation)\
                   .filter(Observation.id == obs.id)
        if q.count() > 0:
            for meas in q.all():
                session.delete(meas)
            session.commit()

        slm = SpectralLineMeasurement(**fit_kw)
        slm.info = Halpha
        slm.observation = obs
        session.add(slm)
        session.commit()

        # --------------------------------------------------------------------
        # plot MCMC traces
        fig, axes = plt.subplots(2, 4, figsize=(18, 6))
        for i in range(sampler.dim):
            for walker in sampler.chain[..., i]:
                axes.flat[i].plot(walker,
                                  marker='',
                                  drawstyle='steps-mid',
                                  alpha=0.2)
            axes.flat[i].set_title(lf.gp.get_parameter_names()[i], fontsize=12)
        fig.tight_layout()
        fig.savefig(path.join(plot_path, '{}_mcmc_trace.png'.format(filebase)),
                    dpi=256)
        plt.close(fig)
        # --------------------------------------------------------------------

        # --------------------------------------------------------------------
        # plot samples
        fig, axes = plt.subplots(3, 1, figsize=(10, 10), sharex=True)

        samples = sampler.flatchain
        for s in samples[np.random.randint(len(samples), size=32)]:
            lf.gp.set_parameter_vector(s)
            lf.plot_fit(axes=axes, fit_alpha=0.2)

        fig.tight_layout()
        fig.savefig(path.join(plot_path, '{}_mcmc_fits.png'.format(filebase)),
                    dpi=256)
        plt.close(fig)
        # --------------------------------------------------------------------

        # --------------------------------------------------------------------
        # corner plot
        fig = corner.corner(
            sampler.flatchain[::10, :],
            labels=[x.split(':')[1] for x in lf.gp.get_parameter_names()])
        fig.savefig(path.join(plot_path, '{}_corner.png'.format(filebase)),
                    dpi=256)
        plt.close(fig)
        # --------------------------------------------------------------------

        # compute centroids for sky lines
        sky_centroids = []
        for j, sky_line in enumerate(OI_lines):
            wvln = sky_line.wavelength.value
            x, (flux, ivar) = extract_region(
                spec['wavelength'],
                center=wvln,
                width=32.,  # angstroms
                arrs=[spec['background_flux'], spec['background_ivar']])

            lf = GaussianLineFitter(x, flux, ivar,
                                    absorp_emiss=1.)  # all emission lines

            try:
                lf.fit()
                fit_pars = lf.get_gp_mean_pars()

            except Exception as e:
                logger.warn("Failed to fit sky line {0}:\n{1}".format(
                    sky_line, e))
                lf.success = False
                fit_pars = lf.get_init()
                # OMG this is the biggest effing hack
                fit_pars['amp'] = 0.
                fit_pars['bg_coef'] = None
                fit_pars['x0'] = 0.

            # HACK: hackish signal-to-noise
            max_ = fit_pars['amp'] / np.sqrt(2 * np.pi * fit_pars['std']**2)
            SNR = max_ / np.median(1 / np.sqrt(ivar))

            if (not lf.success or abs(fit_pars['x0'] - wvln) > 4
                    or fit_pars['amp'] < 10 or fit_pars['std'] > 4
                    or SNR < 2.5):
                # failed
                x0 = np.nan * u.angstrom
                title = 'f****d'
                fit_pars['amp'] = 0.

            else:
                x0 = fit_pars['x0'] * u.angstrom
                title = '{:.2f}'.format(fit_pars['amp'])

            if lf.success:
                fig = lf.plot_fit()
                fig.suptitle(title, y=0.95)
                fig.subplots_adjust(top=0.8)
                fig.savefig(path.join(
                    plot_path,
                    '{}_maxlike_sky_{:.0f}.png'.format(filebase, wvln)),
                            dpi=256)
                plt.close(fig)

            # store the sky line measurements
            fit_pars['std_G'] = fit_pars.pop('std')  # HACK
            fit_pars.pop('bg_coef')  # HACK
            slm = SpectralLineMeasurement(**fit_pars)
            slm.info = sky_line
            slm.observation = obs
            session.add(slm)
            session.commit()

            sky_centroids.append(x0)
        sky_centroids = u.Quantity(sky_centroids)

        logger.info('{} [{}]: x0={x0:.3f} σ={err:.3f}\n--------'.format(
            obs.object, filebase, x0=fit_kw['x0'], err=fit_kw['x0_error']))

        session.commit()

    pool.close()
Пример #10
0
    parser.add_argument('--sim', action='store_true', default=False,
                        dest='simulated_data')
    parser.add_argument('--name', required=True, dest='name',
                        help='Name of the data - can be "apw" or "rave"')

    args = parser.parse_args()

    if args.mpi:
        pool = MPIPool()

        if not pool.is_master():
            pool.wait()
            sys.exit(0)

    else:
        pool = schwimmbad.SerialPool()

    if args.simulated_data:
        print("Loading simulated data")

        # Load simulated data
        _tbl1 = fits.getdata('../notebooks/data1.fits')
        data1 = TGASData(_tbl1, rv=_tbl1['RV']*u.km/u.s,
                         rv_err=_tbl1['RV_err']*u.km/u.s)

        _tbl2 = fits.getdata('../notebooks/data2.fits')
        data2 = TGASData(_tbl2, rv=_tbl2['RV']*u.km/u.s,
                         rv_err=_tbl2['RV_err']*u.km/u.s)

    else:
        print("Loading real data")
Пример #11
0
import schwimmbad
import numpy as np


def func(i):
    '''
    A useless function
    '''
    print(str(i + 1))
    return i


# Use multipool - same as multiprocessing
with schwimmbad.MultiPool() as pool:
    inputs = [i for i in np.arange(0, 10, 2)]
    out1 = list(pool.map(func, inputs))

# Use serial pool
with schwimmbad.SerialPool() as pool:
    inputs = [i for i in np.arange(10, 20, 2)]
    out2 = list(pool.map(func, inputs))

print(out1, out2)
Пример #12
0
def main(argv=None):

    args = get_options(argv=argv)
    np.random.seed(seed=42)

    # setup time-ranges
    ligo_run_start = Time('2022-06-01T00:00:00.0')
    ligo_run_end   = Time('2023-06-01T00:00:00.0')
    hst_cyc_start  = Time('2021-10-01T00:00:00.0')
    hst_cyc_end    = Time('2023-09-30T00:00:00.0')
    #hst_cyc_end    = Time('2023-09-30T00:00:00.0')
    eng_time       = 2.*u.week
    Range = namedtuple('Range', ['start', 'end'])
    ligo_run  = Range(start=ligo_run_start, end=ligo_run_end)
    hst_cycle = Range(start=hst_cyc_start,  end=hst_cyc_end)
    latest_start = max(ligo_run.start, hst_cycle.start)
    earliest_end = min(ligo_run.end, hst_cycle.end)
    td = (earliest_end - latest_start) + eng_time
    fractional_duration = (td/(1.*u.year)).decompose().value

    box_size = args.box_size
    volume = box_size**3
    # create the mass distribution of the merging neutron star
    mass_distrib = args.mass_distrib
    # the truncated normal distribution looks to be from:
    # https://arxiv.org/pdf/1309.6635.pdf
    mean_mass = args.masskey1
    sig_mass  = args.masskey2

    min_mass = args.masskey1
    max_mass = args.masskey2

    # the two ligo detectors ahve strongly correlated duty cycles
    # they are both not very correlated with Virgo
    lvc_cor_matrix = np.array([[1., 0.8, 0.5, 0.2],
                               [0.8, 1., 0.5, 0.2],
                               [0.5, 0.5, 1., 0.2],
                               [0.2, 0.2, 0.2, 1.]])
    upper_chol = cholesky(lvc_cor_matrix)

    # setup duty cycles
    h_duty = args.hdutycycle
    l_duty = args.ldutycycle
    v_duty = args.vdutycycle
    k_duty = args.kdutycycle

    # setup event rates
    mean_lograte = args.mean_lograte
    sig_lograte  = args.sig_lograte
    n_try = args.ntry

    temp = at.Table.read('kilonova_phottable_40Mpc.txt', format='ascii')
    phase = temp['ofphase']
    temphmag  = temp['f160w']
    tempf200w = temp['f218w']
    temprmag  = temp['f625w']

    # define ranges
    ligo_range = get_range('ligo')
    virgo_range = get_range('virgo')
    kagra_range = get_range('kagra')
    
    def dotry(n):
        rate = 10.**(np.random.normal(mean_lograte, sig_lograte))
        n_events = np.around(rate*volume*fractional_duration).astype('int')
        if n_events == 0:
                return tuple(0 for _ in range(15))  # FIXME: fix to prevent unpacking error
        print(f"### Num trial = {n}; Num events = {n_events}")
        if mass_distrib == 'mw':
            mass1 = spstat.truncnorm.rvs(0, np.inf, 1.4, 0.09, n_events)  # FIXME: Unbound local error
            mass2 = spstat.truncnorm.rvs(0, np.inf, 1.4, 0.09, n_events)
        elif mass_distrib == 'msp':
            print("MSP population chosen, overriding mean_mass and sig_mass if supplied.")
            # numbers from https://arxiv.org/pdf/1605.01665.pdf
            # two modes, choose a random one each time
            mean_mass, sig_mass = random.choice([(1.393, 0.064), (1.807, 0.177)])
            mass1 = spstat.truncnorm.rvs(0, np.inf, mean_mass, sig_mass, n_events)
            mass2 = spstat.truncnorm.rvs(0, np.inf, mean_mass, sig_mass, n_events)
        else:
            print("Flat population chosen.")
            mass1 = np.random.uniform(min_mass, max_mass, n_events)
            mass2 = np.random.uniform(min_mass, max_mass, n_events)
        bns_range_ligo = np.array(
            [ligo_range(m1=m1, m2=m2) for m1, m2 in zip(mass1, mass2)]
        ) * u.Mpc
        bns_range_virgo = np.array(
            [virgo_range(m1=m1, m2=m2) for m1, m2 in zip(mass1, mass2)]
        ) * u.Mpc
        bns_range_kagra = np.array(
            [kagra_range(m1=m1, m2=m2) for m1, m2 in zip(mass1, mass2)]
        ) * u.Mpc
        tot_mass = mass1 + mass2

        delay = np.random.uniform(0, 365.25, n_events)
        delay[delay > 90] = 0

        av = np.random.exponential(1, n_events)*0.4
        ah = av/6.1
        ar = av/1.33  # ref: table 2 of https://arxiv.org/abs/astro-ph/9809387

        sss17a = -16.9 #H-band
        sss17a_r = -15.8 #Rband
        minmag = -14.7
        maxmag = sss17a - 2.

        rmag = temprmag - min(temprmag)
        rmag[phase < 2.5] = 0

        magindex = [(phase - x).argmin() for x in delay]
        magindex = np.array(magindex)

        default_value= [0,]
        if n_events == 0:
            return default_value, default_value, default_value, default_value, default_value, default_value, 0, 0

        absm = np.random.uniform(0, 1, n_events)*abs(maxmag-minmag) + sss17a_r + rmag[magindex] + ar
        absm = np.array(absm)

        # simulate coordinates
        x = np.random.uniform(-box_size/2., box_size/2., n_events)*u.megaparsec
        y = np.random.uniform(-box_size/2., box_size/2., n_events)*u.megaparsec
        z = np.random.uniform(-box_size/2., box_size/2., n_events)*u.megaparsec
        dist = (x**2. + y**2. + z**2. + (0.05*u.megaparsec)**2.)**0.5

        h_on, l_on, v_on, k_on = get_sim_dutycycles(n_events, upper_chol,
                                                    h_duty, l_duty, v_duty, k_duty)
        n_detectors_on = np.array(
            [sum(_) for _ in np.vstack((h_on, l_on, v_on, k_on)).T]
        )
        # which detectors observed
        dist_ligo_bool  = dist <= bns_range_ligo
        dist_virgo_bool = dist <= bns_range_virgo
        dist_kagra_bool = dist <= bns_range_kagra

        h_on_and_observed = h_on * dist_ligo_bool
        l_on_and_observed = l_on * dist_ligo_bool
        v_on_and_observed = v_on * dist_virgo_bool
        k_on_and_observed = k_on * dist_kagra_bool

        n_detectors_on_and_obs = np.sum(np.vstack(
            (h_on_and_observed, l_on_and_observed, v_on_and_observed,
             k_on_and_observed)).T,
            axis=1
        )

        two_det_obs = n_detectors_on_and_obs == 2
        three_det_obs = n_detectors_on_and_obs == 3
        four_det_obs = n_detectors_on_and_obs == 4

        # decide whether there is a kilnova based on remnant matter
        has_ejecta_bool = [
            has_ejecta_mass(m1, m2) for m1, m2 in zip(mass1, mass2)
        ]

        distmod = Distance(dist)
        obsmag = absm + distmod.distmod.value
        em_bool = obsmag < 22.

        # whether this event was not affected by then sun
        detected_events = np.where(em_bool)
        sun_bool = np.random.random(len(detected_events[0])) >= args.sun_loss
        em_bool[detected_events] = sun_bool

        n2_gw_only = np.where(two_det_obs)[0]
        n2_gw = len(n2_gw_only)
        n2_good = np.where(two_det_obs & em_bool & has_ejecta_bool)[0]
        n2 = len(n2_good)
        # sanity check
        assert n2_gw >= n2, "GW events ({}) less than EM follow events ({})".format(n2_gw, n2)
        n3_gw_only = np.where(three_det_obs)[0]
        n3_gw = len(n3_gw_only)
        n3_good = np.where(three_det_obs & em_bool & has_ejecta_bool)[0]
        n3 = len(n3_good)
        # sanity check
        assert n3_gw >= n3, "GW events ({}) less than EM follow events ({})".format(n3_gw, n3)
        n4_gw_only = np.where(four_det_obs)[0]
        n4_gw = len(n4_gw_only)
        n4_good = np.where(four_det_obs & em_bool & has_ejecta_bool)[0]
        n4 = len(n4_good)
        # sanity check
        assert n4_gw >= n4, "GW events ({}) less than EM follow events ({})".format(n4_gw, n4)
        return dist[n2_good].value.tolist(), tot_mass[n2_good].tolist(),\
            dist[n3_good].value.tolist(), tot_mass[n3_good].tolist(),\
            dist[n4_good].value.tolist(), tot_mass[n4_good].tolist(),\
            obsmag[n2_good].tolist(), obsmag[n3_good].tolist(),\
            obsmag[n3_good].tolist(), n2, n3, n4

    with schwimmbad.SerialPool() as pool:
        values = list(pool.map(dotry, range(n_try)))
    print("Finshed computation, plotting...")
    data_dump = dict()
    n_detect2 = []
    n_detect3 = []
    n_detect4 = []
    dist_detect2 = []
    mass_detect2 = []
    dist_detect3 = []
    mass_detect3 = []
    dist_detect4 = []
    mass_detect4 = []
    rmah_detect2 = []
    rmah_detect3 = []
    rmah_detect4 = []
    for idx, val in enumerate(values):
        d2, m2, d3, m3, d4, m4, h2, h3, h4, n2, n3, n4, *_ = val
        if n2 >= 0:
            n_detect2.append(n2)
            if n3>0:
                dist_detect2 += d2
                mass_detect2 += m2
                rmah_detect2 += h2
        if n3>=0:
            n_detect3.append(n3)
            if n3 > 0:
                dist_detect3 += d3
                mass_detect3 += m3
                rmah_detect3 += h3
        if n4>=0:
            n_detect4.append(n4)
            if n4 > 0:
                dist_detect4 += d4
                mass_detect4 += m4
                rmah_detect4 += h4
        data_dump[f"{idx}"] = {"d2": d2, "m2": m2, "d3": d3,
                               "m3": m3, "d4": d4, "m4": m4,
                               "h2": h2, "h3": h3, "h4": h4,
                               "n2": n2, "n3": n3, "n4": n4}
    with open(f"hst/data-dump-hst-29-30-31-{args.mass_distrib}.pickle", "wb") as f:
        pickle.dump(data_dump, f)

    n_detect2 = np.array(n_detect2)
    n_detect3 = np.array(n_detect3)
    n_detect4 = np.array(n_detect4)

    #print(f"2 det: {n_detect2};\n3 det: {n_detect3};\n4 det: {n_detect4}")
    #print(f"2 det mean: {np.mean(n_detect2)};\n3 det mean: {np.mean(n_detect3)};\n4 det mean: {np.mean(n_detect4)}")
    fig_kw = {'figsize':(9.5/0.7, 3.5)}
    fig, axes = plt.subplots(nrows=1, ncols=3, **fig_kw)

    #ebins = np.logspace(0, 1.53, 10)
    #ebins = np.insert(ebins, 0, 0)
    ebins = np.arange(32)
    norm = np.sum(n_detect3)/np.sum(n_detect2)
    vals, _, _ = axes[0].hist(n_detect2, histtype='stepfilled', \
            bins=ebins, color='C0', alpha=0.3, density=True, zorder=0)

    axes[0].hist(n_detect2, histtype='step', \
                 bins=ebins, color='C0', lw=3, density=True, zorder=3)
    bin_centers = (ebins[0:-1] + ebins[1:])/2.
    mean_nevents = np.mean(n_detect2)
    five_percent, ninetyfive_percent = np.percentile(n_detect2, 5), np.percentile(n_detect2, 95)
    axes[0].axvline(round(mean_nevents), color='C0', linestyle='--', lw=2,
                    label=r'$\langle N\rangle = %d ;~ N_{95} = %d$' % (round(mean_nevents), ninetyfive_percent))
    axes[0].axvline(ninetyfive_percent, color='C0',
                    linestyle='dotted', lw=1)

    #vals, bins = np.histogram(n_detect3, bins=ebins, density=True)
    mean_nevents = np.mean(n_detect3)
    #vals*=norm
    #test = dict(zip(ebins, vals))
    #print(ebins, vals)
    #print("Test")
    #print(test)
    axes[0].hist(n_detect3, density=True, histtype='stepfilled', color='C1', alpha=0.5, bins=ebins, zorder=1)
    axes[0].hist(n_detect3, density=True, histtype='step', color='C1', lw=3, bins=ebins, zorder=2)
    #axes[0].hist(list(test.keys()), weights=list(test.values()), histtype='stepfilled', color='C1', alpha=0.5, bins=ebins, zorder=1)
    #axes[0].hist(list(test.keys()), weights=list(test.values()), histtype='step', color='C1', lw=3, bins=ebins, zorder=2)
    five_percent, ninetyfive_percent = np.percentile(n_detect3, 5), np.percentile(n_detect3, 95)
    axes[0].axvline(round(mean_nevents), color='C1', linestyle='--', lw=2,
                    label=r'$\langle N\rangle = %d ;~ N_{95} = %d$' % (round(mean_nevents), ninetyfive_percent))
    axes[0].axvline(ninetyfive_percent, color='C1',
                    linestyle='dotted', lw=1)
    #vals, bins = np.histogram(n_detect4, bins=ebins, density=True)
    mean_nevents = np.mean(n_detect4)
    #vals*=norm
    #test = dict(zip(ebins, vals))
    axes[0].hist(n_detect4, density=True, histtype='stepfilled', color='C2', alpha=0.5, bins=ebins, zorder=1)
    axes[0].hist(n_detect4, density=True, histtype='step', color='C2', lw=3, bins=ebins, zorder=2)
    five_percent, ninetyfive_percent = np.percentile(n_detect4, 5), np.percentile(n_detect4, 95)
    axes[0].axvline(round(mean_nevents), color='C2', linestyle='--', lw=2,
                    label=r'$\langle N \rangle = %d ;~ N_{95} = %d$' % (round(mean_nevents), ninetyfive_percent))
    axes[0].axvline(ninetyfive_percent, color='C2',
                    linestyle='dotted', lw=1)
    axes[0].legend(frameon=False, fontsize='small', loc='upper right')
    #axes[0].set_xscale('log')
    axes[0].set_yscale('log')
    axes[0].set_xlim((0., 31))
    #axes[0].set_ylim((1e-2, 1))
    #######################################################
    ### print out probabilities of greater than 1 event ###
    #######################################################
    print("P(N > 1 event detected)")
    print("For two detector", np.sum(n_detect2 > 1)/len(n_detect2))
    print("For three detector", np.sum(n_detect3 > 1)/len(n_detect2))
    print("For four detector", np.sum(n_detect4 > 1)/len(n_detect2))
    # save number of detections
    with open(f'hst/n-events-hst-29-30-31-{args.mass_distrib}.pickle', 'wb') as f:
        res = dict(n_detect2=n_detect2, n_detect3=n_detect3, n_detect4=n_detect4,
                   dist_detect2=dist_detect2, dist_detect3=dist_detect3, dist_detect4=dist_detect4,
                   mass_detect2=mass_detect2, mass_detect3=mass_detect3, mass_detect4=mass_detect4,
                   rmah_detect2=rmah_detect2, rmah_detect3=rmah_detect3, rmah_detect4=rmah_detect4)
        pickle.dump(res, f)
    dist_range = np.arange(0, 400., 0.1)
    patches = list()
    legend_text = list()
    try:
        kde = spstat.gaussian_kde(dist_detect2, bw_method='scott')
        pdist = kde(dist_range)
        axes[1].plot(dist_range, pdist, color='C0', linestyle='-', lw=3, zorder=4)
        patch1 = axes[1].fill_between(dist_range, np.zeros(len(dist_range)), pdist, color='C0', alpha=0.3, zorder=0)
        patches.append(patch1)
        legend_text.append('2 Detector Events')
        mean_dist = np.mean(dist_detect2)
        axes[1].axvline(mean_dist, color='C0', linestyle='--', lw=1.5, zorder=6, label=r'$\langle D \rangle = {:.0f}$ Mpc'.format(mean_dist))
        ind0_40 = dist_range <= 40.
        ind40_80 = (dist_range <= 100.) & (dist_range > 40.)
        ind80_160 = (dist_range <= 160.) & (dist_range > 100.)
        p0_40 = scinteg.trapz(pdist[ind0_40], dist_range[ind0_40])
        p40_80 = scinteg.trapz(pdist[ind40_80], dist_range[ind40_80])
        p80_160 = scinteg.trapz(pdist[ind80_160], dist_range[ind80_160])
        print(p0_40*5, p40_80*5, p80_160*5)
    except ValueError:
        print("Could not create KDE since no 2-det detection")

    try:
        kde = spstat.gaussian_kde(dist_detect3, bw_method='scott')
        pdist = kde(dist_range)
        axes[1].plot(dist_range, pdist, color='C1', linestyle='-', lw=3, zorder=2)
        patch2 = axes[1].fill_between(dist_range, np.zeros(len(dist_range)), pdist, color='C1', alpha=0.5, zorder=1)
        patches.append(patch2)
        legend_text.append('3 Detector Events')
        mean_dist = np.mean(dist_detect3)
        axes[1].axvline(mean_dist, color='C1', linestyle='--', lw=1.5, zorder=6, label=r'$\langle D \rangle = {:.0f}$ Mpc'.format(mean_dist))
        axes[1].legend(frameon=False, fontsize='small')
    except ValueError:
        print("Could not create KDE since no 3-det detection")

    try:
        kde = spstat.gaussian_kde(dist_detect4, bw_method='scott')
        pdist = kde(dist_range)
        mean_dist = np.mean(dist_detect4)
        axes[1].plot(dist_range, pdist, color='C2', linestyle='-', lw=3, zorder=2)
        axes[1].axvline(mean_dist, color='C2', linestyle='--', lw=1.5, zorder=6, label=r'$\langle D \rangle = {:.0f}$ Mpc'.format(mean_dist))
        patch3 = axes[1].fill_between(dist_range, np.zeros(len(dist_range)), pdist, color='C2', alpha=0.5, zorder=1)
        patches.append(patch3)
        legend_text.append('4 Detector Events')
        axes[1].legend(frameon=False, fontsize='small')
    except ValueError:
        print("Could not create KDE since no 4-det detection")

    h_range = np.arange(15, 23, 0.1)
    kde = spstat.gaussian_kde(rmah_detect2, bw_method='scott')
    ph = kde(h_range)
    axes[2].plot(h_range, ph, color='C0', linestyle='-', lw=3, zorder=4)
    axes[2].fill_between(h_range, np.zeros(len(h_range)), ph, color='C0', alpha=0.3, zorder=0)
    mean_h = np.mean(rmah_detect2)
    axes[2].axvline(mean_h, color='C0', linestyle='--', lw=1.5, zorder=6, label=r'$\langle H \rangle = {:.1f}$ mag'.format(mean_h))

    kde = spstat.gaussian_kde(rmah_detect3, bw_method='scott')
    ph = kde(h_range)
    axes[2].plot(h_range, ph, color='C1', linestyle='-', lw=3, zorder=2)
    axes[2].fill_between(h_range, np.zeros(len(h_range)), ph, color='C1', alpha=0.5, zorder=1)
    mean_h = np.mean(rmah_detect3)
    axes[2].axvline(mean_h, color='C1', linestyle='--', lw=1.5, zorder=6, label=r'$\langle H \rangle = {:.1f}$ mag'.format(mean_h))
    axes[2].legend(frameon=False, fontsize='small')

    try:
        kde = spstat.gaussian_kde(rmah_detect4, bw_method='scott')
        ph = kde(h_range)
        axes[2].plot(h_range, ph, color='C2', linestyle='-', lw=3, zorder=2)
        axes[2].fill_between(h_range, np.zeros(len(h_range)), ph, color='C1', alpha=0.5, zorder=1)
        mean_h = np.mean(rmah_detect4)
        axes[2].axvline(mean_h, color='C2', linestyle='--', lw=1.5, zorder=6, label=r'$\langle H \rangle = {:.1f}$ mag'.format(mean_h))
        axes[2].legend(frameon=False, fontsize='small')
    except ValueError:
        print("Could not create KDE for h-mag since no 4 detector events found")

    axes[1].set_xlabel('Distance ($D$, Mpc)', fontsize='large')
    axes[1].set_ylabel('$P(D)$', fontsize='large')
    axes[0].set_xlabel('Number of Events ($N$)', fontsize='large')
    axes[0].set_ylabel('$P(N)$', fontsize='large')

    axes[2].set_xlabel('Apparent F475W ($g$, AB mag)', fontsize='large')
    axes[2].set_ylabel('$P(H)$', fontsize='large')
    axes[0].set_xlim(0, ebins.max())

    ymin, ymax = axes[1].get_ylim()
    axes[1].set_ylim(0, ymax)
    ymin, ymax = axes[2].get_ylim()
    axes[2].set_ylim(0, ymax)

    fig.legend(patches, legend_text,
               'upper center', frameon=False, ncol=3, fontsize='medium')
    fig.tight_layout(rect=[0, 0, 1, 0.97], pad=1.05)
    fig.savefig(f'hst/hst_gw_detect_hst_29_30_31_{args.mass_distrib}.pdf')
    plt.show()