Ejemplo n.º 1
0
    def time_infinite(self):
        """ Time in mjd where burst is at infinite frequency
        """

        delay = util.calc_delay(1e5, self.time_top, self.dmarr[self.loc[2]],
                                self.state.inttime)
        return self.time_top - delay
Ejemplo n.º 2
0
def pipeline_seg2(st, segment, cfile=None, vys_timeout=vys_timeout_default):
    """ Submit pipeline processing of a single segment to scheduler.
    No multi-threading or scheduling.
    """

    # plan fft
    wisdom = search.set_wisdom(st.npixx, st.npixy)

    data = source.read_segment(st, segment, timeout=vys_timeout, cfile=cfile)
    data_prep = source.data_prep(st, data)

    for dmind in range(len(st.dmarr)):
        delay = util.calc_delay(st.freq, st.freq.max(), st.dmarr[dmind],
                                st.inttime)
        data_dm = search.dedisperse(data_prep, delay)

        for dtind in range(len(st.dtarr)):
            data_dmdt = search.resample(data_dm, st.dtarr[dtind])
            canddatalist = search.search_thresh(st,
                                                data_dmdt,
                                                segment,
                                                dmind,
                                                dtind,
                                                wisdom=wisdom)

            features = candidates.calc_features(canddatalist)
            search.save_cands(st, features, canddatalist)
Ejemplo n.º 3
0
def pipeline_datacorrect(st, candloc, data_prep=None):
    """ Prepare and correct for dm and dt sampling of a given candloc
    Can optionally pass in prepared (flagged, calibrated) data, if available.
    """

    from rfpipe import util
    import rfpipe.search

    if data_prep is None:
        data_prep = pipeline_dataprep(st, candloc)

    segment, candint, dmind, dtind, beamnum = candloc
    dt = st.dtarr[dtind]
    dm = st.dmarr[dmind]

    scale = None
    if hasattr(st, "rtpipe_version"):
        scale = 4.2e-3 if st.rtpipe_version <= 1.54 else None
    delay = util.calc_delay(st.freq,
                            st.freq.max(),
                            dm,
                            st.inttime,
                            scale=scale)

    data_dmdt = rfpipe.search.dedisperseresample(
        data_prep,
        delay,
        dt,
        parallel=st.prefs.nthread > 1,
        resamplefirst=st.fftmode == 'cuda')

    return data_dmdt
Ejemplo n.º 4
0
def pipeline_datacorrect(st, candloc, data_prep=None):
    """ Prepare and correct for dm and dt sampling of a given candloc
    Can optionally pass in prepared (flagged, calibrated) data, if available.
    """

    if data_prep is None:
        data_prep = pipeline_dataprep(st, candloc)

    segment, candint, dmind, dtind, beamnum = candloc.astype(int)
    dt = st.dtarr[dtind]
    dm = st.dmarr[dmind]

    scale = None
    if hasattr(st, "rtpipe_version"):
        scale = 4.2e-3 if st.rtpipe_version <= 1.54 else None
    delay = util.calc_delay(st.freq,
                            st.freq.max(),
                            dm,
                            st.inttime,
                            scale=scale)

    data_dm = search.dedisperse(data_prep, delay)
    data_dmdt = search.resample(data_dm, dt)

    return data_dmdt
Ejemplo n.º 5
0
def calc_cand_integration(cd):
    """
    
    Calculates the integration of the candidate using the canddata.
    The integration is calcualted with respect to the start of the SDM, 
    referenced to the max frequency of the band. 
    
    """
    segment, candint, dmind, dtind, beamnum = cd.loc
    st = cd.state
    dm = st.dmarr[dmind]
    name = cd.state.metadata.filename.split('/')[-1] + '/'
    integration = candint
    dt = cd.state.dtarr[dtind]
    inttime = cd.state.inttime

    orig_max_freq = np.max(cd.state.metadata.freq_orig)
    obs_max_freq = np.max(cd.state.freq)
    delay_ints = calc_delay(freq=[obs_max_freq],
                            freqref=orig_max_freq,
                            dm=dm,
                            inttime=cd.state.inttime)[0]

    i = segment * cd.state.readints + integration * cd.state.dtarr[
        dtind] - delay_ints - 1
    return i, name, dm
Ejemplo n.º 6
0
    def dmshifts(self):
        """ Calculate max DM delay in units of integrations for each dm trial.
        Gets cached.
        """

        if not hasattr(self, '_dmshifts'):
            self._dmshifts = [util.calc_delay(self.freq, self.freq.max(), dm,
                              self.inttime).max()
                              for dm in self.dmarr]
        return self._dmshifts
Ejemplo n.º 7
0
    def dmshifts(self):
        """ Calculate max DM delay in units of integrations for each dm trial.
        Gets cached.
        """

        if not hasattr(self, '_dmshifts'):
            self._dmshifts = [
                util.calc_delay(self.freq, self.freq.max(), dm,
                                self.inttime).max() for dm in self.dmarr
            ]
        return self._dmshifts
Ejemplo n.º 8
0
    def t_overlap(self):
        """ Max DM delay in seconds that is fixed to int mult of integration time.
        Gets cached. """

        if not hasattr(self, '_t_overlap'):
            from rfpipe import util
            self._t_overlap = util.calc_delay(self.freq.take([0]),
                                              self.freq.max(),
                                              max(self.dmarr),
                                              self.inttime)[0]*self.inttime

        return self._t_overlap
Ejemplo n.º 9
0
    def dmshifts(self):
        """ Calculate max DM delay in units of integrations for each dm trial.
        Gets cached.
        """

        if not hasattr(self, '_dmshifts'):
            from rfpipe import util
            self._dmshifts = [util.calc_delay(self.freq.take([0]),
                                              self.freq.max(),
                                              dm, self.inttime)[0]
                              for dm in self.dmarr]
        return self._dmshifts
Ejemplo n.º 10
0
def prepare_data(sdmfile,
                 gainfile,
                 delta_l,
                 delta_m,
                 segment=0,
                 dm=0,
                 dt=1,
                 spws=None):
    """
    
    Applies Calibration, flagging, dedispersion and other data preparation steps
    from rfpipe. Then phaseshifts the data to the location of the candidate. 
    
    """
    st = state.State(sdmfile=sdmfile,
                     sdmscan=1,
                     inprefs={
                         'gainfile': gainfile,
                         'workdir': '.',
                         'maxdm': 0,
                         'flaglist': []
                     },
                     showsummary=False)
    if spws:
        st.prefs.spw = spws

    data = source.read_segment(st, segment)

    takepol = [st.metadata.pols_orig.index(pol) for pol in st.pols]
    takebls = [
        st.metadata.blarr_orig.tolist().index(list(bl)) for bl in st.blarr
    ]
    datap = np.require(data, requirements='W').take(takepol, axis=3).take(
        st.chans, axis=2).take(takebls, axis=1)
    datap = source.prep_standard(st, segment, datap)
    datap = calibration.apply_telcal(st, datap)
    datap = flagging.flag_data(st, datap)

    delay = calc_delay(st.freq, st.freq.max(), dm, st.inttime)
    data_dmdt = dedisperseresample(datap, delay, dt)

    print(f'shape of data_dmdt is {data_dmdt.shape}')

    uvw = get_uvw_segment(st, segment)
    phase_shift(data_dmdt, uvw=uvw, dl=delta_l, dm=delta_m)

    dataret = data_dmdt
    return dataret, st
Ejemplo n.º 11
0
def pipeline_datacorrect(st, candloc, data_prep=None):
    """ Prepare and correct for dm and dt sampling of a given candloc
    Can optionally pass in prepared (flagged, calibrated) data, if available.
    """

    if data_prep is None:
        data_prep = pipeline_dataprep(st, candloc)

    segment, candint, dmind, dtind, beamnum = candloc
    dt = st.dtarr[dtind]
    dm = st.dmarr[dmind]

    scale = None
    if hasattr(st, "rtpipe_version"):
        scale = 4.2e-3 if st.rtpipe_version <= 1.54 else None
    delay = util.calc_delay(st.freq, st.freq.max(), dm, st.inttime,
                            scale=scale)

    data_dmdt = rfpipe.search.dedisperseresample(data_prep, delay, dt,
                                                 parallel=st.prefs.nthread > 1,
                                                 resamplefirst=st.fftmode=='cuda')

    return data_dmdt
Ejemplo n.º 12
0
def dedisperse_search_cuda(st, segment, data, devicenum=None):
    """ Run dedispersion, resample for all dm and dt.
    Grid and image on GPU.
    rfgpu is built from separate repo.
    Uses state to define integrations to image based on segment, dm, and dt.
    devicenum can force the gpu to use, but can be inferred via distributed.
    """

    assert st.dtarr[0] == 1, "st.dtarr[0] assumed to be 1"
    assert all([st.dtarr[dtind]*2 == st.dtarr[dtind+1]
                for dtind in range(len(st.dtarr)-1)]), ("dtarr must increase "
                                                        "by factors of 2")

    if not np.any(data):
        logger.info("Data is all zeros. Skipping search.")
        return candidates.CandCollection(prefs=st.prefs,
                                         metadata=st.metadata)

    if devicenum is None:
        # assume first gpu, but try to infer from worker name
        devicenum = 0
        try:
            from distributed import get_worker
            name = get_worker().name
            devicenum = int(name.split('g')[1])
            logger.debug("Using name {0} to set GPU devicenum to {1}"
                         .format(name, devicenum))
        except IndexError:
            logger.warn("Could not parse worker name {0}. Using default GPU devicenum {1}"
                        .format(name, devicenum))
        except ValueError:
            logger.warn("No worker found. Using default GPU devicenum {0}"
                        .format(devicenum))
        except ImportError:
            logger.warn("distributed not available. Using default GPU devicenum {0}"
                        .format(devicenum))

    rfgpu.cudaSetDevice(devicenum)

    beamnum = 0
    uvw = util.get_uvw_segment(st, segment)

    upix = st.npixx
    vpix = st.npixy//2 + 1

    grid = rfgpu.Grid(st.nbl, st.nchan, st.readints, upix, vpix)
    image = rfgpu.Image(st.npixx, st.npixy)
    image.add_stat('rms')
    image.add_stat('pix')

    # Data buffers on GPU
    vis_raw = rfgpu.GPUArrayComplex((st.nbl, st.nchan, st.readints))
    vis_grid = rfgpu.GPUArrayComplex((upix, vpix))
    img_grid = rfgpu.GPUArrayReal((st.npixx, st.npixy))

    # Convert uv from lambda to us
    u, v, w = uvw
    u_us = 1e6*u[:, 0]/(1e9*st.freq[0])
    v_us = 1e6*v[:, 0]/(1e9*st.freq[0])

    # Q: set input units to be uv (lambda), freq in GHz?
    grid.set_uv(u_us, v_us)  # u, v in us
    grid.set_freq(st.freq*1e3)  # freq in MHz
    grid.set_cell(st.uvres)  # uv cell size in wavelengths (== 1/FoV(radians))

    # Compute gridding transform
    grid.compute()

    # move Stokes I data in (assumes dual pol data)
    vis_raw.data[:] = np.rollaxis(data.mean(axis=3), 0, 3)
    vis_raw.h2d()  # Send it to GPU memory

    grid.conjugate(vis_raw)

    # some prep if kalman filter is to be applied
    if st.prefs.searchtype in ['imagek']:
        # TODO: check that this is ok if pointing at bright source
        spec_std = data.real.mean(axis=1).mean(axis=2).std(axis=0)
        sig_ts, kalman_coeffs = kalman_prepare_coeffs(spec_std)
        if not np.all(sig_ts):
            logger.info("sig_ts all zeros. Skipping search.")
            return candidates.CandCollection(prefs=st.prefs,
                                             metadata=st.metadata)

    # place to hold intermediate result lists
    canddict = {}
    canddict['candloc'] = []
    for feat in st.features:
        canddict[feat] = []

    for dtind in range(len(st.dtarr)):
        if dtind > 0:
            grid.downsample(vis_raw)

        for dmind in range(len(st.dmarr)):
            delay = util.calc_delay(st.freq, st.freq.max(), st.dmarr[dmind],
                                    st.inttime)

            grid.set_shift(delay >> dtind)  # dispersion shift per chan in samples

            integrations = st.get_search_ints(segment, dmind, dtind)
            if len(integrations) == 0:
                continue
            minint = min(integrations)
            maxint = max(integrations)

            logger.info('Imaging {0} ints ({1}-{2}) in seg {3} at DM/dt {4:.1f}/{5}'
                        ' with image {6}x{7} (uvres {8}) with gpu {9}'
                        .format(len(integrations), minint, maxint, segment,
                                st.dmarr[dmind], st.dtarr[dtind], st.npixx,
                                st.npixy, st.uvres, devicenum))

            for i in integrations:
                # grid and FFT
                grid.operate(vis_raw, vis_grid, i)
                image.operate(vis_grid, img_grid)

                # calc snr
                stats = image.stats(img_grid)
                if stats['rms'] != 0.:
                    snr1 = stats['max']/stats['rms']
                else:
                    snr1 = 0.
                    logger.warn("rfgpu rms is 0 in int {0}. Skipping.".format(i))

                # threshold image
                if snr1 > st.prefs.sigma_image1:
                    candloc = (segment, i, dmind, dtind, beamnum)

                    xpeak = stats['xpeak']
                    ypeak = stats['ypeak']
                    l1, m1 = st.pixtolm((xpeak+st.npixx//2, ypeak+st.npixy//2))

                    if st.prefs.searchtype == 'image':
                        logger.info("Got one! SNR1 {0:.1f} candidate at {1} and (l, m) = ({2},{3})"
                                    .format(snr1, candloc, l1, m1))
                        canddict['candloc'].append(candloc)
                        canddict['l1'].append(l1)
                        canddict['m1'].append(m1)
                        canddict['snr1'].append(snr1)
                        canddict['immax1'].append(stats['max'])

                    elif st.prefs.searchtype == 'imagek':
                        # TODO: implement phasing on GPU
                        data_corr = dedisperseresample(data, delay,
                                                       st.dtarr[dtind],
                                                       parallel=st.prefs.nthread > 1,
                                                       resamplefirst=st.fftmode=='cuda')
                        spec = data_corr.take([i], axis=0)
                        util.phase_shift(spec, uvw, l1, m1)
                        spec = spec[0].real.mean(axis=2).mean(axis=0)

                        # TODO: this significance can be biased low if averaging in long baselines that are not phased well
                        # TODO: spec should be calculated from baselines used to measure l,m?
                        significance_kalman = kalman_significance(spec,
                                                                  spec_std,
                                                                  sig_ts=sig_ts,
                                                                  coeffs=kalman_coeffs)
                        snrk = (2*significance_kalman)**0.5
                        snrtot = (snrk**2 + snr1**2)**0.5
                        if snrtot > (st.prefs.sigma_kalman**2 + st.prefs.sigma_image1**2)**0.5:
                            logger.info("Got one! SNR1 {0:.1f} and SNRk {1:.1f} candidate at {2} and (l,m) = ({3},{4})"
                                        .format(snr1, snrk, candloc, l1, m1))
                            canddict['candloc'].append(candloc)
                            canddict['l1'].append(l1)
                            canddict['m1'].append(m1)
                            canddict['snr1'].append(snr1)
                            canddict['immax1'].append(stats['max'])
                            canddict['snrk'].append(snrk)
                    elif st.prefs.searchtype == 'armkimage':
                        raise NotImplementedError
                    elif st.prefs.searchtype == 'armk':
                        raise NotImplementedError
                    else:
                        logger.warn("searchtype {0} not recognized"
                                    .format(st.prefs.searchtype))

    cc = candidates.make_candcollection(st, **canddict)
    logger.info("First pass found {0} candidates in seg {1}."
                .format(len(cc), segment))

    if st.prefs.clustercands is not None:
        cc = candidates.cluster_candidates(cc)

    if st.prefs.savecands or st.prefs.saveplots:
        # triggers optional plotting and saving
        cc = reproduce_candcollection(cc, data)

    candidates.save_cands(st, candcollection=cc)

    return cc
Ejemplo n.º 13
0
def dedisperse_search_fftw(st, segment, data, wisdom=None):
    """ Fuse the dediserpse, resample, search, threshold functions.
    Returns list of CandData objects that define candidates with
    candloc, image, and phased visibility data.
    Integrations can define subset of all available in data to search.
    Default will take integrations not searched in neighboring segments.

    ** only supports threshold > image max (no min)
    ** dmind, dtind, beamnum assumed to represent current state of data
    """

    if not np.any(data):
        logger.info("Data is all zeros. Skipping search.")
        return candidates.CandCollection(prefs=st.prefs,
                                         metadata=st.metadata)

    # some prep if kalman filter is to be applied
    if st.prefs.searchtype in ['imagek', 'armk', 'armkimage']:
        # TODO: check that this is ok if pointing at bright source
        spec_std = data.real.mean(axis=1).mean(axis=2).std(axis=0)
        sig_ts, kalman_coeffs = kalman_prepare_coeffs(spec_std)

    beamnum = 0
    uvw = util.get_uvw_segment(st, segment)

    # place to hold intermediate result lists
    canddict = {}
    canddict['candloc'] = []
    for feat in st.features:
        canddict[feat] = []

    for dtind in range(len(st.dtarr)):
        for dmind in range(len(st.dmarr)):
            # set search integrations
            integrations = st.get_search_ints(segment, dmind, dtind)
            if len(integrations) == 0:
                continue
            minint = min(integrations)
            maxint = max(integrations)

            logger.info('{0} search of {1} ints ({2}-{3}) in seg {4} at DM/dt '
                        '{5:.1f}/{6} with image {7}x{8} (uvres {9}) with fftw'
                        .format(st.prefs.searchtype, len(integrations), minint,
                                maxint, segment, st.dmarr[dmind],
                                st.dtarr[dtind], st.npixx,
                                st.npixy, st.uvres))

            # correct data
            delay = util.calc_delay(st.freq, st.freq.max(), st.dmarr[dmind],
                                    st.inttime)
            data_corr = dedisperseresample(data, delay, st.dtarr[dtind],
                                           parallel=st.prefs.nthread > 1,
                                           resamplefirst=st.fftmode=='cuda')

            # run search
            if st.prefs.searchtype in ['image', 'imagek']:
                images = grid_image(data_corr, uvw, st.npixx, st.npixy, st.uvres,
                                    'fftw', st.prefs.nthread, wisdom=wisdom,
                                    integrations=integrations)

                for i, image in enumerate(images):
                    immax1 = image.max()
                    snr1 = immax1/image.std()
                    if snr1 > st.prefs.sigma_image1:
                        candloc = (segment, integrations[i], dmind, dtind, beamnum)
                        l1, m1 = st.pixtolm(np.where(image == immax1))

                        # if set, use sigma_kalman as second stage filter
                        if st.prefs.searchtype == 'imagek':
                            spec = data_corr.take([integrations[i]], axis=0)
                            util.phase_shift(spec, uvw, l1, m1)
                            spec = spec[0].real.mean(axis=2).mean(axis=0)
                            # TODO: this significance can be biased low if averaging in long baselines that are not phased well
                            # TODO: spec should be calculated from baselines used to measure l,m?
                            significance_kalman = kalman_significance(spec,
                                                                      spec_std,
                                                                      sig_ts=sig_ts,
                                                                      coeffs=kalman_coeffs)
                            snrk = (2*significance_kalman)**0.5
                            snrtot = (snrk**2 + snr1**2)**0.5
                            if snrtot > (st.prefs.sigma_kalman**2 + st.prefs.sigma_image1**2)**0.5:
                                logger.info("Got one! SNR1 {0:.1f} and SNRk {1:.1f} candidate at {2} and (l,m) = ({3},{4})"
                                            .format(snr1, snrk, candloc, l1, m1))
                                canddict['candloc'].append(candloc)
                                canddict['l1'].append(l1)
                                canddict['m1'].append(m1)
                                canddict['snr1'].append(snr1)
                                canddict['immax1'].append(immax1)
                                canddict['snrk'].append(snrk)
                        elif st.prefs.searchtype == 'image':
                            logger.info("Got one! SNR1 {0:.1f} candidate at {1} and (l, m) = ({2},{3})"
                                        .format(snr1, candloc, l1, m1))
                            canddict['candloc'].append(candloc)
                            canddict['l1'].append(l1)
                            canddict['m1'].append(m1)
                            canddict['snr1'].append(snr1)
                            canddict['immax1'].append(immax1)

            elif st.prefs.searchtype in ['armkimage', 'armk']:
                armk_candidates = search_thresh_armk(st, data_corr, uvw,
                                                     integrations=integrations,
                                                     spec_std=spec_std,
                                                     sig_ts=sig_ts,
                                                     coeffs=kalman_coeffs)

                for candind, snrarms, snrk, armloc, peakxy, lm in armk_candidates:
                    candloc = (segment, candind, dmind, dtind, beamnum)

                    # if set, use sigma_kalman as second stage filter
                    if st.prefs.searchtype == 'armkimage':
                        image = grid_image(data_corr, uvw, st.npixx_full,
                                           st.npixy_full, st.uvres, 'fftw',
                                           st.prefs.nthread,
                                           wisdom=wisdom, integrations=candind)
                        peakx, peaky = np.where(image[0] == image[0].max())
                        l1, m1 = st.calclm(st.npixx_full, st.npixy_full,
                                           st.uvres, peakx[0], peaky[0])
                        immax1 = image.max()
                        snr1 = immax1/image.std()
                        if snr1 > st.prefs.sigma_image1:
                            logger.info("Got one! SNRarms {0:.1f} and SNRk "
                                        "{1:.1f} and SNR1 {2:.1f} candidate at"
                                        " {3} and (l,m) = ({4},{5})"
                                        .format(snrarms, snrk, snr1,
                                                candloc, l1, m1))
                            canddict['candloc'].append(candloc)
                            canddict['l1'].append(l1)
                            canddict['m1'].append(m1)
                            canddict['snrarms'].append(snrarms)
                            canddict['snrk'].append(snrk)
                            canddict['snr1'].append(snr1)
                            canddict['immax1'].append(immax1)

                    elif st.prefs.searchtype == 'armk':
                        l1, m1 = lm
                        logger.info("Got one! SNRarms {0:.1f} and SNRk {1:.1f} "
                                    "candidate at {2} and (l,m) = ({3},{4})"
                                    .format(snrarms, snrk, candloc, l1, m1))
                        canddict['candloc'].append(candloc)
                        canddict['l1'].append(l1)
                        canddict['m1'].append(m1)
                        canddict['snrarms'].append(snrarms)
                        canddict['snrk'].append(snrk)
            else:
                raise NotImplemented("only searchtype=image, imagek, armk, armkimage implemented")

    cc = candidates.make_candcollection(st, **canddict)
    logger.info("First pass found {0} candidates in seg {1}."
                .format(len(cc), segment))

    if st.prefs.clustercands is not None:
        cc = candidates.cluster_candidates(cc)

    if st.prefs.savecands or st.prefs.saveplots:
        # triggers optional plotting and saving
        cc = reproduce_candcollection(cc, data)

    candidates.save_cands(st, candcollection=cc)

    return cc
Ejemplo n.º 14
0
def cd_refined_plot(cd, devicenum, nsubbands=4, mode='CPU', frbprob=None):
    """ Use canddata object to create refinement plot of subbanded SNR and dm-time plot.
    """

    import rfpipe.search
    from rfpipe import util
    from matplotlib import gridspec
    import pylab as plt
    import matplotlib

    params = {
        'axes.labelsize': 14,
        'font.size': 9,
        'legend.fontsize': 12,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'text.usetex': False,
        'figure.figsize': [12, 10]
    }
    matplotlib.rcParams.update(params)

    segment, candint, dmind, dtind, beamnum = cd.loc
    st = cd.state
    scanid = cd.state.metadata.scanId
    width_m = st.dtarr[dtind]
    timewindow = st.prefs.timewindow
    tsamp = st.inttime * width_m
    dm = st.dmarr[dmind]
    ft_dedisp = np.flip((cd.data.real.sum(axis=2).T), axis=0)
    chan_freqs = np.flip(st.freq * 1000, axis=0)  # from high to low, MHz
    nf, nt = np.shape(ft_dedisp)

    candloc = cd.loc

    logger.debug('Size of the FT array is ({0}, {1})'.format(nf, nt))

    try:
        assert nt > 0
    except AssertionError as err:
        logger.exception("Number of time bins is equal to 0")
        raise err

    try:
        assert nf > 0
    except AssertionError as err:
        logger.exception("Number of frequency bins is equal to 0")
        raise err

    roll_to_center = nt // 2 - cd.integration_rel
    ft_dedisp = np.roll(ft_dedisp, shift=roll_to_center, axis=1)

    # If timewindow is not set during search, set it equal to the number of time bins of candidate
    if nt != timewindow:
        logger.info('Setting timewindow equal to nt = {0}'.format(nt))
        timewindow = nt
    else:
        logger.info('Timewindow length is {0}'.format(timewindow))

    try:
        assert nf == len(chan_freqs)
    except AssertionError as err:
        logger.exception(
            "Number of frequency channel in data should match the frequency list"
        )
        raise err

    if dm is not 0:
        dm_start = 0
        dm_end = 2 * dm
    else:
        dm_start = -10
        dm_end = 10

    logger.info(
        'Generating DM-time for candid {0} in DM range {1:.2f}--{2:.2f} pc/cm3'
        .format(cd.candid, dm_start, dm_end))

    logger.info("Using gpu devicenum: {0}".format(devicenum))
    os.environ['CUDA_VISIBLE_DEVICES'] = str(devicenum)

    dmt = rfpipe.search.make_dmt(ft_dedisp,
                                 dm_start - dm,
                                 dm_end - dm,
                                 256,
                                 chan_freqs / 1000,
                                 tsamp,
                                 mode=mode,
                                 devicenum=int(devicenum))

    delay = util.calc_delay(chan_freqs / 1000,
                            chan_freqs.max() / 1000, -1 * dm, tsamp)
    dispersed = rfpipe.search.dedisperse_roll(ft_dedisp, delay)
    #    dispersed = disperse(ft_dedisp, -1*dm, chan_freqs/1000, tsamp)

    im = cd.image
    imstd = im.std()  # consistent with rfgpu
    snrim = np.round(im.max() / imstd, 2)
    snrk = np.round(cd.snrk, 2)
    l1, m1 = st.pixtolm(np.where(im == im.max()))

    subsnrs, subts, bands = calc_subband_info(ft_dedisp, chan_freqs, nsubbands)
    logging.info(f'Generating time series of full band')
    ts_full = ft_dedisp.sum(0)
    logging.info(f'Calculating SNR of full band')
    snr_full = calc_snr(ts_full)

    to_print = []
    logging.info(f'{scanid}')
    to_print.append(f"{'.'.join(scanid.split('.')[:3])}. \n")
    to_print.append(f"{'.'.join(scanid.split('.')[3:])}\n")
    logging.info(f'candloc: {candloc}, DM: {dm:.2f}')
    to_print.append(f'candloc: {candloc}, DM: {dm:.2f}\n')
    logging.info(f'Source: {st.metadata.source}')
    to_print.append(f'Source: {st.metadata.source}\n')
    logging.info(f'Subbanded SNRs are:')
    to_print.append(f'Subbanded SNRs are:\n')
    for i in range(nsubbands):
        logging.info(
            f'Band: {chan_freqs[bands[i][0]]:.2f}-{chan_freqs[bands[i][1]-1]:.2f}, SNR: {subsnrs[i]:.2f}'
        )
        to_print.append(
            f'Band: {chan_freqs[bands[i][0]]:.2f}-{chan_freqs[bands[i][1]-1]:.2f}, SNR: {subsnrs[i]:.2f}\n'
        )
    logging.info(f'SNR of full band is: {snr_full:.2f}')
    to_print.append(f'SNR of full band is: {snr_full:.2f}\n')
    logging.info(f'SNR (im/k): {snrim}/{snrk}')
    to_print.append(f'SNR (im/k): {snrim}/{snrk}\n')
    logging.info(f'Clustersize: {cd.clustersize}')
    to_print.append(f'Clustersize: {cd.clustersize}\n')
    if frbprob is not None:
        logging.info(f'frbprob: {frbprob}')
        to_print.append(f'frbprob: {np.round(frbprob, 4)}\n')
    str_print = ''.join(to_print)

    fov = np.degrees(1. / st.uvres) * 60.
    l1arcm = np.degrees(l1) * 60
    m1arcm = np.degrees(m1) * 60

    ts = np.arange(timewindow) * tsamp

    gs = gridspec.GridSpec(4,
                           3,
                           width_ratios=[3.5, 0.1, 3],
                           height_ratios=[1, 1, 1, 1],
                           wspace=0.05,
                           hspace=0.20)
    ax1 = plt.subplot(gs[0, 0])
    ax2 = plt.subplot(gs[1, 0])
    ax3 = plt.subplot(gs[2, 0])
    ax4 = plt.subplot(gs[3, 0])
    ax11 = plt.subplot(gs[0, 1])
    ax22 = plt.subplot(gs[1, 1])
    ax33 = plt.subplot(gs[2, 1])
    ax44 = plt.subplot(gs[3, 1])
    ax5 = plt.subplot(gs[0, 2:3])
    ax6 = plt.subplot(gs[2:4, 2])
    ax7 = plt.subplot(gs[1, 2])

    x_loc = 0.1
    y_loc = 0.5

    for i in range(nsubbands):
        ax1.plot(
            ts,
            subts[i] - subts[i].mean(),
            label=
            f'Band: {chan_freqs[bands[i][0]]:.0f}-{chan_freqs[bands[i][1]-1]:.0f}'
        )
    ax1.plot(ts, subts.sum(0) - subts.sum(0).mean(), 'k.', label='Full Band')
    ax1.legend(loc='upper center',
               bbox_to_anchor=(0.5, 1.45),
               ncol=3,
               fancybox=True,
               shadow=True,
               fontsize=11)
    ax1.set_ylabel('Flux (Arb. units)')
    ax1.set_xlim(np.min(ts), np.max(ts))
    ax11.text(x_loc,
              y_loc,
              'Time Series',
              fontsize=14,
              ha='center',
              va='center',
              wrap=True,
              rotation=-90)
    ax11.axis('off')

    ax2.imshow(ft_dedisp,
               aspect='auto',
               extent=[ts[0], ts[-1],
                       np.min(chan_freqs),
                       np.max(chan_freqs)])
    ax2.set_ylabel('Freq')
    ax22.text(x_loc,
              y_loc,
              'Dedispersed FT',
              fontsize=14,
              ha='center',
              va='center',
              wrap=True,
              rotation=-90)
    ax22.axis('off')

    ax3.imshow(dispersed,
               aspect='auto',
               extent=[ts[0], ts[-1],
                       np.min(chan_freqs),
                       np.max(chan_freqs)])
    ax3.set_ylabel('Freq')
    ax33.text(x_loc,
              y_loc,
              'Original dispersed FT',
              fontsize=14,
              ha='center',
              va='center',
              wrap=True,
              rotation=-90)
    ax33.axis('off')

    ax4.imshow(np.flip(dmt, axis=0),
               aspect='auto',
               extent=[ts[0], ts[-1], dm_start, dm_end])
    ax4.set_xlabel('Time (s)')
    ax4.set_ylabel('DM')
    ax44.text(x_loc,
              y_loc,
              'DM-Time',
              fontsize=14,
              ha='center',
              va='center',
              wrap=True,
              rotation=-90)
    ax44.axis('off')

    # ax5.text(0.02, 0.8, str_print, fontsize=14, ha='left', va='top', wrap=True)
    ax5.text(0.02,
             1.4,
             str_print,
             fontsize=11.5,
             ha='left',
             va='top',
             wrap=True)
    ax5.axis('off')

    _ = ax6.imshow(im.transpose(),
                   aspect='equal',
                   origin='upper',
                   interpolation='nearest',
                   extent=[fov / 2, -fov / 2, -fov / 2, fov / 2],
                   cmap=plt.get_cmap('viridis'),
                   vmin=0,
                   vmax=0.5 * im.max())
    ax6.set_xlabel('RA Offset (arcmin)')
    ax6.set_ylabel('Dec Offset (arcmin)', rotation=-90, labelpad=12)
    ax6.yaxis.tick_right()
    ax6.yaxis.set_label_position("right")
    # to set scale when we plot the triangles that label the location
    ax6.autoscale(False)
    # add markers on the axes at measured position of the candidate
    ax6.scatter(x=[l1arcm],
                y=[-fov / 2],
                c='#ffff00',
                s=60,
                marker='^',
                clip_on=False)
    ax6.scatter(x=[fov / 2],
                y=[m1arcm],
                c='#ffff00',
                s=60,
                marker='>',
                clip_on=False)
    # makes it so the axis does not intersect the location triangles
    ax6.set_frame_on(False)

    sbeam = np.mean(st.beamsize_deg) * 60
    # figure out the location to center the zoomed image on
    xratio = len(im[0]) / fov  # pix/arcmin
    yratio = len(im) / fov  # pix/arcmin
    mult = 5  # sets how many times the synthesized beam the zoomed FOV is
    xmin = max(0, int(len(im[0]) // 2 - (m1arcm + sbeam * mult) * xratio))
    xmax = int(len(im[0]) // 2 - (m1arcm - sbeam * mult) * xratio)
    ymin = max(0, int(len(im) // 2 - (l1arcm + sbeam * mult) * yratio))
    ymax = int(len(im) // 2 - (l1arcm - sbeam * mult) * yratio)
    left, width = 0.231, 0.15
    bottom, height = 0.465, 0.15
    # rect_imcrop = [left, bottom, width, height]
    # ax_imcrop = fig.add_axes(rect_imcrop)
    # logger.debug('{0}'.format(im.transpose()[xmin:xmax, ymin:ymax].shape))
    # logger.debug('{0} {1} {2} {3}'.format(xmin, xmax, ymin, ymax))
    _ = ax7.imshow(im.transpose()[xmin:xmax, ymin:ymax],
                   aspect=1,
                   origin='upper',
                   interpolation='nearest',
                   extent=[-1, 1, -1, 1],
                   cmap=plt.get_cmap('viridis'),
                   vmin=0,
                   vmax=0.5 * im.max())
    # setup the axes
    ax7.set_ylabel('Dec (arcmin)')
    ax7.set_xlabel('RA (arcmin)')
    ax7.xaxis.set_label_position('top')
    # ax7.xaxis.tick_top()
    ax7.yaxis.tick_right()
    # ax7.yaxis.set_label_position("right")
    xlabels = [
        str(np.round(l1arcm + sbeam * mult / 2, 1)), '',
        str(np.round(l1arcm, 1)), '',
        str(np.round(l1arcm - sbeam * mult / 2, 1))
    ]
    ylabels = [
        str(np.round(m1arcm - sbeam * mult / 2, 1)), '',
        str(np.round(m1arcm, 1)), '',
        str(np.round(m1arcm + sbeam * mult / 2, 1))
    ]
    ax7.set_xticklabels(xlabels)
    ax7.set_yticklabels(ylabels)
    # change axis label loc of inset to avoid the full picture
    ax7.get_yticklabels()[0].set_verticalalignment('bottom')
    plt.tight_layout()
    plt.savefig(os.path.join(
        cd.state.prefs.workdir,
        'cands_{0}_refined.png'.format(cd.state.metadata.scanId)),
                bbox_inches='tight')
Ejemplo n.º 15
0
def disperse(data, dm, freqs, inttime):
    from rfpipe import util
    from rfpipe.search import dedisperse_roll
    delay = util.calc_delay(freqs, freqs.max(), dm, inttime)
    return dedisperse_roll(data, delay)