Esempio n. 1
0
def data_prep(st, data):
    """ Applies calibration, flags, and subtracts time mean for data.
    """

    # ** need to make this portable or at least reimplement in rfpipe
    if np.any(data):
        if st.metadata.datasource != 'sim':
            if os.path.exists(st.gainfile):
                data = calibration.apply_telcal(st, data)
            else:
                logger.warn(
                    'Telcal file not found. No calibration being applied.')
        else:
            logger.info('Not applying telcal solutions for simulated data')

        # ** dataflag points to rtpipe for now.
        # changes memory in place, so need to force writability
        data = util.dataflag(st, np.require(data, requirements='W'))

        if st.prefs.timesub == 'mean':
            logger.info('Subtracting mean visibility in time.')
            data = util.meantsub(data)
        else:
            logger.info('No visibility subtraction done.')

        if st.prefs.savenoise:
            logger.warn("Saving of noise properties not implemented yet.")

    return data
Esempio n. 2
0
def prepare_data(sdmfile,
                 gainfile,
                 delta_l,
                 delta_m,
                 segment=0,
                 dm=0,
                 dt=1,
                 spws=None):
    """
    
    Applies Calibration, flagging, dedispersion and other data preparation steps
    from rfpipe. Then phaseshifts the data to the location of the candidate. 
    
    """
    st = state.State(sdmfile=sdmfile,
                     sdmscan=1,
                     inprefs={
                         'gainfile': gainfile,
                         'workdir': '.',
                         'maxdm': 0,
                         'flaglist': []
                     },
                     showsummary=False)
    if spws:
        st.prefs.spw = spws

    data = source.read_segment(st, segment)

    takepol = [st.metadata.pols_orig.index(pol) for pol in st.pols]
    takebls = [
        st.metadata.blarr_orig.tolist().index(list(bl)) for bl in st.blarr
    ]
    datap = np.require(data, requirements='W').take(takepol, axis=3).take(
        st.chans, axis=2).take(takebls, axis=1)
    datap = source.prep_standard(st, segment, datap)
    datap = calibration.apply_telcal(st, datap)
    datap = flagging.flag_data(st, datap)

    delay = calc_delay(st.freq, st.freq.max(), dm, st.inttime)
    data_dmdt = dedisperseresample(datap, delay, dt)

    print(f'shape of data_dmdt is {data_dmdt.shape}')

    uvw = get_uvw_segment(st, segment)
    phase_shift(data_dmdt, uvw=uvw, dl=delta_l, dm=delta_m)

    dataret = data_dmdt
    return dataret, st
Esempio n. 3
0
def data_prep(st, segment, data, flagversion="latest"):
    """ Applies calibration, flags, and subtracts time mean for data.
    flagversion can be "latest" or "rtpipe".
    Optionally prepares data with antenna flags, fixing out of order data,
    calibration, downsampling, etc..
    """

    if not np.any(data):
        return data

    # take pols of interest
    takepol = [st.metadata.pols_orig.index(pol) for pol in st.pols]
    logger.debug('Selecting pols {0} and chans {1}'.format(st.pols, st.chans))

    # TODO: check on reusing 'data' to save memory
    datap = np.nan_to_num(data.take(takepol, axis=3).take(st.chans, axis=2),
                          copy=True)
    datap = prep_standard(st, segment, datap)

    if not np.any(datap):
        logger.info("All data zeros after prep_standard")
        return datap

    if st.gainfile is not None:
        datap = calibration.apply_telcal(st, datap)
        if not np.any(datap):
            logger.info("All data zeros after apply_telcal")
            return datap
    else:
        logger.info("No gainfile found, so not applying calibration.")

    # support backwards compatibility for reproducible flagging
    if flagversion == "latest":
        datap = flagging.flag_data(st, datap)
    elif flagversion == "rtpipe":
        datap = flagging.flag_data_rtpipe(st, datap)

    if st.prefs.timesub == 'mean':
        logger.info('Subtracting mean visibility in time.')
        datap = util.meantsub(datap, parallel=st.prefs.nthread > 1)
    else:
        logger.info('No visibility subtraction done.')

    if st.prefs.savenoise:
        save_noise(st, segment, datap)

    return datap
Esempio n. 4
0
def make_transient_params(st,
                          ntr=1,
                          segment=None,
                          dmind=None,
                          dtind=None,
                          i=None,
                          amp=None,
                          lm=None,
                          snr=None,
                          data=None):
    """ Given a state, create ntr randomized detectable transients.
    Returns list of ntr tuples of parameters.
    If data provided, it is used to inject transient at fixed SNR.
    selects random value from dmarr and dtarr.
    Mock transient will have either l or m equal to 0.
    Option exists to overload random selection with fixed segment, dmind, etc.
    """

    segment0 = segment
    dmind0 = dmind
    dtind0 = dtind
    i0 = i
    amp0 = amp
    lm0 = lm
    snr0 = snr

    mocks = []
    for tr in range(ntr):
        if segment is None:
            segment = random.choice(range(st.nsegment))

        if dmind is not None:
            dm = st.dmarr[dmind]
#            dmind = random.choice(range(len(st.dmarr)))
        else:
            dm = np.random.uniform(min(st.dmarr), max(st.dmarr))  # pc /cc

            dmarr = np.array(calc_dmarr(st))
            if dm > np.max(dmarr):
                logging.warning(
                    "Dm of injected transient is greater than the max DM searched."
                )
                dmind = len(dmarr) - 1
            else:
                dmind = np.argmax(dmarr > dm)

        if dtind is not None:
            dt = st.inttime * min(st.dtarr[dtind], 2)  # dt>2 not yet supported
        else:
            #dtind = random.choice(range(len(st.dtarr)))
            dt = st.inttime * np.random.uniform(min(st.dtarr), max(
                st.dtarr))  # s  #like an alias for "dt"
            if dt < st.inttime:
                dtind = 0
            else:
                dtind = int(np.log2(dt / st.inttime))
                if dtind >= len(st.dtarr):
                    dtind = len(st.dtarr) - 1
                    logging.warning(
                        "Width of transient is greater than max dt searched.")


# TODO: add support for arb dm/dt
#        dm = random.uniform(min(st.dmarr), max(st.dmarr))
#        dt = random.uniform(min(st.dtarr), max(st.dtarr))

        if i is None:
            i = random.choice(st.get_search_ints(segment, dmind, dtind))

        if amp is None:
            if data is None:
                amp = random.uniform(0.1, 0.5)
                logger.info("Setting mock amp to {0}".format(amp))
            else:
                if snr is None:
                    snr = random.uniform(10, 50)
                    logger.info("Setting mock snr to {0}".format(snr))
                    # TODO: support flagged data in size calc and injection
                datap = calibration.apply_telcal(st, data)
                sig = madtostd(datap[i].real) / np.sqrt(
                    datap[i].size * st.dtarr[dtind])
                amp = snr * sig
                logger.info("Setting mock amp as {0}*{1}={2}".format(
                    snr, sig, amp))

        if lm is None:
            # flip a coin to set either l or m
            if random.choice([0, 1]):
                l = math.radians(
                    random.uniform(-st.fieldsize_deg / 2,
                                   st.fieldsize_deg / 2))
                m = 0.
            else:
                l = 0.
                m = math.radians(
                    random.uniform(-st.fieldsize_deg / 2,
                                   st.fieldsize_deg / 2))
        else:
            assert len(lm) == 2, "lm must be 2-tuple"
            l, m = lm

        mocks.append((segment, i, dm, dt, amp, l, m))
        (segment, dmind, dtind, i, amp, lm, snr) = (segment0, dmind0, dtind0,
                                                    i0, amp0, lm0, snr0)

    return mocks
Esempio n. 5
0
def make_transient_params(st,
                          ntr=1,
                          segment=None,
                          dmind=None,
                          dtind=None,
                          i=None,
                          amp=None,
                          lm=None,
                          snr=None,
                          data=None):
    """ Given a state, create ntr randomized detectable transients.
    Returns list of ntr tuples of parameters.
    If data provided, it is used to inject transient at fixed apparent SNR.
    selects random value from dmarr and dtarr.
    Mock transient will have either l or m equal to 0.
    Option exists to overload random selection with fixed segment, dmind, etc.
    """

    segment0 = segment
    dmind0 = dmind
    dtind0 = dtind
    i0 = i
    amp0 = amp
    lm0 = lm
    snr0 = snr

    mocks = []
    for tr in range(ntr):
        if segment is None:
            segment = random.choice(range(st.nsegment))

        if dmind is not None:
            dm = st.dmarr[dmind]
#            dmind = random.choice(range(len(st.dmarr)))
        else:
            dm = np.random.uniform(min(st.dmarr), max(st.dmarr))  # pc /cc

            dmarr = np.array(calc_dmarr(st))
            if dm > np.max(dmarr):
                logging.warning(
                    "Dm of injected transient is greater than the max DM searched."
                )
                dmind = len(dmarr) - 1
            else:
                dmind = np.argmax(dmarr > dm)

        if dtind is not None:
            dt = st.inttime * st.dtarr[dtind]
        else:
            #max_width = 0.04/st.inttime
            #dt = st.inttime*np.random.uniform(0, max_width) # s  #like an alias for "dt"
            dt = np.random.uniform(0.001, 0.04)
            if dt < st.inttime:
                dtind = 0
            else:
                boxcar_widths = np.array(st.dtarr) * st.inttime
                if dt > np.max(boxcar_widths):
                    logging.warning(
                        "Width of transient is greater than max dt searched.")
                dtind = np.argmin(np.abs(boxcar_widths - dt))


#            else:
#                dtind = int(np.round(np.log2(dt/st.inttime)))
#                if dtind >= len(st.dtarr):
#                    dtind = len(st.dtarr) - 1
#                    logging.warning("Width of transient is greater than max dt searched.")

# TODO: add support for arb dm/dt
#        dm = random.uniform(min(st.dmarr), max(st.dmarr))
#        dt = random.uniform(min(st.dtarr), max(st.dtarr))

        if i is None:
            ints = np.array(st.get_search_ints(segment, dmind,
                                               dtind)) * st.dtarr[dtind]
            i = np.random.randint(min(ints), max(ints))

            #i = np.random.randint(0,st.readints)
            #i = random.choice(st.get_search_ints(segment, dmind, dtind))*st.dtarr[dtind]

        if amp is None:
            if data is None:
                amp = random.uniform(0.1, 0.5)
                logger.info("Setting mock amp to {0}".format(amp))
            else:
                if snr is None:
                    snr = random.uniform(10, 50)
                    # TODO: support flagged data in size calc and injection
                if data.shape != st.datashape:
                    logger.info(
                        "Looks like raw data passed in. Selecting and calibrating."
                    )
                    takepol = [
                        st.metadata.pols_orig.index(pol) for pol in st.pols
                    ]
                    data = calibration.apply_telcal(
                        st,
                        data.take(takepol, axis=3).take(st.chans, axis=2))
                # noise = madtostd(data[i].real)/np.sqrt(data[i].size*st.dtarr[dtind])
                noise = madtostd(data[i].real) / np.sqrt(
                    data[i].size)  #*st.dtarr[dtind])
                #width_factor = (dt//st.inttime)/np.sqrt(st.dtarr[dtind])
                width_factor = (dt / st.inttime) / np.sqrt(st.dtarr[dtind])
                amp = snr * noise * width_factor  #*(st.inttime/dt)
                logger.info("Setting mock amp as {0}*{1}*{2}={3}".format(
                    snr, noise, width_factor, amp))
        #else:
        #    if data is None:
        #        logger.info("Setting mock amp to {0}".format(amp))
        #    else:
        #        if data.shape != st.datashape:
        #            logger.info("Looks like raw data passed in. Selecting and calibrating.")
        #            takepol = [st.metadata.pols_orig.index(pol) for pol in st.pols]
        #            data = calibration.apply_telcal(st, data.take(takepol, axis=3).take(st.chans, axis=2))
        #        bkgrnd = np.mean(data[i].real)
        #        amp = amp0 + bkgrnd
        #        logger.info("Setting mock amp to {0}+{1} = {2}".format(amp0, bkgrnd, amp))
        #        noise = madtostd(data[i].real)/np.sqrt(data[i].size)
        #        snr_est = amp/(noise*np.sqrt(st.dtarr[dtind]))
        #        logger.info("Estimated SNR is {0}".format(snr_est))

        if lm is None:
            # flip a coin to set either l or m
            if random.choice([0, 1]):
                l = math.radians(
                    random.uniform(-st.fieldsize_deg / 2,
                                   st.fieldsize_deg / 2))
                m = 0.
            else:
                l = 0.
                m = math.radians(
                    random.uniform(-st.fieldsize_deg / 2,
                                   st.fieldsize_deg / 2))
        else:
            if lm == -1:
                l = math.radians(
                    random.uniform(-st.fieldsize_deg / 2,
                                   st.fieldsize_deg / 2))
                m = math.radians(
                    random.uniform(-st.fieldsize_deg / 2,
                                   st.fieldsize_deg / 2))
            else:
                assert len(lm) == 2, "lm must be 2-tuple or -1"
                l, m = lm

        mocks.append((segment, i, dm, dt, amp, l, m))
        (segment, dmind, dtind, i, amp, lm, snr) = (segment0, dmind0, dtind0,
                                                    i0, amp0, lm0, snr0)

    return mocks
Esempio n. 6
0
def data_prep(st, segment, data, flagversion="latest", returnsoltime=False):
    """ Applies calibration, flags, and subtracts time mean for data.
    flagversion can be "latest" or "rtpipe".
    Optionally prepares data with antenna flags, fixing out of order data,
    calibration, downsampling, OTF rephasing...
    """

    from rfpipe import calibration, flagging, util

    if not np.any(data):
        return data

    # take pols of interest
    takepol = [st.metadata.pols_orig.index(pol) for pol in st.pols]
    logger.debug('Selecting pols {0} and chans {1}'.format(st.pols, st.chans))

    # TODO: check on reusing 'data' to save memory
    datap = np.nan_to_num(
        np.require(data, requirements='W').take(takepol, axis=3).take(st.chans,
                                                                      axis=2))
    datap = prep_standard(st, segment, datap)

    if not np.any(datap):
        logger.info("All data zeros after prep_standard")
        return datap

    if st.gainfile is not None:
        logger.info("Applying calibration with {0}".format(st.gainfile))
        ret = calibration.apply_telcal(st,
                                       datap,
                                       savesols=st.prefs.savesols,
                                       returnsoltime=returnsoltime)
        if returnsoltime:
            datap, soltime = ret
        else:
            datap = ret

        if not np.any(datap):
            logger.info("All data zeros after apply_telcal")
            return datap
    else:
        logger.info("No gainfile found, so not applying calibration.")

    # support backwards compatibility for reproducible flagging
    logger.info("Flagging with version: {0}".format(flagversion))
    if flagversion == "latest":
        datap = flagging.flag_data(st, datap)
    elif flagversion == "rtpipe":
        datap = flagging.flag_data_rtpipe(st, datap)

    zerofrac = 1 - np.count_nonzero(datap) / datap.size
    if zerofrac > 0.8:
        logger.warning(
            'Flagged {0:.1f}% of data. Zeroing all if greater than 80%.'.
            format(zerofrac * 100))
        return np.array([])

    if st.prefs.timesub == 'mean':
        logger.info('Subtracting mean visibility in time.')
        datap = util.meantsub(datap, parallel=st.prefs.nthread > 1)
    else:
        logger.info('No visibility subtraction done.')

    if (st.prefs.apply_chweights
            or st.prefs.apply_blweights) and st.readints > 3:
        if st.prefs.apply_chweights:
            # TODO: find better estimator. Currently loses sensitivity to FRB 121102 bursts.
            chvar = np.std(np.abs(datap).mean(axis=1), axis=0)
            chvar_norm = np.mean(1 / chvar**2, axis=0)

        if st.prefs.apply_blweights:
            blvar = np.std(np.abs(datap).mean(axis=2), axis=0)
            blvar_norm = np.mean(1 / blvar**2, axis=0)

        if st.prefs.apply_chweights:
            logger.info('Reweighting data by channel variances')
            datap = (datap / chvar[None, None, :, :]) / chvar_norm[None, None,
                                                                   None, :]

        if st.prefs.apply_blweights:
            logger.info('Reweighting data by baseline variances')
            datap = (datap / blvar[None, :, None, :]) / blvar_norm[None, None,
                                                                   None, :]

    if st.prefs.savenoise:
        save_noise(st, segment, datap)

    if returnsoltime:
        return datap, soltime
    else:
        return datap
Esempio n. 7
0
def prep_standard(st, segment, data):
    """ Common first data prep stages, incl
    online flags, resampling, and mock transients.
    """

    from rfpipe import calibration, flagging, util

    if not np.any(data):
        return data

    # read and apply flags for given ant/time range. 0=bad, 1=good
    if st.prefs.applyonlineflags and st.metadata.datasource in ['vys', 'sdm']:
        flags = flagging.getonlineflags(st, segment)
        data = np.where(flags[None, :, None, None], data, 0j)
    else:
        logger.info('Not applying online flags.')

    if not np.any(data):
        return data

    if st.prefs.simulated_transient is not None or st.otfcorrections is not None:
        uvw = util.get_uvw_segment(st, segment)

    # optionally integrate (downsample)
    if ((st.prefs.read_tdownsample > 1) or (st.prefs.read_fdownsample > 1)):
        data2 = np.zeros(st.datashape, dtype='complex64')
        if st.prefs.read_tdownsample > 1:
            logger.info('Downsampling in time by {0}'.format(
                st.prefs.read_tdownsample))
            for i in range(st.datashape[0]):
                data2[i] = data[i * st.prefs.read_tdownsample:(i + 1) *
                                st.prefs.read_tdownsample].mean(axis=0)
        if st.prefs.read_fdownsample > 1:
            logger.info('Downsampling in frequency by {0}'.format(
                st.prefs.read_fdownsample))
            for i in range(st.datashape[2]):
                data2[:, :,
                      i, :] = data[:, :,
                                   i * st.prefs.read_fdownsample:(i + 1) *
                                   st.prefs.read_fdownsample].mean(axis=2)
        data = data2

    # optionally add transients
    if st.prefs.simulated_transient is not None:
        # for an int type, overload prefs.simulated_transient random mocks
        if isinstance(st.prefs.simulated_transient, int):
            logger.info(
                "Filling simulated_transient with {0} random transients".
                format(st.prefs.simulated_transient))
            st.prefs.simulated_transient = util.make_transient_params(
                st,
                segment=segment,
                ntr=st.prefs.simulated_transient,
                data=data)

        assert isinstance(st.prefs.simulated_transient,
                          list), "Simulated transient must be list of tuples."

        for params in st.prefs.simulated_transient:
            assert len(params) == 7 or len(params) == 8, (
                "Transient requires 7 or 8 parameters: "
                "(segment, i0/int, dm/pc/cm3, dt/s, "
                "amp/sys, dl/rad, dm/rad) and optionally "
                "ampslope/sys")
            if segment == params[0]:
                if len(params) == 7:
                    (mock_segment, i0, dm, dt, amp, l, m) = params
                    ampslope = 0

                    logger.info(
                        "Adding transient to segment {0} at int {1}, "
                        "DM {2}, dt {3} with amp {4} and l,m={5},{6}".format(
                            mock_segment, i0, dm, dt, amp, l, m))
                elif len(params) == 8:
                    (mock_segment, i0, dm, dt, amp, l, m, ampslope) = params
                    logger.info("Adding transient to segment {0} at int {1}, "
                                " DM {2}, dt {3} with amp {4}-{5} and "
                                "l,m={6},{7}".format(mock_segment, i0, dm, dt,
                                                     amp, amp + ampslope, l,
                                                     m))
                try:
                    model = np.require(np.broadcast_to(
                        util.make_transient_data(
                            st, amp, i0, dm, dt,
                            ampslope=ampslope).transpose()[:, None, :, None],
                        st.datashape),
                                       requirements='W')
                except IndexError:
                    logger.warning(
                        "IndexError while adding transient. Skipping...")
                    continue

                if st.gainfile is not None:
                    model = calibration.apply_telcal(st, model, sign=-1)
                util.phase_shift(model, uvw, -l, -m)
                data += model

    if st.otfcorrections is not None:
        # shift phasecenters to first phasecenter in segment
        if len(st.otfcorrections[segment]) > 1:
            ints, ra0, dec0 = st.otfcorrections[segment][
                0]  # new phase center for segment
            logger.info(
                "Correcting {0} phasecenters to first at RA,Dec = {1},{2}".
                format(len(st.otfcorrections[segment]) - 1, ra0, dec0))
            for ints, ra_deg, dec_deg in st.otfcorrections[segment][1:]:
                l0 = np.radians(ra_deg - ra0)
                m0 = np.radians(dec_deg - dec0)
                util.phase_shift(data, uvw, l0, m0, ints=ints)

    return data
Esempio n. 8
0
def data_prep(st, segment, data, flagversion="latest", returnsoltime=False):
    """ Applies calibration, flags, and subtracts time mean for data.
    flagversion can be "latest" or "rtpipe".
    Optionally prepares data with antenna flags, fixing out of order data,
    calibration, downsampling, OTF rephasing...
    """

    from rfpipe import calibration, flagging

    if not np.any(data):
        return data

    # take pols of interest
    takepol = [st.metadata.pols_orig.index(pol) for pol in st.pols]
    logger.debug('Selecting pols {0} and chans {1}'.format(st.pols, st.chans))

    # TODO: check on reusing 'data' to save memory
    datap = np.nan_to_num(np.require(data, requirements='W').take(takepol, axis=3).take(st.chans, axis=2))
    datap = prep_standard(st, segment, datap)

    if not np.any(datap):
        logger.info("All data zeros after prep_standard")
        return datap

    if st.gainfile is not None:
        logger.info("Applying calibration with {0}".format(st.gainfile))
        ret = calibration.apply_telcal(st, datap, savesols=st.prefs.savesols,
                                       returnsoltime=returnsoltime)
        if returnsoltime:
            datap, soltime = ret
        else:
            datap = ret

        if not np.any(datap):
            logger.info("All data zeros after apply_telcal")
            return datap
    else:
        logger.info("No gainfile found, so not applying calibration.")

    # support backwards compatibility for reproducible flagging
    logger.info("Flagging with version: {0}".format(flagversion))
    if flagversion == "latest":
        datap = flagging.flag_data(st, datap)
    elif flagversion == "rtpipe":
        datap = flagging.flag_data_rtpipe(st, datap)

    zerofrac = 1-np.count_nonzero(datap)/datap.size
    if zerofrac > 0.8:
        logger.warning('Flagged {0:.1f}% of data. Zeroing all if greater than 80%.'.format(zerofrac*100))
        return np.array([])

    if st.prefs.timesub == 'mean':
        logger.info('Subtracting mean visibility in time.')
        datap = util.meantsub(datap, parallel=st.prefs.nthread > 1)
    else:
        logger.info('No visibility subtraction done.')

    if (st.prefs.apply_chweights or st.prefs.apply_blweights) and st.readints > 3:
        if st.prefs.apply_chweights:
            # TODO: find better estimator. Currently loses sensitivity to FRB 121102 bursts.
            chvar = np.std(np.abs(datap).mean(axis=1), axis=0)
            chvar_norm = np.mean(1/chvar**2, axis=0)

        if st.prefs.apply_blweights:
            blvar = np.std(np.abs(datap).mean(axis=2), axis=0)
            blvar_norm = np.mean(1/blvar**2, axis=0)

        if st.prefs.apply_chweights:
            logger.info('Reweighting data by channel variances')
            datap = (datap/chvar[None, None, :, :])/chvar_norm[None, None, None, :]

        if st.prefs.apply_blweights:
            logger.info('Reweighting data by baseline variances')
            datap = (datap/blvar[None, :, None, :])/blvar_norm[None, None, None, :]

    if st.prefs.savenoise:
        save_noise(st, segment, datap)

    if returnsoltime:
        return datap, soltime
    else:
        return datap
Esempio n. 9
0
def prep_standard(st, segment, data):
    """ Common first data prep stages, incl
    online flags, resampling, and mock transients.
    """

    from rfpipe import calibration, flagging

    if not np.any(data):
        return data

    # read and apply flags for given ant/time range. 0=bad, 1=good
    if st.prefs.applyonlineflags and st.metadata.datasource in ['vys', 'sdm']:
        flags = flagging.getonlineflags(st, segment)
        data = np.where(flags[None, :, None, None], data, 0j)
    else:
        logger.info('Not applying online flags.')

    if not np.any(data):
        return data

    if st.prefs.simulated_transient is not None or st.otfcorrections is not None:
        uvw = util.get_uvw_segment(st, segment)

    # optionally integrate (downsample)
    if ((st.prefs.read_tdownsample > 1) or (st.prefs.read_fdownsample > 1)):
        data2 = np.zeros(st.datashape, dtype='complex64')
        if st.prefs.read_tdownsample > 1:
            logger.info('Downsampling in time by {0}'
                        .format(st.prefs.read_tdownsample))
            for i in range(st.datashape[0]):
                data2[i] = data[
                    i*st.prefs.read_tdownsample:(i+1)*st.prefs.read_tdownsample].mean(axis=0)
        if st.prefs.read_fdownsample > 1:
            logger.info('Downsampling in frequency by {0}'
                        .format(st.prefs.read_fdownsample))
            for i in range(st.datashape[2]):
                data2[:, :, i, :] = data[:, :, i*st.prefs.read_fdownsample:(i+1)*st.prefs.read_fdownsample].mean(axis=2)
        data = data2

    # optionally add transients
    if st.prefs.simulated_transient is not None:
        # for an int type, overload prefs.simulated_transient random mocks
        if isinstance(st.prefs.simulated_transient, int):
            logger.info("Filling simulated_transient with {0} random transients"
                        .format(st.prefs.simulated_transient))
            st.prefs.simulated_transient = util.make_transient_params(st, segment=segment,
                                                                      ntr=st.prefs.simulated_transient,
                                                                      data=data)

        assert isinstance(st.prefs.simulated_transient, list), "Simulated transient must be list of tuples."

        for params in st.prefs.simulated_transient:
            assert len(params) == 7 or len(params) == 8, ("Transient requires 7 or 8 parameters: "
                                                          "(segment, i0/int, dm/pc/cm3, dt/s, "
                                                          "amp/sys, dl/rad, dm/rad) and optionally "
                                                          "ampslope/sys")
            if segment == params[0]:
                if len(params) == 7:
                    (mock_segment, i0, dm, dt, amp, l, m) = params
                    ampslope = 0

                    logger.info("Adding transient to segment {0} at int {1}, "
                                "DM {2}, dt {3} with amp {4} and l,m={5},{6}"
                                .format(mock_segment, i0, dm, dt, amp, l, m))
                elif len(params) == 8:
                    (mock_segment, i0, dm, dt, amp, l, m, ampslope) = params
                    logger.info("Adding transient to segment {0} at int {1}, "
                                " DM {2}, dt {3} with amp {4}-{5} and "
                                "l,m={6},{7}"
                                .format(mock_segment, i0, dm, dt, amp,
                                        amp+ampslope, l, m))
                try:
                    model = np.require(np.broadcast_to(util.make_transient_data(st, amp, i0, dm, dt, ampslope=ampslope)
                                                       .transpose()[:, None, :, None],
                                                       st.datashape),
                                       requirements='W')
                except IndexError:
                    logger.warning("IndexError while adding transient. Skipping...")
                    continue

                if st.gainfile is not None:
                    model = calibration.apply_telcal(st, model, sign=-1)
                util.phase_shift(model, uvw, -l, -m)
                data += model

    if st.otfcorrections is not None:
        # shift phasecenters to first phasecenter in segment
        if len(st.otfcorrections[segment]) > 1:
            ints, ra0, dec0 = st.otfcorrections[segment][0]  # new phase center for segment
            logger.info("Correcting {0} phasecenters to first at RA,Dec = {1},{2}"
                        .format(len(st.otfcorrections[segment])-1, ra0, dec0))
            for ints, ra_deg, dec_deg in st.otfcorrections[segment][1:]:
                l0 = np.radians(ra_deg-ra0)
                m0 = np.radians(dec_deg-dec0)
                util.phase_shift(data, uvw, l0, m0, ints=ints)

    return data
Esempio n. 10
0
def make_transient_params(st, ntr=1, segment=None, dmind=None, dtind=None,
                          i=None, amp=None, lm=None, snr=None, data=None):
    """ Given a state, create ntr randomized detectable transients.
    Returns list of ntr tuples of parameters.
    If data provided, it is used to inject transient at fixed SNR.
    selects random value from dmarr and dtarr.
    Mock transient will have either l or m equal to 0.
    Option exists to overload random selection with fixed segment, dmind, etc.
    """

    segment0 = segment
    dmind0 = dmind
    dtind0 = dtind
    i0 = i
    amp0 = amp
    lm0 = lm
    snr0 = snr

    mocks = []
    for tr in range(ntr):
        if segment is None:
            segment = random.choice(range(st.nsegment))

        if dmind is not None:
            dm = st.dmarr[dmind]
#            dmind = random.choice(range(len(st.dmarr)))
        else:
            dm = np.random.uniform(min(st.dmarr), max(st.dmarr)) # pc /cc

            dmarr = np.array(calc_dmarr(st))
            if dm > np.max(dmarr):
                logging.warning("Dm of injected transient is greater than the max DM searched.")
                dmind = len(dmarr) - 1
            else:
                dmind = np.argmax(dmarr>dm)
            

        if dtind is not None:
            dt = st.inttime*min(st.dtarr[dtind], 2)  # dt>2 not yet supported
        else:
            #dtind = random.choice(range(len(st.dtarr)))
            dt = st.inttime*np.random.uniform(min(st.dtarr), max(st.dtarr)) # s  #like an alias for "dt"
            if dt < st.inttime:
                dtind = 0
            else:    
                dtind = int(np.log2(dt/st.inttime))
                if dtind >= len(st.dtarr):
                    dtind = len(st.dtarr) - 1
                    logging.warning("Width of transient is greater than max dt searched.")


# TODO: add support for arb dm/dt
#        dm = random.uniform(min(st.dmarr), max(st.dmarr))
#        dt = random.uniform(min(st.dtarr), max(st.dtarr))

        if i is None:
            i = random.choice(st.get_search_ints(segment, dmind, dtind))

        if amp is None:
            if data is None:
                amp = random.uniform(0.1, 0.5)
                logger.info("Setting mock amp to {0}".format(amp))
            else:
                if snr is None:
                    snr = random.uniform(10, 50)
                    logger.info("Setting mock snr to {0}".format(snr))
                    # TODO: support flagged data in size calc and injection
                datap = calibration.apply_telcal(st, data)
                sig = madtostd(datap[i].real)/np.sqrt(datap[i].size*st.dtarr[dtind])
                amp = snr*sig
                logger.info("Setting mock amp as {0}*{1}={2}".format(snr, sig, amp))
                

        if lm is None:
            # flip a coin to set either l or m
            if random.choice([0, 1]):
                l = math.radians(random.uniform(-st.fieldsize_deg/2,
                                                st.fieldsize_deg/2))
                m = 0.
            else:
                l = 0.
                m = math.radians(random.uniform(-st.fieldsize_deg/2,
                                                st.fieldsize_deg/2))
        else:
            assert len(lm) == 2, "lm must be 2-tuple"
            l, m = lm

        mocks.append((segment, i, dm, dt, amp, l, m))
        (segment, dmind, dtind, i, amp, lm, snr) = (segment0, dmind0, dtind0, i0,
                                               amp0, lm0, snr0)

    return mocks