Exemple #1
0
class Photon_Counter(QMainWindow):
    def __init__(self):
        super().__init__()

        self.f_acq = 1E5
        self.t_acq = 15E-3

        self.f_pulse = 5E6
        self.t_pulse = 1E-3
        self.t_lect = 4E-4

        self.n_points = 301

        self.f_uw = 2800  #MHz
        self.level_uw = 10  #dB
        self.t_uw = 10  #us

        self.n_lect_min = 0
        self.n_lect_max = -1

        self.time_last_refresh = time.time()
        self.refresh_rate = 0.1

        self.setWindowTitle("T1")

        ##Creation of the graphical interface##

        self.main = QWidget()
        self.setCentralWidget(self.main)

        layout = QHBoxLayout()
        Vbox = QVBoxLayout()
        Vbox_gauche = QVBoxLayout()
        Vbox_droite = QVBoxLayout()

        layout.addLayout(Vbox_gauche)
        layout.addLayout(Vbox)
        layout.addLayout(Vbox_droite)
        self.main.setLayout(layout)

        #Fields on the left

        self.labelf_acq = QLabel("f_acq(max 500 KHz)")
        self.lectf_acq = QLineEdit('%3.2E' % self.f_acq)
        Vbox_gauche.addWidget(self.labelf_acq)
        Vbox_gauche.addWidget(self.lectf_acq)

        self.labelt_acq = QLabel("t_acq")
        self.lectt_acq = QLineEdit(str(self.t_acq))
        Vbox_gauche.addWidget(self.labelt_acq)
        Vbox_gauche.addWidget(self.lectt_acq)

        Vbox_gauche.addStretch(1)

        self.labelf_pulse = QLabel("f_pulse (max 25 MHz)")
        self.lectf_pulse = QLineEdit('%3.2E' % self.f_pulse)
        Vbox_gauche.addWidget(self.labelf_pulse)
        Vbox_gauche.addWidget(self.lectf_pulse)

        self.labelt_pulse = QLabel("t_pulse")
        self.lectt_pulse = QLineEdit(str(self.t_pulse))
        Vbox_gauche.addWidget(self.labelt_pulse)
        Vbox_gauche.addWidget(self.lectt_pulse)

        self.labelt_lect = QLabel("t_lect")
        self.lectt_lect = QLineEdit(str(self.t_lect))
        Vbox_gauche.addWidget(self.labelt_lect)
        Vbox_gauche.addWidget(self.lectt_lect)

        Vbox_gauche.addStretch(1)

        self.labeln_points = QLabel("n_points")
        self.lectn_points = QLineEdit(str(self.n_points))
        Vbox_gauche.addWidget(self.labeln_points)
        Vbox_gauche.addWidget(self.lectn_points)
        Vbox_gauche.addStretch(1)

        self.labelf_uw = QLabel("frequency uw (MHz)")
        self.lectf_uw = QLineEdit(str(self.f_uw))
        Vbox_gauche.addWidget(self.labelf_uw)
        Vbox_gauche.addWidget(self.lectf_uw)

        self.labellevel_uw = QLabel("level uw (dB)")
        self.lectlevel_uw = QLineEdit(str(self.level_uw))
        Vbox_gauche.addWidget(self.labellevel_uw)
        Vbox_gauche.addWidget(self.lectlevel_uw)

        self.labelt_uw = QLabel("t_uw (us)")
        self.lectt_uw = QLineEdit(str(self.t_uw))
        Vbox_gauche.addWidget(self.labelt_uw)
        Vbox_gauche.addWidget(self.lectt_uw)
        Vbox_gauche.addStretch(1)

        #Buttons on the right
        self.stop = QPushButton('Stop')
        self.start = QPushButton('Start')
        self.keep_button = QPushButton('Keep trace')
        self.clear_button = QPushButton('Clear Last Trace')
        self.fit_button = QPushButton('Fit')
        self.normalize_cb = QCheckBox('Normalize')

        self.labelIter = QLabel("iter # 0")

        Vbox_droite.addWidget(self.normalize_cb)
        Vbox_droite.addStretch(1)
        Vbox_droite.addWidget(self.start)
        Vbox_droite.addWidget(self.stop)
        Vbox_droite.addStretch(1)
        Vbox_droite.addWidget(self.keep_button)
        Vbox_droite.addWidget(self.clear_button)
        Vbox_droite.addStretch(1)
        Vbox_droite.addWidget(self.fit_button)
        Vbox_droite.addStretch(1)
        Vbox_droite.addWidget(self.labelIter)

        self.stop.setEnabled(False)

        #Plot in the middle
        self.dynamic_canvas = FigureCanvas(Figure(figsize=(30, 10)))
        Vbox.addStretch(1)
        Vbox.addWidget(self.dynamic_canvas)
        self.addToolBar(Qt.BottomToolBarArea,
                        MyToolbar(self.dynamic_canvas, self))

        ## Matplotlib Setup ##

        self.dynamic_ax = self.dynamic_canvas.figure.subplots()

        self.t = np.linspace(0, 100, 100)
        self.y = np.zeros(100)
        self.dynamic_line, = self.dynamic_ax.plot(self.t, self.y)

        self.dynamic_ax.set_xlabel('time(s)')
        self.dynamic_ax.set_ylabel('PL(counts/s)')

        #Define the buttons' action

        self.start.clicked.connect(self.start_measure)
        self.stop.clicked.connect(self.stop_measure)
        self.keep_button.clicked.connect(self.keep_trace)
        self.clear_button.clicked.connect(self.clear_trace)
        self.fit_button.clicked.connect(self.auto_fit)

        ## Timer Setup ##

        self.update_value()
        self.timer = QTimer(self, interval=0)
        self.timer.timeout.connect(self.update_canvas)

    def keep_trace(self):
        self.dynamic_ax.plot(self.dynamic_line._x, self.dynamic_line._y)

    def clear_trace(self):
        lines = self.dynamic_ax.get_lines()
        line = lines[-1]
        if line != self.dynamic_line:
            line.remove()
        self.dynamic_ax.figure.canvas.draw()

    def auto_fit(self):
        from scipy.optimize import curve_fit, root_scalar
        x = self.dynamic_line._x[self.n_lect_min:self.n_lect_max]
        y = self.dynamic_line._y[self.n_lect_min:self.n_lect_max]

        def exp_fit(x, y, Amp=None, ss=None, tau=None):
            if not Amp:
                Amp = max(y) - min(y)
            if not ss:
                ss = y[-1]
            if not tau:
                tau = x[int(len(x) / 10)] - x[0]

            def f(x, Amp, ss, tau):
                return Amp * np.exp(-x / tau) + ss

            p0 = [Amp, ss, tau]
            popt, pcov = curve_fit(f, x, y, p0)
            return (popt, f(x, popt[0], popt[1], popt[2]))

        popt, yfit = exp_fit(x, y)
        self.dynamic_ax.plot(x, yfit, label='tau=%4.3e' % popt[2])
        self.dynamic_ax.legend()
        self.dynamic_canvas.draw()

    def update_value(self):

        self.t_acq = np.float(self.lectt_acq.text())
        self.f_acq = np.float(self.lectf_acq.text())
        self.t_pulse = np.float(self.lectt_pulse.text())
        self.f_pulse = np.float(self.lectf_pulse.text())
        self.t_lect = np.float(self.lectt_lect.text())
        self.n_points = np.int(self.lectn_points.text())
        self.f_uw = np.float(self.lectf_uw.text())
        self.level_uw = np.float(self.lectlevel_uw.text())
        self.t_uw = np.float(self.lectt_uw.text()) * 1e-6

        self.n_pulse = int(self.f_pulse * self.t_pulse)
        self.n_pulse_acq = int(self.f_acq * self.t_pulse)
        self.n_lect = int(self.f_acq * self.t_lect)
        self.n_uw = max(int(self.t_uw * self.f_acq), 1)

        self.t = np.linspace(0, self.t_acq, self.n_points + 1)
        self.t = self.t[1:]
        self.n_step_time = [int(t * self.f_acq) for t in self.t]
        self.n_acq = sum(self.n_step_time) + self.n_pulse_acq * self.n_points

        self.gate_signal = [False] * self.n_acq
        for i in range(self.n_points):
            self.gate_signal[i * self.n_pulse_acq +
                             sum(self.n_step_time[0:i + 1])] = True
        self.gate_signal = self.gate_signal * 2

        self.bornes_lect = [[
            i * self.n_pulse_acq + sum(self.n_step_time[0:i + 1]),
            i * self.n_pulse_acq + sum(self.n_step_time[0:i + 1]) + self.n_lect
        ] for i in range(self.n_points)]

        self.switch_signal = [False] * self.n_acq
        print(self.n_step_time[0], self.n_uw)
        if self.n_step_time[0] < self.n_uw:
            print("Pulse micro onde trop longue")
            exit()
        for i in range(self.n_points):
            self.switch_signal[i * self.n_pulse_acq +
                               sum(self.n_step_time[0:i + 1]) -
                               self.n_uw:i * self.n_pulse_acq +
                               sum(self.n_step_time[0:i + 1])] = [True
                                                                  ] * self.n_uw
        self.switch_signal = [False] * self.n_acq + self.switch_signal

        #Astuce de gitan, on prend le premier point à la fin de la pulse de polarisation plutot qu'au début. Ca marche pas de ouf.
        # self.bornes_lect[0]=[self.n_pulse_acq-self.n_lect,self.n_pulse_acq]

        self.y_1 = np.zeros(self.n_points)
        self.y_2 = np.zeros(self.n_points)

        # self.t=np.linspace(0,self.n_acq/self.f_acq,self.n_acq)
        # self.y=np.zeros(self.n_acq)
        self.dynamic_line.set_data(self.t[self.n_lect_min:self.n_lect_max],
                                   self.y_1[self.n_lect_min:self.n_lect_max])

        self.set_lim('x')
        self.dynamic_canvas.draw()

    def start_measure(self):
        ## What happens when you click "start" ##

        self.start.setEnabled(False)
        self.stop.setEnabled(True)

        #Read integration input values
        self.update_value()

        self.laser_trigg = nidaqmx.Task()
        self.laser_trigg.co_channels.add_co_pulse_chan_freq('Dev1/ctr0',
                                                            freq=self.f_pulse)
        self.laser_trigg.timing.cfg_implicit_timing(
            sample_mode=nidaqmx.constants.AcquisitionType.FINITE,
            samps_per_chan=self.n_pulse)
        self.laser_trigg.triggers.start_trigger.cfg_dig_edge_start_trig(
            '/Dev1/PFI9')
        self.laser_trigg.triggers.start_trigger.retriggerable = True
        self.laser_trigg.start()

        self.logic_out = nidaqmx.Task()
        self.logic_out.do_channels.add_do_chan('Dev1/port0/line7')
        self.logic_out.do_channels.add_do_chan('Dev1/port0/line3')
        self.logic_out.timing.cfg_samp_clk_timing(
            self.f_acq,
            sample_mode=nidaqmx.constants.AcquisitionType.FINITE,
            samps_per_chan=self.n_acq * 2)
        self.logic_out.triggers.start_trigger.cfg_dig_edge_start_trig(
            '/Dev1/ai/StartTrigger')
        self.logic_out.triggers.start_trigger.retriggerable = True
        logic_out_signal = [self.gate_signal, self.switch_signal]
        self.logic_out.write(self.logic_out_signal)
        self.logic_out.start()

        self.read = nidaqmx.Task()
        self.read.ai_channels.add_ai_voltage_chan("Dev1/ai11")
        self.read.timing.cfg_samp_clk_timing(
            self.f_acq,
            sample_mode=nidaqmx.constants.AcquisitionType.FINITE,
            samps_per_chan=self.n_acq * 2)

        self.repeat = 1

        #Start the task, then the timer

        self.timer.start()

    def set_lim(self, axes='both', line=None, ax=None):
        if not line:
            line = self.dynamic_line
        if not ax:
            ax = self.dynamic_ax

        x = line._x
        y = line._y
        xmin = min(x)
        xmax = max(x)
        ymin = min(y)
        ymax = max(y)
        Dx = xmax - xmin
        Dy = ymax - ymin
        dx = 0.01 * Dx
        dy = 0.01 * Dy
        if axes == 'both':
            ax.set_xlim([xmin - dx, xmax + dx])
            ax.set_ylim([ymin - dy, ymax + dy])
        if axes == 'x':
            ax.set_xlim([xmin - dx, xmax + dx])
        if axes == 'y':
            ax.set_ylim([ymin - dy, ymax + dy])

    def update_canvas(self):
        ##Update the plot and the value of the PL ##

        lecture = np.array(self.read.read(self.n_acq * 2))

        lecture_1 = lecture[:self.n_acq]
        lecture_2 = lecture[self.n_acq:]

        PL_1 = [
            sum(lecture_1[self.bornes_lect[i][0]:self.bornes_lect[i][1]]) /
            self.n_lect for i in range(self.n_points)
        ]
        PL_1 = np.array(PL_1)
        self.y_1 = self.y_1 * (1 - 1 / self.repeat) + PL_1 * (1 / self.repeat)

        PL_2 = [
            sum(lecture_2[self.bornes_lect[i][0]:self.bornes_lect[i][1]]) /
            self.n_lect for i in range(self.n_points)
        ]
        PL_2 = np.array(PL_2)
        self.y_2 = self.y_2 * (1 - 1 / self.repeat) + PL_2 * (1 / self.repeat)

        self.repeat += 1

        self.y = self.y_1

        if time.time() - self.time_last_refresh > self.refresh_rate:
            self.time_last_refresh = time.time()

            if self.normalize_cb.isChecked():
                ytoplot = self.y / max(self.y)
            else:
                ytoplot = self.y

            self.dynamic_line.set_ydata(
                ytoplot[self.n_lect_min:self.n_lect_max])
            self.set_lim()

            self.dynamic_canvas.draw()

            self.labelIter.setText("iter # %i" % self.repeat)

    def stop_measure(self):
        #Stop the measuring, clear the tasks on both counters
        try:
            self.timer.stop()
        except:
            pass
        try:
            self.read.close()
        except:
            pass
        try:
            self.laser_trigg.close()
        except:
            pass
        try:
            self.logic_out.close()
        except:
            pass

        self.stop.setEnabled(False)
        self.start.setEnabled(True)
Exemple #2
0
class FDTD(QMainWindow, Ui_MainWindow):
    """
    FDTD object.
    """
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the button
        self.start_fdtd.clicked.connect(self._start_fdtd)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.fig = fig
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea,
                        NavigationToolbar(self.my_canvas, self))

        # TE or TM
        self.mode = None

        # Angle of the incident field (degrees)
        self.incident_angle = None

        # Number of time steps
        self.number_of_time_steps = None

        # Geometry file to use
        self.geometry_file = None

        # Width of the Gaussian pulse in time steps
        self.gaussian_pulse_width = None

        # Amplitude of the Gaussian pulse
        self.gaussian_pulse_amplitude = None

        # Number of PML layers
        self.number_of_pml = None

        # Others that will be set during initialization
        self.nx = None
        self.ny = None

        self.dx = None
        self.dy = None

        self.dt = None

        self.mu_r = None
        self.eps_r = None
        self.sigma = None

        self.exi = None
        self.eyi = None
        self.dexi = None
        self.deyi = None
        self.exs = None
        self.eys = None
        self.hzs = None
        self.dhzi = None

        self.hxs = None
        self.hys = None
        self.dhxi = None
        self.dhyi = None
        self.ezs = None
        self.ezi = None
        self.dezi = None

        self.esctc = None
        self.eincc = None
        self.edevcn = None
        self.ecrlx = None
        self.ecrly = None
        self.dtmdx = None
        self.dtmdy = None
        self.hdhvcn = None

        self.amplitude_y = None
        self.amplitude_x = None

        self.geometry_file = '../Libs/rcs/fdtd.cell'

    def _start_fdtd(self):
        """
        Update the figure when the user changes an input value.
        :return:
        """
        # Get the parameters from the form
        self.mode = self._mode.currentText()
        self.incident_angle = float(self._incident_angle.text())
        self.number_of_time_steps = int(self._number_of_time_steps.text())
        self.gaussian_pulse_width = int(self._gaussian_pulse_width.text())
        self.gaussian_pulse_amplitude = float(
            self._gaussian_pulse_amplitude.text())
        self.number_of_pml = int(self._number_of_pml.text())

        # Read the geometry file and calculate material parameters
        self.initialize()

    def initialize(self):
        """
        Initialize the variables for FDTD calculations.
        :return:
        """
        # Read the geometry file
        self.read_geometry_file()

        # Initialize the fields
        if self.mode == 'TE':
            self.exi = zeros([self.nx, self.ny])
            self.eyi = zeros([self.nx, self.ny])
            self.dexi = zeros([self.nx, self.ny])
            self.deyi = zeros([self.nx, self.ny])
            self.exs = zeros([self.nx, self.ny])
            self.eys = zeros([self.nx, self.ny])
            self.hzs = zeros([self.nx, self.ny])
            self.dhzi = zeros([self.nx, self.ny])
        else:
            self.hxs = zeros([self.nx, self.ny])
            self.hys = zeros([self.nx, self.ny])
            self.dhxi = zeros([self.nx, self.ny])
            self.dhyi = zeros([self.nx, self.ny])
            self.ezs = zeros([self.nx, self.ny])
            self.ezi = zeros([self.nx, self.ny])
            self.dezi = zeros([self.nx, self.ny])

        # Initialize the other values
        self.esctc = zeros([self.nx, self.ny])
        self.eincc = zeros([self.nx, self.ny])
        self.edevcn = zeros([self.nx, self.ny])
        self.ecrlx = zeros([self.nx, self.ny])
        self.ecrly = zeros([self.nx, self.ny])
        self.dtmdx = zeros([self.nx, self.ny])
        self.dtmdy = zeros([self.nx, self.ny])
        self.hdhvcn = zeros([self.nx, self.ny])

        # Calculate the maximum time step allowed by the Courant stability condition
        self.dt = 1.0 / (c * (sqrt(1.0 / (self.dx**2) + 1.0 / (self.dy**2))))

        for i in range(self.nx):
            for j in range(self.ny):
                eps = epsilon_0 * self.eps_r[i][j]
                mu = mu_0 * self.mu_r[i][j]
                self.esctc[i][j] = eps / (eps + self.sigma[i][j] * self.dt)
                self.eincc[i][j] = self.sigma[i][j] * self.dt / (
                    eps + self.sigma[i][j] * self.dt)
                self.edevcn[i][j] = self.dt * (eps - epsilon_0) / (
                    eps + self.sigma[i][j] * self.dt)
                self.ecrlx[i][j] = self.dt / (
                    (eps + self.sigma[i][j] * self.dt) * self.dx)
                self.ecrly[i][j] = self.dt / (
                    (eps + self.sigma[i][j] * self.dt) * self.dy)
                self.dtmdx[i][j] = self.dt / (mu * self.dx)
                self.dtmdy[i][j] = self.dt / (mu * self.dy)
                self.hdhvcn[i][j] = self.dt * (mu - mu_0) / mu

        # Amplitude of incident field components
        self.amplitude_x = -self.gaussian_pulse_amplitude * sin(
            radians(self.incident_angle))
        self.amplitude_y = self.gaussian_pulse_amplitude * cos(
            radians(self.incident_angle))

        # Run the selected mode
        if self.mode == 'TE':
            self.te()
        else:
            self.tm()

    def read_geometry_file(self):
        """
        Read the FDTD geometry file.
        :return:
        """

        # Get the base path for the file
        base_path = Path(__file__).parent

        with open((base_path / self.geometry_file).resolve(), 'r') as file:

            # Header Line 1: Comment
            _ = file.readline()

            # Header Line 2: nx ny
            line = file.readline()
            line_list = line.split()
            self.nx = int(line_list[0])
            self.ny = int(line_list[1])

            # Header Line 3: Comment
            _ = file.readline()

            # Header Line 4: dx dy
            line = file.readline()
            line_list = line.split()
            self.dx = float(line_list[0])
            self.dy = float(line_list[1])

            # Set up the PML areas first
            self.nx += 2 * self.number_of_pml
            self.ny += 2 * self.number_of_pml

            self.mu_r = zeros([self.nx, self.ny])
            self.eps_r = zeros([self.nx, self.ny])
            self.sigma = zeros([self.nx, self.ny])

            # Set up the maximum conductivities
            sigma_max_x = -3.0 * epsilon_0 * c * log(1e-5) / (
                2.0 * self.dx * self.number_of_pml)
            sigma_max_y = -3.0 * epsilon_0 * c * log(1e-5) / (
                2.0 * self.dy * self.number_of_pml)

            # Create the conductivity profile
            sigma_v = [((m + 0.5) / (self.number_of_pml + 0.5))**2
                       for m in range(self.number_of_pml)]

            # Back region
            for i in range(self.nx):
                for j in range(self.number_of_pml):
                    self.mu_r[i][j] = 1.0
                    self.eps_r[i][j] = 1.0
                    self.sigma[i][j] = sigma_max_y * sigma_v[self.number_of_pml
                                                             - 1 - j]

            # Front region
            for i in range(self.nx):
                for j in range(self.ny - self.number_of_pml, self.ny):
                    self.mu_r[i][j] = 1.0
                    self.eps_r[i][j] = 1.0
                    self.sigma[i][j] = sigma_max_y * sigma_v[
                        j - (self.ny - self.number_of_pml)]

            # Left region
            for i in range(self.number_of_pml):
                for j in range(self.ny):
                    self.mu_r[i][j] = 1.0
                    self.eps_r[i][j] = 1.0
                    self.sigma[i][j] += sigma_max_x * sigma_v[
                        self.number_of_pml - 1 - i]

            # Right region
            for i in range(self.nx - self.number_of_pml, self.nx):
                for j in range(self.ny):
                    self.mu_r[i][j] = 1.0
                    self.eps_r[i][j] = 1.0
                    self.sigma[i][j] += sigma_max_x * sigma_v[
                        i - (self.nx - self.number_of_pml)]

            # Read the geometry
            for i in range(self.number_of_pml, self.nx - self.number_of_pml):
                for j in range(self.number_of_pml,
                               self.ny - self.number_of_pml):

                    line = file.readline()

                    line_list = line.split()

                    # Relative permeability first
                    self.mu_r[i][j] = float(line_list[0])

                    # Relative permittivity next
                    self.eps_r[i][j] = float(line_list[1])

                    # Finally the conductivity
                    self.sigma[i][j] = float(line_list[2])
        file.close()

    def te(self):
        """
        TE mode.
        :return:
        """
        # Start at time = 0
        t = 0.0

        # Loop over the time steps
        for n in range(self.number_of_time_steps):

            # Update the scattered electric field
            self.escattered_te(t)

            # Advance the time by 1/2 time step
            t += 0.5 * self.dt

            # Update the scattered magnetic field
            self.hscattered_te(t)

            # Advance the time by 1/2 time step
            t += 0.5 * self.dt

            # Update the canvas
            self._update_canvas()

            # Progress
            print('{} of {} time steps'.format(n + 1,
                                               self.number_of_time_steps))

    def tm(self):
        """
        TM mode.
        :return:
        """
        # Start at time = 0
        t = 0.0

        # Loop over the time steps
        for n in range(self.number_of_time_steps):

            # Update the scattered electric field
            self.escattered_tm(t)

            # Advance the time by 1/2 time step
            t += 0.5 * self.dt

            # Update the scattered magnetic field
            self.hscattered_tm(t)

            # Advance the time by 1/2 time step
            t += 0.5 * self.dt

            # Update the canvas
            self._update_canvas()

            # Progress
            print('{} of {} time steps'.format(n + 1,
                                               self.number_of_time_steps))

    def eincident_te(self, t):
        """
        Calculate the incident electric field for TE mode.
        :param t: Time (s).
        :return:
        """
        # Calculate the incident electric field and derivative
        delay = 0

        # Calculate the decay rate determined by Gaussian pulse width
        alpha = (1.0 / (self.dt * self.gaussian_pulse_width / 4.0))**2

        # Calculate the period
        period = 2.0 * self.dt * self.gaussian_pulse_width

        # Spatial delay of each cell
        x_disp = -cos(radians(self.incident_angle))
        y_disp = -sin(radians(self.incident_angle))

        if x_disp < 0:
            delay -= x_disp * (self.nx - 2.0) * self.dx

        if y_disp < 0:
            delay -= y_disp * (self.ny - 2.0) * self.dy

        for i in range(self.number_of_pml, self.nx - self.number_of_pml):
            for j in range(self.number_of_pml, self.ny - self.number_of_pml):
                distance = i * self.dx * x_disp + j * self.dy * y_disp + delay
                a = 0
                a_prime = 0
                tau = t - distance / c

                if 0 <= tau <= period:
                    a = exp(-alpha *
                            (tau - self.gaussian_pulse_width * self.dt)**2)
                    a_prime = exp(-alpha * (tau - self.gaussian_pulse_width * self.dt) ** 2) \
                              * (-2.0 * alpha * (tau - self.gaussian_pulse_width * self.dt))

                self.exi[i][j] = self.amplitude_x * a
                self.dexi[i][j] = self.amplitude_x * a_prime

                self.eyi[i][j] = self.amplitude_y * a
                self.deyi[i][j] = self.amplitude_y * a_prime

    def escattered_te(self, t):
        """
        Calculate the scattered electric field for TE mode.
        :param t: Time (s).
        :return:
        """
        # Calculate the incident electric field
        self.eincident_te(t)

        # Update the x-component electric scattered field
        for i in range(self.nx - 1):
            for j in range(self.ny - 1):
                self.exs[i][j] = self.exs[i][j] * self.esctc[i][j] - self.eincc[i][j] * self.exi[i][j] \
                                 - self.edevcn[i][j] * self.dexi[i][j] + (self.hzs[i][j] - self.hzs[i][j - 1]) \
                                 * self.ecrly[i][j]

        # Update the y-component electric scattered field
        for i in range(1, self.nx - 1):
            for j in range(self.ny - 1):
                self.eys[i][j] = self.eys[i][j] * self.esctc[i][j] - self.eincc[i][j] * self.eyi[i][j] \
                                 - self.edevcn[i][j] * self.deyi[i][j] - (self.hzs[i][j] - self.hzs[i - 1][j]) \
                                 * self.ecrlx[i][j]

    def hincident_te(self, t):
        """
        Calculate the incident magnetic field for TE mode.
        :param t: Time (s).
        :return:
        """
        # Calculate the incident magnetic field and time derivative
        delay = 0.0
        eta = sqrt(mu_0 / epsilon_0)

        # Calculate the decay rate determined by Gaussian pulse width
        alpha = (1.0 / (self.gaussian_pulse_width * self.dt / 4.0))**2

        # Calculate the period
        period = 2.0 * self.gaussian_pulse_width * self.dt

        # Spatial delay of each cell
        x_disp = -cos(radians(self.incident_angle))
        y_disp = -sin(radians(self.incident_angle))

        if x_disp < 0:
            delay -= x_disp * (self.nx - 2.0) * self.dx

        if y_disp < 0:
            delay -= y_disp * (self.ny - 2.0) * self.dy

        for i in range(self.number_of_pml, self.nx - self.number_of_pml):
            for j in range(self.number_of_pml, self.ny - self.number_of_pml):
                distance = i * self.dx * x_disp + j * self.dy * y_disp + delay
                a_prime = 0
                tau = t - distance / c

                if 0 <= tau <= period:
                    a_prime = exp(-alpha * (tau - self.gaussian_pulse_width * self.dt) ** 2) \
                              * (-2.0 * alpha * (tau - self.gaussian_pulse_width * self.dt))

                self.dhzi[i][j] = self.gaussian_pulse_amplitude * a_prime / eta

    def hscattered_te(self, t):
        """
        Calculate the scattered magnetic field for TE mode.
        :param t: Time (s).
        :return:
        """
        # Calculate the incident magnetic field
        self.hincident_te(t)

        # Update the scattered magnetic field
        for i in range(self.nx - 1):
            for j in range(self.ny - 1):
                self.hzs[i][j] = self.hzs[i][j] - (self.eys[i + 1][j] - self.eys[i][j]) * self.dtmdx[i][j] \
                                + (self.exs[i][j + 1] - self.exs[i][j]) * self.dtmdy[i][j] - self.hdhvcn[i][j] \
                                * self.dhzi[i][j]

    def eincident_tm(self, t):
        """
        Calculate the incident electric field for TM mode.
        :param t: Time (s).
        :return:
        """
        # Calculate the incident electric field and derivative
        delay = 0

        # Calculate the decay rate determined by Gaussian pulse width
        alpha = (1.0 / (self.dt * self.gaussian_pulse_width / 4.0))**2

        # Calculate the period
        period = 2.0 * self.dt * self.gaussian_pulse_width

        # Spatial delay of each cell
        x_disp = -cos(radians(self.incident_angle))
        y_disp = -sin(radians(self.incident_angle))

        if x_disp < 0:
            delay -= x_disp * (self.nx - 2.0) * self.dx

        if y_disp < 0:
            delay -= y_disp * (self.ny - 2.0) * self.dy

        for i in range(self.number_of_pml, self.nx - self.number_of_pml):
            for j in range(self.number_of_pml, self.ny - self.number_of_pml):
                distance = i * self.dx * x_disp + j * self.dy * y_disp + delay
                a = 0
                a_prime = 0
                tau = t - distance / c

                if 0 <= tau <= period:
                    a = exp(-alpha *
                            (tau - self.gaussian_pulse_width * self.dt)**2)
                    a_prime = exp(-alpha * (tau - self.gaussian_pulse_width * self.dt) ** 2) \
                              * (-2.0 * alpha * (tau - self.gaussian_pulse_width * self.dt))

                self.ezi[i][j] = self.gaussian_pulse_amplitude * a
                self.dezi[i][j] = self.gaussian_pulse_amplitude * a_prime

    def escattered_tm(self, t):
        """
        Calculate the scattered electric field for TM mode.
        :param t: Time (s).
        :return:
        """
        # Calculate the incident electric field
        self.eincident_tm(t)

        # Update the z-component electric scattered field
        for i in range(1, self.nx - 1):
            for j in range(1, self.ny - 1):
                self.ezs[i][j] = self.ezs[i][j] * self.esctc[i][j] - self.eincc[i][j] * self.ezi[i][j] \
                                 - self.edevcn[i][j] * self.dezi[i][j] + (self.hys[i][j] - self.hys[i - 1][j]) \
                                 * self.ecrlx[i][j] - (self.hxs[i][j] - self.hxs[i][j - 1]) * self.ecrly[i][j]

    def hincident_tm(self, t):
        """
        Calculate the incident magnetic field for TM mode.
        :param t: Time (s).
        :return:
        """
        # Calculate the incident magnetic field and derivative
        delay = 0
        eta = sqrt(mu_0 / epsilon_0)

        # Calculate the decay rate determined by Gaussian pulse width
        alpha = (1.0 / (self.dt * self.gaussian_pulse_width / 4.0))**2

        # Calculate the period
        period = 2.0 * self.dt * self.gaussian_pulse_width

        # Spatial delay of each cell
        x_disp = -cos(radians(self.incident_angle))
        y_disp = -sin(radians(self.incident_angle))

        if x_disp < 0:
            delay -= x_disp * (self.nx - 2.0) * self.dx

        if y_disp < 0:
            delay -= y_disp * (self.ny - 2.0) * self.dy

        for i in range(self.number_of_pml, self.nx - self.number_of_pml):
            for j in range(self.number_of_pml, self.ny - self.number_of_pml):
                distance = i * self.dx * x_disp + j * self.dy * y_disp + delay
                a = 0
                a_prime = 0
                tau = t - distance / c

                if 0 <= tau <= period:
                    a = exp(-alpha *
                            (tau - self.gaussian_pulse_width * self.dt)**2)
                    a_prime = exp(-alpha * (tau - self.gaussian_pulse_width * self.dt) ** 2) \
                              * (-2.0 * alpha * (tau - self.gaussian_pulse_width * self.dt))

                self.dhxi[i][j] = self.gaussian_pulse_amplitude * a_prime / eta
                self.dhyi[i][j] = self.gaussian_pulse_amplitude * a_prime / eta

    def hscattered_tm(self, t):
        """
        Calculate the scattered magnetic field for TM mode.
        :param t:
        :return:
        """
        # Calculate the incident magnetic field
        self.hincident_tm(t)

        # Update the X component of the magnetic scattered field
        for i in range(1, self.nx - 1):
            for j in range(self.ny - 1):
                self.hxs[i][j] = self.hxs[i][j] - (
                    self.ezs[i][j + 1] - self.ezs[i][j]) * self.dtmdx[i][j]

        for i in range(self.nx - 1):
            for j in range(1, self.ny - 1):
                self.hys[i][j] = self.hys[i][j] + (
                    self.ezs[i + 1][j] - self.ezs[i][j]) * self.dtmdx[i][j]

    def _update_canvas(self):
        # Remove the color bar
        try:
            self.cbar.remove()
        except:
            print('Initial Plot')

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Calculate the total field for plotting
        etotal = zeros([self.nx, self.ny])

        if self.mode == 'TM':
            for i in range(self.nx):
                for j in range(self.ny):
                    etotal[i][j] = self.ezi[i][j] + self.ezs[i][j]
        else:
            for i in range(self.nx):
                for j in range(self.ny):
                    extotal = self.exi[i][j] + self.exs[i][j]
                    eytotal = self.eyi[i][j] + self.eys[i][j]
                    etotal[i][j] = sqrt(extotal**2 + eytotal**2)

        # x and y grid for plotting
        x = linspace(0, self.nx * self.dx, self.nx)
        y = linspace(0, self.ny * self.dy, self.ny)

        x_grid, y_grid = meshgrid(x, y)

        # Display the results
        im = self.axes1.pcolor(x_grid,
                               y_grid,
                               abs(etotal),
                               cmap="jet",
                               vmin=0,
                               vmax=self.gaussian_pulse_amplitude)
        self.cbar = self.fig.colorbar(im,
                                      ax=self.axes1,
                                      orientation='vertical')
        self.cbar.set_label('Electric Field (V/m)', size=10)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Update the canvas
        self.my_canvas.draw()
        self.my_canvas.flush_events()
Exemple #3
0
class Shnidman(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.probability_of_detection.returnPressed.connect(
            self._update_canvas)
        self.probability_of_false_alarm.returnPressed.connect(
            self._update_canvas)
        self.number_of_pulses.returnPressed.connect(self._update_canvas)
        self.target_type.currentIndexChanged.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea,
                        NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value.
        :return:
        """
        # Get the parameters from the form
        pd = self.probability_of_detection.text().split(',')
        pd_all = linspace(float(pd[0]), float(pd[1]), 200)
        pfa = float(self.probability_of_false_alarm.text())
        number_of_pulses = int(self.number_of_pulses.text())

        # Get the selected target type from the form
        target_type = self.target_type.currentText()

        # Calculate the error in the Shnidman approximation of signal to noise
        error = [
            10.0 *
            log10(single_pulse_snr(p, pfa, number_of_pulses, target_type)) -
            signal_to_noise(p, pfa, number_of_pulses, target_type)
            for p in pd_all
        ]

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Display the results
        self.axes1.plot(pd_all, error, '')

        # Set the plot title and labels
        self.axes1.set_title('Shnidman\'s Approximation', size=14)
        self.axes1.set_xlabel('Probability of Detection', size=12)
        self.axes1.set_ylabel('Signal to Noise Error (dB)', size=12)
        self.axes1.set_ylim(-1, 1)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Update the canvas
        self.my_canvas.draw()
Exemple #4
0
class ReflectionTransmission(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.frequency.returnPressed.connect(self._update_results)
        self.relative_permittivity_1.returnPressed.connect(self._update_results)
        self.relative_permeability_1.returnPressed.connect(self._update_results)
        self.conductivity_1.returnPressed.connect(self._update_results)
        self.relative_permittivity_2.returnPressed.connect(self._update_results)
        self.relative_permeability_2.returnPressed.connect(self._update_results)
        self.conductivity_2.returnPressed.connect(self._update_results)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.axes1 = fig.add_subplot(111)
        self.axes2 = self.axes1.twinx()
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea, NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_results()

    def _update_results(self):
        """
        Update the results when the user changes an input value.
        :return:
        """
        # Get the parameters from the form
        self.relative_permittivity = array([float(self.relative_permittivity_1.text()),
                                             float(self.relative_permittivity_2.text())])

        self.relative_permeability = array([float(self.relative_permeability_1.text()),
                                       float(self.relative_permeability_2.text())])

        self.conductivity = array([float(self.conductivity_1.text()), float(self.conductivity_2.text())])

        # Set up the key word args for the inputs
        kwargs = {'frequency':              float(self.frequency.text()),
                  'relative_permittivity':  self.relative_permittivity,
                  'relative_permeability':  self.relative_permeability,
                  'conductivity':           self.conductivity}

        # Calculate the critical and Brewster angles
        critical_angle = plane_waves.critical_angle(**kwargs)
        brewster_angle = plane_waves.brewster_angle(**kwargs)

        # Update the form with the results
        self.critical_angle.setText('{:.1f}'.format(critical_angle))
        self.brewster_angle.setText('{:.1f}'.format(brewster_angle))

        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value.
        :return:
        """
        # Set up the incident angles
        incident_angle = linspace(0., 0.5 * pi, 1000)

        # Set up the keyword args
        kwargs = {'frequency': float(self.frequency.text()),
                  'incident_angle': incident_angle,
                  'relative_permittivity': self.relative_permittivity,
                  'relative_permeability': self.relative_permeability,
                  'conductivity': self.conductivity}

        # Calculate the reflection and transmission coefficients
        reflection_coefficient_te, transmission_coefficient_te, reflection_coefficient_tm, \
        transmission_coefficient_tm = plane_waves.reflection_transmission(**kwargs)

        # Clear the axes for the updated plot
        self.axes1.clear()
        self.axes2.clear()

        # Display the reflection coefficients
        self.axes1.plot(degrees(incident_angle), abs(reflection_coefficient_te), 'b', label='|$\Gamma_{TE}$|')
        self.axes1.plot(degrees(incident_angle), abs(reflection_coefficient_tm), 'b--', label='|$\Gamma_{TM}$|')

        # Display the transmission coefficients
        self.axes2.plot(degrees(incident_angle), abs(transmission_coefficient_te), 'r', label='|$T_{TE}$|')
        self.axes2.plot(degrees(incident_angle), abs(transmission_coefficient_tm), 'r--', label='|$T_{TM}$|')

        # Set the plot title and labels
        self.axes1.set_title('Plane Wave Reflection and Transmission', size=14)
        self.axes1.set_xlabel('Incident Angle (degrees)', size=12)
        self.axes1.set_ylabel('|Reflection Coefficient|', size=12)
        self.axes2.set_ylabel('|Transmission Coefficient|', size=12)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)
        self.axes2.tick_params(labelsize=12)

        # Set the legend
        self.axes1.legend(loc='upper right', prop={'size': 10})
        self.axes2.legend(loc='upper left', prop={'size': 10})

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Update the canvas
        self.my_canvas.draw()
class Calibration_Window(QMainWindow):
    '''Open a calibration window that allows the user to select the calibration
    style. Allows the user to click on location in graph to select the 
    channels to calibrate and gets the required energy at that point
    '''
    counts = []
    channels = []
    calibration_lines = []
    energies = {}

    def __init__(self):
        super().__init__()
        self.font = QFont()
        self.font.setPointSize(12)
        self.size_policy = QSizePolicy.Expanding
        self.setWindowTitle('Energy Calibration')
        self.menu()
        self.geometry()
        self.mouse_tracking()
        self.showMaximized()
        self.show()

    def menu(self):
        self.menuFile = self.menuBar().addMenu('&File')
        self.load_new = QAction('&Load New Spectrum')
        self.load_new.triggered.connect(self.new_spectrum)
        self.load_new.setShortcut('Ctrl+O')
        self.load_new.setToolTip('Load a raw spectrum')

        self.rebin_action = QAction('&Rebin Data')
        self.rebin_action.triggered.connect(self.rebin)
        self.rebin_action.setEnabled(False)
        self.rebin_action.setShortcut('Ctrl+B')

        self.rebin_action_save = QAction('&Save Rebin Count')
        self.rebin_action_save.triggered.connect(self.save_rebinned)
        self.rebin_action_save.setEnabled(False)
        self.rebin_action_save.setShortcut('Ctrl+Shift+S')

        self.calibrateAction = QAction('&Calibrate')
        self.calibrateAction.triggered.connect(self.calibration)
        self.calibrateAction.setShortcut('Ctrl+C')
        self.calibrateAction.setDisabled(True)

        self.save = QAction('&Save Calibration')
        self.save.triggered.connect(self.save_)
        self.save.setShortcut('Ctrl+S')
        self.save.setDisabled(True)
        self.menuFile.addActions([
            self.load_new, self.rebin_action, self.calibrateAction, self.save,
            self.rebin_action_save
        ])

    def geometry(self):
        '''Setup the geometry
        '''
        self.added_values = QListView(self)
        self.added_values.setFont(self.font)
        self.added_values.setSizePolicy(self.size_policy, self.size_policy)
        self.added_values.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.added_ = QDockWidget('Calibration Values')
        self.added_.setWidget(self.added_values)
        self.addDockWidget(Qt.LeftDockWidgetArea, self.added_)

        self.loaded = QStandardItemModel()
        self.added_values.setModel(self.loaded)
        self.added_values.doubleClicked[QModelIndex].connect(self.update)

        self.calibrate = QPushButton('Calibrate')
        self.calibrate.setFont(self.font)
        self.calibrate.setSizePolicy(self.size_policy, self.size_policy)
        self.calibrate.clicked.connect(self.calibration)

        self.plot = QWidget()
        layout = QVBoxLayout()
        self.figure = Figure()
        self.canvas = FigureCanvas(self.figure)
        self.toolbar = NavigationToolbar(self.canvas, self)
        layout.addWidget(self.toolbar)
        layout.addWidget(self.canvas)
        self.plot.setLayout(layout)
        self.ax = self.canvas.figure.subplots()
        self.ax.set_yscale('log')
        self.ax.set_xlabel('Channel')
        self.ax.set_ylabel('Counts')
        self.figure.tight_layout()

        self.calibrated_plot = QWidget()
        layout1 = QVBoxLayout()
        self.figure1 = Figure()
        self.canvas1 = FigureCanvas(self.figure1)
        self.toolbar1 = NavigationToolbar(self.canvas1, self)
        layout1.addWidget(self.toolbar1)
        layout1.addWidget(self.canvas1)
        self.calibrated_plot.setLayout(layout1)
        self.ax1 = self.canvas1.figure.subplots()
        self.ax1.set_yscale('log')
        self.ax1.set_xlabel('Energy (MeV)')
        self.ax1.set_ylabel('Counts')
        self.ax1.set_title('Current Linear Calibration')
        self.figure1.tight_layout()

        main = QWidget()
        main_lay = QVBoxLayout()
        main_lay.addWidget(self.plot)
        main_lay.addWidget(self.calibrated_plot)
        main.setLayout(main_lay)
        self.setCentralWidget(main)

    def mouse_tracking(self):
        self.lx = self.ax.axvline(color='k', linestyle='--')
        self.txt = self.ax.text(0.8, 0.9, "", transform=self.ax.transAxes)
        self.figure.canvas.mpl_connect('motion_notify_event', self.mouse_move)

    def mouse_move(self, event):
        if not event.inaxes:
            return
        x = event.xdata
        self.lx.set_xdata(x)
        self.txt.set_text('Channel: {:.2f}'.format(x))
        self.canvas.draw()

    def mouse_click(self, event):
        if not event.inaxes:
            return
        if event.dblclick:
            if event.button == 3:
                e, ok = QInputDialog.getDouble(self, 'Energy', 'Energy:(MeV)',
                                               10, 0, 15, 10)
                if ok:
                    self.calibration_lines.append(round(event.xdata, 2))
                    self.energies['{:.2f}'.format(
                        event.xdata)] = '{:.4f}'.format(e)
                    self.loaded.appendRow(
                        QStandardItem('Ch: {:.2f}->Energy: {:.4f} MeV'.format(
                            event.xdata, e)))
                    if len(self.calibration_lines) >= 2:
                        self.calibrateAction.setEnabled(True)
                self.replot()

    def new_spectrum(self):
        #load the file from a text or csv file
        fileName, ok = QFileDialog.getOpenFileName(
            self, 'Raw Spectrum', '',
            'Text File (*.txt);;Comma Seperate File (*.csv);;IAEA(*.spe)')
        if ok and fileName != '':
            #clear the list view to start calibrating again
            self.loaded.removeRows(0, self.loaded.rowCount())
            self.energies = {}
            self.calibration_lines = []
            self.counts = []
            self.channels = []
            if '.spe' not in fileName.lower():
                f = open(fileName, 'r')
                f_data = f.readlines()
                f.close()

                headers = f_data[0].split(sep=',')
                h_space = f_data[0].split(sep=' ')
                mult = False
                if len(headers) > 1 or len(h_space) > 1:
                    headers[-1] = headers[-1].split(sep='\n')[0]
                    item, ok = QInputDialog.getItem(self, 'Select Data Header',
                                                    'Header:', headers, 0,
                                                    False)
                    if ok:
                        column = headers.index(item)
                else:
                    mult = True
                    column = 0
                if ok or mult:
                    for i in range(len(f_data)):
                        try:
                            self.counts.append(
                                float(f_data[i].split(sep=',')[column]))
                            self.channels.append(i)
                        except:
                            True
            else:
                self.counts, self.channels = self.load_spe(fileName)
            self.rebin_action.setEnabled(True)
            self.rebin_action_save.setEnabled(True)
            self.replot(left_lim=0,
                        right_lim=len(self.channels) +
                        0.01 * len(self.channels))

    def find_peaks(self, x, width, distance):
        peaks, properties = signal.find_peaks(x,
                                              width=width,
                                              distance=distance)
        e_res = []
        widths = properties['widths']  #the fwhm of the peak
        left = properties['left_ips']  #left point of the fwhm
        right = properties['right_ips']  #right point of the fwhm
        sigma = [i / (2 * np.sqrt(2 * np.log(2)))
                 for i in widths]  #standard deviation
        left_sig = []
        right_sig = []
        #recalculate the peak location based on the average fo the left and right fwhm
        for i in range(len(peaks)):
            avg = (left[i] + right[i]) / 2
            peaks[i] = avg
            left_sig.append(avg - 4 * sigma[i])
            right_sig.append(avg + 4 * sigma[i])
            e_res.append(widths[i] / avg * 100)

        return peaks, e_res, left_sig, right_sig

    def replot(self, left_lim=None, right_lim=None):
        '''Redraw the uncalibrated spectrum'''
        l, r = self.ax.get_xlim()
        self.ax.clear()
        self.mouse_tracking()
        self.figure.canvas.mpl_connect('button_press_event', self.mouse_click)
        self.ax.set_yscale('log')
        self.ax.set_xlabel('Channel')
        self.ax.set_ylabel('Counts')

        self.ax.plot(self.channels, self.counts)
        if left_lim != None and right_lim != None:
            self.ax.set_xlim(left_lim, right_lim)
        else:
            self.ax.set_xlim(l, r)
        #let an algorithm find the peaks
        peaks = self.find_peaks(self.counts, 3, 2)[0]
        for i in peaks:
            self.ax.axvline(x=i, color='r', linestyle='--', linewidth=0.5)
        for i in self.calibration_lines:
            self.ax.axvline(x=i, color='k', linestyle='--')

        self.canvas.draw()
        self.figure.tight_layout()
        self.replot_calibration()

    def replot_calibration(self, left_lim=None, right_lim=None):
        '''Replot the calibrated spectrum:
            If 0 data points are entered, set intercept to 0 and
            set maximum energy to be 3MeV
            If 1 data point is enetered, set intercep to 0 and 
            find slope to be E/C (energy entered divided by bin num)
            If 2 or more data points are entered, call the linear calibration
            functions
        '''
        #calibration channel numbers: self.calibration_lines-> list
        #calibration energies: self.energies-> dictionary with string of cali as key
        self.ax1.clear()
        if len(self.calibration_lines) == 0:
            #take the channels and scale them from 0-3MeV
            slope = 3 / len(self.channels)
            calibrated = [i * slope for i in self.channels]
        elif len(self.calibration_lines) == 1:
            #get the energy and channel number
            ch = self.calibration_lines[0]
            en = float(self.energies[str(ch)])
            #the slope if those divided
            slope = en / ch
            calibrated = [i * slope for i in self.channels]

        else:
            channels = self.calibration_lines
            energies = self.energies.values()
            energies = [float(i) for i in energies]
            calibrated, m, b = cali(self.channels).linear_least_squares_fit(
                channels, energies, live_plotter=True)
            x = max(calibrated) - 0.12 * max(calibrated)
            y = max(self.counts) / 5
            self.ax1.annotate(
                'Slope: {:.3f} keV/ch\nIntercept: {:.3f}keV'.format(
                    m * 1000, b * 1000),
                xy=(x, y))
        self.ax1.set_xlim(0, max(calibrated))
        self.ax1.set_yscale('log')
        self.ax1.set_xlabel('Energy (MeV)')
        self.ax1.set_ylabel('Counts')
        self.ax1.set_title('Current Linear Calibration')
        self.ax1.plot(calibrated, self.counts)
        enr = list(self.energies.values())
        enr = [float(i) for i in enr]
        for i in enr:
            self.ax1.axvline(x=i, color='k', linestyle='--')
        self.canvas1.draw()
        self.figure1.tight_layout()

    def calibration(self):
        channels = list(self.energies.keys())
        energies = list(self.energies.values())
        channels = [float(i) for i in channels]
        energies = [float(j) for j in energies]

        #get the calibration method desired
        items = ('Linear', 'Deviation Pairs', 'External Calibration',
                 'Segemented Linear')
        item, ok = QInputDialog.getItem(self, 'Calibration Type',
                                        'Calibration:', items, 0, False)
        if ok and item:
            if item == items[0]:
                self.cal_values = cali(self.channels).linear_least_squares_fit(
                    channels, energies)
            elif item == items[1]:
                self.cal_values = cali(self.channels).deviation_pairs(
                    channels, energies)
            elif item == items[2]:
                text, ok = QInputDialog.getText(
                    self, 'Slope and Intercept',
                    'Slope [MeV/Ch],Intercept[MeV]:', QLineEdit.Normal, "")
                if ok and len(text.split(sep=',')) == 2:
                    vals = text.split(sep=',')
                    self.cal_values = cali(self.channels).external_calibration(
                        float(vals[0]), float(vals[1]))
            elif item == items[3]:
                self.cal_values = cali(
                    self.channels).segmented_linear_least_squares(
                        channels, energies)

            plt.figure(1, figsize=(5, 5))
            plt.plot(self.channels, self.cal_values)
            plt.figure(2, figsize=(5, 5))
            plt.plot(self.cal_values, self.counts)
            plt.xlabel('Energy [MeV]')
            plt.ylabel('Counts')
            plt.title('Energy Calibrated Spectrum')
            plt.yscale('log')
            plt.xlim(0, 14)
            plt.show()
            self.save.setEnabled(True)

    def save_(self):
        name, ok = QFileDialog.getSaveFileName(
            self, 'Calibration Data', '',
            'Text File (*.txt);; Comma Seperated File (*.csv)')
        if ok:
            f = open(name, 'w')
            for i in self.cal_values:
                f.write('{:.8f}\n'.format(i))

    def save_rebinned(self):
        name, ok = QFileDialog.getSaveFileName(
            self, 'Calibration Data', '',
            'Text File (*.txt);; Comma Seperated File (*.csv)')
        if ok:
            f = open(name, 'w')
            for i in self.counts:
                f.write('{:.8f}\n'.format(i))

    def update(self, index):
        item = self.loaded.itemFromIndex(index)
        val = item.text()
        ch = val.split(sep='->')[0].split(sep=': ')[1]
        self.energies.pop(ch)
        self.calibration_lines.remove(float(ch))
        self.loaded.removeRows(0, self.loaded.rowCount())
        for i in self.calibration_lines:
            self.loaded.appendRow(
                QStandardItem('Ch: {:.2f}->Energy: {} MeV'.format(
                    i, self.energies[str(i)])))
        self.replot()

    def load_spe(self, file_path):
        '''Load a spectrum file type using the SPE file format'''
        f = open(file_path, 'r')
        data = f.readlines()
        f.close()

        counts = []
        channels = []
        # num_counts=int(data[7].split()[1])
        #first find the index for $DATA so
        s_index = 0
        e_index = 0
        for i in range(len(data)):
            if '$DATA:' in data[i]:
                s_index = i + 2
        for i in range(s_index, len(data)):
            if '$' in data[i]:
                e_index = i - 1
                break

        for i in range(s_index, e_index):
            counts.append(float(data[i]))
            channels.append(i - s_index)
        return counts, channels

    def rebin(self):
        '''rebing the data and the replot it'''
        values = ['2', '4']
        selected, ok = QInputDialog.getItem(self, 'Select Rebin',
                                            'Bins to combine', values, 0,
                                            False)
        if ok:
            s = int(selected)
            self.counts = Rebins(self.counts, s).rebinner()
            self.channels = [i for i in range(len(self.counts))]
            #need to redo the calibration lines
            for i in range(len(self.calibration_lines)):
                self.calibration_lines[i] = round(
                    self.calibration_lines[i] / 2, 2)
            #next take care of the values shown on the right
            self.loaded.removeRows(0, self.loaded.rowCount())
            channels = self.calibration_lines
            energies = list(self.energies.values())
            channels = [float(i) for i in channels]
            energies = [float(j) for j in energies]
            self.energies = {}
            for i in range(len(energies)):
                self.energies['{:.2f}'.format(channels[i])] = '{:.4f}'.format(
                    energies[i])
            for i in self.calibration_lines:
                self.loaded.appendRow(
                    QStandardItem('Ch: {:.2f}->Energy: {} MeV'.format(
                        i, self.energies[str(i)])))

            self.replot(left_lim=0,right_lim=len(self.channels)+\
                        0.01*len(self.channels))
class RabbitWin(QWidget, ie.Imp_Exp_Mixin, Raf.Rabbit_functions_mixin):
    """ Rabbit window

        For the names of the children widgets, I tried to put suffixes that indicate clearly their types:
        *_btn -> QPushButton,
        *_le -> QLineEdit,
        *_lb -> QLabel,
        *layout -> QHBoxLayout, QVBoxLayout or QGridLayout,
        *_box -> QGroupBox,
        *_cb -> QCheckBox,
        *_rb -> QRadioButton

        The functions that are connected to a widget's event have the suffix _lr (for 'listener'). For example,
        a button named test_btn will be connected to a function test_lr.
        Some functions may be connected to widgets but without the suffix _lr in their names. It means that they
        are not only called when interacting with the widget.
        """
    def __init__(self, parent=None):
        """Initialization of the window

                the main layout is called mainLayout. It is divided in two:
                    - graphLayout: the left part, contains all the figures
                    - commandLayout: the right part, contains all the buttons, fields, checkboxes...
                Both graphLayout and commandLayout are divided into sub-layouts.

                This function calls several functions to initialize each part of the window.
                The name of these functions has the shape 'init_*layout'."""
        super(RabbitWin, self).__init__(parent=parent)
        self.setWindowTitle("RABBIT")
        self.mainlayout = QHBoxLayout()
        self.graphlayout = QVBoxLayout()
        self.commandLayout = QVBoxLayout()
        self.commandLayout.setSpacing(10)

        self.init_var()
        self.init_importlayout()
        self.init_envectlayout()
        self.init_sigtreatmentlayout()
        self.init_rabbitlayout()
        self.init_exportlayout()
        self.init_plotbtnlayout()
        self.init_graphlayout()

        self.mainlayout.addLayout(self.graphlayout)
        self.mainlayout.addLayout(self.commandLayout)
        self.setLayout(self.mainlayout)
        self.show()

    def init_var(self):
        ''' Initialization of instance attributes'''
        self.dataloaded = False
        self.xuvonlyloaded = False
        self.bandselected = False

        self.data_tof = []
        self.tof = []
        self.toflength = 0

        self.elength = 0

        self.SBi = 0
        self.ampl = []
        self.ang = []
        self.fpeak = []
        self.peak = []
        self.peak_phase = []
        self.freqnorm = []
        self.pa = []

        self.ampl_rainbow = []
        self.ampl_rainbow2 = []
        self.ang_rainbow = []
        self.fpeak_rainbow = []
        self.peak_rainbow = []
        self.peak_phase_rainbow = []
        self.energy_rainbow = []
        self.energy_rainbow2 = []

        self.fpeak_index = 0

    def init_importlayout(self):
        ''' In commandLayout - Initialization of the "Import" section'''
        Importlayout = QGridLayout()
        Importlayout.setSpacing(10)

        Import_box = QGroupBox(self)
        Import_box.setTitle("Import")
        Import_box.setFixedSize(300, 100)

        self.importcalib_btn = QPushButton("calib", self)
        self.importdata_btn = QPushButton("data", self)
        self.importXUV_btn = QPushButton("XUV only", self)
        self.importrabparam_btn = QPushButton("RABBIT param", self)
        self.importrab_btn = QPushButton("RABBIT", self)

        self.importcalib_btn.clicked.connect(self.importcalib_lr)
        self.importdata_btn.clicked.connect(self.importdata_lr)
        self.importXUV_btn.clicked.connect(self.importXUV_lr)
        self.importrabparam_btn.clicked.connect(self.importrabparam_lr)
        self.importrab_btn.clicked.connect(self.importrab_lr)

        Importlayout.addWidget(self.importcalib_btn, 0, 0)
        Importlayout.addWidget(self.importdata_btn, 1, 0)
        Importlayout.addWidget(self.importXUV_btn, 0, 1)
        Importlayout.addWidget(self.importrabparam_btn, 0, 2)
        Importlayout.addWidget(self.importrab_btn, 1, 2)

        Import_box.setLayout(Importlayout)
        self.commandLayout.addWidget(Import_box)

        for widget in Import_box.children():
            if isinstance(widget, QPushButton):
                widget.setSizePolicy(0, 0)
                widget.setEnabled(False)
        self.importcalib_btn.setEnabled(True)
        self.importrabparam_btn.setEnabled(True)

    def init_envectlayout(self):
        ''' In commandLayout - Initialization of the energy vector section, with elow, ehigh and dE'''
        paramlayout = QHBoxLayout()
        envectlayout = QGridLayout()
        envectlayout.setSpacing(10)
        envect_box = QGroupBox()
        envect_box.setTitle("Energy vector parameters")
        envect_box.setSizePolicy(0, 0)

        self.elow_le = QLineEdit("{:.2f}".format(cts.elow), self)
        self.ehigh_le = QLineEdit("{:.2f}".format(cts.ehigh), self)
        self.dE_le = QLineEdit(str(cts.dE), self)

        self.elow_le.returnPressed.connect(self.update_envect_fn)
        self.ehigh_le.returnPressed.connect(self.update_envect_fn)
        self.dE_le.returnPressed.connect(self.update_envect_fn)

        envectlayout.addWidget(QLabel("E low (eV)"), 0, 0)
        envectlayout.addWidget(self.elow_le, 1, 0)
        envectlayout.addWidget(QLabel("E high (eV)"), 0, 1)
        envectlayout.addWidget(self.ehigh_le, 1, 1)
        envectlayout.addWidget(QLabel("dE (eV)"), 0, 2)
        envectlayout.addWidget(self.dE_le, 1, 2)

        envect_box.setLayout(envectlayout)

        for widget in envect_box.children():
            if isinstance(widget, QLabel) or isinstance(widget, QLineEdit):
                widget.setSizePolicy(0, 0)
                widget.setFixedSize(55, 20)

        paramlayout.addWidget(envect_box)

        scanparamlayout = QVBoxLayout()

        self.scanparam_le = QLineEdit(str(cts.scanstep_nm))
        self.scanparam_le.setSizePolicy(0, 0)
        self.scanparam_le.setFixedSize(55, 20)
        self.scanparam_le.returnPressed.connect(self.update_scanparam)
        label = QLabel("scan steps (nm)")
        label.setSizePolicy(0, 0)
        label.setFixedSize(80, 20)

        scanparamlayout.addWidget(label)
        scanparamlayout.addWidget(self.scanparam_le)

        paramlayout.addLayout(scanparamlayout)

        self.commandLayout.addLayout(paramlayout)

    def init_sigtreatmentlayout(self):
        ''' In commandLayout - Initialization of the "Signal Treatment" section'''
        sigtreatment_box = QGroupBox("Signal Treatment", self)
        sigtreatmentlayout = QGridLayout()

        sigtreatment_box.setFixedSize(300, 100)

        self.smooth_btn = QPushButton("smooth", self)
        self.smooth_le = QLineEdit("2", self)
        self.normalize_btn = QPushButton("normalize", self)
        self.subXUV_btn = QPushButton("substract XUV", self)
        self.selectbands_btn = QPushButton("select bands")
        self.bandsnb_le = QLineEdit(str(cts.bandsnb), self)

        self.smooth_btn.clicked.connect(self.smoothrab_lr)
        self.normalize_btn.clicked.connect(self.normalizerab_lr)
        self.selectbands_btn.clicked.connect(self.selectbands_lr)
        self.bandsnb_le.returnPressed.connect(self.bandsnb_lr)
        self.subXUV_btn.clicked.connect(self.subXUV_lr)

        sigtreatmentlayout.addWidget(self.smooth_btn, 0, 0)
        sigtreatmentlayout.addWidget(self.smooth_le, 1, 0)
        sigtreatmentlayout.addWidget(self.normalize_btn, 0, 1)
        sigtreatmentlayout.addWidget(self.subXUV_btn, 1, 1)
        sigtreatmentlayout.addWidget(self.selectbands_btn, 0, 2)
        sigtreatmentlayout.addWidget(self.bandsnb_le, 1, 2)

        sigtreatment_box.setLayout(sigtreatmentlayout)

        for widget in sigtreatment_box.children():
            if isinstance(widget, QPushButton):
                widget.setSizePolicy(0, 0)
                widget.setEnabled(False)
            if isinstance(widget, QLineEdit):
                widget.setFixedSize(55, 20)

        self.commandLayout.addWidget(sigtreatment_box)

    def init_rabbitlayout(self):
        ''' In commandLayout - Initialization of the "RABBIT" section'''
        rabbitlayout = QGridLayout()
        rabbit_box = QGroupBox("RABBIT", self)
        rabbit_box.setFixedSize(300, 100)

        self.normalrab_btn = QPushButton("Normal", self)
        self.FTcontrast_btn = QPushButton("FT/Contrast", self)
        self.rainbowrab_btn = QPushButton("Rainbow", self)
        self.clear_btn = QPushButton("Clear", self)

        self.normalrab_btn.clicked.connect(self.normalrab_lr)
        self.FTcontrast_btn.clicked.connect(self.FTcontrast_lr)
        self.rainbowrab_btn.clicked.connect(self.rainbowrab_lr)
        self.clear_btn.clicked.connect(self.clear_lr)

        rabbitlayout.addWidget(self.normalrab_btn, 0, 0)
        rabbitlayout.addWidget(self.FTcontrast_btn, 0, 1)
        rabbitlayout.addWidget(self.rainbowrab_btn, 0, 2)
        rabbitlayout.addWidget(self.clear_btn, 1, 0)

        rabbit_box.setLayout(rabbitlayout)

        for widget in rabbit_box.children():
            if isinstance(widget, QPushButton):
                widget.setSizePolicy(0, 0)
                widget.setEnabled(False)

        self.commandLayout.addWidget(rabbit_box)

    def init_exportlayout(self):
        ''' In commandLayout - Initialization of the "Export" section'''
        exportlayout = QGridLayout()
        export_box = QGroupBox("Export", self)
        export_box.setFixedSize(300, 60)

        self.exportrab_btn = QPushButton("RABBIT", self)

        self.exportrab_btn.clicked.connect(self.exportrab_lr)

        exportlayout.addWidget(self.exportrab_btn)

        export_box.setLayout(exportlayout)

        for widget in export_box.children():
            if isinstance(widget, QPushButton):
                widget.setSizePolicy(0, 0)
                widget.setEnabled(False)
        self.commandLayout.addWidget(export_box)

    def init_plotbtnlayout(self):
        ''' In commandLayout - Initialization of the "Plot" section'''
        plotbtnlayout = QGridLayout()
        plotbtn_box = QGroupBox("Plot", self)
        plotbtn_box.setFixedSize(300, 60)

        self.plotSBvsdelay_btn = QPushButton("SB vs delay", self)
        self.plotPulseInTime_btn = QPushButton("Pulse in time", self)

        self.plotSBvsdelay_btn.clicked.connect(self.plotSBvsdelay_lr)
        self.plotPulseInTime_btn.clicked.connect(self.plotPulseInTime_lr)

        plotbtnlayout.addWidget(self.plotSBvsdelay_btn, 0, 0)
        plotbtnlayout.addWidget(self.plotPulseInTime_btn, 0, 1)

        plotbtn_box.setLayout(plotbtnlayout)

        for widget in plotbtn_box.children():
            if isinstance(widget, QPushButton):
                widget.setSizePolicy(0, 0)
                widget.setEnabled(False)

        self.commandLayout.addWidget(plotbtn_box)

    def init_graphlayout(self):
        ''' In graphLayout - Initialization of the 3 figures'''
        self.rab_widget = ow.plot3DWidget(
            self)  # new class defined in other_widgets.py
        self.rab_widget.xlabel = "E (eV)"
        self.rab_widget.ylabel = "t (fs)"

        self.graphlayout.addWidget(self.rab_widget)

        self.phaseFTlayout = QHBoxLayout()
        self.phaselayout = QVBoxLayout()

        phase_fig = Figure(figsize=(2, 2), dpi=100)
        self.phase_fc = FigureCanvas(phase_fig)
        self.phase_fc.setSizePolicy(1, 0)
        self.phase_ax = self.phase_fc.figure.add_subplot(111)
        self.phase_ax.tick_params(labelsize=8)
        nav = NavigationToolbar2QT(self.phase_fc, self)
        nav.setStyleSheet("QToolBar { border: 0px }")
        self.phaselayout.addWidget(self.phase_fc)
        self.phaselayout.addWidget(nav)
        self.phase_fc.draw()

        self.FTlayout = QVBoxLayout()
        FT_fig = Figure(figsize=(2, 2), dpi=100)
        self.FT_fc = FigureCanvas(FT_fig)
        self.FT_fc.setSizePolicy(1, 0)
        self.FT_ax = self.FT_fc.figure.add_subplot(111)
        self.FT_ax.tick_params(labelsize=8)
        nav2 = NavigationToolbar2QT(self.FT_fc, self)
        nav2.setStyleSheet("QToolBar { border: 0px }")
        self.FTlayout.addWidget(self.FT_fc)
        self.FTlayout.addWidget(nav2)
        self.FT_fc.draw()

        self.phaseFTlayout.addLayout(self.phaselayout)
        self.phaseFTlayout.addLayout(self.FTlayout)

        self.graphlayout.addLayout(self.phaseFTlayout)

    def selectbands_lr(self):
        ''' "select bands" button listener'''
        try:
            cts.bandsnb = int(self.bandsnb_le.text())
            cts.bands_vect = np.zeros([cts.bandsnb, 2])
            self.subXUV_btn.setEnabled(False)
            self.normalrab_btn.setEnabled(False)
            self.rainbowrab_btn.setEnabled(False)
            self.FTcontrast_btn.setEnabled(False)
            self.plotSBvsdelay_btn.setEnabled(False)
            self.window().updateglobvar_fn()
            nw = Rsb.selectBandsWin(
                self)  # new class defined in Rabbit_select_bands.py
        except ValueError:
            self.window().statusBar().showMessage(
                "Number of bands must be an integer")

    def bandsnb_lr(self):
        ''' called when pressing enter in the bandsnb_le object'''
        try:
            cts.bandsnb = int(self.bandsnb_le.text())
            self.window().updateglobvar_fn()
        except ValueError:
            self.window().statusBar().showMessage(
                "Number of bands must be an integer")

    def subXUV_lr(self):
        ''' "substract XUV" button listener. Opens a new window'''
        cts.xuvsubstracted = False
        sw = subXUVWin(self)  # new class defined below

    def FTcontrast_lr(self):
        ''' "FT/contrast" button listener. Opens a new window'''
        try:
            w = ftcw.FTContrastWin(self)  # new class defined below
        except Exception:
            print(traceback.format_exception(*sys.exc_info()))

    def plotSBvsdelay_lr(self):
        ''' "[plot] SB vs delay" button listener. Opens a new window'''
        try:
            w = sbdw.SBvsDelayWin(self)  # new class defined below
        except Exception:
            print(traceback.format_exception(*sys.exc_info()))

    def plotPulseInTime_lr(self):
        ''' "[plot] Pulse in time" button listener. Opens a new window'''
        try:
            w = pitw.PulseInTimeWin(self)  # new class defined below
        except Exception:
            print(traceback.format_exception(*sys.exc_info()))

    def update_scanparam(self):
        ''' Updates the values of the scan steps, in nm and fs'''
        try:
            cts.scanstep_nm = float(self.scanparam_le.text())
            cts.scanstep_fs = float(
                self.scanparam_le.text()) * 2 / (cts.C * 1e-6)
            self.window().updateglobvar_fn()
        except ValueError:
            print('Scan step must be a number')

    def update_envect_fn(self):
        ''' Updates the energy vector parameters'''

        cts.elow = float(self.elow_le.text())
        cts.ehigh = float(self.ehigh_le.text())
        cts.dE = float(self.dE_le.text())
        self.elow_le.setText("{:.2f}".format(cts.elow))
        self.ehigh_le.setText("{:.2f}".format(cts.ehigh))
        self.window().updateglobvar_fn()

    def reset_btn(self):
        ''' "Reset" button listener. Resets the widgets, not the variables'''
        self.importXUV_btn.setEnabled(False)
        self.importdata_btn.setEnabled(False)
        self.importrab_btn.setEnabled(False)
        self.smooth_btn.setEnabled(False)
        self.normalize_btn.setEnabled(False)
        self.subXUV_btn.setEnabled(False)
        self.selectbands_btn.setEnabled(False)
        self.normalrab_btn.setEnabled(False)
        self.FTcontrast_btn.setEnabled(False)
        self.rainbowrab_btn.setEnabled(False)
        self.exportrab_btn.setEnabled(False)
        self.plotSBvsdelay_btn.setEnabled(False)

        self.rab_widget.colorauto_cb.setEnabled(False)
        self.rab_widget.logscale_cb.setEnabled(False)
class BinaryIntegration(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.signal_to_noise.returnPressed.connect(self._update_canvas)
        self.probability_of_false_alarm.returnPressed.connect(
            self._update_canvas)
        self.m.returnPressed.connect(self._update_canvas)
        self.n.returnPressed.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea,
                        NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value.
        :return:
        """
        # Get the parameters from the form
        snr_db = self.signal_to_noise.text().split(',')
        snr = 10.0**(linspace(float(snr_db[0]), float(snr_db[1]), 200) / 10.0)
        pfa = float(self.probability_of_false_alarm.text())
        m = int(self.m.text())
        n = int(self.n.text())

        # Calculate the probability of detection
        pd = [
            probability_of_detection(m, n, pd_rayleigh(isnr, pfa))
            for isnr in snr
        ]

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Display the results
        self.axes1.plot(10.0 * log10(snr), pd, '')

        # Set the plot title and labels
        self.axes1.set_title('Binary Integration (M of N)', size=14)
        self.axes1.set_xlabel('Signal to Noise (dB)', size=12)
        self.axes1.set_ylabel('Probability of Detection', size=12)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Update the canvas
        self.my_canvas.draw()
Exemple #8
0
class ADC(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.number_of_bits.returnPressed.connect(self._update_canvas)
        self.sampling_frequency.returnPressed.connect(self._update_canvas)
        self.start_frequency.returnPressed.connect(self._update_canvas)
        self.end_frequency.returnPressed.connect(self._update_canvas)
        self.am_amplitude.returnPressed.connect(self._update_canvas)
        self.am_frequency.returnPressed.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.axes1 = fig.add_subplot(211)
        self.axes2 = fig.add_subplot(212)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea,
                        NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value.
        :return:
        """
        # Get the parameters from the form
        number_of_bits = int(self.number_of_bits.text())
        sampling_frequency = float(self.sampling_frequency.text())
        start_frequency = float(self.start_frequency.text())
        end_frequency = float(self.end_frequency.text())
        am_amplitude = float(self.am_amplitude.text())
        am_frequency = float(self.am_frequency.text())

        # Analog signal for plotting
        t = linspace(0.0, 1.0, 4196)
        a_signal = chirp(t, start_frequency, t[-1], end_frequency)
        a_signal *= (1.0 + am_amplitude * sin(2.0 * pi * am_frequency * t))

        # Set up the waveform
        time = arange(sampling_frequency + 1) / sampling_frequency
        if_signal = chirp(time, start_frequency, time[-1], end_frequency)
        if_signal *= (1.0 + am_amplitude * sin(2.0 * pi * am_frequency * time))

        # Calculate the envelope
        quantized_signal, error_signal = quantization.quantize(
            if_signal, number_of_bits)

        # Clear the axes for the updated plot
        self.axes1.clear()
        self.axes2.clear()

        # Display the results
        self.axes1.plot(t, a_signal, '', label='Analog Signal')
        self.axes1.plot(time, quantized_signal, '-.', label='Digital Signal')
        self.axes2.plot(time, error_signal, '', label='Quadrature')

        # Set the plot title and labels
        self.axes1.set_title('Analog to Digital Conversion', size=14)
        self.axes2.set_xlabel('Time (s)', size=12)
        self.axes1.set_ylabel('Amplitude (V)', size=12)
        self.axes2.set_ylabel('Error (V)', size=12)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)
        self.axes2.tick_params(labelsize=12)

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)
        self.axes2.grid(linestyle=':', linewidth=0.5)

        # Show the legend
        self.axes1.legend(loc='lower left', prop={'size': 10})

        # Update the canvas
        self.my_canvas.draw()
Exemple #9
0
class Points_Input(QWidget):
    def __init__(self, parent):
        QWidget.__init__(self, parent)
        self.layout = QVBoxLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        # Creating de graph
        self.fig = plt.figure(2)
        self.ax = plt.subplot()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setFocus()
        self.layout.addWidget(self.canvas)

        self.init_graph()
        self.canvas.draw()

    def init_graph(self):
        plt.figure(2)
        plt.tight_layout()
        self.ax = plt.gca()
        self.fig.set_facecolor('#323232')
        self.ax.grid(zorder=0)
        self.ax.set_axisbelow(True)
        self.ax.set_xlim([0, 5])
        self.ax.set_ylim([0, 5])
        self.ax.set_xticks(range(0, 6))
        self.ax.set_yticks(range(0, 6))
        self.ax.axhline(y=0, color='#323232')
        self.ax.axvline(x=0, color='#323232')
        self.ax.spines['right'].set_visible(False)
        self.ax.spines['top'].set_visible(False)
        self.ax.spines['bottom'].set_visible(False)
        self.ax.spines['left'].set_visible(False)
        self.ax.tick_params(axis='x', colors='#b1b1b1')
        self.ax.tick_params(axis='y', colors='#b1b1b1')

    def clearPlot(self):
        self.fig = plt.figure(2)
        self.fig.clf()
        self.ax = plt.gca()
        self.ax.cla()
        # self.init_graph()
        self.canvas.draw()

    def plot_lines(self, net):
        self.clearPlot()
        self.fig = plt.figure(2)
        self.ax = plt.gca()
        self.fig.set_facecolor('#323232')
        # setup axes
        # self.ax = self.fig.add_subplot(111, aspect='equal')
        self.ax.set_xlim((0, net.shape[0] + 1))
        self.ax.set_ylim((0, net.shape[1] + 1))
        self.ax.tick_params(axis='x', colors='#b1b1b1')
        self.ax.tick_params(axis='y', colors='#b1b1b1')

        # plot the rectangles
        for x in range(1, net.shape[0] + 1):
            for y in range(1, net.shape[1] + 1):
                face_color = net[x - 1, y - 1, :]
                face_color = [
                    sum(face_color[:3]) / 3,
                    sum(face_color[3:6]) / 3,
                    sum(face_color[6:]) / 4
                ]
                self.ax.add_patch(
                    patches.Rectangle(
                        (x - 0.5, y - 0.5),
                        1,
                        1,
                        #  facecolor=net[x-1,y-1,:],
                        facecolor=face_color,
                        edgecolor='none'))
        self.canvas.draw()
Exemple #10
0
class scattering_window(lattice_window):
    def __init__(self):
        super().__init__()

    def create_variables(self):
        self.lattice_names = [
            'cubic with a basis', 'simple cubic', 'conventional fcc',
            'conventional bcc'
        ]
        self.lattices = [
            'cubic with a basis', 'simple cubic', 'conventional fcc',
            'conventional bcc'
        ]
        self.colors = [
            'xkcd:cement', 'red', 'blue', 'green', 'cyan', 'magenta', 'black',
            'yellow'
        ]
        # A dictionary of the default valueFalse lattice plotting
        self.default_config = {
            'a1':
            d[0],
            'a2':
            d[1],
            'a3':
            d[2],
            'k_in':
            np.array([0, 0, -1.5]),
            'indices':
            [None, [-1, -1, 2], [-1, 1, 2], [0, 0, 3], [1, -1, 2], [1, 1, 2]],
            'highlight':
            None,
            'show_all':
            True,
            'preset_basis':
            d[3],
            'user_colors': ['xkcd:cement'] * 5,
            'form_factors': [1] * 5,
            'enabled_form_factors': [1],
            'enabled_user_colors': ['xkcd:cement'],
            'preset_colors': ['xkcd:cement'] * 4,
            'sizes':
            d[5],
            'enabled_user_basis':
            np.zeros((1, 3)),
            'user_basis':
            np.zeros((5, 3)),
            'lattice':
            'cubic with a basis',
            'max_preset_basis':
            4
        }
        self.presets_with_basis = {
            'simple cubic': 1,
            'conventional fcc': 4,
            'conventional bcc': 2
        }
        self.lattice_config = self.default_config.copy()
        self.current = 'user'

    def create_plot_window(self):
        # Create the default plot and return the figure and axis objects for
        # it. Then create the FigureCanvas, add them all to the layout and add
        # a toolbar. Lastly enable mouse support for Axes3D
        self.static_fig, self.static_ax, self.static_ax2, _ = Scattering(
            returns=True, return_indices=True, plots=False)
        self.static_canvas = FigureCanvas(self.static_fig)
        self.addToolBar(NavigationToolbar(self.static_canvas, self))
        self.static_ax.mouse_init()

    def create_options(self):
        self.layout_options = QW.QVBoxLayout()
        self.layout_options_form = QW.QFormLayout()
        # Create the lattice chooser dropdown
        self.lattice_chooser = QW.QComboBox(self)
        self.lattice_chooser.addItems(self.lattices)
        sep_index = len(self.lattices) - len(self.presets_with_basis)
        self.lattice_chooser.insertSeparator(sep_index)
        self.lattice_chooser.activated[str].connect(self.update_lattice_name)

        # Create the k_in fields
        self.layout_k_in = QW.QHBoxLayout()
        k_in_label = QW.QLabel('k_in, [2pi/a]')
        self.k_in_fields = []
        for i in range(3):
            el = QW.QLineEdit()
            el.setText(str(self.lattice_config['k_in'][i]))
            el.setValidator(QG.QDoubleValidator(decimals=2))
            el.editingFinished.connect(
                lambda i=i, el=el: self.update_k(i, el.text()))
            self.k_in_fields.append(el)
            self.layout_k_in.addWidget(el)

        # Highlighting stuff
        highlight_label = QW.QLabel('Highlight indices')
        self.highlight_combo = QW.QComboBox()
        str_indices = [str(i) for i in self.lattice_config['indices']]
        self.highlight_combo.addItems(str_indices)
        self.highlight_combo.activated[int].connect(self.update_highlight)

        # The show all checkbox
        show_all_label = QW.QLabel('Show all')
        self.show_all_checkbox = QW.QCheckBox()
        self.show_all_checkbox.setChecked(True)
        self.show_all_checkbox.stateChanged.connect(self.show_all)

        # Note on k_in
        str_ = ('Notes:\n\n'
                'k_in is specified in units of 2pi/a, '
                'and that the z-component will always be '
                'passed as a negative value. So -|k_in,z|. \n\n'
                'Highlighting a set of Miller indices shows the following:\n'
                '- The outgoing wave vector in red\n'
                '- The reciprocal lattice vector, which gave rise to the '
                'scattering event, in green\n'
                '- The family of lattice planes for the reciprocal lattice '
                'vector.')
        note_label = QW.QLabel(str_)
        note_label.setWordWrap(True)

        # Add stuff to the layout
        self.layout_options.addLayout(self.layout_options_form)
        self.layout_options_form.addRow(self.lattice_chooser)
        self.layout_options_form.addRow(k_in_label, self.layout_k_in)
        self.layout_options_form.addRow(highlight_label, self.highlight_combo)
        self.layout_options_form.addRow(show_all_label, self.show_all_checkbox)
        self.layout_options_form.addRow(note_label)
        self.create_user_basis()
        self.add_form_factors()

    def update_lattice_name(self, text):
        # Delete current basis layout.
        self.delete_layout(self.current_basis_layout)
        if text in self.presets_with_basis:
            # We have a preset with a basis, so we delete the user basis and
            # load the preset basis
            self.create_preset_basis(self.presets_with_basis[text])
        else:
            self.create_user_basis()
        self.lattice_config['lattice'] = text
        self.add_form_factors()

        # And then we update the lattice
        self.update_lattice()

    def update_lattice(self):
        # Grab a new lattice based on the parameters in lattice_config
        a = 1
        name = self.lattice_config['lattice']
        _, basis, _ = lattices.chooser(lattice_name=name, a=a)

        # Update primitive lattice vectors and (preset) basis.
        self.lattice_config['preset_basis'] = basis
        if name in self.presets_with_basis:
            self.update_preset_basis_widgets()
        self.plot_lattice()

    def add_form_factors(self):
        # This method runs whenever a basis has been created, to add the form
        # factors. I do this because then I can reuse as much code as possible
        # First we find out whether we're using a preset or user basis
        place = 4
        if self.lattice_config['lattice'] in self.presets_with_basis:
            lattice_name = self.lattice_config['lattice']
            n_basis = self.presets_with_basis[lattice_name]
            self.current_basis_grid = self.layout_preset_basis_grid
            self.current_basis_layout = self.layout_preset_basis
            move_checkboxes = False
        else:
            n_basis = 5
            self.current_basis_grid = self.layout_basis_grid
            self.current_basis_layout = self.layout_basis
            move_checkboxes = True

        self.lattice_config['form_factors'] = [1] * n_basis

        self.form_factor_fields = []
        label = QW.QLabel('Form Factors')
        label.setAlignment(QC.Qt.AlignCenter)
        self.current_basis_grid.addWidget(label, 0, place)
        for i in range(n_basis):
            el = QW.QLineEdit()
            el.setText('1')
            if i and move_checkboxes:
                el.setEnabled(False)
            el.setValidator(QG.QDoubleValidator(decimals=2))
            el.editingFinished.connect(
                lambda i=i, el=el: self.update_form_factor(i, el.text()))
            self.form_factor_fields.append(el)
            self.current_basis_grid.addWidget(el, i + 1, place)
            if move_checkboxes:
                el = self.basis_check_widgets[i]
                self.current_basis_grid.addWidget(el, i + 1, place + 1)

    def update_form_factor(self, i, text):
        self.lattice_config['form_factors'][i] = float(text)
        self.update_basis()

    def hide_basis_widgets(self, basis_no):
        # enable or disable basis coord widgets and update the basis
        checkbox = self.basis_check_widgets[basis_no]
        for le in self.basis_coord_widgets[basis_no]:
            le.setEnabled(checkbox.isChecked())
        self.form_factor_fields[basis_no].setEnabled(checkbox.isChecked())
        self.update_basis()

    def update_basis(self):
        # We get a list of basis atoms that are enabled
        enabled_basis_atoms = []
        for i in self.basis_check_widgets:
            enabled_basis_atoms.append(i.isChecked())
        # update the enabled_user_basis config and plot the lattice with the
        # new basis
        new_basis = self.lattice_config['user_basis'][enabled_basis_atoms]
        new_colors = self.lattice_config['user_colors']
        new_colors = list(compress(new_colors, enabled_basis_atoms))
        form_factors = self.lattice_config['form_factors']
        form_factors = list(compress(form_factors, enabled_basis_atoms))
        self.lattice_config['enabled_user_basis'] = new_basis
        self.lattice_config['enabled_user_colors'] = new_colors
        self.lattice_config['enabled_form_factors'] = form_factors
        self.plot_lattice()

    def update_k(self, coord_no, text):
        if coord_no == 2:
            # The z-coordinate
            num = -abs(float(text))
        else:
            num = float(text)
        self.lattice_config['k_in'][coord_no] = num
        self.plot_lattice()

    def update_indices(self, indices):
        self.lattice_config['indices'] = [None] + indices.tolist()
        self.highlight_combo.clear()
        str_list = [str(i) for i in self.lattice_config['indices']]
        self.highlight_combo.addItems(str_list)

    def update_highlight(self, i):
        highlight = self.lattice_config['indices'][i]
        self.lattice_config['highlight'] = highlight
        self.plot_lattice(no_change=True)

    def show_all(self):
        # get the state of the checkbox
        show_all = self.show_all_checkbox.isChecked()
        self.lattice_config['show_all'] = show_all
        self.plot_lattice(no_change=True)

    def plot_lattice(self, no_change=False):
        # This function takes the values from lattice_config and uses them to
        # update the plot. no_change is a flag, set if the basis/form factors
        # aren't changed

        # Get the veiwing angle of the axes (so we can remember it)
        azim = self.static_ax.azim
        elev = self.static_ax.elev

        # Clear the axes
        self.static_ax.clear()
        self.static_ax2.clear()

        # Grab the basis and colors
        if self.lattice_config['lattice'] in self.presets_with_basis:
            # We are dealing with a preset with basis
            basis = self.lattice_config['preset_basis']
            n_basis = np.atleast_2d(basis).shape[0]
            colors = self.lattice_config['preset_colors']
            colors = colors[:n_basis]
            form_factors = self.lattice_config['form_factors']
        else:
            colors = self.lattice_config['enabled_user_colors']
            basis = self.lattice_config['enabled_user_basis']
            form_factors = self.lattice_config['enabled_form_factors']
        k_in = self.lattice_config['k_in']
        highlight = self.lattice_config['highlight']
        show_all = self.lattice_config['show_all']

        # Plot the new lattice
        self.static_fig, self.static_ax, self.static_ax2, indices = Scattering(
            basis=basis,
            k_in=k_in,
            colors=colors,
            form_factor=form_factors,
            highlight=highlight,
            fig=self.static_fig,
            axes=(self.static_ax, self.static_ax2),
            show_all=show_all,
            returns=True,
            return_indices=True,
            plots=False)
        self.static_ax.view_init(elev, azim)

        if not no_change:
            # If we don't only highlight stuff (ie we've changed the basis or
            # form factors), we also update the list of highlights
            self.update_indices(indices)

        # Remember to have the canvas draw it!
        self.static_canvas.draw()
Exemple #11
0
class lattice_window(QW.QMainWindow):
    def __init__(self):
        super().__init__()
        self._main = QW.QWidget()
        self.setCentralWidget(self._main)
        self.layout_main = QW.QHBoxLayout(self._main)
        self.create_variables()
        # We create the options and add it to our main layout (it also creates
        # the basis fiels)
        self.create_options()
        self.layout_main.addLayout(self.layout_options)
        self.create_plot_window()
        self.layout_main.addWidget(self.static_canvas)

    def create_plot_window(self):
        # Create the default plot and return the figure and axis objects for
        # it. Then create the FigureCanvas, add them all to the layout and add
        # a toolbar. Lastly enable mouse support for Axes3D
        self.static_fig, self.static_ax = Lattice(returns=True, plots=False)
        self.static_canvas = FigureCanvas(self.static_fig)
        self.addToolBar(NavigationToolbar(self.static_canvas, self))
        self.static_ax.mouse_init()

    def create_variables(self):
        # A list of names for available lattice presets
        self.lattices = [
            'simple cubic', 'primitive bcc', 'primitive fcc', 'tetragonal',
            'tetragonal body centred', 'tetragonal face centred',
            'orthorhombic', 'orthorhombic body centred',
            'orthorhombic face centred', 'orthorhombic base centred',
            'simple monoclinic', 'base centred monoclinic', 'hexagonal',
            'triclinic', 'rhombohedral', 'diamond', 'wurtzite', 'zincblende',
            'conventional fcc', 'conventional bcc'
        ]
        self.colors = [
            'xkcd:cement', 'red', 'blue', 'green', 'cyan', 'magenta', 'black',
            'yellow'
        ]
        # A dictionary of the default values for lattice plotting
        self.default_config = {
            'a1': d[0],
            'a2': d[1],
            'a3': d[2],
            'preset_basis': d[3],
            'user_colors': ['xkcd:cement'] * 5,
            'enabled_user_colors': ['xkcd:cement'],
            'preset_colors': ['xkcd:cement'] * 4,
            'sizes': d[5],
            'enabled_user_basis': np.zeros((1, 3)),
            'user_basis': np.zeros((5, 3)),
            'lim_type': d[6],
            'grid_type': None,
            'max_': d[8],
            'a': 1,
            'b': 1.2,
            'c': 1.5,
            'alpha': 80,
            'beta': 70,
            'gamma': 60,
            'lattice': 'simple cubic',
            'max_preset_basis': 4
        }
        # Needed parameters for each lattice (a, b, c, alpha, beta, gamma)
        self.needed_params = {
            'simple cubic': [0],
            'primitive bcc': [0],
            'primitive fcc': [0],
            'tetragonal': [0, 1],
            'tetragonal body centred': [0, 1],
            'tetragonal face centred': [0, 1],
            'orthorhombic': [0, 1, 2],
            'orthorhombic body centred': [0, 1, 2],
            'orthorhombic face centred': [0, 1, 2],
            'orthorhombic base centred': [0, 1, 2],
            'simple monoclinic': [0, 1, 2, 3],
            'base centred monoclinic': [0, 1, 2, 3],
            'hexagonal': [0],
            'triclinic': [0, 1, 2, 3, 4, 5],
            'rhombohedral': [0],
            'diamond': [0],
            'wurtzite': [0, 1],
            'zincblende': [0],
            'conventional fcc': [0],
            'conventional bcc': [0]
        }
        self.presets_with_basis = {
            'wurtzite': 4,
            'diamond': 2,
            'zincblende': 2,
            'conventional fcc': 4,
            'conventional bcc': 2
        }
        # Copy of the default config. This is what the user'll actually change
        self.lattice_config = self.default_config.copy()

    def create_options(self):
        self.parameter_names = ['a', 'b', 'c', 'alpha', 'beta', 'gamma']
        self.parameter_text = [
            'a', 'b', 'c', 'alpha (degrees)', 'beta (degrees)',
            'gamma (degrees)'
        ]
        self.parameter_tooltips = [
            '', '', '', 'Angle between side 1 and 3',
            'Angle between side 1 and 2', 'Angle between side 2 and 3'
        ]
        self.param_labels = []
        self.param_fields = []
        self.layout_options = QW.QVBoxLayout()
        # Create the "show plot" button
        self.button_show = QW.QPushButton("Update plot", self)
        self.button_show.clicked.connect(self.update_lattice)

        # Create the lattice chooser dropdown
        self.lattice_chooser = QW.QComboBox(self)
        self.lattice_chooser.addItems(self.lattices)
        sep_index = len(self.lattices) - len(self.presets_with_basis)
        self.lattice_chooser.insertSeparator(sep_index)
        self.lattice_chooser.activated[str].connect(self.update_lattice_name)

        # Create the parameter layout
        self.layout_parameters = QW.QFormLayout()
        self.layout_parameters.addRow(self.button_show, self.lattice_chooser)

        for n, name in enumerate(self.parameter_names):
            # Create all the parameter labels and fields.
            label = QW.QLabel(self.parameter_text[n], self)
            label.setToolTip(self.parameter_tooltips[n])
            self.param_labels.append(label)
            field = QW.QLineEdit()
            field.setToolTip(self.parameter_tooltips[n])

            # Only allow floats and 2 decimals to input
            field.setValidator(QG.QDoubleValidator(decimals=2))

            # Populate with default values
            field.setText(str(self.lattice_config[name]))
            field.setEnabled(False)

            # Pass both parameter name and value to update_config_parameter
            field.returnPressed.connect(
                lambda name=name, el=field: self.update_config_parameter(
                    name, el.text()))

            # Add the parameter field to the list
            self.param_fields.append(field)

        # When everything has been created we add the parameters to the layout,
        # along with the labels
        for n in range(len(self.param_labels)):
            self.layout_parameters.addRow(self.param_labels[n],
                                          self.param_fields[n])

        self.layout_options.addLayout(self.layout_parameters)
        # Enable only the needed parameter fields.
        for n in self.needed_params[self.lattice_config['lattice']]:
            self.param_fields[n].setEnabled(True)
        self.create_user_basis()

    def create_preset_basis(self, n_basis):
        # So far the largest number of atoms in a preset basis is 4.
        self.layout_preset_basis = QW.QVBoxLayout()
        self.basis_title = QW.QLabel('Basis coordinates')
        self.basis_title.setAlignment(QC.Qt.AlignCenter)

        font = QG.QFont()
        font.setBold(True)
        self.basis_title.setFont(font)
        self.layout_preset_basis.addWidget(self.basis_title)

        self.layout_preset_basis_grid = QW.QGridLayout()
        names = ['x', 'y', 'z', 'color']
        for n, name in enumerate(names):
            label = QW.QLabel(name)
            label.setAlignment(QC.Qt.AlignCenter)
            self.layout_preset_basis_grid.addWidget(label, 0, n)
        n_coords = 3
        self.preset_basis_coord_widgets = np.empty((n_basis, n_coords),
                                                   dtype=object)
        self.preset_basis_color_widgets = []
        for i in range(n_basis):
            for j in range(n_coords):
                el = QW.QLineEdit()
                el.setEnabled(False)
                if i == 0:
                    el.setText('0')
                self.preset_basis_coord_widgets[i, j] = el
                self.layout_preset_basis_grid.addWidget(el, i + 1, j)
            el = QW.QComboBox()
            el.addItems(self.colors)
            el.activated[str].connect(
                lambda i=i, el=el: self.update_basis_color(
                    'preset', self.preset_basis_color_widgets.index(el), i))
            self.preset_basis_color_widgets.append(el)
            self.layout_preset_basis_grid.addWidget(el, i + 1, n_coords)
        self.layout_preset_basis.addLayout(self.layout_preset_basis_grid)
        self.current_basis_layout = self.layout_preset_basis
        self.layout_options.addLayout(self.layout_preset_basis)

    def create_user_basis(self):
        # Basis-stuff
        font = QG.QFont()
        font.setBold(True)
        self.layout_basis = QW.QVBoxLayout()
        self.basis_title = QW.QLabel('Basis coordinates')
        self.basis_title.setAlignment(QC.Qt.AlignCenter)
        self.basis_title.setFont(font)
        self.layout_basis.addWidget(self.basis_title)
        self.layout_basis_grid = QW.QGridLayout()
        n_basis = 5
        n_coords = 3
        names = ['x', 'y', 'z', 'color']
        for n, name in enumerate(names):
            label = QW.QLabel(name)
            label.setAlignment(QC.Qt.AlignCenter)
            self.layout_basis_grid.addWidget(label, 0, n)
        self.basis_coord_widgets = np.empty((n_basis, n_coords), dtype=object)
        self.basis_color_widgets = []
        self.basis_check_widgets = []
        # Create all the basis coordinate widgets
        for i in range(n_basis):
            for j in range(n_coords):
                # A QLineEdit for each of the basis atoms coordinates
                el = QW.QLineEdit()
                el.setText(str(0))
                # We want the first to be enabled
                el.setEnabled(i == 0)
                el.setValidator(QG.QDoubleValidator(decimals=2))
                # Pass basis and coordinate number, along with value, to
                # update_basis_val
                el.editingFinished.connect(lambda i=i, j=j, el=el: self.
                                           update_basis_val(i, j, el.text()))
                # Add the QLineEdit to the array of basis coordinate widgets
                # and to the layout
                self.basis_coord_widgets[i, j] = el
                self.layout_basis_grid.addWidget(el, i + 1, j)

            # Add a color lineedit for each basis atom
            el = QW.QComboBox()
            el.addItems(self.colors)
            # Okay, for some reason i is the combo-box text. I don't know why.
            # So we're gonna do a slight hack to find the proper "i". We're
            # gonna index the list of color widgets
            el.activated[str].connect(
                lambda i=i, el=el: self.update_basis_color(
                    'user', self.basis_color_widgets.index(el), i))

            self.basis_color_widgets.append(el)
            self.layout_basis_grid.addWidget(el, i + 1, n_coords)

            # Add a checkbox for each basis atom
            check = QW.QCheckBox()
            check.setChecked(i == 0)

            # For some reason this doesn't work when we do it in a loop...
            # check.stateChanged.connect(
            #     lambda i=i: self.hide_basis_widgets(i))

            # Add the checkbox to the list of widgets, and the layout.
            self.basis_check_widgets.append(check)
            self.layout_basis_grid.addWidget(check, i + 1, n_coords + 2)

        # It's ugly but it works. We make the checkbox to stuff
        self.basis_check_widgets[0].stateChanged.connect(
            lambda: self.hide_basis_widgets(0))
        self.basis_check_widgets[1].stateChanged.connect(
            lambda: self.hide_basis_widgets(1))
        self.basis_check_widgets[2].stateChanged.connect(
            lambda: self.hide_basis_widgets(2))
        self.basis_check_widgets[3].stateChanged.connect(
            lambda: self.hide_basis_widgets(3))
        self.basis_check_widgets[4].stateChanged.connect(
            lambda: self.hide_basis_widgets(4))

        # We also reset the basis and colors:
        self.lattice_config = self.default_config.copy()

        self.current_basis_layout = self.layout_basis
        self.layout_basis.addLayout(self.layout_basis_grid)
        self.layout_options.addLayout(self.layout_basis)

    def update_lattice(self):
        # Grab a new lattice based on the parameters in lattice_config
        a = self.lattice_config['a']
        b = self.lattice_config['b']
        c = self.lattice_config['c']
        alpha = self.lattice_config['alpha'] * np.pi / 180
        beta = self.lattice_config['beta'] * np.pi / 180
        gamma = self.lattice_config['gamma'] * np.pi / 180
        name = self.lattice_config['lattice']
        (a1, a2, a3), basis, _ = lattices.chooser(lattice_name=name,
                                                  a=a,
                                                  b=b,
                                                  c=c,
                                                  alpha=alpha,
                                                  beta=beta,
                                                  gamma=gamma)

        # Update primitive lattice vectors and (preset) basis.
        self.lattice_config.update(
            dict(zip(('a1', 'a2', 'a3', 'preset_basis'), (a1, a2, a3, basis))))
        if name in self.presets_with_basis:
            self.update_preset_basis_widgets()
        self.plot_lattice()

    def update_lattice_name(self, text):
        # Delete current basis layout.
        self.delete_layout(self.current_basis_layout)
        if text in self.presets_with_basis:
            # We have a preset with a basis, so we delete the user basis and
            # load the preset basis
            self.create_preset_basis(self.presets_with_basis[text])
            self.current_basis_layout = self.layout_preset_basis
        else:
            self.create_user_basis()
            self.current_basis_layout = self.layout_basis

        self.lattice_config['lattice'] = text

        # disable all fields and enable only those needed. It's the easiest
        # way, if maybe a bit redundant
        for le in self.param_fields:
            le.setEnabled(False)
        for n in self.needed_params[text]:
            self.param_fields[n].setEnabled(True)

        # We should also load the default values for the parameters
        for n, param in enumerate(self.parameter_names):
            self.lattice_config[param] = self.default_config[param]
            self.param_fields[n].setText(str(self.lattice_config[param]))

        # And then we update the lattice
        self.update_lattice()

    def update_preset_basis_widgets(self):
        basis = np.atleast_2d(self.lattice_config['preset_basis'])
        for n_atom, atom in enumerate(basis):
            for n_coord, coord in enumerate(atom):
                el = self.preset_basis_coord_widgets[n_atom, n_coord]
                el.setText("{0:.3f}".format(coord))

    def update_config_parameter(self, param, text):
        # This function updates the relevant parameter in the lattice_config
        # dict, but only if the text is a float!
        try:
            self.lattice_config[param] = float(text)
        except ValueError:
            pass
        self.update_lattice()

    def update_basis_color(self, type_, num, text):
        colors = self.lattice_config['{}_colors'.format(type_)]
        text = text.lower()
        colors[num] = text
        self.update_basis()

    def plot_lattice(self):
        # This function takes the values from lattice_config and uses them to
        # update the plot.

        # Clear the axes
        self.static_ax.clear()

        # Grab lattice vectors and basis(es) from lattice_config
        a1 = self.lattice_config['a1']
        a2 = self.lattice_config['a2']
        a3 = self.lattice_config['a3']

        # Grab the basis and colors
        if self.lattice_config['lattice'] in self.presets_with_basis:
            # We are dealing with a preset with basis
            basis = self.lattice_config['preset_basis']
            n_basis = np.atleast_2d(basis).shape[0]
            colors = self.lattice_config['preset_colors']
            colors = colors[:n_basis]
        else:
            colors = self.lattice_config['enabled_user_colors']
            basis = self.lattice_config['enabled_user_basis']

        # Plot the new lattice
        self.static_fig, self.static_ax = Lattice(a1=a1,
                                                  a2=a2,
                                                  a3=a3,
                                                  basis=basis,
                                                  colors=colors,
                                                  fig=self.static_fig,
                                                  ax=self.static_ax,
                                                  returns=True,
                                                  plots=False,
                                                  checks=False)

        # Remember to have the canvas draw it!
        self.static_canvas.draw()

    def update_basis_val(self, basis_no, coord_no, val):
        self.lattice_config['user_basis'][basis_no, coord_no] = float(val)
        self.update_basis()

    def hide_basis_widgets(self, basis_no):
        # enable or disable basis coord widgets and update the basis
        checkbox = self.basis_check_widgets[basis_no]
        for le in self.basis_coord_widgets[basis_no]:
            le.setEnabled(checkbox.isChecked())
        self.update_basis()

    def update_basis(self):
        # We get a list of basis atoms that are enabled
        enabled_basis_atoms = []
        for i in self.basis_check_widgets:
            enabled_basis_atoms.append(i.isChecked())
        # update the enabled_user_basis config and plot the lattice with the
        # new basis
        new_basis = self.lattice_config['user_basis'][enabled_basis_atoms]
        self.lattice_config['enabled_user_basis'] = new_basis
        new_colors = self.lattice_config['user_colors']
        new_colors = list(compress(new_colors, enabled_basis_atoms))
        self.lattice_config['enabled_user_colors'] = new_colors
        self.plot_lattice()

    def delete_layout(self, layout):
        if layout is not None:
            while layout.count():
                item = layout.takeAt(0)
                widget = item.widget()
                if widget is not None:
                    widget.setParent(None)
                else:
                    self.delete_layout(item.layout())
class PowerAperture(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.target_min_range.returnPressed.connect(self._update_canvas)
        self.target_max_range.returnPressed.connect(self._update_canvas)
        self.system_temperature.returnPressed.connect(self._update_canvas)
        self.search_volume.returnPressed.connect(self._update_canvas)
        self.noise_figure.returnPressed.connect(self._update_canvas)
        self.losses.returnPressed.connect(self._update_canvas)
        self.signal_to_noise.returnPressed.connect(self._update_canvas)
        self.scan_time.returnPressed.connect(self._update_canvas)
        self.target_rcs.returnPressed.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea,
                        NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value.
        :return:
        """
        # Get the range from the form
        target_min_range = float(self.target_min_range.text())
        target_max_range = float(self.target_max_range.text())

        # Set up the range array
        target_range = linspace(target_min_range, target_max_range, 2000)

        # Convert the noise figure to noise factor
        noise_figure = float(self.noise_figure.text())
        noise_factor = 10.0**(noise_figure / 10.0)

        # Convert the losses, snr and target rcs to linear units
        losses = 10.0**(float(self.losses.text()) / 10.0)
        target_rcs = 10.0**(float(self.target_rcs.text()) / 10.0)
        signal_to_noise = 10.0**(float(self.signal_to_noise.text()) / 10.0)

        # Set up the input args
        kwargs = {
            'target_range': target_range,
            'system_temperature': float(self.system_temperature.text()),
            'search_volume': float(self.search_volume.text()),
            'noise_factor': noise_factor,
            'losses': losses,
            'signal_to_noise': signal_to_noise,
            'scan_time': float(self.scan_time.text()),
            'target_rcs': target_rcs
        }

        # Calculate the power aperture product
        power_aperture = search_radar_range.power_aperture(**kwargs)

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Display the results
        self.axes1.plot(target_range / 1.0e3, 10.0 * log10(power_aperture), '')

        # Set the plot title and labels
        self.axes1.set_title('Power Aperture Product', size=14)
        self.axes1.set_xlabel('Target Range (km)', size=14)
        self.axes1.set_ylabel('Power Aperture (dB)', size=14)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Update the canvas
        self.my_canvas.draw()
Exemple #13
0
class Base_Plot(QtCore.QObject):
    def __init__(self, parent, widget, mpl_layout):
        super().__init__(parent)
        self.parent = parent

        self.widget = widget
        self.mpl_layout = mpl_layout
        self.fig = mplfigure.Figure()
        mpl.scale.register_scale(AbsoluteLogScale)
        mpl.scale.register_scale(BiSymmetricLogScale)
        
        # Set plot variables
        self.x_zoom_constraint = False
        self.y_zoom_constraint = False
        
        self.create_canvas()        
        self.NavigationToolbar(self.canvas, self.widget, coordinates=True)
        
        # AutoScale
        self.autoScale = [True, True]

        self.i = 0
        
        # Connect Signals
        self._draw_event_signal = self.canvas.mpl_connect('draw_event', self._draw_event)
        self.canvas.mpl_connect('button_press_event', lambda event: self.click(event))
        self.canvas.mpl_connect('key_press_event', lambda event: self.key_press(event))
        # self.canvas.mpl_connect('key_release_event', lambda event: self.key_release(event))

        self._draw_event()
    
    def create_canvas(self):
        self.canvas = FigureCanvas(self.fig)
        self.mpl_layout.addWidget(self.canvas)
        self.canvas.setFocusPolicy(QtCore.Qt.StrongFocus)
        self.canvas.draw()
        
        # Set scales
        scales = {'linear': True, 'log': 0, 'abslog': 0, 'bisymlog': 0}
        for ax in self.ax:
            ax.scale = {'x': scales, 'y': deepcopy(scales)}
            ax.ticklabel_format(scilimits=(-4, 4), useMathText=True)
        
        # Get background
        self.background_data = self.canvas.copy_from_bbox(ax.bbox)
    
    def _find_calling_axes(self, event):
        for axes in self.ax:    # identify calling axis
            if axes == event or (hasattr(event, 'inaxes') and event.inaxes == axes):
                return axes
    
    def set_xlim(self, axes, x):
        if not self.autoScale[0]: return    # obey autoscale right click option
    
        if axes.get_xscale() in ['linear']:
            # range = np.abs(np.max(x) - np.min(x))
            # min = np.min(x) - range*0.05
            # if min < 0:
                # min = 0
            # xlim = [min, np.max(x) + range*0.05]
            xlim = [np.min(x), np.max(x)]
        if 'log' in axes.get_xscale():
            abs_x = np.abs(x)
            abs_x = abs_x[np.nonzero(abs_x)]    # exclude 0's
            
            if axes.get_xscale() in ['log', 'abslog', 'bisymlog']:
                min_data = np.ceil(np.log10(np.min(abs_x)))
                max_data = np.floor(np.log10(np.max(abs_x)))
                
                xlim = [10**(min_data-1), 10**(max_data+1)]
        
        if np.isnan(xlim).any() or np.isinf(xlim).any():
            pass
        elif xlim != axes.get_xlim():   # if xlim changes
            axes.set_xlim(xlim)
    
    def set_ylim(self, axes, y):
        if not self.autoScale[1]: return    # obey autoscale right click option
        
        min_data = np.array(y)[np.isfinite(y)].min()
        max_data = np.array(y)[np.isfinite(y)].max()
        
        if min_data == max_data:
            min_data -= 10**-1
            max_data += 10**-1
        
        if axes.get_yscale() == 'linear':
            range = np.abs(max_data - min_data)
            ylim = [min_data - range*0.1, max_data + range*0.1]
            
        elif axes.get_yscale() in ['log', 'abslog']:
            abs_y = np.abs(y)
            abs_y = abs_y[np.nonzero(abs_y)]    # exclude 0's
            abs_y = abs_y[np.isfinite(abs_y)]    # exclude nan, inf
            
            if abs_y.size == 0:             # if no data, assign 
                ylim = [10**-7, 10**-1]
            else:            
                min_data = np.ceil(np.log10(np.min(abs_y)))
                max_data = np.floor(np.log10(np.max(abs_y)))
                
                ylim = [10**(min_data-1), 10**(max_data+1)]
                
        elif axes.get_yscale() == 'bisymlog':
            min_sign = np.sign(min_data)
            max_sign = np.sign(max_data)
            
            if min_sign > 0:
                min_data = np.ceil(np.log10(np.abs(min_data)))
            elif min_data == 0 or max_data == 0:
                pass
            else:
                min_data = np.floor(np.log10(np.abs(min_data)))
            
            if max_sign > 0:
                max_data = np.floor(np.log10(np.abs(max_data)))
            elif min_data == 0 or max_data == 0:
                pass
            else:
                max_data = np.ceil(np.log10(np.abs(max_data)))
            
            # TODO: ylim could be incorrect for neg/neg, checked for pos/pos, pos/neg
            ylim = [min_sign*10**(min_data-min_sign), max_sign*10**(max_data+max_sign)]
        
        if ylim != axes.get_ylim():   # if ylim changes, update
            axes.set_ylim(ylim)
    
    def update_xylim(self, axes, xlim=[], ylim=[], force_redraw=True):
        data = self._get_data(axes)         

        # on creation, there is no data, don't update
        if np.shape(data['x'])[0] < 2 or np.shape(data['y'])[0] < 2:   
            return
        
        for (axis, lim) in zip(['x', 'y'], [xlim, ylim]):
            # Set Limits
            if len(lim) == 0:
                eval('self.set_' + axis + 'lim(axes, data["' + axis + '"])')
            else:
                eval('axes.set_' + axis + 'lim(lim)')
            
            # If bisymlog, also update scaling, C
            if eval('axes.get_' + axis + 'scale()') == 'bisymlog':
                self._set_scale(axis, 'bisymlog', axes)
            
            ''' # TODO: Do this some day, probably need to create 
                        annotation during canvas creation
            # Move exponent 
            exp_loc = {'x': (.89, .01), 'y': (.01, .96)}
            eval(f'axes.get_{axis}axis().get_offset_text().set_visible(False)')
            ax_max = eval(f'max(axes.get_{axis}ticks())')
            oom = np.floor(np.log10(ax_max)).astype(int)
            axes.annotate(fr'$\times10^{oom}$', xy=exp_loc[axis], 
                          xycoords='axes fraction')
            '''
        
        if force_redraw:
            self._draw_event()  # force a draw
    
    def _get_data(self, axes):      # NOT Generic
        # get experimental data for axes
        data = {'x': [], 'y': []}
        if 'exp_data' in axes.item:
            data_plot = axes.item['exp_data'].get_offsets().T
            if np.shape(data_plot)[1] > 1:
                data['x'] = data_plot[0,:]
                data['y'] = data_plot[1,:]
            
            # append sim_x if it exists
            if 'sim_data' in axes.item and hasattr(axes.item['sim_data'], 'raw_data'):
                if axes.item['sim_data'].raw_data.size > 0:
                    data['x'] = np.append(data['x'], axes.item['sim_data'].raw_data[:,0])
        
        elif 'weight_unc_fcn' in axes.item:
            data['x'] = axes.item['weight_unc_fcn'].get_xdata()
            data['y'] = axes.item['weight_unc_fcn'].get_ydata()
        
        elif any(key in axes.item for key in ['density', 'qq_data', 'sim_data']):
            name = np.intersect1d(['density', 'qq_data'], list(axes.item.keys()))[0]
            for n, coord in enumerate(['x', 'y']):
                xyrange = np.array([])
                for item in axes.item[name]:
                    if name == 'qq_data':
                        coordData = item.get_offsets()
                        if coordData.size == 0:
                            continue
                        else:
                            coordData = coordData[:,n]
                    elif name == 'density':
                        coordData = eval('item.get_' + coord + 'data()')
                    
                    coordData = np.array(coordData)[np.isfinite(coordData)]
                    if coordData.size == 0:
                        continue
                    
                    xyrange = np.append(xyrange, [coordData.min(), coordData.max()])

                xyrange = np.reshape(xyrange, (-1,2))
                data[coord] = [np.min(xyrange[:,0]), np.max(xyrange[:,1])]

        return data
    
    def _set_scale(self, coord, type, event, update_xylim=False):
        def RoundToSigFigs(x, p):
            x = np.asarray(x)
            x_positive = np.where(np.isfinite(x) & (x != 0), np.abs(x), 10**(p-1))
            mags = 10 ** (p - 1 - np.floor(np.log10(x_positive)))
            return np.round(x * mags) / mags
    
        # find correct axes
        axes = self._find_calling_axes(event)
        # for axes in self.ax:
            # if axes == event or (hasattr(event, 'inaxes') and event.inaxes == axes):
                # break
        
        # Set scale menu boolean
        if coord == 'x':
            shared_axes = axes.get_shared_x_axes().get_siblings(axes)               
        else:
            shared_axes = axes.get_shared_y_axes().get_siblings(axes)
        
        for shared in shared_axes:
            shared.scale[coord] = dict.fromkeys(shared.scale[coord], False) # sets all types: False
            shared.scale[coord][type] = True                                # set selected type: True

        # Apply selected scale
        if type == 'linear':
            str = 'axes.set_{:s}scale("{:s}")'.format(coord, 'linear')
        elif type == 'log':
            str = 'axes.set_{0:s}scale("{1:s}", nonpos{0:s}="mask")'.format(coord, 'log')
        elif type == 'abslog':
            str = 'axes.set_{:s}scale("{:s}")'.format(coord, 'abslog')
        elif type == 'bisymlog':
            # default string to evaluate 
            str = 'axes.set_{0:s}scale("{1:s}")'.format(coord, 'bisymlog')
            
            data = self._get_data(axes)[coord]
            if len(data) != 0:
                finite_data = np.array(data)[np.isfinite(data)] # ignore nan and inf
                min_data = finite_data.min()  
                max_data = finite_data.max()
                
                if min_data != max_data:
                    # if zero is within total range, find largest pos or neg range
                    if np.sign(max_data) != np.sign(min_data):  
                        processed_data = [finite_data[finite_data>=0], finite_data[finite_data<=0]]
                        C = 0
                        for data in processed_data:
                            range = np.abs(data.max() - data.min())
                            if range > C:
                                C = range
                                max_data = data.max()
                    else:
                        C = np.abs(max_data-min_data)
                    C *= 10**(OoM(max_data) + 2)  # scaling factor TODO: + 1 looks loglike, + 2 linear like
                    C = RoundToSigFigs(C, 1)    # round to 1 significant figure
                    str = 'axes.set_{0:s}scale("{1:s}", C={2:e})'.format(coord, 'bisymlog', C)
        
        eval(str)
        if type == 'linear' and coord == 'x':
            formatter = MathTextSciSIFormatter(useOffset=False, useMathText=True)
            axes.xaxis.set_major_formatter(formatter)
            
        elif type == 'linear' and coord == 'y':
            formatter = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True)
            formatter.set_powerlimits([-3, 4])
            axes.yaxis.set_major_formatter(formatter)
            
        if update_xylim:
            self.update_xylim(axes)
 
    def _animate_items(self, bool=True):
        for axis in self.ax:
            axis.xaxis.set_animated(bool)
            axis.yaxis.set_animated(bool)
            if axis.get_legend() is not None:
                axis.get_legend().set_animated(bool)
            
            for item in axis.item.values():
                if isinstance(item, list):
                    for subItem in item:
                        if isinstance(subItem, dict):
                            subItem['line'].set_animated(bool)
                        else:
                            subItem.set_animated(bool)
                else:
                    item.set_animated(bool)
    
    def _draw_items_artist(self):
        self.canvas.restore_region(self.background_data)           
        for axis in self.ax:
            axis.draw_artist(axis.xaxis)
            axis.draw_artist(axis.yaxis)
            for item in axis.item.values():
                if isinstance(item, list):
                    for subItem in item:
                        if isinstance(subItem, dict):
                            axis.draw_artist(subItem['line'])
                        else:
                            axis.draw_artist(subItem) 
                else:
                    axis.draw_artist(item)
           
            if axis.get_legend() is not None:
                axis.draw_artist(axis.get_legend())
          
        self.canvas.update()
    
    def set_background(self):
        self.canvas.mpl_disconnect(self._draw_event_signal)
        self.canvas.draw() # for when shock changes. Without signal disconnect, infinite loop
        self._draw_event_signal = self.canvas.mpl_connect('draw_event', self._draw_event)
        self.background_data = self.canvas.copy_from_bbox(self.fig.bbox)
    
    def _draw_event(self, event=None):   # After redraw (new/resizing window), obtain new background
        self._animate_items(True)
        self.set_background()
        self._draw_items_artist()
        #self.canvas.draw_idle()
    
    def clear_plot(self, ignore=[], draw=True):
        for axis in self.ax:
            if axis.get_legend() is not None:
                axis.get_legend().remove()
                
            for item in axis.item.values():
                if hasattr(item, 'set_offsets'):    # clears all data points
                    if 'scatter' not in ignore:
                        item.set_offsets(([np.nan, np.nan]))
                elif hasattr(item, 'set_xdata') and hasattr(item, 'set_ydata'):
                    if 'line' not in ignore:
                        item.set_xdata([np.nan, np.nan]) # clears all lines
                        item.set_ydata([np.nan, np.nan])
                elif hasattr(item, 'set_text'): # clears all text boxes
                    if 'text' not in ignore:
                        item.set_text('')
        if draw:
            self._draw_event()

    def click(self, event):
        if event.button == 3: # if right click
            if not self.toolbar.mode:
                self._popup_menu(event)
            # if self.toolbar._active is 'ZOOM':  # if zoom is on, turn off
                # self.toolbar.press_zoom(event)  # cancels current zooom
                # self.toolbar.zoom()             # turns zoom off
            elif event.dblclick:                  # if double right click, go to default view
                self.toolbar.home()

    def key_press(self, event):
        if event.key == 'escape':
            if self.toolbar.mode == 'zoom rect':  # if zoom is on, turn off
                self.toolbar.zoom()               # turns zoom off
            elif self.toolbar.mode == 'pan/zoom':
                self.toolbar.pan()
        # elif event.key == 'shift':
        elif event.key == 'x':  # Does nothing, would like to make sticky constraint zoom/pan
            self.x_zoom_constraint = not self.x_zoom_constraint
        elif event.key == 'y':  # Does nothing, would like to make sticky constraint zoom/pan
            self.y_zoom_constraint = not self.y_zoom_constraint
        elif event.key in ['s', 'l', 'L', 'k']: pass
        else:
            key_press_handler(event, self.canvas, self.toolbar)
    
    # def key_release(self, event):
        # print(event.key, 'released')
    
    def NavigationToolbar(self, *args, **kwargs):
        ## Add toolbar ##
        self.toolbar = CustomNavigationToolbar(self.canvas, self.widget, coordinates=True)
        self.mpl_layout.addWidget(self.toolbar)

    def _popup_menu(self, event):
        axes = self._find_calling_axes(event)   # find axes calling right click
        if axes is None: return
        
        pos = self.parent.mapFromGlobal(QtGui.QCursor().pos())
        
        popup_menu = QMenu(self.parent)
        xScaleMenu = popup_menu.addMenu('x-scale')
        yScaleMenu = popup_menu.addMenu('y-scale')
        
        for coord in ['x', 'y']:
            menu = eval(coord + 'ScaleMenu')
            for type in axes.scale[coord].keys():
                action = QAction(type, menu, checkable=True)
                if axes.scale[coord][type]: # if it's checked
                    action.setEnabled(False)
                else:
                    action.setEnabled(True)
                menu.addAction(action)
                action.setChecked(axes.scale[coord][type])
                fcn = lambda event, coord=coord, type=type: self._set_scale(coord, type, axes, True)
                action.triggered.connect(fcn)
        
        # Create menu for AutoScale options X Y All
        popup_menu.addSeparator()
        autoscale_options = ['AutoScale X', 'AutoScale Y', 'AutoScale All']
        for n, text in enumerate(autoscale_options):
            action = QAction(text, menu, checkable=True)
            if n < len(self.autoScale):
                action.setChecked(self.autoScale[n])
            else:
                action.setChecked(all(self.autoScale))
            popup_menu.addAction(action)
            action.toggled.connect(lambda event, n=n: self._setAutoScale(n, event, axes))
                    
        popup_menu.exec_(self.parent.mapToGlobal(pos))    
    
    def _setAutoScale(self, choice, event, axes):
        if choice == len(self.autoScale):
            for n in range(len(self.autoScale)):
                self.autoScale[n] = event
        else:
            self.autoScale[choice] = event
        
        if event:   # if something toggled true, update limits
            self.update_xylim(axes)
class BackProjection(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.range_center.returnPressed.connect(self._update_canvas)
        self.x_target.returnPressed.connect(self._update_canvas)
        self.y_target.returnPressed.connect(self._update_canvas)
        self.rcs.returnPressed.connect(self._update_canvas)
        self.x_span.returnPressed.connect(self._update_canvas)
        self.y_span.returnPressed.connect(self._update_canvas)
        self.nx_ny.returnPressed.connect(self._update_canvas)
        self.start_frequency.returnPressed.connect(self._update_canvas)
        self.bandwidth.returnPressed.connect(self._update_canvas)
        self.az_start_end.returnPressed.connect(self._update_canvas)
        self.dynamic_range.returnPressed.connect(self._update_image_only)
        self.window_type.currentIndexChanged.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.fig = fig
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea,
                        NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes and input value.
        :return:
        """
        # Get the parameters from the form
        range_center = float(self.range_center.text())
        x_target = self.x_target.text().split(',')
        y_target = self.y_target.text().split(',')
        rcs = self.rcs.text().split(',')
        xt = []
        yt = []
        rt = []
        for x, y, r in zip(x_target, y_target, rcs):
            xt.append(float(x))
            yt.append(float(y))
            rt.append(float(r))

        x_span = float(self.x_span.text())
        y_span = float(self.y_span.text())

        nx_ny = self.nx_ny.text().split(',')
        nx = int(nx_ny[0])
        ny = int(nx_ny[1])

        start_frequency = float(self.start_frequency.text())
        bandwidth = float(self.bandwidth.text())

        az_start_end = self.az_start_end.text().split(',')
        az_start = float(az_start_end[0])
        az_end = float(az_start_end[1])

        # Set up the azimuth space
        r = sqrt(x_span**2 + y_span**2)
        da = c / (2.0 * r * start_frequency)
        na = int((az_end - az_start) / da)
        az = linspace(az_start, az_end, na)

        # Set up the frequency space
        df = c / (2.0 * r)
        nf = int(bandwidth / df)
        frequency = linspace(start_frequency, start_frequency + bandwidth, nf)

        # Set the length of the FFT
        fft_length = next_fast_len(4 * nf)

        # Set up the aperture positions
        sensor_x = range_center * cos(radians(az))
        sensor_y = range_center * sin(radians(az))
        sensor_z = zeros_like(sensor_x)

        # Set up the image space
        self.xi = linspace(-0.5 * x_span, 0.5 * x_span, nx)
        self.yi = linspace(-0.5 * y_span, 0.5 * y_span, ny)
        x_image, y_image = meshgrid(self.xi, self.yi)
        z_image = zeros_like(x_image)

        # Calculate the signal (k space)
        signal = zeros([nf, na], dtype=complex)

        index = 0
        for a in az:
            r_los = [cos(radians(a)), sin(radians(a))]

            for x, y, r in zip(xt, yt, rt):
                r_target = -dot(r_los, [x, y])
                signal[:, index] += r * exp(
                    -1j * 4.0 * pi * frequency / c * r_target)
            index += 1

        # Get the selected window from the form
        window_type = self.window_type.currentText()

        if window_type == 'Hanning':
            h1 = hanning(nf, True)
            h2 = hanning(na, True)
            coefficients = sqrt(outer(h1, h2))
        elif window_type == 'Hamming':
            h1 = hamming(nf, True)
            h2 = hamming(na, True)
            coefficients = sqrt(outer(h1, h2))
        elif window_type == 'Rectangular':
            coefficients = ones([nf, na])

        # Apply the selected window
        signal *= coefficients

        # Reconstruct the image
        self.bp_image = backprojection.reconstruct(signal, sensor_x, sensor_y,
                                                   sensor_z, range_center,
                                                   x_image, y_image, z_image,
                                                   frequency, fft_length)

        # Update the image
        self._update_image_only()

    def _update_image_only(self):
        dynamic_range = float(self.dynamic_range.text())

        # Remove the color bar
        try:
            self.cbar.remove()
        except:
            print('Initial Plot')

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Display the results
        bpi = abs(self.bp_image) / amax(abs(self.bp_image))
        im = self.axes1.pcolor(self.xi,
                               self.yi,
                               20.0 * log10(bpi),
                               cmap='jet',
                               vmin=-dynamic_range,
                               vmax=0)
        self.cbar = self.fig.colorbar(im,
                                      ax=self.axes1,
                                      orientation='vertical')
        self.cbar.set_label("(dB)", size=10)

        # Set the plot title and labels
        self.axes1.set_title('Back Projection', size=14)
        self.axes1.set_xlabel('Range (m)', size=12)
        self.axes1.set_ylabel('Cross Range (m)', size=12)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Update the canvas
        self.my_canvas.draw()
Exemple #15
0
 def add_plot(self, plot, event_source):
     canvas = FigureCanvas(plot.fig)
     layout = QtWidgets.QHBoxLayout(self.centralwidget)
     layout.addWidget(canvas)
     ani = DataAnimation(plot, event_source, blit=True)
     canvas.draw()
class RainAttenuation(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.frequencyStartGHz.returnPressed.connect(self._update_canvas)
        self.frequencyEndGHz.returnPressed.connect(self._update_canvas)
        self.rain_rate.returnPressed.connect(self._update_canvas)
        self.elevation_angle.returnPressed.connect(self._update_canvas)
        self.polarization_tilt_angle.returnPressed.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea, NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value.
        :return:
        """
        # Get the frequencies from the form
        frequency_start = float(self.frequencyStartGHz.text())
        frequency_end = float(self.frequencyEndGHz.text())

        # Set up the frequency array
        frequency = linspace(frequency_start, frequency_end, 2000)

        # Set up the input args
        kwargs = {'frequency': frequency,
                  'rain_rate': float(self.rain_rate.text()),
                  'elevation_angle': radians(float(self.elevation_angle.text())),
                  'polarization_tilt_angle': radians(float(self.polarization_tilt_angle.text()))}

        # Calculate the rain attenuation
        gamma = rain.attenuation(**kwargs)

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Display the results
        self.axes1.plot(frequency, gamma)

        # Set the plot title and labels
        self.axes1.set_title('Rain Attenuation', size=14)
        self.axes1.set_xlabel('Frequency (GHz)', size=12)
        self.axes1.set_ylabel('Specific Attenuation (dB/km)', size=12)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Update the canvas
        self.my_canvas.draw()
Exemple #17
0
class Preceptron(QWidget):
    def __init__(self):
        super().__init__()
        self.__initUI()

    def __initUI(self):
        self.__window_layout = QHBoxLayout()
        self.__status_layout = QVBoxLayout()
        self.__result_layout = QVBoxLayout()

        self.__set_file_box_UI()
        self.__set_dataInfo_box_UI()
        self.__set_setting_box_UI()
        self.__set_control_box_UI()
        self.__set_figure_box_UI()
        self.__set_result_box_UI()

        self.__status_layout.addWidget(self.__file_box, 1)
        self.__status_layout.addWidget(self.__dataInfo_box, 2)
        self.__status_layout.addWidget(self.__setting_box, 2)
        self.__status_layout.addWidget(self.__control_box, 2)

        self.__result_layout.addWidget(self.__figure_box, 3)
        self.__result_layout.addWidget(self.__result_box, 1)

        self.__window_layout.addLayout(self.__status_layout, 1)
        self.__window_layout.addLayout(self.__result_layout, 1)
        self.__window_layout.setContentsMargins(5, 5, 5, 5)
        self.setLayout(self.__window_layout)

    def __set_file_box_UI(self):
        self.__file_box = QGroupBox('File')
        hbox = QHBoxLayout()

        file_label = QLabel("Select File :")
        file_label.setAlignment(Qt.AlignCenter)
        self.file_cb = QComboBox()
        self.file_cb.addItems(support.find_all_dataset())
        self.file_cb.setStatusTip("Please select a file as dataset")
        self.file_search_btn = QPushButton('Load File', self)
        self.file_search_btn.setStatusTip(
            "Load file and update file information")
        self.file_search_btn.clicked.connect(self.load_file)

        hbox.addWidget(file_label, 1)
        hbox.addWidget(self.file_cb, 3)
        hbox.addWidget(self.file_search_btn, 1)
        self.__file_box.setLayout(hbox)

    def __set_dataInfo_box_UI(self):
        self.__dataInfo_box = QGroupBox('Data Information')
        vbox = QVBoxLayout()
        name_box = QHBoxLayout()
        number_of_instances_box = QHBoxLayout()
        number_of_feature_box = QHBoxLayout()
        number_of_class_box = QHBoxLayout()

        name_label = QLabel("DataSet Name :")
        name_label.setAlignment(Qt.AlignCenter)
        self.name_text = QLabel(" -- ")
        self.name_text.setAlignment(Qt.AlignCenter)
        name_box.addWidget(name_label, 1)
        name_box.addWidget(self.name_text, 2)

        number_of_instances_label = QLabel("Number of Instances :")
        number_of_instances_label.setAlignment(Qt.AlignCenter)
        self.number_of_instances_text = QLabel(" -- ")
        self.number_of_instances_text.setAlignment(Qt.AlignCenter)
        number_of_instances_box.addWidget(number_of_instances_label, 1)
        number_of_instances_box.addWidget(self.number_of_instances_text, 2)

        number_of_feature_label = QLabel("Number of Features :")
        number_of_feature_label.setAlignment(Qt.AlignCenter)
        self.number_of_feature_text = QLabel(" -- ")
        self.number_of_feature_text.setAlignment(Qt.AlignCenter)
        number_of_feature_box.addWidget(number_of_feature_label, 1)
        number_of_feature_box.addWidget(self.number_of_feature_text, 2)

        number_of_class_label = QLabel("Number of Classes :")
        number_of_class_label.setAlignment(Qt.AlignCenter)
        self.number_of_class_text = QLabel(" -- ")
        self.number_of_class_text.setAlignment(Qt.AlignCenter)
        number_of_class_box.addWidget(number_of_class_label, 1)
        number_of_class_box.addWidget(self.number_of_class_text, 2)

        vbox.addLayout(name_box)
        vbox.addLayout(number_of_instances_box)
        vbox.addLayout(number_of_feature_box)
        vbox.addLayout(number_of_class_box)
        self.__dataInfo_box.setLayout(vbox)

    def __set_setting_box_UI(self):
        self.__setting_box = QGroupBox('Setting')
        vbox = QVBoxLayout()
        weight_box = QHBoxLayout()
        initalize_box = QHBoxLayout()
        split_box = QHBoxLayout()

        weight_label = QLabel("Initalize the weight with Value :")
        weight_label.setAlignment(Qt.AlignCenter)
        self.initial_weight = QLabel("--")
        self.initial_weight.setAlignment(Qt.AlignCenter)
        self.reset_weight_btn = QPushButton("Reset weight")
        self.reset_weight_btn.setEnabled(False)
        self.reset_weight_btn.setStatusTip(
            'Reset initial weight with value from -1 to 1')
        self.reset_weight_btn.clicked.connect(self.update_initial_weight)
        weight_box.addWidget(weight_label, 3)
        weight_box.addWidget(self.initial_weight, 3)
        weight_box.addWidget(self.reset_weight_btn, 1)

        learning_rate_label = QLabel("Learning Rate :")
        learning_rate_label.setAlignment(Qt.AlignCenter)
        self.learning_rate_text = QLineEdit()
        self.learning_rate_text.setStatusTip('The learning rate of training')
        initalize_box.addWidget(learning_rate_label, 2)
        initalize_box.addWidget(self.learning_rate_text, 2)
        initalize_box.addStretch(1)

        training_times_label = QLabel("Maximum Training Times :")
        training_times_label.setAlignment(Qt.AlignCenter)
        self.training_times_text = QLineEdit()
        self.training_times_text.setStatusTip(
            'Using training times as converge condition')
        initalize_box.addWidget(training_times_label, 2)
        initalize_box.addWidget(self.training_times_text, 2)

        propotion_of_test_label = QLabel("Propotion of Testing Data (%) :")
        propotion_of_test_label.setAlignment(Qt.AlignCenter)
        self.propotion_of_test_text = QLineEdit()
        self.propotion_of_test_text.setStatusTip('testing_data / all_data = ?')
        split_box.addWidget(propotion_of_test_label, 3)
        split_box.addWidget(self.propotion_of_test_text, 3)
        split_box.addStretch(1)

        vbox.addLayout(weight_box)
        vbox.addLayout(initalize_box)
        vbox.addLayout(split_box)
        self.__setting_box.setLayout(vbox)

    def __set_control_box_UI(self):
        self.__control_box = QGroupBox('Control Panel')
        vbox = QVBoxLayout()
        control_box = QHBoxLayout()
        start_box = QHBoxLayout()

        self.confirm_btn = QPushButton("Confirm")
        self.confirm_btn.setStatusTip(
            'Confirm to use : initial weight,  learning rate, maximum training times, propotion of testing data'
        )
        self.confirm_btn.setEnabled(False)
        self.confirm_btn.clicked.connect(self.confirm_data)

        self.load_training_data_btn = QPushButton("Load training data")
        self.load_training_data_btn.setStatusTip(
            'Randomly Split dataset into training part and testing part. Load training data and draw it.'
        )
        self.load_training_data_btn.setEnabled(False)
        self.load_training_data_btn.clicked.connect(self.load_training)

        self.load_testing_data_btn = QPushButton("Load testing data")
        self.load_testing_data_btn.setStatusTip(
            'Randomly Split dataset into training part and testing part. Load testing data and draw it.'
        )
        self.load_testing_data_btn.setEnabled(False)
        self.load_testing_data_btn.clicked.connect(self.load_testing)

        control_box.addWidget(self.confirm_btn, 2)
        control_box.addWidget(self.load_training_data_btn, 2)
        control_box.addWidget(self.load_testing_data_btn, 2)

        self.redo_btn = QPushButton("Redo Training")
        self.redo_btn.setStatusTip(
            'Reset all parameter and redo again for better result.')
        self.redo_btn.setEnabled(False)
        self.redo_btn.clicked.connect(self.redo)

        self.start_training_btn = QPushButton("Start Training")
        self.start_training_btn.setStatusTip(
            'Start training data with parameter above.')
        self.start_training_btn.setEnabled(False)
        self.start_training_btn.clicked.connect(self.start_training)

        self.start_testing_btn = QPushButton("Start Testing")
        self.start_testing_btn.setStatusTip(
            'Testing remaining dataset with weight which get from training.')
        self.start_testing_btn.setEnabled(False)
        self.start_testing_btn.clicked.connect(self.start_testing)

        start_box.addWidget(self.redo_btn, 2)
        start_box.addWidget(self.start_training_btn, 2)
        start_box.addWidget(self.start_testing_btn, 2)

        vbox.addLayout(control_box)
        vbox.addLayout(start_box)
        self.__control_box.setLayout(vbox)

    def __set_figure_box_UI(self):
        self.__figure_box = QGroupBox('Figure')
        self.figure = Figure()
        self.canvas = FigureCanvas(self.figure)
        self.ax = self.figure.add_subplot(111)

        hbox = QHBoxLayout()
        hbox.addWidget(self.canvas)
        self.__figure_box.setLayout(hbox)

    def __set_result_box_UI(self):
        self.__result_box = QGroupBox('Result')
        vbox = QVBoxLayout()
        weight_box = QHBoxLayout()
        tranining_times_result_box = QHBoxLayout()
        training_result_box = QHBoxLayout()
        testing_result_box = QHBoxLayout()

        weight_result_label = QLabel("Weight : ")
        weight_result_label.setAlignment(Qt.AlignCenter)
        self.weight_result_text = QLabel(" -- ")
        self.weight_result_text.setAlignment(Qt.AlignCenter)
        weight_box.addWidget(weight_result_label, 3)
        weight_box.addWidget(self.weight_result_text, 2)

        training_times_result_label = QLabel("Training times : ")
        training_times_result_label.setAlignment(Qt.AlignCenter)
        self.training_times_result_text = QLabel(" -- ")
        self.training_times_result_text.setAlignment(Qt.AlignCenter)
        tranining_times_result_box.addWidget(training_times_result_label, 3)
        tranining_times_result_box.addWidget(self.training_times_result_text,
                                             2)

        training_result_label = QLabel("Recognition rate of training (%) : ")
        training_result_label.setAlignment(Qt.AlignCenter)
        self.training_result_text = QLabel(" -- ")
        self.training_result_text.setAlignment(Qt.AlignCenter)
        training_result_box.addWidget(training_result_label, 3)
        training_result_box.addWidget(self.training_result_text, 2)

        testing_result_label = QLabel("Recognition rate of testing (%) : ")
        testing_result_label.setAlignment(Qt.AlignCenter)
        self.testing_result_text = QLabel(" -- ")
        self.testing_result_text.setAlignment(Qt.AlignCenter)
        testing_result_box.addWidget(testing_result_label, 3)
        testing_result_box.addWidget(self.testing_result_text, 2)

        vbox.addLayout(weight_box)
        vbox.addLayout(tranining_times_result_box)
        vbox.addLayout(training_result_box)
        vbox.addLayout(testing_result_box)
        self.__result_box.setLayout(vbox)

    @pyqtSlot()
    def load_file(self):
        self.file_name = str(self.file_cb.currentText())
        self.feature, self.label = support.load_file_info(self.file_name)
        self.individual_label = support.get_individual_label(self.label)
        if (len(self.feature) > 2):
            QMessageBox.about(
                self, "Warning",
                "Exist more than 2 features, automatically ignore them.")
        if (len(self.individual_label) > 2):
            QMessageBox.about(
                self, "Warning",
                "Exist more than 3 classes, treat them as class 2.")
            self.label = support.handle_as_noise(self.label,
                                                 self.individual_label)
        self.update_file_info()
        self.update_initial_weight()
        self.draw_points("all")
        # set GUI
        self.reset_weight_btn.setEnabled(True)
        self.confirm_btn.setEnabled(True)
        self.load_training_data_btn.setEnabled(False)
        self.load_testing_data_btn.setEnabled(False)
        self.start_training_btn.setEnabled(False)
        self.start_testing_btn.setEnabled(False)
        self.redo_btn.setEnabled(False)
        self.learning_rate_text.setText("0.8")
        self.training_times_text.setText("100")
        self.propotion_of_test_text.setText("33")
        self.confirm_btn.setText("Confirm")
        self.reset_result_text()

    @pyqtSlot()
    def update_initial_weight(self):
        self.weight = [-1]
        for i in range(self.dimension):
            rand_num = round(float(support.get_random_weight()), 3)
            self.weight.append(rand_num)
        self.initial_weight.setText(str(self.weight))

    @pyqtSlot()
    def confirm_data(self):
        if (self.learning_rate_text.text() == ""):
            self.learning_rate_text.setText("0.8")
        if (self.training_times_text.text() == ""):
            self.training_times_text.setText("100")
        if (self.propotion_of_test_text.text() == ""):
            self.propotion_of_test_text.setText("33")
        self.learning_rate = float(self.learning_rate_text.text())
        self.training_times = int(self.training_times_text.text())
        self.pro_of_test = float(int(self.propotion_of_test_text.text()) / 100)
        self.split_train_test_data()
        # set GUI
        self.load_training_data_btn.setEnabled(True)
        self.confirm_btn.setEnabled(False)
        self.reset_weight_btn.setEnabled(False)
        self.reset_result_text()

    @pyqtSlot()
    def load_training(self):
        self.draw_points("training")
        # set GUI
        self.load_training_data_btn.setEnabled(False)
        self.start_training_btn.setEnabled(True)

    @pyqtSlot()
    def load_testing(self):
        self.draw_points("testing")
        # set GUI
        self.load_testing_data_btn.setEnabled(False)
        self.start_training_btn.setEnabled(False)
        self.redo_btn.setEnabled(False)
        self.load_testing_data_btn.setEnabled(True)
        self.start_testing_btn.setEnabled(True)

    @pyqtSlot()
    def start_training(self):
        self.weight_result, self.training_times_result, proc_weight = support.do_training(
            self.feature_train, self.label_train, self.individual_label,
            self.weight, self.learning_rate, self.training_times)
        self.drawer = Drawer(self.canvas, self.ax, self.feature, proc_weight)
        self.drawer.start()
        self.drawer.finish.connect(self.update_training_result)
        self.start_training_btn.setEnabled(False)

    @pyqtSlot()
    def update_training_result(self):
        self.recog_train = support.get_recognition(self.feature_train,
                                                   self.label_train,
                                                   self.weight_result,
                                                   self.individual_label)
        weight_res = [round(w, 3) for w in self.weight]
        # set GUI
        self.weight_result_text.setText(str(weight_res))
        self.training_result_text.setText(str(self.recog_train))
        self.training_times_result_text.setText(str(
            self.training_times_result))
        self.load_testing_data_btn.setEnabled(True)
        self.redo_btn.setEnabled(True)

    @pyqtSlot()
    def start_testing(self):
        # plot line
        min_x1, max_x1 = min(self.feature[0]) - 0.5, max(self.feature[0]) + 0.5
        self.ax.plot([min_x1, max_x1], [
            support.find_x2(self.weight[0], self.weight[1], self.weight[2],
                            min_x1),
            support.find_x2(self.weight[0], self.weight[1], self.weight[2],
                            max_x1)
        ],
                     color='orange',
                     linewidth=2)
        self.canvas.draw()
        # get recognition
        self.recog_test = support.get_recognition(self.feature_test,
                                                  self.label_test,
                                                  self.weight_result,
                                                  self.individual_label)
        self.testing_result_text.setText(str(self.recog_test))
        # set GUI
        self.reset_weight_btn.setEnabled(True)
        self.confirm_btn.setEnabled(True)
        self.confirm_btn.setText("Redo Again")
        self.load_training_data_btn.setEnabled(False)
        self.load_testing_data_btn.setEnabled(False)
        self.start_training_btn.setEnabled(False)
        self.start_testing_btn.setEnabled(False)
        self.redo_btn.setEnabled(False)

    @pyqtSlot()
    def redo(self):
        # set GUI
        self.confirm_btn.setEnabled(True)
        self.load_training_data_btn.setEnabled(False)
        self.start_training_btn.setEnabled(False)
        self.load_testing_data_btn.setEnabled(False)
        self.start_testing_btn.setEnabled(False)
        self.redo_btn.setEnabled(False)
        self.reset_weight_btn.setEnabled(True)
        self.training_result_text.setText(" -- ")
        self.training_times_result_text.setText(" -- ")
        self.weight_result_text.setText(" -- ")
        self.learning_rate_text.clear()
        self.training_times_text.clear()
        self.propotion_of_test_text.clear()

    def update_file_info(self):
        self.dimension = len(self.feature)
        self.name_text.setText(self.file_name)
        # set GUI
        self.number_of_feature_text.setText(str(self.dimension))
        self.number_of_class_text.setText(str(len(self.individual_label)))
        self.number_of_instances_text.setText(str(len(self.label)))

    def split_train_test_data(self):
        self.feature_train, self.label_train, self.feature_test, self.label_test = support.split_train_test_data(
            self.feature, self.label, self.pro_of_test)

    def draw_points(self, mode):
        self.ax.clear()
        if (mode == "all"):
            label1_x1, label1_x2, label2_x1, label2_x2 = support.get_seperate_points(
                self.feature, self.label, self.individual_label)
            self.ax.set_title("All data")
            self.ax.set_xlim(
                min(self.feature[0]) - 0.5,
                max(self.feature[0]) + 0.5)
            self.ax.set_ylim(
                min(self.feature[1]) - 0.5,
                max(self.feature[1]) + 0.5)
            self.ax.scatter(label1_x1, label1_x2, c='blue', s=10)
            self.ax.scatter(label2_x1, label2_x2, c='green', s=10)
            self.canvas.draw()
        elif (mode == "training"):
            label1_x1, label1_x2, label2_x1, label2_x2 = support.get_seperate_points(
                self.feature_train, self.label_train, self.individual_label)
            self.ax.set_title("Training data")
            self.ax.set_xlim(
                min(self.feature_train[0]) - 0.5,
                max(self.feature_train[0]) + 0.5)
            self.ax.set_ylim(
                min(self.feature_train[1]) - 0.5,
                max(self.feature_train[1]) + 0.5)
            self.ax.scatter(label1_x1, label1_x2, c='blue', s=10)
            self.ax.scatter(label2_x1, label2_x2, c='green', s=10)
            self.canvas.draw()
        elif (mode == "testing"):
            label1_x1, label1_x2, label2_x1, label2_x2 = support.get_seperate_points(
                self.feature_test, self.label_test, self.individual_label)
            self.ax.set_title("Testing data")
            self.ax.set_xlim(
                min(self.feature_test[0]) - 0.5,
                max(self.feature_test[0]) + 0.5)
            self.ax.set_ylim(
                min(self.feature_test[1]) - 0.5,
                max(self.feature_test[1]) + 0.5)
            self.ax.scatter(label1_x1, label1_x2, c='blue', s=10)
            self.ax.scatter(label2_x1, label2_x2, c='green', s=10)
            self.canvas.draw()

    def reset_result_text(self):
        self.weight_result_text.setText(" -- ")
        self.training_result_text.setText(" -- ")
        self.testing_result_text.setText(" -- ")
        self.training_times_result_text.setText(" -- ")
Exemple #18
0
class Interface(QMainWindow):
    current_graph = 'Scatter Plot'
    plot_type = [
        'Scatter Plot', 'Line Plot', 'Histogram', 'Pie Chart', 'Bar Chart',
        'Double Bar Chart'
    ]
    header_labels = {
        'Scatter Plot': 'X Data, Y Data',
        'Line Plot': 'X Data, Y Data',
        'Histogram': 'X Data, Y Data',
        'Pie Chart': 'Percent',
        'Bar Chart': 'X Label, Height',
        'Double Bar Chart': 'X Label, Height1, Height2'
    }
    data_types = {
        'Scatter Plot': ['float', 'float'],
        'Line Plot': ['float', 'float'],
        'Histogram': ['float', 'float'],
        'Pie Chart': ['float'],
        'Bar Chart': ['str', 'float'],
        'Double Bar Chart': ['str', 'float', 'float']
    }
    function_calls = {
        'Scatter Plot': 'scatter',
        'Line Plot': 'plot',
        'Histogram': 'hist',
        'Pie Chart': 'pie',
        'Bar Chart': 'bar',
        'Double Bar Chart': 'beast'
    }
    function_kwargs = {
        'Scatter Plot': None,
        'Line Plot': None,
        'Histogram': None,
        'Pie Chart': "autopct='%1.1f%%'",
        'Bar Chart': None,
        'Double Bar Chart': None
    }

    def __init__(self):
        super().__init__()
        self.setWindowTitle('Live Plotter')
        self.size_policy = QSizePolicy.Expanding
        self.font = QFont()
        self.font.setPointSize(12)
        self.geometry()
        self.menu()
        self.showMaximized()
        self.show()

    def menu(self):
        self.menuType = self.menuBar().addMenu('&Chart Type')
        self.actions = {}
        for i in self.plot_type:
            self.actions[i] = self.action_definer(i)
            self.menuType.addAction(self.actions[i])

    def action_definer(self, label):
        action = QAction('&{}'.format(label), self, checkable=True)
        action.setFont(self.font)
        action.triggered.connect(lambda: self.action_changed(label))
        if label == self.plot_type[0]:
            action.setChecked(True)
        return action

    def action_changed(self, label):
        self.text_format.setText(self.header_labels[label])
        for i in self.actions:
            if i != label:
                self.actions[i].setChecked(False)
            else:
                self.actions[i].setChecked(True)
                self.current_graph = i
                self.update_graph()

    def geometry(self):
        #configure the main graph
        self.plot_window = QWidget()
        layout = QVBoxLayout()
        self.figure = Figure()
        self._canvas = FigureCanvas(self.figure)
        self.toolbar = CustomToolbar(self._canvas, self)
        layout.addWidget(self.toolbar)
        layout.addWidget(self._canvas)
        self.plot_window.setLayout(layout)
        self.setCentralWidget(self.plot_window)
        self.axis = self._canvas.figure.subplots()
        self.figure.tight_layout()
        self.setCentralWidget(self.plot_window)

        #now configure the right side to be a plain text editor
        right_dock = QDockWidget('Data')
        right_widget = QWidget()
        self.text_format = QLabel(self.header_labels['Scatter Plot'])
        self.text_format.setFont(self.font)

        # self.data_editor=QTextEdit(self)
        self.data_editor = Custom_Text_Editor(self)
        self.data_editor.setSizePolicy(self.size_policy, self.size_policy)
        self.data_editor.setFont(self.font)
        self.data_editor.installEventFilter(self)

        self.hold_plot_on = QCheckBox('Reset Graph')
        self.hold_plot_on.setChecked(True)
        self.hold_plot_on.setFont(self.font)

        self.process = QPushButton('Update')
        self.process.setFont(self.font)
        self.process.clicked.connect(self.update_graph)

        layout = QVBoxLayout()
        layout.addWidget(self.text_format)
        layout.addWidget(self.data_editor)
        layout.addWidget(self.hold_plot_on)
        layout.addWidget(self.process)
        right_widget.setLayout(layout)
        right_dock.setWidget(right_widget)
        self.addDockWidget(Qt.LeftDockWidgetArea, right_dock)

    def eventFilter(self, obj, event):
        if event.type() == QEvent.KeyPress and obj is self.data_editor:
            if event.key() == Qt.Key_Return or event.key()==Qt.Key_Enter \
                and self.data_editor.hasFocus():
                self.update_graph()
        return super().eventFilter(obj, event)

    def update_graph(self):
        if self.hold_plot_on.isChecked():
            self.axis.clear()
        #get the data from the text editor first
        data = self.data_import(self.data_editor.toPlainText())
        #now actually plot the data
        d_lists = ''
        for i in range(len(data) - 1):
            d_lists += 'data[{}],'.format(i)
        d_lists += 'data[{}]'.format(len(data) - 1)
        if self.function_kwargs[self.current_graph] != None:
            d_lists += ',{}'.format(self.function_kwargs[self.current_graph])
        try:
            exec('{}.{}({})'.format('self.axis',
                                    self.function_calls[self.current_graph],
                                    d_lists))
        except Exception as e:
            print(e)
        #now handle the various options for the pie charts and bar charts
        self._canvas.draw()
        self.figure.tight_layout()

    def data_import(self, text):
        columns = text.split(sep='\n')
        #get the data types to try and convert based on the selected
        #graph style
        d_type = self.data_types[self.current_graph]
        data = {}
        for i in range(len(d_type)):
            data[i] = []

        for i in columns:
            splitr = i.split(sep=',')
            for j in range(len(splitr)):
                value = splitr[j]
                if value != '':
                    try:
                        if d_type[j] != 'str':
                            data[j].append(
                                eval('{}({})'.format(d_type[j], value)))
                        else:
                            data[j].append(value)
                    except Exception as e:
                        print(e)
        return data
Exemple #19
0
class Crossover(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.radar_transmit_power.returnPressed.connect(self._update_canvas)
        self.radar_antenna_gain.returnPressed.connect(self._update_canvas)
        self.radar_antenna_sidelobe.returnPressed.connect(self._update_canvas)
        self.radar_bandwidth.returnPressed.connect(self._update_canvas)
        self.radar_losses.returnPressed.connect(self._update_canvas)
        self.target_rcs.returnPressed.connect(self._update_canvas)
        self.jammer_bandwidth.returnPressed.connect(self._update_canvas)
        self.jammer_erp.returnPressed.connect(self._update_canvas)
        self.jammer_range.returnPressed.connect(self._update_canvas)
        self.jammer_type.currentIndexChanged.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.fig = fig
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea,
                        NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes and input value.
        :return:
        """
        # Get the parameters from the form
        jammer_erp = self.jammer_erp.text().split(',')
        jammer_erp_vector = linspace(float(jammer_erp[0]),
                                     float(jammer_erp[1]), 1000)

        # Load the selected target
        jammer_type = self.jammer_type.currentText()

        # Set up the input args based on jammer type
        if jammer_type == 'Self Screening':
            kwargs = {
                'peak_power': float(self.radar_transmit_power.text()),
                'antenna_gain':
                10**(float(self.radar_antenna_gain.text()) / 10.0),
                'target_rcs': 10**(float(self.target_rcs.text()) / 10.0),
                'jammer_bandwidth': float(self.jammer_bandwidth.text()),
                'effective_radiated_power': 10**(jammer_erp_vector / 10.0),
                'radar_bandwidth': float(self.radar_bandwidth.text()),
                'losses': 10**(float(self.radar_losses.text()) / 10.0)
            }

            # Calculate the jammer to signal ratio
            crossover_range = countermeasures.crossover_range_selfscreen(
                **kwargs)

        elif jammer_type == 'Escort':
            kwargs = {
                'peak_power':
                float(self.radar_transmit_power.text()),
                'antenna_gain':
                10**(float(self.radar_antenna_gain.text()) / 10.0),
                'target_rcs':
                10**(float(self.target_rcs.text()) / 10.0),
                'jammer_range':
                float(self.jammer_range.text()) * 1e3,
                'jammer_bandwidth':
                float(self.jammer_bandwidth.text()),
                'effective_radiated_power':
                10**(jammer_erp_vector / 10.0),
                'radar_bandwidth':
                float(self.radar_bandwidth.text()),
                'losses':
                10**(float(self.radar_losses.text()) / 10.0),
                'antenna_gain_jammer_direction':
                10**(float(self.radar_antenna_sidelobe.text()) / 10.0)
            }

            # Calculate the jammer to signal ratio
            crossover_range = countermeasures.crossover_range_escort(**kwargs)

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Display the results
        self.axes1.plot(jammer_erp_vector, crossover_range, '')

        # Set the plot title and labels
        self.axes1.set_title('Crossover Range', size=14)
        self.axes1.set_xlabel('Jammer ERP (dBW)', size=12)
        self.axes1.set_ylabel('Crossover Range (m)', size=12)

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Update the canvas
        self.my_canvas.draw()
class LinearWire(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.frequency.returnPressed.connect(self._update_canvas)
        self.current.returnPressed.connect(self._update_canvas)
        self.length.returnPressed.connect(self._update_canvas)
        self.antenna_type.currentIndexChanged.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.axes1 = fig.add_subplot(111, projection='polar')
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea, NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value
        :return:
        """
        # Get the parameters from the form
        frequency = float(self.frequency.text())
        length = float(self.length.text())
        current = float(self.current.text())

        # Get the selected antenna from the form
        antenna_type = self.antenna_type.currentIndex()

        # Set the range and angular span
        r = 1.0e9
        theta = linspace(finfo(float).eps, 2.0 * pi, 1000)

        # Get the antenna parameters and antenna pattern for the selected antenna
        if antenna_type == 0:
            total_power_radiated = infinitesimal_dipole.radiated_power(frequency, length, current)
            radiation_resistance = infinitesimal_dipole.radiation_resistance(frequency, length)
            beamwidth = infinitesimal_dipole.beamwidth()
            directivity = infinitesimal_dipole.directivity()
            maximum_effective_aperture = infinitesimal_dipole.maximum_effective_aperture(frequency)
            _, et, _, _, _, _ = infinitesimal_dipole.far_field(frequency, length, current, r, theta)
        elif antenna_type == 1:
            total_power_radiated = small_dipole.radiated_power(frequency, length, current)
            radiation_resistance = small_dipole.radiation_resistance(frequency, length)
            beamwidth = small_dipole.beamwidth()
            directivity = small_dipole.directivity()
            maximum_effective_aperture = small_dipole.maximum_effective_aperture(frequency)
            _, et, _, _, _, _ = small_dipole.far_field(frequency, length, current, r, theta)
        else:
            total_power_radiated = finite_length_dipole.radiated_power(frequency, length, current)
            radiation_resistance = finite_length_dipole.radiation_resistance(frequency, length)
            beamwidth = finite_length_dipole.beamwidth(frequency, length)
            directivity = finite_length_dipole.directivity(frequency, length, current)
            maximum_effective_aperture = finite_length_dipole.maximum_effective_aperture(frequency, length, current)
            _, et, _, _, _, _ = finite_length_dipole.far_field(frequency, length, current, r, theta)

        # Populate the form with the correct values
        self.total_radiated_power.setText('{:.2f}'.format(total_power_radiated))
        self.radiation_resistance.setText('{:.2f}'.format(radiation_resistance))
        self.beamwidth.setText('{:.2f}'.format(beamwidth))
        self.directivity.setText('{:.2f}'.format(directivity))
        self.maximum_effective_aperture.setText('{:.2f}'.format(maximum_effective_aperture))

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Display the results
        self.axes1.plot(theta, abs(et), '')

        # Set the plot title and labels
        self.axes1.set_title('Linear Wire Antenna Pattern', size=14)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Update the canvas
        self.my_canvas.draw()
Exemple #21
0
class Ducting(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.duct_thickness.returnPressed.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea,
                        NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value.
        :return:
        """

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Set up the duct thickness
        duct_thickness = self.duct_thickness.text().split(',')
        dts = [float(x) for x in duct_thickness]

        # Set up the refractivity gradient
        refractivity_gradient = linspace(-500., -150., 1000)

        line_style = ['-', '--', '-.', ':', '-', '--', '-.', ':']

        # Calculate the critical angle for ducting
        i = 0
        for dt in dts:
            critical_angle = ducting.critical_angle(refractivity_gradient, dt)
            i += 1

            # Display the results
            self.axes1.plot(refractivity_gradient,
                            critical_angle * 1e3,
                            line_style[i - 1],
                            label="Thickness {} (m)".format(dt))

        # Set the plot title and labels
        self.axes1.set_title('Ducting over a Spherical Earth', size=14)
        self.axes1.set_xlabel('Refractivity Gradient (N/km)', size=12)
        self.axes1.set_ylabel('Critical Angle (mrad)', size=12)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Set the legend
        self.axes1.legend(loc='best', prop={'size': 10})

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Update the canvas
        self.my_canvas.draw()
Exemple #22
0
class ACGui(QWidget):
    def __init__(self):

        super().__init__()

        self.initUI()

    def initUI(self):

        self.setWindowTitle('AC processing')
        self.setWindowIcon(QIcon('double_well_potential_R6p_icon.ico'))
        self.headline_font = QFont()
        self.headline_font.setBold(True)

        #self.taskbar_btn = QWinTaskbarButton()
        #self.taskbar_btn.setWindow(self.container)

        self.main_layout = QHBoxLayout()

        # Data containers
        self.data_file_origin = None
        self.data_T = None
        self.data_tau = None

        self.data_used_pointer = None
        self.data_not_used_pointer = None
        self.plot_of_fit_pointer = None

        self.used_T = None
        self.not_used_T = None

        self.used_tau = None
        self.not_used_tau = None

        self.fitted_parameters = None
        self.used_indices = None

        self.simulation_items = []
        """ Adding load controls """
        self.load_layout = QVBoxLayout()
        self.load_layout.addStretch(1)

        self.load_btn = QPushButton('Load')
        self.load_btn.clicked.connect(self.load_t_tau_data)
        self.load_layout.addWidget(self.load_btn)

        self.reset_axes_btn = QPushButton('Reset axes')
        self.reset_axes_btn.clicked.connect(self.see_all_on_axes)
        self.load_layout.addWidget(self.reset_axes_btn)
        """ Adding plotting controls """
        self.plot_layout = QVBoxLayout()

        # Adding the canvas for plotting relaxation times
        self.tau_t_fig = Figure(figsize=(5, 3))
        self.tau_t_canvas = FigureCanvas(self.tau_t_fig)
        self.tau_t_tools = NavigationToolbar(self.tau_t_canvas, self)
        self.tau_t_ax = self.tau_t_fig.add_subplot(111)

        # Plotting in the axes
        self.tau_t_ax.set_xlabel('Temperature [$K^{-1}$]')
        self.tau_t_ax.set_ylabel(r'$\ln{\tau}$ [$\ln{s}$]')

        self.plot_layout.addWidget(self.tau_t_canvas)
        self.plot_layout.addWidget(self.tau_t_tools)

        # Adding fit controls
        self.fit_layout = QVBoxLayout()

        self.cb_headline = QLabel('Fit types to consider')
        self.cb_headline.setFont(self.headline_font)
        self.fit_layout.addWidget(self.cb_headline)

        self.orbach_cb = QCheckBox('Orbach')
        self.orbach_cb.stateChanged.connect(self.read_fit_type_cbs)
        self.fit_layout.addWidget(self.orbach_cb)

        self.qt_cb = QCheckBox('QT')
        self.qt_cb.stateChanged.connect(self.read_fit_type_cbs)
        self.fit_layout.addWidget(self.qt_cb)

        self.raman_cb = QCheckBox('Raman')
        self.raman_cb.stateChanged.connect(self.read_fit_type_cbs)
        self.fit_layout.addWidget(self.raman_cb)

        # Adding temperature controls
        self.temp_headline = QLabel('Temperature interval')
        self.temp_headline.setFont(self.headline_font)
        self.fit_layout.addWidget(self.temp_headline)

        self.temp_horizontal_layout = QHBoxLayout()
        self.temp_line = [
            QLabel('('),
            QDoubleSpinBox(),
            QLabel(','),
            QDoubleSpinBox(),
            QLabel(')')
        ]

        self.temp_line[1].setRange(0, self.temp_line[3].value())
        self.temp_line[1].setSingleStep(0.1)
        self.temp_line[3].setRange(self.temp_line[1].value(), 1000)
        self.temp_line[3].setSingleStep(0.1)

        self.temp_line[1].editingFinished.connect(self.set_new_temp_ranges)
        self.temp_line[3].editingFinished.connect(self.set_new_temp_ranges)
        for w in self.temp_line:
            self.temp_horizontal_layout.addWidget(w)

        self.fit_layout.addLayout(self.temp_horizontal_layout)

        # Adding a button to run a fit
        self.run_fit_btn = QPushButton('Run fit!')
        self.run_fit_btn.clicked.connect(self.make_the_fit)
        self.fit_layout.addWidget(self.run_fit_btn)

        # Adding a list to hold information about simulations
        self.simulations_headline = QLabel('Simulations')
        self.simulations_headline.setFont(self.headline_font)
        self.fit_layout.addWidget(self.simulations_headline)

        self.list_of_simulations = QListWidget()
        self.list_of_simulations.doubleClicked.connect(self.edit_a_simulation)
        self.fit_layout.addWidget(self.list_of_simulations)

        # Adding buttons to control simulation list
        self.sim_btn_layout = QHBoxLayout()

        self.delete_sim_btn = QPushButton('Delete')
        self.delete_sim_btn.clicked.connect(self.delete_sim)
        self.sim_btn_layout.addWidget(self.delete_sim_btn)

        self.edit_sim_btn = QPushButton('Edit')
        self.edit_sim_btn.clicked.connect(self.edit_a_simulation)
        self.sim_btn_layout.addWidget(self.edit_sim_btn)

        self.new_sim_btn = QPushButton('New')
        self.new_sim_btn.clicked.connect(self.add_new_simulation)
        self.sim_btn_layout.addWidget(self.new_sim_btn)

        self.fit_layout.addLayout(self.sim_btn_layout)

        # Finalizing layout and showing the GUI
        self.main_layout.addLayout(self.load_layout)
        self.main_layout.addLayout(self.plot_layout)
        self.main_layout.addLayout(self.fit_layout)
        self.setLayout(self.main_layout)
        self.show()

    def see_all_on_axes(self):

        s = 0
        if len(self.tau_t_ax.lines) < 1: pass
        else:
            while True:
                start = self.tau_t_ax.lines[s]
                if len(start.get_xdata()) < 1:
                    s += 1
                else:
                    break

            x = start.get_xdata()
            y = start.get_ydata()

            new_x = [x.min(), x.max()]
            new_y = [y.min(), y.max()]

            for i in range(s + 1, len(self.tau_t_ax.lines)):
                x = self.tau_t_ax.lines[i].get_xdata()
                y = self.tau_t_ax.lines[i].get_ydata()

                if len(x) > 1 and len(y) > 1:
                    if x.min() < new_x[0]: new_x[0] = x.min()
                    if x.max() > new_x[1]: new_x[1] = x.max()
                    if y.min() < new_y[0]: new_y[0] = y.min()
                    if y.max() > new_y[1]: new_y[1] = y.max()

            self.tau_t_ax.set_xlim(new_x[0] - 0.1 * (new_x[1] - new_x[0]),
                                   new_x[1] + 0.1 * (new_x[1] - new_x[0]))
            self.tau_t_ax.set_ylim(new_y[0] - 0.1 * (new_y[1] - new_y[0]),
                                   new_y[1] + 0.1 * (new_y[1] - new_y[0]))
            self.tau_t_canvas.draw()

    def add_new_simulation(self):

        sim_dialog = SimulationDialog(fitted_parameters=self.fitted_parameters,
                                      plot_type_list=[],
                                      plot_parameters={
                                          'tQT': 0.01,
                                          'Cr': 0.00,
                                          'n': 0.00,
                                          't0': 0.00,
                                          'Ueff': 0.00
                                      },
                                      min_and_max_temps=[0, 0])

        finished_value = sim_dialog.exec_()

        if finished_value:

            plot_type = sim_dialog.plot_type_list
            if len(plot_type) < 1:
                pass
            else:
                p_fit = sim_dialog.plot_parameters
                T_vals = sim_dialog.min_and_max_temps

                plot_to_make = ''.join(plot_type)
                new_item_text = '{}, ({},{}), tQT: {}, Cr: {}, n: {}, t0: {}, Ueff: {}'.format(
                    plot_type, T_vals[0], T_vals[1], p_fit['tQT'], p_fit['Cr'],
                    p_fit['n'], p_fit['t0'], p_fit['Ueff'])

                new_list_item = QListWidgetItem()

                line = addPartialModel(
                    self.tau_t_fig,
                    T_vals[0],
                    T_vals[1],
                    self.prepare_sim_dict_for_plotting(p_fit),
                    plotType=plot_to_make)

                list_item_data = {
                    'plot_type': plot_type,
                    'p_fit': p_fit,
                    'T_vals': T_vals,
                    'line': line
                }

                self.list_of_simulations.addItem(new_list_item)
                new_list_item.setText(new_item_text)
                new_list_item.setData(32, list_item_data)

                self.tau_t_canvas.draw()

        else:
            pass

    def edit_a_simulation(self):

        try:
            sim_item = self.list_of_simulations.selectedItems()[0]
        except IndexError:
            pass
        else:
            # Reading off information from the selected item
            old_data = sim_item.data(32)
            old_plot_type_input = old_data['plot_type']
            old_p_fit = old_data['p_fit']
            old_T_vals = old_data['T_vals']
            old_line = old_data['line']

            # Opening simulation dialog with old parameters
            sim_dialog = SimulationDialog(
                fitted_parameters=self.fitted_parameters,
                plot_type_list=old_plot_type_input,
                plot_parameters=old_p_fit,
                min_and_max_temps=old_T_vals)

            finished_value = sim_dialog.exec_()

            if finished_value:
                # Reading new parameters of simulation
                new_plot_type = sim_dialog.plot_type_list

                if len(new_plot_type) < 1:
                    pass
                else:
                    new_p_fit = sim_dialog.plot_parameters
                    new_T_vals = sim_dialog.min_and_max_temps

                    plot_to_make = ''.join(new_plot_type)
                    new_item_text = '{}, ({},{}), tQT: {}, Cr: {}, n: {}, t0: {}, Ueff: {}'.format(
                        new_plot_type, new_T_vals[0], new_T_vals[1],
                        new_p_fit['tQT'], new_p_fit['Cr'], new_p_fit['n'],
                        new_p_fit['t0'], new_p_fit['Ueff'])

                    self.tau_t_ax.lines.remove(old_line)

                    new_line = addPartialModel(
                        self.tau_t_fig,
                        new_T_vals[0],
                        new_T_vals[1],
                        self.prepare_sim_dict_for_plotting(new_p_fit),
                        plotType=plot_to_make)

                    list_item_data = {
                        'plot_type': new_plot_type,
                        'p_fit': new_p_fit,
                        'T_vals': new_T_vals,
                        'line': new_line
                    }

                    self.tau_t_canvas.draw()

                    sim_item.setData(32, list_item_data)
                    sim_item.setText(new_item_text)
            else:
                pass

    def delete_sim(self):

        try:
            sim_item = self.list_of_simulations.selectedItems()[0]
        except IndexError:
            pass
        else:
            line_pointer = sim_item.data(32)['line']
            self.tau_t_ax.lines.remove(line_pointer)
            self.tau_t_canvas.draw()

            item_row = self.list_of_simulations.row(sim_item)
            sim_item = self.list_of_simulations.takeItem(item_row)

            del sim_item

    def plot_t_tau_on_axes(self):

        if self.data_used_pointer is not None:
            self.tau_t_ax.lines.remove(self.data_used_pointer)
        if self.data_not_used_pointer is not None:
            self.tau_t_ax.lines.remove(self.data_not_used_pointer)

        self.data_used_pointer, = self.tau_t_ax.plot(1 / self.used_T,
                                                     np.log(self.used_tau),
                                                     'bo')
        self.data_not_used_pointer, = self.tau_t_ax.plot(
            1 / self.not_used_T, np.log(self.not_used_tau), 'ro')

        self.tau_t_canvas.draw()

    def load_t_tau_data(self):

        starting_directory = os.getcwd()
        filename = QFileDialog().getOpenFileName(self, 'Open file',
                                                 starting_directory)

        self.data_file_origin = filename[0]

        try:
            D = np.loadtxt(self.data_file_origin, skiprows=1)
            self.data_T = D[:, 0]
            self.data_tau = D[:, 1]

        except ValueError:
            print('Encountered value error... check your input file!')
        except OSError:
            print('File was not found! Empty file name?')
        else:
            self.read_indices_for_used_temps()
            self.plot_t_tau_on_axes()

    def prepare_sim_dict_for_plotting(self, p_fit_gui_struct):

        params = []
        quantities = []
        sigmas = [0] * 5

        for key, val in p_fit_gui_struct.items():
            params.append(val)
            quantities.append(key)

        Ueff = params[quantities.index('Ueff')]
        params[quantities.index('Ueff')] = Ueff * scicon.Boltzmann

        p_fit_script_type = {
            'params': params,
            'quantities': quantities,
            'sigmas': sigmas
        }

        return p_fit_script_type

    def read_fit_type_cbs(self):

        list_of_checked = []
        if self.qt_cb.isChecked(): list_of_checked.append('QT')
        if self.raman_cb.isChecked(): list_of_checked.append('R')
        if self.orbach_cb.isChecked(): list_of_checked.append('O')
        fitToMake = ''.join(list_of_checked)

        return fitToMake

    def read_indices_for_used_temps(self):

        min_t = self.temp_line[1].value()
        max_t = self.temp_line[3].value()

        try:
            self.used_indices = [
                list(self.data_T).index(t) for t in self.data_T
                if t >= min_t and t <= max_t
            ]

            self.used_T = self.data_T[self.used_indices]
            self.used_tau = self.data_tau[self.used_indices]

            self.not_used_T = np.delete(self.data_T, self.used_indices)
            self.not_used_tau = np.delete(self.data_tau, self.used_indices)

        except (AttributeError, TypeError):
            print('No data have been selected yet!')

    def set_new_temp_ranges(self):

        new_max_for_low = self.temp_line[3].value()
        new_min_for_high = self.temp_line[1].value()
        self.temp_line[1].setRange(0, new_max_for_low)
        self.temp_line[3].setRange(new_min_for_high, 1000)

        self.read_indices_for_used_temps()
        if self.data_T is not None:
            self.plot_t_tau_on_axes()

    def make_the_fit(self):

        try:
            Tmin = self.temp_line[1].value()
            Tmax = self.temp_line[3].value()
            perform_this_fit = self.read_fit_type_cbs()
            assert Tmin != Tmax
            assert perform_this_fit != ''

        except AssertionError:
            msg = QMessageBox()
            msg.setIcon(QMessageBox.Warning)
            msg.setWindowTitle('Fit aborted')
            msg.setText('Check your temperature and fit settings')
            msg.setDetailedText("""Possible errors:
 - min and max temperatures are the same
 - no fit options have been selected""")
            msg.exec_()

        else:
            fig3, p_fit = fitRelaxation(self.data_file_origin, (Tmin, Tmax),
                                        fitType=perform_this_fit)
            self.fitted_parameters = p_fit
Exemple #23
0
class DataViz(QWidget):
    def __init__(self):

        super().__init__()

        self.initUI()

    def initUI(self):

        # Opening a QWidget with a box layout
        LO = QVBoxLayout()

        # Adding the layout for file-related stuff
        fileLO = QHBoxLayout()

        openFilebtn = QPushButton('&Open')
        openFilebtn.clicked.connect(self.openFile)
        fileLO.addWidget(openFilebtn)

        self.currentFilelbl = QLabel()
        fileLO.addWidget(self.currentFilelbl)

        fileLO.addStretch()
        LO.addLayout(fileLO)

        # Adding the layout for frame-related stuff
        frameLO = QHBoxLayout()

        self.showFrameInfobtn = QPushButton('&Frame info')
        self.showFrameInfobtn.clicked.connect(self.displayFrameInfo)
        frameLO.addWidget(self.showFrameInfobtn)

        self.frameNumberBox = QSpinBox()
        self.frameNumberBox.setMinimum(1)
        self.frameNumberBox.valueChanged.connect(self.plotImage)
        frameLO.addWidget(self.frameNumberBox)

        self.noOfFrameslbl = QLabel()
        frameLO.addWidget(self.noOfFrameslbl)

        frameLO.addStretch()
        LO.addLayout(frameLO)

        # Adding the layout for image-related stuff

        sliderLO = QHBoxLayout()
        self.minvalueSlider = QSlider(Qt.Horizontal)
        self.minvalueSlider.sliderMoved.connect(self.plotImage)
        sliderLO.addWidget(self.minvalueSlider)

        self.maxvalueSlider = QSlider(Qt.Horizontal)
        self.maxvalueSlider.sliderMoved.connect(self.plotImage)
        sliderLO.addWidget(self.maxvalueSlider)

        self.areaDetectorImages = Figure()
        self.dynamic_canvas = FigureCanvas(self.areaDetectorImages)
        self.toolbar = NavigationToolbar(self.dynamic_canvas, self)

        self.axU = self.areaDetectorImages.add_subplot(121)
        self.axD = self.areaDetectorImages.add_subplot(122)

        # Containers for the objects that are returned by the imshow command
        self.imageUobject = None
        self.imageDobject = None

        LO.addWidget(self.dynamic_canvas)

        # Adding controls for the plotting
        controlsLO = QHBoxLayout()
        self.colormapBox = QComboBox()
        self.colormapBox.addItems(
            ['hot', 'viridis', 'inferno', 'plasma', 'Greys'])
        self.colormapBox.currentIndexChanged.connect(self.plotImage)

        controlsLO.addWidget(self.toolbar)
        controlsLO.addWidget(self.colormapBox)

        controlsLO.addStretch()
        LO.addLayout(controlsLO)

        LO.addLayout(sliderLO)
        LO.addStretch()

        self.path = None
        self.tree = None
        self.root = None
        self.xmlInfo = None
        self.dataUArray = None
        self.dataDArray = None

        # Setting layout and opening the widget
        self.setLayout(LO)
        self.show()

    def displayFrameInfo(self):

        if self.root is not None:

            frameNo = self.frameNumberBox.value()
            idx = self.xmlInfo['framestartIndex'] + frameNo - 1

            msg = QMessageBox()
            msg.setStandardButtons(QMessageBox.Ok)
            msg.setWindowTitle('Frame info for #{}'.format(frameNo))

            frame = self.root[idx]

            infostr = ''
            info = frame.attrib
            for elem in frame.iter():
                if elem.tag != 'Data':
                    info.update(elem.attrib)
                    info.update({elem.tag: elem.text})

            for key, value in info.items():
                infostr += '{}: {}\n'.format(key, value)

            msg.setText(infostr)
            msg.exec_()

    def openFile(self):

        self.path = QFileDialog.getOpenFileName()[0]

        try:

            self.tree = ET.parse(self.path)
            self.root = self.tree.getroot()

            self.xmlInfo = self.readXMLinfo(self.root)

            self.frameNumberBox.setMaximum(
                len(self.root) - self.xmlInfo['framestartIndex'])
            self.maxvalueSlider.setValue(self.maxvalueSlider.maximum())

            self.currentFilelbl.setText(self.path)
            self.noOfFrameslbl.setText(
                '/{}'.format(len(self.root) - self.xmlInfo['framestartIndex']))

        except (FileNotFoundError, xml.etree.ElementTree.ParseError) as e:

            msg = QMessageBox()
            msg.setText('Something happened! Try again')
            msg.setStandardButtons(QMessageBox.Ok)
            details = ''
            if isinstance(e, xml.etree.ElementTree.ParseError):
                details = 'Selected file could not be parsed as an XML-file'
            elif isinstance(e, FileNotFoundError):
                details = 'The selected file could not be found'
            msg.setDetailedText(details)
            msg.exec_()

        finally:
            self.frameNumberBox.setValue(1)
            self.plotImage()

    def plotImage(self):

        self.axU.clear()
        self.axD.clear()

        try:
            frameNo = self.frameNumberBox.value()

            dataU = self.root[self.xmlInfo['framestartIndex'] + frameNo -
                              1][self.xmlInfo['datastartIndex'] + 0].text
            dataUArray = np.fromstring(dataU, sep=';')
            self.dataUArray = dataUArray.reshape(
                (self.xmlInfo['y'], self.xmlInfo['x']))

            dataD = self.root[self.xmlInfo['framestartIndex'] + frameNo -
                              1][self.xmlInfo['datastartIndex'] + 1].text
            dataDArray = np.fromstring(dataD, sep=';')
            self.dataDArray = dataDArray.reshape(
                (self.xmlInfo['y'], self.xmlInfo['x']))

            self.maxvalueSlider.setMaximum(
                max(np.amax(self.dataDArray), np.amax(self.dataUArray)))
            self.minvalueSlider.setMaximum(self.maxvalueSlider.value())

            self.imageUobject = self.axU.imshow(
                self.dataUArray,
                extent=[0, self.xmlInfo['x'], 0, self.xmlInfo['y']],
                aspect=self.xmlInfo['x'] / self.xmlInfo['y'],
                vmin=self.minvalueSlider.value(),
                vmax=self.maxvalueSlider.value(),
                cmap=self.colormapBox.currentText())

            self.imageDobject = self.axD.imshow(
                self.dataDArray,
                extent=[0, self.xmlInfo['x'], 0, self.xmlInfo['y']],
                aspect=self.xmlInfo['x'] / self.xmlInfo['y'],
                vmin=self.minvalueSlider.value(),
                vmax=self.maxvalueSlider.value(),
                cmap=self.colormapBox.currentText())

            self.dynamic_canvas.draw()

        except TypeError:
            None

    def readXMLinfo(self, root):

        xmlInfo = {}
        xmlInfo['framestartIndex'] = 0
        while root[xmlInfo['framestartIndex']].tag != 'Frame':
            xmlInfo['framestartIndex'] += 1

        xmlInfo['datastartIndex'] = 0
        while root[xmlInfo['framestartIndex']][
                xmlInfo['datastartIndex']].tag != 'Data':
            xmlInfo['datastartIndex'] += 1

        xmlInfo['x'] = int(root[xmlInfo['framestartIndex']][
            xmlInfo['datastartIndex']].attrib['x'])
        xmlInfo['y'] = int(root[xmlInfo['framestartIndex']][
            xmlInfo['datastartIndex']].attrib['y'])

        return xmlInfo

    def get_minvalueSlider(self):

        v = self.minvalueSlider.value()

        return v

    def get_maxvalueSlider(self):

        v = self.maxvalueSlider.value()

        return v

    def keyPressEvent(self, event):

        if event.key() == QtCore.Qt.Key_O:
            N = self.frameNumberBox.value()
            self.frameNumberBox.setValue(N - 1)
        elif event.key() == QtCore.Qt.Key_P:
            N = self.frameNumberBox.value()
            self.frameNumberBox.setValue(N + 1)
            event.accept()
class AlphaBetaGamma(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.time.returnPressed.connect(self._update_canvas)
        self.initial_position.returnPressed.connect(self._update_canvas)
        self.initial_velocity.returnPressed.connect(self._update_canvas)
        self.initial_acceleration.returnPressed.connect(self._update_canvas)
        self.noise.returnPressed.connect(self._update_canvas)
        self.alpha.returnPressed.connect(self._update_canvas)
        self.beta.returnPressed.connect(self._update_canvas)
        self.gamma.returnPressed.connect(self._update_canvas)
        self.plot_type.currentIndexChanged.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.fig = fig
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea, NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes and input value.
        :return:
        """
         # Get the parameters from the form
        time = self.time.text().split(',')
        start = float(time[0])
        end = float(time[1])
        step = float(time[2])
        number_of_updates = round((end - start) / step) + 1
        t, dt = linspace(start, end, number_of_updates, retstep=True)

        initial_position = float(self.initial_position.text())
        initial_velocity = float(self.initial_velocity.text())
        initial_acceleration = float(self.initial_acceleration.text())
        noise_variance = float(self.noise.text())
        alpha = float(self.alpha.text())
        beta = float(self.beta.text())
        gamma = float(self.gamma.text())

        # True position and velocity
        v_true = initial_velocity + initial_acceleration * t
        x_true = initial_position + initial_velocity * t + 0.5 * initial_acceleration * t ** 2

        # Measurements (add noise)
        z = x_true + sqrt(noise_variance) * (random.rand(number_of_updates) - 0.5)

        # Initialize
        xk_1 = 0.0
        vk_1 = 0.0
        ak_1 = 0.0

        x_filt = []
        v_filt = []
        a_filt = []
        r_filt = []

        # Loop over all measurements
        for zk in z:
            # Predict the next state
            xk = xk_1 + vk_1 * dt + 0.5 * ak_1 * dt ** 2
            vk = vk_1 + ak_1 * dt
            ak = ak_1

            # Calculate the residual
            rk = zk - xk

            # Correct the predicted state
            xk += alpha * rk
            vk += beta / dt * rk
            ak += 2.0 * gamma / dt ** 2 * rk

            # Set the current state as previous
            xk_1 = xk
            vk_1 = vk
            ak_1 = ak

            x_filt.append(xk)
            v_filt.append(vk)
            a_filt.append(ak)
            r_filt.append(rk)

        # Clear the axes for the updated plot
        self.axes1.clear()

        # Get the selected plot from the form
        plot_type = self.plot_type.currentText()

        # Display the results
        if plot_type == 'Position':
            self.axes1.plot(t, x_true, '', label='True')
            self.axes1.plot(t, z, ':', label='Measurement')
            self.axes1.plot(t, x_filt, '--', label='Filtered')
            self.axes1.set_ylabel('Position (m)', size=12)
            self.axes1.legend(loc='best', prop={'size': 10})
        elif plot_type == 'Velocity':
            self.axes1.plot(t, v_true, '', label='True')
            self.axes1.plot(t, v_filt, '--', label='Filtered')
            self.axes1.set_ylabel('Velocity (m/s)', size=12)
            self.axes1.legend(loc='best', prop={'size': 10})
        elif plot_type == 'Acceleration':
            self.axes1.plot(t, initial_acceleration * ones_like(t), '', label='True')
            self.axes1.plot(t, a_filt, '--', label='Filtered')
            self.axes1.set_ylabel('Acceleration (m/s/s)', size=12)
            self.axes1.legend(loc='best', prop={'size': 10})
        elif plot_type == 'Residual':
            self.axes1.plot(t, r_filt, '')
            self.axes1.set_ylabel('Residual (m)', size=12)

        # Set the plot title and labels
        self.axes1.set_title('Alpha-Beta-Gamma Filter', size=14)
        self.axes1.set_xlabel('Time (s)', size=12)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Turn on the grid
        self.axes1.grid(linestyle=':', linewidth=0.5)

        # Update the canvas
        self.my_canvas.draw()
class BeamsizeMeasurement(QtWidgets.QMainWindow):
    def __init__(self, argv):
        QtWidgets.QMainWindow.__init__(self)

        self.ui = BeamsizeMeasurement_design.Ui_BeamsizeMeasurement()
        self.ui.setupUi(self)

        filename = None
        if len(argv) > 1:
            raise Exception("Too many input arguments given to 'beamsize'.")
        elif len(argv) == 1:
            filename = argv[0]

        self.plotWindow = PlotWindow()

        self.GF = None
        self.beamHandle = None
        self.radius = None

        self.imageContoursHandle = None
        self.imageHandle = None
        self.overlay = None
        self.overlayHandle = None

        self.toggleEnabled(False)
        self.bindEvents()

        if filename is not None and os.path.isfile(filename):
            self.loadFile(filename)

        self.setupImage()

    def bindEvents(self):
        self.ui.btnBrowse.clicked.connect(self.openFile)
        self.ui.btnBrowseOverlay.clicked.connect(self.openOverlay)
        self.ui.btnSaveImage.clicked.connect(self.saveImage)
        self.ui.btnSaveBoth.clicked.connect(self.saveBoth)
        self.ui.btnSaveProfile.clicked.connect(self.saveProfile)

        self.ui.sliderBeamsize.valueChanged.connect(self.sliderBeamsizeChanged)
        self.ui.sliderIntensity.valueChanged.connect(
            self.sliderIntensityChanged)
        self.ui.tbRadialProfile.textChanged.connect(self.radialProfileChanged)

        self.ui.gbRadialProfile.toggled.connect(self.fullProfileToggled)

        self.ui.sliderIntensity.valueChanged.connect(self.intensityChanged)
        self.ui.sliderOverlay.valueChanged.connect(self.updateOverlay)

        self.ui.cbContour.toggled.connect(self.contourToggled)

        self.ui.actionSave.triggered.connect(self.saveImage)
        self.ui.actionExit.triggered.connect(self.exit)

    def closeEvent(self, event):
        self.exit()

    def contourToggled(self):
        enabled = self.ui.cbContour.isChecked()

        self.ui.sliderIntensity.setEnabled(enabled)
        self.ui.lblIntensity_desc.setEnabled(enabled)
        self.ui.lblIntensity.setEnabled(enabled)
        self.ui.lblIntensity0.setEnabled(enabled)
        self.ui.lblIntensity20.setEnabled(enabled)
        self.ui.lblIntensity40.setEnabled(enabled)
        self.ui.lblIntensity60.setEnabled(enabled)
        self.ui.lblIntensity80.setEnabled(enabled)
        self.ui.lblIntensity100.setEnabled(enabled)

        self.updateImage()

    def exit(self):
        self.plotWindow.close()
        self.close()

    def toggleEnabled(self, enabled):
        self.ui.lblBeamRadius.setEnabled(enabled)

        self.ui.sliderBeamsize.setEnabled(enabled)
        self.ui.lblBeamsize.setEnabled(enabled)
        self.ui.lblBeamsize_desc.setEnabled(enabled)
        self.ui.lblBeamsize0.setEnabled(enabled)
        self.ui.lblBeamsize20.setEnabled(enabled)
        self.ui.lblBeamsize40.setEnabled(enabled)
        self.ui.lblBeamsize60.setEnabled(enabled)
        self.ui.lblBeamsize80.setEnabled(enabled)
        self.ui.lblBeamsize100.setEnabled(enabled)

        self.ui.sliderIntensity.setEnabled(enabled)
        self.ui.lblIntensity.setEnabled(enabled)
        self.ui.lblIntensity_desc.setEnabled(enabled)
        self.ui.lblIntensity0.setEnabled(enabled)
        self.ui.lblIntensity20.setEnabled(enabled)
        self.ui.lblIntensity40.setEnabled(enabled)
        self.ui.lblIntensity60.setEnabled(enabled)
        self.ui.lblIntensity80.setEnabled(enabled)
        self.ui.lblIntensity100.setEnabled(enabled)

        self.ui.sliderOverlay.setEnabled(enabled)
        self.ui.lblOverlay_desc.setEnabled(enabled)
        self.ui.tbOverlay.setEnabled(enabled)
        self.ui.btnBrowseOverlay.setEnabled(enabled)
        self.ui.lblOverlay0.setEnabled(enabled)
        self.ui.lblOverlay20.setEnabled(enabled)
        self.ui.lblOverlay40.setEnabled(enabled)
        self.ui.lblOverlay60.setEnabled(enabled)
        self.ui.lblOverlay80.setEnabled(enabled)
        self.ui.lblOverlay100.setEnabled(enabled)
        self.ui.gbRadialProfile.setEnabled(enabled)

    def loadFile(self, filename):
        self.ui.tbGreensFunction.setText(filename)

        self.GF = Green(filename)

        if not self.validateGreensFunction(self.GF):
            return

        # Store radii in centimeters
        self.radius = (self.GF._r - self.GF._r[0]) * 100.0

        self.toggleEnabled(True)
        self.ui.sliderBeamsize.setMaximum(self.GF.nr - 1)
        self.ui.sliderBeamsize.setSliderPosition(self.GF.nr - 1)

        self.updateBeamRadiusLabel()
        self.setupRadialProfile()

    def openFile(self):
        filename, _ = QFileDialog.getOpenFileName(
            parent=self,
            caption="Open SOFT Green's function file",
            filter="SOFT Green's function (*.mat *.h5 *.hdf5);;All files (*.*)"
        )

        if filename:
            self.loadFile(filename)

    def loadOverlay(self, filename):
        self.ui.tbOverlay.setText(filename)
        self.overlay = mpimg.imread(filename)

        if self.overlayHandle is not None:
            self.overlayHandle.remove()

        a = (self.ui.sliderOverlay.value()) / 100.0
        self.overlayHandle = self.imageAx.imshow(self.overlay,
                                                 alpha=a,
                                                 extent=[-1, 1, -1, 1])
        self.plotWindow.drawSafe()

    def openOverlay(self):
        filename, _ = QFileDialog.getOpenFileName(
            parent=self,
            caption="Open overlay image",
            filter="Portable Network Graphics (*.png)")

        if filename:
            self.loadOverlay(filename)

    def updateBeamRadiusLabel(self):
        v = self.ui.sliderBeamsize.value()
        r = self.radius[v]
        p = int(np.round((v / self.radius.size) * 100.0))

        self.ui.lblBeamRadius.setText('{0:.1f} cm'.format(r))
        self.ui.lblBeamsize.setText('{0}%'.format(p))

    def validateGreensFunction(self, gf):
        if gf.getFormat() != 'rij':
            QMessageBox.critical(
                self, 'Invalid input file',
                "The specified Green's function is not of the appropriate format. Expected 'rij', got {0}."
                .format(gf.getFormat()))
            return False

        return True

    def sliderBeamsizeChanged(self):
        self.updateBeamRadiusLabel()

        if self.imageHandle is not None:
            self.updateRadialProfile()

    def sliderIntensityChanged(self):
        v = self.ui.sliderIntensity.value()
        self.ui.lblIntensity.setText('{0}%'.format(v))

        i = float(v) / 100.0

    def setupRadialProfile(self):
        self.radialProfileLayout = QtWidgets.QVBoxLayout(
            self.ui.widgetRadialProfile)

        self.radialProfileCanvas = FigureCanvas(Figure())
        self.radialProfileLayout.addWidget(self.radialProfileCanvas)

        f = self.getRadialProfile()
        self.radialProfileAx = self.radialProfileCanvas.figure.subplots()
        self.radialProfileHandle, = self.radialProfileAx.plot(self.radius, f)

        self.radialProfileAx.set_xlim([0, self.radius[-1]])
        self.radialProfileAx.set_ylim([0, 1.2])

        self.radialProfileAx.set_xlabel(r'$r$ (cm)')
        self.radialProfileAx.set_ylabel(r'Radial density')

        self.radialProfileAx.figure.tight_layout(pad=4.5)

    def updateRadialProfile(self):
        f = self.getRadialProfile()
        self.radialProfileHandle.set_ydata(f)

        maxf = np.amax(f)
        if maxf == 0:
            self.radialProfileAx.set_ylim([0, 1])
        else:
            self.radialProfileAx.set_ylim([0, np.amax(f) * 1.2])
            self.updateImage()

        self.radialProfileCanvas.draw()

    def getRadialProfile(self):
        s = self.ui.tbRadialProfile.toPlainText().strip()
        x = self.radius / self.radius[-1]
        f0 = np.zeros(self.radius.shape)
        a = self.radius[self.ui.sliderBeamsize.value()] / self.radius[-1]

        if not s:
            f = np.ones(self.radius.shape)
        else:
            # Parse string
            f = None
            lcls = {'a': a}
            try:
                f = evaluateExpression(s, x, lcls=lcls)
            except Exception as ex:
                return np.zeros(self.radius.shape)

            # Set negative values to zero
            f = np.where(f < 0, f0, f)

        # Apply step function
        f = np.where(x < a, f, f0)

        return f

    def radialProfileChanged(self):
        self.updateRadialProfile()

    def setupImage(self):
        self.imageAx = self.plotWindow.figure.add_subplot(111)

        dummy = np.zeros(self.GF._pixels)
        self.imageHandle = self.imageAx.imshow(dummy,
                                               cmap='GeriMap',
                                               interpolation=None,
                                               clim=(0, 1),
                                               extent=[-1, 1, -1, 1])
        self.imageAx.get_xaxis().set_visible(False)
        self.imageAx.get_yaxis().set_visible(False)

        if not self.plotWindow.isVisible():
            self.plotWindow.show()

        self.updateImage()

    def updateImage(self):
        # Generate image
        f = self.getRadialProfile()
        img = np.zeros(self.GF._pixels)

        for i in range(0, len(self.radius)):
            img += self.GF[i, :, :] * f[i]

        img = img.T / np.amax(img)

        if self.ui.gbRadialProfile.isChecked():
            self.imageHandle.set_data(img)
        else:
            self.imageHandle.set_data(np.zeros(self.GF._pixels))

        if self.imageContoursHandle is not None:
            self.imageContoursHandle.remove()
            self.imageContoursHandle = None

        if self.ui.cbContour.isChecked():
            threshold = self.ui.sliderIntensity.value() / 100.0
            cntr = skimage.measure.find_contours(img.T, threshold)[0]

            ipix, jpix = img.shape
            i, j = cntr[:, 0], cntr[:, 1]
            i = (i - ipix / 2) / ipix * 2
            j = (-j + jpix / 2) / jpix * 2

            self.imageContoursHandle, = self.imageAx.plot(i, j, 'w--')

        self.plotWindow.drawSafe()

    def fullProfileToggled(self):
        self.updateImage()

    def intensityChanged(self):
        self.updateImage()

    def updateOverlay(self):
        if self.overlayHandle is not None:
            a = float(self.ui.sliderOverlay.value()) / 100.0
            self.overlayHandle.set_alpha(a)
            self.plotWindow.drawSafe()

    def saveImagePNG(self, filename=False):
        """
        Save the currently displayed SOFT image to a PNG file.
        """
        if filename is False:
            filename, _ = QFileDialog.getSaveFileName(
                self,
                caption='Save PNG image',
                filter='Portable Network Graphics (*.png)')

        if filename:
            f = self.getRadialProfile()
            img = np.zeros(self.GF._pixels)

            for i in range(0, len(self.radius)):
                img += self.GF[i, :, :] * f[i]

            img = img.T / np.amax(img)

            cmap = plt.get_cmap('GeriMap')
            im = Image.fromarray(np.uint8(cmap(img) * 255))
            im.save(filename)

    def saveImage(self, filename=False):
        """
        Save the currently displayed SOFT image to a PNG file.
        """
        if filename is False:
            filename, _ = QFileDialog.getSaveFileName(
                self,
                caption='Save image',
                filter=
                'Portable Document Format (*.pdf);;Portable Network Graphics (*.png)'
            )

        if not filename: return

        if filename.endswith('.png'):
            self.saveImagePNG(filename=filename)
            return

        self.imageAx.set_axis_off()
        self.plotWindow.figure.subplots_adjust(top=1,
                                               bottom=0,
                                               right=1,
                                               left=0,
                                               hspace=0,
                                               wspace=0)

        self.imageAx.get_xaxis().set_major_locator(
            matplotlib.ticker.NullLocator())
        self.imageAx.get_yaxis().set_major_locator(
            matplotlib.ticker.NullLocator())

        fcolor = self.plotWindow.figure.patch.get_facecolor()

        self.plotWindow.canvas.print_figure(filename,
                                            bbox_inches='tight',
                                            pad_inches=0,
                                            dpi=300)

    def saveProfile(self, filename=False):
        """
        Saves the current radial profile.
        """
        if filename is False:
            filename, _ = QFileDialog.getSaveFileName(
                self,
                caption='Save image',
                filter='Portable Network Graphics (*.png)')

        if not filename:
            return

        self.radialProfileCanvas.figure.canvas.print_figure(
            filename, bbox_inches='tight')

    def saveBoth(self):
        filename, _ = QFileDialog.getSaveFileName(
            self,
            caption='Save both figures',
            filter=
            'Portable Document Form (*.pdf);;Portable Network Graphics (*.png);;Encapsulated Post-Script (*.eps);;Scalable Vector Graphics (*.svg)'
        )

        if not filename:
            return

        f = filename.split('.')
        filename = str.join('.', f[:-1])
        ext = f[-1]

        if filename.endswith('_image') or filename.endswith('_super'):
            filename = filename[:-6]

        imgname = filename + '_image.' + ext
        supname = filename + '_profile.' + ext

        self.saveImage(filename=imgname)
        self.saveProfile(filename=supname)
Exemple #26
0
class CircularAperture(QMainWindow, Ui_MainWindow):
    def __init__(self):

        super(self.__class__, self).__init__()

        self.setupUi(self)

        # Connect to the input boxes, when the user presses enter the form updates
        self.frequency.returnPressed.connect(self._update_canvas)
        self.radius.returnPressed.connect(self._update_canvas)
        self.antenna_type.currentIndexChanged.connect(self._update_canvas)
        self.plot_type.currentIndexChanged.connect(self._update_canvas)

        # Set up a figure for the plotting canvas
        fig = Figure()
        self.fig = fig
        self.axes1 = fig.add_subplot(111)
        self.my_canvas = FigureCanvas(fig)

        # Add the canvas to the vertical layout
        self.verticalLayout.addWidget(self.my_canvas)
        self.addToolBar(QtCore.Qt.TopToolBarArea, NavigationToolbar(self.my_canvas, self))

        # Update the canvas for the first display
        self._update_canvas()

    def _update_canvas(self):
        """
        Update the figure when the user changes an input value
        :return:
        """
        # Get the parameters from the form
        frequency = float(self.frequency.text())
        radius = float(self.radius.text())

        # Get the selected antenna from the form
        type = self.antenna_type.currentIndex()

        # Set the range and angular span
        r = 1.0e9

        # Set up the theta and phi arrays
        n = 200
        m = int(n/4)
        theta, phi = meshgrid(linspace(finfo(float).eps, 0.5 * pi, n), linspace(finfo(float).eps, 2.0 * pi, n))

        # Get the antenna parameters and antenna pattern for the selected antenna
        if type == 0:
            half_power_eplane, half_power_hplane, first_null_eplane, first_null_hplane = \
                circular_uniform_ground_plane.beamwidth(radius, frequency)
            directivity = circular_uniform_ground_plane.directivity(radius, frequency)
            sidelobe_level_eplane = circular_uniform_ground_plane.side_lobe_level()
            sidelobe_level_hplane = sidelobe_level_eplane
            _, et, ep, _, _, _ = circular_uniform_ground_plane.far_fields(radius, frequency, r, theta, phi)
        else:
            half_power_eplane, half_power_hplane, first_null_eplane, first_null_hplane = \
                circular_te11_ground_plane.beamwidth(radius, frequency)
            directivity = circular_te11_ground_plane.directivity(radius, frequency)
            sidelobe_level_eplane, sidelobe_level_hplane = circular_te11_ground_plane.side_lobe_level()
            _, et, ep, _, _, _ = circular_te11_ground_plane.far_fields(radius, frequency, r, theta, phi)

        # Set the text boxes for the side lobe levels and directivity
        self.sll_eplane.setText('{:.2f}'.format(sidelobe_level_eplane))
        self.sll_hplane.setText('{:.2f}'.format(sidelobe_level_hplane))
        self.directivity.setText('{:.2f}'.format(directivity))

        # Remove the color bar
        try:
            self.cbar.remove()
        except:
            # Initial plot
            pass

        # Clear the axes for the updated plot
        self.axes1.clear()

        # U-V coordinates for plotting the antenna pattern
        uu = sin(theta) * cos(phi)
        vv = sin(theta) * sin(phi)

        # Normalized electric field magnitude for plotting
        e_mag = sqrt(abs(et * et + ep * ep))
        e_mag /= amax(e_mag)

        if self.plot_type.currentIndex() == 0:

            # Display the results
            im = self.axes1.pcolor(uu, vv, e_mag, cmap="jet")
            self.cbar = self.fig.colorbar(im, ax=self.axes1, orientation='vertical')
            self.cbar.set_label("Normalized Electric Field (V/m)", size=10)

            # Set the x- and y-axis labels
            self.axes1.set_xlabel("U (sines)", size=12)
            self.axes1.set_ylabel("V (sines)", size=12)

        elif self.plot_type.currentIndex() == 1:

            # Display the results
            self.axes1.contour(uu, vv, e_mag, 20, cmap="jet", vmin=-0.2, vmax=1.0)

            # Turn on the grid
            self.axes1.grid(linestyle=':', linewidth=0.5)

            # Set the x- and y-axis labels
            self.axes1.set_xlabel("U (sines)", size=12)
            self.axes1.set_ylabel("V (sines)", size=12)

        else:

            # Create the line plot
            self.axes1.plot(degrees(theta[0]), 20.0 * log10(e_mag[m]), '', label='E Plane')
            self.axes1.plot(degrees(theta[0]), 20.0 * log10(e_mag[0]), '--', label='H Plane')

            # Set the y axis limit
            self.axes1.set_ylim(-60, 5)

            # Set the x and y axis labels
            self.axes1.set_xlabel("Theta (degrees)", size=12)
            self.axes1.set_ylabel("Normalized |E| (dB)", size=12)

            # Turn on the grid
            self.axes1.grid(linestyle=':', linewidth=0.5)

            # Place the legend
            self.axes1.legend(loc='upper right', prop={'size': 10})

        # Set the plot title and labels
        self.axes1.set_title('Circular Aperture Antenna Pattern', size=14)

        # Set the tick label size
        self.axes1.tick_params(labelsize=12)

        # Update the canvas
        self.my_canvas.draw()
class ApplicationWindow(QMainWindow):
    """
    Example based on the example by 'scikit-image' gallery:
    "Immunohistochemical staining colors separation"
    https://scikit-image.org/docs/stable/auto_examples/color_exposure/plot_ihc_color_separation.html
    """

    def __init__(self, parent=None):
        QMainWindow.__init__(self, parent)
        self._main = QWidget()
        self.setCentralWidget(self._main)

        # Main menu bar
        self.menu = self.menuBar()
        self.menu_file = self.menu.addMenu("File")
        exit = QAction("Exit", self, triggered=qApp.quit)
        self.menu_file.addAction(exit)

        self.menu_about = self.menu.addMenu("&About")
        about = QAction("About Qt", self, shortcut=QKeySequence(QKeySequence.HelpContents),
                        triggered=qApp.aboutQt)
        self.menu_about.addAction(about)

        # Create an artificial color close to the original one
        self.ihc_rgb = data.immunohistochemistry()
        self.ihc_hed = rgb2hed(self.ihc_rgb)

        main_layout = QVBoxLayout(self._main)
        plot_layout = QHBoxLayout()
        button_layout = QHBoxLayout()
        label_layout = QHBoxLayout()

        self.canvas1 = FigureCanvas(Figure(figsize=(5, 5)))
        self.canvas2 = FigureCanvas(Figure(figsize=(5, 5)))

        self._ax1 = self.canvas1.figure.subplots()
        self._ax2 = self.canvas2.figure.subplots()

        self._ax1.imshow(self.ihc_rgb)

        plot_layout.addWidget(self.canvas1)
        plot_layout.addWidget(self.canvas2)

        self.button1 = QPushButton("Hematoxylin")
        self.button2 = QPushButton("Eosin")
        self.button3 = QPushButton("DAB")
        self.button4 = QPushButton("Fluorescence")

        self.button1.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
        self.button2.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
        self.button3.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
        self.button4.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)

        self.button1.clicked.connect(self.plot_hematoxylin)
        self.button2.clicked.connect(self.plot_eosin)
        self.button3.clicked.connect(self.plot_dab)
        self.button4.clicked.connect(self.plot_final)

        self.label1 = QLabel("Original", alignment=Qt.AlignCenter)
        self.label2 = QLabel("", alignment=Qt.AlignCenter)

        font = self.label1.font()
        font.setPointSize(16)
        self.label1.setFont(font)
        self.label2.setFont(font)

        label_layout.addWidget(self.label1)
        label_layout.addWidget(self.label2)

        button_layout.addWidget(self.button1)
        button_layout.addWidget(self.button2)
        button_layout.addWidget(self.button3)
        button_layout.addWidget(self.button4)

        main_layout.addLayout(label_layout, 2)
        main_layout.addLayout(plot_layout, 88)
        main_layout.addLayout(button_layout, 10)

        # Default image
        self.plot_hematoxylin()

    def set_buttons_state(self, states):
        self.button1.setEnabled(states[0])
        self.button2.setEnabled(states[1])
        self.button3.setEnabled(states[2])
        self.button4.setEnabled(states[3])

    @Slot()
    def plot_hematoxylin(self):
        cmap_hema = LinearSegmentedColormap.from_list("mycmap", ["white", "navy"])
        self._ax2.imshow(self.ihc_hed[:, :, 0], cmap=cmap_hema)
        self.canvas2.draw()
        self.label2.setText("Hematoxylin")
        self.set_buttons_state((False, True, True, True))

    @Slot()
    def plot_eosin(self):
        cmap_eosin = LinearSegmentedColormap.from_list("mycmap", ["darkviolet", "white"])
        self._ax2.imshow(self.ihc_hed[:, :, 1], cmap=cmap_eosin)
        self.canvas2.draw()
        self.label2.setText("Eosin")
        self.set_buttons_state((True, False, True, True))

    @Slot()
    def plot_dab(self):
        cmap_dab = LinearSegmentedColormap.from_list("mycmap", ["white", "saddlebrown"])
        self._ax2.imshow(self.ihc_hed[:, :, 2], cmap=cmap_dab)
        self.canvas2.draw()
        self.label2.setText("DAB")
        self.set_buttons_state((True, True, False, True))

    @Slot()
    def plot_final(self):
        h = rescale_intensity(self.ihc_hed[:, :, 0], out_range=(0, 1))
        d = rescale_intensity(self.ihc_hed[:, :, 2], out_range=(0, 1))
        zdh = np.dstack((np.zeros_like(h), d, h))
        self._ax2.imshow(zdh)
        self.canvas2.draw()
        self.label2.setText("Stain separated image")
        self.set_buttons_state((True, True, True, False))
Exemple #28
0
class Error_Graph(QWidget):
    def __init__(self, parent):
        QWidget.__init__(self, parent)
        self.layout = QVBoxLayout()
        self.setLayout(self.layout)
        # self.layout.setContentsMargins(0,0,0,0)

        self.TITLE_STYLE = {'size': 12, 'color': "#b1b1b1"}

        # Creating de graph
        self.figure = plt.figure(1)
        self.ax = plt.subplot()
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setFocus()
        self.layout.addWidget(self.canvas)

        self.error_points = []
        self.min_error_reached = int
        self.max_ephocs_reached = int

        self.init_graph()
        self.canvas.draw()

    def init_graph(self, x_max=2000, y_max=15):
        plt.figure(1)
        plt.tight_layout()
        self.ax = plt.gca()
        self.figure.set_facecolor('#323232')
        self.ax.grid(zorder=0)
        self.ax.set_axisbelow(True)
        self.ax.set_xlim([0, x_max])
        self.ax.set_ylim([0, y_max])
        self.ax.set_yticks((0, (y_max) / 2, y_max))
        self.ax.axhline(y=0, color='#323232')
        self.ax.axvline(x=0, color='#323232')
        self.ax.spines['right'].set_visible(False)
        self.ax.spines['top'].set_visible(False)
        self.ax.spines['bottom'].set_visible(False)
        self.ax.spines['left'].set_visible(False)
        self.ax.tick_params(axis='x', colors='#b1b1b1')
        self.ax.tick_params(axis='y', colors='#b1b1b1')

    def clear_plot(self):
        plt.figure(1)
        self.ax = plt.gca()
        self.ax.cla()
        self.init_graph()
        self.canvas.draw()

    def set_title(self, min_error='', max_ephocs=''):
        plt.figure(1)
        self.ax = plt.gca()
        self.ax.set_title(
            'Error mínimo alcanzado: {:.2f}     Épocas alcanzadas: {}'.format(
                min_error, max_ephocs),
            fontdict=self.TITLE_STYLE)

    def add_error(self, error):
        plt.figure(1)
        self.ax = plt.gca()
        self.ax.cla()
        self.ax.grid(zorder=0)
        self.error_points.append(error)
        self.ax.plot(self.error_points)
        self.set_title(error, len(self.error_points))
        self.canvas.draw()

    def clear_graph(self):
        plt.figure(1)
        plt.clf()
        self.init_graph()
        self.canvas.draw()

    def graph_errors(self, errors):
        if type(errors) == np.ndarray:
            self.error_points = errors.copy()
        else:
            self.error_points = list.copy(errors)
        plt.figure(1)
        plt.tight_layout()
        self.ax = plt.gca()
        self.ax.cla()
        self.ax.grid(zorder=0)
        self.ax.plot(self.error_points, c='red')
        try:
            self.set_title(self.error_points[-1], len(self.error_points))
        except IndexError:
            self.set_title(0, len(self.error_points))

        self.canvas.draw()
Exemple #29
0
class spotWindow(QDialog):
    def __init__(self, input_folder, params, parent=None):
        super(spotWindow, self).__init__(parent)
        self.input_folder = input_folder
        self.params = params
        # load the first image to use for parameter definition and find out the number of channels
        _, cond = os.path.split(input_folder)
        save_folder = os.path.join(input_folder, 'result_segmentation')
        props = utils_postprocessing.load_morpho_params(save_folder, cond)
        props = {key: props[key][0] for key in props}
        mask_file = props['mask_file']
        path_to_mask = os.path.join(input_folder, mask_file)
        self.mask = imread(path_to_mask)[props['slice']].astype(np.float)
        input_file = props['input_file']
        path_to_file = os.path.join(input_folder, input_file)
        self.img = imread(path_to_file).astype(float)
        if len(self.img.shape) == 2:
            self.img = np.expand_dims(self.img, 0)
        self.img = np.array([img[props['slice']] for img in self.img])
        self.n_channels = self.img.shape[0]

        # if params are none, set them to default values
        params_default = [
            0.8, 2, 0, (2, self.img.shape[1] * self.img.shape[2])
        ]
        for i, p in enumerate(self.params):
            # if there is no channel indexing, create one if length 1
            if p == None:
                self.params[i] = [None for i in self.n_channels]
            # for every element in the channel indexing, if it is None, set it to defualt
            for ch in range(len(p)):
                if (p[ch] == None) or (p[ch] == (None, None)):
                    self.params[i][ch] = params_default[i]

        # create window
        self.initUI()
        self.updateParamsAndFigure()

    def initUI(self):
        self.figure = Figure(figsize=(10, 2.5), dpi=100)
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)

        self.figure.clear()
        axs = self.figure.subplots(nrows=1, ncols=4)
        self.figure.subplots_adjust(top=0.95,
                                    right=0.95,
                                    left=0.2,
                                    bottom=0.25)
        for i in [0, 1, 3]:
            axs[i].axis('off')
        axs[2].set_xlabel('Fluo')
        axs[2].ticklabel_format(axis="x", style="sci", scilimits=(2, 2))
        axs[2].set_ylabel('Counts')
        axs[2].ticklabel_format(axis="y", style="sci", scilimits=(0, 2))
        self.canvas.draw()

        self.channel = QSpinBox()
        self.channel.setMaximum(self.n_channels - 1)
        self.channel.valueChanged.connect(self.updateChannel)
        self.channel.setAlignment(Qt.AlignRight)

        self.enhancement = QDoubleSpinBox()
        self.enhancement.setMinimum(0)
        self.enhancement.setMaximum(1)
        self.enhancement.setSingleStep(0.05)
        self.enhancement.setValue(self.params[0][self.channel.value()])
        self.enhancement.setAlignment(Qt.AlignRight)

        self.nClasses = QSpinBox()
        self.nClasses.setMinimum(2)
        self.nClasses.setValue(self.params[1][self.channel.value()])
        self.nClasses.valueChanged.connect(self.updatenThrChoice)
        self.nClasses.setAlignment(Qt.AlignRight)

        self.nThr = QSpinBox()
        self.nThr.setValue(self.params[2][self.channel.value()])
        self.nThr.setAlignment(Qt.AlignRight)

        self.minSize = QSpinBox()
        self.minSize.setMaximum(self.img.shape[1] * self.img.shape[2])
        self.minSize.setValue(self.params[3][self.channel.value()][0])
        self.minSize.setAlignment(Qt.AlignRight)

        self.maxSize = QSpinBox()
        self.maxSize.setMaximum(self.img.shape[1] * self.img.shape[2])
        self.maxSize.setValue(self.img.shape[1] * self.img.shape[2])
        self.maxSize.setAlignment(Qt.AlignRight)

        applyButton = QPushButton('Apply params')
        applyButton.clicked.connect(self.updateParamsAndFigure)

        endButton = QPushButton('UPDATE AND RETURN PARAMS')
        endButton.clicked.connect(self.on_clicked)

        lay = QGridLayout(self)
        lay.addWidget(NavigationToolbar(self.canvas, self), 0, 0, 1, 2)
        lay.addWidget(self.canvas, 1, 0, 1, 2)
        lay.addWidget(QLabel('Current channel'), 2, 0, 1, 1)
        lay.addWidget(self.channel, 2, 1, 1, 1)
        lay.addWidget(QLabel('Enhancement'), 3, 0, 1, 1)
        lay.addWidget(self.enhancement, 3, 1, 1, 1)
        lay.addWidget(QLabel('Expected classes for thresholding'), 4, 0, 1, 1)
        lay.addWidget(self.nClasses, 4, 1, 1, 1)
        lay.addWidget(QLabel('Selected threshold'), 5, 0, 1, 1)
        lay.addWidget(self.nThr, 5, 1, 1, 1)
        lay.addWidget(QLabel('Minimum spot size'), 6, 0, 1, 1)
        lay.addWidget(self.minSize, 6, 1, 1, 1)
        lay.addWidget(QLabel('Maximum spot size'), 7, 0, 1, 1)
        lay.addWidget(self.maxSize, 7, 1, 1, 1)
        lay.addWidget(applyButton, 8, 0, 1, 2)
        lay.addWidget(endButton, 9, 0, 1, 2)

        self.setWindowTitle(self.input_folder)
        QApplication.setStyle('Macintosh')

    def updatenThrChoice(self):
        self.nThr.setMaximum(self.nClasses.value() - 2)

    def updateChannel(self):
        ch = self.channel.value()

        self.enhancement.setValue(self.params[0][ch])
        self.nClasses.setValue(self.params[1][ch])
        self.nThr.setValue(self.params[2][ch])
        self.minSize.setValue(self.params[3][ch][0])
        self.maxSize.setValue(self.params[3][ch][1])

        self.updateParamsAndFigure()

    def updateParamsAndFigure(self):
        from matplotlib import rc
        from matplotlib.backends.backend_pdf import PdfPages
        rc('font', size=8)
        rc('font', family='Arial')
        # rc('font', serif='Times')
        rc('pdf', fonttype=42)
        # rc('text', usetex=True)
        self.nThr.setMaximum(self.nClasses.value() - 2)

        ch = self.channel.value()
        enhancement = self.enhancement.value()
        nclasses = self.nClasses.value()
        nThr = self.nThr.value()
        sizelims = (self.minSize.value(), self.maxSize.value())
        dict_, enhanced, thrs, objects = utils_image.detect_peaks(
            self.img[ch],
            self.mask,
            enhancement=enhancement,
            nclasses=nclasses,
            nThr=nThr,
            sizelims=sizelims)

        ### update the values
        self.params[0][ch] = enhancement
        self.params[1][ch] = nclasses
        self.params[2][ch] = nThr
        self.params[3][ch] = sizelims

        ### update the plot
        self.figure.clear()
        axs = self.figure.subplots(nrows=1, ncols=4)
        self.figure.subplots_adjust(top=0.9, right=1., left=0.,
                                    bottom=0.2)  #,wspace=0.01)#,hspace=0.01)
        for i in [0, 1, 3]:
            axs[i].axis('off')
        axs[2].set_xlabel('Fluo')
        axs[2].ticklabel_format(axis="x", style="sci", scilimits=(2, 2))
        axs[2].set_ylabel('Counts')
        axs[2].ticklabel_format(axis="y", style="sci", scilimits=(0, 2))
        axs[0].set_title('Input image')
        axs[1].set_title('Enhanced image')
        axs[2].set_title('Histogram')
        axs[3].set_title('Segmented spots: %d' % len(dict_['centroid']))
        axs[2].set_yscale('log')

        axs[0].imshow(self.img[ch],
                      cmap='magma',
                      vmin=np.percentile(self.img[ch], 0.3),
                      vmax=np.percentile(self.img[ch], 99.7))
        axs[1].imshow(enhanced,
                      cmap='magma',
                      vmin=np.percentile(enhanced, 0.3),
                      vmax=np.percentile(enhanced, 99.7))
        n, _, _ = axs[2].hist(enhanced[self.mask > 0], bins=100)
        for thr in thrs:
            axs[2].plot([thr, thr], [0, np.max(n)], '-r')
        axs[2].plot([thrs[nThr]], [np.max(n)], '*r', ms=10)
        axs[3].imshow(objects, cmap='gray')
        for coords, area in zip(dict_['centroid'], dict_['area']):
            # draw circle around segmented coins
            circle = mpatches.Circle((coords[1], coords[0]),
                                     radius=np.sqrt(area / np.pi),
                                     fc=(1, 0, 0, 0.5),
                                     ec=(1, 0, 0, 1),
                                     linewidth=2)
            axs[3].add_patch(circle)
            # axs[3].plot(coords[1],coords[0],'+r',ms=5,alpha=.8)
        self.canvas.draw()

    @QtCore.pyqtSlot()
    def on_clicked(self):
        self.accept()
Exemple #30
0
class Photon_Counter(QMainWindow):
    def __init__(self):
        super().__init__()

        self.n_points = 301
        self.t_scan = 1.5
        self.V_min = -5
        self.V_max = +5

        self.n_glissant = 10

        self.n_lect_min = 0
        self.n_lect_max = -1

        self.time_last_refresh = time.time()
        self.refresh_rate = 0.1

        self.setWindowTitle("Scan EM")

        ##Creation of the graphical interface##

        self.main = QWidget()
        self.setCentralWidget(self.main)

        layout = QHBoxLayout()
        Vbox = QVBoxLayout()
        Vbox_gauche = QVBoxLayout()
        Vbox_droite = QVBoxLayout()

        layout.addLayout(Vbox_gauche)
        layout.addLayout(Vbox)
        layout.addLayout(Vbox_droite)
        self.main.setLayout(layout)

        #Fields on the left

        self.labelV_min = QLabel("V_min")
        self.lectV_min = QLineEdit(str(self.V_min))
        Vbox_gauche.addWidget(self.labelV_min)
        Vbox_gauche.addWidget(self.lectV_min)

        self.labelV_max = QLabel("V_max")
        self.lectV_max = QLineEdit(str(self.V_max))
        Vbox_gauche.addWidget(self.labelV_max)
        Vbox_gauche.addWidget(self.lectV_max)
        Vbox_gauche.addStretch(1)

        self.labelt_scan = QLabel("t_scan")
        self.lectt_scan = QLineEdit(str(self.t_scan))
        Vbox_gauche.addWidget(self.labelt_scan)
        Vbox_gauche.addWidget(self.lectt_scan)
        Vbox_gauche.addStretch(1)

        self.labeln_points = QLabel("n_points")
        self.lectn_points = QLineEdit(str(self.n_points))
        Vbox_gauche.addWidget(self.labeln_points)
        Vbox_gauche.addWidget(self.lectn_points)
        Vbox_gauche.addStretch(1)

        #Buttons on the right
        self.stop = QPushButton('Stop')
        self.start = QPushButton('Start')
        self.keep_button = QPushButton('Keep trace')
        self.clear_button = QPushButton('Clear Last Trace')
        self.fit_button = QPushButton('Fit')
        self.normalize_cb = QCheckBox('Normalize')

        self.labelIter = QLabel("iter # 0")

        Vbox_droite.addWidget(self.normalize_cb)
        Vbox_droite.addStretch(1)
        Vbox_droite.addWidget(self.start)
        Vbox_droite.addWidget(self.stop)
        Vbox_droite.addStretch(1)
        Vbox_droite.addWidget(self.keep_button)
        Vbox_droite.addWidget(self.clear_button)
        Vbox_droite.addStretch(1)
        Vbox_droite.addWidget(self.fit_button)
        Vbox_droite.addStretch(1)
        Vbox_droite.addWidget(self.labelIter)

        self.stop.setEnabled(False)

        #Plot in the middle
        self.dynamic_canvas = FigureCanvas(Figure(figsize=(30, 10)))
        Vbox.addStretch(1)
        Vbox.addWidget(self.dynamic_canvas)
        self.addToolBar(Qt.BottomToolBarArea,
                        MyToolbar(self.dynamic_canvas, self))

        ## Matplotlib Setup ##

        self.dynamic_ax_X, self.dynamic_ax_Y = self.dynamic_canvas.figure.subplots(
            2)

        self.t = np.linspace(0, 100, 100)
        self.y = np.zeros(100)
        self.dynamic_line_X, = self.dynamic_ax_X.plot(self.t, self.y)
        self.dynamic_ax_X.set_xlabel('Voltage (V)')
        self.dynamic_ax_X.set_ylabel('PL(counts/s)')

        self.t = np.linspace(0, 100, 100)
        self.y = np.zeros(100)
        self.dynamic_line_Y, = self.dynamic_ax_Y.plot(self.t, self.y)
        self.dynamic_ax_Y.set_xlabel('Voltage (V)')
        self.dynamic_ax_Y.set_ylabel('PL(counts/s)')

        #Define the buttons' action

        self.start.clicked.connect(self.start_measure)
        self.stop.clicked.connect(self.stop_measure)
        self.keep_button.clicked.connect(self.keep_trace)
        self.clear_button.clicked.connect(self.clear_trace)
        self.fit_button.clicked.connect(self.auto_fit)

        ## Timer Setup ##

        self.timer = QTimer(self, interval=0)
        self.timer.timeout.connect(self.update_canvas)

    def update_value(self):

        self.n_points = np.int(self.lectn_points.text())
        self.t_scan = np.float(self.lectt_scan.text())
        self.V_min = np.float(self.lectV_min.text())
        self.V_max = np.float(self.lectV_max.text())

        self.t_tot = 2 * self.t_scan
        self.n_tot = 2 * self.n_points * self.n_glissant

        self.n_lect_min = 0
        self.n_lect_max = self.n_points

        self.f_acq = self.n_tot / self.t_tot
        Vm = (self.V_min + self.V_max) / 2
        Nm = self.n_points * self.n_glissant // 2
        self.V_list = list(np.linspace(Vm, self.V_min, Nm)) + list(
            np.linspace(self.V_min, self.V_max, 2 * Nm)) + list(
                np.linspace(self.V_max, Vm, self.n_tot - 3 * Nm))

        self.x = np.linspace(self.V_min, self.V_max, self.n_points)

        self.y_X = np.zeros(self.n_points)
        self.dynamic_line_X.set_data(self.x[self.n_lect_min:self.n_lect_max],
                                     self.y_X[self.n_lect_min:self.n_lect_max])
        self.set_lim(x=self.x[self.n_lect_min:self.n_lect_max],
                     y=self.y_X[self.n_lect_min:self.n_lect_max],
                     ax=self.dynamic_ax_X,
                     line=self.dynamic_line_X)

        self.y_Y = np.zeros(self.n_points)
        self.dynamic_line_Y.set_data(self.x[self.n_lect_min:self.n_lect_max],
                                     self.y_Y[self.n_lect_min:self.n_lect_max])
        self.set_lim(x=self.x[self.n_lect_min:self.n_lect_max],
                     y=self.y_Y[self.n_lect_min:self.n_lect_max],
                     ax=self.dynamic_ax_Y,
                     line=self.dynamic_line_Y)

    def keep_trace(self):
        self.dynamic_ax_X.plot(self.dynamic_line_X._x, self.dynamic_line_X._y)
        self.dynamic_ax_Y.plot(self.dynamic_line_Y._x, self.dynamic_line_Y._y)

    def clear_trace(self):
        lines = self.dynamic_ax.get_lines()
        line = lines[-1]
        if line != self.dynamic_line:
            line.remove()
        self.dynamic_ax.figure.canvas.draw()

    def auto_fit(self):
        from scipy.optimize import curve_fit, root_scalar
        x = self.dynamic_line._x[self.n_lect_min:self.n_lect_max]
        y = self.dynamic_line._y[self.n_lect_min:self.n_lect_max]

        def exp_fit(x, y, Amp=None, ss=None, tau=None):
            if not Amp:
                Amp = max(y) - min(y)
            if not ss:
                ss = y[-1]
            if not tau:
                tau = x[int(len(x) / 10)] - x[0]

            def f(x, Amp, ss, tau):
                return Amp * np.exp(-x / tau) + ss

            p0 = [Amp, ss, tau]
            popt, pcov = curve_fit(f, x, y, p0)
            return (popt, f(x, popt[0], popt[1], popt[2]))

        popt, yfit = exp_fit(x, y)
        self.dynamic_ax.plot(x, yfit, label='tau=%4.3e' % popt[2])
        self.dynamic_ax.legend()
        self.dynamic_canvas.draw()

    def set_lim(self, axes='both', x=[], y=[], line=None, ax=None):
        if not line:
            line = self.dynamic_line
        if not ax:
            ax = self.dynamic_ax
        if len(x) == 0:
            x = line._x
        if len(y) == 0:
            y = line._y
        xmin = min(x)
        xmax = max(x)
        ymin = min(y)
        ymax = max(y)
        Dx = xmax - xmin
        Dy = ymax - ymin
        dx = 0.01 * Dx + 1e-15
        dy = 0.01 * Dy + 1e-15
        if axes == 'both':
            ax.set_xlim([xmin - dx, xmax + dx])
            ax.set_ylim([ymin - dy, ymax + dy])
        if axes == 'x':
            ax.set_xlim([xmin - dx, xmax + dx])
        if axes == 'y':
            ax.set_ylim([ymin - dy, ymax + dy])

    def update_canvas(self):
        ##Update the plot and the value of the PL ##

        lecture = np.array(
            self.tension.read(self.n_tot,
                              timeout=nidaqmx.constants.WAIT_INFINITELY))

        lectureX = lecture[0]
        PLX = [
            sum(lectureX[i * self.n_glissant:(i + 1) * self.n_glissant]) /
            self.n_glissant
            for i in range(self.n_points // 2, 3 * self.n_points // 2)
        ]
        PLX = np.array(PLX)

        self.y_X = self.y_X * (1 - 1 / self.repeat) + PLX * (1 / self.repeat)

        lectureY = lecture[1]
        PLY = [
            sum(lectureY[i * self.n_glissant:(i + 1) * self.n_glissant]) /
            self.n_glissant
            for i in range(self.n_points // 2, 3 * self.n_points // 2)
        ]
        PLY = np.array(PLY)

        # dyX=PLX[-1]-PLX[0]
        # y0X=(PLX[-1]+PLX[0])/2
        # PLX=PLX-y0X
        # PLX=PLX/dyX

        # dyY=PLY[-1]-PLY[0]
        # y0Y=(PLY[-1]+PLY[0])/2
        # PLY=PLY-y0Y
        # PLY=PLY/dyY

        # PLY=PLX-PLY
        self.y_Y = self.y_Y * (1 - 1 / self.repeat) + PLY * (1 / self.repeat)

        self.repeat += 1

        if time.time() - self.time_last_refresh > self.refresh_rate:
            self.time_last_refresh = time.time()

            if self.normalize_cb.isChecked():
                ytoplot = self.y_X / max(self.y_X)
            else:
                ytoplot = self.y_X
            self.dynamic_line_X.set_ydata(
                ytoplot[self.n_lect_min:self.n_lect_max])
            self.set_lim(x=self.x[self.n_lect_min:self.n_lect_max],
                         y=ytoplot[self.n_lect_min:self.n_lect_max],
                         ax=self.dynamic_ax_X,
                         line=self.dynamic_line_X)

            if self.normalize_cb.isChecked():
                ytoplot = self.y_Y / max(self.y_Y)
            else:
                ytoplot = self.y_Y
            self.dynamic_line_Y.set_ydata(
                ytoplot[self.n_lect_min:self.n_lect_max])
            self.set_lim(x=self.x[self.n_lect_min:self.n_lect_max],
                         y=ytoplot[self.n_lect_min:self.n_lect_max],
                         ax=self.dynamic_ax_Y,
                         line=self.dynamic_line_Y)

            self.dynamic_canvas.draw()

            self.labelIter.setText("iter # %i" % self.repeat)

    def start_measure(self):
        ## What happens when you click "start" ##

        self.start.setEnabled(False)
        self.stop.setEnabled(True)

        #Read integration input values
        self.update_value()

        self.tension = nidaqmx.Task()
        self.tension.ai_channels.add_ai_voltage_chan("Dev1/ai11",
                                                     min_val=-10,
                                                     max_val=10)
        self.tension.ai_channels.add_ai_voltage_chan("Dev1/ai13",
                                                     min_val=-10,
                                                     max_val=10)
        self.tension.timing.cfg_samp_clk_timing(
            self.f_acq,
            sample_mode=nidaqmx.constants.AcquisitionType.CONTINUOUS,
            samps_per_chan=self.n_tot)

        self.voltage_out = nidaqmx.Task()
        self.voltage_out.ao_channels.add_ao_voltage_chan('Dev1/ao0')
        self.voltage_out.timing.cfg_samp_clk_timing(
            self.f_acq,
            sample_mode=nidaqmx.constants.AcquisitionType.CONTINUOUS,
            samps_per_chan=self.n_tot)
        self.voltage_out.write(self.V_list)

        self.tension.triggers.start_trigger.cfg_dig_edge_start_trig(
            '/Dev1/ao/StartTrigger')
        self.tension.start()

        self.repeat = 1

        #Start the task, then the timer

        self.voltage_out.start()
        self.timer.start()

    def stop_measure(self):
        #Stop the measuring, clear the tasks on both counters
        try:
            self.timer.stop()
        except:
            pass
        try:
            self.tension.close()
        except:
            pass
        try:
            self.voltage_out.close()
        except:
            pass

        self.stop.setEnabled(False)
        self.start.setEnabled(True)
Exemple #31
0
class ApplicationWindow(QtWidgets.QMainWindow):
    """Main application window."""
    def __init__(self):
        """Initialise the application - includes loading settings from disc, initialising a lakeator, and setting up the GUI."""
        super().__init__()
        self._load_settings()

        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)
        layout = QtWidgets.QHBoxLayout(self._main)

        self.setWindowTitle('Locator')
        self.setWindowIcon(QtGui.QIcon("./kiwi.png"))

        self.loadAction = QtWidgets.QAction("&Load File", self)
        self.loadAction.setShortcut("Ctrl+L")
        self.loadAction.setStatusTip("Load a multichannel .wav file.")
        self.loadAction.triggered.connect(self.file_open)

        self.saveAction = QtWidgets.QAction("&Save Image", self)
        self.saveAction.setShortcut("Ctrl+S")
        self.saveAction.setStatusTip("Save the current display to a PNG file.")
        self.saveAction.triggered.connect(self.save_display)
        self.saveAction.setDisabled(True)

        self.saveGisAction = QtWidgets.QAction("&Save to GIS", self)
        self.saveGisAction.setShortcut("Ctrl+G")
        self.saveGisAction.setStatusTip(
            "Save the heatmap as a QGIS-readable georeferenced TIFF file.")
        self.saveGisAction.triggered.connect(self.exportGIS)
        self.saveGisAction.setDisabled(True)

        self.statusBar()

        mainMenu = self.menuBar()
        fileMenu = mainMenu.addMenu("&File")
        fileMenu.addAction(self.loadAction)
        fileMenu.addAction(self.saveAction)
        fileMenu.addAction(self.saveGisAction)

        setArrayDesign = QtWidgets.QAction("&Configure Array Design", self)
        setArrayDesign.setShortcut("Ctrl+A")
        setArrayDesign.setStatusTip(
            "Input relative microphone positions and array bearing.")
        setArrayDesign.triggered.connect(self.get_array_info)

        setGPSCoords = QtWidgets.QAction("&Set GPS Coordinates", self)
        setGPSCoords.setShortcut("Ctrl+C")
        setGPSCoords.setStatusTip(
            "Set the GPS coordinates for the array, and ESPG code for the CRS."
        )
        setGPSCoords.triggered.connect(self.get_GPS_info)

        arrayMenu = mainMenu.addMenu("&Array")
        arrayMenu.addAction(setArrayDesign)
        arrayMenu.addAction(setGPSCoords)

        setDomain = QtWidgets.QAction("&Set Heatmap Domain", self)
        setDomain.setShortcut("Ctrl+D")
        setDomain.setStatusTip(
            "Configure distances left/right up/down at which to generate the heatmap."
        )
        setDomain.triggered.connect(self.getBoundsInfo)

        self.refreshHeatmap = QtWidgets.QAction("&Calculate", self)
        self.refreshHeatmap.setShortcut("Ctrl+H")
        self.refreshHeatmap.setStatusTip("(Re)calculate heatmap.")
        self.refreshHeatmap.triggered.connect(self.generate_heatmap)
        self.refreshHeatmap.setDisabled(True)

        self.refreshView = QtWidgets.QAction("&Recalculate on View", self)
        self.refreshView.setShortcut("Ctrl+R")
        self.refreshView.setStatusTip(
            "Recalculate heatmap at current zoom level.")
        self.refreshView.triggered.connect(self.recalculateOnView)
        self.refreshView.setDisabled(True)

        heatmapMenu = mainMenu.addMenu("&Heatmap")
        heatmapMenu.addAction(setDomain)

        # Initialise canvas
        self.static_canvas = FigureCanvas(Figure(figsize=(5, 3)))
        layout.addWidget(self.static_canvas)

        # Add a navbar
        navbar = NavigationToolbar(self.static_canvas, self)
        self.addToolBar(QtCore.Qt.BottomToolBarArea, navbar)

        # Override the default mpl save functionality to change default filename
        navbar._actions['save_figure'].disconnect()
        navbar._actions['save_figure'].triggered.connect(self.save_display)

        navbar._actions['home'].triggered.connect(lambda: print("testing"))

        self.img = None

        # Dynamically generate menu full of all available colourmaps. Do not add the inverted ones.
        self.colMenu = heatmapMenu.addMenu("&Choose Colour Map")
        self.colMenu.setDisabled(True)
        colGroup = QtWidgets.QActionGroup(self)
        for colour in sorted(colormaps(), key=str.casefold):
            if colour[-2:] != "_r":
                cm = self.colMenu.addAction(colour)
                cm.setCheckable(True)
                if colour == self.settings["heatmap"]["cmap"][:-2]:
                    cm.setChecked(True)
                receiver = lambda checked, cmap=colour: self.img.set_cmap(cmap)
                cm.triggered.connect(receiver)
                cm.triggered.connect(self._setcol)
                cm.triggered.connect(self.static_canvas.draw)
                colGroup.addAction(cm)

        self.invert = QtWidgets.QAction("&Invert Colour Map", self)
        self.invert.setShortcut("Ctrl+I")
        self.invert.setStatusTip("Invert the current colourmap.")
        self.invert.triggered.connect(self.invert_heatmap)
        self.invert.setCheckable(True)
        self.invert.setDisabled(True)
        heatmapMenu.addAction(self.invert)

        heatmapMenu.addSeparator()
        heatmapMenu.addAction(self.refreshHeatmap)
        heatmapMenu.addAction(self.refreshView)

        algoMenu = mainMenu.addMenu("Algorithm")
        self.algChoice = algoMenu.addMenu("&Change Algorithm")
        algGroup = QtWidgets.QActionGroup(self)
        for alg in sorted(["GCC", "MUSIC", "AF-MUSIC"], key=str.casefold):
            cm = self.algChoice.addAction(alg)
            cm.setCheckable(True)
            if alg == self.settings["algorithm"]["current"]:
                cm.setChecked(True)
            receiver = lambda checked, al=alg: self.setAlg(al)
            cm.triggered.connect(receiver)
            colGroup.addAction(cm)

        self.params = QtWidgets.QAction("&Algorithm Settings", self)
        self.params.setStatusTip("Alter algorithm-specific settings.")
        self.params.triggered.connect(self.getAlgoInfo)
        algoMenu.addAction(self.params)

        # Display a "ready" message
        self.statusBar().showMessage('Ready')

        # Boolean to keep track of whether we have GPS information for the array, and an image
        self._has_GPS = False
        self._has_heatmap = False

        # Keep track of the currently opened file
        self.open_filename = ""

        self.loc = lakeator.Lakeator(self.settings["array"]["mic_locations"])

    def setAlg(self, alg):
        """Change the current algorithm to `alg', and write settings to disc."""
        self.settings["algorithm"]["current"] = alg
        self._save_settings()

    def ondraw(self, event):
        """Return the new axis limits when the figure is zoomed, but not on window resize."""
        if self._has_heatmap and (self.settings["heatmap"]["xlim"][0] != self._static_ax.get_xlim()[0] or \
            self.settings["heatmap"]["xlim"][1] != self._static_ax.get_xlim()[1] or \
            self.settings["heatmap"]["ylim"][0] != self._static_ax.get_ylim()[0] or \
            self.settings["heatmap"]["ylim"][1] != self._static_ax.get_ylim()[1]):
            self.refreshView.setDisabled(False)
        self.last_zoomed = [
            self._static_ax.get_xlim(),
            self._static_ax.get_ylim()
        ]

    def recalculateOnView(self):
        """If the image has been zoomed, calling this method will recalculate the heatmap on the current zoom level."""
        if hasattr(self, "last_zoomed"):
            self.settings["heatmap"]["xlim"] = self.last_zoomed[0]
            self.settings["heatmap"]["ylim"] = self.last_zoomed[1]
            self._save_settings()
            self.generate_heatmap()

    def invert_heatmap(self):
        """Add or remove _r to the current colourmap before setting it (to invert the colourmap), then redraw the canvas."""
        if self.settings["heatmap"]["cmap"][-2:] == "_r":
            self.settings["heatmap"]["cmap"] = self.settings["heatmap"][
                "cmap"][:-2]
            self.img.set_cmap(self.settings["heatmap"]["cmap"])
            self.static_canvas.draw()
        else:
            try:
                self.img.set_cmap(self.settings["heatmap"]["cmap"] + "_r")
                self.settings["heatmap"][
                    "cmap"] = self.settings["heatmap"]["cmap"] + "_r"
                self.static_canvas.draw()
            except ValueError as inst:
                print(type(inst), inst)
        self._save_settings()

    def _setcol(self, c):
        """Set the colourmap attribute to the name of the cmap - needed as I'm using strings to set the cmaps rather than cmap objects."""
        self.settings["heatmap"]["cmap"] = self.img.get_cmap().name
        self._save_settings()

    def generate_heatmap(self):
        """Calculate and draw the heatmap."""
        # Initialise the axis on the canvas, refresh the screen
        self.static_canvas.figure.clf()
        self._static_ax = self.static_canvas.figure.subplots()

        cid = self.static_canvas.mpl_connect('draw_event', self.ondraw)

        # Show a loading message while the user waits
        self.statusBar().showMessage('Calculating heatmap...')
        # dom = self.loc.estimate_DOA_heatmap(self.settings["algorithm"]["current"], xrange=self.last_zoomed[0], yrange=self.last_zoomed[1], no_fig=True)

        dom = self.loc.estimate_DOA_heatmap(
            self.settings["algorithm"]["current"],
            xrange=self.settings["heatmap"]["xlim"],
            yrange=self.settings["heatmap"]["ylim"],
            no_fig=True,
            freq=self.settings["algorithm"]["MUSIC"]["freq"],
            AF_freqs=(self.settings["algorithm"]["AF-MUSIC"]["f_min"],
                      self.settings["algorithm"]["AF-MUSIC"]["f_max"]),
            f_0=self.settings["algorithm"]["AF-MUSIC"]["f_0"])

        # Show the image and set axis labels & title
        self.img = self._static_ax.imshow(
            dom,
            cmap=self.settings["heatmap"]["cmap"],
            interpolation='none',
            origin='lower',
            extent=[
                self.settings["heatmap"]["xlim"][0],
                self.settings["heatmap"]["xlim"][1],
                self.settings["heatmap"]["ylim"][0],
                self.settings["heatmap"]["ylim"][1]
            ])
        self._static_ax.set_xlabel("Horiz. Dist. from Center of Array [m]")
        self._static_ax.set_ylabel("Vert. Dist. from Center of Array [m]")
        self._static_ax.set_title("{}-based Source Location Estimate".format(
            self.settings["algorithm"]["current"]))

        # Add a colourbar and redraw the screen
        self.static_canvas.figure.colorbar(self.img)
        self.static_canvas.draw()

        # Once there's an image being displayed, you can save it and change the colours
        self.saveAction.setDisabled(False)
        if self._has_GPS:
            self.saveGisAction.setDisabled(False)
        self.statusBar().showMessage('Ready.')
        self.colMenu.setDisabled(False)
        self.invert.setDisabled(False)
        self._has_heatmap = True

    def file_open(self):
        """Let the user pick a file to open, and then calculate the cross-correlations."""
        self.statusBar().showMessage('Loading...')
        name, _ = QtWidgets.QFileDialog.getOpenFileName(
            self, "Load .wav file", "./", "Audio *.wav")
        if name:
            try:
                self.loc.load(name,
                              rho=self.settings["algorithm"]["GCC"]["rho"])
            except IndexError:
                msg = QtWidgets.QMessageBox()
                msg.setIcon(QtWidgets.QMessageBox.Critical)
                msg.setText(
                    "File Error\nThe number of microphones in the current array configuration ({0}) is greater than the number of tracks in the selected audio file. Please select a {0}-track audio file."
                    .format(self.loc.mics.shape[0]))
                msg.setWindowTitle("File Error")
                msg.setMinimumWidth(200)
                msg.exec_()
                return
            if self.loc.mics.shape[0] < self.loc.data.shape[1]:
                msg = QtWidgets.QMessageBox()
                msg.setIcon(QtWidgets.QMessageBox.Critical)
                msg.setText(
                    "File Error\nThe number of microphones in the current array configuration ({0}) is less than the number of tracks in the selected audio file ({1}). Please select a {0}-track audio file, or configure the microphone locations to match the current file."
                    .format(self.loc.mics.shape[0], self.loc.data.shape[1]))
                msg.setWindowTitle("File Error")
                msg.setMinimumWidth(200)
                msg.exec_()
                return
            self.open_filename = name
            self.refreshHeatmap.setDisabled(False)
            self.statusBar().showMessage('Ready.')

    def save_display(self):
        """Save the heatmap and colourbar with a sensible default filename."""
        defaultname = self.open_filename[:-4] + "_" + self.settings[
            "algorithm"]["current"] + "_heatmap.png"
        name, _ = QtWidgets.QFileDialog.getSaveFileName(
            self, "Save image", defaultname, "PNG files *.png;; All Files *")
        if name:
            name = name + ".png"
            self.static_canvas.figure.savefig(name)

    def get_GPS_info(self):
        """Create a popup to listen for the GPS info, and connect the listener."""
        self.setGPSInfoDialog = Dialogs.GPSPopUp(
            coords=self.settings["array"]["GPS"]["coordinates"],
            EPSG=self.settings["array"]["GPS"]["EPSG"]["input"],
            pEPSG=self.settings["array"]["GPS"]["EPSG"]["projected"],
            tEPSG=self.settings["array"]["GPS"]["EPSG"]["target"])
        self.setGPSInfoDialog.activate.clicked.connect(self.changeGPSInfo)
        self.setGPSInfoDialog.exec()

    def changeGPSInfo(self):
        """Listener for the change GPS info dialog - writes the new information to disc and enables the ExportToGIS option."""
        try:
            lat, long, EPSG, projEPSG, targetEPSG = self.setGPSInfoDialog.getValues(
            )
            self.settings["array"]["GPS"]["EPSG"]["input"] = EPSG
            self.settings["array"]["GPS"]["EPSG"]["projected"] = projEPSG
            self.settings["array"]["GPS"]["EPSG"]["target"] = targetEPSG
            self.settings["array"]["GPS"]["coordinates"] = (lat, long)
            self._save_settings()

            self._has_GPS = True
            if self._has_heatmap:
                self.saveGisAction.setDisabled(False)
            self.setGPSInfoDialog.close()
        except EPSGError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease enter EPSG codes for coordinate systems as integers, e.g. 4326 or 2193. To find the EPSG of a given coordinate system, visit https://epsg.io/"
            )
            msg.setWindowTitle("Error with EPSG code")
            msg.setMinimumWidth(200)
            msg.exec_()
        except GPSError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease enter only the numerical portion of the coordinates, in the order governed by ISO19111 (see https://proj.org/faq.html#why-is-the-axis-ordering-in-proj-not-consistent)"
            )
            msg.setWindowTitle("Error with GPS input.")
            msg.setMinimumWidth(200)
            msg.exec_()

    def get_array_info(self):
        """Create a popup to listen for the mic position info, and connect the listener."""
        self.setMicsInfoDialog = Dialogs.MicPositionPopUp(
            cur_locs=self.settings["array"]["mic_locations"])
        self.setMicsInfoDialog.activate.clicked.connect(self.changeArrayInfo)
        self.setMicsInfoDialog.exec()

    def changeArrayInfo(self):
        """Listener for the change array info dialog - writes the information to disc and re-initialises the locator."""
        # TODO: reload current file, or disable heatmap again after this call
        try:
            miclocs = self.setMicsInfoDialog.getValues()
            self.settings["array"]["mic_locations"] = miclocs
            self._save_settings()
            self.loc = lakeator.Lakeator(
                self.settings["array"]["mic_locations"])
            self.setMicsInfoDialog.close()
        except ValueError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease enter microphone coordinates in meters as x,y pairs, one per line; e.g.\n0.0, 0.0\n0.1, 0.0\n0.0, -0.1\n-0.1, 0.0\n0.0, 0.1"
            )
            msg.setWindowTitle("Error with microphone location input")
            msg.setMinimumWidth(200)
            msg.exec_()

    def getBoundsInfo(self):
        """Create a popup to listen for the change heatmap bounds info, and connect the listener."""
        l, r = self.settings["heatmap"]["xlim"]
        d, u = self.settings["heatmap"]["ylim"]
        self.setBoundsInfoDialog = Dialogs.HeatmapBoundsPopUp(l, r, u, d)
        self.setBoundsInfoDialog.activate.clicked.connect(
            self.changeBoundsInfo)
        self.setBoundsInfoDialog.exec()

    def changeBoundsInfo(self):
        """ Listener change heatmap bounds info dialog - save the information to disc and regenerate the heatmap on the new zoom area."""
        try:
            l_new, r_new, u_new, d_new = self.setBoundsInfoDialog.getValues()
            self.settings["heatmap"]["xlim"] = [l_new, r_new]
            self.settings["heatmap"]["ylim"] = [d_new, u_new]
            self._save_settings()
            # if self.open_filename:
            # self.generate_heatmap()
            self.setBoundsInfoDialog.close()
        except ValueError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease ensure that all distances are strictly numeric, e.g. enter '5' or '5.0', rather than '5m' or 'five'."
            )
            msg.setWindowTitle("Error with heatmap bounds")
            msg.setMinimumWidth(200)
            msg.exec_()
        except NegativeDistanceError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease ensure that Left/West < Right/East, \nand Up/North < Down/South."
            )
            msg.setWindowTitle("Error; impossible region")
            msg.setMinimumWidth(200)
            msg.exec_()

    def getAlgoInfo(self):
        """Create a popup to listen for the algorithm settings, and attach the listener."""
        self.setAlgoInfoDialog = Dialogs.AlgorithmSettingsPopUp(
            self.settings["algorithm"])
        self.setAlgoInfoDialog.activate.clicked.connect(self.changeAlgoInfo)
        self.setAlgoInfoDialog.cb.currentIndexChanged.connect(self.procChange)
        self.setAlgoInfoDialog.exec()

    def procChange(self):
        self.settings["algorithm"]["GCC"][
            "processor"] = self.setAlgoInfoDialog.cb.currentText()
        self._save_settings()

    def changeAlgoInfo(self):
        """ Listener for the change algorithm settings dialog - saves to disc after obtaining new information."""
        try:
            self.settings["algorithm"] = self.setAlgoInfoDialog.getValues()
            self._save_settings()
            self.setAlgoInfoDialog.close()
        except ValueError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease ensure that all frequencies are strictly numeric, e.g. enter '100' or '100.0', rather than '100 Hz' or 'one hundred'."
            )
            msg.setWindowTitle("Error with frequency input")
            msg.setMinimumWidth(200)
            msg.exec_()

    def exportGIS(self):
        """Export the current heatmap to disc as a TIF file, with associated {}.tif.points georeferencing data. 
        
        This is handled by the lakeator - this method is simply a wrapper and filepath selector."""
        defaultname = self.open_filename[:-4] + "_" + self.settings[
            "algorithm"]["current"] + "_heatmap"
        name, _ = QtWidgets.QFileDialog.getSaveFileName(
            self, "Save image & GIS Metadata", defaultname,
            "TIF files *.tif;; All Files *")
        if name:
            name = name + ".tif"
            self.loc.heatmap_to_GIS(
                self.settings["array"]["GPS"]["coordinates"],
                self.settings["array"]["GPS"]["EPSG"]["input"],
                projected_EPSG=self.settings["array"]["GPS"]["EPSG"]
                ["projected"],
                target_EPSG=self.settings["array"]["GPS"]["EPSG"]["target"],
                filepath=name)

    def _load_settings(self, settings_file="./settings.txt"):
        """Load settings from disc."""
        with open(settings_file, "r") as f:
            self.settings = json.load(f)

    def _save_settings(self, settings_file="./settings.txt"):
        """Save settings to disc."""
        with open(settings_file, "w") as f:
            stngsstr = json.dumps(self.settings, sort_keys=True, indent=4)
            f.write(stngsstr)
Exemple #32
0
class FramePanel(QtWidgets.QWidget):
    '''GUI panel containing frame display widget
    Can scroll through frames of parent's EMCReader object

    Other parameters:
        compare - Side-by-side view of frames and best guess tomograms from reconstruction
        powder - Show sum of all frames

    Required members of parent class:
        emc_reader - Instance of EMCReader class
        geom - Instance of DetReader class
        output_folder - (Only for compare mode) Folder with output data
        need_scaling - (Only for compare mode) Whether reconstruction was done with scaling
    '''
    def __init__(self, parent, compare=False, powder=False, **kwargs):
        super(FramePanel, self).__init__(**kwargs)

        matplotlib.rcParams.update({
            'text.color': '#eff0f1',
            'xtick.color': '#eff0f1',
            'ytick.color': '#eff0f1',
            'axes.labelcolor': '#eff0f1',
            #'axes.facecolor': '#232629',
            #'figure.facecolor': '#232629'})
            'axes.facecolor': '#2a2a2f',
            'figure.facecolor': '#2a2a2f'})

        self.parent = parent
        self.emc_reader = self.parent.emc_reader
        self.do_compare = compare
        self.do_powder = powder
        if self.do_compare:
            self.slices = slices.SliceGenerator(self.parent.geom, 'data/quat.dat',
                                                folder=self.parent.output_folder,
                                                need_scaling=self.parent.need_scaling)
        if self.do_powder:
            self.powder_sum = self.emc_reader.get_powder()

        self.numstr = '0'
        self.rangestr = '10'

        self._init_ui()

    def _init_ui(self):
        vbox = QtWidgets.QVBoxLayout(self)

        self.fig = Figure(figsize=(6, 6))
        self.fig.subplots_adjust(left=0.05, right=0.99, top=0.9, bottom=0.05)
        self.canvas = FigureCanvas(self.fig)
        self.navbar = MyNavigationToolbar(self.canvas, self)
        self.canvas.mpl_connect('button_press_event', self._frame_focus)
        vbox.addWidget(self.navbar)
        vbox.addWidget(self.canvas)

        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        if not self.do_powder:
            label = QtWidgets.QLabel('Frame number: ', self)
            hbox.addWidget(label)
            self.numstr = QtWidgets.QLineEdit('0', self)
            self.numstr.setFixedWidth(64)
            hbox.addWidget(self.numstr)
            label = QtWidgets.QLabel('/%d'%self.emc_reader.num_frames, self)
            hbox.addWidget(label)
        hbox.addStretch(1)
        if not self.do_powder and self.do_compare:
            self.compare_flag = QtWidgets.QCheckBox('Compare', self)
            self.compare_flag.clicked.connect(self._compare_flag_changed)
            self.compare_flag.setChecked(False)
            hbox.addWidget(self.compare_flag)
            label = QtWidgets.QLabel('CMap:', self)
            hbox.addWidget(label)
            self.slicerange = QtWidgets.QLineEdit('10', self)
            self.slicerange.setFixedWidth(30)
            hbox.addWidget(self.slicerange)
            label = QtWidgets.QLabel('^', self)
            hbox.addWidget(label)
            self.exponent = QtWidgets.QLineEdit('1.0', self)
            self.exponent.setFixedWidth(30)
            hbox.addWidget(self.exponent)
            hbox.addStretch(1)
        label = QtWidgets.QLabel('PlotMax:', self)
        hbox.addWidget(label)
        self.rangestr = QtWidgets.QLineEdit('10', self)
        self.rangestr.setFixedWidth(48)
        hbox.addWidget(self.rangestr)

        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        button = QtWidgets.QPushButton('Plot', self)
        button.clicked.connect(self.plot_frame)
        hbox.addWidget(button)
        if self.do_powder:
            button = QtWidgets.QPushButton('Save', self)
            button.clicked.connect(self._save_powder)
            hbox.addWidget(button)
        else:
            gui_utils.add_scroll_hbox(self, hbox)
        hbox.addStretch(1)
        button = QtWidgets.QPushButton('Quit', self)
        button.clicked.connect(self.parent.close)
        hbox.addWidget(button)

        self.show()
        #if not self.do_compare:
        self.plot_frame()

    def plot_frame(self, frame=None):
        '''Update canvas according to GUI parameters
        Updated plot depends on mode (for classifier) and whether the GUI is in
        'compare' or 'powder' mode.
        '''
        try:
            mode = self.parent.mode_val
        except AttributeError:
            mode = None

        if frame is not None:
            pass
        elif self.do_powder:
            frame = self.powder_sum
            num = None
        else:
            num = self.get_num()
            if num is None:
                return
            frame = self.emc_reader.get_frame(num)

        try:
            for point in self.parent.embedding_panel.roi_list:
                point.remove()
        except (ValueError, AttributeError):
            pass

        self.fig.clear()
        if mode == 2:
            subp = self.parent.conversion_panel.plot_converted_frame()
        elif self.do_compare and self.compare_flag.isChecked():
            subp = self._plot_slice(num)
        else:
            subp = self.fig.add_subplot(111)
        subp.imshow(frame.T, vmin=0, vmax=float(self.rangestr.text()),
                    interpolation='none', cmap=self.parent.cmap)
        subp.set_title(self._get_plot_title(frame, num, mode))
        self.fig.tight_layout()
        self.canvas.draw()

    def get_num(self):
        '''Get valid frame number from GUI
        Returns None if the types number is either unparseable or out of bounds
        '''
        try:
            num = int(self.numstr.text())
        except ValueError:
            sys.stderr.write('Frame number must be integer\n')
            return None
        if num < 0 or num >= self.emc_reader.num_frames:
            sys.stderr.write('Frame number %d out of range!\n' % num)
            return None
        return num

    def _plot_slice(self, num):
        with open(self.parent.log_fname, 'r') as fptr:
            line = fptr.readlines()[-1]
            try:
                iteration = int(line.split()[0])
            except (IndexError, ValueError):
                sys.stderr.write('Unable to determine iteration number from %s\n' %
                                 self.parent.log_fname)
                sys.stderr.write('%s\n' % line)
                iteration = None

        if iteration > 0:
            subp = self.fig.add_subplot(121)
            subpc = self.fig.add_subplot(122)
            tomo, info = self.slices.get_slice(iteration, num)
            subpc.imshow(tomo**float(self.exponent.text()), cmap=self.parent.cmap, vmin=0, vmax=float(self.slicerange.text()), interpolation='gaussian')
            subpc.set_title('Mutual Info. = %f'%info)
            self.fig.add_subplot(subpc)
        else:
            subp = self.fig.add_subplot(111)

        return subp

    def _next_frame(self):
        num = int(self.numstr.text()) + 1
        if num < self.emc_reader.num_frames:
            self.numstr.setText(str(num))
            self.plot_frame()

    def _prev_frame(self):
        num = int(self.numstr.text()) - 1
        if num > -1:
            self.numstr.setText(str(num))
            self.plot_frame()

    def _rand_frame(self):
        num = np.random.randint(0, self.emc_reader.num_frames)
        self.numstr.setText(str(num))
        self.plot_frame()

    def _get_plot_title(self, frame, num, mode):
        title = '%d photons' % frame.sum()
        if frame is None and (mode == 1 or mode == 3):
            title += ' (%s)' % self.parent.classes.clist[num]
        if mode == 4 and self.parent.mlp_panel.predictions is not None:
            title += ' [%s]' % self.parent.mlp_panel.predictions[num]
        if (mode is None and
                not self.do_powder and
                self.parent.blacklist is not None and
                self.parent.blacklist[num] == 1):
            title += ' (bad frame)'
        return title

    def _compare_flag_changed(self):
        self.plot_frame()

    def _frame_focus(self, event): # pylint: disable=unused-argument
        self.setFocus()

    def _save_powder(self):
        fname = '%s/assem_powder.bin' % self.parent.output_folder
        sys.stderr.write('Saving assembled powder sum with shape %s to %s\n' %
                         ((self.powder_sum.shape,), fname))
        self.powder_sum.data.tofile(fname)

        raw_powder = self.emc_reader.get_powder(raw=True)
        fname = '%s/powder.bin' % self.parent.output_folder
        sys.stderr.write('Saving raw powder sum with shape %s to %s\n' %
                         ((raw_powder.shape,), fname))
        raw_powder.tofile(fname)

    def keyPressEvent(self, event): # pylint: disable=C0103
        '''Override of default keyPress event handler'''
        key = event.key()
        mod = int(event.modifiers())

        if QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+N'):
            self._next_frame()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+P'):
            self._prev_frame()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+R'):
            self._rand_frame()
        elif key == QtCore.Qt.Key_Return or key == QtCore.Qt.Key_Enter:
            self.plot_frame()
        elif key == QtCore.Qt.Key_Right or key == QtCore.Qt.Key_Down:
            self._next_frame()
        elif key == QtCore.Qt.Key_Left or key == QtCore.Qt.Key_Up:
            self._prev_frame()
        else:
            event.ignore()