Ejemplo n.º 1
0
def aliasing_filter(det):
    '''
    We know there's a 1-pole antialiasing filter at ~1/(2 * 49.9 * 33E-12) Hz
    How big a difference should that make?
    '''

    lowpass = DigitalFilter(2)
    lowpass.num = [1,2,1]
    lowpass.set_poles(0.975, 0.007)

    hipass = DigitalFilter(1)
    hipass.num, hipass.den = rc_decay(72, 1E9)

    det.AddDigitalFilter(lowpass)
    det.AddDigitalFilter(hipass)

    wf_proc = np.copy(det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, wf_length, smoothing=20))
    wf_compare = np.copy(wf_proc)

    f, ax = plt.subplots(2,1,figsize=(15,8))
    ax[0].plot (wf_compare,  color="r")

    lopass2 = DigitalFilter(1)
    lopass2.num = [1,1]
    rc = (2 * 49.9 * 33E-12)
    lopass2.den = [1, -np.exp( -1./rc/1E9)]
    det.AddDigitalFilter(lopass2)

    wf_proc = np.copy(det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, wf_length, smoothing=20))

    ax[0].plot (wf_proc,  color="g")
    ax[1].plot (wf_compare-wf_proc,  color="g")

    plt.show()
    exit()
Ejemplo n.º 2
0
def skew():

    lowpass = DigitalFilter(2)
    lowpass.num = [1, 2, 1]
    lowpass.set_poles(0.975, 0.007)

    hipass = DigitalFilter(2)
    hipass.num = [1, -2, 1]
    hipass.set_poles(1. - 10.**-7, np.pi**-13.3)

    det.AddDigitalFilter(lowpass)
    det.AddDigitalFilter(hipass)

    wf_proc = np.copy(
        det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, 1000, smoothing=25))
    wf_compare = np.copy(wf_proc)

    f, ax = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
    ax[0].plot(wf_compare, color="r")

    overshoot = GretinaOvershootFilter(1)
    overshoot.num, overshoot.den = rc_decay(1.7, 1E9)
    det.AddDigitalFilter(overshoot)

    for frac in np.linspace(0.01, 0.05, 5):
        overshoot.overshoot_frac = frac
        wf_proc = np.copy(
            det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, 1000, smoothing=25))

        ax[0].plot(wf_proc)
        ax[1].plot(wf_proc - wf_compare)

    plt.show()
Ejemplo n.º 3
0
    def min_func(x):
        rc_decay, overshoot_decay, overshoot_pole_rel, energy = x
        # rc_decay, overshoot_decay, overshoot_pole_rel, energy, long_rc = x

        rc_num, rc_den = filt.rc_decay(rc_decay*10)
        wf_proc1 = signal.lfilter(rc_den, rc_num, wf_data)

        # long_rc_num, long_rc_den = filt.rc_decay(long_rc*1000)
        # wf_proc1 = signal.lfilter(long_rc_den, long_rc_num, wf_proc1)

        overshoot_num, overshoot_den = filt.gretina_overshoot(overshoot_decay, overshoot_pole_rel)
        wf_proc = signal.lfilter(overshoot_den, overshoot_num, wf_proc1)

        tail_data = wf_proc[max_idx:]
        flat_line = np.ones(len(tail_data))*energy*1000
        return np.sum((tail_data-flat_line)**2)
Ejemplo n.º 4
0
    def min_func(x):
        rc_decay, overshoot_decay, overshoot_pole_rel, energy = x
        # rc_decay, overshoot_decay, overshoot_pole_rel, energy, long_rc = x

        rc_num, rc_den = filt.rc_decay(rc_decay * 10)
        wf_proc1 = signal.lfilter(rc_den, rc_num, wf_data)

        # long_rc_num, long_rc_den = filt.rc_decay(long_rc*1000)
        # wf_proc1 = signal.lfilter(long_rc_den, long_rc_num, wf_proc1)

        overshoot_num, overshoot_den = filt.gretina_overshoot(
            overshoot_decay, overshoot_pole_rel)
        wf_proc = signal.lfilter(overshoot_den, overshoot_num, wf_proc1)

        tail_data = wf_proc[max_idx:]
        flat_line = np.ones(len(tail_data)) * energy * 1000
        return np.sum((tail_data - flat_line)**2)
Ejemplo n.º 5
0
def fit_tail(wf_data):
    '''
    try to fit out the best tail parameters to flatten the top
    '''

    max_idx = np.argmax(wf_data)


    def min_func(x):
        rc_decay, overshoot_decay, overshoot_pole_rel, energy = x
        # rc_decay, overshoot_decay, overshoot_pole_rel, energy, long_rc = x

        rc_num, rc_den = filt.rc_decay(rc_decay*10)
        wf_proc1 = signal.lfilter(rc_den, rc_num, wf_data)

        # long_rc_num, long_rc_den = filt.rc_decay(long_rc*1000)
        # wf_proc1 = signal.lfilter(long_rc_den, long_rc_num, wf_proc1)

        overshoot_num, overshoot_den = filt.gretina_overshoot(overshoot_decay, overshoot_pole_rel)
        wf_proc = signal.lfilter(overshoot_den, overshoot_num, wf_proc1)

        tail_data = wf_proc[max_idx:]
        flat_line = np.ones(len(tail_data))*energy*1000
        return np.sum((tail_data-flat_line)**2)

    wf_max = wf_data.max()/1000

    res = optimize.minimize(min_func, [7.2, 2, -4, wf_max], method="Powell")
    # print(res["x"])
    rc1, rc2, f, e = res["x"]

    rc1*=10
    rc_num, rc_den = filt.rc_decay(rc1)
    wf_proc1 = signal.lfilter(rc_den, rc_num, wf_data)

    # long_rc*=1000
    # long_rc_num, long_rc_den = filt.rc_decay(long_rc)
    # wf_proc1 = signal.lfilter(long_rc_den, long_rc_num, wf_proc1)

    overshoot_num, overshoot_den = filt.gretina_overshoot(rc2, f)
    wf_proc = signal.lfilter(overshoot_den, overshoot_num, wf_proc1)

    return wf_proc, (e*1000), res["fun"]
Ejemplo n.º 6
0
def two_rc(det):
    '''
    WHATS IT LOOK LIKE WITH A 72 US AND 2 MS DECAY?
    '''

    lowpass = DigitalFilter(2)
    lowpass.num = [1, 2, 1]
    lowpass.set_poles(0.975, 0.007)

    hipass = DigitalFilter(1)
    hipass.num, hipass.den = rc_decay(82, 1E9)

    det.AddDigitalFilter(lowpass)
    det.AddDigitalFilter(hipass)

    wf_proc = np.copy(
        det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, wf_length, smoothing=20))
    wf_compare = np.copy(wf_proc)

    f, ax = plt.subplots(2, 1, figsize=(15, 8))
    ax[0].plot(wf_compare, color="r")

    # hipass2 = DigitalFilter(1)
    # hipass2.num, hipass2.den = rc_decay(2000, 1E9)
    # hipass.num, hipass.den = rc_decay(74.75, 1E9)
    # det.AddDigitalFilter(hipass2)

    mag = 1. - 10.**-5.22
    phi = np.pi**-13.3
    det.RemoveDigitalFilter(hipass)
    hipass = DigitalFilter(2)
    hipass.num = [1, -2, 1]
    hipass.set_poles(mag, phi)
    det.AddDigitalFilter(hipass)

    wf_proc = np.copy(
        det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, wf_length, smoothing=20))

    ax[0].plot(wf_proc, color="g")
    ax[1].plot(wf_compare - wf_proc, color="g")

    plt.show()
    exit()
Ejemplo n.º 7
0
def fit_tail(wf_data):
    '''
    try to fit out the best tail parameters to flatten the top
    '''

    max_idx = np.argmax(wf_data)

    def min_func(x):
        rc_decay, overshoot_decay, overshoot_pole_rel, energy = x
        # rc_decay, overshoot_decay, overshoot_pole_rel, energy, long_rc = x

        rc_num, rc_den = filt.rc_decay(rc_decay * 10)
        wf_proc1 = signal.lfilter(rc_den, rc_num, wf_data)

        # long_rc_num, long_rc_den = filt.rc_decay(long_rc*1000)
        # wf_proc1 = signal.lfilter(long_rc_den, long_rc_num, wf_proc1)

        overshoot_num, overshoot_den = filt.gretina_overshoot(
            overshoot_decay, overshoot_pole_rel)
        wf_proc = signal.lfilter(overshoot_den, overshoot_num, wf_proc1)

        tail_data = wf_proc[max_idx:]
        flat_line = np.ones(len(tail_data)) * energy * 1000
        return np.sum((tail_data - flat_line)**2)

    wf_max = wf_data.max() / 1000

    res = optimize.minimize(min_func, [7.2, 2, -4, wf_max], method="Powell")
    # print(res["x"])
    rc1, rc2, f, e = res["x"]

    rc1 *= 10
    rc_num, rc_den = filt.rc_decay(rc1)
    wf_proc1 = signal.lfilter(rc_den, rc_num, wf_data)

    # long_rc*=1000
    # long_rc_num, long_rc_den = filt.rc_decay(long_rc)
    # wf_proc1 = signal.lfilter(long_rc_den, long_rc_num, wf_proc1)

    overshoot_num, overshoot_den = filt.gretina_overshoot(rc2, f)
    wf_proc = signal.lfilter(overshoot_den, overshoot_num, wf_proc1)

    return wf_proc, (e * 1000), res["fun"]
Ejemplo n.º 8
0
def main():

    runList = [848]
    plt.ion()
    f = plt.figure(figsize=(12,9))


    rc_num, rc_den = filt.rc_decay(72)
    overshoot_num, overshoot_den = filt.gretina_overshoot(2, -3.5)

    for runNumber in runList:
        t1_file = "t1_run{}.h5".format(runNumber)
        df = pd.read_hdf(t1_file,key="ORGretina4MWaveformDecoder")
        g4 = dl.Gretina4MDecoder(t1_file)

        chanList = [49]#np.unique(df["channel"])

        for chan in chanList:
            df_chan = df[df.channel == chan]

            for i, (index, row) in enumerate(df_chan.iterrows()):
                if i<10: continue
                wf = g4.parse_event_data(row)
                wf_dat = wf.data - np.mean(wf.data[:200])
                if np.amax(wf_dat) < 200: continue
                if np.count_nonzero( wf_dat > 0.5*wf_dat.max()) < 10: continue

                max_idx = np.argmax(wf_dat>20)

                wf_corr, model, t0 = fit_tail(wf_dat, max_idx)
                plt.plot(wf_dat)
                plt.plot(wf_corr)
                plt.plot(np.arange(max_idx+t0, max_idx+t0+len(model)), model, c="r")
                plt.axvline(max_idx,c="r", ls=":")

                inp = input("q to continue, else to quit")
                if inp == "q": exit()
Ejemplo n.º 9
0
def main():

    runList = np.arange(11515, 11516)
    proc = DataProcessor()


    plt.ion()
    f = plt.figure(figsize=(12,9))

    rc_num, rc_den = filt.rc_decay(72)
    overshoot_num, overshoot_den = filt.gretina_overshoot(2, -3.5)

    # print(overshoot_num, overshoot_den)
    # exit()

    ds_inf = pd.read_csv("ds1_run_info.csv")

    for runNumber in runList:
        t1_file = os.path.join(proc.t1_data_dir,  "t1_run{}.h5".format(runNumber))
        df = pd.read_hdf(t1_file,key="ORGretina4MWaveformDecoder")
        g4 = dl.Gretina4MDecoder(t1_file)

        chanList = np.unique(df["channel"])
        chanList = [ 580]

        pz_fun = []
        masses = []
        resolutions = []

        for chan in chanList:
            print ("Channel {}".format(chan))

            try:
                ds_inf_det = ds_inf[(ds_inf.LG==chan) | (ds_inf.HG==chan) ]
                det_mass = ds_inf_det.iloc[0].Mass
                det_res = ds_inf_det.iloc[0].Resolution
                det_ctres = ds_inf_det.iloc[0].ct_resolution
                trap_factor = (det_res-det_ctres)/det_ctres
            except IndexError:
                print ("...couldn't find channel info")

            if chan%2 == 1: continue
            if not is_mj(chan): continue

            plt.clf()
            ax1 = plt.subplot(2,2,2)
            ax2 = plt.subplot(2,2,4)
            ax3 = plt.subplot(2,2,3)
            ax4 = plt.subplot(2,2,1)

            df_chan = df[ (df.channel == chan) & (df.energy > 0.2E9) ]

            e_min = df_chan.energy.min()
            e_max = df_chan.energy.max()

            e_cut = 0.9*(e_max-e_min) + e_min

            df_cut = df_chan[df_chan.energy > e_cut]

            bl_idx = 800
            baseline = np.zeros(bl_idx)
            flat_top = np.zeros(800)

            # plt.figure()
            num_wfs = 0
            for i, (index, row) in enumerate(df_cut.iterrows()):
                wf = g4.parse_event_data(row)
                wf_dat = wf.data - np.mean(wf.data[:bl_idx])

                try:
                    wf_corr, energy, gof = fit_tail(wf_dat)

                    align_idx = np.argmax(wf_corr/energy>0.999)+20

                    flat_top_wf = wf_corr[align_idx:align_idx+800] - energy

                    baseline += wf_dat[:bl_idx]
                    flat_top += flat_top_wf
                    num_wfs +=1


                except ValueError as e:
                    print(e)
                    continue

                ax1.plot(wf_corr[:800], c="b", alpha=0.1 )
                ax2.plot(flat_top_wf, c="b", alpha=0.1 )
                ax4.plot(wf_corr[align_idx-400:align_idx+805]/energy, c="b", alpha=0.1 )

                # plt.plot(flat_top_wf, c="b", alpha=0.1)

            # if num_wfs < 5: continue

            flat_top /=num_wfs
            baseline /= num_wfs

            pz_fun.append( np.sum(flat_top**2) )
            masses.append(det_res)

            ax1.set_title("Baseline")
            ax1.plot(baseline, c="r")

            ax2.set_title("Decay (PZ corrected)")
            ax2.plot(flat_top, c="r")
            # plt.plot(flat_top, c="r")

            plt.title("Channel {} (mass {}, res {})".format(chan, det_mass, det_res))
            # ax1.plot(baseline, label="baseline")
            # ax1.plot(flat_top - np.mean(flat_top), label="flat top")
            xf,power = signal.periodogram(baseline, fs=1E8, detrend="linear", scaling="spectrum")

            x_idx = np.argmax(xf>0.2E7)
            ax3.semilogx(xf[x_idx:],power[x_idx:], label="baseline")
            max_pwr = power.max()
            # ax2.plot(flat_top)
            xf,power = signal.periodogram(flat_top, fs=1E8, detrend="constant", scaling="spectrum")
            x_idx = np.argmax(xf>0.2E7)
            ax3.semilogx(xf[x_idx:],power[x_idx:], label="flat top")

            ax3.legend()

            plt.savefig("fft_plots/channel{}_fft.png".format(chan))

            # inp = input("q to continue, else to quit")
            # if inp == "q": exit()
        plt.figure()
        plt.scatter(masses, pz_fun)
        inp = input("q to continue, else to quit")
        if inp == "q": exit()
Ejemplo n.º 10
0
def overshoot(det):
    '''
    How do I get me a decaying overshoot?
    '''

    lowpass = DigitalFilter(2)
    lowpass.num = [1, 2, 1]
    lowpass.set_poles(0.975, 0.007)

    hipass = DigitalFilter(1)
    hipass.num, hipass.den = rc_decay(82, 1E9)

    det.AddDigitalFilter(lowpass)
    # det.AddDigitalFilter(hipass)

    wf_proc = np.copy(
        det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, wf_length, smoothing=20))
    wf_compare = np.copy(wf_proc)

    f, ax = plt.subplots(2, 2, figsize=(15, 8))
    # plt.figure()
    ax[0, 0].plot(wf_compare, color="r")
    cmap = cm.get_cmap('viridis')

    new_filt = DigitalFilter(1)
    det.AddDigitalFilter(new_filt)

    p_mags = 1 - np.logspace(-4, -2, 100, base=10)
    z_mags = 1 - np.logspace(-4, -2, 20, base=10)

    for zero_mag in [5E-4]:  #z_mags:
        for pole_mag in np.logspace(-7, -5, 100, base=10):  #p_mags:
            zero_mag = 1 - zero_mag
            pole_mag = zero_mag - pole_mag

            if zero_mag == pole_mag: continue

            color = "b"
            # color = cmap( (mag2 - mags[0])/(mags[-1] - mags[0]) )
            new_filt.set_zeros(zero_mag, 0)
            new_filt.set_poles(pole_mag, 0)

            # print (new_filt.num, np.sum(new_filt.num))
            # print (new_filt.den, np.sum(new_filt.den))
            # exit()

            wf_proc = np.copy(
                det.MakeSimWaveform(25,
                                    0,
                                    25,
                                    1,
                                    125,
                                    0.95,
                                    1000,
                                    smoothing=20))

            try:
                if wf_proc[155] < 1.00001: continue
                if np.amax(wf_proc) > 1.1: continue
                print(zero_mag, pole_mag)
                p = ax[0, 0].plot(wf_proc)
                color = p[0].get_color()
                ax[1, 0].plot(wf_proc - wf_compare, color=color)
            except (TypeError, IndexError) as e:
                continue

            w, h2 = get_freq_resp(new_filt,
                                  w=np.logspace(-15, 0, 500, base=np.pi))
            ax[0, 1].loglog(w, h2, color=color)

            p[0] = ax[1, 1].scatter(zero_mag, 0, color=color)
            ax[1, 1].scatter(pole_mag, 0, color=color, marker="x")
            an = np.linspace(0, np.pi, 200)
            ax[1, 1].plot(np.cos(an), np.sin(an), c="k")
            ax[1, 1].plot(np.cos(an), -np.sin(an), c="k")
            ax[1, 1].axis("equal")

    plt.show()
Ejemplo n.º 11
0
def oscillation(det):
    '''
    How do I get me a decaying oscillation?
    '''

    lowpass = DigitalFilter(2)
    lowpass.num = [1, 2, 1]
    lowpass.set_poles(0.975, 0.007)

    hipass = DigitalFilter(1)
    hipass.num, hipass.den = rc_decay(82, 1E9)

    det.AddDigitalFilter(lowpass)
    # det.AddDigitalFilter(hipass)

    wf_proc = np.copy(
        det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, wf_length, smoothing=20))
    wf_compare = np.copy(wf_proc)

    f, ax = plt.subplots(2, 2, figsize=(15, 8))
    # plt.figure()
    ax[0, 0].plot(wf_compare, color="r")
    cmap = cm.get_cmap('viridis')

    new_filt = DigitalFilter(2)
    det.AddDigitalFilter(new_filt)
    new_filt.num = [1, 2, 1]

    p_mags = 1 - np.logspace(-3, -2, 20, base=10)
    p_phis = [0.5 * np.pi**-3]
    # p_phis = np.logspace(-5, 1, 4, base=np.pi)
    # z_mags = 1 - np.logspace(-4, -2, 20, base=10)

    pole_phi = 0.5 * np.pi**-3
    pole_mag = 0.995

    for i in range(3):

        # for pole_phi in p_phis:
        #     for pole_mag in p_mags:
        # zero_mag = 1-zero_mag
        # pole_mag = zero_mag + pole_mag
        # if zero_mag == pole_mag: continue

        color = get_color(cmap, pole_mag, p_mags)
        if i == 0:
            color = "k"
        elif i == 1:
            new_filt.set_zeros(0.99, 0.01)
            color = "b"
        elif i == 2:
            new_filt.set_zeros(0.995, 0.001)
            color = "g"
        new_filt.set_poles(pole_mag, pole_phi)

        # print (new_filt.num, np.sum(new_filt.num))
        # print (new_filt.den, np.sum(new_filt.den))
        # exit()

        wf_proc = np.copy(
            det.MakeSimWaveform(25, 0, 25, 1, 125, 0.95, 1000, smoothing=20))

        try:
            # if wf_proc[155] < 1.00001: continue
            # # if np.amax(wf_proc) > 1.1: continue
            # print( zero_mag, pole_mag)
            p = ax[0, 0].plot(wf_proc, color=color)
            # color = p[0].get_color()
            ax[1, 0].plot(wf_proc - wf_compare, color=color)
        except (TypeError, IndexError) as e:
            continue

        w, h2 = get_freq_resp(new_filt, w=np.logspace(-15, 0, 500, base=np.pi))
        ax[0, 1].loglog(w, h2, color=color)
        ax[0, 1].axvline(pole_phi / (np.pi / nyq_freq))

        # p[0] = ax[1,1].scatter(zero_mag, 0, color=color)
        ax[1, 1].scatter(pole_mag * np.cos(pole_phi),
                         pole_mag * np.sin(pole_phi),
                         color=color,
                         marker="x")
        an = np.linspace(0, np.pi, 200)
        ax[1, 1].plot(np.cos(an), np.sin(an), c="k")
        ax[1, 1].plot(np.cos(an), -np.sin(an), c="k")
        ax[1, 1].axis("equal")

    plt.show()
Ejemplo n.º 12
0
def main():

    runList = np.arange(11515, 11516)
    proc = DataProcessor()

    plt.ion()
    f = plt.figure(figsize=(12, 9))

    rc_num, rc_den = filt.rc_decay(72)
    overshoot_num, overshoot_den = filt.gretina_overshoot(2, -3.5)

    # print(overshoot_num, overshoot_den)
    # exit()

    ds_inf = pd.read_csv("ds1_run_info.csv")

    for runNumber in runList:
        t1_file = os.path.join(proc.t1_data_dir,
                               "t1_run{}.h5".format(runNumber))
        df = pd.read_hdf(t1_file, key="ORGretina4MWaveformDecoder")
        g4 = dl.Gretina4MDecoder(t1_file)

        chanList = np.unique(df["channel"])
        chanList = [580]

        pz_fun = []
        masses = []
        resolutions = []

        for chan in chanList:
            print("Channel {}".format(chan))

            try:
                ds_inf_det = ds_inf[(ds_inf.LG == chan) | (ds_inf.HG == chan)]
                det_mass = ds_inf_det.iloc[0].Mass
                det_res = ds_inf_det.iloc[0].Resolution
                det_ctres = ds_inf_det.iloc[0].ct_resolution
                trap_factor = (det_res - det_ctres) / det_ctres
            except IndexError:
                print("...couldn't find channel info")

            if chan % 2 == 1: continue
            if not is_mj(chan): continue

            plt.clf()
            ax1 = plt.subplot(2, 2, 2)
            ax2 = plt.subplot(2, 2, 4)
            ax3 = plt.subplot(2, 2, 3)
            ax4 = plt.subplot(2, 2, 1)

            df_chan = df[(df.channel == chan) & (df.energy > 0.2E9)]

            e_min = df_chan.energy.min()
            e_max = df_chan.energy.max()

            e_cut = 0.9 * (e_max - e_min) + e_min

            df_cut = df_chan[df_chan.energy > e_cut]

            bl_idx = 800
            baseline = np.zeros(bl_idx)
            flat_top = np.zeros(800)

            # plt.figure()
            num_wfs = 0
            for i, (index, row) in enumerate(df_cut.iterrows()):
                wf = g4.parse_event_data(row)
                wf_dat = wf.data - np.mean(wf.data[:bl_idx])

                try:
                    wf_corr, energy, gof = fit_tail(wf_dat)

                    align_idx = np.argmax(wf_corr / energy > 0.999) + 20

                    flat_top_wf = wf_corr[align_idx:align_idx + 800] - energy

                    baseline += wf_dat[:bl_idx]
                    flat_top += flat_top_wf
                    num_wfs += 1

                except ValueError as e:
                    print(e)
                    continue

                ax1.plot(wf_corr[:800], c="b", alpha=0.1)
                ax2.plot(flat_top_wf, c="b", alpha=0.1)
                ax4.plot(wf_corr[align_idx - 400:align_idx + 805] / energy,
                         c="b",
                         alpha=0.1)

                # plt.plot(flat_top_wf, c="b", alpha=0.1)

            # if num_wfs < 5: continue

            flat_top /= num_wfs
            baseline /= num_wfs

            pz_fun.append(np.sum(flat_top**2))
            masses.append(det_res)

            ax1.set_title("Baseline")
            ax1.plot(baseline, c="r")

            ax2.set_title("Decay (PZ corrected)")
            ax2.plot(flat_top, c="r")
            # plt.plot(flat_top, c="r")

            plt.title("Channel {} (mass {}, res {})".format(
                chan, det_mass, det_res))
            # ax1.plot(baseline, label="baseline")
            # ax1.plot(flat_top - np.mean(flat_top), label="flat top")
            xf, power = signal.periodogram(baseline,
                                           fs=1E8,
                                           detrend="linear",
                                           scaling="spectrum")

            x_idx = np.argmax(xf > 0.2E7)
            ax3.semilogx(xf[x_idx:], power[x_idx:], label="baseline")
            max_pwr = power.max()
            # ax2.plot(flat_top)
            xf, power = signal.periodogram(flat_top,
                                           fs=1E8,
                                           detrend="constant",
                                           scaling="spectrum")
            x_idx = np.argmax(xf > 0.2E7)
            ax3.semilogx(xf[x_idx:], power[x_idx:], label="flat top")

            ax3.legend()

            plt.savefig("fft_plots/channel{}_fft.png".format(chan))

            # inp = input("q to continue, else to quit")
            # if inp == "q": exit()
        plt.figure()
        plt.scatter(masses, pz_fun)
        inp = input("q to continue, else to quit")
        if inp == "q": exit()