Esempio n. 1
0
def resample_timeseries_dict(tsd, nproc=1, **sampling_dict):
    """Resample a `TimeSeriesDict`

    Parameters
    ----------
    tsd : `~gwpy.timeseries.TimeSeriesDict`
        the input dict to resample

    nproc : `int`, optional
        the number of parallel processes to use

    **sampling_dict
        ``<name>=<sampling frequency>`` pairs defining new
        sampling frequencies for keys of ``tsd``

    Returns
    -------
    resampled : `~gwpy.timeseries.TimeSeriesDict`
        a new dict with the keys from ``tsd`` and resampled values, if
        that key was included in ``sampling_dict``, or the original value
    """

    # define resample function (must take single argument)
    def _resample(args):
        ts, fs = args
        if fs and units.Quantity(fs, "Hz") == ts.sample_rate:
            warnings.warn(
                "requested resample rate for {0} matches native rate ({1}), "
                "please update configuration".format(ts.name, ts.sample_rate),
                UserWarning,
            )
        elif fs:
            return ts.resample(fs, ftype='fir', window='hamming')
        return ts

    # group timeseries with new sampling frequencies
    inputs = ((ts, sampling_dict.get(name)) for name, ts in tsd.items())

    # apply resampling
    resampled = multiprocess_with_queues(nproc, _resample, inputs)

    # map back to original dict keys
    return dict(zip(list(tsd), resampled))
Esempio n. 2
0
def resample_timeseries_dict(tsd, nproc=1, **sampling_dict):
    """Resample a `TimeSeriesDict`

    Parameters
    ----------
    tsd : `~gwpy.timeseries.TimeSeriesDict`
        the input dict to resample

    nproc : `int`, optional
        the number of parallel processes to use

    **sampling_dict
        ``<name>=<sampling frequency>`` pairs defining new
        sampling frequencies for keys of ``tsd``

    Returns
    -------
    resampled : `~gwpy.timeseries.TimeSeriesDict`
        a new dict with the keys from ``tsd`` and resampled values, if
        that key was included in ``sampling_dict``, or the original value
    """
    # define resample function (must take single argument)
    def _resample(args):
        ts, fs = args
        if fs and units.Quantity(fs, "Hz") == ts.sample_rate:
            warnings.warn(
                "requested resample rate for {0} matches native rate ({1}), "
                "please update configuration".format(ts.name, ts.sample_rate),
                UserWarning,
            )
        elif fs:
            return ts.resample(fs, ftype='fir', window='hamming')
        return ts

    # group timeseries with new sampling frequencies
    inputs = ((ts, sampling_dict.get(name)) for name, ts in tsd.items())

    # apply resampling
    resampled = multiprocess_with_queues(nproc, _resample, inputs)

    # map back to original dict keys
    return dict(zip(list(tsd), resampled))
Esempio n. 3
0
    def classify(self, path_to_cnn, **kwargs):
        """Classify triggers in this table

        Parameters:

            path_to_cnn:
                file name of the CNN you would like to use

            **kwargs:
                nproc : number of parallel event times to be processing at once

        Returns:
            `Events` table
        """
        if 'event_time' not in self.keys():
            raise ValueError("This method only works if you have defined "
                             "a column event_time for your "
                             "Event Trigger Generator.")

        # Parse key word arguments
        config = kwargs.pop('config', utils.GravitySpyConfigFile())
        plot_directory = kwargs.pop('plot_directory', 'plots')
        timeseries = kwargs.pop('timeseries', None)
        source = kwargs.pop('source', None)
        channel_name = kwargs.pop('channel_name', None)
        frametype = kwargs.pop('frametype', None)
        verbose = kwargs.pop('verbose', False)
        # calculate maximum number of processes
        nproc = kwargs.pop('nproc', 1)

        # make a list of event times
        inputs = zip(self['event_time'], self['ifo'],
                     self['gravityspy_id'])

        inputs = ((etime, ifo, gid, config, plot_directory,
                   timeseries, source, channel_name, frametype, nproc, verbose)
                  for etime, ifo, gid in inputs)

        # make q_scans
        output = mp_utils.multiprocess_with_queues(nproc,
                                                   _make_single_qscan,
                                                   inputs)

        qvalues = []
        # raise exceptions (from multiprocessing, single process raises inline)
        for f, x in output:
            if isinstance(x, Exception):
                x.args = ('Failed to make q scan at time %s: %s' % (f,
                                                                    str(x)),)
                raise x
            else:
                qvalues.append(x)

        self['q_value'] = qvalues

        results = utils.label_q_scans(plot_directory=plot_directory,
                                      path_to_cnn=path_to_cnn,
                                      verbose=verbose,
                                      **kwargs)

        results = results.to_pandas()
        results['Filename1'] = results['Filename1'].apply(lambda x, y : os.path.join(y, x),
                                                          args=(plot_directory,))
        results['Filename2'] = results['Filename2'].apply(lambda x, y : os.path.join(y, x),
                                                          args=(plot_directory,))
        results['Filename3'] = results['Filename3'].apply(lambda x, y : os.path.join(y, x),
                                                          args=(plot_directory,))
        results['Filename4'] = results['Filename4'].apply(lambda x, y : os.path.join(y, x),
                                                          args=(plot_directory,))


        results = Events.from_pandas(results.merge(self.to_pandas(),
                                                   on=['gravityspy_id']))
        return results
Esempio n. 4
0
    def evolve(cls, initialbinarytable, BSEDict, **kwargs):
        """After setting a number of initial conditions
        we evolve the system.

        Parameters
        ----------
        nproc : `int`, optional, default: 1
            number of CPUs to use for parallel file reading

        kwargs: 

        Returns
        -------
        output_bpp : DataFrame
            Evolutionary history of each binary
        output_bcm : DataFrame
            Final state of each binary
        initialbinarytable : DataFrame
            Initial conditions for each binary
        """
        idx = kwargs.pop('idx', 0)
        nproc = min(kwargs.pop('nproc', 1), len(initialbinarytable))

        initialbinarytable['neta'] = BSEDict['neta']
        initialbinarytable['bwind'] = BSEDict['bwind']
        initialbinarytable['hewind'] = BSEDict['hewind']
        initialbinarytable['alpha1'] = BSEDict['alpha1']
        initialbinarytable['lambdaf'] = BSEDict['lambdaf']
        initialbinarytable['ceflag'] = BSEDict['ceflag']
        initialbinarytable['tflag'] = BSEDict['tflag']
        initialbinarytable['ifflag'] = BSEDict['ifflag']
        initialbinarytable['wdflag'] = BSEDict['wdflag']
        initialbinarytable['bhflag'] = BSEDict['bhflag']
        initialbinarytable['nsflag'] = BSEDict['nsflag']
        initialbinarytable['mxns'] = BSEDict['mxns']
        initialbinarytable['pts1'] = BSEDict['pts1']
        initialbinarytable['pts2'] = BSEDict['pts2']
        initialbinarytable['pts3'] = BSEDict['pts3']
        initialbinarytable['sigma'] = BSEDict['sigma']
        initialbinarytable['beta'] = BSEDict['beta']
        initialbinarytable['xi'] = BSEDict['xi']
        initialbinarytable['acc2'] = BSEDict['acc2']
        initialbinarytable['epsnov'] = BSEDict['epsnov']
        initialbinarytable['eddfac'] = BSEDict['eddfac']
        initialbinarytable['gamma'] = BSEDict['gamma']
        initialbinarytable['bconst'] = BSEDict['bconst']
        initialbinarytable['CK'] = BSEDict['CK']
        initialbinarytable['merger'] = BSEDict['merger']
        initialbinarytable['windflag'] = BSEDict['windflag']
        initialbinarytable['dtp'] = initialbinarytable['tphysf']
        initialbinarytable['bin_num'] = np.arange(
            idx, idx + len(initialbinarytable))
        initialbinarytable['randomseed'] = np.random.randint(
            1, 1000000, size=len(initialbinarytable))

        initial_conditions = np.array(initialbinarytable)

        # define multiprocessing method
        def _evolve_single_system(f):
            try:
                # kstar, mass, orbital period (days), eccentricity, metaliccity, evolution time (millions of years)
                [tmp1,
                 tmp2] = _evolvebin.evolv2(f[0], f[1], f[2], f[3], f[4], f[5],
                                           f[6], f[7], f[8], f[9], f[10],
                                           f[11], f[12], f[13], f[14], f[15],
                                           f[16], f[17], f[18], f[19], f[20],
                                           f[21], f[22], f[23], f[24], f[25],
                                           f[26], f[27], f[28], f[29], f[30],
                                           f[31], f[32], f[33], f[34], f[36])

                bpp_tmp = tmp1[np.argwhere(tmp1[:, 0] > 0), :].squeeze(1)
                bcm_tmp = tmp2[np.argwhere(tmp2[:, 0] > 1), :].squeeze(1)

                bpp_tmp = pd.DataFrame(bpp_tmp,
                                       columns=bpp_columns,
                                       index=[int(f[35])] * len(bpp_tmp))
                bpp_tmp['bin_num'] = int(f[35])
                bpp_tmp.set_index('bin_num')

                bcm_tmp = pd.DataFrame(bcm_tmp,
                                       columns=bcm_columns,
                                       index=[int(f[35])] * len(bcm_tmp))
                bcm_tmp['bin_num'] = int(f[35])
                bcm_tmp.set_index('bin_num')

                return f, bpp_tmp, bcm_tmp

            except Exception as e:
                if nproc == 1:
                    raise
                else:
                    return f, e

        # evolve sysyems
        output = mp_utils.multiprocess_with_queues(nproc,
                                                   _evolve_single_system,
                                                   initial_conditions,
                                                   raise_exceptions=False)

        # raise exceptions (from multiprocessing, single process raises inline)
        for f, x, y in output:
            if isinstance(x, Exception):
                x.args = ('Failed to evolve %s: %s' % (f, str(x)), )
                raise x
            if isinstance(y, Exception):
                y.args = ('Failed to evolve %s: %s' % (f, str(y)), )
                raise y

        output_bpp = pd.DataFrame()
        output_bcm = pd.DataFrame()
        for f, x, y in output:
            output_bpp = output_bpp.append(x)
            output_bcm = output_bcm.append(y)

        initialbinarytable.set_index('bin_num')
        return output_bpp, output_bcm, initialbinarytable
Esempio n. 5
0
File: data.py Progetto: gwpy/gwsumm
    def process_state(self, state, nds=None, nproc=1,
                      config=GWSummConfigParser(), datacache=None,
                      trigcache=None, segmentcache=None, segdb_error='raise',
                      datafind_error='raise'):
        """Process data for this tab in a given state

        Parameters
        ----------
        state : `~gwsumm.state.SummaryState`
            the state to process. Can give `None` to process ALLSTATE with
            no plots, useful to load all data for other states
        nds : `bool`, optional
            `True` to use NDS to read data, otherwise read from frames.
            Use `None` to read from frames if possible, otherwise
            using NDS.
        nproc : `int`, optional
            number of parallel cores to use when reading data and making
            plots, default: ``1``
        config : `ConfigParser`, optional
            configuration for this analysis
        datacache : `~glue.lal.Cache`, optional
            `Cache` of files from which to read time-series data
        trigcache : `~glue.lal.Cache`, optional
            `Cache` of files from which to read event triggers
        segmentcache : `~glue.lal.Cache`, optional
            `Cache` of files from which to read segments
        segdb_error : `str`, optional
            if ``'raise'``: raise exceptions when the segment database
            reports exceptions, if ``'warn''`, print warnings but continue,
            otherwise ``'ignore'`` them completely and carry on.
        """
        if state:
            all_data = False
        else:
            all_data = True
            state = get_state(ALLSTATE)

        # flag those plots that were already written by this process
        for p in self.plots + self.subplots:
            if p.outputfile in globalv.WRITTEN_PLOTS:
                p.new = False

        # --------------------------------------------------------------------
        # process time-series

        # find channels that need a TimeSeries
        tschannels = self.get_channels('timeseries',
                                       all_data=all_data, read=True)
        if len(tschannels):
            vprint("    %d channels identified for TimeSeries\n"
                   % len(tschannels))
            get_timeseries_dict(tschannels, state, config=config, nds=nds,
                                nproc=nproc, cache=datacache,
                                datafind_error=datafind_error, return_=False)
            vprint("    All time-series data loaded\n")

        # find channels that need a StateVector
        svchannels = set(self.get_channels('statevector', all_data=all_data,
                                           read=True))
        odcchannels = self.get_channels('odc', all_data=all_data, read=True)
        svchannels.update(odcchannels)
        svchannels = list(svchannels)
        if len(svchannels):
            vprint("    %d channels identified as StateVectors\n"
                   % (len(svchannels) - len(odcchannels)))
            get_timeseries_dict(svchannels, state, config=config, nds=nds,
                                nproc=nproc, statevector=True,
                                cache=datacache, return_=False,
                                datafind_error=datafind_error, dtype='uint32')
            vprint("    All state-vector data loaded\n")

        # --------------------------------------------------------------------
        # process spectrograms

        # find FFT parameters
        try:
            fftparams = dict(config.nditems('fft'))
        except NoSectionError:
            fftparams = {}
        for key, val in fftparams.items():
            try:
                fftparams[key] = eval(val)
            except (NameError, SyntaxError):
                pass

        sgchannels = self.get_channels('spectrogram', 'spectrum',
                                       all_data=all_data, read=True)
        raychannels = self.get_channels('rayleigh-spectrogram',
                                        'rayleigh-spectrum',
                                        all_data=all_data, read=True)
        # for coherence spectrograms, we need all pairs of channels,
        # not just the unique ones
        csgchannels = self.get_channels('coherence-spectrogram',
                                        all_data=all_data, read=True,
                                        unique=False, state=state)

        # pad spectrogram segments to include final time bin
        specsegs = SegmentList(state.active)
        specchannels = set.union(sgchannels, raychannels, csgchannels)
        if specchannels and specsegs and specsegs[-1][1] == self.end:
            stride = max(filter(
                lambda x: x is not None,
                (get_fftparams(c, **fftparams).stride for c in specchannels),
            ))
            specsegs[-1] = Segment(specsegs[-1][0], self.end+stride)

        if len(sgchannels):
            vprint("    %d channels identified for Spectrogram\n"
                   % len(sgchannels))

            get_spectrograms(sgchannels, specsegs, config=config, nds=nds,
                             nproc=nproc, return_=False,
                             cache=datacache, datafind_error=datafind_error,
                             **fftparams)

        if len(raychannels):
            fp2 = fftparams.copy()
            fp2['method'] = fp2['format'] = 'rayleigh'
            get_spectrograms(raychannels, specsegs, config=config,
                             return_=False, nproc=nproc, **fp2)

        if len(csgchannels):
            if (len(csgchannels) % 2 != 0):
                raise ValueError("Error processing coherence spectrograms: "
                                 "you must supply exactly 2 channels for "
                                 "each spectrogram.")
            vprint("    %d channel pairs identified for Coherence "
                   "Spectrogram\n" % (len(csgchannels)/2))
            fp2 = fftparams.copy()
            fp2['method'] = 'welch'
            get_coherence_spectrograms(
                csgchannels, specsegs, config=config, nds=nds,
                nproc=nproc, return_=False, cache=datacache,
                datafind_error=datafind_error, **fp2)

        # --------------------------------------------------------------------
        # process spectra

        for channel in self.get_channels('spectrum', all_data=all_data,
                                         read=True):
            get_spectrum(channel, state, config=config, return_=False,
                         query=False, **fftparams)

        for channel in self.get_channels(
                'rayleigh-spectrum', all_data=all_data, read=True):
            fp2 = fftparams.copy()
            fp2['method'] = fp2['format'] = 'rayleigh'
            get_spectrum(channel, state, config=config, return_=False, **fp2)

        # --------------------------------------------------------------------
        # process segments

        # find flags that need a DataQualityFlag
        dqflags = set(self.get_flags('segments', all_data=all_data))
        dqflags.update(self.get_flags('timeseries', all_data=all_data,
                                      type='time-volume'))
        dqflags.update(self.get_flags('spectrogram', all_data=all_data,
                                      type='strain-time-volume'))
        if len(dqflags):
            vprint("    %d data-quality flags identified for segments\n"
                   % len(dqflags))
            get_segments(dqflags, state, config=config,
                         segdb_error=segdb_error, cache=segmentcache)

        # --------------------------------------------------------------------
        # process triggers

        for etg, channel in self.get_triggers('triggers',
                                              'trigger-timeseries',
                                              'trigger-rate',
                                              'trigger-histogram',
                                              all_data=all_data):
            get_triggers(channel, etg, state.active, config=config,
                         cache=trigcache, nproc=nproc, return_=False)

        # --------------------------------------------------------------------
        # make plots

        if all_data or self.noplots:
            vprint("    Done.\n")
            return

        # filter out plots that aren't for this state
        new_plots = [p for p in self.plots + self.subplots if p.new and
                     (p.state is None or p.state.name == state.name)]

        # separate plots into serial and parallel groups
        if int(nproc) <= 1 or not rcParams['text.usetex']:
            serial = new_plots
            parallel = []
        else:
            serial = [p for p in new_plots if not p._threadsafe]
            parallel = [p for p in new_plots if p._threadsafe]

        # process serial plots
        if serial:
            vprint("    Executing %d plots in serial:\n" % len(serial))
            multiprocess_with_queues(1, lambda p: p.process(), serial)

        # process parallel plots
        if parallel:
            nproc = min(len(parallel), nproc)
            vprint("    Executing %d plots in %d processes:\n"
                   % (len(parallel), nproc))
            multiprocess_with_queues(nproc, lambda p: p.process(), parallel)

        # record that we have written all of these plots
        globalv.WRITTEN_PLOTS.extend(p.outputfile for p in serial + parallel)

        vprint('Done.\n')
Esempio n. 6
0
    def evolve(cls, initialbinarytable, BSEDict, **kwargs):
        """After setting a number of initial conditions
        we evolve the system.

        Parameters
        ----------
        nproc : `int`, optional, default: 1
            number of CPUs to use to evolve systems
            in parallel
        idx : `int`, optional, default: 0
            initial index of the bcm/bpp arrays
        dtp : `float`, optional: default: tphysf
            timestep size in Myr for bcm output where tphysf
            is total evolution time in Myr

        Returns
        -------
        output_bpp : DataFrame
            Evolutionary history of each binary
        output_bcm : DataFrame
            Final state of each binary
        initialbinarytable : DataFrame
            Initial conditions for each binary
        """
        idx = kwargs.pop('idx', 0)
        nproc = min(kwargs.pop('nproc', 1), len(initialbinarytable))

        if 'neta' not in initialbinarytable.keys():
            initialbinarytable['neta'] = BSEDict['neta']
        if 'bwind' not in initialbinarytable.keys():
            initialbinarytable['bwind'] = BSEDict['bwind']
        if 'hewind' not in initialbinarytable.keys():
            initialbinarytable['hewind'] = BSEDict['hewind']
        if 'alpha1' not in initialbinarytable.keys():
            initialbinarytable['alpha1'] = BSEDict['alpha1']
        if 'lambdaf' not in initialbinarytable.keys():
            initialbinarytable['lambdaf'] = BSEDict['lambdaf']
        if 'cekickflag' not in initialbinarytable.keys():
            initialbinarytable['cekickflag'] = BSEDict['cekickflag']
        if 'cemergeflag' not in initialbinarytable.keys():
            initialbinarytable['cemergeflag'] = BSEDict['cemergeflag']
        if 'cehestarflag' not in initialbinarytable.keys():
            initialbinarytable['cehestarflag'] = BSEDict['cehestarflag']
        if 'ceflag' not in initialbinarytable.keys():
            initialbinarytable['ceflag'] = BSEDict['ceflag']
        if 'tflag' not in initialbinarytable.keys():
            initialbinarytable['tflag'] = BSEDict['tflag']
        if 'ifflag' not in initialbinarytable.keys():
            initialbinarytable['ifflag'] = BSEDict['ifflag']
        if 'wdflag' not in initialbinarytable.keys():
            initialbinarytable['wdflag'] = BSEDict['wdflag']
        if 'ppsn' not in initialbinarytable.keys():
            initialbinarytable['ppsn'] = BSEDict['ppsn']
        if 'bhflag' not in initialbinarytable.keys():
            initialbinarytable['bhflag'] = BSEDict['bhflag']
        if 'nsflag' not in initialbinarytable.keys():
            initialbinarytable['nsflag'] = BSEDict['nsflag']
        if 'mxns' not in initialbinarytable.keys():
            initialbinarytable['mxns'] = BSEDict['mxns']
        if 'pts1' not in initialbinarytable.keys():
            initialbinarytable['pts1'] = BSEDict['pts1']
        if 'pts2' not in initialbinarytable.keys():
            initialbinarytable['pts2'] = BSEDict['pts2']
        if 'pts3' not in initialbinarytable.keys():
            initialbinarytable['pts3'] = BSEDict['pts3']
        if 'ecsnp' not in initialbinarytable.keys():
            initialbinarytable['ecsnp'] = BSEDict['ecsnp']
        if 'ecsn_mlow' not in initialbinarytable.keys():
            initialbinarytable['ecsn_mlow'] = BSEDict['ecsn_mlow']
        if 'aic' not in initialbinarytable.keys():
            initialbinarytable['aic'] = BSEDict['aic']
        if 'sigma' not in initialbinarytable.keys():
            initialbinarytable['sigma'] = BSEDict['sigma']
        if 'sigmadiv' not in initialbinarytable.keys():
            initialbinarytable['sigmadiv'] = BSEDict['sigmadiv']
        if 'bhsigmafrac' not in initialbinarytable.keys():
            initialbinarytable['bhsigmafrac'] = BSEDict['bhsigmafrac']
        if 'polar_kick_angle' not in initialbinarytable.keys():
            initialbinarytable['polar_kick_angle'] = BSEDict[
                'polar_kick_angle']
        if 'beta' not in initialbinarytable.keys():
            initialbinarytable['beta'] = BSEDict['beta']
        if 'xi' not in initialbinarytable.keys():
            initialbinarytable['xi'] = BSEDict['xi']
        if 'acc2' not in initialbinarytable.keys():
            initialbinarytable['acc2'] = BSEDict['acc2']
        if 'epsnov' not in initialbinarytable.keys():
            initialbinarytable['epsnov'] = BSEDict['epsnov']
        if 'eddfac' not in initialbinarytable.keys():
            initialbinarytable['eddfac'] = BSEDict['eddfac']
        if 'gamma' not in initialbinarytable.keys():
            initialbinarytable['gamma'] = BSEDict['gamma']
        if 'bconst' not in initialbinarytable.keys():
            initialbinarytable['bconst'] = BSEDict['bconst']
        if 'ck' not in initialbinarytable.keys():
            initialbinarytable['ck'] = BSEDict['ck']
        if 'merger' not in initialbinarytable.keys():
            initialbinarytable['merger'] = BSEDict['merger']
        if 'windflag' not in initialbinarytable.keys():
            initialbinarytable['windflag'] = BSEDict['windflag']
        if 'dtp' not in initialbinarytable.keys():
            initialbinarytable['dtp'] = kwargs.pop(
                'dtp', initialbinarytable['tphysf'])
        if 'randomseed' not in initialbinarytable.keys():
            initialbinarytable['randomseed'] = np.random.randint(
                1, 1000000, size=len(initialbinarytable))
        if 'bin_num' not in initialbinarytable.keys():
            initialbinarytable['bin_num'] = np.arange(
                idx, idx + len(initialbinarytable))

        natal_kick_columns = [
            'SNkick_1', 'SNkick_2', 'phi_1', 'phi_2', 'theta_1', 'theta_2'
        ]
        if pd.Series(natal_kick_columns).isin(initialbinarytable.keys()).all(
        ) and 'natal_kick_array' not in initialbinarytable.keys():
            initialbinarytable['natal_kick_array'] = initialbinarytable[
                natal_kick_columns].values.tolist()
        if 'natal_kick_array' not in initialbinarytable.keys():
            initialbinarytable['natal_kick_array'] = [
                BSEDict['natal_kick_array']
            ] * len(initialbinarytable)
            for idx, column_name in enumerate(natal_kick_columns):
                initialbinarytable.loc[:, column_name] = pd.Series(
                    [BSEDict['natal_kick_array'][idx]] *
                    len(initialbinarytable),
                    index=initialbinarytable.index,
                    name=column_name)

        qcrit_columns = ['qcrit_{0}'.format(kstar) for kstar in range(0, 16)]
        if pd.Series(qcrit_columns).isin(initialbinarytable.keys()).all(
        ) and 'qcrit_array' not in initialbinarytable.keys():
            initialbinarytable['qcrit_array'] = initialbinarytable[
                qcrit_columns].values.tolist()

        if 'qcrit_array' not in initialbinarytable.keys():
            initialbinarytable['qcrit_array'] = [BSEDict['qcrit_array']
                                                 ] * len(initialbinarytable)
            for kstar in range(0, 16):
                initialbinarytable.loc[:,
                                       'qcrit_{0}'.format(kstar)] = pd.Series(
                                           [BSEDict['qcrit_array'][kstar]] *
                                           len(initialbinarytable),
                                           index=initialbinarytable.index,
                                           name='qcrit_{0}'.format(kstar))

        # need to ensure that the order of variables is correct
        initial_conditions = initialbinarytable[[
            'kstar_1', 'kstar_2', 'mass1_binary', 'mass2_binary', 'porb',
            'ecc', 'metallicity', 'tphysf', 'neta', 'bwind', 'hewind',
            'alpha1', 'lambdaf', 'ceflag', 'tflag', 'ifflag', 'wdflag', 'ppsn',
            'bhflag', 'nsflag', 'cekickflag', 'cemergeflag', 'cehestarflag',
            'mxns', 'pts1', 'pts2', 'pts3', 'ecsnp', 'ecsn_mlow', 'aic',
            'sigma', 'sigmadiv', 'bhsigmafrac', 'polar_kick_angle',
            'natal_kick_array', 'qcrit_array', 'beta', 'xi', 'acc2', 'epsnov',
            'eddfac', 'gamma', 'bconst', 'ck', 'merger', 'windflag', 'dtp',
            'randomseed', 'bin_num'
        ]].values

        initial_binary_table_column_names = [
            'kstar_1', 'kstar_2', 'mass1_binary', 'mass2_binary', 'porb',
            'ecc', 'metallicity', 'tphysf', 'neta', 'bwind', 'hewind',
            'alpha1', 'lambdaf', 'ceflag', 'tflag', 'ifflag', 'wdflag', 'ppsn',
            'bhflag', 'nsflag', 'cekickflag', 'cemergeflag', 'cehestarflag',
            'mxns', 'pts1', 'pts2', 'pts3', 'ecsnp', 'ecsn_mlow', 'aic',
            'sigma', 'sigmadiv', 'bhsigmafrac', 'polar_kick_angle', 'beta',
            'xi', 'acc2', 'epsnov', 'eddfac', 'gamma', 'bconst', 'ck',
            'merger', 'windflag', 'dtp', 'randomseed', 'bin_num'
        ]

        initial_binary_table_column_names.extend(natal_kick_columns)
        initial_binary_table_column_names.extend(qcrit_columns)

        initialbinarytable = initialbinarytable[
            initial_binary_table_column_names]

        # define multiprocessing method
        def _evolve_single_system(f):
            try:
                # kstar, mass, orbital period (days), eccentricity, metaliccity, evolution time (millions of years)
                [bpp, bcm] = _evolvebin.evolv2(
                    f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9],
                    f[10], f[11], f[12], f[13], f[14], f[15], f[16], f[17],
                    f[18], f[19], f[20], f[21], f[22], f[23], f[24], f[25],
                    f[26], f[27], f[28], f[29], f[30], f[31], f[32], f[33],
                    f[34], f[35], f[36], f[37], f[38], f[39], f[40], f[41],
                    f[42], f[43], f[44], f[45], f[46], f[47])

                try:
                    bpp = bpp[:np.argwhere(bpp[:, 0] == -1)[0][0]]
                    bcm = bcm[:np.argwhere(bcm[:, 0] == -1)[0][0]]
                except IndexError:
                    bpp = bpp[:np.argwhere(bpp[:, 0] > 0)[0][0]]
                    bcm = bcm[:np.argwhere(bcm[:, 0] > 0)[0][0]]
                    raise Warning('bpp overload: mass1 = {0}, mass2 = {1}, porb = {2}, ecc = {3}, tphysf = {4}, metallicity = {5}'\
                                   .format(f[2], f[3], f[4], f[5], f[7], f[6]))

                bpp_bin_numbers = np.atleast_2d(np.array([f[48]] * len(bpp))).T
                bcm_bin_numbers = np.atleast_2d(np.array([f[48]] * len(bcm))).T

                bpp = np.hstack((bpp, bpp_bin_numbers))
                bcm = np.hstack((bcm, bcm_bin_numbers))

                return f, bpp, bcm

            except Exception as e:
                raise

        # evolve sysyems
        output = mp_utils.multiprocess_with_queues(nproc,
                                                   _evolve_single_system,
                                                   initial_conditions,
                                                   raise_exceptions=False)

        output = np.array(output)
        bpp_arrays = np.vstack(output[:, 1])
        bcm_arrays = np.vstack(output[:, 2])

        bpp = pd.DataFrame(bpp_arrays,
                           columns=bpp_columns,
                           index=bpp_arrays[:, -1].astype(int))

        bcm = pd.DataFrame(bcm_arrays,
                           columns=bcm_columns,
                           index=bcm_arrays[:, -1].astype(int))

        bcm.merger_type = bcm.merger_type.astype(int).astype(str).apply(
            lambda x: x.zfill(4))
        bcm.bin_state = bcm.bin_state.astype(int)
        bpp.bin_num = bpp.bin_num.astype(int)
        bcm.bin_num = bcm.bin_num.astype(int)

        return bpp, bcm, initialbinarytable
Esempio n. 7
0
    def process_state(self,
                      state,
                      nds=None,
                      nproc=1,
                      config=GWSummConfigParser(),
                      datacache=None,
                      trigcache=None,
                      segmentcache=None,
                      segdb_error='raise',
                      datafind_error='raise'):
        """Process data for this tab in a given state

        Parameters
        ----------
        state : `~gwsumm.state.SummaryState`
            the state to process. Can give `None` to process ALLSTATE with
            no plots, useful to load all data for other states
        nds : `bool`, optional
            `True` to use NDS to read data, otherwise read from frames.
            Use `None` to read from frames if possible, otherwise
            using NDS.
        nproc : `int`, optional
            number of parallel cores to use when reading data and making
            plots, default: ``1``
        config : `ConfigParser`, optional
            configuration for this analysis
        datacache : `~glue.lal.Cache`, optional
            `Cache` of files from which to read time-series data
        trigcache : `~glue.lal.Cache`, optional
            `Cache` of files from which to read event triggers
        segmentcache : `~glue.lal.Cache`, optional
            `Cache` of files from which to read segments
        segdb_error : `str`, optional
            if ``'raise'``: raise exceptions when the segment database
            reports exceptions, if ``'warn''`, print warnings but continue,
            otherwise ``'ignore'`` them completely and carry on.
        """
        if state:
            all_data = False
        else:
            all_data = True
            state = get_state(ALLSTATE)

        # flag those plots that were already written by this process
        for p in self.plots + self.subplots:
            if p.outputfile in globalv.WRITTEN_PLOTS:
                p.new = False

        # --------------------------------------------------------------------
        # process time-series

        # find channels that need a TimeSeries
        tschannels = self.get_channels('timeseries',
                                       all_data=all_data,
                                       read=True)
        if len(tschannels):
            vprint("    %d channels identified for TimeSeries\n" %
                   len(tschannels))
            get_timeseries_dict(tschannels,
                                state,
                                config=config,
                                nds=nds,
                                nproc=nproc,
                                cache=datacache,
                                datafind_error=datafind_error,
                                return_=False)
            vprint("    All time-series data loaded\n")

        # find channels that need a StateVector
        svchannels = set(
            self.get_channels('statevector', all_data=all_data, read=True))
        odcchannels = self.get_channels('odc', all_data=all_data, read=True)
        svchannels.update(odcchannels)
        svchannels = list(svchannels)
        if len(svchannels):
            vprint("    %d channels identified as StateVectors\n" %
                   (len(svchannels) - len(odcchannels)))
            get_timeseries_dict(svchannels,
                                state,
                                config=config,
                                nds=nds,
                                nproc=nproc,
                                statevector=True,
                                cache=datacache,
                                return_=False,
                                datafind_error=datafind_error,
                                dtype='uint32')
            vprint("    All state-vector data loaded\n")

        # --------------------------------------------------------------------
        # process spectrograms

        # find FFT parameters
        try:
            fftparams = dict(config.nditems('fft'))
        except NoSectionError:
            fftparams = {}
        for key, val in fftparams.items():
            try:
                fftparams[key] = eval(val)
            except (NameError, SyntaxError):
                pass

        sgchannels = self.get_channels('spectrogram',
                                       'spectrum',
                                       all_data=all_data,
                                       read=True)
        raychannels = self.get_channels('rayleigh-spectrogram',
                                        'rayleigh-spectrum',
                                        all_data=all_data,
                                        read=True)
        # for coherence spectrograms, we need all pairs of channels,
        # not just the unique ones
        csgchannels = self.get_channels('coherence-spectrogram',
                                        all_data=all_data,
                                        read=True,
                                        unique=False,
                                        state=state)

        # pad spectrogram segments to include final time bin
        specsegs = SegmentList(state.active)
        specchannels = set.union(sgchannels, raychannels, csgchannels)
        if specchannels and specsegs and specsegs[-1][1] == self.end:
            stride = max(
                filter(
                    lambda x: x is not None,
                    (get_fftparams(c, **fftparams).stride
                     for c in specchannels),
                ))
            specsegs[-1] = Segment(specsegs[-1][0], self.end + stride)

        if len(sgchannels):
            vprint("    %d channels identified for Spectrogram\n" %
                   len(sgchannels))

            get_spectrograms(sgchannels,
                             specsegs,
                             config=config,
                             nds=nds,
                             nproc=nproc,
                             return_=False,
                             cache=datacache,
                             datafind_error=datafind_error,
                             **fftparams)

        if len(raychannels):
            fp2 = fftparams.copy()
            fp2['method'] = fp2['format'] = 'rayleigh'
            get_spectrograms(raychannels,
                             specsegs,
                             config=config,
                             return_=False,
                             nproc=nproc,
                             **fp2)

        if len(csgchannels):
            if (len(csgchannels) % 2 != 0):
                raise ValueError("Error processing coherence spectrograms: "
                                 "you must supply exactly 2 channels for "
                                 "each spectrogram.")
            vprint("    %d channel pairs identified for Coherence "
                   "Spectrogram\n" % (len(csgchannels) / 2))
            fp2 = fftparams.copy()
            fp2['method'] = 'welch'
            get_coherence_spectrograms(csgchannels,
                                       specsegs,
                                       config=config,
                                       nds=nds,
                                       nproc=nproc,
                                       return_=False,
                                       cache=datacache,
                                       datafind_error=datafind_error,
                                       **fp2)

        # --------------------------------------------------------------------
        # process spectra

        for channel in self.get_channels('spectrum',
                                         all_data=all_data,
                                         read=True):
            get_spectrum(channel,
                         state,
                         config=config,
                         return_=False,
                         query=False,
                         **fftparams)

        for channel in self.get_channels('rayleigh-spectrum',
                                         all_data=all_data,
                                         read=True):
            fp2 = fftparams.copy()
            fp2['method'] = fp2['format'] = 'rayleigh'
            get_spectrum(channel, state, config=config, return_=False, **fp2)

        # --------------------------------------------------------------------
        # process segments

        # find flags that need a DataQualityFlag
        dqflags = set(self.get_flags('segments', all_data=all_data))
        dqflags.update(
            self.get_flags('timeseries', all_data=all_data,
                           type='time-volume'))
        dqflags.update(
            self.get_flags('spectrogram',
                           all_data=all_data,
                           type='strain-time-volume'))
        if len(dqflags):
            vprint("    %d data-quality flags identified for segments\n" %
                   len(dqflags))
            get_segments(dqflags,
                         state,
                         config=config,
                         segdb_error=segdb_error,
                         cache=segmentcache)

        # --------------------------------------------------------------------
        # process triggers

        for etg, channel in self.get_triggers('triggers',
                                              'trigger-timeseries',
                                              'trigger-rate',
                                              'trigger-histogram',
                                              all_data=all_data):
            get_triggers(channel,
                         etg,
                         state.active,
                         config=config,
                         cache=trigcache,
                         nproc=nproc,
                         return_=False)

        # --------------------------------------------------------------------
        # make plots

        if all_data or self.noplots:
            vprint("    Done.\n")
            return

        # filter out plots that aren't for this state
        new_plots = [
            p for p in self.plots + self.subplots
            if p.new and (p.state is None or p.state.name == state.name)
        ]

        # separate plots into serial and parallel groups
        if int(nproc) <= 1 or not rcParams['text.usetex']:
            serial = new_plots
            parallel = []
        else:
            serial = [p for p in new_plots if not p._threadsafe]
            parallel = [p for p in new_plots if p._threadsafe]

        # process serial plots
        if serial:
            vprint("    Executing %d plots in serial:\n" % len(serial))
            multiprocess_with_queues(1, lambda p: p.process(), serial)

        # process parallel plots
        if parallel:
            nproc = min(len(parallel), nproc)
            vprint("    Executing %d plots in %d processes:\n" %
                   (len(parallel), nproc))
            multiprocess_with_queues(nproc, lambda p: p.process(), parallel)

        # record that we have written all of these plots
        globalv.WRITTEN_PLOTS.extend(p.outputfile for p in serial + parallel)

        vprint('Done.\n')