예제 #1
0
    def add_to_data(self, delta_t, freq, data, scintillate=True, bandwidth=300.):
        """ Method to add already-dedispersed pulse 
        to background noise data. Includes frequency-dependent 
        width (smearing, scattering, etc.) and amplitude 
        (scintillation, spectral index). 
        """

        NFREQ = data.shape[0]
        NTIME = data.shape[1]
        tmid = NTIME//2
        freq_mid = freq[len(freq)//2]

        scint_amp = self.scintillation(freq)
        self._fluence /= np.sqrt(NFREQ)
#        stds_perchan = np.std(data)#*np.sqrt(NFREQ)
        SNRTools = tools.SNR_Tools()
        stds_perchan, med = SNRTools.sigma_from_mad(data.flatten())

        for ii, f in enumerate(freq):
            # Do not add FRB to missing channels
            if data[ii].sum()==0:
                continue

            width_ = self.calc_width(self._dm, self._f_ref*1e-3, 
                                            bw=bandwidth, NFREQ=NFREQ,
                                            ti=self._width, tsamp=delta_t, tau=0)

            index_width = max(1, (np.round((width_/ delta_t))).astype(int))
            tpix = int(self.arrival_time(f) / delta_t)

            if abs(tpix) >= tmid:
                # ensure that edges of data are not crossed
                continue

            pp = self.pulse_profile(NTIME, index_width, f, 
                                    tau=self._scat_factor, t0=tpix)

            val = pp.copy()
            val /= val.max()
            val *= self._fluence

#            val *= (100.0 / stds_perchan) # hack
            val /= (width_ / delta_t) 
            val = val * (f / self._f_ref) ** self._spec_ind 

            if scintillate is True:
                val = (0.1 + scint_amp[ii]) * val 
                
            data[ii] += val

#            data[ii, tpix] += 5*np.std(data[ii])
            
            if f == freq_mid:
                width_eff = width_

        return width_eff
예제 #2
0
def proc_trigger(fn_fil, dm0, t0, sig_cut,
                 ndm=50, mk_plot=False, downsamp=1,
                 beamno='', fn_mask=None, nfreq_plot=32,
                 ntime_plot=250,
                 cmap='RdBu', cand_no=1, multiproc=False,
                 rficlean=False, snr_comparison=-1,
                 outdir='./', sig_thresh_local=5.0,
                 subtract_zerodm=False,
                 threshold_time=3.25, threshold_frequency=2.75, 
                 bin_size=32, n_iter_time=3, 
                 n_iter_frequency=3, clean_type='time', 
                 freq=1370.0, sb_generator=None, sb=None, 
                 dumb_mask=True, save_sb_fil=False):
    """ Locate data within filterbank file (fn_fi)
    at some time t0, and dedisperse to dm0, generating
    plots

    Parameters:
    ----------
    fn_fil     : str
        name of filterbank file
    dm0        : float
        trigger dm found by single pulse search software
    t0         : float
        time in seconds where trigger was found
    sig_cut    : np.float
        sigma of detected trigger at (t0, dm0)
    ndm        : int
        number of DMs to use in DM transform
    mk_plot    : bool
        make three-panel plots
    downsamp   : int
        factor by which to downsample in time. comes from searchsoft.
    beamno     : str
        beam number, for fig names
    nfreq_plot : int
        number of frequencies channels to plot
    freq       : int
        central frequency used to find zapped channels file
    sb_generator: SBGenerator object
        synthesized beam mapper from DARC (None for TAB/IAB)
    sb         : int
        synthesized beam to generate (None for TAB/IAB)

    Returns:
    -------
    full_dm_arr_downsamp : np.array
        data array with downsampled dm-transformed intensities
    full_freq_arr_downsamp : np.array
        data array with downsampled freq-time intensities
    """

    if dumb_mask:
        try:
            fndmask='/home/arts/.controller/amber_conf/zapped_channels_{:.0f}.conf'.format(int(freq))
            rfimask = np.loadtxt(fndmask)
            rfimask = rfimask.astype(int)
        except:
            rfimask = np.array([])
            logging.warning("Could not load dumb RFIMask")
    else:
        rfimask = np.array([])

    SNRtools = tools.SNR_Tools()
    downsamp = min(4096, downsamp)

    if downsamp >= 100:
        wideclean = int(downsamp)
    else:
        wideclean = None

    # store path to filterbanks
    if sb is not None:
        prefix_fil = fn_fil
        # get first file
        fn_fil = prefix_fil + '_00.fil'
        if not os.path.exists(fn_fil):
            fn_fil = prefix_fil + '00.fil'
    rawdatafile = filterbank.filterbank(fn_fil)
    dfreq_MHz = rawdatafile.header['foff']

    dt = rawdatafile.header['tsamp']
    freq_up = rawdatafile.header['fch1']
    nfreq = rawdatafile.header['nchans']
    # fix RFI mask order
    rfimask = nfreq - rfimask
    freq_low = freq_up + nfreq*rawdatafile.header['foff']
    ntime_fil = (os.path.getsize(fn_fil) - rawdatafile.header_size)/nfreq
    tdm = np.abs(8.3*1e-6*dm0*dfreq_MHz*(freq_low/1000.)**-3)
    dm_min = max(0, dm0-40)
    dm_max = dm0 + 40
    dms = np.linspace(dm_min, dm_max, ndm, endpoint=True)

    # make sure dm0 is in the array
    dm_max_jj = np.argmin(abs(dms-dm0))
    dms += (dm0-dms[dm_max_jj])
    dms[0] = max(0, dms[0])

    global t_min, t_max
    # if smearing timescale is < 4*pulse width,
    # downsample before dedispersion for speed
    downsamp_smear = max(1, int(downsamp*dt/tdm/2.))
    # ensure that it's not larger than pulse width
    downsamp_smear = int(min(downsamp, downsamp_smear))
    downsamp_res = int(downsamp//downsamp_smear)
    downsamp = int(downsamp_res*downsamp_smear)
    time_res = dt * downsamp
    logging.info("Width_full:%d  Width_smear:%d  Width_res: %d" %
                 (downsamp, downsamp_smear, downsamp_res))

    start_bin = int(t0/dt - ntime_plot*downsamp//2)
    width = abs(4.148e3 * dm0 * (freq_up**-2 - freq_low**-2))
    chunksize = int(width/dt + ntime_plot*downsamp)

    t_min, t_max = 0, ntime_plot*downsamp

    if start_bin < 0:
        extra = start_bin
        start_bin = 0
        t_min += extra
        t_max += extra

    t_min, t_max = int(t_min), int(t_max)

    snr_max = 0

    # Account for the pre-downsampling to speed up dedispersion
    t_min /= downsamp_smear
    t_max /= downsamp_smear
    ntime = t_max-t_min

    if ntime_fil < (start_bin+chunksize):
        logging.info("Trigger at end of file, skipping")
        return [], [], [], []

    # get data of all files (SB) or one file (TAB/IAB)
    if sb is not None:
        ntab = 12
        data = np.zeros((ntab, nfreq, chunksize))
        # get list of unique TABs in required SB
        sb_map = list(set(sb_generator.get_map(sb)))
        logging.info("SB {} consists of TABs {}".format(sb, sb_map))
        threads = []
        for tab in range(ntab):
            # skip if we do not need this TAB
            if not tab in sb_map:
                continue
            fname = prefix_fil + '_{:02d}.fil'.format(tab)
            if not os.path.exists(fname):
                fname = prefix_fil + '{:02d}.fil'.format(tab)
            load_tab_data(fname, start_bin, chunksize, out=data, tab=tab)
        for thread in threads:
            logging.info("Waiting for loading of {}".format(thread.name))
            thread.join()
        # generate sb
        logging.info("Synthesizing beam {}".format(sb))

        data = sb_generator.synthesize_beam(data, sb=sb)
        # convert to a spectra object, mimicking filterbank.get_spectra
        data = spectra.Spectra(rawdatafile.frequencies, 
                               rawdatafile.tsamp, data,
                               starttime=start_bin*rawdatafile.tsamp, 
                               dm=0)
    else:
        data = rawdatafile.get_spectra(start_bin, chunksize)

    if save_sb_fil:
        fn_sbfil_out = '%s/data/CB%s_snr%d_dm%d_t0%d_sb%d.fil' % \
                     (outdir, beamno, sig_cut, dm0, t0, sb)
        reader.create_new_filterbank(fn_sbfil_out, telescope='Apertif')
        reader.write_to_fil(data.data.transpose(), rawdatafile.header, fn_sbfil_out)
#        fil_obj = reader.filterbank_.FilterbankFile(fn_sbfil_out, mode='readwrite')
#        fil_obj.append_spectra(data.data.transpose())
#        np.save(fn_fig_out, data.data)

    rawdatafile.close()

    # apply dumb mask
    if len(rfimask)>0:data.data[rfimask] = 0.

    if rficlean is True:
        data.data = cleandata(data.data, threshold_time, 
                              threshold_frequency, bin_size, 
                              n_iter_time, n_iter_frequency, 
                              clean_type, wideclean=wideclean)

    if subtract_zerodm:
        data.data -= np.mean(data.data, axis=0, keepdims=True)

    freq_ref = 0.5*(freq_up+freq_low)
    # Downsample before dedispersion up to 1/4th
    # DM smearing limit
    data.downsample(downsamp_smear)
    data.data -= np.median(data.data, axis=-1)[:, None]
#    full_arr = np.empty([int(ndm), int(ntime)])
    if not fn_mask is None:
        pass
        # rfimask = rfifind.rfifind(fn_mask)
        # mask = get_mask(rfimask, start_bin, chunksize)
        # data = data.masked(mask, maskval='median-mid80')

    if multiproc is True:
        pass
    else:
        logging.info("\nDedispersing Serially\n")
        data_copy = copy.deepcopy(data)
        data_copy.dedisperse(dm0)

        data_dm_max = data_copy.data[:, max(0, t_min):t_max]
        snr_max = SNRtools.calc_snr_matchedfilter(data_dm_max.mean(0), widths=[downsamp_res])[0]

        if t_min<0:
            Z = np.zeros([nfreq, np.abs(t_min)])
            data_dm_max = np.concatenate([Z, data_dm_max], axis=1)

        # scale max DM by pulse width, 5 units for each ms 
        dm_max_trans = 10. + 5*time_res/0.001 + 10*dm0/1000.
        dm_min_trans = -10. - 5*time_res/0.001 - 10*dm0/1000.

        if dm0+dm_min_trans<=0:
            dm_min_trans = 0.
            dm_max_trans = 2*dm0
            dm_center = dm0
        else:
            dm_center = 0.

        full_arr, dms, times = RTproc.dm_transform(data_dm_max, 
                                                  (freq_up, freq_low), 
                                                  dt=dt*downsamp_smear, 
                                                  dm_max=dm_max_trans, 
                                                  dm_min=dm_min_trans, 
                                                  freq_ref=freq_ref,ndm=ndm, 
                                                  dm0=dm_center)
        dms += dm0

    # bin down to nfreq_plot freq channels
    full_freq_arr_downsamp = data_dm_max[:nfreq//nfreq_plot*nfreq_plot, :]
    full_freq_arr_downsamp = full_freq_arr_downsamp.reshape(\
                                   nfreq_plot, -1, ntime).mean(1)

    # bin down in time by factor of downsamp
    full_freq_arr_downsamp = full_freq_arr_downsamp[:,:ntime//downsamp_res*downsamp_res]
    full_freq_arr_downsamp = full_freq_arr_downsamp.reshape(-1, ntime//downsamp_res, 
                                                    downsamp_res).mean(-1)

    if snr_max < sig_thresh_local:
        logging.info("\nSkipping trigger below local threshold %.2f:" % sig_thresh_local)
        logging.info("snr_local=%.2f  snr_trigger=%.2f\n" % (snr_max, sig_cut))
        return [], [], [], []

    times = np.linspace(0,ntime_plot*downsamp*dt,len(full_freq_arr_downsamp[0]))

    full_dm_arr_downsamp = full_arr[:, :ntime//downsamp_res*downsamp_res]
    full_dm_arr_downsamp = full_dm_arr_downsamp.reshape(-1,
                             ntime//downsamp_res, downsamp_res).mean(-1)

    full_freq_arr_downsamp /= np.std(full_freq_arr_downsamp)
    full_dm_arr_downsamp /= np.std(full_dm_arr_downsamp)

    suptitle = " CB:%s  S/N$_{pipe}$:%.1f  S/N$_{presto}$:%.1f\
                 S/N$_{compare}$:%.1f \nDM:%d  t:%.1fs  width:%d" %\
                 (beamno, sig_cut, snr_max, snr_comparison,
                    dm0, t0, downsamp)

    if not os.path.isdir('%s/plots' % outdir):
        os.system('mkdir -p %s/plots' % outdir)

    if sb is None:
        sbname = -1
    else:
        sbname = sb

    fn_fig_out = '%s/plots/CB%s_snr%d_dm%d_t0%d_sb%d.pdf' % \
                     (outdir, beamno, sig_cut, dm0, t0, sbname)

    params = sig_cut, dms[dm_max_jj], downsamp, t0, dt
    tmed = np.median(full_freq_arr_downsamp, axis=-1, keepdims=True)
    full_freq_arr_downsamp -= tmed

    if mk_plot is True:
        logging.info(fn_fig_out)

        if ndm == 1:
            plotter.plot_two_panel(full_freq_arr_downsamp, params, 
                                   prob=None,
                                   freq_low=freq_low, freq_up=freq_up,
                                   cand_no=cand_no, times=times, 
                                   suptitle=suptitle,
                                   fnout=fn_fig_out)
        else:
            plotter.plot_three_panel(full_freq_arr_downsamp,
                                     full_dm_arr_downsamp, params, dms,
                                     times=times, freq_low=freq_low,
                                     freq_up=freq_up,
                                     suptitle=suptitle, fnout=fn_fig_out,
                                     cand_no=cand_no)

    return full_dm_arr_downsamp, full_freq_arr_downsamp, time_res, params
예제 #3
0
    # TAB: filterbank filename is specified
    # SB: filterbank prefix is specified (i.e. without _<TABno>.fil)
    fn_fil = args[0]
    fn_sp = args[1]

    if options.save_data == 'concat':
        data_dm_time_full = []
        data_freq_time_full = []
        params_full = []
        if options.sb:
            data_sb_full = []

    if options.multiproc is True:
        import multiprocessing

    SNRTools = tools.SNR_Tools()

    if options.compare_trig is not None:
        res = SNRTools.compare_snr(fn_sp, options.compare_trig,
                                        dm_min=options.dm_min,
                                        dm_max=options.dm_max,
                                        save_data=False,
                                        sig_thresh=options.sig_thresh,
                                        max_rows=None,
                                        t_window=0.25)

        res = par_1, par_2, par_match_arr, ind_missed, ind_matched

        snr_1, snr_2 = par_1[0], par_2[0]
        snr_comparison_arr = np.zeros_like(snr_1)
        ind_missed = np.array(ind_missed)
예제 #4
0
def proc_trigger(fn_fil,
                 dm0,
                 t0,
                 sig_cut,
                 ndm=50,
                 mk_plot=False,
                 downsamp=1,
                 beamno='',
                 fn_mask=None,
                 nfreq_plot=32,
                 ntime_plot=250,
                 cmap='RdBu',
                 cand_no=1,
                 multiproc=False):
    """ Locate data within filterbank file (fn_fi)
    at some time t0, and dedisperse to dm0, generating 
    plots 

    Parameters:
    ----------
    fn_fil     : str 
        name of filterbank file
    dm0        : float 
        trigger dm found by single pulse search software
    t0         : float 
        time in seconds where trigger was found 
    sig_cut    : np.float 
        sigma of detected trigger at (t0, dm0)
    ndm        : int 
        number of DMs to use in DM transform 
    mk_plot    : bool 
        make three-panel plots 
    downsamp   : int 
        factor by which to downsample in time. comes from searchsoft. 
    beamno     : str 
        beam number, for fig names 
    nfreq_plot : int 
        number of frequencies channels to plot 

    Returns:
    -------
    full_dm_arr_downsamp : np.array
        data array with downsampled dm-transformed intensities
    full_freq_arr_downsamp : np.array
        data array with downsampled freq-time intensities 
    """
    SNRtools = tools.SNR_Tools()

    rawdatafile = filterbank.filterbank(fn_fil)

    mask = [
    ]  #np.array([ 5,   6,   9,  32,  35,  49,  75,  76,  78,  82,  83,  87,  92,
    #         93,  97,  98, 108, 110, 111, 112, 114, 118, 122, 123, 124, 157,
    #         160, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 660,
    #         661])

    dt = rawdatafile.header['tsamp']
    freq_up = rawdatafile.header['fch1']
    nfreq = rawdatafile.header['nchans']
    freq_low = freq_up + nfreq * rawdatafile.header['foff']
    time_res = dt * downsamp
    ntime_fil = (os.path.getsize(fn_fil) - 467.) / nfreq

    dm_min = max(0, dm0 - 20)
    dm_max = dm0 + 20
    dms = np.linspace(dm_min, dm_max, ndm, endpoint=True)

    # make sure dm0 is in the array
    dm_max_jj = np.argmin(abs(dms - dm0))
    dms += (dm0 - dms[dm_max_jj])
    dms[0] = max(0, dms[0])

    # Read in 5 disp delays
    width = 2 * abs(4.14e3 * dm0 * (freq_up**-2 - freq_low**-2))

    tdisp = width / dt
    tplot = ntime_plot * downsamp

    global t_min, t_max

    if tdisp > tplot:
        # Need to read in more data than you'll plot
        # because of large dispersion time
        chunksize = int(tdisp)
        t_min = chunksize // 2 - (ntime_plot * downsamp) // 2
        t_max = chunksize // 2 + (ntime_plot * downsamp) // 2
    else:
        # Only need ot read in enough to plot
        chunksize = int(tplot)
        t_min, t_max = 0, chunksize

    start_bin = int(t0 / dt - chunksize / 2.)

    if start_bin < 0:
        extra = start_bin
        start_bin = 0
        t_min += extra
        t_max += extra

    t_min, t_max = int(t_min), int(t_max)
    ntime = t_max - t_min

    snr_max = 0

    if ntime_fil < (start_bin + chunksize):
        print("Trigger at end of file, skipping")
        return 0, 0, 0

    data = rawdatafile.get_spectra(start_bin, chunksize)
    data.data -= np.median(data.data, axis=-1)[:, None]
    data.data[mask] = 0.
    full_arr = np.empty([int(ndm), int(ntime)])

    if not fn_mask is None:
        rfimask = rfifind.rfifind(fn_mask)
        mask = get_mask(rfimask, start_bin, chunksize)
        data = data.masked(mask, maskval='median-mid80')

    if multiproc is True:
        print("\nDedispersing in Parallel\n")
        t0 = time.time()
        global datacopy

        ndm_ = min(10, ndm)

        for kk in range(ndm // ndm_):
            dms_ = dms[10 * kk:10 * (kk + 1)]
            datacopy = copy.deepcopy(data)
            pool = multiprocessing.Pool(processes=ndm_)
            data_tuple = pool.map(multiproc_dedisp, [i for i in dms_])
            pool.close()

            data_tuple = np.concatenate(data_tuple)
            ddm = np.concatenate(data_tuple[0::2]).reshape(ndm_, -1)
            df = np.concatenate(data_tuple[1::2]).reshape(ndm_, nfreq, -1)

            print(time.time() - t0)
            full_arr[10 * kk:10 * (kk + 1)] = ddm[:, t_min:t_max]

            ind_kk = range(10 * kk, 10 * (kk + 1))

            if dm_max_jj in ind_kk:
                print(dms_[ind_kk.index(dm_max_jj)])
                data_dm_max = df[ind_kk.index(dm_max_jj)]  #dm_max_jj]hack

            del ddm, df

    else:
        print("\nDedispersing Serially\n")

        for jj, dm_ in enumerate(dms):
            print("Dedispersing to dm=%0.1f at t=%0.1f sec with width=%.2f" %
                  (dm_, start_bin * dt, downsamp))
            data_copy = copy.deepcopy(data)

            data_copy.dedisperse(dm_)
            dm_arr = data_copy.data[:, t_min:t_max].mean(0)

            snr_ = SNRtools.calc_snr(dm_arr)

            full_arr[jj] = copy.copy(dm_arr)

            if jj == dm_max_jj:
                data_dm_max = data_copy.data[:, t_min:t_max]

    downsamp = int(downsamp)

    # bin down to nfreq_plot freq channels
    full_freq_arr_downsamp = data_dm_max[:nfreq//nfreq_plot*nfreq_plot, :].reshape(\
                                   nfreq_plot, -1, ntime).mean(1)
    # bin down in time by factor of downsamp
    full_freq_arr_downsamp = full_freq_arr_downsamp[:, :ntime//downsamp*downsamp\
                                   ].reshape(-1, ntime//downsamp, downsamp).mean(-1)

    times = np.linspace(0, ntime * dt, len(full_freq_arr_downsamp[0]))

    full_dm_arr_downsamp = full_arr[:, :ntime // downsamp * downsamp]
    full_dm_arr_downsamp = full_dm_arr_downsamp.reshape(
        -1, ntime // downsamp, downsamp).mean(-1)

    full_freq_arr_downsamp /= np.std(full_freq_arr_downsamp)
    full_dm_arr_downsamp /= np.std(full_dm_arr_downsamp)

    suptitle = "beam%s snr%d dm%d t0%d width%d" %\
                 (beamno, sig_cut, dms[dm_max_jj], t0, downsamp)

    fn_fig_out = './plots/train_data_beam%s_snr%d_dm%d_t0%d.pdf' % \
                     (beamno, sig_cut, dms[dm_max_jj], t0)

    if mk_plot is True:
        print(fn_fig_out)
        if ndm == 1:
            params = snr_, dm_, downsamp, t0
            plotter.plot_two_panel(full_freq_arr_downsamp,
                                   params,
                                   prob=None,
                                   freq_low=1250.09765625,
                                   freq_up=1549.90234375,
                                   cand_no=cand_no)
        else:
            print(fn_fig_out)
            plotter.plot_three_panel(full_freq_arr_downsamp,
                                     full_dm_arr_downsamp,
                                     times,
                                     dms,
                                     freq_low=freq_low,
                                     freq_up=freq_up,
                                     suptitle=suptitle,
                                     fnout=fn_fig_out,
                                     cand_no=cand_no)

    return full_dm_arr_downsamp, full_freq_arr_downsamp, time_res
예제 #5
0
def proc_trigger(fn_fil, dm0, t0, sig_cut,
                 ndm=50, mk_plot=False, downsamp=1,
                 beamno='', fn_mask=None, nfreq_plot=32,
                 ntime_plot=250,
                 cmap='RdBu', cand_no=1, multiproc=False,
                 rficlean=False, snr_comparison=-1,
                 outdir='./', sig_thresh_local=5.0,
                 subtract_zerodm=False,
                 threshold_time=3.25, threshold_frequency=2.75, bin_size=32,
                 n_iter_time=3, n_iter_frequency=3, clean_type='time'):
    """ Locate data within filterbank file (fn_fi)
    at some time t0, and dedisperse to dm0, generating
    plots

    Parameters:
    ----------
    fn_fil     : str
        name of filterbank file
    dm0        : float
        trigger dm found by single pulse search software
    t0         : float
        time in seconds where trigger was found
    sig_cut    : np.float
        sigma of detected trigger at (t0, dm0)
    ndm        : int
        number of DMs to use in DM transform
    mk_plot    : bool
        make three-panel plots
    downsamp   : int
        factor by which to downsample in time. comes from searchsoft.
    beamno     : str
        beam number, for fig names
    nfreq_plot : int
        number of frequencies channels to plot

    Returns:
    -------
    full_dm_arr_downsamp : np.array
        data array with downsampled dm-transformed intensities
    full_freq_arr_downsamp : np.array
        data array with downsampled freq-time intensities
    """

    try:
        rfimask = np.loadtxt('/home/arts/ARTS-obs/amber_conf/zapped_channels_1400.conf')
        rfimask = rfimask.astype(int)
    except:
        rfimask = []
        logging.warning("Could not load dumb RFIMask")

    SNRtools = tools.SNR_Tools()
    downsamp = min(4096, downsamp)
    rawdatafile = filterbank.filterbank(fn_fil)
    dfreq_MHz = rawdatafile.header['foff']
    mask = []

    dt = rawdatafile.header['tsamp']
    freq_up = rawdatafile.header['fch1']
    nfreq = rawdatafile.header['nchans']
    # fix RFI mask order
    rfimask = nfreq - rfimask
    freq_low = freq_up + nfreq*rawdatafile.header['foff']
    ntime_fil = (os.path.getsize(fn_fil) - 467.)/nfreq
    tdm = np.abs(8.3*1e-6*dm0*dfreq_MHz*(freq_low/1000.)**-3)

    dm_min = max(0, dm0-40)
    dm_max = dm0 + 40
    dms = np.linspace(dm_min, dm_max, ndm, endpoint=True)

    # make sure dm0 is in the array
    dm_max_jj = np.argmin(abs(dms-dm0))
    dms += (dm0-dms[dm_max_jj])
    dms[0] = max(0, dms[0])

    global t_min, t_max
    # if smearing timescale is < 4*pulse width,
    # downsample before dedispersion for speed
    downsamp_smear = max(1, int(downsamp*dt/tdm/2.))
    # ensure that it's not larger than pulse width
    downsamp_smear = int(min(downsamp, downsamp_smear))
    downsamp_res = int(downsamp//downsamp_smear)
    downsamp = int(downsamp_res*downsamp_smear)
    time_res = dt * downsamp
    tplot = ntime_plot * downsamp
    logging.info("Width_full:%d  Width_smear:%d  Width_res: %d" %
                 (downsamp, downsamp_smear, downsamp_res))
#    print("Width_full:%d  Width_smear:%d  Width_res: %d" %
#        (downsamp, downsamp_smear, downsamp_res))

    start_bin = int(t0/dt - ntime_plot*downsamp//2)
    width = abs(4.148e3 * dm0 * (freq_up**-2 - freq_low**-2))
    chunksize = int(width/dt + ntime_plot*downsamp)

    t_min, t_max = 0, ntime_plot*downsamp

    if start_bin < 0:
        extra = start_bin
        start_bin = 0
        t_min += extra
        t_max += extra

    t_min, t_max = int(t_min), int(t_max)

    snr_max = 0

    # Account for the pre-downsampling to speed up dedispersion
    t_min /= downsamp_smear
    t_max /= downsamp_smear
    ntime = t_max-t_min

    if ntime_fil < (start_bin+chunksize):
        logging.info("Trigger at end of file, skipping")
#        print("Trigger at end of file, skipping")
        return [],[],[],[]

    data = rawdatafile.get_spectra(start_bin, chunksize)
    # apply dumb mask
    data.data[rfimask] = 0.

    if rficlean is True:
        data = cleandata(data, threshold_time, threshold_frequency, bin_size, \
                         n_iter_time, n_iter_frequency, clean_type)

    if subtract_zerodm:
        data.data -= np.mean(data.data, axis=0)[None]

    # Downsample before dedispersion up to 1/4th
    # DM smearing limit
    data.downsample(downsamp_smear)
    data.data -= np.median(data.data, axis=-1)[:, None]
    full_arr = np.empty([int(ndm), int(ntime)])
    if not fn_mask is None:
        pass
        # rfimask = rfifind.rfifind(fn_mask)
        # mask = get_mask(rfimask, start_bin, chunksize)
        # data = data.masked(mask, maskval='median-mid80')

    if multiproc is True:
        tbeg=time.time()
        global datacopy

        size_arr = sys.getsizeof(data.data)
        nproc = int(32.0e9/size_arr)

        ndm_ = min(min(nproc, ndm), 10)

        for kk in range(ndm//ndm_):
            dms_ = dms[ndm_*kk:ndm_*(kk+1)]
            datacopy = copy.deepcopy(data)
            pool = multiprocessing.Pool(processes=ndm_)
            data_tuple = pool.map(multiproc_dedisp, [i for i in dms_])
            pool.close()

            data_tuple = np.concatenate(data_tuple)
            ddm = np.concatenate(data_tuple[0::2]).reshape(ndm_, -1)
            df = np.concatenate(data_tuple[1::2]).reshape(ndm_, nfreq, -1)

            print(time.time()-tbeg)
            full_arr[ndm_*kk:ndm_*(kk+1)] = ddm[:, t_min:t_max]

            ind_kk = range(ndm_*kk, ndm_*(kk+1))

            if dm_max_jj in ind_kk:
                data_dm_max = df[ind_kk.index(dm_max_jj)]

            del ddm, df
    else:
        logging.info("\nDedispersing Serially\n")
        #print("\nDedispersing Serially\n")
        for jj, dm_ in enumerate(dms):
            tcopy = time.time()
            data_copy = copy.deepcopy(data)

            t0_dm = time.time()
            data_copy.dedisperse(dm_)
            dm_arr = data_copy.data[:, max(0, t_min):t_max].mean(0)

            full_arr[jj, np.abs(min(0, t_min)):] = copy.copy(dm_arr)

            logging.info("Dedispersing to dm=%0.1f at t=%0.1fsec with width=%.1f S/N=%.1f" %
                         (dm_, t0, downsamp, sig_cut))
#            print("Dedispersing to dm=%0.1f at t=%0.1fsec with width=%.1f S/N=%.1f" %
#                        (dm_, t0, downsamp, sig_cut))

            if jj==dm_max_jj:
                data_dm_max = data_copy.data[:, max(0, t_min):t_max]
                snr_max = SNRtools.calc_snr_matchedfilter(data_dm_max.mean(0), widths=[downsamp_res])[0]
                if t_min<0:
                    Z = np.zeros([nfreq, np.abs(t_min)])
                    data_dm_max = np.concatenate([Z, data_dm_max], axis=1)

    # bin down to nfreq_plot freq channels
    full_freq_arr_downsamp = data_dm_max[:nfreq//nfreq_plot*nfreq_plot, :].reshape(\
                                   nfreq_plot, -1, ntime).mean(1)

    # bin down in time by factor of downsamp
    full_freq_arr_downsamp = full_freq_arr_downsamp[:, :ntime//downsamp_res*downsamp_res\
                                   ].reshape(-1, ntime//downsamp_res, downsamp_res).mean(-1)

#    snr_max = SNRtools.calc_snr_mad(full_freq_arr_downsamp.mean(0))

    if snr_max < sig_thresh_local:
        logging.info("\nSkipping trigger below local threshold %.2f:" % sig_thresh_local)
        logging.info("snr_local=%.2f  snr_trigger=%.2f\n" % (snr_max, sig_cut))
        return [],[],[],[]

    times = np.linspace(0, ntime_plot*downsamp*dt, len(full_freq_arr_downsamp[0]))

    full_dm_arr_downsamp = full_arr[:, :ntime//downsamp_res*downsamp_res]
    full_dm_arr_downsamp = full_dm_arr_downsamp.reshape(-1,
                             ntime//downsamp_res, downsamp_res).mean(-1)

    full_freq_arr_downsamp /= np.std(full_freq_arr_downsamp)
    full_dm_arr_downsamp /= np.std(full_dm_arr_downsamp)

    suptitle = " CB:%s  S/N$_{pipe}$:%.1f  S/N$_{presto}$:%.1f\
                 S/N$_{compare}$:%.1f \nDM:%d  t:%.1fs  width:%d" %\
                 (beamno, sig_cut, snr_max, snr_comparison, \
                    dms[dm_max_jj], t0, downsamp)

    if not os.path.isdir('%s/plots' % outdir):
        os.system('mkdir -p %s/plots' % outdir)

    fn_fig_out = '%s/plots/CB%s_snr%d_dm%d_t0%d.pdf' % \
                     (outdir, beamno, sig_cut, dms[dm_max_jj], t0)

    params = sig_cut, dms[dm_max_jj], downsamp, t0, dt
    tmed = np.median(full_freq_arr_downsamp, axis=-1, keepdims=True)
    full_freq_arr_downsamp -= tmed

    if mk_plot is True:
        logging.info(fn_fig_out)

        if ndm==1:
            plotter.plot_two_panel(full_freq_arr_downsamp, params, prob=None,
                                   freq_low=freq_low, freq_up=freq_up,
                                   cand_no=cand_no, times=times, suptitle=suptitle,
                                   fnout=fn_fig_out)
        else:
            plotter.plot_three_panel(full_freq_arr_downsamp,
                                     full_dm_arr_downsamp, params, dms,
                                     times=times, freq_low=freq_low,
                                     freq_up=freq_up,
                                     suptitle=suptitle, fnout=fn_fig_out,
                                     cand_no=cand_no)

    return full_dm_arr_downsamp, full_freq_arr_downsamp, time_res, params
예제 #6
0
def fil_trigger(fn_fil, dm0, t0, sig_cut,
                 ndm=50, mk_plot=False, downsamp=1,
                 beamno='', fn_mask=None, nfreq_plot=32,
                 ntime_plot=250,
                 cmap='RdBu', cand_no=1, multiproc=False,
                 rficlean=False, snr_comparison=-1,
                 outdir='./', sig_thresh_local=5.0,
                 threshold_time=3.25, threshold_frequency=2.75, bin_size=32,
                 n_iter_time=3, n_iter_frequency=3, clean_type='time'):
    try:
        rfimask = np.loadtxt('/home/arts/ARTS-obs/amber_conf/zapped_channels_1400.conf')
        rfimask = rfimask.astype(int)
    except:
        rfimask = []
        logging.info("Could not load dumb RFIMask")

    SNRtools = tools.SNR_Tools()
    downsamp = min(4096, downsamp)
    rawdatafile = filterbank.filterbank(fn_fil)
    dfreq_MHz = rawdatafile.header['foff']
    mask = []

    dt = rawdatafile.header['tsamp']
    freq_up = rawdatafile.header['fch1']
    nfreq = rawdatafile.header['nchans']
    freq_low = freq_up + nfreq*rawdatafile.header['foff']
    ntime_fil = (os.path.getsize(fn_fil) - 467.)/nfreq
    tdm = np.abs(8.3*1e-6*dm0*dfreq_MHz*(freq_low/1000.)**-3)

    dm_min = max(0, dm0-40)
    dm_max = dm0 + 40
    dms = np.linspace(dm_min, dm_max, ndm, endpoint=True)

    # make sure dm0 is in the array
    dm_max_jj = np.argmin(abs(dms-dm0))
    dms += (dm0-dms[dm_max_jj])
    dms[0] = max(0, dms[0])

    global t_min, t_max
    # if smearing timescale is < 4*pulse width,
    # downsample before dedispersion for speed
    downsamp_smear = max(1, int(downsamp*dt/tdm/2.))
    # ensure that it's not larger than pulse width
    downsamp_smear = int(min(downsamp, downsamp_smear))
    downsamp_res = int(downsamp//downsamp_smear)
    downsamp = int(downsamp_res*downsamp_smear)
    time_res = dt * downsamp
    tplot = ntime_plot * downsamp
#    print("Width_full:%d  Width_smear:%d  Width_res: %d" %
#        (downsamp, downsamp_smear, downsamp_res))

    start_bin = int(t0/dt - ntime_plot*downsamp//2)
    width = abs(4.148e3 * dm0 * (freq_up**-2 - freq_low**-2))
    chunksize = int(width/dt + ntime_plot*downsamp)

    t_min, t_max = 0, ntime_plot*downsamp

    if start_bin < 0:
        extra = start_bin
        start_bin = 0
        t_min += extra
        t_max += extra

    t_min, t_max = int(t_min), int(t_max)

    snr_max = 0

    # Account for the pre-downsampling to speed up dedispersion
    t_min /= downsamp_smear
    t_max /= downsamp_smear
    ntime = t_max-t_min

    data = rawdatafile.get_spectra(start_bin, chunksize)

    if rficlean is True:
        data = cleandata(data, threshold_time, threshold_frequency, \
                         bin_size, n_iter_time, n_iter_frequency, clean_type)

    return data, downsamp, downsamp_smear
예제 #7
0
def inject_in_filterbank(fn_fil,
                         fn_out_dir,
                         N_FRB=1,
                         NFREQ=1536,
                         NTIME=2**15,
                         rfi_clean=False,
                         dm=250.0,
                         freq=(1550, 1250),
                         dt=0.00004096,
                         chunksize=5e4,
                         calc_snr=True,
                         start=0,
                         freq_ref=1400.,
                         subtract_zero=False,
                         clipping=None):
    """ Inject an FRB in each chunk of data 
        at random times. Default params are for Apertif data.

    Parameters:
    -----------

    fn_fil : str
        name of filterbank file 
    fn_out_dir : str 
        directory for output files 
    N_FRB : int 
        number of FRBs to inject 
    NTIME : int 
        number of time samples per data chunk 
    rfi_clean : bool 
        apply rfi filters 
    dm : float / tuple 
        dispersion measure(s) to inject FRB with 
    freq : tuple 
        (freq_bottom, freq_top) 
    dt : float 
        time resolution 
    chunksize : int 
        size of data in samples to read in 
    calc_snr : bool 
        calculates S/N of injected pulse 
    start : int 
        start sample 
    freq_ref : float 
        reference frequency for injection code 
    subtract_zero : bool 
        subtract zero DM timestream from data 
    clipping : 
        zero out bright events in zero-DM timestream 

    Returns:
    --------
    None 
    """

    if type(dm) is not tuple:
        max_dm = dm
    else:
        max_dm = max(dm)

    t_delay_max = abs(4.14e3 * max_dm * (freq[0]**-2 - freq[1]**-2))
    t_delay_max_pix = int(t_delay_max / dt)

    # ensure that dispersion sweep is not too large
    # for chunksize
    f_edge = 0.3
    while chunksize <= t_delay_max_pix / f_edge:
        chunksize *= 2
        NTIME *= 2
        print('Increasing to NTIME:%d, chunksize:%d' % (NTIME, chunksize))

    ii = 0
    params_full_arr = []

    ttot = int(N_FRB * chunksize * dt)

    timestr = time.strftime("%Y%m%d-%H%M")
    fn_fil_out = '%s/dm%s_nfrb%d_%s_sec_%s.fil' % (fn_out_dir, dm, N_FRB, ttot,
                                                   timestr)
    fn_params_out = fn_fil_out.strip('.fil') + '.txt'

    f_params_out = open(fn_params_out, 'w+')
    f_params_out.write(
        '# DM      Sigma      Time (s)     Sample    Downfact\n')
    f_params_out.close()

    for ii in xrange(N_FRB):
        # drop FRB in random location in data chunk
        offset = int(
            np.random.uniform(0.1 * chunksize, (1 - f_edge) * chunksize))

        data_filobj, freq_arr, delta_t, header = reader.read_fil_data(
            fn_fil, start=start + chunksize * ii, stop=chunksize)

        if ii == 0:
            fn_rfi_clean = reader.write_to_fil(np.zeros([NFREQ, 0]), header,
                                               fn_fil_out)

        data = data_filobj.data
        # injected pulse time in seconds since start of file

        t0_ind = offset + NTIME // 2 + chunksize * ii
        t0_ind = start + chunksize * ii + offset  # hack because needs to agree with presto
        t0 = t0_ind * delta_t

        if len(data) == 0:
            break

        data_event = (data[:, offset:offset + NTIME]).astype(np.float)

        data_event, params = simulate_frb.gen_simulated_frb(
            NFREQ=NFREQ,
            NTIME=NTIME,
            sim=True,
            #                                               fluence=3000*(1+0.1*ii),
            fluence=(200, 500),
            spec_ind=0,
            width=(delta_t, delta_t * 100),
            dm=dm + 10 * ii,
            scat_factor=(-4, -3.5),
            background_noise=data_event,
            delta_t=delta_t,
            plot_burst=False,
            freq=(freq_arr[0], freq_arr[-1]),
            FREQ_REF=freq_ref,
            scintillate=False)

        dm_ = params[0]
        params.append(offset)

        print("%d/%d Injecting with DM:%d width: %.2f offset: %d" %
              (ii, N_FRB, dm_, params[2], offset))

        data[:, offset:offset + NTIME] = data_event

        #params_full_arr.append(params)
        width = params[2]
        downsamp = max(1, int(width / delta_t))
        t_delay_mid = 4.15e3 * dm_ * (freq_ref**-2 - freq_arr[0]**-2)

        # this is an empirical hack. I do not know why
        # the PRESTO arrival times are different from t0
        # by the dispersion delay between the reference and
        # upper frequency
        t0 -= t_delay_mid

        if rfi_clean is True:
            data = rfi_test.apply_rfi_filters(data.astype(np.float32), delta_t)

        if subtract_zero is True:
            print("Subtracting zero DM")
            data_ts_zerodm = data.mean(0)
            data -= data_ts_zerodm[None]

        if clipping is not None:
            # Find tsamples > 8sigma and replace them with median
            assert type(clipping) in (float,
                                      int), 'clipping must be int or float'

            data_ts_zerodm = data.mean(0)
            stds, med = sigma_from_mad(data_ts_zerodm)
            ind = np.where(np.absolute(data_ts_zerodm - med) > 8.0 * stds)[0]
            data[:, ind] = np.median(data, axis=-1, keepdims=True)

        if ii < 0:
            fn_rfi_clean = reader.write_to_fil(data.transpose(), header,
                                               fn_fil_out)
        elif ii >= 0:
            fil_obj = reader.filterbank.FilterbankFile(fn_fil_out,
                                                       mode='readwrite')
            fil_obj.append_spectra(data.transpose())

        if calc_snr is True:
            data_filobj.data = data
            data_filobj.dedisperse(dm_)
            end_t = abs(4.15e3 * dm_ * (freq[0]**-2 - freq[1]**-2))
            end_pix = int(end_t / dt)
            end_pix_ds = int(end_t / dt / downsamp)

            data_rb = (data_filobj.data).copy()
            data_rb = data_rb[:, :-end_pix].mean(0)
            data_rb -= np.median(data_rb)

            SNRTools = tools.SNR_Tools()
            snr_max, width_max = SNRTools.calc_snr_widths(data_rb,
                                                          widths=range(100))

            #            snr_max2, width_max2 = tools.calc_snr_widths(data_rb,
            #                                         )
            print("S/N: %.2f width_used: %.3f width_tru: %.3f DM: %.1f" %
                  (snr_max, width_max, width / delta_t, dm_))

        else:
            snr_max = 10.0
            width_max = int(width / dt)

        f_params_out = open(fn_params_out, 'a+')
        f_params_out.write('%2f   %2f   %5f   %7d   %d\n' %
                           (params[0], snr_max, t0, t0_ind, width_max))

        f_params_out.close()
        del data, data_event

    params_full_arr = np.array(params_full_arr)