Ejemplo n.º 1
0
class button_manager:
    ''' Handles some missing features of matplotlib check buttons
    on init:
        creates button, links to button_click routine,
        calls call_on_click with active index and firsttime=True
    on click:
        maintains single button on state, calls call_on_click
    '''

    #@output.capture()  # debug
    def __init__(self, fig, dim, labels, init, call_on_click):
        '''
        dim: (list)     [leftbottom_x,bottom_y,width,height]
        labels: (list)  for example ['1','2','3','4','5','6']
        init: (list)    for example [True, False, False, False, False, False]
        '''
        self.fig = fig
        self.ax = plt.axes(dim)  #lx,by,w,h
        self.init_state = init
        self.call_on_click = call_on_click
        self.button = CheckButtons(self.ax, labels, init)
        self.button.on_clicked(self.button_click)
        self.status = self.button.get_status()
        self.call_on_click(self.status.index(True), firsttime=True)

    #@output.capture()  # debug
    def reinit(self):
        self.status = self.init_state
        self.button.set_active(self.status.index(
            True))  #turn off old, will trigger update and set to status

    #@output.capture()  # debug
    def button_click(self, event):
        ''' maintains one-on state. If on-button is clicked, will process correctly '''
        #new_status = self.button.get_status()
        #new = [self.status[i] ^ new_status[i] for i in range(len(self.status))]
        #newidx = new.index(True)
        self.button.eventson = False
        self.button.set_active(
            self.status.index(True))  #turn off old or reenable if same
        self.button.eventson = True
        self.status = self.button.get_status()
        self.call_on_click(self.status.index(True))
Ejemplo n.º 2
0
class DefacingInterface(BaseReviewInterface):
    """Custom interface to rate the quality of defacing in an MRI scan"""
    def __init__(self,
                 fig,
                 axes,
                 issue_list=cfg.defacing_default_issue_list,
                 next_button_callback=None,
                 quit_button_callback=None,
                 processing_choice_callback=None,
                 map_key_to_callback=None):
        """Constructor"""

        super().__init__(fig, axes, next_button_callback, quit_button_callback)

        self.issue_list = issue_list

        self.prev_axis = None
        self.prev_ax_pos = None
        self.zoomed_in = False
        self.next_button_callback = next_button_callback
        self.quit_button_callback = quit_button_callback
        self.processing_choice_callback = processing_choice_callback
        if map_key_to_callback is None:
            self.map_key_to_callback = {}  # empty
        elif isinstance(map_key_to_callback, dict):
            self.map_key_to_callback = map_key_to_callback
        else:
            raise ValueError('map_key_to_callback must be a dict')

        self.add_checkboxes()
        self.add_process_options()
        # include all the non-data axes here (so they wont be zoomed-in)
        self.unzoomable_axes = [
            self.checkbox.ax, self.text_box.ax, self.bt_next.ax,
            self.bt_quit.ax, self.radio_bt_vis_type
        ]

        # this list of artists to be populated later
        # makes to handy to clean them all
        self.data_handles = list()

    def add_checkboxes(self):
        """
        Checkboxes offer the ability to select multiple tags such as Motion,
        Ghosting Aliasing etc, instead of one from a list of mutual exclusive
        rating options (such as Good, Bad, Error etc).

        """

        ax_checkbox = plt.axes(cfg.position_checkbox_t1_mri,
                               facecolor=cfg.color_rating_axis)
        # initially de-activating all
        check_box_status = [False] * len(self.issue_list)
        self.checkbox = CheckButtons(ax_checkbox,
                                     labels=self.issue_list,
                                     actives=check_box_status)
        self.checkbox.on_clicked(self.save_issues)
        for txt_lbl in self.checkbox.labels:
            txt_lbl.set(color=cfg.text_option_color, fontweight='normal')

        for rect in self.checkbox.rectangles:
            rect.set_width(cfg.checkbox_rect_width)
            rect.set_height(cfg.checkbox_rect_height)

        # lines is a list of n crosses, each cross (x) defined by a tuple of lines
        for x_line1, x_line2 in self.checkbox.lines:
            x_line1.set_color(cfg.checkbox_cross_color)
            x_line2.set_color(cfg.checkbox_cross_color)

        self._index_pass = cfg.defacing_default_issue_list.index(
            cfg.defacing_pass_indicator)

    def add_process_options(self):

        ax_radio = plt.axes(cfg.position_radio_bt_t1_mri,
                            facecolor=cfg.color_rating_axis)
        self.radio_bt_vis_type = RadioButtons(ax_radio,
                                              cfg.vis_choices_defacing,
                                              active=None,
                                              activecolor='orange')
        self.radio_bt_vis_type.on_clicked(self.processing_choice_callback)
        for txt_lbl in self.radio_bt_vis_type.labels:
            txt_lbl.set(color=cfg.text_option_color, fontweight='normal')

        for circ in self.radio_bt_vis_type.circles:
            circ.set(radius=0.06)

    def save_issues(self, label):
        """
        Update the rating

        This function is called whenever set_active() happens on any label,
            if checkbox.eventson is True.

        """

        if label == cfg.visual_qc_pass_indicator:
            self.clear_checkboxes(except_pass=True)
        else:
            self.clear_pass_only_if_on()

        self.fig.canvas.draw_idle()

    def clear_checkboxes(self, except_pass=False):
        """Clears all checkboxes.

        if except_pass=True,
            does not clear checkbox corresponding to cfg.t1_mri_pass_indicator
        """

        cbox_statuses = self.checkbox.get_status()
        for index, this_cbox_active in enumerate(cbox_statuses):
            if except_pass and index == self._index_pass:
                continue
            # if it was selected already, toggle it.
            if this_cbox_active:
                # not calling checkbox.set_active() as it calls the callback
                #   self.save_issues() each time, if eventson is True
                self._toggle_visibility_checkbox(index)

    def clear_pass_only_if_on(self):
        """Clear pass checkbox only"""

        cbox_statuses = self.checkbox.get_status()
        if cbox_statuses[self._index_pass]:
            self._toggle_visibility_checkbox(self._index_pass)

    def _toggle_visibility_checkbox(self, index):
        """toggles the visibility of a given checkbox"""

        l1, l2 = self.checkbox.lines[index]
        l1.set_visible(not l1.get_visible())
        l2.set_visible(not l2.get_visible())

    def get_ratings(self):
        """Returns the final set of checked ratings"""

        cbox_statuses = self.checkbox.get_status()
        user_ratings = [
            self.checkbox.labels[idx].get_text()
            for idx, this_cbox_active in enumerate(cbox_statuses)
            if this_cbox_active
        ]

        return user_ratings

    def allowed_to_advance(self):
        """
        Method to ensure work is done for current iteration,
        before allowing the user to advance to next subject.

        Returns False if atleast one of the following conditions are not met:
            Atleast Checkbox is checked
        """

        if any(self.checkbox.get_status()):
            allowed = True
        else:
            allowed = False

        return allowed

    def reset_figure(self):
        "Resets the figure to prepare it for display of next subject."

        self.clear_data()
        self.clear_checkboxes()
        self.clear_radio_buttons()
        self.clear_notes_annot()

    def clear_data(self):
        """clearing all data/image handles"""

        if self.data_handles:
            for artist in self.data_handles:
                artist.remove()
            # resetting it
            self.data_handles = list()

    def clear_notes_annot(self):
        """clearing notes and annotations"""

        self.text_box.set_val(cfg.textbox_initial_text)
        # text is matplotlib artist
        self.annot_text.remove()

    def clear_radio_buttons(self):
        """Clears the radio button"""

        # enabling default rating encourages lazy advancing without review
        # self.radio_bt_rating.set_active(cfg.index_freesurfer_default_rating)
        for index, label in enumerate(self.radio_bt_vis_type.labels):
            if label.get_text() == self.radio_bt_vis_type.value_selected:
                self.radio_bt_vis_type.circles[index].set_facecolor(
                    cfg.color_rating_axis)
                break
        self.radio_bt_vis_type.value_selected = None

    def on_mouse(self, event):
        """Callback for mouse events."""

        if self.prev_axis is not None:
            if event.inaxes not in self.unzoomable_axes:
                self.prev_axis.set_position(self.prev_ax_pos)
                self.prev_axis.set_zorder(0)
                self.prev_axis.patch.set_alpha(0.5)
                self.zoomed_in = False

        # right or double click to zoom in to any axis
        if (event.button in [3] or event.dblclick) and \
            (event.inaxes is not None) and \
            event.inaxes not in self.unzoomable_axes:
            self.prev_ax_pos = event.inaxes.get_position()
            event.inaxes.set_position(cfg.zoomed_position)
            event.inaxes.set_zorder(1)  # bring forth
            event.inaxes.set_facecolor('black')  # black
            event.inaxes.patch.set_alpha(1.0)  # opaque
            self.zoomed_in = True
            self.prev_axis = event.inaxes
        else:
            pass

        self.fig.canvas.draw_idle()

    def on_keyboard(self, key_in):
        """Callback to handle keyboard shortcuts to rate and advance."""

        # ignore keyboard key_in when mouse within Notes textbox
        if key_in.inaxes == self.text_box.ax or key_in.key is None:
            return

        key_pressed = key_in.key.lower()
        # print(key_pressed)
        if key_pressed in ['right', ' ', 'space']:
            self.next_button_callback()
        elif key_pressed in ['ctrl+q', 'q+ctrl']:
            self.quit_button_callback()
        elif key_pressed in self.map_key_to_callback:
            # notice parentheses at the end
            self.map_key_to_callback[key_pressed]()
        else:
            if key_pressed in cfg.abbreviation_t1_mri_default_issue_list:
                checked_label = cfg.abbreviation_t1_mri_default_issue_list[
                    key_pressed]
                self.checkbox.set_active(
                    cfg.t1_mri_default_issue_list.index(checked_label))
            else:
                pass

        self.fig.canvas.draw_idle()
Ejemplo n.º 3
0
class CorrViewer(DataViewer):
    """Plots raw correlation data. You need to hold reference to this object, 
    otherwise it will not work in interactive mode.

    Parameters
    ----------
    semilogx : bool
        Whether plot data with semilogx or not.
    shape : tuple of ints, optional
        Original frame shape. For non-rectangular you must provide this so
        to define k step.
    size : int, optional
        If specified, perform log_averaging of data with provided size parameter.
        If not given, no averaging is performed.
    norm : int, optional
        Normalization constant used in normalization
    scale : bool, optional
        Scale constant used in normalization.
    mask : ndarray, optional
        A boolean array indicating which data elements were computed.
    """
    background = None
    variance = None

    def __init__(self,
                 semilogx=True,
                 shape=None,
                 size=None,
                 norm=None,
                 scale=False,
                 mask=None):
        self.norm = norm
        self.scale = scale
        self.semilogx = semilogx
        self.shape = shape
        self.size = size
        self.computed_mask = mask
        if mask is not None:
            self.kisize, self.kjsize = mask.shape

    def set_norm(self, value):
        """Sets norm parameter"""
        method = _method_from_data(self.data)
        self.norm = _default_norm_from_data(self.data, method, value)

    def _init_fig(self):
        super()._init_fig()

        self.set_norm(self.norm)

        #self.rax = plt.axes([0.48, 0.55, 0.15, 0.3])
        self.cax = plt.axes([0.44, 0.72, 0.2, 0.15])

        self.active = [
            bool(self.norm & NORM_STRUCTURED),
            bool(self.norm & NORM_SUBTRACTED),
            bool((self.norm & NORM_WEIGHTED == NORM_WEIGHTED)),
            bool((self.norm & NORM_COMPENSATED) == NORM_COMPENSATED)
        ]

        self.check = CheckButtons(
            self.cax, ("structured", "subtracted", "weighted", "compensated"),
            self.active)

        #self.radio = RadioButtons(self.rax,("norm 0","norm 1","norm 2","norm 3","norm 4", "norm 5", "norm 6", "norm 7"), active = self.norm, activecolor = "gray")

        def update(label):
            index = ["structured", "subtracted", "weighted",
                     "compensated"].index(label)
            status = self.check.get_status()

            norm = NORM_STRUCTURED if status[0] == True else NORM_STANDARD

            if status[1]:
                norm = norm | NORM_SUBTRACTED
            if status[2]:
                norm = norm | NORM_WEIGHTED
            if status[3]:
                norm = norm | NORM_COMPENSATED
            try:
                self.set_norm(norm)
            except ValueError:
                self.check.set_active(index)
            self.set_mask(int(round(self.kindex.val)), self.angleindex.val,
                          self.sectorindex.val, self.kstep)
            self.plot()

        self.check.on_clicked(update)


#        def update(val):
#            try:
#                self.set_norm(int(self.radio.value_selected[-1]))
#            except ValueError:
#                self.radio.set_active(self.norm)
#            self.set_mask(int(round(self.kindex.val)),self.angleindex.val,self.sectorindex.val, self.kstep)
#            self.plot()
#
#        self.radio.on_clicked(update)

    def set_data(self, data, background=None, variance=None):
        """Sets correlation data.
        
        Parameters
        ----------
        data : tuple
            A data tuple (as computed by ccorr, cdiff, adiff, acorr functions)
        background : tuple or ndarray
            Background data for normalization. For adiff, acorr functions this
            is ndarray, for cdiff,ccorr, it is a tuple of ndarrays.
        variance : tuple or ndarray
            Variance data for normalization. For adiff, acorr functions this
            is ndarray, for cdiff,ccorr, it is a tuple of ndarrays.
        """
        self.data = data
        self.background = background
        self.variance = variance
        self.kshape = data[0].shape[:-1]
        if self.computed_mask is None:
            self.kisize, self.kjsize = self.kshape

    def _get_avg_data(self):
        data = normalize(self.data,
                         self.background,
                         self.variance,
                         norm=self.norm,
                         scale=self.scale,
                         mask=self.mask)

        if self.size is not None:
            t, data = log_average(data, self.size)
        else:
            t = np.arange(data.shape[-1])

        return t, np.nanmean(data, axis=-2)
Ejemplo n.º 4
0
class TrackletVisualizer:
    def __init__(self, manager, videoname, trail_len=50):
        self.manager = manager
        self.cmap = plt.cm.get_cmap(manager.cfg["colormap"],
                                    len(set(manager.tracklet2id)))
        self.videoname = videoname
        self.video = cv2.VideoCapture(videoname)
        if not self.video.isOpened():
            raise IOError("Video could not be opened.")
        self.nframes = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))
        # Take into consideration imprecise OpenCV estimation of total number of frames
        if abs(self.nframes - manager.nframes) >= 0.05 * manager.nframes:
            print(
                "Video duration and data length do not match. Continuing nonetheless..."
            )
        self.trail_len = trail_len
        self.help_text = ""
        self.draggable = False
        self._curr_frame = 0
        self.curr_frame = 0

        self.picked = []
        self.picked_pair = []
        self.cuts = []

        self.player = BackgroundPlayer(self)
        self.thread_player = Thread(target=self.player.run, daemon=True)
        self.thread_player.start()

        self.dps = []

    def _prepare_canvas(self, manager, fig):
        params = {
            "keymap.save": "s",
            "keymap.back": "left",
            "keymap.forward": "right",
            "keymap.yscale": "l",
        }
        for k, v in params.items():
            if v in plt.rcParams[k]:
                plt.rcParams[k].remove(v)

        self.dotsize = manager.cfg["dotsize"]
        self.alpha = manager.cfg["alphavalue"]

        if fig is None:
            self.fig = plt.figure(figsize=(13, 8))
        else:
            self.fig = fig
        gs = self.fig.add_gridspec(2, 2)
        self.ax1 = self.fig.add_subplot(gs[:, 0])
        self.ax2 = self.fig.add_subplot(gs[0, 1])
        self.ax3 = self.fig.add_subplot(gs[1, 1], sharex=self.ax2)
        plt.subplots_adjust(bottom=0.2)
        for ax in self.ax1, self.ax2, self.ax3:
            ax.axis("off")

        self.colors = self.cmap(manager.tracklet2id)
        self.colors[:, -1] = self.alpha

        img = self._read_frame()
        self.im = self.ax1.imshow(img)
        self.scat = self.ax1.scatter([], [], s=self.dotsize**2, picker=True)
        self.scat.set_offsets(manager.xy[:, 0])
        self.scat.set_color(self.colors)
        self.trails = sum(
            [self.ax1.plot([], [], "-", lw=2, c=c) for c in self.colors], [])
        self.lines_x = sum([
            self.ax2.plot([], [], "-", lw=1, c=c, picker=5)
            for c in self.colors
        ], [])
        self.lines_y = sum([
            self.ax3.plot([], [], "-", lw=1, c=c, picker=5)
            for c in self.colors
        ], [])
        self.vline_x = self.ax2.axvline(0, 0, 1, c="k", ls=":")
        self.vline_y = self.ax3.axvline(0, 0, 1, c="k", ls=":")
        custom_lines = [
            plt.Line2D([0], [0], color=self.cmap(i), lw=4)
            for i in range(len(manager.individuals))
        ]
        self.leg = self.fig.legend(
            custom_lines,
            manager.individuals,
            frameon=False,
            fancybox=None,
            ncol=len(manager.individuals),
            fontsize="small",
            bbox_to_anchor=(0, 0.9, 1, 0.1),
            loc="center",
        )
        for line in self.leg.get_lines():
            line.set_picker(5)

        self.ax_slider = self.fig.add_axes([0.1, 0.1, 0.5, 0.03],
                                           facecolor="lightgray")
        self.ax_slider2 = self.fig.add_axes([0.1, 0.05, 0.3, 0.03],
                                            facecolor="darkorange")
        self.slider = Slider(
            self.ax_slider,
            "# Frame",
            self.curr_frame,
            manager.nframes - 1,
            valinit=0,
            valstep=1,
            valfmt="%i",
        )
        self.slider.on_changed(self.on_change)
        self.slider2 = Slider(
            self.ax_slider2,
            "Marker size",
            1,
            30,
            valinit=self.dotsize,
            valstep=1,
            valfmt="%i",
        )
        self.slider2.on_changed(self.update_dotsize)
        self.ax_drag = self.fig.add_axes([0.65, 0.1, 0.05, 0.03])
        self.ax_lasso = self.fig.add_axes([0.7, 0.1, 0.05, 0.03])
        self.ax_flag = self.fig.add_axes([0.75, 0.1, 0.05, 0.03])
        self.ax_save = self.fig.add_axes([0.80, 0.1, 0.05, 0.03])
        self.ax_help = self.fig.add_axes([0.85, 0.1, 0.05, 0.03])
        self.save_button = Button(self.ax_save, "Save", color="darkorange")
        self.save_button.on_clicked(self.save)
        self.help_button = Button(self.ax_help, "Help")
        self.help_button.on_clicked(self.display_help)
        self.drag_toggle = CheckButtons(self.ax_drag, ["Drag"])
        self.drag_toggle.on_clicked(self.toggle_draggable_points)
        self.flag_button = Button(self.ax_flag, "Flag")
        self.flag_button.on_clicked(self.flag_frame)

        self.fig.canvas.mpl_connect("pick_event", self.on_pick)
        self.fig.canvas.mpl_connect("key_press_event", self.on_press)
        self.fig.canvas.mpl_connect("button_press_event", self.on_click)
        self.fig.canvas.mpl_connect("close_event", self.player.terminate)

        self.selector = PointSelector(self, self.ax1, self.scat, self.alpha)
        self.lasso_toggle = CheckButtons(self.ax_lasso, ["Lasso"])
        self.lasso_toggle.on_clicked(self.selector.toggle)
        self.display_traces(only_picked=False)
        self.ax1_background = self.fig.canvas.copy_from_bbox(self.ax1.bbox)
        plt.show()

    def show(self, fig=None):
        self._prepare_canvas(self.manager, fig)

    def _read_frame(self):
        frame = self.video.read()[1]
        if frame is None:
            return
        return frame[:, :, ::-1]

    def fill_shaded_areas(self):
        self.clean_collections()
        if self.picked_pair:
            mask = self.manager.get_nonoverlapping_segments(*self.picked_pair)
            for ax in self.ax2, self.ax3:
                ax.fill_between(
                    self.manager.times,
                    *ax.dataLim.intervaly,
                    mask,
                    facecolor="darkgray",
                    alpha=0.2,
                )
            trans = mtransforms.blended_transform_factory(
                self.ax_slider.transData, self.ax_slider.transAxes)
            self.ax_slider.vlines(np.flatnonzero(mask),
                                  0,
                                  0.5,
                                  color="darkorange",
                                  transform=trans)

    def toggle_draggable_points(self, *args):
        self.draggable = not self.draggable
        if self.draggable:
            self._curr_frame = self.curr_frame
            self.scat.set_offsets([])
            self.add_draggable_points()
        else:
            self.save_coords()
            self.clean_points()
            self.display_points(self._curr_frame)
        self.fig.canvas.draw_idle()

    def add_point(self, center, animal, bodypart, **kwargs):
        circle = patches.Circle(center, **kwargs)
        self.ax1.add_patch(circle)
        dp = generate_training_dataset.auxfun_drag_label_multiple_individuals.DraggablePoint(
            circle, animal, bodypart)
        dp.connect()
        self.dps.append(dp)

    def clean_points(self):
        for dp in self.dps:
            dp.annot.set_visible(False)
            dp.disconnect()
        self.dps = []
        for patch in self.ax1.patches[::-1]:
            patch.remove()

    def add_draggable_points(self):
        self.clean_points()
        xy, _, inds = self.manager.get_non_nan_elements(self.curr_frame)
        for i, (animal, bodypart) in enumerate(self.manager._label_pairs):
            if i in inds:
                coords = xy[inds == i].squeeze()
                self.add_point(
                    coords,
                    animal,
                    bodypart,
                    radius=self.dotsize,
                    fc=self.colors[i],
                    alpha=self.alpha,
                )

    def save_coords(self):
        coords, nonempty, inds = self.manager.get_non_nan_elements(
            self._curr_frame)
        prob = self.manager.prob[:, self._curr_frame]
        for dp in self.dps:
            label = dp.individual_names, dp.bodyParts
            ind = self.manager._label_pairs.index(label)
            nrow = np.flatnonzero(inds == ind)[0]
            if not np.array_equal(
                    coords[nrow],
                    dp.point.center):  # Keypoint has been displaced
                coords[nrow] = dp.point.center
                prob[ind] = 1
        self.manager.xy[nonempty, self._curr_frame] = coords

    def flag_frame(self, *args):
        self.cuts.append(self.curr_frame)
        self.ax_slider.axvline(self.curr_frame, color="r")
        if len(self.cuts) == 2:
            self.cuts.sort()
            mask = np.zeros_like(self.manager.times, dtype=bool)
            mask[self.cuts[0]:self.cuts[1] + 1] = True
            for ax in self.ax2, self.ax3:
                ax.fill_between(
                    self.manager.times,
                    *ax.dataLim.intervaly,
                    mask,
                    facecolor="darkgray",
                    alpha=0.2,
                )
            trans = mtransforms.blended_transform_factory(
                self.ax_slider.transData, self.ax_slider.transAxes)
            self.ax_slider.vlines(np.flatnonzero(mask),
                                  0,
                                  0.5,
                                  color="darkorange",
                                  transform=trans)
        self.fig.canvas.draw_idle()

    def on_scroll(self, event):
        cur_xlim = self.ax1.get_xlim()
        cur_ylim = self.ax1.get_ylim()
        xdata = event.xdata
        ydata = event.ydata
        if event.button == "up":
            scale_factor = 0.5
        elif event.button == "down":
            scale_factor = 2
        else:  # This should never happen anyway
            scale_factor = 1

        self.ax1.set_xlim([
            xdata - (xdata - cur_xlim[0]) / scale_factor,
            xdata + (cur_xlim[1] - xdata) / scale_factor,
        ])
        self.ax1.set_ylim([
            ydata - (ydata - cur_ylim[0]) / scale_factor,
            ydata + (cur_ylim[1] - ydata) / scale_factor,
        ])
        self.fig.canvas.draw()

    def on_press(self, event):
        if event.key == "right":
            self.move_forward()
        elif event.key == "left":
            self.move_backward()
        elif event.key == "s":
            self.swap()
        elif event.key == "i":
            self.invert()
        elif event.key == "x":
            self.flag_frame()
            if len(self.cuts) > 1:
                self.cuts.sort()
                if self.picked_pair:
                    self.manager.tracklet_swaps[self.picked_pair][
                        self.cuts] = ~self.manager.tracklet_swaps[
                            self.picked_pair][self.cuts]
                    self.fill_shaded_areas()
                    self.cuts = []
                    self.ax_slider.lines = []
        elif event.key == "backspace":
            if not self.dps:  # Last flag deletion
                try:
                    self.cuts.pop()
                    self.ax_slider.lines.pop()
                    if not len(self.cuts) == 2:
                        self.clean_collections()
                except IndexError:
                    pass
            else:  # Smart point removal
                i = np.nanargmin([
                    self.calc_distance(*dp.point.center, event.xdata,
                                       event.ydata) for dp in self.dps
                ])
                closest_dp = self.dps[i]
                label = closest_dp.individual_names, closest_dp.bodyParts
                closest_dp.disconnect()
                closest_dp.point.remove()
                self.dps.remove(closest_dp)
                ind = self.manager._label_pairs.index(label)
                self.manager.xy[ind, self._curr_frame] = np.nan
                self.manager.prob[ind, self._curr_frame] = np.nan
            self.fig.canvas.draw_idle()
        elif event.key == "l":
            self.lasso_toggle.set_active(not self.lasso_toggle.get_active)
        elif event.key == "d":
            self.drag_toggle.set_active(not self.drag_toggle.get_active)
        elif event.key == "alt+right":
            self.player.forward()
        elif event.key == "alt+left":
            self.player.rewind()
        elif event.key == "tab":
            self.player.toggle()

    def move_forward(self):
        if self.curr_frame < self.manager.nframes - 1:
            self.curr_frame += 1
            self.slider.set_val(self.curr_frame)

    def move_backward(self):
        if self.curr_frame > 0:
            self.curr_frame -= 1
            self.slider.set_val(self.curr_frame)

    def swap(self):
        if self.picked_pair:
            swap_inds = self.manager.get_swap_indices(*self.picked_pair)
            inds = np.insert(swap_inds, [0, len(swap_inds)],
                             [0, self.manager.nframes - 1])
            if len(inds):
                ind = np.argmax(inds > self.curr_frame)
                self.manager.swap_tracklets(
                    *self.picked_pair, range(inds[ind - 1], inds[ind] + 1))
                self.display_traces()
                self.slider.set_val(self.curr_frame)

    def invert(self):
        if not self.picked_pair and len(self.picked) == 2:
            self.picked_pair = self.picked
        if self.picked_pair:
            self.manager.swap_tracklets(*self.picked_pair, [self.curr_frame])
            self.display_traces()
            self.slider.set_val(self.curr_frame)

    def on_pick(self, event):
        artist = event.artist
        if artist.axes == self.ax1:
            self.picked = list(event.ind)
        elif artist.axes == self.ax2:
            if isinstance(artist, plt.Line2D):
                self.picked = [self.lines_x.index(artist)]
        elif artist.axes == self.ax3:
            if isinstance(artist, plt.Line2D):
                self.picked = [self.lines_y.index(artist)]
        else:  # Click on the legend lines
            if self.picked:
                num_individual = self.leg.get_lines().index(artist)
                nrow = self.manager.tracklet2id.index(num_individual)
                inds = [
                    nrow + self.manager.to_num_bodypart(pick)
                    for pick in self.picked
                ]
                xy = self.manager.xy[self.picked]
                p = self.manager.prob[self.picked]
                mask = np.zeros(xy.shape[1], dtype=bool)
                if len(self.cuts) > 1:
                    mask[self.cuts[-2]:self.cuts[-1] + 1] = True
                    self.cuts = []
                    self.ax_slider.lines = []
                    self.clean_collections()
                else:
                    return
                sl_inds = np.ix_(inds, mask)
                sl_picks = np.ix_(self.picked, mask)
                old_xy = self.manager.xy[sl_inds].copy()
                old_prob = self.manager.prob[sl_inds].copy()
                self.manager.xy[sl_inds] = xy[:, mask]
                self.manager.prob[sl_inds] = p[:, mask]
                self.manager.xy[sl_picks] = old_xy
                self.manager.prob[sl_picks] = old_prob
        self.picked_pair = []
        if len(self.picked) == 1:
            for pair in self.manager.swapping_pairs:
                if self.picked[0] in pair:
                    self.picked_pair = pair
                    break
        self.clean_collections()
        self.display_traces()
        if self.picked_pair:
            self.fill_shaded_areas()
        self.slider.set_val(self.curr_frame)

    def on_click(self, event):
        if (event.inaxes in (self.ax2, self.ax3) and event.button == 1
                and not any(
                    line.contains(event)[0]
                    for line in self.lines_x + self.lines_y)):
            x = max(0, min(event.xdata, self.manager.nframes - 1))
            self.update_vlines(x)
            self.slider.set_val(x)
        elif event.inaxes == self.ax1 and not self.scat.contains(event)[0]:
            self.display_traces(only_picked=False)
            self.clean_collections()

    def clean_collections(self):
        for coll in (self.ax2.collections + self.ax3.collections +
                     self.ax_slider.collections):
            coll.remove()

    def display_points(self, val):
        data = self.manager.xy[:, val]
        self.scat.set_offsets(data)

    def display_trails(self, val):
        sl = slice(val - self.trail_len // 2, val + self.trail_len // 2)
        for n, trail in enumerate(self.trails):
            if n in self.picked:
                xy = self.manager.xy[n, sl]
                trail.set_data(*xy.T)
            else:
                trail.set_data([], [])

    def display_traces(self, only_picked=True):
        if only_picked:
            inds = self.picked + list(self.picked_pair)
        else:
            inds = self.manager.swapping_bodyparts
        for n, (line_x, line_y) in enumerate(zip(self.lines_x, self.lines_y)):
            if n in inds:
                line_x.set_data(self.manager.times, self.manager.xy[n, :, 0])
                line_y.set_data(self.manager.times, self.manager.xy[n, :, 1])
            else:
                line_x.set_data([], [])
                line_y.set_data([], [])
        for ax in self.ax2, self.ax3:
            ax.relim()
            ax.autoscale_view()

    def display_help(self, event):
        if not self.help_text:
            self.help_text = """
            Key D: activate "drag" so you can adjust bodyparts in that particular frame
            Key I: invert the position of a pair of bodyparts
            Key L: toggle the lasso selector
            Key S: swap two tracklets
            Key X: cut swapping tracklets
            Left/Right arrow: navigate through the video
            Tab: play/pause the video
            Alt+Right/Left: fast forward/rewind
            """
            self.text = self.fig.text(
                0.5,
                0.5,
                self.help_text,
                horizontalalignment="center",
                verticalalignment="center",
                fontsize=12,
                color="red",
            )
        else:
            self.help_text = ""
            self.text.remove()

    def update_vlines(self, val):
        self.vline_x.set_xdata([val, val])
        self.vline_y.set_xdata([val, val])

    def on_change(self, val):
        self.curr_frame = int(val)
        self.video.set(cv2.CAP_PROP_POS_FRAMES, self.curr_frame)
        img = self._read_frame()
        if img is not None:
            # Automatically disable the draggable points
            if self.draggable:
                self.drag_toggle.set_active(False)

            self.im.set_array(img)
            self.display_points(self.curr_frame)
            self.display_trails(self.curr_frame)
            self.update_vlines(self.curr_frame)

    def update_dotsize(self, val):
        self.dotsize = val
        self.scat.set_sizes([self.dotsize**2])

    @staticmethod
    def calc_distance(x1, y1, x2, y2):
        return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)

    def save(self, *args):
        self.save_coords()
        self.manager.save()

    def export_to_training_data(self, pcutoff=0.1):
        import os
        from skimage import io

        inds = self.manager.find_edited_frames()
        if not len(inds):
            print("No frames have been manually edited.")
            return

        # Save additional frames to the labeled-data directory
        strwidth = int(np.ceil(np.log10(self.nframes)))
        vname = os.path.splitext(os.path.basename(self.videoname))[0]
        tmpfolder = os.path.join(self.manager.cfg["project_path"],
                                 "labeled-data", vname)
        if os.path.isdir(tmpfolder):
            print("Frames from video", vname,
                  " already extracted (more will be added)!")
        else:
            attempttomakefolder(tmpfolder)
        index = []
        for ind in inds:
            imagename = os.path.join(tmpfolder,
                                     "img" + str(ind).zfill(strwidth) + ".png")
            index.append(os.path.join(*imagename.rsplit(os.path.sep, 3)[-3:]))
            if not os.path.isfile(imagename):
                self.video.set(cv2.CAP_PROP_POS_FRAMES, ind)
                frame = self._read_frame()
                if frame is None:
                    print("Frame could not be read. Skipping...")
                    continue
                frame = frame.astype(np.ubyte)
                if self.manager.cfg["cropping"]:
                    x1, x2, y1, y2 = [
                        int(self.manager.cfg[key])
                        for key in ("x1", "x2", "y1", "y2")
                    ]
                    frame = frame[y1:y2, x1:x2]
                io.imsave(imagename, frame)

        # Store the newly-refined data
        data = self.manager.format_data()
        df = data.iloc[inds]

        # Uncertain keypoints are ignored
        def filter_low_prob(cols, prob):
            mask = cols.iloc[:, 2] < prob
            cols.loc[mask] = np.nan
            return cols

        df = df.groupby(level='bodyparts', axis=1).apply(filter_low_prob,
                                                         prob=pcutoff)
        df.index = index
        machinefile = os.path.join(
            tmpfolder,
            "machinelabels-iter" + str(self.manager.cfg["iteration"]) + ".h5")
        if os.path.isfile(machinefile):
            df_old = pd.read_hdf(machinefile, "df_with_missing")
            df_joint = pd.concat([df_old, df])
            df_joint = df_joint[~df_joint.index.duplicated(keep="first")]
            df_joint.to_hdf(machinefile, key="df_with_missing", mode="w")
            df_joint.to_csv(os.path.join(tmpfolder, "machinelabels.csv"))
        else:
            df.to_hdf(machinefile, key="df_with_missing", mode="w")
            df.to_csv(os.path.join(tmpfolder, "machinelabels.csv"))

        # Merge with the already existing annotated data
        df.columns.set_levels([self.manager.cfg["scorer"]],
                              level="scorer",
                              inplace=True)
        df.drop("likelihood", level="coords", axis=1, inplace=True)
        output_path = os.path.join(
            tmpfolder, f'CollectedData_{self.manager.cfg["scorer"]}.h5')
        if os.path.isfile(output_path):
            print(
                "A training dataset file is already found for this video. The refined machine labels are merged to this data!"
            )
            df_orig = pd.read_hdf(output_path, "df_with_missing")
            df_joint = pd.concat([df, df_orig])
            # Now drop redundant ones keeping the first one [this will make sure that the refined machine file gets preference]
            df_joint = df_joint[~df_joint.index.duplicated(keep="first")]
            df_joint.sort_index(inplace=True)
            df_joint.to_hdf(output_path, key="df_with_missing", mode="w")
            df_joint.to_csv(output_path.replace("h5", "csv"))
        else:
            df.sort_index(inplace=True)
            df.to_hdf(output_path, key="df_with_missing", mode="w")
            df.to_csv(output_path.replace("h5", "csv"))
Ejemplo n.º 5
0
class Curator:
    """
    matplotlib display of scrolling image data 
    
    Parameters
    ---------
    extractor : extractor
        extractor object containing a full set of infilled threads and time series

    Attributes
    ----------
    ind : int
        thread indexing 

    min : int
        min of image data (for setting ranges)

    max : int
        max of image data (for setting ranges)

    """
    def __init__(self, e, window=100):
        # get info from extractors
        self.s = e.spool
        self.timeseries = e.timeseries
        self.tf = e.im
        self.tf.t = 0
        self.window = window
        ## num neurons
        self.numneurons = len(self.s.threads)

        self.path = e.root + 'extractor-objects/curate.json'
        self.ind = 0
        try:
            with open(self.path) as f:
                self.curate = json.load(f)

            self.ind = int(self.curate['last'])
        except:
            self.curate = {}
            self.ind = 0
            self.curate['0'] = 'seen'

        # array to contain internal state: whether to display single ROI, ROI in Z, or all ROIs
        self.pointstate = 0
        self.show_settings = 0
        self.showmip = 0
        ## index for which thread
        #self.ind = 0

        ## index for which time point to display
        self.t = 0

        ### First frame of the first thread
        self.update_im()

        ## Display range
        self.min = np.min(self.im)
        self.max = np.max(self.im)  # just some arbitrary value

        ## maximum t
        self.tmax = e.t

        self.restart()
        atexit.register(self.log_curate)

    def restart(self):
        ## Figure to display
        self.fig = plt.figure()

        ## Size of window around ROI in sub image
        #self.window = window

        ## grid object for complicated subplot handing
        self.grid = plt.GridSpec(4, 2, wspace=0.1, hspace=0.2)

        ### First subplot: whole image with red dot over ROI
        self.ax1 = plt.subplot(self.grid[:3, 0])
        plt.subplots_adjust(bottom=0.4)
        self.img1 = self.ax1.imshow(self.get_im_display(),
                                    cmap='gray',
                                    vmin=0,
                                    vmax=1)

        # plotting for multiple points

        if self.pointstate == 0:
            pass
            #self.point1 = plt.scatter()
            #self.point1 = plt.scatter(self.s.get_positions_t_z(self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:,2], self.s.get_positions_t_z(self.t,self.s.threads[self.ind].get_position_t(self.t)[0])[:,1],c='b', s=10)
        elif self.pointstate == 1:
            self.point1 = self.ax1.scatter(
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2],
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1],
                c='b',
                s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1 = self.ax1.scatter(self.s.get_positions_t(self.t)[:,
                                                                          2],
                                           self.s.get_positions_t(self.t)[:,
                                                                          1],
                                           c='b',
                                           s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        self.thispoint = self.ax1.scatter(
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1],
            c='r',
            s=10)
        plt.axis('off')

        # plotting for single point
        #
        #plt.axis("off")
        #

        ### Second subplot: some window around the ROI
        plt.subplot(self.grid[:3, 1])
        plt.subplots_adjust(bottom=0.4)

        self.subim, self.offset = subaxis(
            self.im, self.s.threads[self.ind].get_position_t(self.t),
            self.window)

        self.img2 = plt.imshow(self.get_subim_display(),
                               cmap='gray',
                               vmin=0,
                               vmax=1)
        self.point2 = plt.scatter(self.window / 2 + self.offset[0],
                                  self.window / 2 + self.offset[1],
                                  c='r',
                                  s=40)

        self.title = self.fig.suptitle(
            'Series=' + str(self.ind) + ', Z=' +
            str(int(self.s.threads[self.ind].get_position_t(self.t)[0])))
        plt.axis("off")

        ### Third subplot: plotting the timeseries
        self.timeax = plt.subplot(self.grid[3, :])
        plt.subplots_adjust(bottom=0.4)
        self.timeplot, = self.timeax.plot(
            (self.timeseries[:, self.ind] -
             np.min(self.timeseries[:, self.ind])) /
            (np.max(self.timeseries[:, self.ind]) -
             np.min(self.timeseries[:, self.ind])))
        plt.axis("off")

        ### Axis for scrolling through t
        self.tr = plt.axes([0.2, 0.15, 0.3, 0.03],
                           facecolor='lightgoldenrodyellow')
        self.s_tr = Slider(self.tr,
                           'Timepoint',
                           0,
                           self.tmax - 1,
                           valinit=0,
                           valstep=1)
        self.s_tr.on_changed(self.update_t)

        ### Axis for setting min/max range
        self.minr = plt.axes([0.2, 0.2, 0.3, 0.03],
                             facecolor='lightgoldenrodyellow')
        self.sminr = Slider(self.minr,
                            'R Min',
                            0,
                            np.max(self.im),
                            valinit=self.min,
                            valstep=1)
        self.maxr = plt.axes([0.2, 0.25, 0.3, 0.03],
                             facecolor='lightgoldenrodyellow')
        self.smaxr = Slider(self.maxr,
                            'R Max',
                            0,
                            np.max(self.im) * 4,
                            valinit=self.max,
                            valstep=1)
        self.sminr.on_changed(self.update_mm)
        self.smaxr.on_changed(self.update_mm)

        ### Axis for buttons for next/previous time series
        #where the buttons are, and their locations
        self.axprev = plt.axes([0.62, 0.20, 0.1, 0.075])
        self.axnext = plt.axes([0.75, 0.20, 0.1, 0.075])
        self.bnext = Button(self.axnext, 'Next')
        self.bnext.on_clicked(self.next)
        self.bprev = Button(self.axprev, 'Previous')
        self.bprev.on_clicked(self.prev)

        #### Axis for button for display
        self.pointsax = plt.axes([0.75, 0.10, 0.1, 0.075])
        self.pointsbutton = RadioButtons(self.pointsax,
                                         ('Single', 'Same Z', 'All'))
        self.pointsbutton.set_active(self.pointstate)
        self.pointsbutton.on_clicked(self.update_pointstate)

        #### Axis for whether to display MIP on left
        self.mipax = plt.axes([0.62, 0.10, 0.1, 0.075])
        self.mipbutton = RadioButtons(self.mipax, ('Single Z', 'MIP'))
        self.mipbutton.set_active(self.showmip)
        self.mipbutton.on_clicked(self.update_mipstate)

        ### Axis for button to keep
        self.keepax = plt.axes([0.87, 0.20, 0.075, 0.075])
        self.keep_button = CheckButtons(self.keepax, ['Keep', 'Trash'],
                                        [False, False])
        self.keep_button.on_clicked(self.keep)

        ### Axis to determine which ones to show
        self.showax = plt.axes([0.87, 0.10, 0.075, 0.075])
        self.showbutton = RadioButtons(
            self.showax, ('All', 'Unlabelled', 'Kept', 'Trashed'))
        self.showbutton.set_active(self.show_settings)
        self.showbutton.on_clicked(self.show)

        plt.show()

    ## Attempting to get autosave when instance gets deleted, not working right now TODO
    def __del__(self):
        self.log_curate()

    def update_im(self):
        #print(self.t)
        #print(self.ind)
        #print(self.t,int(self.s.threads[self.ind].get_position_t(self.t)[0]))
        if self.showmip:
            self.im = np.max(self.tf.get_t(self.t), axis=0)
        else:
            self.im = self.tf.get_tbyf(
                self.t,
                int(self.s.threads[self.ind].get_position_t(self.t)[0]))

    def get_im_display(self):

        return (self.im - self.min) / (self.max - self.min)

    def get_subim_display(self):
        return (self.subim - self.min) / (self.max - self.min)

    def update_figures(self):
        self.subim, self.offset = subaxis(
            self.im, self.s.threads[self.ind].get_position_t(self.t),
            self.window)
        self.img1.set_data(self.get_im_display())

        if self.pointstate == 0:
            pass
        elif self.pointstate == 1:
            self.point1.set_offsets(
                np.array([
                    self.s.get_positions_t_z(
                        self.t,
                        self.s.threads[self.ind].get_position_t(self.t)[0])[:,
                                                                            2],
                    self.s.get_positions_t_z(
                        self.t,
                        self.s.threads[self.ind].get_position_t(self.t)[0])[:,
                                                                            1]
                ]).T)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1.set_offsets(
                np.array([
                    self.s.get_positions_t(self.t)[:, 2],
                    self.s.get_positions_t(self.t)[:, 1]
                ]).T)
        self.thispoint.set_offsets([
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1]
        ])
        plt.axis('off')
        #plotting for single point
        #

        self.img2.set_data(self.get_subim_display())
        self.point2.set_offsets([
            self.window / 2 + self.offset[0], self.window / 2 + self.offset[1]
        ])
        self.title.set_text(
            'Series=' + str(self.ind) + ', Z=' +
            str(int(self.s.threads[self.ind].get_position_t(self.t)[0])))
        plt.draw()

    def update_timeseries(self):
        self.timeplot.set_ydata((self.timeseries[:, self.ind] -
                                 np.min(self.timeseries[:, self.ind])) /
                                (np.max(self.timeseries[:, self.ind]) -
                                 np.min(self.timeseries[:, self.ind])))
        plt.draw()

    def update_t(self, val):
        # Update index for t
        self.t = val
        # update image for t
        self.update_im()
        self.update_figures()

    def update_mm(self, val):
        self.min = self.sminr.val
        self.max = self.smaxr.val
        #self.update_im()
        self.update_figures()

    def next(self, event):
        self.set_index_next()
        self.update_im()
        self.update_figures()
        self.update_timeseries()
        self.update_buttons()
        self.update_curate()

    def prev(self, event):
        self.set_index_prev()
        self.update_im()
        self.update_figures()
        self.update_timeseries()
        self.update_buttons()
        self.update_curate()

    def log_curate(self):
        self.curate['last'] = self.ind
        with open(self.path, 'w') as fp:
            json.dump(self.curate, fp)

    def keep(self, event):
        status = self.keep_button.get_status()
        if np.sum(status) != 1:
            for i in range(len(status)):
                if status[i] != False:
                    self.keep_button.set_active(i)

        else:
            if status[0]:
                self.curate[str(self.ind)] = 'keep'
            elif status[1]:
                self.curate[str(self.ind)] = 'trash'
            else:
                pass

    def update_buttons(self):

        curr = self.keep_button.get_status()
        #print(curr)
        future = [False for i in range(len(curr))]
        if self.curate.get(str(self.ind)) == 'seen':
            pass
        elif self.curate.get(str(self.ind)) == 'keep':
            future[0] = True
        elif self.curate.get(str(self.ind)) == 'trash':
            future[1] = True
        else:
            pass

        for i in range(len(curr)):
            if curr[i] != future[i]:
                self.keep_button.set_active(i)

    def show(self, label):
        d = {'All': 0, 'Unlabelled': 1, 'Kept': 2, 'Trashed': 3}
        #print(label)
        self.show_settings = d[label]

    def set_index_prev(self):
        if self.show_settings == 0:
            self.ind -= 1
            self.ind = self.ind % self.numneurons

        elif self.show_settings == 1:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(self.ind)) in [
                    'keep', 'trash'
            ] and counter != self.numneurons:
                self.ind -= 1
                self.ind = self.ind % self.numneurons
                counter += 1
            self.ind = self.ind % self.numneurons
        elif self.show_settings == 2:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['keep'] and counter != self.numneurons:
                self.ind -= 1
                counter += 1
            self.ind = self.ind % self.numneurons
        else:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['trash'] and counter != self.numneurons:
                self.ind -= 1
                counter += 1
            self.ind = self.ind % self.numneurons

    def set_index_next(self):
        if self.show_settings == 0:
            self.ind += 1
            self.ind = self.ind % self.numneurons

        elif self.show_settings == 1:
            self.ind += 1
            counter = 0
            while self.curate.get(str(self.ind)) in [
                    'keep', 'trash'
            ] and counter != self.numneurons:
                self.ind += 1
                self.ind = self.ind % self.numneurons
                counter += 1
            self.ind = self.ind % self.numneurons
        elif self.show_settings == 2:
            self.ind += 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['keep'] and counter != self.numneurons:
                self.ind += 1
                counter += 1
            self.ind = self.ind % self.numneurons
        else:
            self.ind += 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['trash'] and counter != self.numneurons:
                self.ind += 1
                counter += 1
            self.ind = self.ind % self.numneurons

    def update_curate(self):
        if self.curate.get(str(self.ind)) in ['keep', 'seen', 'trash']:
            pass
        else:
            self.curate[str(self.ind)] = 'seen'

    def update_pointstate(self, label):
        d = {
            'Single': 0,
            'Same Z': 1,
            'All': 2,
        }
        #print(label)
        self.pointstate = d[label]
        self.update_point1()
        self.update_figures()

    def update_point1(self):
        self.ax1.clear()
        self.img1 = self.ax1.imshow(self.get_im_display(),
                                    cmap='gray',
                                    vmin=0,
                                    vmax=1)
        plt.axis('off')
        if self.pointstate == 0:
            self.point1 = None
        elif self.pointstate == 1:
            self.point1 = self.ax1.scatter(
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2],
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1],
                c='b',
                s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1 = self.ax1.scatter(self.s.get_positions_t(self.t)[:,
                                                                          2],
                                           self.s.get_positions_t(self.t)[:,
                                                                          1],
                                           c='b',
                                           s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        self.thispoint = self.ax1.scatter(
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1],
            c='r',
            s=10)
        plt.axis('off')
        #plt.show()

    def update_mipstate(self, label):
        d = {
            'Single Z': 0,
            'MIP': 1,
        }
        #print(label)
        self.showmip = d[label]

        self.update_im()
        self.update_figures()
Ejemplo n.º 6
0
class ScanData:
    '''
    Collect raw data for energy scan experiments and provides methods to
    average them.

    Attributes
    ----------
    label : list (str)
        Labels for graphs.

    idx : list (str)
        Scan indexes

    raw_imp : pandas DataFrame
        Collect imported raw data.

    energy : array
        common energy scale for average of scans.

    lines : list (2d line objects)
        Collect lines object of raw scan for plotting.

    blins : list (2d line objects)
        Collect lines object of raw scan for plotting.

    avg_ln : 2d line objects
        Lines object of average of scans.

    plab : list (str)
        Collect the list of labels plotted in graphs for choosing scans.

    checkbx : CheckButtons obj
        Widget for choosing plots.

    aver : array
        Data average of selected scans from raw_imp.
        If ScanData is a reference one, aver contain normalized data by
        reference.

    dtype : str
        Identifies the data collected, used for graph labelling:
        sigma+, sigma- for XMCD
        CR, CL for XNCD
        H-, H+ for XNXD
        LH, LV fot XNLD

    chsn_scns : list (str)
        Labels of chosen scans for the analysis.

    pe_av : float
        Value of spectra at pre-edge energy. It is obtained from
        averaging data in an defined energy range centered at pre-edge
        energy.

    pe_av_int : float
        Pre-edge value obtained from linear interpolation considering
        pre-edge and post-edge energies.

    bsl : Univariate spline object
        spline interpolation of ArpLS baseline 

    norm : array
        Averaged data normalized by value at pre-edge energy.

    norm_int : array
        Averaged data normalized by interpolated pre-edge value.

    ej : float
        Edge-jump value.

    ej_norm : float
            edge-jump value normalized by value at pre-edge energy.

    ej_int : float
        edge-jump value computed with interpolated pre-edge value.

    ej_norm_int : float
        edge-jump computed and normalized by interpolated pre-edge
        value.

    Methods
    -------
    man_aver_e_scans(guiobj, enrg)
        Manage the choice of scans to be averaged and return the average
        of selected scans.

    aver_e_scans(enrg, chsn, guiobj)
        Performe the average of data scans.

    check_but(label)
        When check buttons are checked switch the visibility of
        corresponding line.

    averbut(event)
        When average button is pressed calls aver_e_scans to compute
        average on selected scans.

    reset(event)
        Reset graph to starting conditions.

    finish_but(self, event)
        Close the figure on pressing the button Finish.

    edge_norm(guiobj, enrg, e_edge, e_pe, pe_rng, pe_int)
        Normalize energy scan data by value at pre-edge energy and
        compute edge-jump.
    '''
    def __init__(self):
        '''
        Initialize attributes label, idx, and raw_imp.
        '''
        self.label = []
        self.idx = []
        self.raw_imp = pd.DataFrame()

    def man_aver_e_scans(self, guiobj, enrg):
        '''
        Manage the choice of scans to be averaged and return the average
        of selected scans.

        Parameters
        ----------
        guiobj : GUI object
            Provides GUI dialogs.

        enrg : array
            Energy values at which average is calculated.

        Return
        ------
        Set class attributes:
        aver : array
            Average values of the chosen scans.

        chsn_scns : list
            Labels of chosen scans for the analysis (for log purpose).
        '''
        self.energy = enrg
        self.chsn_scns = []
        self.aver = 0

        if guiobj.interactive:  # Interactive choose of scans
            fig, ax = plt.subplots(figsize=(10, 6))
            fig.subplots_adjust(right=0.75)

            ax.set_xlabel('E (eV)')
            ax.set_ylabel(self.dtype)

            if guiobj.infile_ref:
                fig.suptitle('Choose reference sample scans')
            else:
                fig.suptitle('Choose sample scans')
            # Initialize list which will contains line obj of scans
            # lines contain colored lines for choose
            # blines contain dark lines to be showed with average
            self.lines = []
            self.blines = []
            # Populate list with line objs
            for i in self.idx:
                e_col = 'E' + i
                # Show lines and not blines
                self.lines.append(
                    ax.plot(self.raw_imp[e_col],
                            self.raw_imp[i],
                            label=self.label[self.idx.index(i)])[0])
                self.blines.append(
                    ax.plot(self.raw_imp[e_col],
                            self.raw_imp[i],
                            color='dimgrey',
                            visible=False)[0])
            # Initialize chsn_scs and average line with all scans and
            # set it invisible
            for line in self.lines:
                if line.get_visible():
                    self.chsn_scns.append(line.get_label())
            self.aver_e_scans()
            self.avg_ln, = ax.plot(self.energy,
                                   self.aver,
                                   color='red',
                                   lw=2,
                                   visible=False)
            # Create box for checkbutton
            chax = fig.add_axes([0.755, 0.32, 0.24, 0.55], facecolor='0.95')
            self.plab = [str(line.get_label()) for line in self.lines]
            visibility = [line.get_visible() for line in self.lines]
            self.checkbx = CheckButtons(chax, self.plab, visibility)
            # Customizations of checkbuttons
            rxy = []
            bxh = 0.05
            for r in self.checkbx.rectangles:
                r.set(height=bxh)
                r.set(width=bxh)
                rxy.append(r.get_xy())
            for i in range(len(rxy)):
                self.checkbx.lines[i][0].set_xdata(
                    [rxy[i][0], rxy[i][0] + bxh])
                self.checkbx.lines[i][0].set_ydata(
                    [rxy[i][1], rxy[i][1] + bxh])
                self.checkbx.lines[i][1].set_xdata(
                    [rxy[i][0] + bxh, rxy[i][0]])
                self.checkbx.lines[i][1].set_ydata(
                    [rxy[i][1], rxy[i][1] + bxh])

            for l in self.checkbx.labels:
                l.set(fontsize='medium')
                l.set_verticalalignment('center')
                l.set_horizontalalignment('left')

            self.checkbx.on_clicked(self.check_but)

            # Create box for average reset and finish buttons
            averbox = fig.add_axes([0.77, 0.2, 0.08, 0.08])
            bnaver = Button(averbox, 'Average')
            bnaver.on_clicked(self.averbut)
            rstbox = fig.add_axes([0.89, 0.2, 0.08, 0.08])
            bnrst = Button(rstbox, 'Reset')
            bnrst.on_clicked(self.reset)
            finbox = fig.add_axes([0.82, 0.07, 0.12, 0.08])
            bnfinish = Button(finbox, 'Finish')
            bnfinish.on_clicked(self.finish_but)

            ax.legend()
            plt.show()

            # If average is not pressed automatically compute average on
            # selected scans
            if self.chsn_scns == []:
                for line in self.lines:
                    if line.get_visible():
                        self.chsn_scns.append(line.get_label())
                self.aver_e_scans()
        else:
            # Not-interactive mode: all scans except 'Dummy Scans' are
            # evaluated
            for lbl in self.label:
                # Check it is not a 'Dummy Scan' and append
                # corresponding scan number in chosen scan list
                if not ('Dummy' in lbl):
                    self.chsn_scns.append(self.idx[self.label.index(lbl)])

            self.aver_e_scans()

    def check_but(self, label):
        '''
        When check buttons are checked switch the visibility of
        corresponding line.
        Also update self.chsn_scns with labels of visible scans.
        '''
        index = self.plab.index(label)
        self.lines[index].set_visible(not self.lines[index].get_visible())
        # Update chsn_scns
        self.chsn_scns = []
        for line in self.lines:
            if line.get_visible():
                self.chsn_scns.append(line.get_label())
        plt.draw()

    def averbut(self, event):
        '''
        When average button is pressed calls aver_e_scans to compute
        average on selected scans.
        Update self.chsn_scns and self.aver.
        '''
        # Initialize list of chosed scans
        self.chsn_scns = []
        # Set visible only chosen scans in blines and append to
        # chsn_scns
        for i in range(len(self.lines)):
            if self.lines[i].get_visible():
                self.lines[i].set(visible=False)
                self.chsn_scns.append(self.lines[i].get_label())
                self.blines[i].set(visible=True)

        self.aver_e_scans()

        # Update average line and make it visible
        self.avg_ln.set_ydata(self.aver)
        self.avg_ln.set(visible=True)

        plt.draw()

    def aver_e_scans(self):
        '''
        Perform the average of data scans. 
        If interactive mode, data scans and their average are shown
        together in a plot. 

        Parameters
        ----------
        enrg : array
            Energy values at which average is calculated.

        chsn : list (str)
            Scan-numbers of scan to be averaged.

        guiobj: GUI object
            Provides GUI dialogs.

        Returns
        -------
        array, containing the average of data scans.

        Notes
        -----
        To compute the average the common energy scale enrg is used.
        All passed scans are interpolated with a linear spline (k=1 and
        s=0 in itp.UnivariateSpline) and evaluated along the common
        energy scale.
        The interpolated data are eventually averaged.
        '''
        intrp = []

        for i in self.idx:
            e_col = 'E' + i
            if self.label[self.idx.index(i)] in self.chsn_scns:
                # chosen data
                x = self.raw_imp['E' + i][1:]
                y = self.raw_imp[i][1:]

                # Compute linear spline interpolation
                y_int = itp.UnivariateSpline(x, y, k=1, s=0)
                # Evaluate interpolation of scan data on enrg energy scale
                # and append to previous interpolations
                intrp.append(y_int(self.energy))

        # Average all inteprolated scans
        self.aver = np.average(intrp, axis=0)

    def reset(self, event):
        '''
        Reset graph for schoosing scans to starting conditions.
        Show all scans and set checked all buttons.
        '''
        # Clear graph
        self.avg_ln.set(visible=False)

        stauts = self.checkbx.get_status()
        for i, stat in enumerate(stauts):
            if not stat:
                self.checkbx.set_active(i)
        # Show all spectra
        for i in range(len(self.lines)):
            self.lines[i].set(visible=True)
            self.blines[i].set(visible=False)

        plt.draw()

    def finish_but(self, event):
        '''
        Close the figure on pressing the button Finish.
        '''
        plt.close()

    def edge_norm(self, guiobj, enrg, e_edge, e_pe, e_poste, pe_rng):
        '''
        Normalize energy scan data by the value at pre-edge energy.
        Also compute the  energy jump defined as the difference between
        the value at the edge and pre-edge energies respectively.

        This computations are implemented also considering baseline.
        If linear baseline is selected edge jump is computed considering
        the the height of data at edge energy from the stright line
        passing from pre-edge and post edge data.
        If asymmetrically reweighted penalized least squares baseline is
        selected the edge jump is calculated considering as the distance
        at edge energy between the averaged spectrum and baseline.
        
        Parameters
        ----------
        guiobj: GUI object
            Provides GUI dialogs.

        enrg : array
            Energy values of scan.

        e_edge : float
            Edge energy value.

        e_pe : float
            Pre-edge energy value.

        pe_rng : int
            Number of points constituting the semi-width of energy range
            centered at e_pe.

        pe_int : float
            Pre-edge value obtained from linear interpolation based on
            pre- and post-edge energies.

        Returns
        -------
        Set class attributes:
        pe_av : float
            value at pre-edge energy.

        norm : array
            self.aver scan normalized by value at pre-edge energy.

        norm_int : array
            Averaged data normalized by interpolated pre-edge value.

        ej : float
            edge-jump value.

        ej_norm : float
            edge-jump value normalized by value at pre-edge energy.

        ej_int : float
            edge-jump value computed with interpolated pre-edge value.

        ej_norm_int : float
            edge-jump computed and normalized by interpolated pre-edge
            value.

        Notes
        -----
        To reduce noise effects the value of scan at pre-edge energy is
        obtained computing an average over an energy range of width
        pe_rng and centered at e_pe pre-edge energy.
        The value of scan at edge energy is obtained by cubic spline
        interpolation of data (itp.UnivariateSpline with k=3 and s=0).
        '''
        # Index of the nearest element to pre-edge energy
        pe_idx = np.argmin((np.abs(enrg - e_pe)))
        # Left and right extremes of energy range for pre-edge average
        lpe_idx = int(pe_idx - pe_rng)
        rpe_idx = int(pe_idx + pe_rng + 1)

        # Average of values for computation of pre-edge
        self.pe_av = np.average(self.aver[lpe_idx:rpe_idx:1])

        # Cubic spline interpolation of energy scan
        y_int = itp.UnivariateSpline(enrg, self.aver, k=3, s=0)
        # value at edge energy from interpolation
        y_edg = y_int(e_edge)

        # Edge-jumps computations - no baseline
        self.ej = y_edg - self.pe_av
        self.ej_norm = self.ej / self.pe_av
        # Normalization by pre-edge value
        self.norm = self.aver / self.pe_av

        # Edge-jumps computations - consider baseline
        if guiobj.bsl_int:
            # ArpLS baseline
            # Interpolation of pre-edge energy
            self.pe_av_int = self.bsl(e_edge)
        else:
            # Linear baseline
            # Interpolation of pre-edge energy
            x = [e_pe, e_poste]
            y = [y_int(e_pe), y_int(e_poste)]
            self.pe_av_int = lin_interpolate(x, y, e_edge)

        # Normalization by pre-edge value
        self.norm_int = self.aver / self.pe_av_int

        self.ej_int = y_edg - self.pe_av_int
        self.ej_norm_int = self.ej_int / self.pe_av_int
Ejemplo n.º 7
0
class FunctionalMRIInterface(T1MriInterface):
    """Interface for the review of fMRI images."""
    def __init__(self,
                 fig,
                 axes,
                 issue_list=cfg.func_mri_default_issue_list,
                 next_button_callback=None,
                 quit_button_callback=None,
                 right_arrow_callback=None,
                 left_arrow_callback=None,
                 zoom_in_callback=None,
                 zoom_out_callback=None,
                 right_click_callback=None,
                 show_stdev_callback=None,
                 axes_to_zoom=None,
                 total_num_layers=5):
        """Constructor"""

        super().__init__(fig, axes, issue_list, next_button_callback,
                         quit_button_callback)
        self.issue_list = issue_list

        self.prev_axis = None
        self.prev_ax_pos = None
        self.prev_ax_zorder = None
        self.prev_visible = False
        self.zoomed_in = False
        self.nested_zoomed_in = False
        self.total_num_layers = total_num_layers
        self.axes_to_zoom = axes_to_zoom

        self.next_button_callback = next_button_callback
        self.quit_button_callback = quit_button_callback
        self.zoom_in_callback = zoom_in_callback
        self.zoom_out_callback = zoom_out_callback
        self.right_arrow_callback = right_arrow_callback
        self.left_arrow_callback = left_arrow_callback
        self.right_click_callback = right_click_callback
        self.show_stdev_callback = show_stdev_callback

        self.add_checkboxes()

        # this list of artists to be populated later
        # makes to handy to clean them all
        self.data_handles = list()

    def add_checkboxes(self):
        """
        Checkboxes offer the ability to select multiple tags such as Motion, Ghosting Aliasing etc,
            instead of one from a list of mutual exclusive rating options (such as Good, Bad, Error etc).

        """

        ax_checkbox = plt.axes(cfg.position_checkbox,
                               facecolor=cfg.color_rating_axis)
        # initially de-activating all
        actives = [False] * len(self.issue_list)
        self.checkbox = CheckButtons(ax_checkbox,
                                     labels=self.issue_list,
                                     actives=actives)
        self.checkbox.on_clicked(self.save_issues)
        for txt_lbl in self.checkbox.labels:
            txt_lbl.set(**cfg.checkbox_font_properties)

        for rect in self.checkbox.rectangles:
            rect.set_width(cfg.checkbox_rect_width)
            rect.set_height(cfg.checkbox_rect_height)

        # lines is a list of n crosses, each cross (x) defined by a tuple of lines
        for x_line1, x_line2 in self.checkbox.lines:
            x_line1.set_color(cfg.checkbox_cross_color)
            x_line2.set_color(cfg.checkbox_cross_color)

        self._index_pass = self.issue_list.index(cfg.func_mri_pass_indicator)

    def maximize_axis(self, ax):
        """zooms a given axes"""

        if not self.nested_zoomed_in:
            self.prev_ax_pos = ax.get_position()
            self.prev_ax_zorder = ax.get_zorder()
            self.prev_ax_alpha = ax.get_alpha()
            ax.set_position(cfg.zoomed_position_level2)
            ax.set_zorder(self.total_num_layers + 1)  # bring forth
            ax.patch.set_alpha(1.0)  # opaque
            self.nested_zoomed_in = True
            self.prev_axis = ax

    def restore_axis(self):

        if self.nested_zoomed_in:
            self.prev_axis.set(position=self.prev_ax_pos,
                               zorder=self.prev_ax_zorder,
                               alpha=self.prev_ax_alpha)
            self.nested_zoomed_in = False

    def on_mouse(self, event):
        """Callback for mouse events."""

        # if event occurs in non-data areas, do nothing
        if event.inaxes in [
                self.checkbox.ax, self.text_box.ax, self.bt_next.ax,
                self.bt_quit.ax
        ]:
            return

        if self.zoomed_in:
            # include all the non-data axes here (so they wont be zoomed-in)
            if event.inaxes not in [
                    self.checkbox.ax, self.text_box.ax, self.bt_next.ax,
                    self.bt_quit.ax
            ]:
                if event.dblclick or event.button in [3]:
                    if event.inaxes in self.axes_to_zoom:
                        self.maximize_axis(event.inaxes)
                    else:
                        self.zoom_out_callback(event)
                else:
                    if self.nested_zoomed_in:
                        self.restore_axis()
                    else:
                        self.zoom_out_callback(event)

        elif event.button in [3]:
            self.right_click_callback(event)
        elif event.dblclick and event.inaxes is not None:
            self.zoom_in_callback(event)
        else:
            pass

        # redraw the figure - important
        self.fig.canvas.draw_idle()

    def on_keyboard(self, key_in):
        """Callback to handle keyboard shortcuts to rate and advance."""

        # ignore keyboard key_in when mouse within Notes textbox
        if key_in.inaxes == self.text_box.ax or key_in.key is None:
            return

        key_pressed = key_in.key.lower()
        # print(key_pressed)
        if key_pressed in ['right', 'up']:
            self.right_arrow_callback()
        elif key_pressed in ['left', 'down']:
            self.left_arrow_callback()
        elif key_pressed in [' ', 'space']:
            self.next_button_callback()
        elif key_pressed in ['ctrl+q', 'q+ctrl']:
            self.quit_button_callback()
        elif key_pressed in ['alt+s', 's+alt']:
            self.show_stdev_callback()
        else:
            if key_pressed in cfg.abbreviation_func_mri_default_issue_list:
                checked_label = cfg.abbreviation_func_mri_default_issue_list[
                    key_pressed]
                # TODO if user chooses a different set of names, keyboard shortcuts might not work
                self.checkbox.set_active(self.issue_list.index(checked_label))
            else:
                pass

        self.fig.canvas.draw_idle()

    def reset_figure(self):
        """Resets the figure to prepare it for display of next subject."""

        self.zoom_out_callback(None)
        self.restore_axis()
        self.clear_data()
        self.clear_checkboxes()
        self.clear_notes_annot()
Ejemplo n.º 8
0
def plot_all(filename):
    global trendsSubplots, labels, ax, fig # modifikujeme globalne premenne, vyuziva ich ControlCheckFunc

    ClearPlotGlobals() # vymaze stare data

    trend_number = 1
    if(True): # jeden spolocny graf, natvrdo nastavene!!!
 
        fig, ax = plt.subplots() # fig treba lebo subplots vracia tuple
        trendNumber = 0

        print('') # empty line
        print('Drawing new trend')

        for varName1 in varNames:
            # vracia tuple, preto ciarka za plotTmp
            plotTmp, = ax.step(dateTimes[varName1],varValues[varName1], visible=True, lw=2, color=getPlotColor(trendNumber), label=TranslateTagName(varName1))
            trendsSubplots.append(plotTmp)
            trendNumber += 1
            print('Drawing subtrend '+str(trendNumber)+': '+varName1 + ' alias ' + TranslateTagName(varName1))
            #break

        plt.subplots_adjust(left=0.28) #odkade nalavo zacina graf, treba nechat offset na legendu
        plt.suptitle('Trends for '+filename)
        plt.xlabel('Time')
        plt.ylabel('Values')

        # Xova mierka pre cas
        #hours = mdates.MinuteLocator(15)   # every hour
        #mins = mdates.MinuteLocator(5)  # every minute
        #ax.xaxis.set_major_locator(hours)
        #ax.xaxis.set_minor_locator(mins)

        # hlavna mierka s datumom, mensia iba cas
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%d.%m %H:%M:%S"))
        ax.xaxis.set_minor_formatter(mdates.DateFormatter("%H:%M:%S"))
        
        _=plt.xticks(rotation=45)   
 
        # Make checkbuttons with all plotted lines with correct visibility
        legendHeight = 0.015 * len(varNames) + 0.05 # height legend, estimation
        rax = plt.axes([0.01, 0.3, 0.20, legendHeight]) # legenda offset zlava, zdola, sirka, vyska
        labels = [str(line.get_label()) for line in trendsSubplots]
        visibility = [line.get_visible() for line in trendsSubplots]
        check = CheckButtons(rax, labels, visibility)
        check.on_clicked(ControlCheckFunc)

        # ofarbenie grafov, index labelov sedi s indexom varName1-ov
        for label1 in labels:
            i = labels.index(label1)
            check.labels[i].set_color(getPlotColor(i))
            check.labels[i].set_fontsize(8)
            x,y = check.labels[i].get_position()
            x_new = x - 0.1 # posun dolava
            if(x_new <= 0):
                x_new = 0.01
            check.labels[i].set_x(x_new)
            check.set_active(i) # default bude nezobrazene, budu sa zaskrtavat iba tie potrebne
        
        #maximalizacia grafu na cele okno
        # pouzivam rady z https://stackoverflow.com/questions/12439588/how-to-maximize-a-plt-show-window-using-python
        mng = plt.get_current_fig_manager()
        mng.window.state('zoomed') #works fine on Windows!
        #mng.frame.Maximize(True)

        #grid
        ax.grid(b=True, which='both', color='#cccccc', linestyle='--')
        plt.show()
Ejemplo n.º 9
0
class NetworkVisualisation(object):
    def __init__(self,
                 units,
                 data_points,
                 min_range,
                 max_range,
                 quality,
                 dataset,
                 saves_path=None,
                 seed=1):
        np.random.seed(seed)

        self.precision = quality
        self.min_range = min_range
        self.max_range = max_range
        self.dataset_type = dataset
        if dataset is Dataset.CIRCLE:
            self.dataset = get_circle_dataset(points=data_points,
                                              min_range=min_range,
                                              max_range=max_range,
                                              radius=0.8)
        elif dataset is Dataset.SPIRAL:
            self.dataset = get_spiral_dataset(data_points, classes=2)
        else:
            raise Exception("Invalid dataset type")
        self.data_points = self.dataset[:, :-1]
        self.data_labels = self.dataset[:, -1]
        self.data_space, self.dim_data = setup_data_space(
            min_range, max_range, quality)

        # Network Creation
        if saves_path and check_saved_network(units, dataset, saves_path):
            self.network = load_network(units, dataset, saves_path)
        else:
            if dataset is Dataset.CIRCLE:
                self.network = train_network_sigmoid(self.dataset,
                                                     units=units,
                                                     learning_rate=5e-3,
                                                     window_size=1000)
            elif dataset is Dataset.SPIRAL:
                self.network = train_network_softmax(self.dataset,
                                                     units=units,
                                                     learning_rate=1,
                                                     window_size=1000)
            else:
                raise Exception("Invalid dataset type")
            save_network(self.network, dataset, saves_path)
        self.default_network = dict(
            zip(self.network.keys(),
                [layer.copy() for layer in self.network.values()]))

        # GUI Visualisation
        self.perceptron1 = 0
        self.is_relu = True
        self.perceptron2 = 0
        self.connection = 0
        self.is_pre_add = False
        self.is_sig = True
        self.all_p1_enabled = set(range(self.network["W1"].shape[1]))
        self.ignore_update = False

        fig = plt.figure(figsize=(13, 6.5))
        self.plot_network(fig)
        self.plot_controls(fig)

    def plot_network(self, fig):
        _, out2, out3, out4 = forward(self.data_space,
                                      self.network,
                                      self.dataset_type,
                                      precision=self.precision)
        outer_points = self.data_points[self.data_labels == 1]
        inner_points = self.data_points[self.data_labels == 0]

        self.layer1_plot = Plot(fig, (4, 4), (0, 0), (1, 3), out2[:, :, 0],
                                self.min_range, self.max_range)
        self.layer1_plot.ax.scatter(outer_points[:, 0],
                                    outer_points[:, 1],
                                    s=3,
                                    c="g",
                                    alpha=0.5)
        self.layer1_plot.ax.scatter(inner_points[:, 0],
                                    inner_points[:, 1],
                                    s=3,
                                    c="r",
                                    alpha=0.5)

        self.layer1_3d_plot = Plot3D(fig, (4, 4), (1, 0), (1, 3),
                                     self.precision, out2)

        self.layer2_plot = Plot(fig, (4, 4), (2, 0), (1, 3), out4[:, :, 0],
                                self.min_range, self.max_range)
        self.layer2_plot.ax.scatter(outer_points[:, 0],
                                    outer_points[:, 1],
                                    s=3,
                                    c="g",
                                    alpha=0.5)
        self.layer2_plot.ax.scatter(inner_points[:, 0],
                                    inner_points[:, 1],
                                    s=3,
                                    c="r",
                                    alpha=0.5)

        self.layer2_3d_plot = Plot3D(fig, (4, 4), (3, 0), (1, 3),
                                     self.precision, out4)

    def plot_controls(self, fig):
        step_size = 0.01
        padding = 5

        # Plot 1 controls
        w1x_min = self.network["W1"][0].min()
        w1x_max = self.network["W1"][0].max()
        w1x_diff = (w1x_max - w1x_min) / 2 + padding

        w1y_min = self.network["W1"][1].min()
        w1y_max = self.network["W1"][1].max()
        w1y_diff = (w1y_max - w1y_min) / 2 + padding

        w1b_min = self.network["b1"].min()
        w1b_max = self.network["b1"].max()
        w1b_diff = (w1b_max - w1b_min) / 2 + padding

        p1x_ax = plot_to_grid(fig, (2, 16), (0, 12), (1, 1))
        self.p1x_slid = Slider(p1x_ax,
                               'P1 x',
                               valmin=w1x_min - w1x_diff,
                               valmax=w1x_max + w1x_diff,
                               valinit=self.network["W1"][0, 0],
                               valstep=step_size)
        self.p1x_slid.on_changed(self.p1x_changed)

        p1y_ax = plot_to_grid(fig, (2, 16), (0, 13), (1, 1))
        self.p1y_slid = Slider(p1y_ax,
                               'P1 y',
                               valmin=w1y_min - w1y_diff,
                               valmax=w1y_max + w1y_diff,
                               valinit=self.network["W1"][1, 0],
                               valstep=step_size)
        self.p1y_slid.on_changed(self.p1y_changed)

        p1b_ax = plot_to_grid(fig, (24, 16), (0, 14), (7, 1))
        self.p1b_slid = Slider(p1b_ax,
                               'P1 b',
                               valmin=w1b_min - w1b_diff,
                               valmax=w1b_max + w1b_diff,
                               valinit=self.network["b1"][0, 0],
                               valstep=step_size)
        self.p1b_slid.on_changed(self.p1b_changed)

        p1_ax = plot_to_grid(fig, (24, 16), (0, 15), (7, 1))
        self.p1_slid = Slider(p1_ax,
                              'P1',
                              valmin=0,
                              valmax=self.network["W1"].shape[1] - 1,
                              valinit=self.perceptron1,
                              valstep=1)
        self.p1_slid.on_changed(self.p1_changed)

        p1_opt_ax = plot_to_grid(fig, (24, 16), (8, 14), (3, 2))
        self.p1_opt_buttons = CheckButtons(p1_opt_ax, ["ReLU?", "Enabled?"],
                                           [self.is_relu, True])
        self.p1_opt_buttons.on_clicked(self.p1_options_update)

        # Plot 2 Controls
        w2_min = self.network["W2"].min()
        w2_max = self.network["W2"].max()
        w2_diff = (w2_max - w2_min) / 2 + padding

        w2b_abs = np.abs(self.network["b2"][0, 0]) + padding
        w2b_min = self.network["b2"][0, 0] - w2b_abs
        w2b_max = self.network["b2"][0, 0] + w2b_abs

        p2_weight_val_ax = plot_to_grid(fig, (2, 16), (1, 12), (1, 1))
        self.p2_dim_val_slid = Slider(p2_weight_val_ax,
                                      'p2 w',
                                      valmin=w2_min - w2_diff,
                                      valmax=w2_max + w2_diff,
                                      valinit=self.network["W2"][0, 0],
                                      valstep=step_size)
        self.p2_dim_val_slid.on_changed(self.p2_weight_changed)

        p2_connection_dim_ax = plot_to_grid(fig, (2, 16), (1, 13), (1, 1))
        self.p2_connection_dim_slid = Slider(
            p2_connection_dim_ax,
            'p2 c',
            valmin=0,
            valmax=self.network["W2"].shape[0] - 1,
            valinit=0,
            valstep=1)
        self.p2_connection_dim_slid.on_changed(self.p2_connection_dim_changed)

        p2b_ax = plot_to_grid(fig, (24, 16), (13, 14), (7, 1))
        self.p2b_slid = Slider(p2b_ax,
                               'p2 b',
                               valmin=w2b_min,
                               valmax=w2b_max,
                               valinit=self.network["b2"][0, 0],
                               valstep=step_size)
        self.p2b_slid.on_changed(self.p2b_changed)

        p2_opt_ax = plot_to_grid(fig, (24, 16), (21, 14), (4, 2))
        self.p2_opt_buttons = CheckButtons(p2_opt_ax,
                                           ["Pre-add?", "Transform?"],
                                           [self.is_pre_add, self.is_sig])
        self.p2_opt_buttons.on_clicked(self.p2_options_update)

    def p1_changed(self, val):
        self.perceptron1 = int(val)
        self.ignore_update = True
        self.update_widgets()
        self.ignore_update = False

        self.update_just_plot1()

    def p1x_changed(self, val):
        self.network["W1"][0, self.perceptron1] = val
        self.update_visuals()

    def p1y_changed(self, val):
        self.network["W1"][1, self.perceptron1] = val
        self.update_visuals()

    def p1b_changed(self, val):
        self.network["b1"][0, self.perceptron1] = val
        self.update_visuals()

    def p1_options_update(self, label):
        if label == "ReLU?":
            self.is_relu = not self.is_relu
            self.update_just_plot1()
        elif label == "Enabled?":
            is_enabled = self.p1_opt_buttons.get_status()[1]
            if is_enabled and self.perceptron1 not in self.all_p1_enabled:
                self.all_p1_enabled.add(self.perceptron1)
            elif not is_enabled and self.perceptron1 in self.all_p1_enabled:
                layer1_out = sorted(list(self.all_p1_enabled)).index(
                    self.perceptron1)
                self.layer1_3d_plot.remove_plot(layer1_out)
                self.all_p1_enabled.remove(self.perceptron1)

            self.update_visuals()

    def p2_weight_changed(self, val):
        self.network["W2"][self.connection, 0] = val

        self.update_just_plot2()

    def p2_connection_dim_changed(self, val):
        self.connection = int(val)
        self.ignore_update = True
        self.p2_dim_val_slid.set_val(self.network["W2"][self.connection, 0])
        self.p2_dim_val_slid.vline.set_xdata(
            self.default_network["W2"][self.connection, 0])
        self.ignore_update = False

    def p2b_changed(self, val):
        self.network["b2"][0, 0] = val
        self.update_just_plot2()

    def p2_options_update(self, label):
        if label == "Transform?":
            self.is_sig = not self.is_sig
        elif label == "Pre-add?":
            self.is_pre_add = not self.is_pre_add

        self.update_just_plot2()

    def show(self):
        plt.show()

    def update_plot1(self, out1, out2):
        if self.perceptron1 in self.all_p1_enabled:
            self.layer1_plot.set_visible(True)
            layer1_out = sorted(list(self.all_p1_enabled)).index(
                self.perceptron1)
            if not self.is_relu:
                layer1_data = out1[:, :, layer1_out]
            else:
                layer1_data = out2[:, :, layer1_out]
            self.layer1_plot.update(layer1_data)
        else:
            self.layer1_plot.set_visible(False)

    def update_3d_plot1(self, out1, out2):
        if self.perceptron1 in self.all_p1_enabled:
            if not self.is_relu:
                self.layer1_3d_plot.update_all(out1)
            else:
                self.layer1_3d_plot.update_all(out2)

    def update_plot2(self, out2, out3, out4):
        if self.is_pre_add:
            layer2_data = scale_out2(out2, self.network["W2"],
                                     self.all_p1_enabled, self.perceptron2,
                                     self.precision)
            layer2_data = np.sum(layer2_data, axis=2)
        elif not self.is_sig:
            layer2_data = out3[:, :, 0]
        else:
            layer2_data = out4[:, :, 0]
        self.layer2_plot.update(layer2_data)

    def update_3d_plot2(self, out2, out3, out4):
        if self.is_pre_add:
            layer2_data = scale_out2(out2, self.network["W2"],
                                     self.all_p1_enabled, self.perceptron2,
                                     self.precision)
        elif not self.is_sig:
            layer2_data = out3
        else:
            layer2_data = out4
        self.layer2_3d_plot.update_all(layer2_data)

    def update_visuals(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot1_visuals(out1, out2)
            self.update_plot2_visuals(out2, out3, out4)
            plt.draw()

    def update_just_plot1(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot1_visuals(out1, out2)
            plt.draw()

    def update_plot1_visuals(self, out1, out2):
        self.update_plot1(out1, out2)
        self.update_3d_plot1(out1, out2)

    def update_just_plot2(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot2_visuals(out2, out3, out4)
            plt.draw()

    def update_plot2_visuals(self, out2, out3, out4):
        self.update_plot2(out2, out3, out4)
        self.update_3d_plot2(out2, out3, out4)

    def update_widgets(self):
        self.p1b_slid.set_val(self.network["b1"][0, self.perceptron1])
        self.p1x_slid.set_val(self.network["W1"][0, self.perceptron1])
        self.p1y_slid.set_val(self.network["W1"][1, self.perceptron1])

        self.p1b_slid.vline.set_xdata(
            self.default_network["b1"][0, self.perceptron1])
        self.p1x_slid.vline.set_xdata(
            self.default_network["W1"][0, self.perceptron1])
        self.p1y_slid.vline.set_xdata(
            self.default_network["W1"][1, self.perceptron1])

        if (self.perceptron1 in self.all_p1_enabled and not self.p1_opt_buttons.get_status()[1]) or \
           (self.perceptron1 not in self.all_p1_enabled and self.p1_opt_buttons.get_status()[1]):
            self.p1_opt_buttons.set_active(1)