def main(args):
    # main method : open the input files, create output files, process waveforms

    input_files={}
    for ii, ch in enumerate(args.ch_names):
        if ch != 'None':
            input_files[ch]=args.input_files[ii]
    channels = list(input_files.keys())

    # get the waveform count from the output file
    shots= choose_shots(input_files, args.reduce_by)
    nWFs=np.minimum(args.nShots, shots[channels[0]].size)
    lastShot=np.minimum(args.startShot+args.nShots, len(shots[channels[0]]))
    nWFs = lastShot-args.startShot+1

    # make the output file
    if os.path.isfile(args.output_file):
        os.remove(args.output_file)
    # define the output datasets
    outDS={}
    outDS['ch']=['R', 'A', 'B', 'delta_t', 't0','tc', 'noise_RMS','shot','Amax','seconds_of_day','nPeaks']
    outDS['both']=['R', 'K0', 'Kmin', 'Kmax', 'sigma']
    outDS['location']=['latitude', 'longitude', 'elevation']
    out_h5 = h5py.File(args.output_file,'w')
    for grp in ['both','location']:
        out_h5.create_group('/'+grp)
        for DS in outDS[grp]:
            out_h5.create_dataset('/'+grp+'/'+DS, (nWFs,), dtype='f8')
    for ch in channels:
        out_h5.create_group('/'+ch)
        for DS in outDS['ch']:
            out_h5.create_dataset('/'+ch+'/'+DS, (nWFs,), dtype='f8')

    # make groups in the file for transmit data
    for ch in channels:
        for field in ['t0','A','R','shot','sigma']:
            out_h5.create_dataset('/TX/%s/%s' % (ch, field), (nWFs,))

    if args.waveforms:
        for ch in channels:
            out_h5.create_dataset('RX/%s/p' % ch, (192, nWFs))
            out_h5.create_dataset('RX/%s/p_fit' % ch, (192, nWFs))
            out_h5.create_dataset('RX/%s/t_shift' % ch, (nWFs,))
    TX={}
    # get the transmit pulse
    for ind, ch in enumerate(channels):
        with h5py.File(args.TXfiles[ind],'r') as fh:
            TX[ch]=waveform(np.array(fh['/TX/t']), np.array(fh['/TX/p']) )
        TX[ch].t -= TX[ch].nSigmaMean()[0]
        TX[ch].tc = 0
        TX[ch].normalize()
    # write the transmit pulse to the file
    for ch in channels:
        out_h5.create_dataset("/TX/%s/t" % ch, data=TX[ch].t.ravel())
        out_h5.create_dataset("/TX/%s/p" % ch, data=TX[ch].p.ravel())

    # initialize the library of templates for the transmit waveforms
    TX_library={}
    for ind, ch in enumerate(channels):
        TX_library[ch] = listDict()
        TX_library[ch].update({0.:TX[ch]})

    # initialize the library of templates for the received waveforms
    WF_library=dict()
    for ind, ch in enumerate(channels):
        WF_library[ch] = dict()
        WF_library[ch].update({0.:TX[ch]})
        WF_library[ch].update(make_rx_scat_catalog(TX[ch], h5_file=args.scat_files[ind]))


    print("Returns:")
    # loop over start vals (one block at a time...)
    # choose how to divide the output
    blocksize=1000
    start_vals=args.startShot+np.arange(0, nWFs, blocksize, dtype=int)

    catalog_buffers={ch:listDict() for ch in channels}
    TX_catalog_buffers={ch:listDict() for ch in channels}
    time_old=time()

    sigmas=np.arange(0, 5, 0.25)
    # choose a set of delta t values
    delta_ts=np.arange(-1., 1.5, 0.5)

    D={}
    for shot0 in start_vals:
        outShot0=shot0-args.startShot
        these_shots=np.arange(shot0, np.minimum(shot0+blocksize, lastShot), dtype=int)
        if len(these_shots) < 1:
            continue
        #tic=time()
        wf_data={}
        for ch in channels:
            ch_shots=shots[ch][these_shots]
            # make the return waveform structure
            D=read_ATM_file(input_files[ch], shot0=ch_shots[0], nShots=ch_shots[-1]-ch_shots[0]+1)
            
            # fit the transmit data for this channel and these pulses
            D['TX']=D['TX'][np.in1d(D['TX'].shots, ch_shots)]
            # set t0 to the center of the waveform
            t_wf_ctr = np.nanmean(D['TX'].t)
            D['TX'].t0 += t_wf_ctr
            D['TX'].t -= t_wf_ctr
            # subtract the background noise
            D['TX'].subBG(t50_minus=3)
            # calculate tc (centroid time relative to t0)
            D['TX'].tc = D['TX'].threshold_centroid(fraction=0.38)
            #D_out_TX, catalog_buffers= fit_catalogs({ch:D['TX']}, TX_library, sigmas, delta_ts, \
            #                            t_tol=0.25, sigma_tol=0.25,  \
            #                            return_catalogs=True,  catalogs=catalog_buffers)
            D_out_TX = fit_catalogs({ch:D['TX']}, TX_library, sigmas, delta_ts, \
                                        t_tol=0.25, sigma_tol=0.25,  \
                                        return_catalogs=False,  catalogs=TX_catalog_buffers, params=outDS)
            N_out=len(D_out_TX[ch]['A'])
            for field in ['t0','A','R','shot']:
                out_h5['/TX/%s/%s' % (ch, field)][outShot0:outShot0+N_out]=D_out_TX[ch][field].ravel()
            out_h5['/TX/%s/%s' % (ch, 'sigma')][outShot0:outShot0+N_out]=D_out_TX['both']['sigma'].ravel()

            wf_data[ch]=D['RX']
            wf_data[ch]=wf_data[ch][np.in1d(wf_data[ch].shots, ch_shots)]
            # identify the samples that have clipped amplitudes:
            clipped=wf_data[ch].p >= 255
            t_wf_ctr = np.nanmean(wf_data[ch].t)
            wf_data[ch].t -= t_wf_ctr
            wf_data[ch].t0 += t_wf_ctr
            wf_data[ch].subBG(t50_minus=3)

            if 'latitude' in D:
                # only one channel has geolocation information. Copy it, will use the 'shot' field to match it to the output data
                loc_info={ff:D[ff] for ff in outDS['location']}
                loc_info['channel']=ch
                loc_info['shot']=D['shots']

            wf_data[ch].tc = wf_data[ch].threshold_centroid(fraction=0.38)
            wf_data[ch].p[clipped]=np.NaN
        # now fit the returns with the waveform model
        tic=time()
        D_out, catalog_buffers= fit_catalogs(wf_data, WF_library, sigmas, delta_ts, \
                                            t_tol=0.25, sigma_tol=0.25, return_data_est=args.waveforms, \
                                            return_catalogs=True,  catalogs=catalog_buffers, params=outDS)

        delta_time=time()-tic

        # write out the fit information
        N_out=D_out['both']['R'].size
        for ch in channels:
            for key in outDS['ch']:
                try:
                    out_h5[ch][key][outShot0:outShot0+N_out]=D_out[ch][key].ravel()
                except OSError:
                    print("OSError for channel %s,  key=%s, outshot0=%d, outshotN=%d, nDS=%d"% (ch, key, outShot0, outShot0+N_out, out_h5[key].size))
        for key in outDS['both']:
            try:
                out_h5['both'][key][outShot0:outShot0+N_out]=D_out['both'][key].ravel()
            except OSError:
                print("OSError for both channels, key=%s, outshot0=%d, outshotN=%d, nDS=%d"% (key, outShot0, outShot0+N_out, out_h5[key].size))

        # write out the location info
        loc_ind=np.flatnonzero(np.in1d(loc_info['shot'], D_out[loc_info['channel']]['shot']))
        for field in outDS['location']:
            out_h5['location'][field][outShot0:outShot0+N_out]=loc_info[field][loc_ind]

        # write out the waveforms
        if args.waveforms:
            for ch in channels:
                out_h5['RX/'+ch+'/p_fit'][:, outShot0:outShot0+N_out] = np.squeeze(D_out[ch]['wf_est']).T
                out_h5['RX/'+ch+'/p'][:, outShot0:outShot0+N_out] = wf_data[ch].p
                out_h5['RX/'+ch+'/t_shift'][outShot0:outShot0+N_out] = D_out[ch]['t_shift'].ravel()

        print("  shot=%d out of %d, N_keys=%d, dt=%5.1f" % (shot0+blocksize, start_vals[-1]+blocksize, len(catalog_buffers['G'].keys()), delta_time))
    print("   time to fit RX=%3.2f" % (time()-time_old))

    if args.waveforms:
        for ch in channels:
            out_h5.create_dataset('RX/'+ch+'/t', data=wf_data[ch].t.ravel())

    out_h5.close()
def fit_catalogs(WFs,
                 catalogs_in,
                 sigmas,
                 delta_ts,
                 t_tol=None,
                 sigma_tol=None,
                 return_data_est=False,
                 return_catalogs=False,
                 catalogs=None,
                 params=None,
                 M_list=None):
    """
    Search a library of waveforms for the best match between the broadened, shifted library waveform
    and the target waveforms

    Inputs:
        WFs: a dict of waveform objects, whose entries include:
            't': the waveform's time vector
            'p': the power samples of the waveform
            'tc': a center time relative to which the waveform's time is shifted
        catalog_in: A dictionary containing waveform objects that will be broadened and
                    shifted to match the waveforms in 'WFs'
        sigmas: a list of spread values that will be searched for each template and waveform
                The search over sigmas terminates when a minimum is found
        delta_ts: a list of time-shift values that will be searched for each template and
                waveform.  All of these will be searched, then the results will be refined
                to a tolerance of t_tol
        'params' : a dict giving a list of parameters to return in D_out
        keyword arguments:
            return_data_est:  set to 'true' if the algorithm should return the best-matching
                shifted and broadened template for each input
            t_tol: tolerance for the time search, defaults to WF.t_samp/10
    Outputs:
        WFp: a set of best-fitting waveform parameters that give:
            delta_t: the time-shift required to align the template and measured waveforms
            sigma: the broadening applied to the measured waveform
            k0: the key into the waveform catalog for the best-fitting waveform

    """
    # set a sensible tolerance for delta_t if none is specified
    if t_tol is None:
        t_tol = WFs.dt
    if sigma_tol is None:
        sigma_tol = 0.25

    channels = list(WFs.keys())

    N_shots = WFs[channels[0]].size
    N_samps = WFs[channels[0]].t.size

    # make an empty output_dictionary
    if params is None:
        params = {
            'ch': [
                'R', 'A', 'B', 'noise_RMS', 'delta_t', 'shot', 't0', 'tc',
                't_shift'
            ],
            'both': ['K0', 'R', 'sigma', 'Kmin', 'Kmax']
        }
    WFp_empty = {}
    for ch in channels:
        WFp_empty[ch] = {f: np.NaN for f in params['ch']}
    WFp_empty['both'] = {f: np.NaN for f in params['both']}
    if return_data_est:
        for ch in channels:
            WFp_empty[ch]['wf_est'] = np.zeros_like(WFs[ch].t) + np.NaN
            WFp_empty[ch]['t_shift'] = np.NaN

    if catalogs is None:
        catalogs = {ch: listDict() for ch in channels}

    # make a container for the pulse widths for the catalogs, and copy the input catalogs into the buffer catalogs
    W_catalogs = {}
    for ch in channels:
        k_vals = np.sort(list(catalogs_in[ch]))
        W_catalogs[ch] = np.zeros(k_vals.shape)
        # loop over the library of templates
        for ii, kk in enumerate(k_vals):
            # record the width of the waveform
            W_catalogs[ch][ii] = catalogs_in[ch][kk].fwhm()[0]
            # check if we've searched this template before, otherwise copy it into
            # the library of checked templates
            if [kk] not in catalogs[ch]:
                # make a copy of the current template
                temp = catalogs_in[ch][kk]
                catalogs[ch][[kk]] = waveform(temp.t,
                                              temp.p,
                                              t0=temp.t0,
                                              tc=temp.tc)

    fit_param_list = []
    sigma_last = None
    t_center = {ch: WFs[ch].t.mean() for ch in channels}
    last_keys = {ch: [] for ch in channels}

    # loop over input waveforms
    for WF_count in range(N_shots):
        fit_params = deepcopy(WFp_empty)
        WF = {ch: WFs[ch][WF_count] for ch in channels}

        # skip waveforms for which we detected an error (so far just waveforms that have t0 far from the calrng value)
        if np.any([WF[ch].error_flag for ch in WF.keys()]):
            continue
        if np.any([np.isnan(WF[ch].tc) for ch in WF.keys()]):
            continue
        # skip multi-peak returns:
        #n_peaks=np.array([WF[ch].nPeaks for ch in channels])
        #if np.any(n_peaks > 1):
        #    continue
        # shift the waveforms to put their tcs at the center of the time vector
        # doing this means that we have to store fewer shifted catalogs
        for ch in channels:
            delta_samp = np.round((WF[ch].tc - t_center[ch]) / WF[ch].dt)
            WF[ch].p = integer_shift(WF[ch].p, -delta_samp)
            WF[ch].t0 += delta_samp * WF[ch].dt
            WF[ch].tc -= delta_samp * WF[ch].dt
            WF[ch].t_shift = delta_samp * WF[ch].dt

        # set up a matching dictionary (contains keys of waveforms and their misfits)
        Ms = {ch: listDict() for ch in channels}
        # this is the bulk of the work, and it's where problems happen.  Wrap it in a try:
        # and write out errors to be examined later
        try:
            if len(k_vals) > 1:
                # find the best misfit between this template and the waveform
                fB = lambda ind: fit_broadened(delta_ts,
                                               None,
                                               WF,
                                               catalogs,
                                               Ms, [k_vals[ind]],
                                               sigma_tol=sigma_tol,
                                               t_tol=t_tol,
                                               sigma_last=sigma_last,
                                               refine_sigma=True)
                W_broad_ind = 0
                # find the first catalog entry that's broader than the waveform (check both cnannels, pick the broader one)
                for ch in channels:
                    this_broad_ind = np.flatnonzero(
                        W_catalogs[ch] >= WF[ch].fwhm()[0])
                    if len(this_broad_ind) == 0:
                        this_broad_ind = len(W_catalogs[ch])
                    else:
                        this_broad_ind = this_broad_ind[0]
                    W_broad_ind = np.maximum(W_broad_ind, this_broad_ind)
                # search two steps on either side of the broadness-matched waveform, as well as zero (all broadening due to roughness)
                key_search_ind = np.array(
                    sorted(tuple(set([0, W_broad_ind - 2, W_broad_ind + 2]))))
                key_search_ind = key_search_ind[(key_search_ind >= 0) &
                                                (key_search_ind < len(k_vals))]
                search_hist = {}
                iBest, Rbest = golden_section_search(fB,
                                                     key_search_ind,
                                                     delta_x=2,
                                                     bnds=[0,
                                                           len(k_vals) - 1],
                                                     integer_steps=True,
                                                     tol=1,
                                                     refine_parabolic=False,
                                                     search_hist=search_hist)
                iBest = int(iBest)
            else:
                _ = fit_broadened(delta_ts,
                                  None,
                                  WF,
                                  catalogs,
                                  Ms, [k_vals[0]],
                                  sigma_tol=sigma_tol,
                                  t_tol=t_tol,
                                  sigma_last=sigma_last)
                iBest = 0
                Rbest = Ms[ch][[k_vals[0]]]['best']['R']
            this_kval = [k_vals[iBest]]
            fit_params['both']['R'] = Rbest
            fit_params['both']['K0'] = this_kval
            R_dict = {}
            sigma_last = 0
            for ch in channels:
                M = Ms[ch]
                M['best'] = {'key': this_kval, 'R': M[this_kval]['best']['R']}
                for ki in [this_key for this_key in k_vals if [this_key] in M]:
                    if ki in R_dict:
                        R_dict[ki] += M[[ki]]['best']['R'] / WF[ch].noise_RMS
                    else:
                        R_dict[ki] = M[[ki]]['best']['R'] / WF[ch].noise_RMS

                # recursively traverse the M dict for the best match.  The lowest-level match
                # will not have a 'best' entry
                this_key = this_kval
                while 'best' in M[this_key]:
                    this_key = M[this_key]['best']['key']
                # write out the best model information
                fit_params[ch].update(M[this_key])
                fit_params[ch]['noise_RMS'] = WF[ch].noise_RMS[0]
                fit_params[ch]['tc'] = WF[ch].tc[0]
                fit_params[ch]['Amax'] = np.nanmax(WF[ch].p)
                fit_params[ch]['seconds_of_day'] = WF[ch].seconds_of_day
                #fit_params[ch][WF_count]['delta_t'] -= WF[ch].t0
                fit_params['both']['shot'] = WF[ch].shots[0]
                if 'nPeaks' in params['ch']:
                    fit_params[ch]['nPeaks'] = WF[ch].nPeaks
                sigma_last = np.maximum(sigma_last, M[this_key]['sigma'])
            fit_params['both']['sigma'] = M[this_key]['sigma']
            searched_k_vals = np.array(sorted(R_dict.keys()))
            R = np.array([R_dict[ki] for ki in searched_k_vals]).ravel()

            R_max = fit_params['both']['R'] * (1. + 1. / np.sqrt(N_samps))

            if np.sum(searched_k_vals > 0) >= 0:
                these = np.flatnonzero(searched_k_vals > 0)
            if np.sum(R[these] > 0) > 2:
                if len(these) > 3:
                    ind_k_vals = np.argsort(R[these])
                    these = these[ind_k_vals[0:4]]
                E_roots = np.polynomial.polynomial.Polynomial.fit(
                    np.log10(searched_k_vals[these]), R[these] - R_max,
                    2).roots()
                if (len(E_roots) == 0) or np.any(np.imag(E_roots) != 0):
                    fit_params['both']['Kmax']=10**np.minimum(3,\
                              np.polynomial.polynomial.Polynomial.fit(\
                              np.log10(searched_k_vals[these]), R[these]-R_max, 1).roots()[0])
                    fit_params['both']['Kmin'] = np.min(
                        searched_k_vals[R < R_max])
                else:
                    fit_params['both']['Kmin'] = 10**np.min(E_roots)
                    fit_params['both']['Kmax'] = 10**np.max(E_roots)
            if (0. in searched_k_vals) and R[searched_k_vals == 0] < R_max:
                fit_params['both']['Kmin'] = 0.
            #copy remaining waveform parameters to the output data structure
            for ch in channels:
                fit_params[ch]['shot'] = WF[ch].shots[0]
                fit_params[ch]['t0'] = WF[ch].t0[0]
                fit_params[ch]['t_shift'] = WF[ch].t_shift[0]
                fit_params[ch]['noise_RMS'] = WF[ch].noise_RMS[0]
            #print(this_key+[R[iR][0]])
            if return_data_est or DOPLOT:
                # call WF_misfit for each channel
                wf_est = {}
                for ch, WFi in WF.items():
                    R0, wf_est = wf_misfit(fit_params[ch]['delta_t'],
                                           fit_params[ch]['sigma'],
                                           WF[ch],
                                           catalogs[ch],
                                           Ms[ch], [this_key[0]],
                                           return_data_est=True)
                    fit_params[ch]['wf_est'] = wf_est
                    fit_params[ch]['t_shift'] = WF[ch].t_shift

            if KEY_SEARCH_PLOT:
                import matplotlib.pyplot as plt
                ch_keys = {}
                new_keys = {}
                fig = plt.gcf()
                fig.clf()
                for ind, ch in enumerate(channels):
                    fig.add_subplot(2, 1, ind + 1)
                    ch_keys[ch] = [[key[0:2]] for key in catalogs[ch].keys()
                                   if len(key) > 1]
                    new_keys[ch] = [
                        key for key in ch_keys[ch] if key not in last_keys[ch]
                    ]
                    kxy = np.concatenate(ch_keys[ch], axis=0)
                    plt.plot(np.log10(kxy[:, 0]), kxy[:, 1], 'k.')
                    if len(new_keys[ch]) > 0:
                        kxy_new = np.concatenate(new_keys[ch], axis=0)
                        plt.plot(np.log10(kxy_new[:, 0]), kxy_new[:, 1], 'ro')
                last_keys = {ch: ch_keys[ch].copy() for ch in ch_keys.keys()}
            # report
            fit_param_list += [fit_params]
            if DOPLOT:
                import matplotlib.pyplot as plt
                plt.figure()
                colors = {'IR': 'r', 'G': 'g'}
                this_title = ''
                for ch in channels:
                    #plt.plot(WF[ch].t, integer_shift(WF[ch].p, delta_samp),'.', color=colors[ch])
                    plt.plot(WF[ch].t, WF[ch].p, 'x', color=colors[ch])
                    plt.plot(WF[ch].t, wf_est[ch], color=colors[ch])
                    this_title += '%s: K=%3.2g, dt=%3.2g, $\sigma$=%3.2g, R=%3.2f\n' % (
                        ch, this_key[0], fit_params[ch]['delta_t'],
                        fit_params[ch]['sigma'], fit_params[ch]['R'])
                plt.title(this_title[0:-2])
                print(WF_count)
            if M_list is not None:
                M_list += [Ms]
        except KeyboardInterrupt:
            sys.exit()
        except Exception as e:
            print("Exception thrown for shot %d" % WF[channels[0]].shots)
            print(e)
            pass
        if np.mod(WF_count, 1000) == 0 and WF_count > 0:
            print('    N=%d, N_keys=%d, %d' %
                  (WF_count, len(list(catalogs[channels[0]])),
                   len(list(catalogs[channels[1]]))))

    result = {}
    for ch in channels + ['both']:
        result[ch] = {}
        for field in WFp_empty[ch].keys():
            try:
                result[ch][field] = np.array(
                    [ii[ch][field] for ii in fit_param_list])
            except ValueError:
                print("problem with channel %s, field %s" % (ch, field))
    if return_catalogs:
        return result, catalogs
    else:
        return result
def fit_broadened(delta_ts,
                  sigmas,
                  WFs,
                  catalogs,
                  Ms,
                  key_top,
                  sigma_tol=0.125,
                  sigma_max=5.,
                  t_tol=None,
                  refine_sigma=False,
                  sigma_last=None):
    """
    Find the best broadening value that minimizes the misfit between a template and a waveform
    """
    channels = list(WFs.keys())
    for _, M in Ms.items():
        #print(key_top)
        if key_top not in M:
            M[key_top] = listDict()
    fSigma = lambda sigma: np.sum([
        broadened_misfit(delta_ts,
                         sigma,
                         WFs[ch],
                         catalogs[ch],
                         Ms[ch],
                         key_top,
                         t_tol=t_tol,
                         refine_parabolic=True) / WFs[ch].noise_RMS
        for ch in channels
    ])
    sigma_step = 2 * sigma_tol
    FWHM2sigma = 2.355

    if sigmas is None:
        # Choose a sigma range to search.  Both channels would have been broadened from
        # the template by some unknown roughness.  Solving for that roughness for each
        # channel gives the maximum that might be needed for either (most conservative range)
        sigma0 = 0
        for ch in channels:
            sigma_template = catalogs[ch][key_top].fwhm()[0] / FWHM2sigma
            sigma_WF = WFs[ch].fwhm()[0] / FWHM2sigma
            # if fwhm() doesn't work, just use an arbitrary value for sigma_WF
            if ~np.isfinite(sigma_WF):
                sigma_WF = sigma_max / 2
            #estimate broadening from template WF to measured WF
            sigma_extra = np.sqrt(
                np.maximum(0, sigma_WF**2 - sigma_template**2))
            sigma0 = np.maximum(sigma0,
                                sigma_step * np.ceil(sigma_extra / sigma_step))
        dSigma = np.maximum(sigma_step, np.ceil(sigma0 / 4.))
        if sigma0 + dSigma > sigma_max or ~np.isfinite(sigma0):
            dSigma = 0.25 * sigma_max
            sigma0 = 0.75 * sigma_max
        sigmas = np.unique([
            0.,
            np.maximum(sigma_step, sigma0 - dSigma), sigma0,
            np.maximum(sigma_step, sigma0 + dSigma)
        ])
    else:
        dSigma = np.max(sigmas) / 4.
    if sigma_last is not None:
        i1 = np.maximum(1, np.argmin(np.abs(sigmas - sigma_last)))
    else:
        i1 = len(sigmas) - 1
    sigma_list = [sigmas[0], sigmas[i1]]
    if np.any(~np.isfinite(sigmas)):
        print("NaN in sigma for shot %d " % WFs[channels[0]].shots)
        if np.nanmax(sigmas) > np.nanmin(sigmas):
            sigma_list=[np.nanmax([np.nanmax([0, np.nanmin(sigmas)])]), \
                        np.nanmin([sigma_max, np.nanmax(sigmas)])]
        else:
            sigma_list = [0, sigma_max]
        dSigma = (np.max(sigma_list) - np.min(sigma_list)) / 4.
    search_hist = {}
    sigma_best, R_best = golden_section_search(fSigma,
                                               sigma_list,
                                               dSigma,
                                               bnds=[0, sigma_max],
                                               tol=sigma_tol,
                                               max_count=20,
                                               refine_parabolic=refine_sigma,
                                               search_hist=search_hist)

    if refine_sigma:
        # calculate the misfit at this exact sigma value
        R_best = 0
        for ch in channels:
            # pass in a temporary catalog so that the top-level catalogs don't
            # get populated with entries that won't get reused
            this_catalog = listDict()
            this_catalog[(key_top)] = catalogs[ch][key_top]
            R_best += broadened_misfit(
                delta_ts,
                sigma_best,
                WFs[ch],
                this_catalog,
                Ms[ch],
                key_top,
                t_tol=t_tol,
                refine_parabolic=True) / WFs[ch].noise_RMS
    return R_best
Beispiel #4
0
def errors_for_one_scat_file(scat_files, TX_files, channels, out_file=None):

    sigmas = np.arange(0, 5, 0.25)
    # choose a set of delta t values
    delta_ts = np.arange(-1., 1.5, 0.5)

    N_WFs = 256
    TX = {}
    # get the transmit pulse
    for ind, ch in enumerate(channels):
        with h5py.File(args.TXfiles[ind], 'r') as fh:
            TX[ch] = waveform(np.array(fh['/TX/t']), np.array(fh['/TX/p']))
        TX[ch].t -= TX[ch].nSigmaMean()[0]
        TX[ch].tc = 0
        TX[ch].normalize()

    # initialize the library of templates for the transmit waveforms
    TX_library = {}
    for ind, ch in enumerate(channels):
        TX_library[ch] = listDict()
        TX_library[ch].update({0.: TX[ch]})

    # initialize the library of templates for the received waveforms
    WF_library = dict()
    for ind, ch in enumerate(channels):
        WF_library[ch] = dict()
        WF_library[ch].update({0.: TX[ch]})
        WF_library[ch].update(
            make_rx_scat_catalog(TX[ch], h5_file=args.scat_files[ind]))

    out_fields = [
        'K16', 'K84', 'sigma16', 'sigma84', 'sigma', 'A_scale', 'K0', 'Kmed',
        'Ksigma', 'Ksigma_est', 'N', 'fitting_time'
    ]
    for ch in channels:
        out_fields.append('A_' + ch)

    noise_RMS = {'G': 1, 'IR': 0.25}
    unit_amp_target = {'G': 175, 'IR': 140}
    sigma_vals = [0, 0.5, 1, 2]
    A_scale = [0.5, 1, 1.25]
    #A_vals=[50., 100., 200.]
    K0_vals = np.array(list(WF_library['G'].keys()))[::4]
    N_out = len(sigma_vals) * len(A_scale) * len(K0_vals)
    Dstats = {field: np.zeros(N_out) + np.NaN for field in out_fields}
    # calculate waveforms with no scaling applied
    WF_unscaled = make_sim_WFs(1, WF_library,
                               list(WF_library['G'].keys())[1], 0, {
                                   'G': 0,
                                   'IR': 0
                               }, {
                                   'G': 1,
                                   'IR': 1
                               })
    ii = 0
    for key in K0_vals:
        for sigma in sigma_vals:
            for A in A_scale:
                # calculate scaling to apply to the waveforms to achieve the
                # target peak amplitude (could be done outside the loop)
                amp_scale = {
                    ch: A * unit_amp_target[ch] / WF_unscaled[ch].p.max()
                    for ch in channels
                }

                #Calculate the noise-free waveforms
                WF_expected=make_sim_WFs(1, WF_library, key, sigma,\
                                         {ch:0. for ch in channels}, amp_scale)
                if np.min([
                        WF_expected[ch].p.max() / noise_RMS[ch]
                        for ch in channels
                ]) < 3:
                    continue
                # calculate scaled and broadened waveforms
                WFs = make_sim_WFs(N_WFs, WF_library, key, sigma, noise_RMS,
                                   amp_scale)
                # fit the waveforms
                tic = time.time()
                D_out= fit_catalogs(WFs, WF_library, sigmas, delta_ts, \
                                            t_tol=0.25, sigma_tol=0.25)
                Dstats['fitting_time'][ii] = time.time() - tic
                for ch in channels:
                    Dstats['A_' + ch][ii] = WF_expected[ch].p.max()
                sR = sps.scoreatpercentile(D_out['both']['sigma'], [16, 84])
                Dstats['A_scale'][ii] = A
                Dstats['sigma'][ii] = sigma
                Dstats['K0'][ii] = key
                Dstats['sigma16'][ii] = sR[0]
                Dstats['sigma84'][ii] = sR[1]
                KR = sps.scoreatpercentile(D_out['both']['K0'], [16, 84])
                Dstats['K16'][ii] = KR[0]
                Dstats['K84'][ii] = KR[1]
                Dstats['Kmed'][ii] = np.nanmedian(D_out['both']['K0'])
                Dstats['Ksigma_est'][ii] = np.nanmedian(D_out['both']['Kmax'] -
                                                        D_out['both']['Kmin'])
                Dstats['Ksigma'][ii] = np.nanstd(D_out['both']['K0'])
                Dstats['N'][ii] = np.sum(np.isfinite(D_out['both']['K0']))
                print(
                    'K0=%2.2g, sigma=%2.2f, A=%2.2f, ER=[%2.2g, %2.2g], E=%2.2g'
                    % (key, sigma, A, KR[0] - key, KR[1] - key,
                       Dstats['Ksigma'][ii]))
                ii += 1

    print("yep")
    if out_file is not None:
        if os.path.isfile(out_file):
            os.remove(out_file)
        out_h5 = h5py.File(out_file, 'w')
        for kk in Dstats:
            out_h5.create_dataset(kk, data=Dstats[kk])
        out_h5.close()
    return Dstats