Ejemplo n.º 1
0
class PlotWidget(QtWidgets.QWidget):
    '''
    classdocs
    '''
    __figure__ = None
    __canvas__ = None

    def __init__(self):
        '''
        Constructor
        '''
        super().__init__()
        self.setLayout(QtWidgets.QVBoxLayout(self))
        self.__figure__ = Figure()
        self.__canvas__ = FigureCanvas(self.__figure__)
        #This next line works for when Axes are created, but not when the axis imshow is called
        self.__figure__.add_axobserver(lambda x: self.__canvas__.draw_idle())
        self.layout().addWidget(self.__canvas__)
        self.layout().setMenuBar(NavigationToolbar2QT(self.__canvas__, self))

    def get_figure(self):
        return self.__figure__

    def get_canvas(self):
        return self.__canvas__

    def imshow(self, *args, **kwargs):
        self.__figure__.gca().imshow(*args, **kwargs)
        self.__canvas__.draw_idle()
Ejemplo n.º 2
0
class Window(QtWidgets.QMainWindow):
    '''Graph frontend window'''
    def __init__(self, points: int, toggle_callback):
        super().__init__()
        _main = QtWidgets.QWidget()
        self.setCentralWidget(_main)
        layout = QtWidgets.QVBoxLayout(_main)

        self.record_button = QtWidgets.QPushButton("Stop recording")
        self.record_button.clicked.connect(toggle_callback)
        layout.addWidget(self.record_button)

        figure, [self.axs_inp, self.axs_eng] = plt.subplots(1,
                                                            2,
                                                            sharex=True,
                                                            sharey=True,
                                                            figsize=(5, 2))

        self.canvas = FigureCanvas(figure)
        layout.addWidget(self.canvas)

        self.points = points
        self.axs_inp.set_title("Input (keyboard/mouse)", fontsize=10)
        self.set_axis_ticks(self.axs_inp)
        self.axs_eng.set_title("Engagement", fontsize=10)
        self.set_axis_ticks(self.axs_eng)

        self.line_inp, * \
            _ = self.axs_inp.plot(range(self.points), np.zeros(self.points))
        self.line_eng, * \
            _ = self.axs_eng.plot(range(self.points), np.zeros(self.points))

    @staticmethod
    def set_axis_ticks(axis: Axes):
        axis.set_ylim(-0.2, 1.2)
        axis.set_yticks([0, 1])
        axis.set_yticklabels(["absent", "present"])
        axis.set_xticks([])

    def plot(self, ys):
        '''Update points'''
        self.line_inp.set_data(range(self.points), [i[0] for i in ys])
        self.line_eng.set_data(range(self.points), [i[1] for i in ys])
        self.canvas.draw_idle()
Ejemplo n.º 3
0
class ControlWidget(QWidget):
    def __init__(self, node):
        super(ControlWidget, self).__init__()
        layout = QVBoxLayout(self)
        self.static_canvas = FigureCanvas(Figure(figsize=(5, 3)))
        layout.addWidget(self.static_canvas)
        self.control_truth = []
        self.control_estimate = []
        self.control_estimate_argmax = []
        self.control_prediction = []
        self.control_prediction_argmax = []
        self.estimate_types = []
        self.ax = self.static_canvas.figure.subplots()

        self.ax.set_xlim(0, 50)
        self.ax.set_ylim(-0.1, 1.1)
        self.ax.set_title('Control History')
        self.ax.legend()
        self.ax.set_yticks([0, 1])
        self.ax.set_yticklabels(['2g', '3g'])
        self.ax.set_xlabel('Time')
        self.ax.set_ylabel('Control')
        self.ax.grid()
        # t = np.linspace(0, 10, 501)
        # self.ax.plot(t, np.sin(t), ".")

        # # Create QWidget
        # # ui_file = os.path.join(rospkg.RosPack().get_path('rqt_dot'), 'resource', 'rqt_dot.ui')
        # _, package_path = get_resource('packages', 'rqt_dot')
        # ui_file = os.path.join(package_path, 'share', 'rqt_dot', 'resource', 'rqt_dot.ui')
        # loadUi(ui_file, self, {'DotWidget':DotWidget})
        # self.setObjectName('ROSDot')
        #
        # self.refreshButton.clicked[bool].connect(self._handle_refresh_clicked)
        # self.saveButton.clicked[bool].connect(self._handle_save_button_clicked)
        #
        # # flag used to zoom out to fit graph the first time it's received
        # self.first_time_graph_received = True
        # # to store the graph msg received in callback, later on is used to save the graph if needed
        # self.graph = None
        self.topic_name = 'graph_string'
        self._node = node
        self._node.create_subscription(ControlPList, 'control_estimate',
                                       self.ReceivedEstimateCallback, 10)
        self._node.create_subscription(ControlA, 'control_data',
                                       self.ReceivedTruthCallback, 10)

        # rclpy.spin_once(self._sub)
        #
        # # inform user that no graph has been received by drawing a single node in the rqt
        # self.gen_single_node('no dot received')

    def plot(self):
        n_estimates = sum(1 if t == 1. else 0 for t in self.types)
        t = range(len(self.types))
        self.ax.clear()
        self.ax.plot(t[n_estimates - 1:],
                     self.control_prediction_argmax,
                     'r--',
                     linewidth=2,
                     label='Predicted')
        self.ax.plot(t[:n_estimates],
                     self.control_estimate_argmax,
                     'b-',
                     linewidth=2,
                     label='Estimated')
        self.ax.plot(range(len(self.control_truth)),
                     self.control_truth,
                     'k--',
                     linewidth=2,
                     label='Ground Truth')
        # self.plot_ellipses()
        self.ax.set_xlim(0, 50)
        self.ax.set_ylim(-0.1, 1.1)
        self.ax.set_title('Control History')
        self.ax.legend()
        self.ax.set_yticks([0, 1])
        self.ax.set_yticklabels(['2g', '3g'])
        self.ax.set_xlabel('Time')
        self.ax.set_ylabel('Control')
        self.ax.grid()

        self.static_canvas.figure.savefig(
            "/mnt/hgfs/michaelkapteyn/digitaltwin_ws/src/digitaltwin/outputfiles/figures/control_plot_{}.svg"
            .format(n_estimates),
            format='svg',
            transparent=True)
        self.static_canvas.draw()
        self.static_canvas.draw_idle()

    def ReceivedEstimateCallback(self, msg):
        # '''
        # updating figure
        # '''
        # # save graph in member variable in case user clicks save button later
        # clear the axis
        self.types = [s.type for s in msg.controls]
        n_estimates = sum(1 if t == 1. else 0 for t in self.types)
        self.control_estimate = [m.control for m in msg.controls[:n_estimates]]
        self.control_prediction = [
            m.control for m in msg.controls[n_estimates - 1:]
        ]
        self.control_estimate_argmax = [
            np.argmax(m) for m in self.control_estimate
        ]
        self.control_prediction_argmax = [
            np.argmax(m) for m in self.control_prediction
        ]
        print(self.control_prediction_argmax)
        self.plot()

    def ReceivedTruthCallback(self, msg):
        # save the data
        self.control_truth.append(msg.data[0])

    def _handle_refresh_clicked(self, checked):
        '''
        called when the refresh button is clicked
        '''
        # self._sub = FigSub(self,self.topicText.text())
        pass

    # def save_graph(self, full_path):
    #     '''
    #     check if last graph msg received is valid (non empty), then save in file.dot
    #     '''
    #     if self.graph:
    #         dot_file = open(full_path,'w')
    #         dot_file.write(self.graph)
    #         dot_file.close()
    #         # rospy.loginfo('graph saved succesfully in %s', full_path)
    #     else:
    #         pass
    #         # if self.graph is None it will fall in this case
    #         # rospy.logerr('Could not save Graph: is empty, currently subscribing to: %s, try' +\
    #                      # ' clicking "Update subscriber" button and make sure graph is published at least one time'\
    #                      # , self.topicText.text())
    #
    # def _handle_save_button_clicked(self, checked):
    #     '''
    #     called when the save button is clicked
    #     '''
    #     # rospy.loginfo('Saving graph to dot file')
    #     fileName = QFileDialog.getSaveFileName(self, 'Save graph to dot file','','Graph xdot Files (*.dot)')
    #     if fileName[0] == '':
    #         pass
    #         # rospy.loginfo("User has cancelled saving process")
    #     else:
    #         # add .dot at the end of the filename
    #         full_dot_path = fileName[0]
    #         if not '.dot' in full_dot_path:
    #             full_dot_path += '.dot'
    #         # rospy.loginfo("path to save dot file: %s", full_dot_path)
    #         self.save_graph(full_dot_path)
    #
    # Qt methods
    def shutdown_plugin(self):
        pass

    def save_settings(self, plugin_settings, instance_settings):
        pass

    def restore_settings(self, plugin_settings, instance_settings):
        pass

    def plot_ellipses(self):
        for j, t in zip(self.joints[-1:0:-1], self.types[-1:0:-1]):
            if t == 1:
                self.confidence_ellipse(j, n_std=1.0, edgecolor='b')
            else:
                self.confidence_ellipse(j,
                                        n_std=1.0,
                                        edgecolor='r',
                                        linestyle='--')

    def confidence_ellipse(self,
                           j,
                           n_std=1.0,
                           edgecolor='k',
                           linestyle='-',
                           **kwargs):
        # """
        # Create a plot of the covariance confidence ellipse of *x* and *y*.
        #
        # Parameters
        # ----------
        # j : joint  probability  table
        #     Input data.
        #
        # ax : matplotlib.axes.Axes
        #     The axes object to draw the ellipse into.
        #
        # n_std : float
        #     The number of standard deviations to determine the ellipse's radiuses.
        #
        # **kwargs
        #     Forwarded to `~matplotlib.patches.Ellipse`
        #
        # Returns
        # -------
        # matplotlib.patches.Ellipse
        # """

        EX = np.dot(np.sum(j, 1), np.arange(5))
        EX2 = np.dot(np.sum(j, 1), [0, 1, 4, 9, 16])
        EY = np.dot(np.sum(j, 0), np.arange(5))
        EY2 = np.dot(np.sum(j, 0), [0, 1, 4, 9, 16])
        EXY = sum([j[x, y] * x * y for x in range(5) for y in range(5)])

        cov = np.zeros((2, 2))
        cov[0, 0] = EX2 - pow(EX, 2)
        cov[1, 1] = EY2 - pow(EY, 2)
        cov[0, 1] = EXY - EX * EY
        cov[1, 0] = cov[0, 1]
        # if x.size != y.size:
        #     raise ValueError("x and y must be the same size")
        #
        # cov = np.cov(x, y)
        if cov[0, 0] < 1e-3 or cov[1, 1] < 1e-3:
            return
        pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
        # Using a special case to obtain the eigenvalues of this
        # two-dimensionl dataset.
        ell_radius_x = np.sqrt(1 + pearson)
        ell_radius_y = np.sqrt(1 - pearson)
        ellipse = Ellipse((0, 0),
                          width=ell_radius_x * 2,
                          height=ell_radius_y * 2,
                          facecolor='none',
                          edgecolor=edgecolor,
                          linestyle=linestyle,
                          **kwargs)

        # Calculating the stdandard deviation of x from
        # the squareroot of the variance and multiplying
        # with the given number of standard deviations.
        scale_x = np.sqrt(cov[0, 0]) * n_std
        mean_x = EX

        # calculating the stdandard deviation of y ...
        scale_y = np.sqrt(cov[1, 1]) * n_std
        mean_y = EY

        transf = transforms.Affine2D() \
            .rotate_deg(45) \
            .scale(scale_x, scale_y) \
            .translate(mean_x, mean_y)

        ellipse.set_transform(transf + self.ax.transData)
        self.ax.add_patch(ellipse)
        return
Ejemplo n.º 4
0
class RewardWidget(QWidget):
    def __init__(self, node):
        super(RewardWidget, self).__init__()
        layout = QVBoxLayout(self)
        self.static_canvas = FigureCanvas(Figure(figsize=(5, 3)))
        layout.addWidget(self.static_canvas)

        self.ax = self.static_canvas.figure.subplots()
        self.ax.set_xlim(0, 50)
        # self.ax.set_ylim(0,5)
        self.ax.set_title('Reward Functions')
        self.ax.set_xlabel('Time')
        self.ax.set_ylabel('Reward')
        self.ax.legend()
        self.ax.grid()

        self.reward_total = []
        self.reward_total_var = []

        self.reward_state = []
        self.reward_state_var = []

        self.reward_control = []
        self.reward_control_var = []

        self.reward_policy = []
        self.reward_policy_var = []

        self.reward_outputerror = []
        self.reward_outputerror_var = []

        self.reward_types = []

        self.topic_name = 'reward_estimate'
        self._node = node
        self._node.create_subscription(Reward, self.topic_name,
                                       self.ReceivedCallback, 10)

        # rclpy.spin_once(self._sub)
        #
        # # inform user that no graph has been received by drawing a single node in the rqt
        # self.gen_single_node('no dot received')

    def ReceivedCallback(self, msg):
        # '''
        # updating figure
        # '''
        # # save graph in member variable in case user clicks save button later
        # clear the axis
        self.reward_total.append(msg.total)
        self.reward_total_var.append(msg.total_var)

        self.reward_state.append(msg.state)
        self.reward_state_var.append(msg.state_var)

        self.reward_control.append(msg.control)
        self.reward_control_var.append(msg.control_var)

        self.reward_policy.append(msg.policy)
        self.reward_policy_var.append(msg.policy_var)

        self.reward_outputerror.append(msg.outputerror)
        self.reward_outputerror_var.append(msg.outputerror_var)

        self.reward_types.append(msg.type)
        n_estimates = sum(1 if t == 1. else 0 for t in self.reward_types[-1])

        xx = range(0, len(msg.total))
        self.ax.clear()
        # self.ax.plot(xx[:n_estimates],self.reward_total[-1][:n_estimates], 'k-', linewidth=3, label='Total')
        # self.ax.fill_between(xxref[:n_estimates], np.array(mean_estimate)-ci_estimate, np.array(mean_estimate)+ci_estimate, color='b', alpha=.1)

        self.ax.plot(xx[:n_estimates],
                     self.reward_state[-1][:n_estimates],
                     'b-',
                     linewidth=2,
                     label='State')
        ci = 2.0 * np.sqrt(self.reward_state_var[-1][:n_estimates])
        self.ax.fill_between(
            xx[:n_estimates],
            np.array(self.reward_state[-1][:n_estimates]) - ci,
            np.array(self.reward_state[-1][:n_estimates]) + ci,
            color='b',
            alpha=.1)

        self.ax.plot(xx[:n_estimates],
                     self.reward_control[-1][:n_estimates],
                     'g-',
                     linewidth=2,
                     label='Control')
        ci = 2.0 * np.sqrt(self.reward_control_var[-1][:n_estimates])
        self.ax.fill_between(
            xx[:n_estimates],
            np.array(self.reward_control[-1][:n_estimates]) - ci,
            np.array(self.reward_control[-1][:n_estimates]) + ci,
            color='g',
            alpha=.1)

        # self.ax.plot(xx[:n_estimates], self.reward_policy[-1]:n_estimates],'m-', linewidth=2, label='Policy')

        self.ax.plot(xx[:n_estimates],
                     self.reward_outputerror[-1][:n_estimates],
                     'r-',
                     linewidth=2,
                     label='Error')
        ci = 2.0 * np.sqrt(self.reward_outputerror_var[-1][:n_estimates])
        self.ax.fill_between(
            xx[:n_estimates],
            np.array(self.reward_outputerror[-1][:n_estimates]) - ci,
            np.array(self.reward_outputerror[-1][:n_estimates]) + ci,
            color='r',
            alpha=.1)

        # self.ax.plot(xx[n_estimates-1:],self.reward_total[-1][n_estimates-1:], 'k--', linewidth=3, label='Total')
        self.ax.plot(xx[n_estimates - 1:],
                     self.reward_state[-1][n_estimates - 1:],
                     'b--',
                     linewidth=2,
                     label='State')
        ci = 2.0 * np.sqrt(self.reward_state_var[-1][n_estimates - 1:])
        self.ax.fill_between(
            xx[n_estimates - 1:],
            np.array(self.reward_state[-1][n_estimates - 1:]) - ci,
            np.array(self.reward_state[-1][n_estimates - 1:]) + ci,
            color='b',
            alpha=.1)

        self.ax.plot(xx[n_estimates - 1:],
                     self.reward_control[-1][n_estimates - 1:],
                     'g--',
                     linewidth=2,
                     label='Control')
        ci = 2.0 * np.sqrt(self.reward_control_var[-1][n_estimates - 1:])
        self.ax.fill_between(
            xx[n_estimates - 1:],
            np.array(self.reward_control[-1][n_estimates - 1:]) - ci,
            np.array(self.reward_control[-1][n_estimates - 1:]) + ci,
            color='g',
            alpha=.1)

        # self.ax.plot(xx[:n_estimates], self.reward_policy[-1]:n_estimates], 'm--', linewidth=2, label='Policy')
        # self.ax.plot(xx[n_estimates-1:], self.reward_outputerror[-1][n_estimates-1:],'r--', linewidth=2, label='Error')

        self.ax.set_xlim(0, 50)
        # self.ax.set_ylim(0,5)
        self.ax.set_title('Reward Functions')
        self.ax.set_xlabel('Time')
        self.ax.set_ylabel('Reward')
        self.ax.legend()
        self.ax.grid()

        self.static_canvas.figure.savefig(
            "/mnt/hgfs/michaelkapteyn/digitaltwin_ws/src/digitaltwin/outputfiles/figures/reward_plot_{}.svg"
            .format(n_estimates),
            format='svg',
            transparent=True)
        self.static_canvas.draw_idle()

    def _handle_refresh_clicked(self, checked):
        '''
        called when the refresh button is clicked
        '''
        # self._sub = FigSub(self,self.topicText.text())
        pass

    # def save_graph(self, full_path):
    #     '''
    #     check if last graph msg received is valid (non empty), then save in file.dot
    #     '''
    #     if self.graph:
    #         dot_file = open(full_path,'w')
    #         dot_file.write(self.graph)
    #         dot_file.close()
    #         # rospy.loginfo('graph saved succesfully in %s', full_path)
    #     else:
    #         pass
    #         # if self.graph is None it will fall in this case
    #         # rospy.logerr('Could not save Graph: is empty, currently subscribing to: %s, try' +\
    #                      # ' clicking "Update subscriber" button and make sure graph is published at least one time'\
    #                      # , self.topicText.text())
    #
    # def _handle_save_button_clicked(self, checked):
    #     '''
    #     called when the save button is clicked
    #     '''
    #     # rospy.loginfo('Saving graph to dot file')
    #     fileName = QFileDialog.getSaveFileName(self, 'Save graph to dot file','','Graph xdot Files (*.dot)')
    #     if fileName[0] == '':
    #         pass
    #         # rospy.loginfo("User has cancelled saving process")
    #     else:
    #         # add .dot at the end of the filename
    #         full_dot_path = fileName[0]
    #         if not '.dot' in full_dot_path:
    #             full_dot_path += '.dot'
    #         # rospy.loginfo("path to save dot file: %s", full_dot_path)
    #         self.save_graph(full_dot_path)
    #
    # Qt methods
    def shutdown_plugin(self):
        pass

    def save_settings(self, plugin_settings, instance_settings):
        pass

    def restore_settings(self, plugin_settings, instance_settings):
        pass
Ejemplo n.º 5
0
class Network(QWidget):
    def __init__(self, node):
        super(Network, self).__init__()
        layout = QVBoxLayout(self)
        self.static_canvas = FigureCanvas(Figure(figsize=(5, 3)))
        layout.addWidget(self.static_canvas)

        self.ax = self.static_canvas.figure.subplots()
        self.ax.axis('off')
        # t = np.linspace(0, 10, 501)
        # self.ax.plot(t, np.sin(t), ".")

        # # Create QWidget
        # # ui_file = os.path.join(rospkg.RosPack().get_path('rqt_dot'), 'resource', 'rqt_dot.ui')
        # _, package_path = get_resource('packages', 'rqt_dot')
        # ui_file = os.path.join(package_path, 'share', 'rqt_dot', 'resource', 'rqt_dot.ui')
        # loadUi(ui_file, self, {'DotWidget':DotWidget})
        # self.setObjectName('ROSDot')
        #
        # self.refreshButton.clicked[bool].connect(self._handle_refresh_clicked)
        # self.saveButton.clicked[bool].connect(self._handle_save_button_clicked)
        #
        # # flag used to zoom out to fit graph the first time it's received
        # self.first_time_graph_received = True
        # # to store the graph msg received in callback, later on is used to save the graph if needed
        # self.graph = None
        self.topic_name = 'graph_string'
        self._node = node
        self._node.create_subscription(String, self.topic_name,
                                       self.ReceivedCallback, 10)

        # rclpy.spin_once(self._sub)
        #
        # # inform user that no graph has been received by drawing a single node in the rqt
        # self.gen_single_node('no dot received')

    def ReceivedCallback(self, msg):
        # '''
        # updating figure
        # '''
        # # save graph in member variable in case user clicks save button later
        # clear the axis
        # self.ax.clear()
        G = pygraphviz.AGraph(msg.data)
        G.layout(prog='dot')  # [‘neato’|’dot’|’twopi’|’circo’|’fdp’|’nop’]
        with tempfile.NamedTemporaryFile() as tf:
            G.draw(tf.name, format='png')
            self.graph = G
            img = matplotlib.image.imread(tf.name)
            self.ax.imshow(img)
            # self.ax.axis('off')

        self.static_canvas.draw_idle()

    def _handle_refresh_clicked(self, checked):
        '''
        called when the refresh button is clicked
        '''
        # self._sub = FigSub(self,self.topicText.text())
        pass

    # def save_graph(self, full_path):
    #     '''
    #     check if last graph msg received is valid (non empty), then save in file.dot
    #     '''
    #     if self.graph:
    #         dot_file = open(full_path,'w')
    #         dot_file.write(self.graph)
    #         dot_file.close()
    #         # rospy.loginfo('graph saved succesfully in %s', full_path)
    #     else:
    #         pass
    #         # if self.graph is None it will fall in this case
    #         # rospy.logerr('Could not save Graph: is empty, currently subscribing to: %s, try' +\
    #                      # ' clicking "Update subscriber" button and make sure graph is published at least one time'\
    #                      # , self.topicText.text())
    #
    # def _handle_save_button_clicked(self, checked):
    #     '''
    #     called when the save button is clicked
    #     '''
    #     # rospy.loginfo('Saving graph to dot file')
    #     fileName = QFileDialog.getSaveFileName(self, 'Save graph to dot file','','Graph xdot Files (*.dot)')
    #     if fileName[0] == '':
    #         pass
    #         # rospy.loginfo("User has cancelled saving process")
    #     else:
    #         # add .dot at the end of the filename
    #         full_dot_path = fileName[0]
    #         if not '.dot' in full_dot_path:
    #             full_dot_path += '.dot'
    #         # rospy.loginfo("path to save dot file: %s", full_dot_path)
    #         self.save_graph(full_dot_path)

    # Qt methods
    def shutdown_plugin(self):
        self.graph.draw('src/digitaltwin/outputfiles/graph.png', format='png')
        pass

    def save_settings(self, plugin_settings, instance_settings):
        pass

    def restore_settings(self, plugin_settings, instance_settings):
        pass
Ejemplo n.º 6
0
class Base_Plot(QtCore.QObject):
    def __init__(self, parent, widget, mpl_layout):
        super().__init__(parent)
        self.parent = parent

        self.widget = widget
        self.mpl_layout = mpl_layout
        self.fig = mplfigure.Figure()
        mpl.scale.register_scale(AbsoluteLogScale)
        mpl.scale.register_scale(BiSymmetricLogScale)
        
        # Set plot variables
        self.x_zoom_constraint = False
        self.y_zoom_constraint = False
        
        self.create_canvas()        
        self.NavigationToolbar(self.canvas, self.widget, coordinates=True)
        
        # AutoScale
        self.autoScale = [True, True]
        
        # Connect Signals
        self._draw_event_signal = self.canvas.mpl_connect('draw_event', self._draw_event)
        self.canvas.mpl_connect('button_press_event', lambda event: self.click(event))
        self.canvas.mpl_connect('key_press_event', lambda event: self.key_press(event))
        # self.canvas.mpl_connect('key_release_event', lambda event: self.key_release(event))
        
        self._draw_event()
    
    def create_canvas(self):
        self.canvas = FigureCanvas(self.fig)
        self.mpl_layout.addWidget(self.canvas)
        self.canvas.setFocusPolicy(QtCore.Qt.StrongFocus)
        self.canvas.draw()
        
        # Set scales
        scales = {'linear': True, 'log': 0, 'abslog': 0, 'bisymlog': 0}
        for ax in self.ax:
            ax.scale = {'x': scales, 'y': deepcopy(scales)}
            ax.ticklabel_format(scilimits=(-4, 4), useMathText=True)
        
        # Get background
        for ax in self.ax:
            ax.background = self.canvas.copy_from_bbox(ax.bbox)
    
    def _find_calling_axes(self, event):
        for axes in self.ax:    # identify calling axis
            if axes == event or (hasattr(event, 'inaxes') and event.inaxes == axes):
                return axes
    
    def set_xlim(self, axes, x):
        if not self.autoScale[0]: return    # obey autoscale right click option
    
        if axes.get_xscale() in ['linear']:
            # range = np.abs(np.max(x) - np.min(x))
            # min = np.min(x) - range*0.05
            # if min < 0:
                # min = 0
            # xlim = [min, np.max(x) + range*0.05]
            xlim = [np.min(x), np.max(x)]
        if 'log' in axes.get_xscale():
            abs_x = np.abs(x)
            abs_x = abs_x[np.nonzero(abs_x)]    # exclude 0's
            
            if axes.get_xscale() in ['log', 'abslog', 'bisymlog']:
                min_data = np.ceil(np.log10(np.min(abs_x)))
                max_data = np.floor(np.log10(np.max(abs_x)))
                
                xlim = [10**(min_data-1), 10**(max_data+1)]
        
        if np.isnan(xlim).any() or np.isinf(xlim).any():
            pass
        elif xlim != axes.get_xlim():   # if xlim changes
            axes.set_xlim(xlim)
    
    def set_ylim(self, axes, y):
        if not self.autoScale[1]: return    # obey autoscale right click option
        
        min_data = np.array(y)[np.isfinite(y)].min()
        max_data = np.array(y)[np.isfinite(y)].max()
        
        if min_data == max_data:
            min_data -= 10**-1
            max_data += 10**-1
        
        if axes.get_yscale() == 'linear':
            range = np.abs(max_data - min_data)
            ylim = [min_data - range*0.1, max_data + range*0.1]
            
        elif axes.get_yscale() in ['log', 'abslog']:
            abs_y = np.abs(y)
            abs_y = abs_y[np.nonzero(abs_y)]    # exclude 0's
            abs_y = abs_y[np.isfinite(abs_y)]    # exclude nan, inf
            
            if abs_y.size == 0:             # if no data, assign 
                ylim = [10**-7, 10**-1]
            else:            
                min_data = np.ceil(np.log10(np.min(abs_y)))
                max_data = np.floor(np.log10(np.max(abs_y)))
                
                ylim = [10**(min_data-1), 10**(max_data+1)]
                
        elif axes.get_yscale() == 'bisymlog':
            min_sign = np.sign(min_data)
            max_sign = np.sign(max_data)
            
            if min_sign > 0:
                min_data = np.ceil(np.log10(np.abs(min_data)))
            elif min_data == 0 or max_data == 0:
                pass
            else:
                min_data = np.floor(np.log10(np.abs(min_data)))
            
            if max_sign > 0:
                max_data = np.floor(np.log10(np.abs(max_data)))
            elif min_data == 0 or max_data == 0:
                pass
            else:
                max_data = np.ceil(np.log10(np.abs(max_data)))
            
            # TODO: ylim could be incorrect for neg/neg, checked for pos/pos, pos/neg
            ylim = [min_sign*10**(min_data-min_sign), max_sign*10**(max_data+max_sign)]
        
        if ylim != axes.get_ylim():   # if ylim changes, update
            axes.set_ylim(ylim)
    
    def update_xylim(self, axes, xlim=[], ylim=[]):
        data = self._get_data(axes)         

        # on creation, there is no data, don't update
        if np.shape(data['x'])[0] < 2 or np.shape(data['y'])[0] < 2:   
            return
        
        for (axis, lim) in zip(['x', 'y'], [xlim, ylim]):
            # Set Limits
            if len(lim) == 0:
                eval('self.set_' + axis + 'lim(axes, data["' + axis + '"])')
            else:
                eval('axes.set_' + axis + 'lim(lim)')
            
            # If bisymlog, also update scaling, C
            if eval('axes.get_' + axis + 'scale()') == 'bisymlog':
                self._set_scale(axis, 'bisymlog', axes)
            
            ''' # TODO: Do this some day, probably need to create 
                        annotation during canvas creation
            # Move exponent 
            exp_loc = {'x': (.89, .01), 'y': (.01, .96)}
            eval(f'axes.get_{axis}axis().get_offset_text().set_visible(False)')
            ax_max = eval(f'max(axes.get_{axis}ticks())')
            oom = np.floor(np.log10(ax_max)).astype(int)
            axes.annotate(fr'$\times10^{oom}$', xy=exp_loc[axis], 
                          xycoords='axes fraction')
            '''
        
        self._draw_event()  # force a draw
    
    def _get_data(self, axes):      # NOT Generic
        # get experimental data for axes
        data = {'x': [], 'y': []}
        if 'exp_data' in axes.item:
            data_plot = axes.item['exp_data'].get_offsets().T
            if np.shape(data_plot)[1] > 1:
                data['x'] = data_plot[0,:]
                data['y'] = data_plot[1,:]
            
            # append sim_x if it exists
            if 'sim_data' in axes.item and hasattr(axes.item['sim_data'], 'raw_data'):
                if axes.item['sim_data'].raw_data.size > 0:
                    data['x'] = np.append(data['x'], axes.item['sim_data'].raw_data[:,0])
        
        elif 'weight' in axes.item:
            data['x'] = axes.item['weight'].get_xdata()
            data['y'] = axes.item['weight'].get_ydata()
        
        elif any(key in axes.item for key in ['density', 'qq_data', 'sim_data']):
            name = np.intersect1d(['density', 'qq_data'], list(axes.item.keys()))[0]
            for n, coord in enumerate(['x', 'y']):
                xyrange = np.array([])
                for item in axes.item[name]:
                    if name == 'qq_data':
                        coordData = item.get_offsets()
                        if coordData.size == 0:
                            continue
                        else:
                            coordData = coordData[:,n]
                    elif name == 'density':
                        coordData = eval('item.get_' + coord + 'data()')
                    
                    coordData = np.array(coordData)[np.isfinite(coordData)]
                    if coordData.size == 0:
                        continue
                    
                    xyrange = np.append(xyrange, [coordData.min(), coordData.max()])

                xyrange = np.reshape(xyrange, (-1,2))
                data[coord] = [np.min(xyrange[:,0]), np.max(xyrange[:,1])]

        return data
    
    def _set_scale(self, coord, type, event, update_xylim=False):
        def RoundToSigFigs(x, p):
            x = np.asarray(x)
            x_positive = np.where(np.isfinite(x) & (x != 0), np.abs(x), 10**(p-1))
            mags = 10 ** (p - 1 - np.floor(np.log10(x_positive)))
            return np.round(x * mags) / mags
    
        # find correct axes
        axes = self._find_calling_axes(event)
        # for axes in self.ax:
            # if axes == event or (hasattr(event, 'inaxes') and event.inaxes == axes):
                # break
        
        # Set scale menu boolean
        if coord == 'x':
            shared_axes = axes.get_shared_x_axes().get_siblings(axes)               
        else:
            shared_axes = axes.get_shared_y_axes().get_siblings(axes)
        
        for shared in shared_axes:
            shared.scale[coord] = dict.fromkeys(shared.scale[coord], False) # sets all types: False
            shared.scale[coord][type] = True                                # set selected type: True

        # Apply selected scale
        if type == 'linear':
            str = 'axes.set_{:s}scale("{:s}")'.format(coord, 'linear')
        elif type == 'log':
            str = 'axes.set_{0:s}scale("{1:s}", nonpos{0:s}="mask")'.format(coord, 'log')
        elif type == 'abslog':
            str = 'axes.set_{:s}scale("{:s}")'.format(coord, 'abslog')
        elif type == 'bisymlog':
            # default string to evaluate 
            str = 'axes.set_{0:s}scale("{1:s}")'.format(coord, 'bisymlog')
            
            data = self._get_data(axes)[coord]
            if len(data) != 0:
                finite_data = np.array(data)[np.isfinite(data)] # ignore nan and inf
                min_data = finite_data.min()  
                max_data = finite_data.max()
                
                if min_data != max_data:
                    # if zero is within total range, find largest pos or neg range
                    if np.sign(max_data) != np.sign(min_data):  
                        pos_data = finite_data[finite_data>=0]
                        pos_range = pos_data.max() - pos_data.min()
                        neg_data = finite_data[finite_data<=0]
                        neg_range = neg_data.max() - neg_data.min()
                        C = np.max([pos_range, neg_range])
                    else:
                        C = np.abs(max_data-min_data)
                    C /= 1E3                  # scaling factor TODO: debating between 100, 500 and 1000
                    C = RoundToSigFigs(C, 1)  # round to 1 significant figure
                    str = 'axes.set_{0:s}scale("{1:s}", C={2:e})'.format(coord, 'bisymlog', C)
        
        eval(str)
        if type == 'linear' and coord == 'x':
            formatter = MathTextSciSIFormatter(useOffset=False, useMathText=True)
            axes.xaxis.set_major_formatter(formatter)
            
        elif type == 'linear' and coord == 'y':
            formatter = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True)
            formatter.set_powerlimits([-3, 4])
            axes.yaxis.set_major_formatter(formatter)
            
        if update_xylim:
            self.update_xylim(axes)
 
    def _animate_items(self, bool=True):
        for axis in self.ax:
            if axis.get_legend() is not None:
                axis.get_legend().set_animated(bool)
            
            for item in axis.item.values():
                if isinstance(item, list):
                    for subItem in item:
                        if isinstance(subItem, dict):
                            subItem['line'].set_animated(bool)
                        else:
                            subItem.set_animated(bool)
                else:
                    item.set_animated(bool)
    
    def _draw_items_artist(self):   
        for axis in self.ax:     # restore background first (needed for twinned plots)
            self.canvas.restore_region(axis.background)   
        
        for axis in self.ax:
            for item in axis.item.values():
                if isinstance(item, list):
                    for subItem in item:
                        if isinstance(subItem, dict):
                            axis.draw_artist(subItem['line'])
                        else:
                            axis.draw_artist(subItem) 
                else:
                    axis.draw_artist(item)
           
            if axis.get_legend() is not None:
                axis.draw_artist(axis.get_legend())
            
        self.canvas.update()
        # self.canvas.flush_events()    # unnecessary?
    
    def set_background(self):
        self.canvas.draw_idle() # for when shock changes
        for axis in self.ax:
            # axis.background = self.canvas.copy_from_bbox(axis.bbox)
            axis.background = self.canvas.copy_from_bbox(self.fig.bbox)
    
    def _draw_event(self, event=None):   # After redraw (new/resizing window), obtain new background
        self._animate_items(True)
        self.set_background()
        self._draw_items_artist()     
        # self.canvas.draw_idle()   # unnecessary?
    
    def clear_plot(self, ignore=[], draw=True):
        for axis in self.ax:
            if axis.get_legend() is not None:
                axis.get_legend().remove()
                
            for item in axis.item.values():
                if hasattr(item, 'set_offsets'):    # clears all data points
                    if 'scatter' not in ignore:
                        item.set_offsets(([np.nan, np.nan]))
                elif hasattr(item, 'set_xdata') and hasattr(item, 'set_ydata'):
                    if 'line' not in ignore:
                        item.set_xdata([np.nan, np.nan]) # clears all lines
                        item.set_ydata([np.nan, np.nan])
                elif hasattr(item, 'set_text'): # clears all text boxes
                    if 'text' not in ignore:
                        item.set_text('')
        if draw:
            self._draw_event()

    def click(self, event):
        if event.button == 3: # if right click
            if self.toolbar._active is None:
                self._popup_menu(event)
            # if self.toolbar._active is 'ZOOM':  # if zoom is on, turn off
                # self.toolbar.press_zoom(event)  # cancels current zooom
                # self.toolbar.zoom()             # turns zoom off
            elif event.dblclick:                  # if double right click, go to default view
                self.toolbar.home()

    def key_press(self, event):
        if event.key == 'escape':
            if self.toolbar._active is 'ZOOM':  # if zoom is on, turn off
                self.toolbar.zoom()             # turns zoom off
            elif self.toolbar._active is 'PAN':
                self.toolbar.pan()
        # elif event.key == 'shift':
        elif event.key == 'x':  # Does nothing, would like to make sticky constraint zoom/pan
            self.x_zoom_constraint = not self.x_zoom_constraint
        elif event.key == 'y':  # Does nothing, would like to make sticky constraint zoom/pan
            self.y_zoom_constraint = not self.y_zoom_constraint
        elif event.key in ['s', 'l', 'L', 'k']: pass
        else:
            key_press_handler(event, self.canvas, self.toolbar)
    
    # def key_release(self, event):
        # print(event.key, 'released')
    
    def NavigationToolbar(self, *args, **kwargs):
        ## Add toolbar ##
        self.toolbar = CustomNavigationToolbar(self.canvas, self.widget, coordinates=True)
        self.mpl_layout.addWidget(self.toolbar)

    def _popup_menu(self, event):
        axes = self._find_calling_axes(event)   # find axes calling right click
        if axes is None: return
        
        pos = self.parent.mapFromGlobal(QtGui.QCursor().pos())
        
        popup_menu = QMenu(self.parent)
        xScaleMenu = popup_menu.addMenu('x-scale')
        yScaleMenu = popup_menu.addMenu('y-scale')
        
        for coord in ['x', 'y']:
            menu = eval(coord + 'ScaleMenu')
            for type in axes.scale[coord].keys():
                action = QAction(type, menu, checkable=True)
                if axes.scale[coord][type]: # if it's checked
                    action.setEnabled(False)
                else:
                    action.setEnabled(True)
                menu.addAction(action)
                action.setChecked(axes.scale[coord][type])
                fcn = lambda event, coord=coord, type=type: self._set_scale(coord, type, axes, True)
                action.triggered.connect(fcn)
        
        # Create menu for AutoScale options X Y All
        popup_menu.addSeparator()
        autoscale_options = ['AutoScale X', 'AutoScale Y', 'AutoScale All']
        for n, text in enumerate(autoscale_options):
            action = QAction(text, menu, checkable=True)
            if n < len(self.autoScale):
                action.setChecked(self.autoScale[n])
            else:
                action.setChecked(all(self.autoScale))
            popup_menu.addAction(action)
            action.toggled.connect(lambda event, n=n: self._setAutoScale(n, event, axes))
                    
        popup_menu.exec_(self.parent.mapToGlobal(pos))    
    
    def _setAutoScale(self, choice, event, axes):
        if choice == len(self.autoScale):
            for n in range(len(self.autoScale)):
                self.autoScale[n] = event
        else:
            self.autoScale[choice] = event
        
        if event:   # if something toggled true, update limits
            self.update_xylim(axes)
Ejemplo n.º 7
0
class MainWindow(QMainWindow):

    MAX_RULE_PARTS = 7

    # TODO: Specify in config file
    scatter_cols = [
        'sepal length (cm)', 'sepal width (cm)', 'petal length (cm)',
        'petal width (cm)'
    ]

    def plot_scatter(self, cols=None, ax=None):

        if cols is None:
            plt_cols = self.scatter_cols
        else:
            plt_cols = cols

        axes = self.plotter.scatter_rule(plt_cols, self.rules, ax=ax)

    def __init__(self):
        QMainWindow.__init__(self)

        self.setWindowTitle("Heuristic Generator")

        self.df = load_data()
        self.ranges = calc_ranges(self.df, self.scatter_cols)

        # Menu
        self.menu = self.menuBar()
        self.file_menu = self.menu.addMenu("File")

        self.rule_parts = []

        # Exit QAction
        exit_action = QAction("Exit", self)
        exit_action.setShortcut(QKeySequence.Quit)
        exit_action.triggered.connect(self.close)

        self.file_menu.addAction(exit_action)

        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)
        self.layout = QtWidgets.QHBoxLayout(self._main)

        self.setup_ui()

        self.plotter = ScatterPlotter(self.df)
        self.update_plot()

        # Window dimensions
        geometry = qApp.desktop().availableGeometry(self)
        self.resize(geometry.width() * 0.7, geometry.height() * 0.8)

    def setup_ui(self):

        left_col = self.setup_left_col()
        right_col = self.setup_right_col()

        self.layout.addWidget(left_col, stretch=2)
        self.layout.addLayout(right_col, stretch=1)

    def setup_left_col(self):

        col_box = QtWidgets.QGroupBox('Rule and data visualization')
        left_col_layout = QtWidgets.QVBoxLayout()

        # self.fig = self.plot_scatter()
        # self.canvas = FigureCanvas(self.fig)
        # self.addToolBar(NavigationToolbar(self.canvas, self))

        # Create Figure canvas
        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.fig.clear()
        self.canvas.updateGeometry()

        # Button for applying the rules and updating the plot and stats
        add_rule_btn = QtWidgets.QPushButton('Apply and Update')
        add_rule_btn.clicked.connect(self.update_plot)
        update_btn_layout = QtWidgets.QHBoxLayout()
        update_btn_layout.addStretch(stretch=1)
        update_btn_layout.addWidget(add_rule_btn)
        update_btn_layout.addStretch(stretch=1)

        self.plot_attr_selection = PlotAttrSelectionWidget(
            parent=self, attributes=self.scatter_cols)
        # self.plot_attr_selection.registerChangeListener(self.update_plot)

        left_col_layout.addWidget(self.canvas, stretch=2)
        left_col_layout.addLayout(update_btn_layout)
        left_col_layout.addWidget(self.plot_attr_selection, stretch=1)
        col_box.setLayout(left_col_layout)

        return col_box

    def update_plot(self):
        plt_cols = self.plot_attr_selection.get_plot_attr()
        self.rules = self.get_rules()
        print(self.rules)
        self.fig.clear()
        # self.canvas.figure.clear()
        if len(plt_cols) > 1:
            ax = self.fig.add_subplot()
            self.canvas.figure.clear()
            self.plot_scatter(cols=plt_cols, ax=ax)
        else:
            pass

        # Update stat table
        stats_df = self.plotter.get_rule_stats()
        print(stats_df)
        self.stat_table.update_table(stats_df)

        self.canvas.draw_idle()

    def get_rules(self):
        rule_list = []
        for rp in self.rule_parts:
            rule_part_dict = {}
            rule = rp.get_rule()
            rule_part_dict['rule_id'] = rule['rule_id']
            rule_part_dict['rule_attr'] = rule['feature']
            rule_part_dict['attr_min'] = rule['range'][0]
            rule_part_dict['attr_max'] = rule['range'][1]
            rule_list.append(rule_part_dict)
        return rule_list

    def add_rule_part(self):
        n_rule_parts = len(self.rule_parts)
        if n_rule_parts < self.MAX_RULE_PARTS:
            rule_part = RulePartWidget(feature_ranges=self.ranges,
                                       rule_number=n_rule_parts + 1)
            self.rule_parts.append(rule_part)
            self.rule_part_layout.insertWidget(n_rule_parts, rule_part)

    def initial_stat_table(self):
        # stat df: ['rule_id','confidence','support', 'lift', 'recall', 'tp', 'tn', 'fp', 'fn']
        header = [
            'Rulepart ID', 'Confidence', 'Support', 'Lift', 'Recall', 'TP',
            'TN', 'FP', 'FN'
        ]
        data = [
            #   ['Rulepart 1', 0.766, 0.427, 2.297, 'X'],
            #   ['Total Ruleset', 0.766, 0.427, 2.297, 'X']
        ]
        table_df = pd.DataFrame(columns=header, data=data)
        return table_df

    def setup_right_col(self):

        right_col_layout = QtWidgets.QVBoxLayout()
        statistics_box = QtWidgets.QGroupBox('Rule statistics')
        statistics_layout = QtWidgets.QVBoxLayout()

        table_df = self.initial_stat_table()
        # rule statistics table
        self.stat_table = RuleTableWidget(table_df)

        statistics_layout.addWidget(self.stat_table)
        statistics_box.setLayout(statistics_layout)

        rule_part_box = QtWidgets.QGroupBox('Rule definition')
        self.rule_part_layout = QtWidgets.QVBoxLayout()
        add_rule_btn = QtWidgets.QPushButton('Add Rule part')
        add_rule_btn.clicked.connect(self.add_rule_part)
        btn_layout = QtWidgets.QHBoxLayout()
        btn_layout.addStretch(stretch=1)
        btn_layout.addWidget(add_rule_btn)
        btn_layout.addStretch(stretch=1)
        self.rule_part_layout.addLayout(btn_layout)
        self.rule_part_layout.addStretch(stretch=1)
        rule_part_box.setLayout(self.rule_part_layout)

        right_col_layout.addWidget(statistics_box, stretch=1)
        right_col_layout.addWidget(rule_part_box, stretch=1)

        return right_col_layout
Ejemplo n.º 8
0
class SensorWidget(QWidget):
    def __init__(self, node):
        super(SensorWidget, self).__init__()
        layout = QVBoxLayout(self)
        self.static_canvas = FigureCanvas(Figure(figsize=(5, 3)))
        layout.addWidget(self.static_canvas)

        self.ax = self.static_canvas.figure.subplots()
        self.ax.set_xlim(0, 50)
        self.ax.set_ylim(500, 1500)

        self.ax.set_title('Sensor Data')
        self.ax.set_xlabel('Time')
        self.ax.set_ylabel('Microstrain')
        self.ax.legend()
        self.ax.grid()

        self.sensor_data = []
        self.sensor_ref = None
        self._node = node
        self._node.create_subscription(Sensor, 'sensor_data',
                                       self.dataCallback, 10)
        self._node.create_subscription(SensorList, 'sensor_ref',
                                       self.refCallback, 10)

        # rclpy.spin_once(self._sub)
        #
        # # inform user that no graph has been received by drawing a single node in the rqt
        # self.gen_single_node('no dot received')

    def dataCallback(self, msg):
        # '''
        # updating figure
        # '''
        # # save graph in member variable in case user clicks save button later
        # clear the axis
        self.sensor_data.append(msg.data)
        if self.sensor_ref is not None:
            self.plot()

    def refCallback(self, msg):
        self.sensor_ref = msg.datas

    def plot(self):

        self.ax.clear()
        idxstoplot = [1, 6, 16]
        colors = ['r', 'g', 'b', 'm', 'c']

        self.types = [s.type for s in self.sensor_ref]
        n_estimates = sum(1 if t == 1. else 0 for t in self.types)
        xx = range(0, len(self.sensor_data))
        xxref = range(0, len(self.sensor_ref))
        for i, idx in enumerate(idxstoplot):
            self.ax.scatter(xx, [s[idx] for s in self.sensor_data],
                            s=20,
                            c=colors[i],
                            label='epsilonhat {}'.format(idx))
            mean_estimate = [
                s.data[idx] for s in self.sensor_ref[:n_estimates]
            ]
            vars_estimate = [
                s.vars[idx] for s in self.sensor_ref[:n_estimates]
            ]
            ci_estimate = 2 * np.sqrt(vars_estimate)  #2 stddevs

            mean_predict = [
                s.data[idx] for s in self.sensor_ref[n_estimates - 1:]
            ]
            vars_predict = [
                s.vars[idx] for s in self.sensor_ref[n_estimates - 1:]
            ]
            ci_predict = 2 * np.sqrt(vars_predict)  #2 stddevs

            self.ax.plot(xxref[:n_estimates],
                         mean_estimate,
                         '{}-'.format(colors[i]),
                         linewidth=2,
                         label='epsilon {}'.format(idx))
            self.ax.fill_between(xxref[:n_estimates],
                                 np.array(mean_estimate) - ci_estimate,
                                 np.array(mean_estimate) + ci_estimate,
                                 color='b',
                                 alpha=.1)

            self.ax.plot(xxref[n_estimates - 1:],
                         mean_predict,
                         '{}--'.format(colors[i]),
                         linewidth=2,
                         label='epsilon {}'.format(idx))
            self.ax.fill_between(xxref[n_estimates - 1:],
                                 np.array(mean_predict) - ci_predict,
                                 np.array(mean_predict) + ci_predict,
                                 color='r',
                                 alpha=.1)

        self.ax.set_xlim(0, 50)
        self.ax.set_ylim(500, 1500)

        self.ax.set_title('Sensor Data')
        self.ax.set_xlabel('Time')
        self.ax.set_ylabel('Microstrain')
        self.ax.legend()
        self.ax.grid()
        # if n_estimates == 30:

        self.static_canvas.figure.savefig(
            "/mnt/hgfs/michaelkapteyn/digitaltwin_ws/src/digitaltwin/outputfiles/figures/sensor_plot_{}.svg"
            .format(n_estimates),
            format='svg',
            transparent=True)
        self.static_canvas.draw_idle()

    def _handle_refresh_clicked(self, checked):
        '''
        called when the refresh button is clicked
        '''
        # self._sub = FigSub(self,self.topicText.text())
        pass

    # Qt methods
    def shutdown_plugin(self):
        pass

    def save_settings(self, plugin_settings, instance_settings):
        pass

    def restore_settings(self, plugin_settings, instance_settings):
        pass
Ejemplo n.º 9
0
class BrightnessContrastEditor(QObject):

    edited = Signal(float, float)

    reset = Signal()

    def __init__(self, parent=None):
        super().__init__(parent)

        self._data_range = (0, 1)
        self._ui_min, self._ui_max = self._data_range
        self._data = None
        self.histogram = None
        self.histogram_artist = None
        self.line_artist = None

        self.default_auto_threshold = 5000
        self.current_auto_threshold = self.default_auto_threshold

        loader = UiLoader()
        self.ui = loader.load_file('brightness_contrast_editor.ui', parent)

        self.setup_plot()

        self.ui.minimum.setMaximum(NUM_INCREMENTS)
        self.ui.maximum.setMaximum(NUM_INCREMENTS)
        self.ui.brightness.setMaximum(NUM_INCREMENTS)
        self.ui.contrast.setMaximum(NUM_INCREMENTS)

        self.setup_connections()

    def setup_connections(self):
        self.ui.minimum.valueChanged.connect(self.minimum_edited)
        self.ui.maximum.valueChanged.connect(self.maximum_edited)
        self.ui.brightness.valueChanged.connect(self.brightness_edited)
        self.ui.contrast.valueChanged.connect(self.contrast_edited)

        self.ui.set_data_range.pressed.connect(self.select_data_range)
        self.ui.reset.pressed.connect(self.reset_pressed)
        self.ui.auto_button.pressed.connect(self.auto_pressed)

    @property
    def data_range(self):
        return self._data_range

    @data_range.setter
    def data_range(self, v):
        self._data_range = v
        self.clip_ui_range()
        self.ensure_min_max_space('max')
        self.update_gui()

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, v):
        self._data = v
        self.reset_data_range()

    @property
    def data_list(self):
        if self.data is None:
            return []
        elif isinstance(self.data, (tuple, list)):
            return list(self.data)
        elif isinstance(self.data, dict):
            return list(self.data.values())
        else:
            return [self.data]

    @property
    def data_bounds(self):
        if self.data is None:
            return (0, 1)

        data = self.data_list
        mins = [x.min() for x in data]
        maxes = [x.max() for x in data]
        return (min(mins), max(maxes))

    def reset_data_range(self):
        self.data_range = self.data_bounds

    def update_gui(self):
        self.update_brightness()
        self.update_contrast()
        self.update_histogram()
        self.update_range_labels()
        self.update_line()

    @property
    def data_min(self):
        return self.data_range[0]

    @property
    def data_max(self):
        return self.data_range[1]

    @property
    def data_mean(self):
        return np.mean(self.data_range)

    @property
    def data_width(self):
        return self.data_range[1] - self.data_range[0]

    @property
    def ui_min(self):
        return self._ui_min

    @ui_min.setter
    def ui_min(self, v):
        self._ui_min = v
        slider_v = np.interp(v, self.data_range, (0, NUM_INCREMENTS))
        self.ui.minimum.setValue(slider_v)
        self.update_range_labels()
        self.update_line()
        self.modified()

    @property
    def ui_max(self):
        return self._ui_max

    @ui_max.setter
    def ui_max(self, v):
        self._ui_max = v
        slider_v = np.interp(v, self.data_range, (0, NUM_INCREMENTS))
        self.ui.maximum.setValue(slider_v)
        self.update_range_labels()
        self.update_line()
        self.modified()

    def clip_ui_range(self):
        # Clip the ui min and max to be in the data range
        if self.ui_min < self.data_min:
            self.ui_min = self.data_min

        if self.ui_max > self.data_max:
            self.ui_max = self.data_max

    @property
    def ui_mean(self):
        return np.mean((self.ui_min, self.ui_max))

    @ui_mean.setter
    def ui_mean(self, v):
        offset = v - self.ui_mean
        self.ui_range = (self.ui_min + offset, self.ui_max + offset)

    @property
    def ui_width(self):
        return self.ui_max - self.ui_min

    @ui_width.setter
    def ui_width(self, v):
        offset = (v - self.ui_width) / 2
        self.ui_range = (self.ui_min - offset, self.ui_max + offset)

    @property
    def ui_range(self):
        return (self.ui_min, self.ui_max)

    @ui_range.setter
    def ui_range(self, v):
        with block_signals(self, self.ui.minimum, self.ui.maximum):
            self.ui_min = v[0]
            self.ui_max = v[1]

        self.modified()

    @property
    def ui_brightness(self):
        return self.ui.brightness.value() / NUM_INCREMENTS * 100

    @ui_brightness.setter
    def ui_brightness(self, v):
        self.ui.brightness.setValue(v / 100 * NUM_INCREMENTS)

    @property
    def ui_contrast(self):
        return self.ui.contrast.value() / NUM_INCREMENTS * 100

    @ui_contrast.setter
    def ui_contrast(self, v):
        self.ui.contrast.setValue(v / 100 * NUM_INCREMENTS)

    @property
    def contrast(self):
        angle = np.arctan((self.ui_width - self.data_width) / self.data_width)
        return 100 - np.interp(angle, (-np.pi / 4, np.pi / 4), (0, 100))

    @contrast.setter
    def contrast(self, v):
        angle = np.interp(100 - v, (0, 100), (-np.pi / 4, np.pi / 4))
        self.ui_width = np.tan(angle) * self.data_width + self.data_width

    @property
    def brightness(self):
        return 100 - np.interp(self.ui_mean, self.data_range, (0, 100))

    @brightness.setter
    def brightness(self, v):
        self.ui_mean = np.interp(100 - v, (0, 100), self.data_range)

    def ensure_min_max_space(self, one_to_change):
        # Keep the maximum at least one increment ahead of the minimum
        if self.ui.maximum.value() > self.ui.minimum.value():
            return

        if one_to_change == 'max':
            w = self.ui.maximum
            v = self.ui.minimum.value() + 1
            a = '_ui_max'
        else:
            w = self.ui.minimum
            v = self.ui.maximum.value() - 1
            a = '_ui_min'

        with block_signals(w):
            w.setValue(v)

        interpolated = np.interp(v, (0, NUM_INCREMENTS), self.data_range)
        setattr(self, a, interpolated)

    def minimum_edited(self):
        v = self.ui.minimum.value()
        self._ui_min = np.interp(v, (0, NUM_INCREMENTS), self.data_range)
        self.clip_ui_range()
        self.ensure_min_max_space('max')

        self.update_brightness()
        self.update_contrast()
        self.update_range_labels()
        self.update_line()
        self.modified()

    def maximum_edited(self):
        v = self.ui.maximum.value()
        self._ui_max = np.interp(v, (0, NUM_INCREMENTS), self.data_range)
        self.clip_ui_range()
        self.ensure_min_max_space('min')

        self.update_brightness()
        self.update_contrast()
        self.update_range_labels()
        self.update_line()
        self.modified()

    def update_brightness(self):
        with block_signals(self, self.ui.brightness):
            self.ui_brightness = self.brightness

    def update_contrast(self):
        with block_signals(self, self.ui.contrast):
            self.ui_contrast = self.contrast

    def brightness_edited(self, v):
        self.brightness = self.ui_brightness
        self.update_contrast()

    def contrast_edited(self, v):
        self.contrast = self.ui_contrast
        self.update_brightness()

    def modified(self):
        self.edited.emit(self.ui_min, self.ui_max)

    def setup_plot(self):
        self.figure = Figure()
        self.canvas = FigureCanvas(self.figure)
        self.axis = self.figure.add_subplot(111)

        # Turn off ticks
        self.axis.axis('off')

        self.figure.tight_layout()

        self.ui.plot_layout.addWidget(self.canvas)

    def clear_plot(self):
        self.axis.clear()
        self.histogram_artist = None
        self.line_artist = None

    def update_histogram(self):
        # Clear the plot so everything will be re-drawn from scratch
        self.clear_plot()

        data = self.data_list
        if not data:
            return

        histograms = []
        for datum in data:
            kwargs = {
                'a': datum,
                'bins': HISTOGRAM_NUM_BINS,
                'range': self.data_range,
            }
            hist, bins = np.histogram(**kwargs)
            histograms.append(hist)

        self.histogram = sum(histograms)
        kwargs = {
            'x': self.histogram,
            'bins': HISTOGRAM_NUM_BINS,
            'color': 'black',
        }
        self.histogram_artist = self.axis.hist(**kwargs)[2]

        self.canvas.draw()

    def update_range_labels(self):
        labels = (self.ui.min_label, self.ui.max_label)
        texts = [f'{x:.2f}' for x in self.ui_range]
        for label, text in zip(labels, texts):
            label.setText(text)

    def create_line(self):
        xs = (self.ui_min, self.ui_max)
        ys = self.axis.get_ylim()
        kwargs = {
            'scalex': False,
            'scaley': False,
            'color': 'black',
        }
        self.line_artist, = self.axis.plot(xs, ys, **kwargs)

    def update_line(self):
        if self.line_artist is None:
            self.create_line()

        xs = (self.ui_min, self.ui_max)
        ys = self.axis.get_ylim()

        xlim = self.axis.get_xlim()

        # Rescale the xs to be in the plot scaling
        interp = interp1d(self.data_range, xlim, fill_value='extrapolate')

        self.line_artist.set_data(interp(xs), ys)
        self.canvas.draw_idle()

    @property
    def max_num_pixels(self):
        return max(np.prod(x.shape) for x in self.data_list)

    def select_data_range(self):
        dialog = QDialog(self.ui)
        layout = QVBoxLayout()
        dialog.setLayout(layout)

        range_widget = RangeWidget(dialog)
        range_widget.bounds = self.data_bounds
        range_widget.min = self.data_range[0]
        range_widget.max = self.data_range[1]
        layout.addWidget(range_widget.ui)

        buttons = QDialogButtonBox.Ok | QDialogButtonBox.Cancel
        button_box = QDialogButtonBox(buttons, dialog)
        button_box.accepted.connect(dialog.accept)
        button_box.rejected.connect(dialog.reject)
        layout.addWidget(button_box)

        if not dialog.exec_():
            # User canceled
            return

        data_range = range_widget.range
        if data_range[0] >= data_range[1]:
            message = 'Min cannot be greater than or equal to the max'
            QMessageBox.critical(self.ui, 'Validation Error', message)
            return

        if self.data_range == data_range:
            # Nothing changed...
            return

        self.data_range = data_range
        self.modified()

    def reset_pressed(self):
        self.reset_data_range()
        self.reset_auto_threshold()
        self.reset.emit()

    def reset_auto_threshold(self):
        self.current_auto_threshold = self.default_auto_threshold

    def auto_pressed(self):
        data_range = self.data_range
        hist = self.histogram

        if hist is None:
            return

        # FIXME: should we do something other than max_num_pixels?
        pixel_count = self.max_num_pixels
        num_bins = len(hist)
        hist_start = data_range[0]
        bin_size = self.data_width / num_bins
        auto_threshold = self.current_auto_threshold

        # Perform the operation as ImageJ does it
        if auto_threshold < 10:
            auto_threshold = self.default_auto_threshold
        else:
            auto_threshold /= 2

        self.current_auto_threshold = auto_threshold

        limit = pixel_count / 10
        threshold = pixel_count / auto_threshold
        for i, count in enumerate(hist):
            if threshold < count <= limit:
                break

        h_min = i

        for i, count in reversed_enumerate(hist):
            if threshold < count <= limit:
                break

        h_max = i

        if h_max < h_min:
            # Reset the range
            self.reset_auto_threshold()
            self.ui_range = self.data_range
        else:
            vmin = hist_start + h_min * bin_size
            vmax = hist_start + h_max * bin_size
            if vmin == vmax:
                vmin, vmax = data_range

            self.ui_range = vmin, vmax

        self.update_brightness()
        self.update_contrast()
Ejemplo n.º 10
0
class MisfitLog(QtWidgets.QWidget):
    def __init__(self, main_window, parent=None):
        super().__init__(parent)

        self.genie = main_window.genie

        self.misfit_log = []

        layout = QtWidgets.QVBoxLayout()
        self.setLayout(layout)
        self.canvas = FigureCanvas(Figure())
        self.canvas.figure.subplots_adjust(left=0.1,
                                           right=0.9,
                                           top=0.95,
                                           bottom=0.1)
        layout.addWidget(self.canvas)

        self._static_ax = self.canvas.figure.subplots()
        self._static_ax_lam = self._static_ax.twinx()

        self.update_log()

    def update_log(self):
        if self.genie.project_cfg is not None:
            file_name = os.path.join(
                self.genie.cfg.current_project_dir, "inversions",
                self.genie.project_cfg.curren_inversion_name, "inv_log.txt")
            if os.path.isfile(file_name):
                self.misfit_log = self.parse(file_name)

        self.show_log()

    def parse(self, file_name):
        misfit_log = []

        with open(file_name) as fd:
            lines = fd.readlines()

        last_iter = -1
        try:
            for l in lines:
                i = l.find("Phi =")
                if i >= 0:
                    s = l.split(":", 1)
                    iter = int(s[0])

                    s = l.split("=")
                    x = s[1]

                    s = x.split("*")
                    lam = float(s[-1])

                    x = s[0]
                    i = x.find("+")
                    if x[i - 1] == "e":
                        i = x.find("+", i + 1)

                    misfit = float(x[:i])
                    reg = float(x[i + 1:]) * lam

                    t = (misfit, reg, lam)
                    if iter == last_iter:
                        misfit_log[-1] = t
                    else:
                        misfit_log.append(t)

                    last_iter = iter
        except Exception:
            pass

        return misfit_log

    def show_log(self):
        self._static_ax.cla()
        self._static_ax.set_yscale('log')
        self._static_ax.set_xlabel('Iteration')
        self._static_ax.set_ylabel('Misfit, Regularization')

        self._static_ax_lam.cla()
        self._static_ax_lam.set_ylabel('Lambda')

        self._static_ax.xaxis.set_major_locator(
            ticker.MaxNLocator(integer=True))

        if self.misfit_log:
            self._static_ax.plot([item[0] for item in self.misfit_log],
                                 color="tab:blue",
                                 label="Misfit")
            self._static_ax.plot([item[1] for item in self.misfit_log],
                                 color="tab:orange",
                                 label="Regularization")
            self._static_ax_lam.plot([item[2] for item in self.misfit_log],
                                     color="tab:green")

            self._static_ax.plot([], color="tab:green", label="Lambda")
            self._static_ax.legend()

        self.canvas.draw_idle()
Ejemplo n.º 11
0
class Calibration(QWidget):
    def __init__(self, parent):
        super().__init__(parent)
        self.parent = parent
        self.mainWidget = QWidget()
        self.mainLayout = QVBoxLayout(self.mainWidget)
        self.mainLayout.setAlignment(Qt.AlignCenter)

        self.settings = QWidget()
        self.settingsLayout = QHBoxLayout(self.settings)
        self.stdSelect = StandardSelect(self)
        self.stdSelect.setFixedHeight(200)
        self.settingsLayout.addWidget(self.stdSelect)

        self.formGroupBox = QGroupBox("")
        layout = QFormLayout()
        self.equationBtn = QPushButton('Calculate')
        layout.addRow(QLabel('Cal. Equations:'), self.equationBtn)
        self.equationBtn.clicked.connect(self.equations)
        self.exportBtn = QPushButton('Export')
        layout.addRow(QLabel('Export Equations:'), self.exportBtn)
        self.elem1 = QComboBox()
        self.elem1.currentTextChanged.connect(self.graph)
        layout.addRow(QLabel("Graph 1:"), self.elem1)
        self.elem2 = QComboBox()
        self.elem2.currentTextChanged.connect(self.graph)
        layout.addRow(QLabel("Graph 2:"), self.elem2)
        self.formGroupBox.setLayout(layout)
        self.plotBtn = QPushButton('Plot')
        layout.addRow(QLabel('Show graphs:'), self.plotBtn)
        self.plotBtn.clicked.connect(self.graph)

        self.settingsLayout.addWidget(self.formGroupBox)
        self.settingsLayout.addStretch(1)
        self.mainLayout.addWidget(self.settings)

        self.fig = Figure()
        self.ax = self.fig.subplots(1, 2)
        self.canvas = FigureCanvas(self.fig)
        self.toolbar = NavigationToolbar(self.canvas, self)
        self.toolbar.setStyleSheet(
            "QWidget {border: None; background-color: white; color: black}")
        self.mainLayout.addWidget(self.toolbar)
        self.mainLayout.addWidget(self.canvas)

    def get_page(self):
        """
        Returns Page for bulk analysis for main stack of gui
        Returns
        -------
        mainWidget : QWidget
            The widget that the houses the bulk analysis screen.
        """
        return self.mainWidget

    def equations(self):
        stds = self.stdSelect.return_all()
        self.parent.Data.get_regression_values('intensity', stds)
        self.parent.Data.calibration_equations()

    def graph(self):
        if not self.parent.Data.regression_equations:
            return

        elem1 = self.elem1.currentText()
        elem2 = self.elem2.currentText()
        self.parent.Data.calibration_graph(elem1, ax=self.ax[0])
        self.parent.Data.calibration_graph(elem2, ax=self.ax[1])
        self.canvas.draw_idle()
Ejemplo n.º 12
0
class InteractiveMPLGraph:
    def __init__(self, parent):

        self.fig = Figure(dpi=170)
        self.canvas = FigureCanvas(self.fig)
        self.ax = self.fig.add_subplot(111)
        self.cax = None
        self.selected_nodes = []
        self._node = {}
        self._adj = {}
        self._current_node = None
        self._current_node_handle = None
        self._parent = parent
        self._analysis = parent.analysis
        self._canvas_toolbar = None
        self._default_size = {}

        node_positions = self._analysis.get_current_node_positions()
        if node_positions is None:
            node_positions = self._analysis.get_node_positions_2d()
        node_labels = self._analysis.get_short_node_labels()
        edge_labels_occupancy = {
            (key.split(':')[0], key.split(':')[1]): value
            for key, value in self._analysis.get_occupancies(
                as_labels=True).items()
        }
        frame_time, frame_unit = self._parent._search_parameter['frame_time']
        edge_labels_endurance = {
            (key.split(':')[0], key.split(':')[1]): value
            for key, value in self._analysis.get_endurance_times(
                as_labels=True, frame_time=frame_time,
                frame_unit=frame_unit).items()
        }
        if self._parent._analysis_type == 'ww':
            edge_labels_nb_water = {
                (key.split(':')[0], key.split(':')[1]): value
                for key, value in self._analysis.get_nb_waters(
                    as_labels=True).items()
            }
        graph = self._analysis.initial_graph
        f = 0.55
        len2fontsize = defaultdict(
            lambda: 6 * f, {
                2: 11 * f,
                3: 11 * f,
                4: 11 * f,
                5: 11 * f,
                6: 10 * f,
                7: 9 * f,
                8: 8 * f,
                9: 7 * f,
                10: 7 * f
            })

        for node in graph.nodes:
            label_length = len(node_labels[node])
            if not self._analysis.residuewise: label_length -= 1
            subgraph = graph.subgraph([node])
            label = nx.draw_networkx_labels(
                subgraph,
                node_positions,
                labels={node: node_labels[node]},
                font_weight='bold',
                font_size=len2fontsize[label_length],
                ax=self.ax)[node]
            handle = nx.draw_networkx_nodes(subgraph,
                                            node_positions,
                                            node_color=default_colors[0],
                                            alpha=0.5,
                                            ax=self.ax)

            self._node[node] = {
                'handle': handle,
                'label': label,
                'active': True,
                'color': default_colors[0]
            }
            self._default_size['node'] = (handle.get_sizes()[0],
                                          len2fontsize[label_length])
            self._adj[node] = {}

        for u, v in graph.edges:
            subgraph = graph.subgraph([u, v])
            direction = list(subgraph.edges)[0]
            handle = nx.draw_networkx_edges(subgraph,
                                            node_positions,
                                            width=1.0,
                                            alpha=0.5,
                                            ax=self.ax)
            self._default_size['edge'] = handle.get_linewidth()
            segments = handle.get_segments()
            ta = trans_angle(segments[0][0], segments[0][1], self.ax)
            try:
                edge_label = {direction: edge_labels_occupancy[(u, v)]}
            except KeyError:
                edge_label = {direction: edge_labels_occupancy[(v, u)]}
            edge_label_occupancy = nx.draw_networkx_edge_labels(
                subgraph,
                node_positions,
                edge_labels=edge_label,
                font_weight='bold',
                font_size=len2fontsize[6],
                ax=self.ax)[direction]
            edge_label_occupancy.set_visible(False)
            edge_label_occupancy.set_rotation(ta)
            self._default_size['label'] = edge_label_occupancy.get_fontsize()

            try:
                edge_label = {direction: edge_labels_endurance[(u, v)]}
            except KeyError:
                edge_label = {direction: edge_labels_endurance[(v, u)]}
            edge_label_endurance = nx.draw_networkx_edge_labels(
                subgraph,
                node_positions,
                edge_labels=edge_label,
                font_weight='bold',
                font_size=len2fontsize[6],
                ax=self.ax)[direction]
            edge_label_endurance.set_visible(False)
            edge_label_endurance.set_rotation(ta)

            if self._parent._analysis_type == 'ww':
                try:
                    edge_label = {direction: edge_labels_nb_water[(u, v)]}
                except KeyError:
                    edge_label = {direction: edge_labels_nb_water[(v, u)]}
                edge_label_water = nx.draw_networkx_edge_labels(
                    subgraph,
                    node_positions,
                    edge_labels=edge_label,
                    font_weight='bold',
                    font_size=len2fontsize[6],
                    ax=self.ax)[direction]
                edge_label_water.set_visible(False)
                edge_label_water.set_rotation(ta)

                edge_data = {
                    'handle': handle,
                    'direction': direction,
                    'active': True,
                    'color': 'black',
                    'all_labels': {
                        'occupancy': edge_label_occupancy,
                        'endurance': edge_label_endurance,
                        'nb_water': edge_label_water
                    }
                }
            else:
                edge_data = {
                    'handle': handle,
                    'direction': direction,
                    'active': True,
                    'color': 'black',
                    'all_labels': {
                        'occupancy': edge_label_occupancy,
                        'endurance': edge_label_endurance
                    }
                }

            self._adj[u][v] = edge_data
            self._adj[v][u] = edge_data

        self.fig.tight_layout()
        self.cax = self.fig.add_axes([0.73, 0.1, 0.2, 0.01])
        self.cax.axis('off')

        self._cidpress = self.fig.canvas.mpl_connect('button_press_event',
                                                     self.on_press)
        self._cidrelease = self.fig.canvas.mpl_connect('button_release_event',
                                                       self.on_release)
        self._cidmotion = self.fig.canvas.mpl_connect('motion_notify_event',
                                                      self.on_motion)

    def on_press(self, event):
        if event.inaxes != self.ax: return
        if (self._canvas_toolbar is None) or (self._canvas_toolbar.mode != ""):
            return
        if self._current_node is not None: return

        clicked_on_node = False
        for node in self._node:
            if not self._node[node]['active']: continue
            node_handle = self._node[node]['handle']
            contains, attrd = node_handle.contains(event)
            if contains:
                clicked_on_node = True
                break
        if not clicked_on_node: return

        self.moved = False
        x0, y0 = node_handle.get_offsets()[0]
        self.press = x0, y0, event.xdata, event.ydata
        self._current_node_handle = node_handle
        self._current_node = node

        node_handle.set_animated(True)
        label = self._node[node]['label']
        label.set_animated(True)
        for other_node in self._adj[node]:
            other_node_handle = self._node[other_node]['handle']
            edge = self._adj[node][other_node]['handle']
            edge.set_animated(True)
            for label_type in self._adj[node][other_node]['all_labels']:
                label = self._adj[node][other_node]['all_labels'][label_type]
                label.set_animated(True)
            other_node_handle.set_animated(True)

        self.canvas.draw()
        self.background = self.fig.canvas.copy_from_bbox(self.ax.bbox)

        for other_node in self._adj[node]:
            other_node_handle = self._node[other_node]['handle']
            edge = self._adj[node][other_node]['handle']
            self.ax.draw_artist(edge)
            self.ax.draw_artist(other_node_handle)
        self.ax.draw_artist(node_handle)
        self.ax.draw_artist(label)

        self.canvas.blit(self.ax.bbox)

    def on_motion(self, event):
        if self._current_node is None: return
        if event.inaxes != self.ax: return
        node = self._current_node
        node_handle = self._current_node_handle
        self.moved = True
        x0, y0, xpress, ypress = self.press
        dx = event.xdata - xpress
        dy = event.ydata - ypress

        node_handle.set_offsets([x0 + dx, y0 + dy])

        for other_node in self._adj[node]:
            edge = self._adj[node][other_node]['handle']
            direction = self._adj[node][other_node]['direction']
            segments = edge.get_segments()
            index = direction.index(node)
            segments[0][index] = [x0 + dx, y0 + dy]
            edge.set_segments(segments)
            label_pos = np.array(segments[0]).sum(axis=0) / 2
            ta = trans_angle(segments[0][0], segments[0][1], self.ax)
            for label_type in self._adj[node][other_node]['all_labels']:
                edge_label = self._adj[node][other_node]['all_labels'][
                    label_type]
                edge_label.set_position(label_pos)
                edge_label.set_rotation(ta)

        label = self._node[node]['label']
        label.set_position([x0 + dx, y0 + dy])

        self.canvas.restore_region(self.background)

        for other_node in self._adj[node]:
            other_node_handle = self._node[other_node]['handle']
            edge = self._adj[node][other_node]['handle']
            for label_type in self._adj[node][other_node]['all_labels']:
                edge_label = self._adj[node][other_node]['all_labels'][
                    label_type]
                self.ax.draw_artist(edge_label)
            self.ax.draw_artist(edge)
            self.ax.draw_artist(other_node_handle)
        self.ax.draw_artist(node_handle)
        self.ax.draw_artist(label)

        self.canvas.blit(self.ax.bbox)

    def on_release(self, event):
        if self._current_node is None: return
        node_handle = self._current_node_handle
        node = self._current_node
        label = self._node[node]['label']

        if not self.moved:
            if node in self.selected_nodes:
                self.selected_nodes.remove(node)
                node_handle.set_linewidth(1.0)
                node_handle.set_edgecolor(self._node[node]['color'])
            else:
                self.selected_nodes.append(node)
                self._parent.statusbar.showMessage(
                    self.get_current_node_info())
            self.process_selected_nodes()
            self.ax.draw_artist(node_handle)
            self.canvas.blit(self.ax.bbox)

        node_handle.set_animated(False)
        label.set_animated(False)
        for other_node in self._adj[node]:
            other_node_handle = self._node[other_node]['handle']
            edge = self._adj[node][other_node]['handle']
            for label_type in self._adj[node][other_node]['all_labels']:
                edge_label = self._adj[node][other_node]['all_labels'][
                    label_type]
                edge_label.set_animated(False)
            edge.set_animated(False)
            other_node_handle.set_animated(False)

        self.background = None
        self.canvas.draw_idle()
        self.press = None
        self._current_node = None

    def edges(self):
        seen = {}
        for node, neighbours in self._adj.items():
            for neighbor, data in neighbours.items():
                if neighbor not in seen:
                    yield (node, neighbor, data)
            seen[node] = True
        del seen

    def nodes(self):
        for node in self._node:
            yield node

    def get_current_node_info(self):
        return self._current_node

    def reset_selected_nodes(self):
        self.selected_nodes = []
        for node in self._node:
            node_handle = self._node[node]['handle']
            node_handle.set_linewidth(1.0)
            node_handle.set_edgecolor(self._node[node]['color'])
        self.canvas.draw_idle()

    def process_selected_nodes(self):
        focus_widget = self._parent.focusWidget()

        if focus_widget is self._parent.line_bonds_connected_root:
            self._parent.line_bonds_connected_root.setText(self._current_node)
        elif focus_widget is self._parent.line_bonds_path_root:
            self._parent.line_bonds_path_root.setText(self._current_node)
        elif focus_widget is self._parent.line_bonds_path_goal:
            self._parent.line_bonds_path_goal.setText(self._current_node)
        elif focus_widget is self._parent.lineEdit_specific_path:
            self._parent.lineEdit_specific_path.setText(', '.join(
                self.selected_nodes))

        current_plugin = self._parent.comboBox_plugins.currentText()
        plugin_ui = self._parent._plugins[current_plugin].ui
        plugin_lineEdits = [
            getattr(plugin_ui, lineEdit) for lineEdit in
            ['lineEdit_node_picker' + str(i) for i in range(1, 4)]
            if hasattr(plugin_ui, lineEdit)
        ]
        for lineEdit in plugin_lineEdits:
            if focus_widget is lineEdit:
                lineEdit.setText(self._current_node)

        rem_nodes = []
        for node in self.selected_nodes:
            node_handle = self._node[node]['handle']
            if node not in self._analysis.filtered_graph.nodes:
                rem_nodes.append(node)
                node_handle.set_linewidth(1.0)
                node_handle.set_edgecolor(self._node[node]['color'])
            node_handle.set_linewidth(2.0)
            node_handle.set_edgecolor('black')
        for node in rem_nodes:
            self.selected_nodes.remove(node)

    def set_edge_color(self):
        for node, other_node, edge_data in self.edges():
            edge_handle = edge_data['handle']
            edge_handle.set_facecolor(edge_data['color'])
            edge_handle.set_edgecolor(edge_data['color'])
        self.canvas.draw_idle()

    def set_subgraph(self, **kwargs):
        subgraph = self._analysis.filtered_graph
        node_labels_active = self._parent.checkBox_bonds_graph_labels.isChecked(
        )
        for node in self._node:
            if node in subgraph.nodes:
                self._node[node]['active'] = True
                node_handle = self._node[node]['handle']
                node_handle.set_visible(True)
                label = self._node[node]['label']
                if node_labels_active: label.set_visible(True)
                else: label.set_visible(False)
            else:
                self._node[node]['active'] = False
                node_handle = self._node[node]['handle']
                node_handle.set_visible(False)
                label = self._node[node]['label']
                label.set_visible(False)
        for node, other_node, edge_data in self.edges():
            edge_handle = edge_data['handle']
            for edge_label_type, edge_label in edge_data['all_labels'].items():
                if (node, other_node) in subgraph.edges:
                    edge_handle.set_visible(True)
                    edge_data['active'] = True
                else:
                    edge_handle.set_visible(False)
                    edge_label.set_visible(False)
                    edge_data['active'] = False
        self.set_edge_labels(draw=False)
        if ('draw' not in kwargs) or kwargs['draw']: self.canvas.draw_idle()

    def set_colors(self, **kwargs):
        if self.ax.get_legend() is not None:
            self.ax.get_legend().remove()
        if self.cax is not None:
            self.cax.clear()
            self.cax.axis('off')

        if self._parent.radioButton_color.isChecked():
            color = self._parent.comboBox_single_color.currentText()
            if color == '': color = default_colors[0]
            if not is_color_like(color):
                Error(
                    'Color Error' + ' ' * 30,
                    "Did not understand color definition '{}'. You can use strings like green or shorthands like g or RGB codes like #15b01a."
                    .format(color))
                return
            for node in self._node:
                self._node[node]['color'] = color

        elif self._parent.radioButton_colors.isChecked():
            for node in self._node:
                segname = node.split('-')[0]
                color = self._parent._segname_colors[segname]
                if not is_color_like(color):
                    Error(
                        'Color Error' + ' ' * 30,
                        "Did not understand color definition '{}'. You can use strings like green or shorthands like g or RGB codes like #15b01a."
                        .format(color))
                    return
                self._node[node]['color'] = color
            if self._parent.checkBox_segnames_legend.isChecked():
                custom_lines = [
                    Line2D([0], [0],
                           marker='o',
                           color='w',
                           markerfacecolor=color,
                           alpha=0.6,
                           markersize=12,
                           lw=4)
                    for color in self._parent._segname_colors.values()
                ]
                segnames = [
                    segname for segname in self._parent._segname_colors.keys()
                ]
                self.ax.legend(custom_lines, segnames)

        elif self._parent.radioButton_degree.isChecked(
        ) or self._parent.radioButton_betweenness.isChecked():
            avg_type = self._parent.checkBox_centralities_avg.isChecked()
            norm_type = self._parent.checkBox_centralities_norm.isChecked()
            if self._parent.radioButton_degree.isChecked():
                centralities = self._analysis.centralities['degree'][avg_type][
                    norm_type]
            elif self._parent.radioButton_betweenness.isChecked():
                centralities = self._analysis.centralities['betweenness'][
                    avg_type][norm_type]

            max_centrality = sorted(centralities.values())[-1]
            if self._parent.radioButton_degree.isChecked() and (
                    not (avg_type or norm_type)):
                max_centrality = round(max_centrality)
            cmap = plt.get_cmap('jet')
            for node in centralities:
                centrality_value = centralities[node]
                self._node[node]['color'] = rgb_to_string(
                    cmap(centrality_value / max_centrality))
            if self._parent.checkBox_color_legend.isChecked():
                sm = plt.cm.ScalarMappable(cmap=cmap)
                sm.set_array([0.0, max_centrality])
                plt.colorbar(sm,
                             cax=self.cax,
                             ticks=[0, max_centrality],
                             orientation='horizontal')
                self.cax.set_xticklabels([str(0), str(max_centrality)])
                self.cax.axis('on')

        for node in self._node:
            node_handle = self._node[node]['handle']
            color = self._node[node]['color']
            if not self._parent.checkBox_white.isChecked():
                node_handle.set_facecolor(color)
            else:
                node_handle.set_facecolor('white')
            node_handle.set_edgecolor(color)
        if ('draw' not in kwargs) or kwargs['draw']: self.canvas.draw_idle()

    def set_node_positions(self, **kwargs):
        projection = 'PCA'
        if self._parent.radioButton_rotation_xy.isChecked(): projection = 'XY'
        elif self._parent.radioButton_rotation_zy.isChecked():
            projection = 'ZY'
        adjust_water = False
        frame = int(self._parent.label_frame.text())
        positions = self._parent.analysis.get_node_positions_2d(
            projection=projection,
            in_frame=frame,
            adjust_water_positions=adjust_water)

        all_pos = np.array([positions[key] for key in positions])
        minx, maxx, miny, maxy = np.min(all_pos[:, 0]), np.max(
            all_pos[:, 0]), np.min(all_pos[:, 1]), np.max(all_pos[:, 1])
        xmargin = (maxx - minx) / 20
        ymargin = (maxy - miny) / 20
        for node in self._node:
            node_handle = self._node[node]['handle']
            node_handle.set_offsets(positions[node])
            node_label = self._node[node]['label']
            node_label.set_position(positions[node])

        for node, other_node, edge_data in self.edges():
            edge = edge_data['handle']
            direction = edge_data['direction']
            index = direction.index(node)
            other_index = direction.index(other_node)
            edge_positions = np.array([positions[node], positions[other_node]
                                       ])[[index, other_index]]
            edge.set_segments([edge_positions])
            label_pos = np.array(edge_positions).sum(axis=0) / 2
            ta = trans_angle(edge_positions[0], edge_positions[1], self.ax)
            for label_type in edge_data['all_labels']:
                edge_label = edge_data['all_labels'][label_type]
                edge_label.set_position(label_pos)
                edge_label.set_rotation(ta)

        self.ax.set_xlim(minx - xmargin, maxx + xmargin)
        self.ax.set_ylim(miny - ymargin, maxy + ymargin)
        if ('draw' not in kwargs): self.canvas.draw_idle()

    def set_nodesize(self, **kwargs):
        for node in self._node:
            offset = self._default_size['node'][0] / 2
            size = offset + 2 / (self._default_size['node'][0]) * (
                self._parent.horizontalSlider_nodes.value() / 100 *
                self._default_size['node'][0])**2
            node_handle = self._node[node]['handle']
            node_handle.set_sizes([size])
            offset = self._default_size['node'][1] / 2
            size = offset + self._parent.horizontalSlider_nodes.value(
            ) / 100 * self._default_size['node'][1]
            label_handle = self._node[node]['label']
            label_handle.set_size(size)
        if ('draw' not in kwargs) or kwargs['draw']: self.canvas.draw_idle()

    def set_edgesize(self, **kwargs):
        for node, other_node, edge_data in self.edges():
            offset = self._default_size['edge'] / 2
            size = offset + self._default_size['edge'] * (
                self._parent.horizontalSlider_edges.value() / 100)
            edge_data['handle'].set_linewidth(size)
        if ('draw' not in kwargs) or kwargs['draw']: self.canvas.draw_idle()

    def set_labelsize(self, **kwargs):
        for node, other_node, edge_data in self.edges():
            offset = self._default_size['label'] / 2
            size = offset + self._default_size['label'] * (
                self._parent.horizontalSlider_labels.value() / 100)
            for typ, label_handle in edge_data['all_labels'].items():
                label_handle.set_fontsize(size)
        if ('draw' not in kwargs) or kwargs['draw']: self.canvas.draw_idle()

    def set_node_labels(self, **kwargs):
        labels_active = self._parent.checkBox_bonds_graph_labels.isChecked()
        for node in self._node:
            show_label = self._node[node]['active']
            label = self._node[node]['label']
            if labels_active and show_label: label.set_visible(True)
            else: label.set_visible(False)
        if ('draw' not in kwargs) or kwargs['draw']: self.canvas.draw_idle()

    def set_edge_labels(self, **kwargs):
        if self._parent.checkBox_bonds_occupancy.isChecked():
            active_label_type = 'occupancy'
        elif self._parent.checkBox_bonds_endurance.isChecked():
            active_label_type = 'endurance'
        elif self._parent.checkBox_nb_water.isChecked():
            active_label_type = 'nb_water'
        else:
            active_label_type = None
        for node, other_node, edge_data in self.edges():
            show_label = edge_data['active']
            labels = edge_data['all_labels']
            for label_type in labels:
                label = labels[label_type]
                if (label_type == active_label_type) and show_label:
                    label.set_visible(True)
                else:
                    label.set_visible(False)
        if ('draw' not in kwargs) or kwargs['draw']: self.canvas.draw_idle()

    def get_active_nodes(self):
        return [node for node in self._node if self._node[node]['active']]

    def set_current_pos(self):
        node_pos = {}
        for node in self._node:
            node_handle = self._node[node]['handle']
            x, y = node_handle.get_offsets()[0]
            node_pos[node] = (x, y)
        self._analysis._current_node_positions = node_pos

    def add_toolbar(self):
        self._canvas_toolbar = NavigationToolbar(self.canvas, self._parent)
        self._canvas_toolbar.home = self.set_node_positions
        self._parent.addToolBar(self._canvas_toolbar)

    def remove_toolbar(self):
        self._parent.removeToolBar(self._canvas_toolbar)

    def _disconnect(self):
        'disconnect all the stored connection ids'
        self.canvas.mpl_disconnect(self.cidpress)
        self.canvas.mpl_disconnect(self.cidrelease)
        self.canvas.mpl_disconnect(self.cidmotion)