示例#1
0
def make_sim_WFs(N_WFs, WF_library, key, sigma, noise_RMS, amp_scale):
    """
    Generate simulated waveforms for a specified grain size (K0)

    Inputs:
        N_WFs: number of pulses in each waveform object
        WF_library: a dict of waveform catalogs (keys specify different channels)
        key: entry within the WF_library from whcih to generate the waveforms
        sigma: pulse spreading time
        noise_RMS: dict giving the RMS noise values for each channel
        amp_scale: scale by which each channel's ampitude is multiplied
    outputs:
        WFs: dict giving waveform data for each channel.
    """

    WFs = {}
    for ch in WF_library.keys():
        if sigma > 0:
            BW = broadened_WF(WF_library[ch][key], sigma)
        else:
            BW = waveform(WF_library[ch][key].t, WF_library[ch][key].p)
        #BW.normalize()
        WFs[ch]=waveform(BW.t, np.tile(amp_scale[ch]*BW.p, [1, N_WFs])+\
           noise_RMS[ch]*np.random.randn(BW.p.size*N_WFs).reshape(BW.p.size, N_WFs))
        WFs[ch].noise_RMS = noise_RMS[ch] + np.zeros(WFs[ch].size)
        WFs[ch].shots = np.arange(WFs[ch].size)
    return WFs
def make_rx_scat_catalog(TX, h5_file=None, reduce_res=False):
    """
    make a dictionary of waveform templates by convolving the transmit pulse with
    subsurface-scattering SRFs
    """
    if h5_file is None:
        h5_file = '/Users/ben/Dropbox/ATM_red_green/subsurface_srf_no_BC.h5'
    with h5py.File(h5_file, 'r') as h5f:
        t0 = np.array(h5f['t']) * 1.e9
        z = np.zeros_like(t0)
        z[np.argmin(abs(t0))] = 1
        TXc = np.convolve(TX.p.ravel(), z, 'full')
        TX.p[~np.isfinite(TX.p)] = 0.
        t_full = np.arange(TXc.size) * 0.25
        t_full -= waveform(t_full, TXc).nSigmaMean()[0]
        RX = dict()
        r_vals = np.array(h5f['r_eff'])
        if reduce_res:
            r_vals = np.concatenate([
                r_vals[r_vals < 1e-4][::10],
                r_vals[(r_vals >= 1e-4) & (r_vals < 5e-3)][::2],
                r_vals[r_vals >= 5e-3]
            ])
        for row, r_val in enumerate(h5f['r_eff']):
            if r_val not in r_vals:
                continue
            rx0 = h5f['p'][row, :]
            temp = np.convolve(TX.p.ravel(), rx0, 'full') * 0.25e-9
            RX[r_val] = waveform(
                TX.t,
                np.interp(TX.t.ravel(), t_full.ravel(),
                          temp).reshape(TX.t.shape))
            RX[r_val].t0 = 0.
            RX[r_val].tc = RX[r_val].nSigmaMean()[0]
    return RX
示例#3
0
def wf_misfit(delta_t, sigma, WF, catalog, M, key_top,  G=None, return_data_est=False, fit_BG=False):
    """
        Find the misfit between a scaled and shifted template and a waveform
    """
    if G is None and fit_BG:
        G=np.ones((WF.p.size, 2))
    this_key=key_top+[sigma]+[delta_t]
    if (this_key in M) and (return_data_est is False):
        return M[this_key]['R']
    else:
        # check if the broadened but unshifted version of this key is in the catalog
        broadened_key=key_top+[sigma]
        if broadened_key in catalog:
            broadened_p=catalog[broadened_key].p
        else:
            # make a broadened version of the catalog WF
            if sigma==0:
                 broadened_p = catalog[key_top].p
            else:
                broadened_p = broaden_p( catalog[key_top], sigma)
            catalog[broadened_key]=waveform(catalog[key_top].t, broadened_p, t0=catalog[key_top].t0, tc=catalog[key_top].tc)
        # check if the shifted version of the broadened waveform is in the catalog
        if this_key not in catalog:
            # if not, make it.
            M[this_key]=listDict()
            temp_p = np.interp(WF.t.ravel(), (catalog[key_top].t-catalog[key_top].tc+delta_t).ravel(), broadened_p.ravel(), left=np.NaN, right=np.NaN)

            # Note that argmax on a binary array returns the first nonzero index (faster than where)
            ii=np.argmax(temp_p>0.01*np.nanmax(temp_p))
            mask=np.ones_like(temp_p, dtype=bool)
            mask[0:ii-4] = False
            catalog[this_key] = waveform(catalog[broadened_key].t, temp_p,
                   tc=catalog[broadened_key].tc, t0=catalog[broadened_key].t0)
            catalog[this_key].params['mask']=mask
        this_entry=catalog[this_key]
        if fit_BG:
            # solve for the background and the amplitude
            R, m, Ginv, good = lin_fit_misfit(catalog[this_key].p, WF.p, G=G,\
                Ginv=this_entry.params['Ginv'], good_old=this_entry.params['good'])
            #catalog[this_key].params['Ginv']=Ginv
            this_entry.params['good']=good
            M[this_key] = {'K0':key_top[0], 'R':R, 'A':np.float64(m[0]), 'B':np.float64(m[1]), 'delta_t':delta_t, 'sigma':sigma}
        else:
            # solve for the amplitude only
            G=this_entry.p
            #try:
            #    ii=np.where(G>0.002*np.nanmax(G))[0][0]
            #except IndexError:
            #    print("The G>002 problem happened!")
            good=np.isfinite(G).ravel() & np.isfinite(WF.p).ravel() & this_entry.params['mask']
            #good[0:np.maximum(0,ii-4)]=0
            if this_entry.p_squared is None:
                this_entry.p_squared=this_entry.p**2
            #R, m, good = amp_misfit(G, WF.p, els=good, x_squared=catalog[this_key].p_squared)
            m, R = corr_no_mean(G.ravel(), WF.p.ravel(), this_entry.p_squared.ravel(), good.astype(np.int32).ravel(), G.size)
            M[this_key] = {'K0':key_top[0], 'R':R, 'A':m, 'B':0., 'delta_t':delta_t, 'sigma':sigma}
        if return_data_est:
            return R, G.dot(m)
        else:
            return R
示例#4
0
def res_test(impulse_file):

    with h5py.File(impulse_file,'r') as h5f:
        TX=waveform(np.array(h5f['/TX/t'])*1.e9,np.array(h5f['/TX/p']))
    dt_vals=[0.25, 0.1, 0.05, 0.025, 0.0125, 0.0125/2, 0.0125/4]
    ctr_vals=np.zeros_like(dt_vals)+np.NaN
    med_vals=np.zeros_like(dt_vals)+np.NaN
    wctr_vals=np.zeros_like(dt_vals)+np.NaN
    wmed_vals=np.zeros_like(dt_vals)+np.NaN
    sigma_vals=np.zeros_like(dt_vals)+np.NaN
    for ii, dt in enumerate(dt_vals):
        ti=np.arange(TX.t[0], TX.t[-1]+dt, dt)
        ti=ti.reshape([ti.size, 1])
        #TXi=waveform(ti, np.interp(ti, TX.t.ravel(), TX.p.ravel()).reshape(ti.shape))
        TXi=resample_wf(TX, dt)
        ctr_vals[ii], med_vals[ii] = wf_med_bar(TXi, 20., t_tol=0.001)
        wctr_vals[ii], wmed_vals[ii], _, sigma_vals[ii] = composite_stats(TXi, 20)
        plt.title(dt)

    plt.figure(11); plt.clf()
    plt.semilogx(dt_vals, ctr_vals, marker='o', label='ctr from int')
    plt.semilogx(dt_vals, med_vals, marker='x', label='med from int')
    plt.semilogx(dt_vals, wctr_vals, marker='o', label='ctr from wf')
    plt.semilogx(dt_vals, wmed_vals, marker='x', label='med from wf')
    plt.legend()
示例#5
0
def broadened_WF(WF, sigma):
    """
    Generate a version of WF broadened by a Gaussian of width sigma
    """
    nK = 3 * np.ceil(sigma / WF.dt)
    tK = np.arange(-nK, nK + 1) * WF.dt
    K = gaussian(tK, 0, sigma)
    K = K / np.sum(K)
    return waveform(WF.t, np.convolve(WF.p.ravel(), K, 'same'))
示例#6
0
def make_composite_wf(k0s, catalog):
    catalog_k0=np.array([key for key in catalog])
    WF_list=[]
    for k0 in k0s:
        ii=np.argmin(np.abs(catalog_k0-k0))
        WF_list.append(catalog[(catalog_k0[ii])])
    WFs=waveform(catalog[0].t, np.concatenate([WFi.p for WFi in WF_list], axis=1))
    WF_mean=WFs.calc_mean(normalize=False, threshold=5000)
    return WF_mean, WFs
示例#7
0
def broadened_misfit(delta_ts, sigma, WF, catalog, M, key_top,  t_tol=None, refine_parabolic=True):
    """
    Calculate the misfit between a broadened template and a waveform (searching over a range of shifts)
    """
    this_key=key_top+[sigma]
    if (this_key in M) and ('best' in M[this_key]):
        return M[this_key]['best']['R']
    else:
        M[this_key]=listDict()
        if this_key not in catalog:
            # if we haven't already broadened the WF to sigma, try it now:
            if sigma==0:
                catalog[this_key]=waveform(catalog[key_top].t, catalog[key_top].p, t0=catalog[key_top].t0, tc=catalog[key_top].tc)
            else:
                #nK=np.minimum(np.floor(catalog[key_top].p.size/2)-1,3*np.ceil(sigma/WF.dt))
                #tK=np.arange(-nK, nK+1)*WF.dt
                #K=gaussian(tK, 0, sigma)
                #K=K/np.sum(K)
                try:
                    catalog[this_key]=waveform(catalog[key_top].t, broaden_p(catalog[key_top], sigma))
                    #catalog[this_key]=waveform(catalog[key_top].t, np.convolve(catalog[key_top].p.ravel(), K,'same'))
                except ValueError:
                    print("Convolution failed")
        return fit_shifted(delta_ts, sigma, catalog, WF,  M, key_top, t_tol=t_tol, refine_parabolic=refine_parabolic)
示例#8
0
def setup(scat_file, impulse_file):

    with h5py.File(impulse_file,'r') as h5f:
        TX=waveform(np.array(h5f['/TX/t']),np.array(h5f['/TX/p']))

    TX.t *= 1e9
    TX.t -= TX.nSigmaMean()[0]
    TX.tc = 0
    TR=np.round((np.max(TX.t)-np.min(TX.t))/0.25)*0.25;
    t_i=np.arange(-TR/2, TR/2+0.25, 0.25)
    TX.p=np.interp(t_i, TX.t.ravel(), TX.p.ravel())
    TX.t=t_i
    TX.p.shape=[TX.p.size,1]
    TX.normalize()

    # make the library of templates

    catalog = dict()
    catalog.update({0.:TX})
    catalog.update(make_rx_scat_catalog(TX, h5_file=scat_file))
    return catalog
示例#9
0
def read_ATM_file(fname,
                  getCountAndReturn=False,
                  shot0=0,
                  nShots=np.Inf,
                  readTX=True,
                  readRX=True):
    """
    Read data from an ATM file
    """
    with h5py.File(fname, 'r') as h5f:

        # figure out what shots to read
        #shotMax=h5f['/waveforms/twv/shot/gate_start'].size
        shotMax = h5f['/laser/calrng'].size
        if getCountAndReturn:
            return shotMax

        nShots = np.minimum(shotMax - shot0, nShots)
        shotN = np.int(shot0 + nShots)
        shot0 = np.int(shot0)
        # read in some of the data fields
        D_in = dict()

        # read the waveform starts, stops, and lengths for all shots in the file (inefficient, but hard to avoid)
        for key in ('/waveforms/twv/gate/wvfm_start', '/waveforms/twv/gate/wvfm_length', '/waveforms/twv/gate/position',\
                    '/waveforms/twv/gate/pulse/count'):
            D_in[key] = np.array(h5f[key], dtype=int)
        # read in the gate info for the shots we want to read
        for key in ('/waveforms/twv/shot/gate_start',
                    '/waveforms/twv/shot/gate_count', '/laser/gate_xmt',
                    '/laser/gate_rcv', '/laser/calrng'):
            D_in[key] = np.array(h5f[key][shot0:shotN], dtype=int)

        #read in the geolocation and time
        try:
            for key in ('/footprint/latitude', '/footprint/longitude',
                        '/footprint/elevation', '/laser/scan_azimuth'):
                D_in[key] = np.array(h5f[key][shot0:shotN])
        except KeyError:
            pass
        D_in['/waveforms/twv/shot/seconds_of_day'] = np.array(
            h5f['/waveforms/twv/shot/seconds_of_day'][shot0:shotN])
        # read the sampling interval
        dt = np.float64(h5f['/waveforms/twv/ancillary_data/sample_interval'])

        # figure out what samples to read from the 'amplitude' dataset
        gate0 = D_in['/waveforms/twv/shot/gate_start'][0] - 1 + D_in[
            '/laser/gate_xmt'][0] - 1
        sample_start = D_in['/waveforms/twv/gate/wvfm_start'][gate0] - 1
        gateN = D_in['/waveforms/twv/shot/gate_start'][-1] - 1 + D_in[
            '/laser/gate_rcv'][-1] - 1
        sample_end = D_in['/waveforms/twv/gate/wvfm_start'][gateN] + D_in[
            '/waveforms/twv/gate/wvfm_length'][gateN]
        # ... and read the amplitude.  The sample_start variable will get subtracted off
        # subsequent indexes into the amplitude array
        key = '/waveforms/twv/wvfm/amplitude'
        D_in[key] = np.array(h5f[key][sample_start:sample_end + 1], dtype=int)

        TX = list()
        RX = list()
        tx_samp0 = list()
        rx_samp0 = list()
        RX = list()
        nPeaks = list()
        rxBuffer = np.zeros(192) + np.NaN
        for shot in range(int(nShots)):
            wfd = read_wf(D_in,
                          shot,
                          starting_sample=sample_start,
                          read_tx=readTX,
                          read_rx=readRX)
            if readTX:
                TX.append(wfd['tx']['P'][0:160])
                tx_samp0.append(wfd['tx']['pos'])
            if readRX:
                nRX = np.minimum(190, wfd['rx']['P'].size)
                rxBuffer[0:nRX] = wfd['rx']['P'][0:nRX]
                rxBuffer[nRX + 1:-1] = np.NaN
                RX.append(rxBuffer.copy())
                nPeaks.append(wfd['rx']['count'])
                rx_samp0.append(wfd['rx']['pos'])
        shots = np.arange(shot0, shotN, dtype=int)
        try:
            result = {
                'az': D_in['/laser/scan_azimuth'],
                'dt': dt,
                'elevation': D_in['/footprint/elevation'],
                'latitude': D_in['/footprint/latitude'],
                'longitude': D_in['/footprint/longitude']
            }
        except KeyError:
            result = {}
        result['calrng'] = D_in['/laser/calrng']
        result['seconds_of_day'] = D_in['/waveforms/twv/shot/seconds_of_day']

        if readTX:
            TX = np.c_[TX].transpose()
            result['TX'] = waveform(
                np.arange(TX.shape[0]) * dt,
                TX,
                shots=shots,
                t0=tx_samp0 * dt,
                seconds_of_day=D_in['/waveforms/twv/shot/seconds_of_day'])

        L_TX = 30
        if readRX:
            RX = np.c_[RX].transpose()
            nPeaks = np.c_[nPeaks].ravel()
            result['RX'] = waveform(
                np.arange(RX.shape[0]) * dt,
                RX,
                shots=shots,
                nPeaks=nPeaks,
                t0=rx_samp0 * dt,
                seconds_of_day=D_in['/waveforms/twv/shot/seconds_of_day'])
            result['rx_samp0'] = rx_samp0
            result['RX'].error_flag[
                np.abs(result['calrng'] -
                       (result['RX'].t0 * .15 - L_TX)) > 55] = 1
        result['shots'] = shots
    return result
示例#10
0
def errors_for_one_scat_file(scat_file, TX_file, out_file=None):

    N_WFs = 256
    sigma_vals = [0, 0.25, 0.5, 1, 2]
    A_vals = [25., 50., 100., 200.]

    with h5py.File(TX_file) as h5f:
        TX = waveform(np.array(h5f['/TX/t']), np.array(h5f['/TX/p']))
    TX.t -= TX.nSigmaMean()[0]
    TX.tc = 0
    WF_library = dict()
    WF_library.update({0.: TX})
    if scat_file is not None:
        WF_library.update(make_rx_scat_catalog(TX, h5_file=scat_file))
    K0_vals = np.sort(list(WF_library))[::5]
    #K0_vals=np.sort(list(WF_library))[-2:]

    N_out = len(sigma_vals) * len(A_vals) * len(K0_vals)
    Dstats = {
        field: np.zeros(N_out) + np.NaN
        for field in {
            'K16', 'K84', 'sigma16', 'sigma84', 'sigma', 'A', 'K0', 'Kmed',
            'Ksigma', 'Ksigma_est'
        }
    }

    catalogBuffer = None
    ii = 0

    for key in K0_vals:
        for sigma in sigma_vals:
            for A in A_vals:
                if sigma > 0:
                    BW = broadened_WF(WF_library[key], sigma)
                else:
                    BW = waveform(WF_library[key].t, WF_library[key].p)
                BW.normalize()
                WFs = waveform(
                    TX.t,
                    np.tile(A * BW.p, [1, N_WFs]) + np.random.randn(
                        BW.p.size * N_WFs).reshape(BW.p.size, N_WFs))
                WFs.shots = np.arange(WFs.size)
                D_out, rxData, D, catalogBuffer = proc_RX(
                    None,
                    np.arange(N_WFs),
                    rxData=WFs,
                    sigmas=np.array([0., 1.]),
                    deltas=np.arange(-0.5, 1, 0.5),
                    TX=TX,
                    WF_library=WF_library,
                    catalogBuffer=catalogBuffer)
                sR = sps.scoreatpercentile(D_out['sigma'], [16, 84])
                Dstats['A'][ii] = A
                Dstats['sigma'][ii] = sigma
                Dstats['K0'][ii] = key
                Dstats['sigma16'][ii] = sR[0]
                Dstats['sigma84'][ii] = sR[1]
                KR = sps.scoreatpercentile(D_out['K0'], [16, 84])
                Dstats['K16'][ii] = KR[0]
                Dstats['K84'][ii] = KR[1]
                Dstats['Kmed'][ii] = np.nanmedian(D_out['K0'])
                Dstats['Ksigma_est'][ii] = np.nanmedian(D_out['Kmax'] -
                                                        D_out['Kmin'])
                Dstats['Ksigma'][ii] = np.nanstd(D_out['K0'])
                print([key, sigma, A, KR - key, Dstats['Ksigma'][ii]])
                ii += 1

    print("yep")
    if out_file is not None:
        if os.path.isfile(out_file):
            os.remove(out_file)
        out_h5 = h5py.File(out_file, 'w')
        for kk in Dstats:
            out_h5.create_dataset(kk, data=Dstats[kk])
        out_h5.close()
    return Dstats
示例#11
0
def broadened_WF(TX, sigma):
    nK = 3 * np.ceil(sigma / TX.dt)
    tK = np.arange(-nK, nK + 1) * TX.dt
    K = gaussian(tK, 0, sigma)
    K = K / np.sum(K)
    return waveform(TX.t, np.convolve(TX.p.ravel(), K, 'same'))
示例#12
0
def main(args):
    # main method : open the input files, create output files, process waveforms

    input_files={}
    for ii, ch in enumerate(args.ch_names):
        if ch != 'None':
            input_files[ch]=args.input_files[ii]
    channels = list(input_files.keys())

    # get the waveform count from the output file
    shots= choose_shots(input_files, args.reduce_by)
    nWFs=np.minimum(args.nShots, shots[channels[0]].size)
    lastShot=np.minimum(args.startShot+args.nShots, len(shots[channels[0]]))
    nWFs = lastShot-args.startShot+1

    # make the output file
    if os.path.isfile(args.output_file):
        os.remove(args.output_file)
    # define the output datasets
    outDS={}
    outDS['ch']=['R', 'A', 'B', 'delta_t', 't0','tc', 'noise_RMS','shot','Amax','seconds_of_day','nPeaks']
    outDS['both']=['R', 'K0', 'Kmin', 'Kmax', 'sigma']
    outDS['location']=['latitude', 'longitude', 'elevation']
    out_h5 = h5py.File(args.output_file,'w')
    for grp in ['both','location']:
        out_h5.create_group('/'+grp)
        for DS in outDS[grp]:
            out_h5.create_dataset('/'+grp+'/'+DS, (nWFs,), dtype='f8')
    for ch in channels:
        out_h5.create_group('/'+ch)
        for DS in outDS['ch']:
            out_h5.create_dataset('/'+ch+'/'+DS, (nWFs,), dtype='f8')

    # make groups in the file for transmit data
    for ch in channels:
        for field in ['t0','A','R','shot','sigma']:
            out_h5.create_dataset('/TX/%s/%s' % (ch, field), (nWFs,))

    if args.waveforms:
        for ch in channels:
            out_h5.create_dataset('RX/%s/p' % ch, (192, nWFs))
            out_h5.create_dataset('RX/%s/p_fit' % ch, (192, nWFs))
            out_h5.create_dataset('RX/%s/t_shift' % ch, (nWFs,))
    TX={}
    # get the transmit pulse
    for ind, ch in enumerate(channels):
        with h5py.File(args.TXfiles[ind],'r') as fh:
            TX[ch]=waveform(np.array(fh['/TX/t']), np.array(fh['/TX/p']) )
        TX[ch].t -= TX[ch].nSigmaMean()[0]
        TX[ch].tc = 0
        TX[ch].normalize()
    # write the transmit pulse to the file
    for ch in channels:
        out_h5.create_dataset("/TX/%s/t" % ch, data=TX[ch].t.ravel())
        out_h5.create_dataset("/TX/%s/p" % ch, data=TX[ch].p.ravel())

    # initialize the library of templates for the transmit waveforms
    TX_library={}
    for ind, ch in enumerate(channels):
        TX_library[ch] = listDict()
        TX_library[ch].update({0.:TX[ch]})

    # initialize the library of templates for the received waveforms
    WF_library=dict()
    for ind, ch in enumerate(channels):
        WF_library[ch] = dict()
        WF_library[ch].update({0.:TX[ch]})
        WF_library[ch].update(make_rx_scat_catalog(TX[ch], h5_file=args.scat_files[ind]))


    print("Returns:")
    # loop over start vals (one block at a time...)
    # choose how to divide the output
    blocksize=1000
    start_vals=args.startShot+np.arange(0, nWFs, blocksize, dtype=int)

    catalog_buffers={ch:listDict() for ch in channels}
    TX_catalog_buffers={ch:listDict() for ch in channels}
    time_old=time()

    sigmas=np.arange(0, 5, 0.25)
    # choose a set of delta t values
    delta_ts=np.arange(-1., 1.5, 0.5)

    D={}
    for shot0 in start_vals:
        outShot0=shot0-args.startShot
        these_shots=np.arange(shot0, np.minimum(shot0+blocksize, lastShot), dtype=int)
        if len(these_shots) < 1:
            continue
        #tic=time()
        wf_data={}
        for ch in channels:
            ch_shots=shots[ch][these_shots]
            # make the return waveform structure
            D=read_ATM_file(input_files[ch], shot0=ch_shots[0], nShots=ch_shots[-1]-ch_shots[0]+1)
            
            # fit the transmit data for this channel and these pulses
            D['TX']=D['TX'][np.in1d(D['TX'].shots, ch_shots)]
            # set t0 to the center of the waveform
            t_wf_ctr = np.nanmean(D['TX'].t)
            D['TX'].t0 += t_wf_ctr
            D['TX'].t -= t_wf_ctr
            # subtract the background noise
            D['TX'].subBG(t50_minus=3)
            # calculate tc (centroid time relative to t0)
            D['TX'].tc = D['TX'].threshold_centroid(fraction=0.38)
            #D_out_TX, catalog_buffers= fit_catalogs({ch:D['TX']}, TX_library, sigmas, delta_ts, \
            #                            t_tol=0.25, sigma_tol=0.25,  \
            #                            return_catalogs=True,  catalogs=catalog_buffers)
            D_out_TX = fit_catalogs({ch:D['TX']}, TX_library, sigmas, delta_ts, \
                                        t_tol=0.25, sigma_tol=0.25,  \
                                        return_catalogs=False,  catalogs=TX_catalog_buffers, params=outDS)
            N_out=len(D_out_TX[ch]['A'])
            for field in ['t0','A','R','shot']:
                out_h5['/TX/%s/%s' % (ch, field)][outShot0:outShot0+N_out]=D_out_TX[ch][field].ravel()
            out_h5['/TX/%s/%s' % (ch, 'sigma')][outShot0:outShot0+N_out]=D_out_TX['both']['sigma'].ravel()

            wf_data[ch]=D['RX']
            wf_data[ch]=wf_data[ch][np.in1d(wf_data[ch].shots, ch_shots)]
            # identify the samples that have clipped amplitudes:
            clipped=wf_data[ch].p >= 255
            t_wf_ctr = np.nanmean(wf_data[ch].t)
            wf_data[ch].t -= t_wf_ctr
            wf_data[ch].t0 += t_wf_ctr
            wf_data[ch].subBG(t50_minus=3)

            if 'latitude' in D:
                # only one channel has geolocation information. Copy it, will use the 'shot' field to match it to the output data
                loc_info={ff:D[ff] for ff in outDS['location']}
                loc_info['channel']=ch
                loc_info['shot']=D['shots']

            wf_data[ch].tc = wf_data[ch].threshold_centroid(fraction=0.38)
            wf_data[ch].p[clipped]=np.NaN
        # now fit the returns with the waveform model
        tic=time()
        D_out, catalog_buffers= fit_catalogs(wf_data, WF_library, sigmas, delta_ts, \
                                            t_tol=0.25, sigma_tol=0.25, return_data_est=args.waveforms, \
                                            return_catalogs=True,  catalogs=catalog_buffers, params=outDS)

        delta_time=time()-tic

        # write out the fit information
        N_out=D_out['both']['R'].size
        for ch in channels:
            for key in outDS['ch']:
                try:
                    out_h5[ch][key][outShot0:outShot0+N_out]=D_out[ch][key].ravel()
                except OSError:
                    print("OSError for channel %s,  key=%s, outshot0=%d, outshotN=%d, nDS=%d"% (ch, key, outShot0, outShot0+N_out, out_h5[key].size))
        for key in outDS['both']:
            try:
                out_h5['both'][key][outShot0:outShot0+N_out]=D_out['both'][key].ravel()
            except OSError:
                print("OSError for both channels, key=%s, outshot0=%d, outshotN=%d, nDS=%d"% (key, outShot0, outShot0+N_out, out_h5[key].size))

        # write out the location info
        loc_ind=np.flatnonzero(np.in1d(loc_info['shot'], D_out[loc_info['channel']]['shot']))
        for field in outDS['location']:
            out_h5['location'][field][outShot0:outShot0+N_out]=loc_info[field][loc_ind]

        # write out the waveforms
        if args.waveforms:
            for ch in channels:
                out_h5['RX/'+ch+'/p_fit'][:, outShot0:outShot0+N_out] = np.squeeze(D_out[ch]['wf_est']).T
                out_h5['RX/'+ch+'/p'][:, outShot0:outShot0+N_out] = wf_data[ch].p
                out_h5['RX/'+ch+'/t_shift'][outShot0:outShot0+N_out] = D_out[ch]['t_shift'].ravel()

        print("  shot=%d out of %d, N_keys=%d, dt=%5.1f" % (shot0+blocksize, start_vals[-1]+blocksize, len(catalog_buffers['G'].keys()), delta_time))
    print("   time to fit RX=%3.2f" % (time()-time_old))

    if args.waveforms:
        for ch in channels:
            out_h5.create_dataset('RX/'+ch+'/t', data=wf_data[ch].t.ravel())

    out_h5.close()
示例#13
0
def fit_catalog(WFs, catalog_in, sigmas, delta_ts, t_tol=None, sigma_tol=None, return_data_est=False, return_catalog=False, catalog=None):
    """
    Search a library of waveforms for the best match between the broadened, shifted library waveform
    and the target waveforms

    Inputs:
        WFs: a waveform object, whose fields include:
            't': the waveform's time vector
            'p': the power samples of the waveform
            'tc': a center time relative to which the waveform's time is shifted
        catalog_in: A dictionary containing waveform objects that will be broadened and
                    shifted to match the waveforms in 'WFs'
        sigmas: a list of spread values that will be searched for each template and waveform
                The search over sigmas terminates when a minimum is found
        delta_ts: a list of time-shift values that will be searched for each template and
                waveform.  All of these will be searched, then the results will be refined
                to a tolerance of t_tol
        keyword arguments:
            return_data_est:  set to 'true' if the algorithm should return the best-matching
                shifted and broadened template for each input
            t_tol: tolerance for the time search, defaults to WF.t_samp/10
    Outputs:
        WFp: a set of best-fitting waveform parameters that give:
            delta_t: the time-shift required to align the template and measured waveforms
            sigma: the broadening applied to the measured waveform
            k0: the key into the waveform catalog for the best-fitting waveform

    """
    # set a sensible tolerance for delta_t if none is specified
    if t_tol is None:
        t_tol=WFs.dt*0.1
    if sigma_tol is None:
        sigma_tol=0.25
    # make an empty output_dictionary
    WFp_empty={f:np.NaN for f in ['K0','R','A','B','delta_t','sigma','t0','Kmin','Kmax','shot']}
    if return_data_est:
        WFp_empty['wf_est']=np.zeros_like(WFs.t)+np.NaN

    # make an empty container where we will keep waveforms we've tried already
    if catalog is None:
        catalog=listDict()
    keys=np.sort(list(catalog_in))

    # loop over the library of templates
    for ii, kk in enumerate(keys):
        # check if we've searched this template before, otherwise copy it into
        # the library of checked templates
        if [kk] not in catalog:
            # make a copy of the current template
            temp=catalog_in[kk]
            catalog[[kk]]=waveform(temp.t, temp.p, t0=temp.t0, tc=temp.tc)

    W_catalog=np.zeros(keys.shape)
    for ind, key in enumerate(keys):
        W_catalog[ind]=catalog_in[key].fwhm()[0]

    fit_params=[WFp_empty.copy() for ii in range(WFs.size)]
    sigma_last=None
    t_center=WFs.t.mean()
    # loop over input waveforms
    for WF_count in range(WFs.size):
        WF=WFs[WF_count]
        if WF.nPeaks > 1:
            continue
        # shift the waveform to put its tc at the center of the time vector
        delta_samp=np.round((WF.tc-t_center)/WF.dt)
        WF.p=integer_shift(WF.p, -delta_samp)
        WF.t0=-delta_samp*WF.dt

        # set up a matching dictionary (contains keys of waveforms and their misfits)
        M=listDict()
        # this is the bulk of the work, and it's where problems happen.  Wrap it in a try:
        # and write out errors to be examined later
        if True:
            if len(keys)>1:
                 # Search over input keys to find the best misfit between this template and the waveform
                fB=lambda ind: fit_broadened(delta_ts, None,  WF, catalog, M, [keys[ind]], sigma_tol=sigma_tol, t_tol=t_tol, sigma_last=sigma_last)
                W_match_ind=np.flatnonzero(W_catalog >= WF.fwhm()[0])
                if len(W_match_ind) >0:
                    ind=np.array(tuple(set([0, W_match_ind[0]-2,  W_match_ind[0]+2])))
                    ind=ind[(ind >= 0) & (ind<len(keys))]
                else:
                    ind=[2, 4]
                iBest, Rbest = golden_section_search(fB, ind, delta_x=2, bnds=[0, len(keys)-1], integer_steps=True, tol=1)
                iBest=int(iBest)
            else:
                # only one key in input, return its misfit
                Rbest=fit_broadened(delta_ts, None, WF, catalog, M, [keys[0]], sigma_tol=sigma_tol, t_tol=t_tol, sigma_last=sigma_last)
                iBest=0
            this_key=[keys[iBest]]
            M['best']={'key':this_key, 'R':Rbest}
            searched_keys = np.array([this_key for this_key in keys if [this_key] in M])
            R=np.array([M[[ki]]['best']['R'] for ki in searched_keys])

            # recursively traverse the M dict for the best match.  The lowest-level match
            # will not have a 'best' entry
            while 'best' in M[this_key]:
                this_key=M[this_key]['best']['key']
            # write out the best model information
            fit_params[WF_count].update(M[this_key])
            fit_params[WF_count]['delta_t'] -= WF.t0[0]
            fit_params[WF_count]['shot'] = WF.shots[0]
            sigma_last=M[this_key]['sigma']
            R_max=fit_params[WF_count]['R']*(1.+1./np.sqrt(WF.t.size))
            if np.sum(searched_keys>0)>=3:
                these=np.flatnonzero(searched_keys>0)
                if len(these) > 3:
                     ind_keys=np.argsort(R[these])
                     these=these[ind_keys[0:4]]
                E_roots=np.polynomial.polynomial.Polynomial.fit(np.log10(searched_keys[these]), R[these]-R_max, 2).roots()
                if np.any(np.imag(E_roots)!=0):
                    fit_params[WF_count]['Kmax']=10**np.minimum(3,np.polynomial.polynomial.Polynomial.fit(np.log10(searched_keys[these]), R[these]-R_max, 1).roots()[0])
                    fit_params[WF_count]['Kmin']=np.min(searched_keys[R<R_max])
                else:
                    fit_params[WF_count]['Kmin']=10**np.min(E_roots)
                    fit_params[WF_count]['Kmax']=10**np.max(E_roots)
            if (0. in searched_keys) and R[searched_keys==0]<R_max:
                fit_params[WF_count]['Kmin']=0.

            #print(this_key+[R[iR][0]])
            if return_data_est or DOPLOT:
                #             wf_misfit(delta_t, sigma, WF, catalog, M, key_top, G=None, return_data_est=False):
                WF.t=WF.t-WF.t0
                R0, wf_est=wf_misfit(fit_params[WF_count]['delta_t'], fit_params[WF_count]['sigma'], WFs[WF_count], catalog, M, [this_key[0]], return_data_est=True)
                fit_params[WF_count]['wf_est']=wf_est#integer_shift(wf_est, -delta_samp)
            if DOPLOT:
                plt.figure();
                plt.plot(WF.t, integer_shift(WF.p, delta_samp),'k.')
                plt.plot(WF.t, wf_est,'r')
                plt.title('K=%f, dt=%f, sigma=%f, R=%f' % (this_key[0], fit_params[WF_count]['delta_t'], fit_params[WF_count]['sigma'], fit_params[WF_count]['R']))
                print(WF_count)
        #except KeyboardInterrupt:
        #    sys.exit()
        #except Exception as e:
        #    print("Exception thrown for shot %d" % WF.shots)
        #    print(e)
        #    pass
        if np.mod(WF_count, 1000)==0 and WF_count > 0:
            print('    N=%d, N_keys=%d' % (WF_count, len(list(catalog))))

    result=dict()
    for key in WFp_empty:
        if key in ['wf_est']:
            result[key]=np.concatenate( [ ii['wf_est'] for ii in fit_params ], axis=1 )
        else:
            result[key]=np.array([ii[key] for ii in fit_params]).ravel()

    if return_catalog:
        return result, catalog
    else:
        return result
示例#14
0
def fit_catalogs(WFs,
                 catalogs_in,
                 sigmas,
                 delta_ts,
                 t_tol=None,
                 sigma_tol=None,
                 return_data_est=False,
                 return_catalogs=False,
                 catalogs=None,
                 params=None,
                 M_list=None):
    """
    Search a library of waveforms for the best match between the broadened, shifted library waveform
    and the target waveforms

    Inputs:
        WFs: a dict of waveform objects, whose entries include:
            't': the waveform's time vector
            'p': the power samples of the waveform
            'tc': a center time relative to which the waveform's time is shifted
        catalog_in: A dictionary containing waveform objects that will be broadened and
                    shifted to match the waveforms in 'WFs'
        sigmas: a list of spread values that will be searched for each template and waveform
                The search over sigmas terminates when a minimum is found
        delta_ts: a list of time-shift values that will be searched for each template and
                waveform.  All of these will be searched, then the results will be refined
                to a tolerance of t_tol
        'params' : a dict giving a list of parameters to return in D_out
        keyword arguments:
            return_data_est:  set to 'true' if the algorithm should return the best-matching
                shifted and broadened template for each input
            t_tol: tolerance for the time search, defaults to WF.t_samp/10
    Outputs:
        WFp: a set of best-fitting waveform parameters that give:
            delta_t: the time-shift required to align the template and measured waveforms
            sigma: the broadening applied to the measured waveform
            k0: the key into the waveform catalog for the best-fitting waveform

    """
    # set a sensible tolerance for delta_t if none is specified
    if t_tol is None:
        t_tol = WFs.dt
    if sigma_tol is None:
        sigma_tol = 0.25

    channels = list(WFs.keys())

    N_shots = WFs[channels[0]].size
    N_samps = WFs[channels[0]].t.size

    # make an empty output_dictionary
    if params is None:
        params = {
            'ch': [
                'R', 'A', 'B', 'noise_RMS', 'delta_t', 'shot', 't0', 'tc',
                't_shift'
            ],
            'both': ['K0', 'R', 'sigma', 'Kmin', 'Kmax']
        }
    WFp_empty = {}
    for ch in channels:
        WFp_empty[ch] = {f: np.NaN for f in params['ch']}
    WFp_empty['both'] = {f: np.NaN for f in params['both']}
    if return_data_est:
        for ch in channels:
            WFp_empty[ch]['wf_est'] = np.zeros_like(WFs[ch].t) + np.NaN
            WFp_empty[ch]['t_shift'] = np.NaN

    if catalogs is None:
        catalogs = {ch: listDict() for ch in channels}

    # make a container for the pulse widths for the catalogs, and copy the input catalogs into the buffer catalogs
    W_catalogs = {}
    for ch in channels:
        k_vals = np.sort(list(catalogs_in[ch]))
        W_catalogs[ch] = np.zeros(k_vals.shape)
        # loop over the library of templates
        for ii, kk in enumerate(k_vals):
            # record the width of the waveform
            W_catalogs[ch][ii] = catalogs_in[ch][kk].fwhm()[0]
            # check if we've searched this template before, otherwise copy it into
            # the library of checked templates
            if [kk] not in catalogs[ch]:
                # make a copy of the current template
                temp = catalogs_in[ch][kk]
                catalogs[ch][[kk]] = waveform(temp.t,
                                              temp.p,
                                              t0=temp.t0,
                                              tc=temp.tc)

    fit_param_list = []
    sigma_last = None
    t_center = {ch: WFs[ch].t.mean() for ch in channels}
    last_keys = {ch: [] for ch in channels}

    # loop over input waveforms
    for WF_count in range(N_shots):
        fit_params = deepcopy(WFp_empty)
        WF = {ch: WFs[ch][WF_count] for ch in channels}

        # skip waveforms for which we detected an error (so far just waveforms that have t0 far from the calrng value)
        if np.any([WF[ch].error_flag for ch in WF.keys()]):
            continue
        if np.any([np.isnan(WF[ch].tc) for ch in WF.keys()]):
            continue
        # skip multi-peak returns:
        #n_peaks=np.array([WF[ch].nPeaks for ch in channels])
        #if np.any(n_peaks > 1):
        #    continue
        # shift the waveforms to put their tcs at the center of the time vector
        # doing this means that we have to store fewer shifted catalogs
        for ch in channels:
            delta_samp = np.round((WF[ch].tc - t_center[ch]) / WF[ch].dt)
            WF[ch].p = integer_shift(WF[ch].p, -delta_samp)
            WF[ch].t0 += delta_samp * WF[ch].dt
            WF[ch].tc -= delta_samp * WF[ch].dt
            WF[ch].t_shift = delta_samp * WF[ch].dt

        # set up a matching dictionary (contains keys of waveforms and their misfits)
        Ms = {ch: listDict() for ch in channels}
        # this is the bulk of the work, and it's where problems happen.  Wrap it in a try:
        # and write out errors to be examined later
        try:
            if len(k_vals) > 1:
                # find the best misfit between this template and the waveform
                fB = lambda ind: fit_broadened(delta_ts,
                                               None,
                                               WF,
                                               catalogs,
                                               Ms, [k_vals[ind]],
                                               sigma_tol=sigma_tol,
                                               t_tol=t_tol,
                                               sigma_last=sigma_last,
                                               refine_sigma=True)
                W_broad_ind = 0
                # find the first catalog entry that's broader than the waveform (check both cnannels, pick the broader one)
                for ch in channels:
                    this_broad_ind = np.flatnonzero(
                        W_catalogs[ch] >= WF[ch].fwhm()[0])
                    if len(this_broad_ind) == 0:
                        this_broad_ind = len(W_catalogs[ch])
                    else:
                        this_broad_ind = this_broad_ind[0]
                    W_broad_ind = np.maximum(W_broad_ind, this_broad_ind)
                # search two steps on either side of the broadness-matched waveform, as well as zero (all broadening due to roughness)
                key_search_ind = np.array(
                    sorted(tuple(set([0, W_broad_ind - 2, W_broad_ind + 2]))))
                key_search_ind = key_search_ind[(key_search_ind >= 0) &
                                                (key_search_ind < len(k_vals))]
                search_hist = {}
                iBest, Rbest = golden_section_search(fB,
                                                     key_search_ind,
                                                     delta_x=2,
                                                     bnds=[0,
                                                           len(k_vals) - 1],
                                                     integer_steps=True,
                                                     tol=1,
                                                     refine_parabolic=False,
                                                     search_hist=search_hist)
                iBest = int(iBest)
            else:
                _ = fit_broadened(delta_ts,
                                  None,
                                  WF,
                                  catalogs,
                                  Ms, [k_vals[0]],
                                  sigma_tol=sigma_tol,
                                  t_tol=t_tol,
                                  sigma_last=sigma_last)
                iBest = 0
                Rbest = Ms[ch][[k_vals[0]]]['best']['R']
            this_kval = [k_vals[iBest]]
            fit_params['both']['R'] = Rbest
            fit_params['both']['K0'] = this_kval
            R_dict = {}
            sigma_last = 0
            for ch in channels:
                M = Ms[ch]
                M['best'] = {'key': this_kval, 'R': M[this_kval]['best']['R']}
                for ki in [this_key for this_key in k_vals if [this_key] in M]:
                    if ki in R_dict:
                        R_dict[ki] += M[[ki]]['best']['R'] / WF[ch].noise_RMS
                    else:
                        R_dict[ki] = M[[ki]]['best']['R'] / WF[ch].noise_RMS

                # recursively traverse the M dict for the best match.  The lowest-level match
                # will not have a 'best' entry
                this_key = this_kval
                while 'best' in M[this_key]:
                    this_key = M[this_key]['best']['key']
                # write out the best model information
                fit_params[ch].update(M[this_key])
                fit_params[ch]['noise_RMS'] = WF[ch].noise_RMS[0]
                fit_params[ch]['tc'] = WF[ch].tc[0]
                fit_params[ch]['Amax'] = np.nanmax(WF[ch].p)
                fit_params[ch]['seconds_of_day'] = WF[ch].seconds_of_day
                #fit_params[ch][WF_count]['delta_t'] -= WF[ch].t0
                fit_params['both']['shot'] = WF[ch].shots[0]
                if 'nPeaks' in params['ch']:
                    fit_params[ch]['nPeaks'] = WF[ch].nPeaks
                sigma_last = np.maximum(sigma_last, M[this_key]['sigma'])
            fit_params['both']['sigma'] = M[this_key]['sigma']
            searched_k_vals = np.array(sorted(R_dict.keys()))
            R = np.array([R_dict[ki] for ki in searched_k_vals]).ravel()

            R_max = fit_params['both']['R'] * (1. + 1. / np.sqrt(N_samps))

            if np.sum(searched_k_vals > 0) >= 0:
                these = np.flatnonzero(searched_k_vals > 0)
            if np.sum(R[these] > 0) > 2:
                if len(these) > 3:
                    ind_k_vals = np.argsort(R[these])
                    these = these[ind_k_vals[0:4]]
                E_roots = np.polynomial.polynomial.Polynomial.fit(
                    np.log10(searched_k_vals[these]), R[these] - R_max,
                    2).roots()
                if (len(E_roots) == 0) or np.any(np.imag(E_roots) != 0):
                    fit_params['both']['Kmax']=10**np.minimum(3,\
                              np.polynomial.polynomial.Polynomial.fit(\
                              np.log10(searched_k_vals[these]), R[these]-R_max, 1).roots()[0])
                    fit_params['both']['Kmin'] = np.min(
                        searched_k_vals[R < R_max])
                else:
                    fit_params['both']['Kmin'] = 10**np.min(E_roots)
                    fit_params['both']['Kmax'] = 10**np.max(E_roots)
            if (0. in searched_k_vals) and R[searched_k_vals == 0] < R_max:
                fit_params['both']['Kmin'] = 0.
            #copy remaining waveform parameters to the output data structure
            for ch in channels:
                fit_params[ch]['shot'] = WF[ch].shots[0]
                fit_params[ch]['t0'] = WF[ch].t0[0]
                fit_params[ch]['t_shift'] = WF[ch].t_shift[0]
                fit_params[ch]['noise_RMS'] = WF[ch].noise_RMS[0]
            #print(this_key+[R[iR][0]])
            if return_data_est or DOPLOT:
                # call WF_misfit for each channel
                wf_est = {}
                for ch, WFi in WF.items():
                    R0, wf_est = wf_misfit(fit_params[ch]['delta_t'],
                                           fit_params[ch]['sigma'],
                                           WF[ch],
                                           catalogs[ch],
                                           Ms[ch], [this_key[0]],
                                           return_data_est=True)
                    fit_params[ch]['wf_est'] = wf_est
                    fit_params[ch]['t_shift'] = WF[ch].t_shift

            if KEY_SEARCH_PLOT:
                import matplotlib.pyplot as plt
                ch_keys = {}
                new_keys = {}
                fig = plt.gcf()
                fig.clf()
                for ind, ch in enumerate(channels):
                    fig.add_subplot(2, 1, ind + 1)
                    ch_keys[ch] = [[key[0:2]] for key in catalogs[ch].keys()
                                   if len(key) > 1]
                    new_keys[ch] = [
                        key for key in ch_keys[ch] if key not in last_keys[ch]
                    ]
                    kxy = np.concatenate(ch_keys[ch], axis=0)
                    plt.plot(np.log10(kxy[:, 0]), kxy[:, 1], 'k.')
                    if len(new_keys[ch]) > 0:
                        kxy_new = np.concatenate(new_keys[ch], axis=0)
                        plt.plot(np.log10(kxy_new[:, 0]), kxy_new[:, 1], 'ro')
                last_keys = {ch: ch_keys[ch].copy() for ch in ch_keys.keys()}
            # report
            fit_param_list += [fit_params]
            if DOPLOT:
                import matplotlib.pyplot as plt
                plt.figure()
                colors = {'IR': 'r', 'G': 'g'}
                this_title = ''
                for ch in channels:
                    #plt.plot(WF[ch].t, integer_shift(WF[ch].p, delta_samp),'.', color=colors[ch])
                    plt.plot(WF[ch].t, WF[ch].p, 'x', color=colors[ch])
                    plt.plot(WF[ch].t, wf_est[ch], color=colors[ch])
                    this_title += '%s: K=%3.2g, dt=%3.2g, $\sigma$=%3.2g, R=%3.2f\n' % (
                        ch, this_key[0], fit_params[ch]['delta_t'],
                        fit_params[ch]['sigma'], fit_params[ch]['R'])
                plt.title(this_title[0:-2])
                print(WF_count)
            if M_list is not None:
                M_list += [Ms]
        except KeyboardInterrupt:
            sys.exit()
        except Exception as e:
            print("Exception thrown for shot %d" % WF[channels[0]].shots)
            print(e)
            pass
        if np.mod(WF_count, 1000) == 0 and WF_count > 0:
            print('    N=%d, N_keys=%d, %d' %
                  (WF_count, len(list(catalogs[channels[0]])),
                   len(list(catalogs[channels[1]]))))

    result = {}
    for ch in channels + ['both']:
        result[ch] = {}
        for field in WFp_empty[ch].keys():
            try:
                result[ch][field] = np.array(
                    [ii[ch][field] for ii in fit_param_list])
            except ValueError:
                print("problem with channel %s, field %s" % (ch, field))
    if return_catalogs:
        return result, catalogs
    else:
        return result
示例#15
0
def TX_corr(TX, T_win, sigma, TX_sigma=None):
    if TX_sigma is None:
        TX_sigma=TX.robust_spread()
    sigma_extra=np.sqrt(np.max([0, sigma**2-TX_sigma**2]))
    TXb=waveform(TX.t, TX.p.copy()).broaden(sigma_extra)
    return wf_med_bar(TXb, T_win, t_tol=0.01*0.25)
示例#16
0
def resample_wf(WF, dt):
    ti=np.arange(WF.t[0], WF.t[-1]+dt, dt)
    sz=[ti.size, 1]
    fi=interp1d(WF.t.ravel(), WF.p.ravel(), kind='cubic', fill_value=0, bounds_error=False)
    return waveform(ti.reshape(sz), fi(ti).reshape(sz))
示例#17
0
def errors_for_one_scat_file(scat_files, TX_files, channels, out_file=None):

    sigmas = np.arange(0, 5, 0.25)
    # choose a set of delta t values
    delta_ts = np.arange(-1., 1.5, 0.5)

    N_WFs = 256
    TX = {}
    # get the transmit pulse
    for ind, ch in enumerate(channels):
        with h5py.File(args.TXfiles[ind], 'r') as fh:
            TX[ch] = waveform(np.array(fh['/TX/t']), np.array(fh['/TX/p']))
        TX[ch].t -= TX[ch].nSigmaMean()[0]
        TX[ch].tc = 0
        TX[ch].normalize()

    # initialize the library of templates for the transmit waveforms
    TX_library = {}
    for ind, ch in enumerate(channels):
        TX_library[ch] = listDict()
        TX_library[ch].update({0.: TX[ch]})

    # initialize the library of templates for the received waveforms
    WF_library = dict()
    for ind, ch in enumerate(channels):
        WF_library[ch] = dict()
        WF_library[ch].update({0.: TX[ch]})
        WF_library[ch].update(
            make_rx_scat_catalog(TX[ch], h5_file=args.scat_files[ind]))

    out_fields = [
        'K16', 'K84', 'sigma16', 'sigma84', 'sigma', 'A_scale', 'K0', 'Kmed',
        'Ksigma', 'Ksigma_est', 'N', 'fitting_time'
    ]
    for ch in channels:
        out_fields.append('A_' + ch)

    noise_RMS = {'G': 1, 'IR': 0.25}
    unit_amp_target = {'G': 175, 'IR': 140}
    sigma_vals = [0, 0.5, 1, 2]
    A_scale = [0.5, 1, 1.25]
    #A_vals=[50., 100., 200.]
    K0_vals = np.array(list(WF_library['G'].keys()))[::4]
    N_out = len(sigma_vals) * len(A_scale) * len(K0_vals)
    Dstats = {field: np.zeros(N_out) + np.NaN for field in out_fields}
    # calculate waveforms with no scaling applied
    WF_unscaled = make_sim_WFs(1, WF_library,
                               list(WF_library['G'].keys())[1], 0, {
                                   'G': 0,
                                   'IR': 0
                               }, {
                                   'G': 1,
                                   'IR': 1
                               })
    ii = 0
    for key in K0_vals:
        for sigma in sigma_vals:
            for A in A_scale:
                # calculate scaling to apply to the waveforms to achieve the
                # target peak amplitude (could be done outside the loop)
                amp_scale = {
                    ch: A * unit_amp_target[ch] / WF_unscaled[ch].p.max()
                    for ch in channels
                }

                #Calculate the noise-free waveforms
                WF_expected=make_sim_WFs(1, WF_library, key, sigma,\
                                         {ch:0. for ch in channels}, amp_scale)
                if np.min([
                        WF_expected[ch].p.max() / noise_RMS[ch]
                        for ch in channels
                ]) < 3:
                    continue
                # calculate scaled and broadened waveforms
                WFs = make_sim_WFs(N_WFs, WF_library, key, sigma, noise_RMS,
                                   amp_scale)
                # fit the waveforms
                tic = time.time()
                D_out= fit_catalogs(WFs, WF_library, sigmas, delta_ts, \
                                            t_tol=0.25, sigma_tol=0.25)
                Dstats['fitting_time'][ii] = time.time() - tic
                for ch in channels:
                    Dstats['A_' + ch][ii] = WF_expected[ch].p.max()
                sR = sps.scoreatpercentile(D_out['both']['sigma'], [16, 84])
                Dstats['A_scale'][ii] = A
                Dstats['sigma'][ii] = sigma
                Dstats['K0'][ii] = key
                Dstats['sigma16'][ii] = sR[0]
                Dstats['sigma84'][ii] = sR[1]
                KR = sps.scoreatpercentile(D_out['both']['K0'], [16, 84])
                Dstats['K16'][ii] = KR[0]
                Dstats['K84'][ii] = KR[1]
                Dstats['Kmed'][ii] = np.nanmedian(D_out['both']['K0'])
                Dstats['Ksigma_est'][ii] = np.nanmedian(D_out['both']['Kmax'] -
                                                        D_out['both']['Kmin'])
                Dstats['Ksigma'][ii] = np.nanstd(D_out['both']['K0'])
                Dstats['N'][ii] = np.sum(np.isfinite(D_out['both']['K0']))
                print(
                    'K0=%2.2g, sigma=%2.2f, A=%2.2f, ER=[%2.2g, %2.2g], E=%2.2g'
                    % (key, sigma, A, KR[0] - key, KR[1] - key,
                       Dstats['Ksigma'][ii]))
                ii += 1

    print("yep")
    if out_file is not None:
        if os.path.isfile(out_file):
            os.remove(out_file)
        out_h5 = h5py.File(out_file, 'w')
        for kk in Dstats:
            out_h5.create_dataset(kk, data=Dstats[kk])
        out_h5.close()
    return Dstats