示例#1
0
class TrajSpace(QWidget):
    def __init__(self, bam):

        QWidget.__init__(self)

        self.height_points = []
        self.bam = bam
        self.buildGUI()
        self.calculate()

    def clearax(self):
        self.height_points = []
        # self.pvh_graph.ax1.clear()
        self.pvh_graph.ax2.clear()
        self.pvh_graph.ax3.clear()
        # self.pvh_graph.ax4.clear()
        self.pvh_graph.ax5.clear()
        self.pvh_graph.ax6.clear()

    def binify(self):
        bin_size = float(self.bin_edits.text())
        h_min = float(self.min_height_edits.text()) * 1000
        h_max = float(self.max_height_edits.text()) * 1000

        bins = np.arange(h_min, h_max + bin_size, bin_size)
        bin_content = [0] * len(bins)

        pts = self.height_points

        for pt in pts:
            h = pt[0]
            a = pt[1]

            indx = [n for n, i in enumerate(bins) if i >= h][0] - 1
            bin_content[indx] += a

        # self.pvh_graph.ax4.scatter(bins + bin_size/2, bin_content, alpha=1.0)

    def calculate(self):

        h_min = float(self.min_height_edits.text()) * 1000
        h_max = float(self.max_height_edits.text()) * 1000

        self.geminus_heights = []
        self.geminus_p = []
        self.geminus_t = []
        self.geminus_stat = []

        self.clearax()

        infra_list = self.getInfraStats()

        popt_list = self.genHyperbola(infra_list)

        trace_list, resp_list = self.getInfraTraces(infra_list)

        N = int(self.N_edits.text())
        l = float(self.l_edits.text())

        for tt, (trace, resp, popt, infra) in enumerate(
                zip(trace_list, resp_list, popt_list, infra_list)):

            if popt is None:
                continue

            stn_name = "{:}-{:}".format(infra.metadata.network,
                                        infra.metadata.code)

            a, t = procTrace(trace,
                             ref_datetime=self.bam.setup.fireball_datetime,
                             resp=resp,
                             bandpass=None,
                             backup=False)
            a_new = [item for sublist in a for item in sublist]
            t_new = [item for sublist in t for item in sublist]

            a = np.array(a_new)
            t = np.array(t_new)
            # p_list, h_list = self.convertTimes(a[0], t[0], popt)

            heights = invhypfunc(t, *popt)
            h_indicies = np.where(
                np.logical_and(heights >= h_min, heights <= h_max))

            divide = 0
            for i in range(len(h_indicies[0])):

                if h_indicies[0][i] - h_indicies[0][i - 1] != 1:
                    divide = i

            branch_1 = h_indicies[0][:divide]
            branch_2 = h_indicies[0][divide:-1]

            FASes = []
            logs = []
            L = len(a)
            spacer = int(L / (l * (1 - N) + N))
            shifter = int(spacer * (1 - l))
            print("{:} Window Length = {:.2f} s".format(
                stn_name, spacer / trace.stats.sampling_rate))
            print("{:} Window Shift = {:.2f} s".format(
                stn_name, shifter / trace.stats.sampling_rate))
            for i in range((L - spacer) // shifter + 1):
                a_temp, t_temp = procTrace(
                    trace,
                    ref_datetime=self.bam.setup.fireball_datetime,
                    resp=resp,
                    bandpass=None,
                    backup=False)
                # if self.h_space_tog.isChecked():
                #     self.pvh_graph.ax1.plot(invhypfunc(t_temp[0][i*shifter:int(i*shifter + spacer)], *popt), a_temp[0][i*shifter:int(i*shifter + spacer)], alpha=0.5)
                # else:
                #     self.pvh_graph.ax1.plot(t_temp[0][i*shifter:int(i*shifter + spacer)], a_temp[0][i*shifter:int(i*shifter + spacer)], alpha=0.5)

                freq, FAS = genFFT(a[i * shifter:int(i * shifter + spacer)],
                                   trace.stats.sampling_rate)
                if i == 0:
                    FAS_N = FAS

                FASes.append(np.sum(FAS))
                logs.append(FAS / FAS_N)

            best_fas = np.argmax(FASes)
            #second best fas
            FASes[best_fas] = 0
            bad_fas = np.argmax(FASes)
            self.pvh_graph.ax2.plot(freq, logs[best_fas], label=stn_name)

            if self.auto_gain.isChecked():
                j = np.where(logs[best_fas] >= logs[bad_fas])
            else:
                j = np.where(logs[best_fas] >= float(self.gain_edits.text()))

            # plt.loglog(freq[j], logs[np.argmax(FASes)][j])

            filtered_freq = freq[j]

            # print("Optimal Frequency Range {:.2f} - {:.2f} Hz".format(filtered_freq[0], filtered_freq[-1]))

            if self.stat_bandpass.isChecked():
                bnps = infra.bandpass
            else:
                bnps = [filtered_freq[0], filtered_freq[-1]]

            if bnps is not None:
                print("Optimal Frequency Range {:.2f} - {:.2f} Hz".format(
                    bnps[0], bnps[-1]))

            a, t = procTrace(trace,
                             ref_datetime=self.bam.setup.fireball_datetime,
                             resp=resp,
                             bandpass=bnps,
                             backup=False)

            a_new = [item for sublist in a for item in sublist]
            t_new = [item for sublist in t for item in sublist]

            a = np.array(a_new)
            t = np.array(t_new)

            s2n = np.max(a) / np.median(np.abs(a))
            filtered_wave = a

            if tt == 0:
                total_vals = []
                total_elements = []
                total_h = []

            if self.h_space_tog.isChecked():
                h = invhypfunc(t, *popt)
                if self.branchselector.isChecked():
                    h = h[branch_2]
                    vals = a[branch_2]
                else:
                    h = h[branch_1]
                    vals = a[branch_1]
                self.pvh_graph.ax3.plot(h,
                                        vals,
                                        alpha=0.3,
                                        label="{:}".format(stn_name))

                # hil = reHilbert(vals)

                # self.height_points.append([h, vals])
                # # for ampl, heig in zip(vals, h):
                # #     self.height_points.append([heig, ampl])

                # # Last element
                # if tt == len(trace_list) - 1:

                #     # do the hilbert thing here

                #     self.binify()

                # self.pvh_graph.ax3.plot(h, vals, alpha=0.3, label="{:}: Optimal Bandpass ({:.2f} - {:.2f} Hz) S/N {:.2f}".format(stn_name, filtered_freq[0], filtered_freq[-1], s2n))
            else:

                # hil = reHilbert(a)
                # total_vals.append(hil)
                # total_elements.append(len(hil))

                # # Last element

                # if tt == len(trace_list) - 1:
                #     min_trace = np.nanmin(total_elements)

                #     for vv, val_cut in enumerate(total_vals):
                #         if vv == 0:
                #             adjusted_cut = val_cut[:min_trace]
                #         else:
                #             adjusted_cut += val_cut[:min_trace]

                #     self.pvh_graph.ax4.plot(t[:len(adjusted_cut)], adjusted_cut[:len(t)], alpha=1.0)
                # self.pvh_graph.ax3.plot(t[0], a[0], alpha=0.3, label="{:}: Optimal Bandpass ({:.2f} - {:.2f} Hz) S/N {:.2f}".format(stn_name, filtered_freq[0], filtered_freq[-1], s2n))
                self.pvh_graph.ax3.plot(t,
                                        a,
                                        alpha=0.3,
                                        label="{:}".format(stn_name))

            # sig = a[0][j]
            # sig_times = t[0][j]
            # p, freq, FAS = findDominantPeriodPSD(sig, trace.stats.sampling_rate, normalize=False)

            # self.pvh_graph.ax4.plot(freq, FAS, label="Dominant Period of Signal: {:.2f} s".format(p))
            # print("Dominant Period of Signal: {:.2f} s".format(p))

            shortest_period = 1 / trace.stats.sampling_rate
            longest_period = t[shifter] - t[0]
            ax5h = []
            ax5p = []
            ax6h = []
            ax6p = []
            for i in range((L - spacer) // shifter + 1):

                max_p = np.nanmax(a[i * shifter:int(i * shifter + spacer)])
                min_p = np.nanmin(a[i * shifter:int(i * shifter + spacer)])

                height_of_sol = invhypfunc(t[int(i * shifter + spacer // 2)],
                                           *popt)

                if self.h_space_tog.isChecked():
                    h = invhypfunc(t[int(i * shifter + spacer // 2)], *popt)

                    if h in heights[
                            branch_1] and not self.branchselector.isChecked():
                        ax5h.append(h)
                        ax5p.append((max_p - min_p) / 2)

                    elif h in heights[
                            branch_2] and self.branchselector.isChecked():
                        ax5h.append(h)
                        ax5p.append((max_p - min_p) / 2)
                else:
                    ax5h.append(t[int(i * shifter + spacer // 2)])
                    ax5p.append((max_p - min_p) / 2)

                p, freq, FAS = findDominantPeriodPSD(
                    a[i * shifter:int(i * shifter + spacer)],
                    trace.stats.sampling_rate,
                    normalize=False)

                # if h_min <= height_of_sol and height_of_sol <= h_max:

                if self.h_space_tog.isChecked():

                    h = invhypfunc(t[int(i * shifter + spacer // 2)], *popt)

                    if h in heights[
                            branch_1] and not self.branchselector.isChecked():
                        self.geminus_heights.append(height_of_sol)
                        self.geminus_p.append((max_p - min_p) / 2)
                        self.geminus_stat.append(infra)
                        self.geminus_t.append(p)
                        ax6h.append(h)
                        ax6p.append(p)

                    elif h in heights[
                            branch_2] and self.branchselector.isChecked():
                        self.geminus_heights.append(height_of_sol)
                        self.geminus_p.append((max_p - min_p) / 2)
                        self.geminus_stat.append(infra)
                        self.geminus_t.append(p)
                        ax6h.append(h)
                        ax6p.append(p)
                else:
                    ax6h.append(t[int(i * shifter + spacer // 2)])
                    ax6p.append(p)

                self.pvh_graph.ax6.axhline(y=shortest_period, linestyle='-')

            self.pvh_graph.ax5.scatter(ax5h, ax5p, label=stn_name)
            self.pvh_graph.ax6.scatter(ax6h, ax6p, label=stn_name)

            t_in_range = t[branch_1]
            try:
                t_min_range_1 = t_in_range[0]
                t_max_range_1 = t_in_range[-1]
            except IndexError:
                t_min_range_1, t_max_range_1 = np.nan, np.nan
            t_in_range = t[branch_2]
            try:
                t_min_range_2 = t_in_range[0]
                t_max_range_2 = t_in_range[-1]
            except IndexError:
                t_min_range_2, t_max_range_2 = np.nan, np.nan

        # plt.axhline(y=longest_period, color='k', linestyle='-')
        # spacer = 5*60//2
        # shifter = 100

        # period_list = []
        # period_h_list = []

        # for pp in range(len(waveform_list)):
        #     periods = []
        #     period_times = []
        #     for ii in range(len(waveform_list[pp][0])//shifter):

        #         if shifter*ii - spacer >= 0 and shifter*ii + spacer <= len(waveform_list[pp][0]):
        #             temp_waveform = waveform_list[pp][0][shifter*ii-spacer:shifter*ii+spacer]
        #             temp_time = time_list[pp][0][shifter*ii-spacer:shifter*ii+spacer]
        #             try:
        #                 st = infra_list[pp].stream.select(channel="*DF")[0]
        #             except IndexError:
        #                 sf = infra_list[pp].stream.select(channel="*HZ")[0]
        #             period, freq, FAS = findDominantPeriodPSD(st, resp=infra_list[pp].response)
        #             plt.semilogx(freq, FAS)

        #             periods.append(period)
        #             period_times.append(time_list[pp][0][shifter*ii])

        #     if popt_list[pp] is None:
        #         continue
        #     min_height = 17000
        #     max_height = 40000
        #     height_periods = invhypfunc(period_times, *popt_list[pp])
        #     h_indicies = np.where(np.logical_and(height_periods>=min_height, height_periods<=max_height))

        #     new_heights = []
        #     new_period = []
        #     for hh in h_indicies[0]:

        #         new_heights.append(height_periods[hh])
        #         new_period.append(periods[hh])

        #     period_list.append(new_period)
        #     period_h_list.append(new_heights)

        # plt.show()

        # # This gives the wrong station labels

        # for pp in range(len(p_list)):

        #     self.pvh_graph.ax1.plot(h_list[pp], np.abs(p_list[pp]), label="{:}".format(infra_list[pp].metadata.code), alpha=0.5)
        # #     self.pvh_graph.ax2.plot(period_h_list[pp], period_list[pp], label="{:}".format(infra_list[pp].metadata.code), alpha=0.5)

        try:
            t_min = float(self.min_time_edits.text())
            t_max = float(self.max_time_edits.text())
        except:
            t_min = t[0]
            t_max = t[-1]

        # if self.h_space_tog.isChecked():
        #     self.pvh_graph.ax1.set_xlabel("Height [m]")
        #     self.pvh_graph.ax1.set_xlim([h_min, h_max])
        # else:
        #     self.pvh_graph.ax1.set_xlabel("Time [s]")
        #     self.pvh_graph.ax1.set_xlim([t_min, t_max])

        # self.pvh_graph.ax1.set_ylabel("Overpressure [Pa]")

        self.pvh_graph.ax2.set_xlabel("Frequency [Hz]")
        self.pvh_graph.ax2.set_ylabel("Gain")
        self.pvh_graph.ax2.set_xscale('log')
        self.pvh_graph.ax2.set_yscale('log')
        self.pvh_graph.ax2.axvline(x=filtered_freq[0], linestyle='-')
        self.pvh_graph.ax2.axvline(x=filtered_freq[-1], linestyle='-')

        if self.h_space_tog.isChecked():
            self.pvh_graph.ax3.set_xlabel("Height [m]")
            self.pvh_graph.ax3.set_xlim([h_min, h_max])
        else:
            self.pvh_graph.ax3.set_xlabel("Time [s]")
            self.pvh_graph.ax3.set_xlim([t_min, t_max])
            if self.branchselector.isChecked():
                self.pvh_graph.ax3.axvline(x=t_min_range_2, linestyle='-')
                self.pvh_graph.ax3.axvline(x=t_max_range_2, linestyle='-')
            else:
                self.pvh_graph.ax3.axvline(x=t_min_range_1, linestyle='-')
                self.pvh_graph.ax3.axvline(x=t_max_range_1, linestyle='-')
        self.pvh_graph.ax3.set_ylabel("Overpressure [Pa]")

        # if self.h_space_tog.isChecked():
        #     self.pvh_graph.ax4.set_xlabel("Height [m]")
        #     self.pvh_graph.ax4.set_xlim([h_min, h_max])
        # else:
        #     self.pvh_graph.ax4.set_xlabel("Time [s]")
        #     self.pvh_graph.ax4.set_xlim([t_min, t_max])
        # self.pvh_graph.ax4.set_ylabel("Overpressure [Pa]")

        # self.pvh_graph.ax4.set_xlabel("Frequency [Hz]")
        # self.pvh_graph.ax4.set_ylabel("Gain")
        # self.pvh_graph.ax4.set_xscale('log')

        if self.h_space_tog.isChecked():
            self.pvh_graph.ax5.set_xlabel("Height [m]")
            self.pvh_graph.ax5.set_xlim([h_min, h_max])
        else:
            self.pvh_graph.ax5.set_xlabel("Time [s]")
            self.pvh_graph.ax5.set_xlim([t_min, t_max])
            if self.branchselector.isChecked():
                self.pvh_graph.ax5.axvline(x=t_min_range_2, linestyle='-')
                self.pvh_graph.ax5.axvline(x=t_max_range_2, linestyle='-')
            else:
                self.pvh_graph.ax5.axvline(x=t_min_range_1, linestyle='-')
                self.pvh_graph.ax5.axvline(x=t_max_range_1, linestyle='-')
        self.pvh_graph.ax5.set_ylabel("Overpressure [Pa]")

        if self.h_space_tog.isChecked():
            self.pvh_graph.ax6.set_xlabel("Height [m]")
            self.pvh_graph.ax6.set_xlim([h_min, h_max])
        else:
            self.pvh_graph.ax6.set_xlabel("Time [s]")
            self.pvh_graph.ax6.set_xlim([t_min, t_max])
            if self.branchselector.isChecked():
                self.pvh_graph.ax6.axvline(x=t_min_range_2, linestyle='-')
                self.pvh_graph.ax6.axvline(x=t_max_range_2, linestyle='-')
            else:
                self.pvh_graph.ax6.axvline(x=t_min_range_1, linestyle='-')
                self.pvh_graph.ax6.axvline(x=t_max_range_1, linestyle='-')
        self.pvh_graph.ax6.set_ylabel("Dominant Period [s]")
        self.pvh_graph.ax2.legend()
        # self.pvh_graph.ax3.legend()
        # self.pvh_graph.ax5.legend()
        # self.pvh_graph.ax6.legend()
        self.pvh_graph.show()

    def buildGUI(self):
        self.setWindowTitle('Traj Space')
        app_icon = QtGui.QIcon()
        app_icon.addFile(
            os.path.join('supra', 'GUI', 'Images', 'BAM_no_wave.png'),
            QtCore.QSize(16, 16))
        self.setWindowIcon(app_icon)

        p = self.palette()
        p.setColor(self.backgroundRole(), Qt.black)
        self.setPalette(p)

        theme(self)

        layout = QGridLayout()
        self.setLayout(layout)

        self.pvh_graph = MatplotlibPyQT()
        # self.pvh_graph.ax1 = self.pvh_graph.figure.add_subplot(611)
        self.pvh_graph.ax2 = self.pvh_graph.figure.add_subplot(411)
        self.pvh_graph.ax3 = self.pvh_graph.figure.add_subplot(412)
        # self.pvh_graph.ax4 = self.pvh_graph.figure.add_subplot(513)
        self.pvh_graph.ax5 = self.pvh_graph.figure.add_subplot(413)
        self.pvh_graph.ax6 = self.pvh_graph.figure.add_subplot(414)
        layout.addWidget(self.pvh_graph, 1, 1, 100, 1)
        export_raw = createButton("Export Overpressures and Periods",
                                  layout,
                                  101,
                                  1,
                                  self.exportraw,
                                  args=[])

        run_button = createButton("Make Plots",
                                  layout,
                                  1,
                                  3,
                                  self.calculate,
                                  args=[])
        _, self.N_edits = createLabelEditObj("Number of Windows",
                                             layout,
                                             2,
                                             width=1,
                                             h_shift=1,
                                             tool_tip='',
                                             validate='int',
                                             default_txt='100')
        _, self.l_edits = createLabelEditObj("Percent Overlap",
                                             layout,
                                             3,
                                             width=1,
                                             h_shift=1,
                                             tool_tip='',
                                             validate='float',
                                             default_txt='0.5')
        self.h_space_tog = createToggle("Height Space",
                                        layout,
                                        4,
                                        width=1,
                                        h_shift=2,
                                        tool_tip='')
        self.h_space_tog.setChecked(True)
        self.auto_gain = createToggle("Auto Gain Limits",
                                      layout,
                                      5,
                                      width=1,
                                      h_shift=2,
                                      tool_tip='')
        _, self.gain_edits = createLabelEditObj("Gain Cutoff",
                                                layout,
                                                6,
                                                width=1,
                                                h_shift=1,
                                                tool_tip='',
                                                validate='float',
                                                default_txt='5')
        self.auto_gain.setChecked(True)
        ro_button = createButton("Calculate Relaxation Radii",
                                 layout,
                                 7,
                                 3,
                                 self.geminusify,
                                 args=[])
        _, self.min_height_edits = createLabelEditObj("Minimum Height [km]",
                                                      layout,
                                                      8,
                                                      width=1,
                                                      h_shift=1,
                                                      tool_tip='',
                                                      validate='float',
                                                      default_txt='17')
        _, self.max_height_edits = createLabelEditObj("Maximum Height [km]",
                                                      layout,
                                                      9,
                                                      width=1,
                                                      h_shift=1,
                                                      tool_tip='',
                                                      validate='float',
                                                      default_txt='40')
        self.branchselector = createToggle("Use Branch 2",
                                           layout,
                                           10,
                                           width=1,
                                           h_shift=2,
                                           tool_tip='')
        self.branchselector.setChecked(True)
        _, self.min_time_edits = createLabelEditObj("Minimum Time [s]",
                                                    layout,
                                                    11,
                                                    width=1,
                                                    h_shift=1,
                                                    tool_tip='',
                                                    validate='float',
                                                    default_txt='300')
        _, self.max_time_edits = createLabelEditObj("Maximum Time [s]",
                                                    layout,
                                                    12,
                                                    width=1,
                                                    h_shift=1,
                                                    tool_tip='',
                                                    validate='float',
                                                    default_txt='600')
        self.stat_bandpass = createToggle("Use Station Bandpass",
                                          layout,
                                          13,
                                          width=1,
                                          h_shift=2,
                                          tool_tip='')
        _, self.bin_edits = createLabelEditObj("Size of Bins [m]",
                                               layout,
                                               14,
                                               width=1,
                                               h_shift=1,
                                               tool_tip='',
                                               validate='float',
                                               default_txt='100')
        self.bin_edits.editingFinished.connect(self.binify)

        self.ro_graph = MatplotlibPyQT()
        self.ro_graph.ax = self.ro_graph.figure.add_subplot(111)
        layout.addWidget(self.ro_graph, 15, 2, 90, 2)
        export_ro = createButton("Export Relaxation Radii Curve",
                                 layout,
                                 105,
                                 3,
                                 self.exportro,
                                 args=[])

    def exportraw(self):

        file_name = saveFile('.csv')

        with open(file_name, 'w+') as f:
            f.write(
                'Station, Height [m], Overpressure [Pa], Dominant Period [s] \n'
            )
            for i in range(len(self.geminus_heights)):
                f.write('{:}, {:}, {:}, {:} \n'.format(
                    self.geminus_stat[i].metadata.code,
                    self.geminus_heights[i], self.geminus_p[i],
                    self.geminus_t[i]))

        errorMessage('Raw Data Saved!',
                     0,
                     info='Saved to File {:}'.format(file_name))

    def exportro(self):

        file_name = saveFile('.csv')

        with open(file_name, 'w+') as f:
            f.write(
                'Station, Height [m], Overpressure [Pa], Dominant Period [s], Ro (Weak-Shock, Overpressure) [m], Ro (Linear, Overpressure) [m], Ro (Weak-Shock, Dominant Period) [m], Ro (Linear, Dominant Period) [m]\n'
            )
            for i in range(len(self.geminus_heights)):
                f.write('{:}, {:}, {:}, {:}, {:}, {:}, {:}, {:} \n'.format(\
                    self.geminus_stat[i].metadata.code, self.geminus_heights[i], self.geminus_p[i], self.geminus_t[i], \
                    self.ro_data[i, 1], self.ro_data[i, 2], self.ro_data[i, 3], self.ro_data[i, 4]))

        errorMessage('Data saved to CSV!',
                     0,
                     info='Saved to File {:}'.format(file_name))

    def geminusify(self):
        data = []
        max_steps = len(self.geminus_heights)
        for hh, h in enumerate(self.geminus_heights):
            traj = self.bam.setup.trajectory

            source = traj.findGeo(h)

            self.sounding_pres = None
            source_list = [source.lat, source.lon, source.elev / 1000]

            stat = self.geminus_stat[hh]
            stat_pos = stat.metadata.position
            stat = [stat_pos.lat, stat_pos.lon, stat_pos.elev / 1000]

            v = traj.v / 1000

            theta = 90 - traj.zenith.deg

            dphi = np.degrees(
                np.arctan2(stat_pos.lon - source.lon,
                           stat_pos.lat - source.lat)) - traj.azimuth.deg

            # Switch 3 doesn't do anything in this version of overpressure.py
            sw = [1, 0, 1]

            lat = [source.lat, stat_pos.lat]
            lon = [source.lon, stat_pos.lon]
            elev = [source.elev, stat_pos.elev]

            sounding, _ = self.bam.atmos.getSounding(
                lat=lat,
                lon=lon,
                heights=elev,
                ref_time=self.bam.setup.fireball_datetime)

            pres = 10 * 101.325 * np.exp(-0.00012 * sounding[:, 0])

            sounding_pres = np.zeros(
                (sounding.shape[0], sounding.shape[1] + 1))
            sounding_pres[:, :-1] = sounding
            sounding_pres[:, -1] = pres
            sounding_pres[:, 1] -= 273.15

            sounding_pres = np.flip(sounding_pres, axis=0)

            gem_inputs = [
                source_list, stat, v, theta, dphi, sounding_pres, sw, True,
                True
            ]
            Ro_ws_p, Ro_lin_p = presSearch(self.geminus_p[hh],
                                           gem_inputs,
                                           paths=False)
            Ro_ws_t, Ro_lin_t = periodSearch(self.geminus_t[hh],
                                             gem_inputs,
                                             paths=False)

            data.append([h, Ro_ws_p, Ro_lin_p, Ro_ws_t, Ro_lin_t])
            print("Complete Step {:} of {:}".format(hh + 1, max_steps))
        for row in data:
            self.ro_graph.ax.scatter(row[0],
                                     row[1],
                                     c='m',
                                     label="Weak-Shock Period")
            self.ro_graph.ax.scatter(row[0],
                                     row[2],
                                     c='c',
                                     label="Linear Period")
            self.ro_graph.ax.scatter(row[0],
                                     row[3],
                                     c='y',
                                     label="Weak-Shock Overpressure")
            self.ro_graph.ax.scatter(row[0],
                                     row[4],
                                     c='g',
                                     label="Linear Overpressure")

        self.ro_data = np.array(data)
        self.ro_graph.ax.set_xlabel("Height [m]")
        self.ro_graph.ax.set_ylabel("Relaxation Radius [m]")
        self.ro_graph.ax.legend()
        self.ro_graph.show()

    def getInfraStats(self):

        infra_list = []

        for stn in self.bam.stn_list:
            if len(stn.stream.select(channel="*DF")) > 0:
                infra_list.append(stn)
            # elif len(stn.stream.select(channel="*HZ")) > 0:
            #     infra_list.append(stn)
        return infra_list

    def genHyperbola(self, infra_list):

        popt_list = []
        for stn in infra_list:
            points_x = []
            points_y = []

            for i in range(len(self.bam.setup.fragmentation_point)):

                f_time = stn.times.fragmentation[i][0][0]

                points_x.append(
                    self.bam.setup.fragmentation_point[i].position.elev)
                points_y.append(f_time)

            ### ADD PERTURBATION POINTS HERE

            try:
                x_vals = []
                y_vals = []

                for pp in range(len(points_y)):
                    if not np.isnan(points_y[pp]):

                        x_vals.append(points_x[pp])
                        y_vals.append(points_y[pp])

                x_vals, y_vals = np.array(x_vals), np.array(y_vals)

                # Hyperbola in the form y = kx (since we invert the x values)
                from scipy.optimize import curve_fit

                popt, pcov = curve_fit(hypfunc, x_vals, y_vals)
                popt_list.append(popt)
            except TypeError:
                print("Could not generate hyperbola fit!")
                popt_list.append(None)
            except RuntimeError:
                print("Optimal hyperbola not found!")
                popt_list.append(None)
            except ValueError:
                print("No Arrivals!")
                popt_list.append(None)
        return popt_list

    def getInfraTraces(self, infra_list):
        resp_list = []
        trace_list = []
        for stn in infra_list:
            st, resp, gap_times = procStream(
                stn, ref_time=self.bam.setup.fireball_datetime)
            resp_list.append(resp)
            temp_st = findChn(st, "*DF")
            # if len(temp_st) == 0:
            #     temp_st = findChn(st, "*HZ")
            trace_list.append(temp_st)
        return trace_list, resp_list
        # waveform_data, time_data = procTrace(st, ref_datetime=self.bam.setup.fireball_datetime,\
        #         resp=resp, bandpass=None)

        # wave_d = []
        # time_d = []
        # for wave, time in zip(waveform_data, time_data):
        #     wave_d.append(wave)
        #     time_d.append(time)

        # wave_list.append(wave_d)
        # time_list.append(time_d)

        # return wave_list, time_list

    def convertTimes(self, wave, time, popt):

        press_list = []
        height_list = []

        if popt is None:
            return [], []

        height_peaks = invhypfunc(time, *popt)

        ##################
        # Get Bounds for Heights
        ##################
        h_min = float(self.min_height_edits.text()) * 1000
        h_max = float(self.max_height_edits.text()) * 1000
        h_indicies = np.where(
            np.logical_and(height_peaks >= h_min, height_peaks <= h_max))
        new_heights = []
        new_press = []

        for hh in h_indicies[0]:

            new_heights.append(height_peaks[hh])
            new_press.append(wave[hh])

        press_list.append(new_press)
        height_list.append(new_heights)

        return press_list, height_list

    def genPFFAS(self, wave_list, time_list):

        a = wave_list
        t = time_list

        ### optimal bandpass is when S/N is a maximum
        sampling_rate = 1 / (t[1] - t[0])

        nyq = sampling_rate / 2
        low_freq = 1 / (t[-1] - t[0])

        p, freq, FAS = findDominantPeriodPSD(a, sampling_rate, normalize=True)

        return p, freq, FAS

    def getOptBandpass(self, trace_list, infra_list):

        b_list = [[1e-3, 9.9]]

        for bandpass in b_list:
            for trace, stn in zip(trace_list, infra_list):
                waveform_data, time_data = procTrace(trace, ref_datetime=self.bam.setup.fireball_datetime,\
                   resp=stn.response, bandpass=None)

                for ii in range(100):
                    L = len(waveform_data[0]) // 100

                    p, freq, FAS = self.genPFFAS(
                        waveform_data[0][ii * L:(ii + 1) * L],
                        time_data[0][ii * L:(ii + 1) * L])
                    if ii == 0:
                        FAS_n = FAS

                    f2 = interp1d(freq, FAS)
                    xnew = np.logspace(-3, np.log10(9))
                    plt.semilogx(xnew, f2(xnew))
                    plt.semilogx(freq, FAS)

        plt.show()
示例#2
0
class BandpassWindow(QWidget):
    def __init__(self, bam, stn, channel, t_arrival=0):

        self.bam = bam
        self.stn = stn
        self.channel = channel
        self.t_arrival = t_arrival

        QWidget.__init__(self)
        self.buildGUI()

    def buildGUI(self):

        self.setWindowTitle('Bandpass Optimizer')

        app_icon = QtGui.QIcon()
        app_icon.addFile(
            os.path.join('supra', 'GUI', 'Images', 'BAM_no_wave.png'),
            QtCore.QSize(16, 16))
        self.setWindowIcon(app_icon)

        p = self.palette()
        p.setColor(self.backgroundRole(), Qt.black)
        self.setPalette(p)

        theme(self)

        layout = QGridLayout()
        self.setLayout(layout)

        self.bandpass_view = pg.GraphicsLayoutWidget()
        self.bandpass_canvas = self.bandpass_view.addPlot()

        self.bp_button = createButton('Bandpass', layout, 4, 2, self.bandpass)
        self.save_button = createButton('Save', layout, 4, 3,
                                        self.bandpassSave)

        layout.addWidget(self.bandpass_view, 1, 1, 1, 2)

        _, self.low_bandpass_edits = createLabelEditObj('Low Bandpass',
                                                        layout,
                                                        2,
                                                        width=1,
                                                        h_shift=0,
                                                        tool_tip='',
                                                        validate='float',
                                                        default_txt='2')
        _, self.high_bandpass_edits = createLabelEditObj('High Bandpass',
                                                         layout,
                                                         3,
                                                         width=1,
                                                         h_shift=0,
                                                         tool_tip='',
                                                         validate='float',
                                                         default_txt='8')

        self.stream = self.stn.stream.select(
            channel="{:}".format(self.channel))

        st = self.stn.stream.select(channel="{:}".format(self.channel))[0]
        self.orig_trace = st.copy()
        stn = self.stn

        delta = self.orig_trace.stats.delta
        start_datetime = self.orig_trace.stats.starttime.datetime
        end_datetime = self.orig_trace.stats.endtime.datetime

        stn.offset = (start_datetime -
                      self.bam.setup.fireball_datetime).total_seconds()

        self.current_waveform_delta = delta
        self.current_waveform_time = np.arange(0, self.orig_trace.stats.npts / self.orig_trace.stats.sampling_rate, \
             delta)

        time_data = np.copy(self.current_waveform_time)

        self.orig_trace.detrend()

        resp = stn.response
        if resp is not None:
            self.orig_trace = self.orig_trace.remove_response(inventory=resp,
                                                              output="DISP")
        # st.remove_sensitivity(resp)

        waveform_data = self.orig_trace.data

        self.orig_data = np.copy(waveform_data)

        waveform_data = waveform_data[:len(time_data)]
        time_data = time_data[:len(waveform_data)] + stn.offset

        self.current_waveform_processed = waveform_data

        # Init the butterworth bandpass filter
        butter_b, butter_a = butterworthBandpassFilter(2, 8, \
            1.0/self.current_waveform_delta, order=6)

        # Filter the data
        waveform_data = scipy.signal.filtfilt(butter_b, butter_a,
                                              np.copy(waveform_data))

        self.current_station_waveform = pg.PlotDataItem(x=time_data,
                                                        y=waveform_data,
                                                        pen='w')
        self.bandpass_canvas.addItem(self.current_station_waveform)
        self.bandpass_canvas.setXRange(self.t_arrival - 100,
                                       self.t_arrival + 100,
                                       padding=1)
        self.bandpass_canvas.setLabel(
            'bottom',
            "Time after {:} s".format(self.bam.setup.fireball_datetime))
        self.bandpass_canvas.setLabel('left', "Signal Response")

        self.bandpass_canvas.plot(x=[-10000, 10000],
                                  y=[0, 0],
                                  pen=pg.mkPen(color=(100, 100, 100)))

        self.noise_selector = pg.LinearRegionItem(values=[0, 10],
                                                  brush=(255, 0, 0, 100))

        self.signal_selector = pg.LinearRegionItem(values=[200, 210],
                                                   brush=(0, 255, 0, 100))

        self.bandpass_canvas.addItem(self.noise_selector)
        self.bandpass_canvas.addItem(self.signal_selector)

        self.bandpass_graph = MatplotlibPyQT()
        self.bandpass_graph.ax1 = self.bandpass_graph.figure.add_subplot(211)
        self.bandpass_graph.ax2 = self.bandpass_graph.figure.add_subplot(212)
        layout.addWidget(self.bandpass_graph, 1, 4, 1, 2)

    def bandpassSave(self):
        bandpass = [
            float(self.low_bandpass_edits.text()),
            float(self.high_bandpass_edits.text())
        ]
        self.stn.bandpass = bandpass
        save(self.bam, file_check=False)

    def determineROIidx(self, roi):

        len_of_region = roi[1] - roi[0]

        st = self.stn.stream.select(channel="{:}".format(self.channel))[0]

        number_of_pts_per_s = st.stats.sampling_rate
        num_of_pts_in_roi = len_of_region * number_of_pts_per_s

        num_of_pts_in_offset = np.abs(number_of_pts_per_s * self.stn.offset)

        num_of_pts_to_roi = roi[0] * number_of_pts_per_s

        pt_0 = int(num_of_pts_in_offset + num_of_pts_to_roi)
        pt_1 = int(pt_0 + num_of_pts_in_roi)

        return pt_0, pt_1

    def bandpass(self):

        self.bandpass_graph.ax1.clear()
        self.bandpass_graph.ax2.clear()

        st = self.stn.stream.select(channel="{:}".format(self.channel))[0]

        noise_roi = self.noise_selector.getRegion()
        signal_roi = self.signal_selector.getRegion()

        noise_a, noise_b = self.determineROIidx(noise_roi)
        signal_a, signal_b = self.determineROIidx(signal_roi)


        waveform_data, t = procTrace(self.orig_trace, ref_datetime=self.bam.setup.fireball_datetime, \
                    resp=self.stn.response, bandpass=None, backup=False)

        S = waveform_data[0][signal_a:signal_b]
        N = waveform_data[0][noise_a:noise_b]

        # Need to detrend or bandpass first
        zero_cross_p = findDominantPeriod(S,
                                          t[0][signal_a:signal_b],
                                          return_all=True)

        # Make sure the windows are the same length
        if len(N) >= len(S):
            N = N[:len(S)]

        freq, FAS_S = genFFT(S,
                             self.orig_trace.stats.sampling_rate,
                             interp=False)
        freq, FAS_N = genFFT(N,
                             self.orig_trace.stats.sampling_rate,
                             interp=False)

        S_N_FAS = FAS_S / FAS_N

        self.bandpass_graph.ax2.loglog(freq, FAS_S, label="Signal")
        self.bandpass_graph.ax2.loglog(freq, FAS_N, label="Noise")
        self.bandpass_graph.ax1.loglog(freq, S_N_FAS, label="Signal/Noise")
        self.bandpass_graph.ax1.axhline(y=1)

        self.bandpass_graph.ax1.set_xlabel("Frequency [Hz]")
        self.bandpass_graph.ax2.set_xlabel("Frequency [Hz]")

        if len(zero_cross_p) == 0:
            print("No Zero-Crossings!")
        else:
            print("Zero-Crossing Periods:")

        for pp, p in enumerate(zero_cross_p):
            if pp == 0:
                self.bandpass_graph.ax1.axvline(
                    x=p, label="Zero-Crossing Dominant Period")
                self.bandpass_graph.ax2.axvline(
                    x=p, label="Zero-Crossing Dominant Period")
            else:
                self.bandpass_graph.ax1.axvline(x=p)
                self.bandpass_graph.ax2.axvline(x=p)
            print("{:.2f} s".format(p))

        self.bandpass_graph.ax1.legend()
        self.bandpass_graph.ax2.legend()
        self.bandpass_graph.show()
示例#3
0
class lumEffDialog(QWidget):
    def __init__(self, bam):

        QWidget.__init__(self)

        self.bam = bam
        self.setup = self.bam.setup

        self.v = self.bam.setup.trajectory.v
        self.tau = [5.00] * len(self.bam.energy_measurements)
        self.height_list = []
        for i in range(len(self.bam.energy_measurements)):
            self.height_list.append([None, None])

        self.buildGUI()
        self.processEnergy()

    def buildGUI(self):
        self.setWindowTitle('Luminous Efficiency')
        app_icon = QtGui.QIcon()
        app_icon.addFile(
            os.path.join('supra', 'GUI', 'Images', 'BAM_no_wave.png'),
            QtCore.QSize(16, 16))
        self.setWindowIcon(app_icon)

        p = self.palette()
        p.setColor(self.backgroundRole(), Qt.black)
        self.setPalette(p)

        theme(self)

        main_layout = QGridLayout()
        self.light_curve = MatplotlibPyQT()
        self.light_curve.ax = self.light_curve.figure.add_subplot(111)
        main_layout.addWidget(self.light_curve, 1, 101, 1, 100)

        self.lum_curve = MatplotlibPyQT()
        self.lum_curve.ax = self.lum_curve.figure.add_subplot(111)
        main_layout.addWidget(self.lum_curve, 1, 202, 1, 100)

        self.lightCurve()

        self.sources_table = QScrollArea()
        # self.sources_layout.addWidget(self.sources_table)
        self.sources_table.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.sources_table.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)

        self.sources_table.setWidgetResizable(True)

        container = QWidget()
        container.setStyleSheet("""
            QWidget {
                background-color: rgb(0, 0, 0);
                }
            """)
        self.sources_table.setWidget(container)
        self.sources_table_layout = QVBoxLayout(container)
        self.sources_table_layout.setSpacing(10)

        main_layout.addWidget(self.sources_table, 1, 1, 1, 100)
        l, self.tau_edits, b = createFileSearchObj("Tau Curve File",
                                                   main_layout,
                                                   3,
                                                   width=1,
                                                   h_shift=0,
                                                   tool_tip='')
        b.clicked.connect(partial(fileSearch, ['CSV (*.csv)'], self.tau_edits))

        self.redraw_button = createButton("Plot",
                                          main_layout,
                                          2,
                                          1,
                                          self.redraw,
                                          args=[])
        self.cla_button = createButton("Clear All",
                                       main_layout,
                                       2,
                                       2,
                                       self.clearEnergy,
                                       args=[])

        self.setLayout(main_layout)

    def clearEnergy(self):

        self.bam.energy_measurements = []
        self.redraw()

    def plotTaus(self):

        with open(self.tau_edits.text(), 'r+') as f:
            light_curve = []
            for line in f:
                line = line.split(',')
                temp_line = []

                for item in line:
                    temp_line.append(float(item.strip()))

                light_curve.append(temp_line)

        light_curve = [x for x in light_curve if x != []]

        light_curve = np.array(light_curve)

        h = light_curve[:, 1]
        t = light_curve[:, 0]

        self.lum_curve.ax.plot(t, h)
        self.lum_curve.show()

    def redraw(self):

        self.height_list = []

        # Get Taus
        for ii in range(len(self.source_widget)):
            self.tau[ii] = self.source_widget[ii].getTau()

            min_h, max_h = self.source_widget[ii].getHeights()
            self.height_list.append([min_h, max_h])

        # Remove all Widgets
        for i in reversed(range(self.sources_table_layout.count())):
            self.sources_table_layout.itemAt(i).widget().setParent(None)

        # Clear Graph
        self.light_curve.ax.clear()
        self.lum_curve.ax.clear()

        # Light Curve
        self.lightCurve()

        # Plot tau
        if self.tau_edits.text() != "":
            self.plotTaus()

        # Data
        self.processEnergy()

    def lightCurve(self):

        if len(self.setup.light_curve_file) > 0 or not hasattr(
                self.setup, "light_curve_file"):

            light_curve = readLightCurve(self.setup.light_curve_file)

            self.light_curve_list = processLightCurve(light_curve)

            for L in self.light_curve_list:
                h, M = L.interpCurve(dh=10000)

                self.light_curve.ax.plot(h, M, label=L.station)

            self.light_curve.ax.invert_yaxis()
            self.light_curve.ax.legend()

    def processEnergy(self):
        def ballEnergy2Mag(ball_E, v, tau):
            return -2.5 * np.log10((ball_E * v * tau) / 1500)

        def fragEnergy2Mag(frag_E, h):

            for L in self.light_curve_list:
                M_list, h_list = L.interpCurve(dh=10000)
                area = findArea(h / 1000, frag_E, M_list, h_list)

            return area

        self.source_widget = [None] * len(self.bam.energy_measurements)

        for ee, energy in enumerate(self.bam.energy_measurements):

            self.source_widget[ee] = TauEx(energy, self.tau[ee],
                                           self.height_list[ee])

            if self.source_widget[ee].mark_for_deletion == True:
                self.bam.energy_measurements[ee] = None
                continue

            self.sources_table_layout.addWidget(self.source_widget[ee])

            tau = self.tau[ee] / 100
            h = energy.height

            ### BALLISTIC
            if energy.source_type.lower() == "ballistic":

                lin_e = energy.linear_E
                ws_e = energy.ws_E

                lin_mag = ballEnergy2Mag(lin_e, self.v, tau)
                print("Magnitude {:.2f}".format(lin_mag))
                self.light_curve.ax.scatter(
                    h / 1000, lin_mag, label="Ballistic Measurement - Linear")

                ws_mag = ballEnergy2Mag(ws_e, self.v, tau)
                print("Magnitude {:.2f}".format(ws_mag))
                self.light_curve.ax.scatter(
                    h / 1000,
                    ws_mag,
                    label="Ballistic Measurement - Weak Shock")

            ### FRAGMENTATION
            elif energy.source_type.lower() == "fragmentation":

                chem_pres_yield = energy.chem_pres

                energy_area = fragEnergy2Mag(chem_pres_yield * tau * self.v, h)

                if energy_area is not None:
                    self.light_curve.ax.fill_between(
                        energy_area[:, 1],
                        energy_area[:, 0],
                        color="w",
                        alpha=0.3,
                        label="Fragmentation: {:.1f} km".format(h / 1000))

                h_min = float(self.source_widget[ee].min_h_edits.text())
                h_max = float(self.source_widget[ee].max_h_edits.text())

                self.light_curve.ax.axvline(x=h_min,
                                            color='w',
                                            linestyle='--',
                                            alpha=0.3)
                self.light_curve.ax.axvline(x=h_max,
                                            color='w',
                                            linestyle='--',
                                            alpha=0.3)

            self.lum_curve.ax.scatter(tau * 100,
                                      h / 1000,
                                      label=energy.source_type)
            self.lum_curve.ax.set_xlabel("Luminous Efficiency [%]")
            self.lum_curve.ax.set_ylabel("Height [km]")
            self.lum_curve.ax.legend()
            self.lum_curve.show()

        self.light_curve.ax.set_xlabel("Height [km]")
        self.light_curve.ax.set_ylabel("Magnitude")
        self.light_curve.ax.legend()
        self.light_curve.show()
示例#4
0
class FragmentationStaff(QWidget):
    """ A visualization of arrival times vs. heights. Makes it easier to see what height along
    the trajectory (fragmentation or ballistic) the arrival is from.
    """
    def __init__(self, setup, pack, bam):

        QWidget.__init__(self)

        self.buildGUI()

        # Take important values from main window class
        stn, self.current_station, self.pick_list, self.channel = pack

        # The first pick created is the one all times are based around (reference time)
        nom_pick = self.pick_list[0]

        main_layout = QGridLayout()

        # Pass setup value
        self.setup = setup
        self.bam = bam

        ###########
        # Build GUI
        ###########

        self.height_plot = MatplotlibPyQT()

        main_layout.addWidget(self.height_plot, 1, 1, 1, 100)

        export_button = QPushButton('Export')
        main_layout.addWidget(export_button, 3, 1, 1, 25)
        export_button.clicked.connect(self.export)

        X = []
        Y = []
        Y_M = []

        # Create a local coordinate system with the bottom of the trajectory as the reference
        A = self.setup.trajectory.pos_i
        B = self.setup.trajectory.pos_f

        A.pos_loc(B)
        B.pos_loc(B)

        # All points are placed here
        self.dots_x = []
        self.dots_y = []

        #########################
        # Light Curve Plot
        #########################
        if len(self.setup.light_curve_file) > 0 and hasattr(
                self.setup, "light_curve_file"):
            self.height_plot.ax1 = self.height_plot.figure.add_subplot(211)
            self.height_plot.ax2 = self.height_plot.figure.add_subplot(
                212, sharex=self.height_plot.ax1)
            # lc_plot = MatplotlibPyQT()
            # main_layout.addWidget(lc_plot, 1, 1, 1, 100)
            # self.light_curve_view = pg.GraphicsLayoutWidget()
            # self.light_curve_canvas = self.light_curve_view.addPlot()

            light_curve = readLightCurve(self.setup.light_curve_file)

            light_curve_list = processLightCurve(light_curve)

            for L in light_curve_list:
                self.height_plot.ax1.scatter(L.h, L.I, label=L.station)
                # light_curve_curve = pg.ScatterPlotItem(x=L.M, y=L.t)
                # self.light_curve_canvas.addItem(light_curve_curve)

            self.height_plot.ax1.legend()
            # plt.gca().invert_yaxis()

            # main_layout.addWidget(self.light_curve_view, 1, 101, 1, 10)

            # blank_spacer = QWidget()
            # main_layout.addWidget(blank_spacer, 2, 101, 2, 10)

            self.height_plot.ax1.set_xlim((-10, 100))
            self.height_plot.ax1.set_xlabel("Height [km]")
            self.height_plot.ax1.set_ylabel("Intensity")
        else:
            self.height_plot.ax2 = self.height_plot.figure.add_subplot(111)

        #########################
        # Station Plot
        #########################

        try:

            # Bandpass the waveform from 2 - 8 Hz (not optimal, but does a good job
            # in showing arrivals clearly for most cases)
            st, resp, gap_times = procStream(
                stn, ref_time=self.setup.fireball_datetime)
            st = findChn(st, self.channel)
            waveform_data, time_data = procTrace(st, ref_datetime=self.setup.fireball_datetime,\
                    resp=resp, bandpass=[2, 8])

            # Scale the data so that the maximum is brought out to SCALE_LEN

            SCALE_LEN = 10  # km
            max_val = 0
            for ii in range(len(waveform_data)):
                wave = waveform_data[ii]
                for point in wave:
                    if abs(point) > max_val:
                        max_val = abs(point)

            scaling = SCALE_LEN / max_val

            # Plot all waveform data segments (for gaps in data)
            for ii in range(len(waveform_data)):
                self.height_plot.ax2.plot(waveform_data[ii] * scaling,
                                          time_data[ii] - nom_pick.time)

        except ValueError:
            print("Could not filter waveform!")

        #########################
        # Light Curve Plot
        #########################

        if len(self.setup.light_curve_file) > 0 or not hasattr(
                self.setup, "light_curve_file"):

            light_curve = readLightCurve(self.setup.light_curve_file)

            light_curve_list = processLightCurve(light_curve)

            for L in light_curve_list:
                self.height_plot.ax1.scatter(L.h, L.I, label=L.station)
                # light_curve_curve = pg.ScatterPlotItem(x=L.M, y=L.t)
                # self.light_curve_canvas.addItem(light_curve_curve)

            self.height_plot.ax1.legend()

            # plt.gca().invert_yaxis()

            # main_layout.addWidget(self.light_curve_view, 1, 101, 1, 10)

            # blank_spacer = QWidget()
            # main_layout.addWidget(blank_spacer, 2, 101, 2, 10)

        ########################
        # Generate Hyperbola
        ########################

        # D_0 = A

        # stn.metadata.position.pos_loc(B)

        # theta = self.setup.trajectory.zenith.rad
        # h_0 = A.z

        # h = np.arange(0, 100000)
        # v = self.setup.trajectory.v
        # k = stn.metadata.position - D_0
        # n = Position(0, 0, 0)
        # n.x, n.y, n.z = self.setup.trajectory.vector.x, self.setup.trajectory.vector.y, self.setup.trajectory.vector.z
        # n.pos_geo(B)
        # c = 350

        # T = (h - h_0)/(-v*np.cos(theta)) + (k - n*((h - h_0)/(-np.cos(theta)))).mag()/c - nom_pick.time

        # # estimate_plot = pg.PlotDataItem(x=h, y=T)
        # # self.height_canvas.addItem(estimate_plot, update=True)

        # self.height_plot.ax2.scatter(h/1000, T)

        #######################
        # Plot nominal points
        #######################
        # base_points = pg.ScatterPlotItem()
        angle_off = []
        no_wind_points = []
        u = np.array([
            self.setup.trajectory.vector.x, self.setup.trajectory.vector.y,
            self.setup.trajectory.vector.z
        ])
        for i in range(len(self.setup.fragmentation_point)):

            f_time = stn.times.fragmentation[i][0][0]

            X = self.setup.fragmentation_point[i].position.elev
            Y = f_time - nom_pick.time

            self.dots_x.append(X)
            self.dots_y.append(Y)

            travel_dis = self.setup.fragmentation_point[
                i].position.pos_distance(stn.metadata.position)

            travel_time = travel_dis / 330

            no_wind_points.append([self.setup.fragmentation_point[i].position.elev, \
                         travel_time - nom_pick.time + self.setup.fragmentation_point[i].time])

            az = stn.times.fragmentation[i][0][1]
            tf = stn.times.fragmentation[i][0][2]

            az = np.radians(az)
            tf = np.radians(180 - tf)
            v = np.array([
                np.sin(az) * np.sin(tf),
                np.cos(az) * np.sin(tf), -np.cos(tf)
            ])

            angle_off.append(
                np.degrees(
                    np.arccos(
                        np.dot(u / np.sqrt(u.dot(u)), v / np.sqrt(v.dot(v))))))

            print(
                "Points", X, Y,
                np.degrees(
                    np.arccos(
                        np.dot(u / np.sqrt(u.dot(u)), v / np.sqrt(v.dot(v))))))

            # base_points.addPoints(x=[X], y=[Y], pen=(255, 0, 238), brush=(255, 0, 238), symbol='o')

        ptb_colors = [(0, 255, 26, 150), (3, 252, 176, 150), (252, 3, 3, 150),
                      (176, 252, 3, 150), (255, 133, 3, 150),
                      (149, 0, 255, 150), (76, 128, 4, 150), (82, 27, 27, 150),
                      (101, 128, 125, 150), (5, 176, 249, 150)]

        # base_points.setZValue(1)
        # self.height_canvas.addItem(base_points, update=True)

        #########################
        # Plot Precursor Points
        #########################

        # pre_points = pg.ScatterPlotItem()

        for i in range(len(self.setup.fragmentation_point)):

            X = self.setup.fragmentation_point[i].position.elev

            # Distance between frag point and the ground below it
            v_dist = X - stn.metadata.position.elev

            # Horizontal distance betweent the new ground point and the stn
            h_dist = self.setup.fragmentation_point[
                i].position.ground_distance(stn.metadata.position)

            # Speed of wave in air
            v_time = v_dist / 310

            # Speed of wave in ground
            h_time = h_dist / 3100

            # Total travel time
            Y = v_time + h_time - nom_pick.time

            self.height_plot.ax2.scatter(np.array(X) / 1000,
                                         np.array(Y),
                                         c='y')
            # pre_points.addPoints(x=[X], y=[Y], pen=(210, 235, 52), brush=(210, 235, 52), symbol='o')

        # self.height_canvas.addItem(pre_points, update=True)

        #########################
        # Perturbation points
        #########################
        # prt_points = pg.ScatterPlotItem()
        for i in range(len(self.setup.fragmentation_point)):
            data, remove = self.obtainPerts(stn.times.fragmentation, i)
            azdata, remove = self.obtainPerts(stn.times.fragmentation, i, pt=1)
            tfdata, remove = self.obtainPerts(stn.times.fragmentation, i, pt=2)
            Y = []
            X = self.setup.fragmentation_point[i].position.elev
            for pt, az, tf in zip(data, azdata, tfdata):

                Y = (pt - nom_pick.time)

                self.dots_x.append(X)
                self.dots_y.append(Y)

                az = np.radians(az)
                tf = np.radians(180 - tf)
                v = np.array([
                    np.sin(az) * np.sin(tf),
                    np.cos(az) * np.sin(tf), -np.cos(tf)
                ])

                angle_off.append(
                    np.degrees(
                        np.arccos(
                            np.dot(u / np.sqrt(u.dot(u)),
                                   v / np.sqrt(v.dot(v))))))

        colour_angle = abs(90 - np.array(angle_off))
        sc = self.height_plot.ax2.scatter(np.array(self.dots_x) / 1000,
                                          np.array(self.dots_y),
                                          c=colour_angle,
                                          cmap='viridis_r')
        cbar = self.height_plot.figure.colorbar(sc,
                                                orientation="horizontal",
                                                pad=0.2)
        cbar.ax.set_xlabel("Difference from 90 deg [deg]")
        no_wind_points = np.array(no_wind_points)
        self.height_plot.ax2.scatter(no_wind_points[:, 0] / 1000,
                                     no_wind_points[:, 1],
                                     c='w')

        for pick in self.pick_list:
            if pick.group == 0:
                self.height_plot.ax2.axhline(pick.time - nom_pick.time, c='g')
                # self.height_canvas.addItem(pg.InfiniteLine(pos=(0, pick.time - nom_pick.time), angle=0, pen=QColor(0, 255, 0)))
            else:
                self.height_plot.ax2.axhline(pick.time - nom_pick.time, c='b')
                # self.height_canvas.addItem(pg.InfiniteLine(pos=(0, pick.time - nom_pick.time), angle=0, pen=QColor(0, 0, 255)))

        self.dots = np.array([self.dots_x, self.dots_y])

        #####################
        # Angle Calculation
        #####################

        u = np.array([
            self.setup.trajectory.vector.x, self.setup.trajectory.vector.y,
            self.setup.trajectory.vector.z
        ])

        angle_off = []
        X = []
        for i in range(len(self.setup.fragmentation_point)):
            az = stn.times.fragmentation[i][0][1]
            tf = stn.times.fragmentation[i][0][2]

            az = np.radians(az)
            tf = np.radians(180 - tf)
            v = np.array([
                np.sin(az) * np.sin(tf),
                np.cos(az) * np.sin(tf), -np.cos(tf)
            ])

            angle_off.append(
                np.degrees(
                    np.arccos(
                        np.dot(u / np.sqrt(u.dot(u)), v / np.sqrt(v.dot(v))))))
            X.append(self.setup.fragmentation_point[i].position.elev)
        angle_off = np.array(angle_off)

        ###############################
        # Find optimal ballistic angle
        ###############################
        try:
            best_indx = np.nanargmin(abs(angle_off - 90))
            print(
                "Optimal Ballistic Height {:.2f} km with angle of {:.2f} deg".
                format(X[best_indx] / 1000, angle_off[best_indx]))

            self.height_plot.ax2.axvline(X[best_indx] / 1000, c='b')
            # self.angle_canvas.addItem(pg.InfiniteLine(pos=(X[best_indx], 0), angle=90, pen=QColor(0, 0, 255)))
            # self.height_canvas.addItem(pg.InfiniteLine(pos=(X[best_indx], 0), angle=90, pen=QColor(0, 0, 255)))
            # self.angle_canvas.scatterPlot(x=X, y=angle_off, pen=(255, 255, 255), symbol='o', brush=(255, 255, 255))

            best_arr = []
            angle_arr = []

        except ValueError:
            best_indx = None

        angle_off = 0
        height = None
        for i in range(len(self.setup.fragmentation_point)):
            for j in range(len(stn.times.fragmentation[i][1])):
                az = stn.times.fragmentation[i][1][j][1]
                tf = stn.times.fragmentation[i][1][j][2]
                az = np.radians(az)
                tf = np.radians(180 - tf)
                v = np.array([
                    np.sin(az) * np.sin(tf),
                    np.cos(az) * np.sin(tf), -np.cos(tf)
                ])

                angle_off_new = np.degrees(
                    np.arccos(
                        np.dot(u / np.sqrt(u.dot(u)), v / np.sqrt(v.dot(v)))))

                # self.angle_canvas.scatterPlot(x=[self.setup.fragmentation_point[i].position.elev], y=[angle_off_new], symbol='o')

                if abs(angle_off_new - 90) < abs(
                        angle_off - 90) and not np.isnan(angle_off_new):
                    angle_off = angle_off_new

                    height = self.setup.fragmentation_point[i].position.elev

        if height is not None:
            # self.angle_canvas.addItem(pg.InfiniteLine(pos=(height, 0), angle=90, pen=QColor(0, 0, 255)))
            self.height_plot.ax2.axvline(height / 1000, c='b')
            # self.height_canvas.addItem(pg.InfiniteLine(pos=(height, 0), angle=90, pen=QColor(0, 0, 255)))

        # self.angle_canvas.addItem(pg.InfiniteLine(pos=(0, 90), angle=0, pen=QColor(255, 0, 0)))

        ########################
        # Fit Hyperbola
        ########################

        try:
            x_vals = []
            y_vals = []

            for pp, point in enumerate(self.dots_y):
                if not np.isnan(point):
                    y_vals.append(point)
                    x_vals.append(self.dots_x[pp])

            x_vals, y_vals = np.array(x_vals), np.array(y_vals)

            def hypfunc(x, a, b, h, k):
                return b * np.sqrt(1 + ((x - h) / a)**2) + k

            def invhypfunc(x, a, b, h, k):
                return a * np.sqrt(((x - k) / b)**2 - 1) + h

            # Hyperbola in the form y = kx (since we invert the x values)
            from scipy.optimize import curve_fit

            popt, pcov = curve_fit(hypfunc, x_vals, y_vals)

            h = np.linspace(0, 100000, 1000)

            self.height_plot.ax2.plot(h / 1000, hypfunc(h, *popt), c='m')
        except TypeError:
            print("Could not generate hyperbola fit!")
        except RuntimeError:
            print("Optimal hyperbola not found!")
        except ValueError:
            print("No Arrivals!")

        ########################
        # Overpressure vs Time
        ########################
        try:
            st, resp, gap_times = procStream(
                stn, ref_time=self.setup.fireball_datetime)
            st = findChn(st, self.channel)
            waveform_data, time_data = procTrace(st, ref_datetime=self.setup.fireball_datetime,\
                    resp=resp, bandpass=[2, 8])

        except ValueError:
            print("Could not filter waveform!")

        # waveform_peaks = []
        # time_peaks = []
        # N = 10

        # for wave, time_pack in zip(waveform_data, time_data):
        #     for ii in range(0, len(wave), N):
        #         peak = np.nanmax(np.abs(wave[ii:ii+N]))

        #         time_peak = np.mean(time_pack[ii:ii+N] - nom_pick.time)

        #         waveform_peaks.append(peak)
        #         time_peaks.append(time_peak)

        # ##################
        # # Relate Time to Height
        # ##################

        # height_peaks = invhypfunc(time_peaks, *popt)

        # ##################
        # # Get Bounds for Heights
        # ##################
        # min_height = 17000
        # max_height = 40000
        # h_indicies = np.where(np.logical_and(height_peaks>=min_height, height_peaks<=max_height))

        # new_heights = []
        # new_press = []
        # for hh in h_indicies[0]:

        #     new_heights.append(height_peaks[hh])
        #     new_press.append(waveform_peaks[hh])

        # plt.subplot(3, 1, 1)

        # plt.plot(new_heights, new_press)
        # plt.xlabel("Height [m]")
        # plt.ylabel("Overpressure [Pa]")

        # spacer = 5*60
        # shifter = 10
        # periods = []
        # period_times = []

        # for wave, time_pack in zip(waveform_data, time_data):
        #     for ii in range(len(wave)//shifter):

        #         temp_waveform = wave[shifter*ii:shifter*ii+spacer]
        #         temp_time = time_pack[shifter*ii:shifter*ii+spacer]

        #         period = findDominantPeriodPSD(temp_waveform, st.stats.sampling_rate)

        #         periods.append(period)
        #         period_times.append(time_pack[shifter*ii])

        # plt.subplot(3, 1, 2)
        # height_periods = invhypfunc(period_times, *popt)
        # h_indicies = np.where(np.logical_and(height_periods>=min_height, height_periods<=max_height))

        # new_heights = []
        # new_period = []
        # for hh in h_indicies[0]:

        #     new_heights.append(height_periods[hh])
        #     new_period.append(periods[hh])

        # plt.plot(new_heights, new_period)
        # plt.xlabel("Height [m]")
        # plt.ylabel("Dominant Period [s]")

        # plt.subplot(3, 1, 3)
        # ws_list = []
        # lin_list = []
        # h_list = []
        # N_times=1
        # ws_list_p = []
        # lin_list_p = []
        # h_list_p = []
        # for ii in range(len(new_heights)//N_times):
        #     ws, lin = self.periodSearch(new_heights[ii*N_times], stn, new_period[ii*N_times])
        #     print("Phase 1: {:}/{:}".format(ii+1, len(new_heights)//N_times))
        #     ws_list_p.append(ws)
        #     lin_list_p.append(lin)
        #     h_list_p.append(new_heights[ii*N_times])
        # plt.scatter(h_list_p, ws_list_p, label='Period Weak-Shock', alpha=0.4)
        # plt.scatter(h_list_p,lin_list_p, label='Period Linear', alpha=0.4)
        # # for ii in range(len(new_heights)//N_times):
        # #     ws, lin = self.presSearch(new_heights[ii*N_times], stn, new_press[ii*N_times])
        # #     print("Phase 2: {:}/{:}".format(ii+1, len(new_heights)//N_times))
        # #     ws_list.append(ws)
        # #     lin_list.append(lin)
        # #     h_list.append(new_heights[ii*N_times])
        # # plt.scatter(h_list, ws_list, label='Pressure Weak-Shock', alpha=0.4)
        # # plt.scatter(h_list,lin_list, label='Pressure Linear', alpha=0.4)

        # plt.xlabel("Height [m]")
        # plt.ylabel("Relaxation Radius [m]")
        # plt.legend()
        # plt.show()

        # fit_hyper = pg.PlotDataItem(x=h, y=hypfunc(h, *popt))
        # self.height_canvas.addItem(fit_hyper, update=True, pen='m')

        #####################
        # Build plot window
        #####################

        # 25 deg tolerance window
        # phigh = pg.PlotCurveItem([np.nanmin(X), np.nanmax(X)], [65, 65], pen = 'g')
        # plow = pg.PlotCurveItem([np.nanmin(X), np.nanmax(X)], [115, 115], pen = 'g')
        # pfill = pg.FillBetweenItem(phigh, plow, brush = (0, 0, 255, 100))
        # self.angle_canvas.addItem(phigh)
        # self.angle_canvas.addItem(plow)
        # self.angle_canvas.addItem(pfill)

        # Build axes
        # self.height_canvas.setTitle('Fragmentation Height Prediction of Given Pick', color=(0, 0, 0))
        # self.angle_canvas.setTitle('Angles of Initial Acoustic Wave Path', color=(0, 0, 0))
        # self.height_canvas.setLabel('left', 'Difference in Time from {:.2f}'.format(nom_pick.time), units='s')
        # self.angle_canvas.setLabel('left', 'Angle Away from Trajectory of Initial Acoustic Wave Path', units='deg')
        # self.height_canvas.setLabel('bottom', 'Height of Solution', units='m')
        # self.angle_canvas.setLabel('bottom', 'Height of Solution', units='m')
        tol = 5  #seconds

        self.height_plot.ax2.set_xlim((-10, 100))
        self.height_plot.ax2.set_ylim(
            (np.min([0, np.nanmin(self.dots_y)]) - tol,
             tol + np.max([0, np.nanmax(self.dots_y)])))
        # print(len(X), len(Y))
        # self.height_plot.ax2.scatter(np.array(X)/1000, np.array(Y), c='m')

        self.height_plot.ax2.set_xlabel("Height [km]")
        self.height_plot.ax2.set_ylabel("Time after {:.2f} [s]".format(
            nom_pick.time))

        # self.height_plot.ax1.show()
        # self.height_plot.ax2.show()
        self.height_plot.figure.tight_layout()
        self.height_plot.figure.subplots_adjust(hspace=0)
        self.height_plot.show()
        # self.height_canvas.setLimits(xMin=B.elev, xMax=A.elev, yMin=-40, yMax=100, minXRange=1000, maxXRange=33000, minYRange=2, maxYRange=140)

        # Fonts
        # font= QFont()
        # font.setPixelSize(20)
        # self.height_canvas.getAxis("bottom").tickFont = font
        # self.height_canvas.getAxis("left").tickFont = font
        # self.height_canvas.getAxis('bottom').setPen((255, 255, 255))
        # self.height_canvas.getAxis('left').setPen((255, 255, 255))
        # self.angle_canvas.getAxis("bottom").tickFont = font
        # self.angle_canvas.getAxis("left").tickFont = font
        # self.angle_canvas.getAxis('bottom').setPen((255, 255, 255))
        # self.angle_canvas.getAxis('left').setPen((255, 255, 255))
        self.setLayout(main_layout)

    def buildGUI(self):
        self.setWindowTitle('Fragmentation Staff')

        app_icon = QIcon()
        app_icon.addFile(
            os.path.join('supra', 'GUI', 'Images', 'BAM_no_wave.png'),
            QSize(16, 16))
        self.setWindowIcon(app_icon)

        p = self.palette()
        p.setColor(self.backgroundRole(), Qt.black)
        self.setPalette(p)

        theme(self)

    def export(self):

        dlg = QFileDialog.getSaveFileName(self, 'Save File')

        file_name = dlg[0]

        exporter = pg.exporters.SVGExporter(self.plot_view.scene())

        file_name = file_name + '.svg'
        exporter.export(file_name)

    def obtainPerts(self, data, frag, pt=0):
        data_new = []

        for i in range(len(data[frag][1])):
            data_new.append(data[frag][1][i][pt])
        data, remove = chauvenet(data_new)

        return data, remove

    def linkSeis(self):

        if self.linked_seis:
            self.seis_canvas.setYLink(None)

        else:
            self.seis_canvas.setYLink(self.height_canvas)

    def presSearch(self, h, stn, op):

        traj = self.setup.trajectory

        source = traj.findGeo(h)

        self.sounding_pres = None
        source_list = [source.lat, source.lon, source.elev / 1000]

        stat = stn
        stat_pos = stat.metadata.position
        stat = [stat_pos.lat, stat_pos.lon, stat_pos.elev / 1000]

        v = traj.v / 1000

        theta = 90 - traj.zenith.deg

        dphi = np.degrees(
            np.arctan2(stat_pos.lon - source.lon,
                       stat_pos.lat - source.lat)) - traj.azimuth.deg

        # Switch 3 doesn't do anything in this version of overpressure.py
        sw = [1, 0, 1]

        lat = [source.lat, stat_pos.lat]
        lon = [source.lon, stat_pos.lon]
        elev = [source.elev, stat_pos.elev]

        sounding, _ = self.bam.atmos.getSounding(lat=lat,
                                                 lon=lon,
                                                 heights=elev)

        pres = 10 * 101.325 * np.exp(-0.00012 * sounding[:, 0])

        sounding_pres = np.zeros((sounding.shape[0], sounding.shape[1] + 1))
        sounding_pres[:, :-1] = sounding
        sounding_pres[:, -1] = pres
        sounding_pres[:, 1] -= 273.15

        sounding_pres = np.flip(sounding_pres, axis=0)

        gem_inputs = [
            source_list, stat, v, theta, dphi, sounding_pres, sw, True, True
        ]
        Ro_ws, Ro_lin = presSearch(op, gem_inputs, paths=False)

        return Ro_ws, Ro_lin

    def periodSearch(self, h, stn, op):

        traj = self.setup.trajectory

        source = traj.findGeo(h)

        self.sounding_pres = None
        source_list = [source.lat, source.lon, source.elev / 1000]

        stat = stn
        stat_pos = stat.metadata.position
        stat = [stat_pos.lat, stat_pos.lon, stat_pos.elev / 1000]

        v = traj.v / 1000

        theta = 90 - traj.zenith.deg

        dphi = np.degrees(
            np.arctan2(stat_pos.lon - source.lon,
                       stat_pos.lat - source.lat)) - traj.azimuth.deg

        # Switch 3 doesn't do anything in this version of overpressure.py
        sw = [1, 0, 1]

        lat = [source.lat, stat_pos.lat]
        lon = [source.lon, stat_pos.lon]
        elev = [source.elev, stat_pos.elev]

        sounding, _ = self.bam.atmos.getSounding(lat=lat,
                                                 lon=lon,
                                                 heights=elev)

        pres = 10 * 101.325 * np.exp(-0.00012 * sounding[:, 0])

        sounding_pres = np.zeros((sounding.shape[0], sounding.shape[1] + 1))
        sounding_pres[:, :-1] = sounding
        sounding_pres[:, -1] = pres
        sounding_pres[:, 1] -= 273.15

        sounding_pres = np.flip(sounding_pres, axis=0)

        gem_inputs = [
            source_list, stat, v, theta, dphi, sounding_pres, sw, True, True
        ]
        Ro_ws, Ro_lin = periodSearch(op, gem_inputs, paths=False)

        return Ro_ws, Ro_lin
示例#5
0
class ArrayStacker(QMainWindow):
    def __init__(self):
        super().__init__()

        self.setWindowTitle('mSeed Reader (not working)')
        app_icon = QtGui.QIcon()
        app_icon.addFile(os.path.join('images', 'bam.png'),
                         QtCore.QSize(16, 16))
        self.setWindowIcon(app_icon)

        theme(self)

        self.buildGUI()

    def buildGUI(self):

        self._main = QWidget()
        self.setCentralWidget(self._main)
        layout = QGridLayout(self._main)

        self.raw_traces = [None] * N_ELEMENTS

        self.stat_graph = MatplotlibPyQT()
        self.sum_graph = MatplotlibPyQT()
        self.mseed_browser_label = [None] * N_ELEMENTS
        self.mseed_browser_edits = [None] * N_ELEMENTS
        self.mseed_browser_buton = [None] * N_ELEMENTS
        self.stat_shifter_label = [None] * N_ELEMENTS
        self.stat_shifter_edits = [None] * N_ELEMENTS

        self.stat_graph.ax = []
        for i in range(N_ELEMENTS):
            if i == 0:
                self.stat_graph.ax.append(
                    self.stat_graph.figure.add_subplot(N_ELEMENTS, 1, i + 1))
            else:

                self.stat_graph.ax.append(
                    self.stat_graph.figure.add_subplot(
                        N_ELEMENTS, 1, i + 1, sharex=self.stat_graph.ax[0]))

        layout.addWidget(self.stat_graph, 0, 4, N_ELEMENTS * 4, 1)

        self.sum_graph.ax = self.sum_graph.figure.add_subplot(
            1, 1, 1, sharex=self.stat_graph.ax[0])
        layout.addWidget(self.sum_graph, N_ELEMENTS * 4 + 1, 4)

        for ii in range(N_ELEMENTS):
            self.mseed_browser_label[ii], self.mseed_browser_edits[
                ii], self.mseed_browser_buton[ii] = createFileSearchObj(
                    'mSeed File: {:}'.format(ii + 1),
                    layout,
                    2 * ii,
                    width=1,
                    h_shift=0)
            _, self.stat_shifter_edits[ii] = createLabelEditObj(
                'Shift [s]',
                layout,
                2 * ii + 1,
                width=1,
                h_shift=0,
                tool_tip='',
                validate='float',
                default_txt='0')

            self.mseed_browser_buton[ii].clicked.connect(
                partial(fileSearch, ['mSeed File (*.mseed)'],
                        self.mseed_browser_edits[ii]))

        self.stackRaw_button = createButton("Raw Stack", layout,
                                            2 * N_ELEMENTS + 2, 2,
                                            self.stackRaw)
        self.stack_button = createButton("Stack", layout, 2 * N_ELEMENTS + 2,
                                         1, self.stack)
        self.read = createButton("Read", layout, 2 * N_ELEMENTS + 2, 0,
                                 self.read)

    def resetGraphs(self):
        for ii in range(N_ELEMENTS):
            self.stat_graph.ax[ii].clear()
        self.sum_graph.ax.clear()

    def read(self):

        self.resetGraphs()

        for ii in range(N_ELEMENTS):
            stat_text = self.mseed_browser_edits[ii].text()
            if stat_text is not None and len(stat_text) > 0:

                self.raw_traces[ii] = obspy.read(stat_text)

                tr = self.raw_traces[ii].copy().select(channel=CHN)[0]

                if ii == 0:
                    ref_time = tr.stats.starttime.datetime
                # tr.stats.starttime = tr.stats.starttime - float(self.stat_shifter_edits.text())

                y, x = procTrace(tr, ref_datetime=ref_time - datetime.timedelta(seconds=float(self.stat_shifter_edits[ii].text())), \
                                        resp=None, bandpass=[3, 9.5], backup=False)
                x = x[0]
                y = y[0]

                if ii == 3:
                    self.startpoint = x[0]
                # Bandpass
                # Add shifter

                self.stat_graph.ax[ii].plot(x, y)

        # self.stat_graph.ax.legend()
        self.stat_graph.show()

    def findCommonEnds(self):

        common_start = None
        common_end = None

        for ii in range(N_ELEMENTS):

            stat_text = self.mseed_browser_edits[ii].text()

            if stat_text is not None and len(stat_text) > 0:
                st = obspy.read(stat_text)[0]

                start = st.stats.starttime
                end = st.stats.endtime

                if common_start is None:
                    common_start = start

                if common_end is None:
                    common_end = end

                if start - common_start > 0:
                    common_start = start

                if end - common_end < 0:
                    common_end = end

        return common_start, common_end

    def findTotalRange(self):
        common_start = None
        common_end = None

        for ii in range(N_ELEMENTS):

            stat_text = self.mseed_browser_edits[ii].text()

            if stat_text is not None and len(stat_text) > 0:
                st = obspy.read(stat_text)[0]

                start = st.stats.starttime
                end = st.stats.endtime

                if common_start is None:
                    common_start = start

                if common_end is None:
                    common_end = end

                if start - common_start < 0:
                    common_start = start

                if end - common_end > 0:
                    common_end = end

        return common_start, common_end

    def stack(self):

        for ii in range(N_ELEMENTS):

            stat_text = self.mseed_browser_edits[ii].text()

            if stat_text is not None and len(stat_text) > 0:
                st = obspy.read(stat_text)
                # print("Unshifted", st)
                # a = input("Shift?")

                st[0].stats.starttime = st[
                    0].stats.starttime - datetime.timedelta(
                        seconds=float(self.stat_shifter_edits[ii].text()))
                # print("Shifted", st)
                common_start, common_end = self.findCommonEnds()
                total_start, total_end = self.findTotalRange()
                if ii == 0:
                    master_st = st
                    # common_start, common_end = obspy.core.utcdatetime.UTCDateTime("2015-01-07T01:00:58.526038Z"),\
                    #                             obspy.core.utcdatetime.UTCDateTime("2015-01-07T02:05:57.676038Z")
                    master_st[0].trim(starttime=common_start,
                                      endtime=common_end)
                else:

                    tr = st.select(channel=CHN)[0]
                    tr.trim(starttime=common_start, endtime=common_end)
                    master_st.append(tr)

        stacked_st = master_st.stack(npts_tol=100)
        stacked_st[0].stats.starttime = common_start
        tr = stacked_st[0]

        y, x = procTrace(tr, ref_datetime=tr.stats.starttime.datetime , \
                                        resp=None, bandpass=[2, 8], backup=False)
        x = x[0]
        y = y[0]

        self.sum_graph.ax.plot(x + self.startpoint, y)
        self.sum_graph.show()
        # print("stacked", stacked_st)
        file_name = saveFile("mseed", note="")
        stacked_st.write(file_name)

    def stackRaw(self):

        common_start, common_end = self.findCommonEnds()
        print(common_start, common_end)
        for ii in range(N_ELEMENTS):

            stat_text = self.mseed_browser_edits[ii].text()

            if stat_text is not None and len(stat_text) > 0:
                st = obspy.read(stat_text)

                if ii == 0:
                    master_st = st
                    master_st[0].trim(starttime=common_start,
                                      endtime=common_end)
                else:
                    tr = st.select(channel=CHN)[0]
                    tr.trim(starttime=common_start, endtime=common_end)
                    master_st.append(tr)

        stacked_st = np.sum(master_st, axis=0)
        stacked_st[0].stats.starttime = obspy.read(
            self.mseed_browser_edits[-1].text())[0].stats.starttime

        tr = stacked_st[0]

        y, x = procTrace(tr, ref_datetime=tr.stats.starttime.datetime , \
                                        resp=None, bandpass=[2, 8], backup=False)
        x = x[0]
        y = y[0]

        self.sum_graph.ax.plot(x + self.startpoint, y)
        self.sum_graph.show()
        # print("stacked", stacked_st)
        file_name = saveFile("mseed", note="")
        stacked_st.write(file_name)
示例#6
0
class Geminus(QWidget):
    def __init__(self, bam, prefs):

        QWidget.__init__(self)

        self.bam = bam
        self.prefs = prefs

        if not hasattr(bam, "infra_curve"):
            self.bam.infra_curve = []

        self.buildGUI()

        self.current_Ro, self.current_height = None, None

    def buildGUI(self):
        self.setWindowTitle('Geminus (Silber 2014)')

        app_icon = QIcon()
        app_icon.addFile(
            os.path.join('supra', 'GUI', 'Images', 'BAM_no_wave.png'),
            QSize(16, 16))
        self.setWindowIcon(app_icon)

        p = self.palette()
        p.setColor(self.backgroundRole(), Qt.black)
        self.setPalette(p)

        layout = QHBoxLayout()
        self.setLayout(layout)

        graph_layout = QVBoxLayout()
        layout.addLayout(graph_layout)

        right_panels = QVBoxLayout()
        layout.addLayout(right_panels)

        input_panels = QGridLayout()
        right_panels.addLayout(input_panels)

        output_panels = QGridLayout()
        right_panels.addLayout(output_panels)

        stn_name_list = []
        for stn in self.bam.stn_list:
            stn_name_list.append("{:}-{:}".format(stn.metadata.network,
                                                  stn.metadata.code))

        control_panel = QGridLayout()
        input_panels.addLayout(control_panel, 6, 1, 3, 3)

        _, self.source_height = createLabelEditObj(
            'Source Height Along Trajectory [m]',
            input_panels,
            1,
            width=1,
            h_shift=0,
            tool_tip='')
        _, self.station_combo = createComboBoxObj('Station',
                                                  input_panels,
                                                  2,
                                                  items=stn_name_list,
                                                  width=1,
                                                  h_shift=0,
                                                  tool_tip='')

        _, self.blast_radius = createLabelEditObj('Blast Radius [m]',
                                                  input_panels,
                                                  3,
                                                  width=1,
                                                  h_shift=0,
                                                  tool_tip='')
        _, self.dom_period = createLabelEditObj('Dominant Period [s]',
                                                input_panels,
                                                4,
                                                width=1,
                                                h_shift=0,
                                                tool_tip='')
        _, self.over_pres = createLabelEditObj('Overpressure [Pa]',
                                               input_panels,
                                               5,
                                               width=1,
                                               h_shift=0,
                                               tool_tip='')
        self.vary_period = createToggle('Vary Period',
                                        control_panel,
                                        1,
                                        width=1,
                                        h_shift=1,
                                        tool_tip='')
        self.add_winds = createToggle('Include Winds',
                                      control_panel,
                                      2,
                                      width=1,
                                      h_shift=1,
                                      tool_tip='')
        self.doppler = createToggle('Doppler Shift',
                                    control_panel,
                                    3,
                                    width=1,
                                    h_shift=1,
                                    tool_tip='')
        self.overpressure_run = createButton("Run Blast Radius Simulation",
                                             control_panel,
                                             4,
                                             1,
                                             self.overpressure,
                                             args=["normal"])
        self.overpressure_period_finder = createButton("Run Period Search",
                                                       control_panel,
                                                       4,
                                                       2,
                                                       self.overpressure,
                                                       args=["period"])
        self.overpressure_pres_finder = createButton("Run Overpressure Search",
                                                     control_panel,
                                                     4,
                                                     3,
                                                     self.overpressure,
                                                     args=["pres"])
        self.pro_sim = createButton("Period vs. Blast Radius",
                                    control_panel,
                                    5,
                                    1,
                                    self.overpressure,
                                    args=["pro"])
        self.proE_sim = createButton("Period vs. Energy",
                                     control_panel,
                                     5,
                                     2,
                                     self.overpressure,
                                     args=["proE"])
        self.infra_curve = createButton("Infrasound Curve", control_panel, 5,
                                        3, self.infraCurve)
        self.clear_infra = createButton("Clear Curve", control_panel, 6, 1,
                                        self.clearInfra)
        self.save_energy = createButton("Save Energy", control_panel, 6, 2,
                                        self.saveInfra)

        self.vary_period.setChecked(True)
        self.add_winds.setChecked(True)
        self.doppler.setChecked(True)

        self.overpressure_plot = MatplotlibPyQT()
        self.overpressure_plot.ax = self.overpressure_plot.figure.add_subplot(
            111)
        graph_layout.addWidget(self.overpressure_plot)

        theme(self)

    def saveInfra(self):
        a = EnergyObj()
        a.source_type = "Ballistic"
        a.height = self.current_height
        stat_idx = self.station_combo.currentIndex()
        a.station = self.bam.stn_list[stat_idx]
        a.linear_Ro = self.current_lin_Ro
        a.ws_Ro = self.current_ws_Ro
        a.linear_E = Efunction(self.current_lin_Ro, self.current_height)
        a.ws_E = Efunction(self.current_ws_Ro, self.current_height)

        try:
            a.period = float(self.dom_period.text())
        except:
            a.period = None

        try:
            a.overpress = float(self.over_pres.text())
        except:
            a.overpress = None

        self.bam.energy_measurements.append(a)

    def clearInfra(self):
        self.bam.infra_curve = []

        print(printMessage("status"), "Cleared infrasound curve data!")

    def infraCurve(self):

        if not hasattr(self, "current_lin_Ro"):
            if self.prefs.debug:
                print(
                    printMessage("warning"),
                    " Run a search for overpressure or period first before plotting a point!"
                )
            errorMessage(
                "No points to plot!",
                1,
                detail=
                'Please use one of the searches to find Ro for both weak-shock and linear!'
            )
            return None

        if not self.current_lin_Ro is None and not self.current_ws_Ro is None and not self.current_height is None:
            self.bam.infra_curve.append(
                [self.current_lin_Ro, self.current_ws_Ro, self.current_height])

        if len(self.bam.infra_curve) == 0:
            errorMessage("No points to plot!",
                         1,
                         detail='Please use one of the searches to find Ro!')
            return None

        ax1 = plt.subplot(2, 1, 1)
        E_lin = []
        E_ws = []

        for point in self.bam.infra_curve:

            h = point[2]
            E_lin.append(Efunction(point[0], h))
            E_ws.append(Efunction(point[1], h))

        ax1.scatter(h / 1000, E_lin, label='Linear')
        ax1.scatter(h / 1000, E_ws, label="Weak Shock")

        ax1.set_xlabel("Height [km]")
        ax1.set_ylabel("Energy per Unit Length [J/m]")

        ax2 = plt.subplot(2, 1, 2, sharex=ax1)

        light_curve = readLightCurve(self.bam.setup.light_curve_file)

        light_curve_list = processLightCurve(light_curve)

        for L in light_curve_list:
            ax2.scatter(L.h, L.M, label=L.station)
            # light_curve_curve = pg.ScatterPlotItem(x=L.M, y=L.t)
            # self.light_curve_canvas.addItem(light_curve_curve)

        ax2.set_xlabel("Height [km]")
        ax2.set_ylabel("Absolute Magnitude")
        plt.gca().invert_yaxis()
        plt.legend()

        # ax3 = plt.subplot(3, 1, 3, sharex=ax2)
        # v = self.bam.setup.trajectory.v
        # ax3.scatter(np.array(h)/1000, np.array(E_lin)/v, label='Linear')
        # ax3.scatter(np.array(h)/1000, np.array(E_ws)/v, label="Weak Shock")
        # for L in light_curve_list:
        #     ax3.scatter(L.h, 10**(-0.4*np.array(L.M)), label=L.station)

        # ax3.set_xlabel("Height [km]")
        # ax3.set_ylabel("?? Max Intensity ??")

        # plt.legend()
        plt.show()

    def overpressure(self, mode):

        wind = self.add_winds.isChecked()
        dopplershift = self.doppler.isChecked()

        if self.prefs.debug:
            print(printMessage("debug"),
                  " Running Geminus on mode '{:}'".format(mode))

        self.overpressure_plot.ax.clear()

        traj = self.bam.setup.trajectory

        try:
            source = traj.findGeo(float(self.source_height.text()))
        except ValueError as e:
            if self.prefs.debug:
                print(printMessage("Error"), " No source height given!")
            errorMessage("Cannot read source height!",
                         2,
                         detail='{:}'.format(e))

            return None

        source_list = [source.lat, source.lon, source.elev / 1000]

        stat_idx = self.station_combo.currentIndex()
        stat = self.bam.stn_list[stat_idx]
        stat_pos = stat.metadata.position
        stat = [stat_pos.lat, stat_pos.lon, stat_pos.elev / 1000]

        try:
            Ro = float(self.blast_radius.text())
        except:
            Ro = None

        v = traj.v / 1000

        theta = 90 - traj.zenith.deg

        dphi = np.degrees(
            np.arctan2(stat_pos.lon - source.lon,
                       stat_pos.lat - source.lat)) - traj.azimuth.deg

        # Switch 3 doesn't do anything in this version of overpressure.py
        sw = [self.vary_period.isChecked(), 0, 1]

        lat = [source.lat, stat_pos.lat]
        lon = [source.lon, stat_pos.lon]
        elev = [source.elev, stat_pos.elev]

        sounding, _ = self.bam.atmos.getSounding(
            lat=lat,
            lon=lon,
            heights=elev,
            ref_time=self.bam.setup.fireball_datetime)

        pres = 10 * 101.325 * np.exp(-0.00012 * sounding[:, 0])

        sounding_pres = np.zeros((sounding.shape[0], sounding.shape[1] + 1))
        sounding_pres[:, :-1] = sounding
        sounding_pres[:, -1] = pres
        sounding_pres[:, 1] -= 273.15

        sounding_pres = np.flip(sounding_pres, axis=0)
        self.sounding_pres = sounding_pres

        gem_inputs = [
            source_list, stat, v, theta, dphi, sounding_pres, sw, wind,
            dopplershift
        ]

        if mode == "normal":
            try:
                tau, tauws, Z, sR, inc, talt, dpws, dp, it = \
                            overpressureihmod_Ro(source_list, stat, Ro, v, theta, dphi, sounding_pres, sw, wind=wind, dopplershift=dopplershift)
            except TypeError as e:
                errorMessage("Error in running Geminus!",
                             2,
                             detail='{:}'.format(e))
                return None

            self.overpressure_plot.ax.plot(tau[0:it],
                                           Z[0:it],
                                           'r-',
                                           label="Weak Shock Period Change")
            self.overpressure_plot.ax.plot(tau[it - 1:],
                                           Z[it - 1:],
                                           'b-',
                                           label="Stable Period")
            self.overpressure_plot.ax.plot(tauws[it - 1:],
                                           Z[it - 1:],
                                           'm-',
                                           label="Weak Shock: No Transition")

            self.overpressure_plot.ax.scatter([tau[it - 1]], [Z[it - 1]])

            self.overpressure_plot.ax.set_xlabel("Signal Period [s]")
            self.overpressure_plot.ax.set_ylabel("Geopotential Height [km]")
            self.overpressure_plot.ax.legend()
            self.overpressure_plot.show()

            print('Geminus Output')
            print('=========================================================')
            print('Period (weak shock):     {:3.4f} s'.format(tauws[-1]))
            print('  Frequency (weak shock):   {:3.3f} Hz'.format(1 /
                                                                  tauws[-1]))
            print('Period (linear):         {:3.4f} s'.format(tau[-1]))
            print('  Frequency (linear):       {:3.3f} Hz'.format(1 / tau[-1]))
            print('Slant range:             {:5.2f} km'.format(sR))
            print('Arrival (inclination):   {:3.4f} deg'.format(
                np.degrees(inc) % 360))
            print('Transition height:       {:3.3f} km'.format(talt))
            print('Overpressure (weak shock):     {:3.4f} Pa'.format(dpws[-1]))
            print('Overpressure (linear):         {:3.4f} Pa'.format(dp[-1]))

        elif mode == "period":

            p = float(self.dom_period.text())

            Ro_ws, Ro_lin, weak_path, lin_path, tau, Z, it = periodSearch(
                p, gem_inputs, paths=True)

            self.overpressure_plot.ax.plot(weak_path,
                                           Z,
                                           'r-',
                                           label="Weak Shock")
            self.overpressure_plot.ax.plot(lin_path, Z, 'b-', label="Linear")

            self.overpressure_plot.ax.scatter([tau[it - 1]], [Z[it - 1]])

            self.overpressure_plot.ax.set_xlabel("Signal Period [s]")
            self.overpressure_plot.ax.set_ylabel("Geopotential Height [km]")
            self.overpressure_plot.ax.legend()
            self.overpressure_plot.show()

            print('Geminus Output')
            print('=========================================================')
            print("Blast Radius (Weak-Shock): {:.2f} m".format(Ro_ws))
            print("Blast Radius (Linear): {:.2f} m".format(Ro_lin))
            self.current_lin_Ro = Ro_lin
            self.current_ws_Ro = Ro_ws

        elif mode == "pres":

            p = float(self.over_pres.text())

            Ro_ws, Ro_lin, weak_path, lin_path, tau, Z, it = presSearch(
                p, gem_inputs, paths=True)

            self.overpressure_plot.ax.plot(weak_path,
                                           Z,
                                           'r-',
                                           label="Weak Shock")
            self.overpressure_plot.ax.plot(lin_path, Z, 'b-', label="Linear")

            self.overpressure_plot.ax.scatter([tau[it - 1]], [Z[it - 1]])

            self.overpressure_plot.ax.set_xlabel("Signal Period [s]")
            self.overpressure_plot.ax.set_ylabel("Geopotential Height [km]")
            self.overpressure_plot.ax.legend()
            self.overpressure_plot.show()

            print('Geminus Output')
            print('=========================================================')
            print("Blast Radius (Weak-Shock): {:.2f} m".format(Ro_ws))
            print("Blast Radius (Linear): {:.2f} m".format(Ro_lin))

            self.current_lin_Ro = Ro_lin
            self.current_ws_Ro = Ro_ws

        elif mode == "pro" or mode == "proE":

            Ro = np.linspace(0.01, 100.00, 100)

            tau_list = []
            tau_ws_list = []

            for R in Ro:

                try:
                    tau, tauws, Z, sR, inc, talt, dpws, dp, it = \
                            overpressureihmod_Ro(source_list, stat, R, v, theta, dphi, sounding_pres, sw, wind=wind, dopplershift=dopplershift)
                except TypeError as e:
                    errorMessage("Error in running Geminus!",
                                 2,
                                 detail='{:}'.format(e))
                    return None

                tau_list.append(tau[-1])
                tau_ws_list.append(tauws[-1])

            if mode == "pro":
                self.overpressure_plot.ax.plot(Ro,
                                               np.array(tau_list),
                                               'b-',
                                               label="Linear")
                self.overpressure_plot.ax.plot(Ro,
                                               np.array(tau_ws_list),
                                               'r-',
                                               label="Weak Shock")

                self.overpressure_plot.ax.set_xlabel("Blast Radius [m]")
                self.overpressure_plot.ax.set_ylabel("Period [s]")

            elif mode == "proE":

                self.overpressure_plot.ax.plot(Efunction(Ro, source.elev),
                                               np.array(tau_list),
                                               'b-',
                                               label="Linear")
                self.overpressure_plot.ax.plot(Efunction(Ro, source.elev),
                                               np.array(tau_ws_list),
                                               'r-',
                                               label="Weak Shock")

                self.overpressure_plot.ax.set_xlabel(
                    "Energy per Unit Length [J/m]")
                self.overpressure_plot.ax.set_ylabel("Period [s]")

            self.overpressure_plot.ax.legend()
            self.overpressure_plot.show()

        self.current_height = float(self.source_height.text())
示例#7
0
class rtvWindowDialog(QWidget):
    def __init__(self, bam, prefs):

        QWidget.__init__(self)

        self.bam = bam
        self.prefs = prefs
        self.buildGUI()

        self.current_eigen = None
        self.current_loaded_rays = []

        self.plot_ba_data = []

    def buildGUI(self):
        self.setWindowTitle('Ray-Trace Viewer')
        app_icon = QtGui.QIcon()
        app_icon.addFile(
            os.path.join('supra', 'GUI', 'Images', 'BAM_no_wave.png'),
            QtCore.QSize(16, 16))
        self.setWindowIcon(app_icon)

        p = self.palette()
        p.setColor(self.backgroundRole(), Qt.black)
        self.setPalette(p)

        theme(self)

        layout = QGridLayout()
        self.setLayout(layout)

        self.rtv_graph = MatplotlibPyQT()

        self.rtv_graph.ax = self.rtv_graph.figure.add_subplot(111,
                                                              projection='3d')

        layout.addWidget(self.rtv_graph, 1, 1, 15, 1)

        self.hvt_graph = MatplotlibPyQT()
        self.hvt_graph.ax = self.hvt_graph.figure.add_subplot(111)
        layout.addWidget(self.hvt_graph, 16, 1, 15, 1)

        stn_name_list = []
        for stn in self.bam.stn_list:
            stn_name_list.append("{:}-{:}".format(stn.metadata.network,
                                                  stn.metadata.code))

        _, self.source_height = createLabelEditObj(
            'Source Height Along Trajectory [m]',
            layout,
            1,
            width=1,
            h_shift=1,
            tool_tip='',
            validate='float')
        _, self.station_combo = createComboBoxObj('Station',
                                                  layout,
                                                  2,
                                                  items=stn_name_list,
                                                  width=1,
                                                  h_shift=1,
                                                  tool_tip='')
        self.trajmode = createToggle("Plot Trajectory?",
                                     layout,
                                     3,
                                     width=1,
                                     h_shift=2,
                                     tool_tip='')
        self.netmode = createToggle("Run Ray Net?",
                                    layout,
                                    9,
                                    width=1,
                                    h_shift=2,
                                    tool_tip='')

        self.run_trace_button = createButton("Run", layout, 4, 3,
                                             self.runRayTrace)
        self.clear_trace_button = createButton("Clear", layout, 5, 3,
                                               self.clearRayTrace)
        # _, self.ray_frac = createLabelEditObj('Fraction of Rays to Show', layout, 5, width=1, h_shift=1, tool_tip='', validate='int', default_txt='50')

        _, self.horizontal_tol = createLabelEditObj('Horizontal Tolerance',
                                                    layout,
                                                    6,
                                                    width=1,
                                                    h_shift=1,
                                                    tool_tip='',
                                                    validate='float',
                                                    default_txt='330')
        _, self.vertical_tol = createLabelEditObj('Vertical Tolerance',
                                                  layout,
                                                  7,
                                                  width=1,
                                                  h_shift=1,
                                                  tool_tip='',
                                                  validate='float',
                                                  default_txt='3000')

        self.pertstog = createToggle("Use Pertubations",
                                     layout,
                                     8,
                                     width=1,
                                     h_shift=2,
                                     tool_tip='')
        _, self.source_lat = createLabelEditObj('Source Latitude',
                                                layout,
                                                10,
                                                width=1,
                                                h_shift=1,
                                                tool_tip='',
                                                validate='float')
        _, self.source_lon = createLabelEditObj('Source Longitude',
                                                layout,
                                                11,
                                                width=1,
                                                h_shift=1,
                                                tool_tip='',
                                                validate='float')

        self.save_ray = createButton("Export Ray", layout, 12, 3,
                                     self.exportRay)

        self.load_ray_label, self.load_ray_edits, self.load_ray_buton = createFileSearchObj(
            'Load Ray Trace: ', layout, 13, width=1, h_shift=1)
        self.load_ray_buton.clicked.connect(
            partial(fileSearch, ['DAT (*.dat)'], self.load_ray_edits))
        self.load_ray_buton.clicked.connect(self.plotLoadedRay)

        self.draw_stat = createButton("Draw Station", layout, 14, 3,
                                      self.drawStat)
        self.draw_src = createButton("Draw Source", layout, 15, 3,
                                     self.drawSrc)
        self.draw_traj = createButton("Draw Trajectory", layout, 16, 3,
                                      self.drawTraj)

        _, self.draw_beam = createLabelEditObj('Beam Azimuth',
                                               layout,
                                               17,
                                               width=1,
                                               h_shift=1,
                                               tool_tip='',
                                               validate='float')
        self.draw_beam_button = createButton("Draw", layout, 17, 4,
                                             self.drawBeam)
        _, self.draw_exp_time = createLabelEditObj('Expected Arrival Time [s]',
                                                   layout,
                                                   18,
                                                   width=1,
                                                   h_shift=1,
                                                   tool_tip='',
                                                   validate='float')
        self.trace_rev_button = createButton("Trace Reverse", layout, 18, 4,
                                             self.traceRev)

        self.hvt_graph.ax.set_xlabel("Time after Source [s]")
        self.hvt_graph.ax.set_ylabel("Height [km]")

        self.load_glm_label, self.load_glm_edits, self.load_glm_buton = createFileSearchObj(
            'Load GLM: ', layout, 19, width=1, h_shift=1)
        self.load_glm_buton.clicked.connect(
            partial(fileSearch, ['CSV (*.csv)'], self.load_glm_edits))
        self.load_glm_buton.clicked.connect(partial(self.procGLM, True))

        self.fireball_datetime_label, self.fireball_datetime_edits = createLabelDateEditObj(
            "GLM Initial Datetime", layout, 20, h_shift=1)
        self.glm2lc = createButton("GLM to Light Curve", layout, 21, 4,
                                   self.glm2LC)

        self.pol_graph = MatplotlibPyQT()
        self.pol_graph.ax1 = self.pol_graph.figure.add_subplot(211)
        self.pol_graph.ax2 = self.pol_graph.figure.add_subplot(212)
        layout.addWidget(self.pol_graph, 1, 5, 25, 1)

        self.load_baz_label, self.load_baz_edits, self.load_baz_buton = createFileSearchObj(
            'Load Backazimuth: ', layout, 22, width=1, h_shift=1)
        self.load_baz_buton.clicked.connect(
            partial(fileSearch, ['CSV (*.csv)'], self.load_baz_edits))
        self.load_baz_buton.clicked.connect(self.loadAngleCSV)
        self.height_unc_button = createButton("Height Uncertainty", layout, 23,
                                              4, self.heightUnc)

    def traceRev(self):
        ### Set up parameters of source

        traj = self.bam.setup.trajectory

        source = traj.findGeo(float(self.source_height.text()))
        stat_idx = self.station_combo.currentIndex()
        stat = self.bam.stn_list[stat_idx]
        stat_pos = stat.metadata.position

        lat = [source.lat, stat_pos.lat]
        lon = [source.lon, stat_pos.lon]
        elev = [source.elev, stat_pos.elev]

        sounding, perturbations = self.bam.atmos.getSounding(
            lat=lat,
            lon=lon,
            heights=elev,
            spline=100,
            ref_time=self.bam.setup.fireball_datetime)

        ref_pos = Position(self.bam.setup.lat_centre,
                           self.bam.setup.lon_centre, 0)

        stat_pos.pos_loc(source)
        source.pos_loc(source)

        source.z = source.elev
        stat_pos.z = stat_pos.elev

        h_tol = float(self.horizontal_tol.text())
        v_tol = float(self.vertical_tol.text())

        az = float(self.draw_beam.text())
        D = []
        # Use 25 angles between 90 and 180 deg

        ### To do this more right, calculate D with bad winds, and then use D to find new winds and then recalc D
        ze_list = np.linspace(1, 89, 25)

        for zenith in ze_list:
            # D = anglescanrev(S.xyz, self.azimuth + offset, zenith, sounding, wind=True)
            # D = anglescanrev(S.xyz, (self.azimuth + offset + 180)%360, zenith, sounding, wind=True)

            # Plus 180 for backazimuth
            D.append(
                anglescanrev(stat_pos.xyz, (az + 180) % 360,
                             zenith,
                             sounding,
                             wind=True,
                             trace=True,
                             fix_phi=False))

        # pt, err = finalanglecheck(self.bam, self.bam.setup.trajectory, self.stn.metadata.position, self.azimuth)
        # print(D)

        for ii, trace in enumerate(D):

            tr = []

            for pt in trace:
                x, y, z, T = pt

                A = Position(0, 0, 0)
                A.x, A.y, A.z = x, y, z
                A.pos_geo(source)
                tr.append([A.lat, A.lon, A.elev])
            tr = np.array(tr)
            self.rtv_graph.ax.plot(tr[:, 1], tr[:, 0], tr[:, 2], c="g")
            print(
                "Ray Trace at ze={:.2f} deg: {:.4f}N {:.4f}E {:.3f}km".format(
                    ze_list[ii], tr[-1, 1], tr[-1, 0], tr[-1, 2] / 1000))

    def getSource(self):
        traj = self.bam.setup.trajectory
        source = traj.findGeo(float(self.source_height.text()))
        return source

    def getStat(self):
        stat_idx = self.station_combo.currentIndex()
        stat = self.bam.stn_list[stat_idx]
        stat_pos = stat.metadata.position
        return stat_pos

    def getATM(self, source, stat_pos, perturbations=None):

        lat = [source.lat, stat_pos.lat]
        lon = [source.lon, stat_pos.lon]
        elev = [source.elev, stat_pos.elev]

        sounding, perturbations = self.bam.atmos.getSounding(lat=lat, lon=lon, heights=elev, spline=100, \
            ref_time=self.bam.setup.fireball_datetime, perturbations=perturbations)

        return sounding, perturbations

    def heightUnc(self):
        # use given height
        # find nominal height within large tolerance around given height (prevent double solutions)
        # run perturbations within this large tolerance, return heights
        # return bar graph of solutions

        source = self.getSource()
        stat_pos = self.getStat()

        sounding, perturbations = self.getATM(source,
                                              stat_pos,
                                              perturbations=0)

        stat_pos.pos_loc(source)
        source.pos_loc(source)

        source.z = source.elev
        stat_pos.z = stat_pos.elev

        given_height = source.elev
        given_time = float(self.draw_exp_time.text())

        h_tol = float(self.horizontal_tol.text())
        v_tol = float(self.vertical_tol.text())

        def findHeightFromTime(sounding, t, approx_height):
            traj = self.bam.setup.trajectory

            found = False

            prev_err = 999

            TIME_TOL = 1e-3

            BOUNDS = 4000

            indxs = 100

            stat_pos = self.getStat()

            SWARM_SIZE = 100
            MAXITER = 25
            PHIP = 0.5
            PHIG = 0.5
            OMEGA = 0.5
            MINFUNC = 1e-3
            MINSTEP = 1e-2

            search_min = [approx_height - BOUNDS]
            search_max = [approx_height + BOUNDS]

            f_opt, x_opt = pso(heightErr, search_min, search_max, \
                                args=[t, traj, stat_pos, sounding], processes=multiprocessing.cpu_count()-1, particle_output=False, swarmsize=SWARM_SIZE,\
                             maxiter=MAXITER, phip=PHIP, phig=PHIG, \
                             debug=False, omega=OMEGA, minfunc=MINFUNC, minstep=MINSTEP)

            best_height = f_opt[0]

            print("Best Height = {:.5f} km".format(best_height / 1000))

            return best_height

        print("Nominal")
        nominal_height = findHeightFromTime(sounding, given_time, given_height)
        pert_height = []

        for pp, pert in enumerate(perturbations):
            print("Perturbation {:}".format(pp + 1))
            temp_h = findHeightFromTime(pert, given_time, given_height)

            pert_height.append(temp_h)

        print("FINAL RESULTS")
        print("Nominal Height {:} km".format(nominal_height / 1000))
        # print("Pert Heights {:} km".format(np.array(pert_height)/1000))

    def glm2LC(self):
        t, h, M = self.procGLM(plot=False)

        dlg = QFileDialog.getSaveFileName(self, 'Save File')

        file_name = dlg[0]

        with open(file_name, 'w+') as f:
            for ii in range(len(M)):
                f.write("{:}, {:}, {:}\n".format(t[ii], h[ii], M[ii]))
            f.close()

        print("Wrote file to {:}".format(file_name))

    def procGLM(self, plot=True):
        time = []
        lon = []
        lat = []
        energy = []
        print("Loaded: {:}".format(self.load_glm_edits.text()))
        with open(self.load_glm_edits.text(), 'r+') as f:
            for line in f:
                a = line.strip().split(',')

                try:
                    float(a[0])
                    time.append(float(a[0]))
                    lon.append(float(a[1]))
                    lat.append(float(a[2]))
                    energy.append(float(a[3]))
                except:
                    continue

        initial_time = self.fireball_datetime_edits.dateTime().toPyDateTime()

        utc_times = []
        for tt, t in enumerate(time):
            utc_times.append(initial_time +
                             timedelta(seconds=(t - time[0]) / 1000))

        rel_time = []

        traj = self.bam.setup.trajectory
        for tt in utc_times:
            rel_time.append(
                (tt - self.bam.setup.fireball_datetime).total_seconds())

        MIN_SIZE = 10
        MAX_SIZE = 400

        lon_list = []
        lat_list = []
        height_list = []
        # I changed my mind and did intensity instead
        mag_list = []
        for ii in range(len(energy)):

            traj_h = traj.approxHeight(lat[ii], lon[ii], rel_time[ii])

            traj_pos = traj.findGeo(traj_h)

            lon_list.append(traj_pos.lon)
            lat_list.append(traj_pos.lat)
            height_list.append(traj_pos.elev)
            mag_list.append(energy[ii])

        min_mag = np.nanmin(mag_list)
        max_mag = np.nanmax(mag_list)

        size_scal = (mag_list - min_mag) * (MAX_SIZE - MIN_SIZE) / (
            max_mag - min_mag) + MIN_SIZE
        if plot:
            self.rtv_graph.ax.scatter(lon_list,
                                      lat_list,
                                      height_list,
                                      s=size_scal,
                                      c=mag_list,
                                      cmap='magma')

            self.rtv_graph.ax.plot([traj.pos_i.lon, traj.pos_f.lon],
                                   [traj.pos_i.lat, traj.pos_f.lat],
                                   [traj.pos_i.elev, traj.pos_f.elev])

        else:
            t = rel_time
            h = np.array(height_list) / 1000
            M = -2.5 * np.log10(mag_list)
            return t, h, M

    def plothvt(self):

        self.hvt_graph.ax.clear()

        max_t = 0
        for ray in self.current_loaded_rays:
            self.hvt_graph.ax.plot(ray[1], ray[0][:, 2] / 1000)
            for t in ray[1]:
                if max_t <= t:
                    max_t = t

        ### Wind contour
        if len(self.current_loaded_rays) > 0:
            ray = self.current_loaded_rays[0]
            pos_i = Position(ray[0][0, 0], ray[0][0, 1], ray[0][0, 2])
            pos_f = Position(ray[0][-1, 0], ray[0][-1, 1], ray[0][-1, 2])

            sounding, perturbations = self.bam.atmos.getSounding(lat=[pos_f.lat, pos_i.lat],\
                     lon=[pos_f.lon, pos_i.lon], heights=[pos_f.elev, pos_i.elev], spline=100, \
                     ref_time=self.bam.setup.fireball_datetime)

            levels = sounding[:, 0] / 1000

            # maximum possible speed
            speed = sounding[:, 1] + sounding[:, 2]

            xmin, xmax = self.hvt_graph.ax.get_xlim()
            ymin, ymax = self.hvt_graph.ax.get_ylim()

            h = []
            t = []
            z = []

            for i in range(len(levels)):
                h.append(levels[i])
                t.append(0)
                z.append(speed[i])

            for i in range(len(levels)):
                h.append(levels[i])
                t.append(max_t)
                z.append(speed[i])

            sc = self.hvt_graph.ax.tricontourf(t,
                                               h,
                                               z,
                                               alpha=0.3,
                                               cmap='viridis')
            if len(self.current_loaded_rays) == 1:

                cbar = self.hvt_graph.figure.colorbar(sc, pad=0.2)
                cbar.ax.set_xlabel("Maximum Effective Speed [m/s]")

        self.hvt_graph.ax.set_xlabel("Time after Source [s]")
        self.hvt_graph.ax.set_ylabel("Height [km]")

        self.hvt_graph.show()

    def drawBeam(self):
        # For short distances this is an ok approx
        stat_idx = self.station_combo.currentIndex()
        stat = self.bam.stn_list[stat_idx]
        stat_pos = stat.metadata.position

        start_point = [stat_pos.lon, stat_pos.lat]

        y = np.tan(np.radians(float(self.draw_beam.text())))

        # deg lat/lon to extend beam to
        SCALE = 1

        a = SCALE / np.sqrt(y**2 + 1)

        new_pos = [stat_pos.lon + a, stat_pos.lat + a * y]

        self.rtv_graph.ax.plot([start_point[0], new_pos[0]],
                               [start_point[1], new_pos[1]], [0, 0])
        self.pol_graph.ax1.axhline(y=float(self.draw_beam.text()),
                                   linestyle='-')

    def drawStat(self):
        stat_idx = self.station_combo.currentIndex()
        stat = self.bam.stn_list[stat_idx]
        stat_pos = stat.metadata.position

        self.rtv_graph.ax.scatter(stat_pos.lon,
                                  stat_pos.lat,
                                  stat_pos.elev,
                                  marker="^",
                                  c="b",
                                  s=200)

    def drawSrc(self):

        traj = self.bam.setup.trajectory
        if len(self.source_lat.text()) != 0 and len(
                self.source_lon.text()) != 0:
            source = Position(float(self.source_lat.text()),
                              float(self.source_lon.text()),
                              float(self.source_height.text()))
        else:
            try:
                source = traj.findGeo(float(self.source_height.text()))

            except ValueError as e:
                if self.prefs.debug:
                    print(printMessage("Error"), " No source height given!")
                errorMessage("Cannot read source height!",
                             2,
                             detail='{:}'.format(e))

                return None
        self.rtv_graph.ax.scatter(source.lon,
                                  source.lat,
                                  source.elev,
                                  marker="*",
                                  c="r",
                                  s=200)

    def drawTraj(self):

        traj = self.bam.setup.trajectory

        MIN_SIZE = 10
        MAX_SIZE = 400

        if len(self.bam.setup.light_curve_file) > 0 or not hasattr(
                self.bam.setup, "light_curve_file"):

            light_curve = readLightCurve(self.bam.setup.light_curve_file)

            light_curve_list = processLightCurve(light_curve)

            lon_list = []
            lat_list = []
            height_list = []
            # I changed my mind and did intensity instead
            mag_list = []
            for L in light_curve_list:
                for ii in range(len(L.h)):
                    traj_pos = traj.findGeo(L.h[ii] * 1000)

                    lon_list.append(traj_pos.lon)
                    lat_list.append(traj_pos.lat)
                    height_list.append(L.h[ii] * 1000)
                    mag_list.append(L.I[ii])

            min_mag = np.nanmin(mag_list)
            max_mag = np.nanmax(mag_list)

            size_scal = (mag_list - min_mag) * (MAX_SIZE - MIN_SIZE) / (
                max_mag - min_mag) + MIN_SIZE
            self.rtv_graph.ax.scatter(lon_list,
                                      lat_list,
                                      height_list,
                                      s=size_scal,
                                      c=mag_list,
                                      cmap='magma')

        self.rtv_graph.ax.plot([traj.pos_i.lon, traj.pos_f.lon],
                               [traj.pos_i.lat, traj.pos_f.lat],
                               [traj.pos_i.elev, traj.pos_f.elev])

    def exportRay(self):

        filename = QFileDialog.getSaveFileName(self, 'Save File', '',
                                               'Dat File (*.dat)')

        with open(str(filename[0]), 'w+') as f:

            f.write(
                "# lat [deg]    lon [deg]   z [km]  geo. atten. [dB]    absorption [dB] time [s]\n"
            )

            for i in range(len(self.current_eigen[0])):

                position = self.current_eigen[0]
                time = self.current_eigen[1]
                f.write("{:}, {:}, {:}, {:}, {:}, {:}\n".format(position[i, 1], \
                            position[i, 0], position[i, 2]/1000, -999, -999, time[i]))

        errorMessage("Ray Exported!",
                     0,
                     title="Everybody Loves Ray-Tracing",
                     detail='File saved to {:}'.format(filename))

    def plotLoadedRay(self):

        ray_pos = []
        ray_time = []
        with open(self.load_ray_edits.text(), 'r+') as f:
            for line in f:
                a = line.strip().split(',')

                if len(a) == 1:
                    a = line.strip().split("\t")

                if line[0] == "#":
                    continue
                lat = float(a[0])
                lon = float(a[1])
                elev = float(a[2]) * 1000
                t = float(a[5])

                ray_pos.append([lat, lon, elev])
                ray_time.append(t)

        ray_pos = np.array(ray_pos)

        self.rtv_graph.ax.plot(ray_pos[:, 1], ray_pos[:, 0], ray_pos[:, 2])
        self.current_loaded_rays.append([ray_pos, ray_time])
        self.plothvt()

    def clearRayTrace(self):

        self.rtv_graph.ax.clear()
        self.rtv_graph.show()

    def rayTraceFromSource(self, source, clean_mode=False, debug=False):

        traj = self.bam.setup.trajectory

        ### Set up parameters of source

        stat_idx = self.station_combo.currentIndex()
        stat = self.bam.stn_list[stat_idx]
        stat_pos = stat.metadata.position

        lat = [source.lat, stat_pos.lat]
        lon = [source.lon, stat_pos.lon]
        elev = [source.elev, stat_pos.elev]

        if not clean_mode:
            print("Source Location")
            print(source)
            print("Station Location")
            print(stat_pos)

        sounding, perturbations = self.bam.atmos.getSounding(
            lat=lat,
            lon=lon,
            heights=elev,
            spline=1000,
            ref_time=self.bam.setup.fireball_datetime)

        ref_pos = Position(self.bam.setup.lat_centre,
                           self.bam.setup.lon_centre, 0)

        stat_pos.pos_loc(source)
        source.pos_loc(source)

        source.z = source.elev
        stat_pos.z = stat_pos.elev

        h_tol = float(self.horizontal_tol.text())
        v_tol = float(self.vertical_tol.text())

        ### Ray Trace

        r, tr, f_particle = cyscan(np.array([source.x, source.y, source.z]), np.array([stat_pos.x, stat_pos.y, stat_pos.z]), \
                            sounding, trace=True, plot=False, particle_output=True, debug=False, \
                            wind=True, h_tol=h_tol, v_tol=v_tol, print_times=True, processes=1)

        if not clean_mode:
            az = np.radians(r[1])
            tf = np.radians(180 - r[2])
            u = np.array([traj.vector.x, traj.vector.y, traj.vector.z])
            v = np.array([
                np.sin(az) * np.sin(tf),
                np.cos(az) * np.sin(tf), -np.cos(tf)
            ])
            angle_off = abs(
                np.degrees(
                    np.arccos(
                        np.dot(u / np.sqrt(u.dot(u)), v /
                               np.sqrt(v.dot(v))))) - 90)
            if not self.pertstog.isChecked():
                dx = np.abs(stat_pos.x - source.x)
                dy = np.abs(stat_pos.y - source.y)
                dz = np.abs(stat_pos.z - source.z)
                time_along_trajectory = traj.findTime(source.elev)

                print("###### RESULTS ######")
                print("Time Along Trajectory (wrt Reference): {:.4f} s".format(
                    time_along_trajectory))
                print("Acoustic Path Time: {:.4f} s".format(r[0]))
                print("Total Time from Reference: {:.4f} s".format(
                    r[0] + time_along_trajectory))
                print("Launch Angle {:.2f} deg".format(angle_off))
                print("###### EXTRAS #######")
                print("Time: {:.4f} s".format(r[0]))
                print("Azimuth: {:.2f} deg from North".format(r[1]))
                print("Takeoff: {:.2f} deg from up".format(r[2]))
                print("Error in Solution {:.2f} m".format(r[3]))
                print("Distance in x: {:.2f} m".format(dx))
                print("Distance in y: {:.2f} m".format(dy))
                print("Distance in z: {:.2f} m".format(dz))
                print("Horizontal Distance: {:.2f} m".format(
                    np.sqrt(dx**2 + dy**2)))
                print("Total Distance: {:.2f} m".format(
                    np.sqrt(dx**2 + dy**2 + dz**2)))
                print("No Winds Time: {:.2f} s".format(
                    np.sqrt(dx**2 + dy**2 + dz**2) / 330))
                print("Time Difference: {:.2f} s".format(
                    r[0] - np.sqrt(dx**2 + dy**2 + dz**2) / 330))
                print("Time Along Trajectory: {:.4f} s".format(
                    time_along_trajectory))
                print("Total Time from Reference: {:.4f} s".format(
                    r[0] + time_along_trajectory))
            else:
                t_array = []
                az_array = []
                tk_array = []
                err_array = []
                angle_off_array = []
                t_array.append(r[0])
                az_array.append(r[1])
                tk_array.append(r[2])
                err_array.append(r[3])
                angle_off_array.append(angle_off)

        try:

            N_LAYERS = 1

            for i in range(N_LAYERS):
                try:
                    ba, tf = determineBackAz(tr[-(i + 2), :], tr[-1, :],
                                             sounding[0, 2],
                                             np.degrees(sounding[0, 3]))

                    if hasattr(self, "plot_ba_data"):
                        self.plot_ba_data.append([source.elev / 1000, ba, tf])
                except IndexError:
                    pass

        except TypeError:
            pass

        ### Plot begin and end points of ray-trace

        self.rtv_graph.ax.scatter(source.lon,
                                  source.lat,
                                  source.elev,
                                  c='r',
                                  marker='*',
                                  s=200)
        self.rtv_graph.ax.scatter(stat_pos.lon,
                                  stat_pos.lat,
                                  stat_pos.elev,
                                  c='b',
                                  marker='^',
                                  s=200)

        positions = []

        ### Plot trace of eigen-ray

        try:
            path_len = 0
            for i in range(len(tr[:, 0])):
                A = Position(0, 0, 0)
                A.x, A.y, A.z = tr[i, 0], tr[i, 1], tr[i, 2]
                if i > 0:
                    path_len += np.sqrt((tr[i, 0] - tr[i - 1, 0])**2 +
                                        (tr[i, 1] - tr[i - 1, 1])**2 +
                                        (tr[i, 2] - tr[i - 1, 2])**2)

                else:
                    path_len += np.sqrt((tr[i, 0] - 0)**2 + (tr[i, 1] - 0)**2 +
                                        (tr[i, 2] - 0)**2)

                A.pos_geo(source)
                positions.append([A.lon, A.lat, A.elev])

            if not clean_mode:
                print("Total Path Length: {:.2f} m".format(path_len))
                print("Approximate Ray Time: {:.2f} s".format(path_len / 330))
            positions = np.array(positions)
            if not clean_mode:

                self.rtv_graph.ax.scatter(positions[:, 0],
                                          positions[:, 1],
                                          positions[:, 2],
                                          c='b',
                                          alpha=0.5)
            self.rtv_graph.ax.plot(positions[:, 0],
                                   positions[:, 1],
                                   positions[:, 2],
                                   c='k')

            self.current_eigen = [positions, tr[:, -1]]

            err = np.sqrt((stat_pos.x - tr[-1, 0])**2 +
                          (stat_pos.y - tr[-1, 1])**2 +
                          (stat_pos.z - tr[-1, 2])**2)

            h_err = np.sqrt((stat_pos.x - tr[-1, 0])**2 +
                            (stat_pos.y - tr[-1, 1])**2)
            v_err = np.sqrt((stat_pos.z - tr[-1, 2])**2)

            if debug:
                print(
                    "Source Height: {:.2f} km - Final Error: {:.2f} m (v) {:.2f} m (h)"
                    .format(source.elev / 1000, v_err, h_err))

            self.rtv_graph.ax.scatter(positions[-1, 0],
                                      positions[-1, 1],
                                      positions[-1, 2],
                                      c='g')

            self.current_loaded_rays.append([positions, tr[:, -1]])
            self.plothvt()
        except:
            pass

        if not clean_mode:
            for sol in f_particle:
                r = anglescan(np.array([source.x, source.y, source.z]),
                              sol[0],
                              sol[1],
                              sounding,
                              trace=True,
                              debug=False,
                              wind=True)
                tr = np.array(r[1])
                positions = []
                for i in range(len(tr[:, 0])):
                    A = Position(0, 0, 0)
                    A.x, A.y, A.z = tr[i, 0], tr[i, 1], tr[i, 2]
                    A.pos_geo(source)
                    positions.append([A.lon, A.lat, A.elev])
                positions = np.array(positions)
                # self.rtv_graph.ax.plot(positions[:, 0], positions[:, 1], positions[:, 2], alpha=0.3)
                err = np.sqrt((stat_pos.x - tr[-1, 0])**2 +
                              (stat_pos.y - tr[-1, 1])**2 +
                              (stat_pos.z - tr[-1, 2])**2)
                if err <= 1000:
                    self.rtv_graph.ax.scatter(positions[-1, 0],
                                              positions[-1, 1],
                                              positions[-1, 2],
                                              c='g')
                else:
                    self.rtv_graph.ax.scatter(positions[-1, 0],
                                              positions[-1, 1],
                                              positions[-1, 2],
                                              c='r')

        if self.pertstog.isChecked() and len(perturbations) > 0:
            if clean_mode:
                t_array = []
                az_array = []
                tk_array = []
                err_array = []
                angle_off_array = []
            for pert_idx, pert in enumerate(perturbations):
                sys.stdout.write("\r Working on Perturbation {:}/{:}".format(
                    pert_idx + 1, len(perturbations)))
                sys.stdout.flush()
                r, tr, f_particle = cyscan(np.array([source.x, source.y, source.z]), np.array([stat_pos.x, stat_pos.y, stat_pos.z]), \
                                pert, trace=True, plot=False, particle_output=True, debug=False, \
                                wind=True, h_tol=h_tol, v_tol=v_tol, print_times=True)

                t_array.append(r[0])
                az_array.append(r[1])
                tk_array.append(r[2])
                err_array.append(r[3])

                az = np.radians(r[1])
                tf = np.radians(180 - r[2])
                u = np.array([traj.vector.x, traj.vector.y, traj.vector.z])
                v = np.array([
                    np.sin(az) * np.sin(tf),
                    np.cos(az) * np.sin(tf), -np.cos(tf)
                ])
                angle_off = abs(
                    np.degrees(
                        np.arccos(
                            np.dot(u / np.sqrt(u.dot(u)), v /
                                   np.sqrt(v.dot(v))))) - 90)
                angle_off_array.append(angle_off)

                try:

                    ba = determineBackAz(tr[-2, :], tr[-1, :], pert[0, 2],
                                         np.degrees(pert[0, 3]))

                    # print("Height: {:.2f} km".format(source.elev/1000))
                    # print("Back Azimuth: {:.2f} deg".format(ba))
                    # print("Travel Time: {:.2f} s".format(r[0]))
                    # print("Approx: {:.2f} deg".format(last_azimuth))
                    # print("Pure Winds: {:.2f} deg".format(np.degrees(sounding[0, 3])%360))
                    self.plot_ba_data.append([source.elev / 1000, ba, r[0]])

                except TypeError:
                    pass

            print()
            print("###### NOMINAL RESULTS ######")
            print("Time: {:.4f} s".format(t_array[0]))
            print("Azimuth: {:.2f} deg from North".format(az_array[1]))
            print("Takeoff: {:.2f} deg from up".format(tk_array[2]))
            print("Error in Solution {:.2f} m".format(err_array[3]))
            print("### UNCERTAINTIES ###")
            print("Time: {:.4f} - {:.4f} s ({:.4f} s)".format(
                np.nanmin(t_array), np.nanmax(t_array),
                np.nanmax(t_array) - np.nanmin(t_array)))
            print(
                "Azimuth: {:.2f} - {:.2f} deg from North ({:.2f} deg)".format(
                    np.nanmin(az_array), np.nanmax(az_array),
                    np.nanmax(az_array) - np.nanmin(az_array)))
            print("Takeoff: {:.2f} - {:.2f} deg from up ({:.2f} deg)".format(
                np.nanmin(tk_array), np.nanmax(tk_array),
                np.nanmax(tk_array) - np.nanmin(tk_array)))
            print("Error in Solution {:.2f} - {:.2f} m ({:.2f} m)".format(
                np.nanmin(err_array), np.nanmax(err_array),
                np.nanmax(err_array) - np.nanmin(err_array)))
            print("Angle Off {:.2f} - {:.2f} deg ({:.2f} deg)".format(
                np.nanmin(angle_off_array), np.nanmax(angle_off_array),
                np.nanmax(angle_off_array) - np.nanmin(angle_off_array)))

            print("Saving CSV of Perturbations")
            file_name = saveFile("csv", note="Perturbations")

            with open(file_name, "w+") as f:
                f.write(
                    "Height [km], Time [s], Azimuth [deg from North], Takeoff [deg from Up], Error [m], Angle Off [deg]\n"
                )
                for ll, line in enumerate(t_array):
                    if ll == len(t_array):
                        f.write("{:}, {:}, {:}, {:}, {:}, {:}".format(
                            source.elev / 1000, t_array[ll], az_array[ll],
                            tk_array[ll], err_array[ll], angle_off_array[ll]))
                    else:
                        f.write("{:}, {:}, {:}, {:}, {:}, {:}\n".format(
                            source.elev / 1000, t_array[ll], az_array[ll],
                            tk_array[ll], err_array[ll], angle_off_array[ll]))

    def loadAngleCSV(self):

        times = []
        height = []
        trace = []
        incl = []
        baz = []

        with open(self.load_baz_edits.text(), 'r+') as f:
            for line in f:
                a = line.strip().split(',')

                times.append(float(a[0]))
                height.append(float(a[1]))
                trace.append(float(a[2]))
                incl.append(float(a[3]))
                baz.append(float(a[4]))

        self.pol_graph.ax1.scatter(np.array(height) / 1000, np.array(baz))
        self.pol_graph.ax1.set_xlabel("Height [km]")
        self.pol_graph.ax1.set_ylabel("Backazimuth [deg]")

        self.pol_graph.ax2.scatter(np.array(height) / 1000, np.array(incl))
        self.pol_graph.ax2.set_xlabel("Height [km]")
        self.pol_graph.ax2.set_ylabel("Inclination [deg]")

    def runRayTrace(self):

        traj = self.bam.setup.trajectory

        if not self.trajmode.isChecked():
            if self.netmode.isChecked():

                daz = 0.05
                # az_net = np.arange(0, 360 - daz, daz)
                az_net = np.arange(105, 160, daz)

                dtf = 0.05
                tf_net = np.arange(90 + dtf, 115 - dtf, dtf)

                stat_idx = self.station_combo.currentIndex()
                stat = self.bam.stn_list[stat_idx]
                stat_pos = stat.metadata.position
                source = traj.findGeo(float(self.source_height.text()))

                h_tol = float(self.horizontal_tol.text())
                v_tol = float(self.vertical_tol.text())
                lat = [source.lat, source.lat]
                lon = [source.lon, source.lon]
                elev = [source.elev, 0]
                sounding, perturbations = self.bam.atmos.getSounding(
                    lat=lat,
                    lon=lon,
                    heights=elev,
                    spline=1000,
                    ref_time=self.bam.setup.fireball_datetime)

                ref_pos = Position(self.bam.setup.lat_centre,
                                   self.bam.setup.lon_centre, 0)

                source.pos_loc(source)

                source.z = source.elev
                stat_pos.pos_loc(source)
                stat_pos.z = stat_pos.elev

                az_list = []
                tf_list = []
                T_list = []
                h_err_list = []
                v_err_list = []

                for az in az_net:
                    for tf in tf_net:
                        r = anglescan(np.array([source.x, source.y, source.z]),
                                      az,
                                      tf,
                                      sounding,
                                      trace=False,
                                      debug=False,
                                      wind=True)
                        x, y, z, T = r

                        h_err = np.sqrt((x - stat_pos.x)**2 +
                                        (y - stat_pos.y)**2)
                        v_err = np.abs(z - stat_pos.z)

                        if h_err <= h_tol and v_err <= v_tol:
                            eigen = True
                        else:
                            eigen = False
                        print(
                            "Ray Trace - Azimuth: {:.2f} Takeoff: {:.2f} Time: {:.2f} s H_err: {:.2f} V_err: {:.2f}"
                            .format(az, tf, T, h_err, v_err))

                        if eigen == True:
                            az_list.append(az)
                            tf_list.append(tf)
                            T_list.append(T)
                            h_err_list.append(h_err)
                            v_err_list.append(v_err)

                print("Done")

                print("Arrival Range: {:.2f}-{:.2f} s".format(
                    np.nanmin(T_list), np.nanmax(T_list)))
                print("Azimuths: {:.2f}-{:.2f}".format(
                    az_list[np.nanargmin(T_list)],
                    az_list[np.nanargmax(T_list)]))
                print("Takeoffs: {:.2f}-{:.2f}".format(
                    tf_list[np.nanargmin(T_list)],
                    tf_list[np.nanargmax(T_list)]))
            else:

                if len(self.source_lat.text()) != 0 and len(
                        self.source_lon.text()) != 0:
                    source = Position(float(self.source_lat.text()),
                                      float(self.source_lon.text()),
                                      float(self.source_height.text()))
                else:
                    try:
                        source = traj.findGeo(float(self.source_height.text()))

                    except ValueError as e:
                        if self.prefs.debug:
                            print(printMessage("Error"),
                                  " No source height given!")
                        errorMessage("Cannot read source height!",
                                     2,
                                     detail='{:}'.format(e))

                        return None

                self.rayTraceFromSource(source, debug=True)

        else:

            self.plot_ba_data = []
            # define line bottom boundary
            max_height = traj.pos_i.elev
            min_height = traj.pos_f.elev

            points = traj.trajInterp2(div=50,
                                      min_p=min_height,
                                      max_p=max_height)

            loadingBar('Calculating Station Times: ', 0, len(points))
            for pp, pt in enumerate(points):
                loadingBar('Calculating Station Times: ', pp, len(points))
                source = Position(pt[0], pt[1], pt[2])
                self.rayTraceFromSource(source, clean_mode=True, debug=True)
                loadingBar('Calculating Station Times: ', pp + 1, len(points))

            self.drawTraj()

            self.rtv_graph.show()

            self.plot_ba_data = np.array(self.plot_ba_data)

            self.pol_graph.ax1.scatter(self.plot_ba_data[:, 0],
                                       self.plot_ba_data[:, 1])
            self.pol_graph.ax1.set_xlabel("Height [km]")
            self.pol_graph.ax1.set_ylabel("Backazimuth [deg]")

            self.pol_graph.ax2.scatter(self.plot_ba_data[:, 0],
                                       self.plot_ba_data[:, 2])
            self.pol_graph.ax2.set_xlabel("Height [km]")
            self.pol_graph.ax2.set_ylabel("Inclination [deg]")

            print("Back Azimuths and Zeniths:")
            for i in range(len(self.plot_ba_data)):
                print("{:.4f} km, {:.2f} deg, {:.2f} deg".format(self.plot_ba_data[i, 0], \
                                                                self.plot_ba_data[i, 1], \
                                                                self.plot_ba_data[i, 2]))
示例#8
0
class Yield(QWidget):
    ''' Dialog to estimate yields from overpressures
    '''
    def __init__(self, bam, prefs, current_station):

        #################
        # Initialize GUI
        #################
        QWidget.__init__(self)
        self.setWindowTitle('Yields')
        p = self.palette()
        p.setColor(self.backgroundRole(), Qt.black)
        self.setPalette(p)

        self.prefs = prefs
        self.bam = bam
        self.setup = bam.setup
        self.stn_list = bam.stn_list
        self.current_station = current_station
        self.iterator = 0

        theme(self)

        self.count = 0
        layout = QHBoxLayout()
        self.setLayout(layout)

        pane1 = QGridLayout()
        layout.addLayout(pane1)

        pane2 = QVBoxLayout()
        layout.addLayout(pane2)

        self.station_label = QLabel('Station: {:}'.format(
            self.stn_list[self.current_station].metadata.code))
        pane1.addWidget(self.station_label, 0, 1, 1, 1)

        self.station1_label = QLabel('Nominal')
        pane1.addWidget(self.station1_label, 0, 2, 1, 1)

        self.height_label, self.height_edits = createLabelEditObj(
            'Height', pane1, 1)
        self.range_label, self.range_edits = createLabelEditObj(
            'Range', pane1, 2)
        self.pressure_label, self.pressure_edits = createLabelEditObj(
            'Explosion Height Pressure', pane1, 3)
        self.overpressure_label, self.overpressure_edits = createLabelEditObj(
            'Overpressure', pane1, 4)
        self.afi_label, self.afi_edits = createLabelEditObj(
            'Attenuation Integration Factor', pane1, 5)
        self.geo_label, self.geo_edits = createLabelEditObj(
            'Geometric Factor', pane1, 6)
        self.p_a_label, self.p_a_edits = createLabelEditObj(
            'Ambient Pressure', pane1, 7)
        self.Jd_label, self.Jd_edits = createLabelEditObj(
            'Positive Phase Length [ms]', pane1, 8)
        self.fd_label, self.fd_edits = createLabelEditObj(
            'Transmission Factor', pane1, 9)
        _, self.freq_edits = createLabelEditObj("Dominant Period", pane1, 10)
        _, self.rfangle_edits = createLabelEditObj("Refraction Angle", pane1,
                                                   11)
        # _, self.I_A_edits = createLabelEditObj("Impulse per Area (Area under curve)", pane1, 12)

        # self.fyield_button = QPushButton('Calculate Yield (Frequency)')
        # pane1.addWidget(self.fyield_button, 15, 1, 1, 4)
        # self.fyield_button.clicked.connect(self.freqYield)

        self.yield_button = QPushButton('Calculate Yield')
        pane1.addWidget(self.yield_button, 14, 1, 1, 4)
        self.yield_button.clicked.connect(self.yieldCalc)

        self.integrate_button = QPushButton('Integrate')
        pane1.addWidget(self.integrate_button, 13, 1, 1, 4)
        self.integrate_button.clicked.connect(self.intCalc)

        # Constants - Reed 1972
        self.W_0 = 4.184e6  # Standard reference explosion yield (1 kg of Chemical TNT)
        self.P_0 = 101325  # Standard pressure
        self.b = 1 / consts.SCALE_HEIGHT
        self.k = consts.ABSORPTION_COEFF
        self.J_m = 0.375  # Avg positive period of reference explosion
        self.c_m = 347  # Sound speed of reference explosion

        # self.blastline_view = pg.GraphicsLayoutWidget()
        # self.blastline_canvas = self.blastline_view.addPlot()
        # pane2.addWidget(self.blastline_view)
        # self.blastline_view.sizeHint = lambda: pg.QtCore.QSize(100, 100)

        self.blastline_plot = MatplotlibPyQT()
        self.blastline_plot.ax = self.blastline_plot.figure.add_subplot(111)
        pane2.addWidget(self.blastline_plot)

    def freqYield(self):
        print("Broken")
        # self.R = tryFloat(self.range_edits.text())
        # self.P_a = tryFloat(self.p_a_edits.text())
        # self.cf = tryFloat(self.geo_edits.text())
        # self.I = tryFloat(self.afi_edits.text())
        # self.P = tryFloat(self.pressure_edits.text())
        # self.c = tryFloat(self.c_edits.text())
        # self.f_d = tryFloat(self.fd_edits.text())
        # self.T_d = tryFloat(self.freq_edits.text())

        # #W_0 = 4.184e12 (IBM Problem)
        # W = self.W_0*self.P_a/self.P_0*(self.c*self.T_d/2/self.c_m/self.J_m)**3
        # print("Freq: {:.2f} -> {:.2E} J".format(1/self.T_d, W))

    def integration_full(self, k, v, b, I, P_a):
        return np.exp(-I * (k * v**2 / b) * N_LAYERS)
        # return np.exp(-k*v**2/b/P_a*I)

    def inverse_gunc(self, p_ans):
        a, b = pso(gunc_error, [1], [15], args=([p_ans, self.J_m, \
                        self.W_0, self.P, self.P_0, self.Jd, self.c_m, self.f_d, self.R,\
                         self.P_a, self.k, self.b, self.I, self.cf, self.horizontal_range, self.v]), \
                        processes=1, swarmsize=1000, maxiter=1000)

        return 10**(a[0])

    def integrate(self, height, D_ANGLE=1.5, tf=1, az=1):
        ref_pos = Position(self.setup.lat_centre, self.setup.lon_centre, 0)
        try:
            point = self.setup.trajectory.findGeo(height)
        except AttributeError:
            print("STATUS: No trajectory given, assuming lat/lon center")
            point = Position(self.setup.lat_centre, self.setup.lon_centre,
                             height)
        point.pos_loc(ref_pos)

        stn = self.stn_list[self.current_station]
        stn.metadata.position.pos_loc(ref_pos)

        lats = [point.lat, stn.metadata.position.lat]
        lons = [point.lon, stn.metadata.position.lon]
        elevs = [point.elev, stn.metadata.position.elev]

        # make the spline lower to save time here

        sounding, perturbations = self.bam.atmos.getSounding(
            lats,
            lons,
            elevs,
            spline=N_LAYERS,
            ref_time=self.setup.fireball_datetime)

        trans = []
        ints = []
        ts = []
        ps = []
        rfs = []

        if perturbations is None:
            ptb_len = 1
        else:
            ptb_len = len(perturbations) + 1

        for ptb_n in range(ptb_len):

            # Temporary adjustment to remove randomness from perts

            if ptb_n == 0:
                zProfile = sounding
            else:
                zProfile = perturbations[ptb_n - 1]

            S = np.array([point.x, point.y, point.z])
            D = np.array([
                stn.metadata.position.x, stn.metadata.position.y,
                stn.metadata.position.z
            ])

            _, az_n, tf_n, _ = cyscan(S, D, zProfile, wind=True,\
                h_tol=30, v_tol=30)

            self.T_d = tryFloat(self.freq_edits.text())

            self.v = 1 / self.T_d

            f, g, T, P, path_length, pdr, reed_attenuation = intscan(S,
                                                                     az_n,
                                                                     tf_n,
                                                                     zProfile,
                                                                     self.v,
                                                                     wind=True)

            self.reed_attenuation = reed_attenuation

            rf = refractiveFactor(point,
                                  stn.metadata.position,
                                  zProfile,
                                  D_ANGLE=D_ANGLE)
            trans.append(f)
            ints.append(g)
            ts.append(T)
            ps = P

            rfs.append(rf)

        return trans, ints, ts, ps, rfs, path_length, pdr

    def intCalc(self):

        stn = self.stn_list[self.current_station]
        if tryFloat(self.height_edits.text()) != None:
            height = tryFloat(self.height_edits.text())
            self.rfang = tryFloat(self.rfangle_edits.text())
            trans, ints, ts, ps, rfs, path_length, pdr = self.integrate(
                height, D_ANGLE=self.rfang)

            f_val = np.nanmean(trans)
            g_val = np.nanmean(ints)
            t_val = np.nanmean(ts)
            r_val = np.nanmean(rfs)

            # h = np.linspace(20000, 34000, 56)
            # d_ang = np.array([1.5, 2.0])
            # c = ['w', 'm', 'r', 'b', 'g']
            # print('Code Started')
            # for ii, d in enumerate(d_ang):
            #     my_data = []
            #     for height in h:
            #         trans, ints, ts, ps, rfs = self.integrate(height, D_ANGLE=d)

            #         f_val = np.nanmean(trans)
            #         g_val = np.nanmean(ints)
            #         t_val = np.nanmean(ts)
            #         p_val = np.nanmean(ps)
            #         r_val = np.nanmean(rfs)
            #         my_data.append(r_val)
            #         print('RF: {:} | ANGLE: {:} | HEIGHT: {:}'.format(r_val, d, height))

            #     plt.scatter(h, my_data, label='Angle: {:} deg'.format(d), c=c[ii])
            # plt.legend()
            # plt.show()
            # print('RF - Not checking if consistant')

            self.fd_edits.setText('{:.4f}'.format(f_val))
            self.afi_edits.setText('{:.4f}'.format(g_val))
            # self.c_edits.setText('{:.4f}'.format(t_val))
            self.pressure_edits.setText('{:.4f}'.format(ps[0]))
            self.p_a_edits.setText('{:.4f}'.format(ps[1]))
            self.path_length = path_length
            self.pdr = pdr

            try:
                frag_pos = self.setup.trajectory.findGeo(height)
            except AttributeError:
                print("STATUS: No trajectory given, assuming lat/lon center")
                frag_pos = Position(self.setup.lat_centre,
                                    self.setup.lon_centre, height)

            self.geo_edits.setText('{:.4f}'.format(r_val))
            stn_pos = stn.metadata.position
            dist = stn_pos.pos_distance(frag_pos)
            self.range_edits.setText('{:.4f}'.format(dist))

        # if tryFloat(self.height_min_edits.text()) != None:
        #     height = tryFloat(self.height_min_edits.text())

        #     trans, ints, ts, ps, rfs = self.integrate(height)

        #     f_val = np.nanmean(trans)
        #     g_val = np.nanmean(ints)
        #     t_val = np.nanmean(ts)
        #     p_val = np.nanmean(ps)
        #     r_val = np.nanmean(rfs)

        #     self.fd_min_edits.setText('{:.4f}'.format(f_val))
        #     self.afi_min_edits.setText('{:.4f}'.format(g_val))
        #     self.c_min_edits.setText('{:.4f}'.format(t_val))
        #     self.pressure_min_edits.setText('{:.4f}'.format(p_val))
        #     self.p_a_min_edits.setText('{:.4f}'.format(estPressure(stn.metadata.position.elev)))
        #     frag_pos = self.setup.trajectory.findGeo(height)
        #     self.geo_min_edits.setText('{:.4f}'.format(r_val))
        #     stn_pos = stn.metadata.position
        #     dist = stn_pos.pos_distance(frag_pos)
        #     self.range_min_edits.setText('{:.4f}'.format(dist))

        # if tryFloat(self.height_max_edits.text()) != None:
        #     height = tryFloat(self.height_max_edits.text())

        #     trans, ints, ts, ps, rfs = self.integrate(height)

        #     f_val = np.nanmean(trans)
        #     g_val = np.nanmean(ints)
        #     t_val = np.nanmean(ts)
        #     p_val = np.nanmean(ps)
        #     r_val = np.nanmean(rfs)

        #     self.fd_max_edits.setText('{:.4f}'.format(f_val))
        #     self.afi_max_edits.setText('{:.4f}'.format(g_val))
        #     self.c_max_edits.setText('{:.4f}'.format(t_val))
        #     self.pressure_max_edits.setText('{:.4f}'.format(p_val))
        #     self.p_a_max_edits.setText('{:.4f}'.format(estPressure(stn.metadata.position.elev)))
        #     frag_pos = self.setup.trajectory.findGeo(height)
        #     self.geo_max_edits.setText('{:.4f}'.format(r_val))
        #     stn_pos = stn.metadata.position
        #     dist = stn_pos.pos_distance(frag_pos)
        #     self.range_max_edits.setText('{:.4f}'.format(dist))

    def yieldCalc(self):
        w = np.linspace(np.log10(1e1), np.log10(1e15))
        W = 10**w

        if tryFloat(self.height_edits.text()) != None:
            self.R = tryFloat(self.range_edits.text())
            self.P_a = tryFloat(self.p_a_edits.text())
            self.cf = tryFloat(self.geo_edits.text())
            self.I = tryFloat(self.afi_edits.text())
            self.P = tryFloat(self.pressure_edits.text())
            self.Jd = tryFloat(self.Jd_edits.text())
            self.f_d = tryFloat(self.fd_edits.text())
            # self.I_A = tryFloat(self.I_A_edits.text())

            # Kinney and Graham 1985 uses horizontal distance (see Problem 7.1 in their book)
            height = tryFloat(self.height_edits.text())
            frag_pos = self.setup.trajectory.findGeo(height)
            stn = self.stn_list[self.current_station]
            stn_pos = stn.metadata.position
            dist = stn_pos.ground_distance(frag_pos)
            true_range = stn_pos.pos_distance(frag_pos)
            self.horizontal_range = dist

            print("Height:          {:.2f} km".format(height / 1000))
            print("Ground Distance: {:.2f} km".format(self.horizontal_range /
                                                      1000))
            print("Slant Range:     {:.2f} km".format(true_range / 1000))
            print("Path Length:     {:.2f} km".format(self.path_length / 1000))
            print("")
            print("Positive Phase Duration  {:.0f} ms".format(self.Jd))

            del_P = tryFloat(self.overpressure_edits.text())

            print("Overpressure:            {:.2f} Pa".format(del_P))
            # print("Impulse / Area           {:.2f} Pa*s".format(self.I_A))
            print("Dominant Period:         {:.2f} s".format(self.T_d))

            perc_diff = percDiff(self.T_d, self.Jd / 1000 * 2)

            print(
                "Dominant Period and Positive Phase are different by: {:.2f}%".
                format(perc_diff))
            print("")

            del_P_adj = del_P

            print("### Pressure Changes")
            print("\tAttenuation Factor: {:.2f}".format(self.I))
            print("\tGeometric Factor:   {:.2f}".format(self.cf))
            print("\tCombined Factor:    {:.2f}".format(self.I * self.cf))
            print("")
            print("### Atmospheric Pressure")
            print("\tAt Source      {:8.2f} Pa".format(self.P))
            print("\tAt Reciever    {:8.2f} Pa".format(self.P_a))
            print("\tPressure Ratio {:8.2f} Pa".format(self.P / self.P_a))
            print("")
            print("Overpressure (Reciever)                      : {:.2E} Pa".
                  format(del_P_adj))
            print("Overpressure (Reciever at Source Pressure)   : {:.2E} Pa".
                  format(del_P_adj * (self.P_a / self.P)**(1 / 6)))
            print(
                "Overpressure Ratio (Source)                  : {:.2E}".format(
                    del_P_adj / self.P))
            print(
                "Overpressure Ratio (Reciever)                : {:.2E}".format(
                    del_P_adj / self.P_a))

            print("")

            #needham blast waves
            # brode
            # jones 1968

            # Reed 1972 enhancement for airborne bursts
            pres_fact = (self.P_a / self.P)**(1 / 6)
            pres_fact_85 = 1 / self.f_d

            # f_t depends on sound speed - roughly a factor of 1
            f_t = self.f_d * (330 / 330)

            p_p0 = np.exp(self.reed_attenuation)

            print(
                "Pressure Factor (Reed 1972) Attenuation: {:.2f}".format(p_p0))
            print("Pressure Factor (Reed 1972) 1/6 Power  : {:.2f}".format(
                pres_fact))
            print("Pressure Factor (KG   1985)            : {:.2f}".format(
                pres_fact_85))

            new_pres = del_P_adj * pres_fact / self.I / self.cf
            phase_duration = self.Jd * (1 / f_t)

            reed_yield = ReedYield(self.R, new_pres / pres_fact)
            ansi_yield = ANSIYield(self.R, self.P, new_pres / pres_fact,
                                   self.P_a)

            print("Height Adjusted Pressure: {:.2f} Pa".format(new_pres))
            print("Height Adjusted Time    : {:.2f} ms".format(phase_duration))

            Z_chem = findScaledDistance([(new_pres) / (self.P_a)],
                                        chemFuncMinimizer)
            Z_nuc = findScaledDistance([(new_pres) / (self.P_a)],
                                       nucFuncMinimizer)
            Z_chem_d = findScaledDistance([phase_duration, self.R, self.f_d],
                                          chemFuncDurationMinimizer)
            Z_nuc_d = findScaledDistance([phase_duration, self.R, self.f_d],
                                         nucFuncDurationMinimizer)

            # # Z_chem_IA = findScaledDistance([self.I_A], chemFuncImpulseMinimizer)

            print(
                "Scaled Distance from Overpressure (Chemical):    {:10.2f} km".
                format(Z_chem / 1000))
            print(
                "Scaled Distance from Duration (Chemical):        {:10.2f} km".
                format(Z_chem_d / 1000))
            # print("Scaled Distance from Impulse (Chemical):         {:10.2f} km".format(Z_chem_IA/1000))
            print(
                "Scaled Distance from Overpressure  (Nuclear):    {:10.2f} km".
                format(Z_nuc / 1000))
            print(
                "Scaled Distance from Duration  (Nuclear):        {:10.2f} km".
                format(Z_nuc_d / 1000))

            Yield_chem = (self.f_d * self.R / Z_chem)**3 * self.W_0
            Yield_nuc = (self.f_d * self.R / Z_nuc)**3 * self.W_0 * 1e6

            Yield_chem_d = (self.f_d * self.R / Z_chem_d)**3 * self.W_0
            Yield_nuc_d = (self.f_d * self.R / Z_nuc_d)**3 * self.W_0 * 1e6
            # Yield_chem_IA = (self.f_d*self.R/Z_chem_IA)**3*self.W_0
            # print("### Previous Yields")
            print(
                "\tExpected Yield (ANSI 1977 Empirical):    {:.2E} J ({:.2f} kg TNT)"
                .format(ansi_yield, ansi_yield / self.W_0))
            print(
                "\tExpected Yield (Reed 1977 Empirical):    {:.2E} J ({:.2f} kg TNT)"
                .format(reed_yield, reed_yield / self.W_0))
            print(
                "\tExpected Yield (Chemical Overpressure):  {:.2E} J ({:.2f} kg TNT)"
                .format(Yield_chem, Yield_chem / self.W_0))
            print(
                "\tExpected Yield (Chemical Duration):      {:.2E} J ({:.2f} kg TNT)"
                .format(Yield_chem_d, Yield_chem_d / self.W_0))
            # print("Expected Yield (Chemical Impulse):       {:.2E} J ({:.2f} kg TNT)".format(Yield_chem_IA, Yield_chem_IA/self.W_0))
            print(
                "\tExpected Yield (Nuclear Overpressure):   {:.2E} J ({:.2f} kT TNT)"
                .format(Yield_nuc, Yield_nuc / self.W_0 / 1e6))
            print(
                "\tExpected Yield (Nuclear Duration):       {:.2E} J ({:.2f} kT TNT)"
                .format(Yield_nuc_d, Yield_nuc_d / self.W_0 / 1e6))

            # factor = 101325/self.P # This is 101325 because that is what the KG85 equations are reference to

            # print("Yield Correction Factor: {:.2f}".format(factor))
            # print("New Factor:              {:.2f}".format((self.pdr/self.P_a/self.path_length**3)))

            # print("\tExpected Yield (Chemical Overpressure):  {:.2E} J ({:.2f} kg TNT)".format(Yield_chem, Yield_chem/self.W_0))
            # print("\tExpected Yield (Chemical Duration):      {:.2E} J ({:.2f} kg TNT)".format(factor*Yield_chem_d, factor*Yield_chem_d/self.W_0))
            # # print("Expected Yield (Chemical Impulse):       {:.2E} J ({:.2f} kg TNT)".format(factor*Yield_chem_IA, factor*Yield_chem_IA/self.W_0))
            # print("\tExpected Yield (Nuclear Overpressure):   {:.2E} J ({:.2f} kT TNT)".format(factor*Yield_nuc, factor*Yield_nuc/self.W_0/1e6))
            # print("\tExpected Yield (Nuclear Duration):       {:.2E} J ({:.2f} kT TNT)".format(factor*Yield_nuc_d, factor*Yield_nuc_d/self.W_0/1e6))

            # print("### Sach Scaling")
            # Z = np.array([Z_chem, Z_nuc, Z_chem_d, Z_nuc_d])
            # Yield = (self.W_0/self.P_a/Z**3)*self.pdr
            # print("\tExpected Yield (Chemical Overpressure):  {:.2E} J ({:.2f} kg TNT)".format(Yield[0], Yield[0]/self.W_0))
            # print("\tExpected Yield (Chemical Duration):      {:.2E} J ({:.2f} kg TNT)".format(Yield[2], Yield[2]/self.W_0))
            # # print("Expected Yield (Chemical Impulse):       {:.2E} J ({:.2f} kg TNT)".format(factor*Yield_chem_IA, factor*Yield_chem_IA/self.W_0))
            # print("\tExpected Yield (Nuclear Overpressure):   {:.2E} J ({:.2f} kT TNT)".format(Yield[1]*1e6, Yield[1]/self.W_0))
            # print("\tExpected Yield (Nuclear Duration):       {:.2E} J ({:.2f} kT TNT)".format(Yield[3]*1e6, Yield[3]/self.W_0))

            # # Plot of overpressure vs. yield
            # del_p = chem_func(scaledDistance(self.f_d, self.R, W, self.W_0))*self.P_a*\
            #             self.I*self.cf

            # self.yieldPlot(del_p, W)

            ### PLOTTING

            sample_R = np.linspace(0, 100000)
            sample_dP = new_pres / pres_fact  # Pa
            sample_P = self.P

            sample_W = ReedYield(sample_R, sample_dP)
            sample_W_ansi = ANSIYield(sample_R, sample_P, sample_dP, self.P_a)

            c = ["w", "r", "m"]

            self.blastline_plot.ax.plot(
                sample_R / 1000,
                sample_W_ansi / self.W_0,
                c=c[self.iterator],
                linestyle="--",
                label="(ANSI 1983) Overpressure = {:.2f} Pa".format(sample_dP))
            self.blastline_plot.ax.plot(
                sample_R / 1000,
                sample_W / self.W_0,
                c=c[self.iterator],
                label="(Reed 1972) Overpressure = {:.2f} Pa".format(sample_dP))
            self.blastline_plot.ax.scatter(self.R / 1000,
                                           Yield_chem / self.W_0,
                                           c=c[self.iterator],
                                           marker="<",
                                           label="Chemical Overpressure")
            self.blastline_plot.ax.scatter(self.R / 1000,
                                           Yield_chem_d / self.W_0,
                                           c=c[self.iterator],
                                           marker=">",
                                           label="Chemical Duration")
            self.blastline_plot.ax.legend()

            self.iterator += 1

            self.blastline_plot.ax.semilogy()

            self.blastline_plot.ax.set_xlabel("Range [km]")
            self.blastline_plot.ax.set_ylabel("Yield [kg TNT HE]")

            self.blastline_plot.show()

            qm = QMessageBox()
            ret = qm.question(self, '', "Save this yield?", qm.Yes | qm.No)

            if ret == qm.Yes:
                a = EnergyObj()
                a.source_type = "Fragmentation"
                a.height = height
                a.range = self.R
                a.station = self.stn_list[self.current_station]
                a.ansi_yield = ansi_yield
                a.reed_yield = reed_yield
                a.chem_pres = Yield_chem
                a.chem_dur = Yield_chem_d
                a.nuc_pres = Yield_nuc
                a.nuc_dur = Yield_nuc_d
                a.adj_pres = new_pres
                a.adj_dur = phase_duration

                self.bam.energy_measurements.append(a)

    def yieldPlot(self, p_ratio, W, unc='none'):
        if unc == 'none':
            self.count += 1

        colour = [(0, 255, 26), (3, 252, 176), (252, 3, 3), (176, 252, 3),
                  (255, 133, 3), (149, 0, 255), (76, 128, 4), (82, 27, 27),
                  (101, 128, 125), (5, 176, 249)]
        ptb_colour = [(0, 255, 26, 150), (3, 252, 176, 150), (252, 3, 3, 150),
                      (176, 252, 3, 150), (255, 133, 3, 150),
                      (149, 0, 255, 150), (76, 128, 4, 150), (82, 27, 27, 150),
                      (101, 128, 125, 150), (5, 176, 249, 150)]

        self.blastline_canvas.setLogMode(False, True)
        if unc == 'none':
            self.nominal = pg.PlotCurveItem(x=p_ratio,
                                            y=np.log10(W),
                                            pen=colour[self.count - 1],
                                            name='Fragmentation {:}'.format(
                                                self.count))
            self.blastline_canvas.addItem(self.nominal)
            self.p_rat = tryFloat(self.overpressure_edits.text())
        # if unc == 'min':
        #     self.min = pg.PlotCurveItem(x=p_ratio, y=np.log10(W), pen=ptb_colour[self.count-1], name='Fragmentation {:}'.format(self.count))
        #     self.blastline_canvas.addItem(self.min)
        #     self.p_rat = tryFloat(self.overpressure_min_edits.text())
        # if unc == 'max':
        #     self.max = pg.PlotCurveItem(x=p_ratio, y=np.log10(W), pen=ptb_colour[self.count-1], name='Fragmentation {:}'.format(self.count))
        #     self.blastline_canvas.addItem(self.max)
        #     self.p_rat = tryFloat(self.overpressure_max_edits.text())
        # try:
        #     pfill = pg.FillBetweenItem(self.min, self.max, brush=ptb_colour[self.count-1])
        #     self.blastline_canvas.addItem(pfill)
        #     self.min = None
        #     self.max = None
        # except:
        #     pass

        self.blastline_canvas.setLabel('bottom', "Overpressure", units='Pa')
        self.blastline_canvas.setLabel('left', "Yield", units='J')
        # self.blastline_canvas.addItem(pg.InfiniteLine(pos=(self.p_rat, 0), angle=90))

        W_final = self.inverse_gunc(self.p_rat)
        self.blastline_canvas.scatterPlot(x=[self.p_rat],
                                          y=[W_final],
                                          pen=(255, 255, 255),
                                          brush=(255, 255, 255),
                                          size=10,
                                          pxMode=True)

        print('Blast Estimate     {:.2E} J'.format(W_final))

        if unc == 'none':
            txt = pg.TextItem("Frag {:}: {:.2E} J".format(self.count, W_final))
            txt.setPos(self.p_rat, np.log10(W_final))
            self.blastline_canvas.addItem(txt)
        self.blastline_canvas.setTitle(
            'Fragmentation Yield Curves for a Given Overpressure')