class numeric_field:
    def __init__(self, inp_ax, label, regexp, maxlen, decpoint=None):
        self.val = ''
        self.curpos = 0
        self.pattern = re.compile(regexp)
        self.maxlen = maxlen
        self.decpoint = decpoint
        self.tb = TextBox(inp_ax, label)
        self.tb.on_text_change(self.tc_func)

    def tc_func(self, inp):
        if (len(inp) > self.maxlen):
            self.tb.cursor_index = self.curpos
            self.tb.set_val(self.val)
            return
        if (self.decpoint and inp.find('.') < 0
                and len(inp) > self.maxlen - 1):
            self.tb.cursor_index = self.curpos
            self.tb.set_val(self.val)
            return
        if (not self.pattern.match(inp)):
            self.tb.cursor_index = self.curpos
            self.tb.set_val(self.val)
            return
        self.val = inp
        self.curpos = self.tb.cursor_index
Esempio n. 2
0
class AtlasEditor(plot_support.ImageSyncMixin):
    """Graphical interface to view an atlas in multiple orthogonal 
    dimensions and edit atlas labels.
    
    :attr:`plot_eds` are dictionaries of keys specified by one of
    :const:`magmap.config.PLANE` plane orientations to Plot Editors.
    
    Attributes:
        image5d: Numpy image array in t,z,y,x,[c] format.
        labels_img: Numpy image array in z,y,x format.
        channel: Channel of the image to display.
        offset: Index of plane at which to start viewing in x,y,z (user) 
            order.
        fn_close_listener: Handle figure close events.
        borders_img: Numpy image array in z,y,x,[c] format to show label 
            borders, such as that generated during label smoothing. 
            Defaults to None. If this image has a different number of 
            labels than that of ``labels_img``, a new colormap will 
            be generated.
        fn_show_label_3d: Function to call to show a label in a 
            3D viewer. Defaults to None.
        title (str): Window title; defaults to None.
        fn_refresh_atlas_eds (func): Callback for refreshing other
            Atlas Editors to synchronize them; defaults to None.
            Typically takes one argument, this ``AtlasEditor`` object
            to refreshing it. Defaults to None.
        alpha_slider: Matplotlib alpha slider control.
        alpha_reset_btn: Maplotlib button for resetting alpha transparency.
        alpha_last: Float specifying the previous alpha value.
        interp_planes: Current :class:`InterpolatePlanes` object.
        interp_btn: Matplotlib button to initiate plane interpolation.
        save_btn: Matplotlib button to save the atlas.
        fn_status_bar (func): Function to call during status bar updates
            in :class:`pixel_display.PixelDisplay`; defaults to None.
        fn_update_coords (func): Handler for coordinate updates, which
            takes coordinates in z-plane orientation; defaults to None.
    """

    _EDIT_BTN_LBLS = ("Edit", "Editing")

    def __init__(self,
                 image5d,
                 labels_img,
                 channel,
                 offset,
                 fn_close_listener,
                 borders_img=None,
                 fn_show_label_3d=None,
                 title=None,
                 fn_refresh_atlas_eds=None,
                 fig=None,
                 fn_status_bar=None):
        """Plot ROI as sequence of z-planes containing only the ROI itself."""
        super().__init__()
        self.image5d = image5d
        self.labels_img = labels_img
        self.channel = channel
        self.offset = offset
        self.fn_close_listener = fn_close_listener
        self.borders_img = borders_img
        self.fn_show_label_3d = fn_show_label_3d
        self.title = title
        self.fn_refresh_atlas_eds = fn_refresh_atlas_eds
        self.fig = fig
        self.fn_status_bar = fn_status_bar

        self.alpha_slider = None
        self.alpha_reset_btn = None
        self.alpha_last = None
        self.interp_planes = None
        self.interp_btn = None
        self.save_btn = None
        self.edit_btn = None
        self.color_picker_box = None
        self.fn_update_coords = None

        self._labels_img_sitk = None  # for saving labels image

    def show_atlas(self):
        """Set up the atlas display with multiple orthogonal views."""
        # set up the figure
        if self.fig is None:
            fig = figure.Figure(self.title)
            self.fig = fig
        else:
            fig = self.fig
        fig.clear()
        gs = gridspec.GridSpec(2,
                               1,
                               wspace=0.1,
                               hspace=0.1,
                               height_ratios=(20, 1),
                               figure=fig,
                               left=0.06,
                               right=0.94,
                               bottom=0.02,
                               top=0.98)
        gs_viewers = gridspec.GridSpecFromSubplotSpec(2,
                                                      2,
                                                      subplot_spec=gs[0, 0])

        # set up a colormap for the borders image if present
        cmap_borders = colormaps.get_borders_colormap(self.borders_img,
                                                      self.labels_img,
                                                      config.cmap_labels)
        coord = list(self.offset[::-1])

        # editor controls, split into a slider sub-spec to allow greater
        # spacing for labels on either side and a separate sub-spec for
        # buttons and other fields
        gs_controls = gridspec.GridSpecFromSubplotSpec(1,
                                                       2,
                                                       subplot_spec=gs[1, 0],
                                                       width_ratios=(1, 1),
                                                       wspace=0.15)
        self.alpha_slider = Slider(
            fig.add_subplot(gs_controls[0, 0]),
            "Opacity",
            0.0,
            1.0,
            valinit=plot_editor.PlotEditor.ALPHA_DEFAULT)
        gs_controls_btns = gridspec.GridSpecFromSubplotSpec(
            1, 5, subplot_spec=gs_controls[0, 1], wspace=0.1)
        self.alpha_reset_btn = Button(fig.add_subplot(gs_controls_btns[0, 0]),
                                      "Reset")
        self.interp_btn = Button(fig.add_subplot(gs_controls_btns[0, 1]),
                                 "Fill Label")
        self.interp_planes = InterpolatePlanes(self.interp_btn)
        self.interp_planes.update_btn()
        self.save_btn = Button(fig.add_subplot(gs_controls_btns[0, 2]), "Save")
        self.edit_btn = Button(fig.add_subplot(gs_controls_btns[0, 3]), "Edit")
        self.color_picker_box = TextBox(
            fig.add_subplot(gs_controls_btns[0, 4]), None)

        # adjust button colors based on theme and enabled status; note
        # that colors do not appear to refresh until fig mouseover
        for btn in (self.alpha_reset_btn, self.edit_btn):
            enable_btn(btn)
        enable_btn(self.save_btn, False)
        enable_btn(self.color_picker_box, color=config.widget_color + 0.1)

        def setup_plot_ed(axis, gs_spec):
            # set up a PlotEditor for the given axis

            # subplot grid, with larger height preference for plot for
            # each increased row to make sliders of approx equal size and
            # align top borders of top images
            rows_cols = gs_spec.get_rows_columns()
            extra_rows = rows_cols[3] - rows_cols[2]
            gs_plot = gridspec.GridSpecFromSubplotSpec(
                2,
                1,
                subplot_spec=gs_spec,
                height_ratios=(1, 10 + 14 * extra_rows),
                hspace=0.1 / (extra_rows * 1.4 + 1))

            # transform arrays to the given orthogonal direction
            ax = fig.add_subplot(gs_plot[1, 0])
            plot_support.hide_axes(ax)
            plane = config.PLANE[axis]
            arrs_3d, aspect, origin, scaling = \
                plot_support.setup_images_for_plane(
                    plane,
                    (self.image5d[0], self.labels_img, self.borders_img))
            img3d_tr, labels_img_tr, borders_img_tr = arrs_3d

            # slider through image planes
            ax_scroll = fig.add_subplot(gs_plot[0, 0])
            plane_slider = Slider(ax_scroll,
                                  plot_support.get_plane_axis(plane),
                                  0,
                                  len(img3d_tr) - 1,
                                  valfmt="%d",
                                  valinit=0,
                                  valstep=1)

            # plot editor
            max_size = max_sizes[axis] if max_sizes else None
            plot_ed = plot_editor.PlotEditor(
                ax,
                img3d_tr,
                labels_img_tr,
                config.cmap_labels,
                plane,
                aspect,
                origin,
                self.update_coords,
                self.refresh_images,
                scaling,
                plane_slider,
                img3d_borders=borders_img_tr,
                cmap_borders=cmap_borders,
                fn_show_label_3d=self.fn_show_label_3d,
                interp_planes=self.interp_planes,
                fn_update_intensity=self.update_color_picker,
                max_size=max_size,
                fn_status_bar=self.fn_status_bar)
            return plot_ed

        # setup plot editors for all 3 orthogonal directions
        max_sizes = plot_support.get_downsample_max_sizes()
        for i, gs_viewer in enumerate(
            (gs_viewers[:2, 0], gs_viewers[0, 1], gs_viewers[1, 1])):
            self.plot_eds[config.PLANE[i]] = setup_plot_ed(i, gs_viewer)
        self.set_show_crosslines(True)

        # attach listeners
        fig.canvas.mpl_connect("scroll_event", self.scroll_overview)
        fig.canvas.mpl_connect("key_press_event", self.on_key_press)
        fig.canvas.mpl_connect("close_event", self._close)
        fig.canvas.mpl_connect("axes_leave_event", self.axes_exit)

        self.alpha_slider.on_changed(self.alpha_update)
        self.alpha_reset_btn.on_clicked(self.alpha_reset)
        self.interp_btn.on_clicked(self.interpolate)
        self.save_btn.on_clicked(self.save_atlas)
        self.edit_btn.on_clicked(self.toggle_edit_mode)
        self.color_picker_box.on_text_change(self.color_picker_changed)

        # initialize and show planes in all plot editors
        if self._max_intens_proj is not None:
            self.update_max_intens_proj(self._max_intens_proj)
        self.update_coords(coord, config.PLANE[0])

        plt.ion()  # avoid the need for draw calls

    def _close(self, evt):
        """Handle figure close events by calling :attr:`fn_close_listener`
        with this object.

        Args:
            evt (:obj:`matplotlib.backend_bases.CloseEvent`): Close event.

        """
        self.fn_close_listener(evt, self)

    def on_key_press(self, event):
        """Respond to key press events.
        """
        if event.key == "a":
            # toggle between current and 0 opacity
            if self.alpha_slider.val == 0:
                # return to saved alpha if available and reset
                if self.alpha_last is not None:
                    self.alpha_slider.set_val(self.alpha_last)
                self.alpha_last = None
            else:
                # make translucent, saving alpha if not already saved
                # during a halve-opacity event
                if self.alpha_last is None:
                    self.alpha_last = self.alpha_slider.val
                self.alpha_slider.set_val(0)
        elif event.key == "A":
            # halve opacity, only saving alpha on first halving to allow
            # further halving or manual movements while still returning to
            # originally saved alpha
            if self.alpha_last is None:
                self.alpha_last = self.alpha_slider.val
            self.alpha_slider.set_val(self.alpha_slider.val / 2)
        elif event.key == "up" or event.key == "down":
            # up/down arrow for scrolling planes
            self.scroll_overview(event)
        elif event.key == "w":
            # shortcut to toggle editing mode
            self.toggle_edit_mode(event)
        elif event.key == "ctrl+s" or event.key == "cmd+s":
            # support default save shortcuts on multiple platforms;
            # ctrl-s will bring up save dialog from fig, but cmd/win-S
            # will bypass
            self.save_fig(self.get_save_path())

    def update_coords(self, coord, plane_src=config.PLANE[0]):
        """Update all plot editors with given coordinates.
        
        Args:
            coord: Coordinate at which to center images, in z,y,x order.
            plane_src: One of :const:`magmap.config.PLANE` to specify the 
                orientation from which the coordinates were given; defaults 
                to the first element of :const:`magmap.config.PLANE`.
        """
        coord_rev = libmag.transpose_1d_rev(list(coord), plane_src)
        for i, plane in enumerate(config.PLANE):
            coord_transposed = libmag.transpose_1d(list(coord_rev), plane)
            if i == 0:
                self.offset = coord_transposed[::-1]
                if self.fn_update_coords:
                    # update offset based on xy plane, without centering
                    # planes are centered on the offset as-is
                    self.fn_update_coords(coord_transposed, False)
            self.plot_eds[plane].update_coord(coord_transposed)

    def view_subimg(self, offset, shape):
        """Zoom all Plot Editors to the given sub-image.

        Args:
            offset: Sub-image coordinates in ``z,y,x`` order.
            shape: Sub-image shape in ``z,y,x`` order.
        
        """
        for i, plane in enumerate(config.PLANE):
            offset_tr = libmag.transpose_1d(list(offset), plane)
            shape_tr = libmag.transpose_1d(list(shape), plane)
            self.plot_eds[plane].view_subimg(offset_tr[1:], shape_tr[1:])
        self.fig.canvas.draw_idle()

    def refresh_images(self, plot_ed=None, update_atlas_eds=False):
        """Refresh images in a plot editor, such as after editing one
        editor and updating the displayed image in the other editors.
        
        Args:
            plot_ed (:obj:`magmap.plot_editor.PlotEditor`): Editor that
                does not need updating, typically the editor that originally
                changed. Defaults to None.
            update_atlas_eds (bool): True to update other ``AtlasEditor``s;
                defaults to False.
        """
        for key in self.plot_eds:
            ed = self.plot_eds[key]
            if ed != plot_ed: ed.refresh_img3d_labels()
            if ed.edited:
                # display save button as enabled if any editor has been edited
                enable_btn(self.save_btn)
        if update_atlas_eds and self.fn_refresh_atlas_eds is not None:
            # callback to synchronize other Atlas Editors
            self.fn_refresh_atlas_eds(self)

    def scroll_overview(self, event):
        """Scroll images and crosshairs in all plot editors
        
        Args:
            event: Scroll event.
        """
        for key in self.plot_eds:
            self.plot_eds[key].scroll_overview(event)

    def alpha_update(self, event):
        """Update the alpha transparency in all plot editors.
        
        Args:
            event: Slider event.
        """
        for key in self.plot_eds:
            self.plot_eds[key].alpha_updater(event)

    def alpha_reset(self, event):
        """Reset the alpha transparency in all plot editors.
        
        Args:
            event: Button event, currently ignored.
        """
        self.alpha_slider.reset()

    def axes_exit(self, event):
        """Trigger axes exit for all plot editors.
        
        Args:
            event: Axes exit event.
        """
        for key in self.plot_eds:
            self.plot_eds[key].on_axes_exit(event)

    def interpolate(self, event):
        """Interpolate planes using :attr:`interp_planes`.
        
        Args:
            event: Button event, currently ignored.
        """
        try:
            self.interp_planes.interpolate(self.labels_img)
            # flag Plot Editors as edited so labels can be saved
            for ed in self.plot_eds.values():
                ed.edited = True
            self.refresh_images(None, True)
        except ValueError as e:
            print(e)

    def save_atlas(self, event):
        """Save atlas labels using the registered image suffix given by
        :attr:`config.reg_suffixes[config.RegSuffixes.ANNOTATION]`.
        
        Args:
            event: Button event, currently not used.
        
        """
        # only save if at least one editor has been edited
        if not any([ed.edited for ed in self.plot_eds.values()]): return

        # save to the labels reg suffix; use sitk Image if loaded and store
        # any Image loaded during saving
        reg_name = config.reg_suffixes[config.RegSuffixes.ANNOTATION]
        if self._labels_img_sitk is None:
            self._labels_img_sitk = config.labels_img_sitk
        self._labels_img_sitk = sitk_io.write_registered_image(
            self.labels_img,
            config.filename,
            reg_name,
            self._labels_img_sitk,
            overwrite=True)

        # reset edited flag in all editors and show save button as disabled
        for ed in self.plot_eds.values():
            ed.edited = False
        enable_btn(self.save_btn, False)
        print("Saved labels image at {}".format(datetime.datetime.now()))

    def get_save_path(self):
        """Get figure save path based on filename, ROI, and overview plane
         shown.
        
        Returns:
            str: Figure save path.

        """
        ext = config.savefig if config.savefig else config.DEFAULT_SAVEFIG
        return "{}.{}".format(
            naming.get_roi_path(os.path.basename(self.title), self.offset),
            ext)

    def toggle_edit_mode(self, event):
        """Toggle editing mode, determining the current state from the
        first :class:`magmap.plot_editor.PlotEditor` and switching to the 
        opposite value for all plot editors.

        Args:
            event: Button event, currently not used.
        """
        edit_mode = False
        for i, ed in enumerate(self.plot_eds.values()):
            if i == 0:
                # change edit mode based on current mode in first plot editor
                edit_mode = not ed.edit_mode
                toggle_btn(self.edit_btn, edit_mode, text=self._EDIT_BTN_LBLS)
            ed.edit_mode = edit_mode
        if not edit_mode:
            # reset the color picker text box when turning off editing
            self.color_picker_box.set_val("")

    def update_color_picker(self, val):
        """Update the color picker :class:`TextBox` with the given value.

        Args:
            val (str): Color value. If None, only :meth:`color_picker_changed`
                will be triggered.
        """
        if val is None:
            # updated picked color directly
            self.color_picker_changed(val)
        else:
            # update text box, which triggers color_picker_changed
            self.color_picker_box.set_val(val)

    def color_picker_changed(self, text):
        """Respond to color picker :class:`TextBox` changes by updating
        the specified intensity value in all plot editors.

        Args:
            text (str): String of text box value. Converted to an int if
                non-empty.
        """
        intensity = text
        if text:
            if not libmag.is_number(intensity): return
            intensity = int(intensity)
        print("updating specified color to", intensity)
        for i, ed in enumerate(self.plot_eds.values()):
            ed.intensity_spec = intensity
Esempio n. 3
0
class Visualization:
    """Visualize embedding with its infrastructure and overlay"""
    def __init__(self, embedding: PartialEmbedding):
        self.embedding = embedding

        shape = (2, 3)
        self.infra_ax = plt.subplot2grid(shape=shape, loc=(0, 0))
        self.overlay_ax = plt.subplot2grid(shape=shape, loc=(1, 0))
        self.embedding_ax = plt.subplot2grid(shape=shape,
                                             loc=(0, 1),
                                             rowspan=2,
                                             colspan=2)
        plt.subplots_adjust(bottom=0.2)
        input_text_ax = plt.axes([0.1, 0.05, 0.6,
                                  0.075]  # left  # bottom  # width  # height
                                 )
        input_btn_ax = plt.axes([0.7, 0.05, 0.2,
                                 0.075]  # left  # bottom  # width  # height
                                )

        self.update_infra()
        self.update_overlay()
        self.update_embedding()

        pa = mpatches.Patch
        plt.gcf().legend(handles=[
            pa(color=COLORS["sources_color"], label="source"),
            pa(color=COLORS["sink_color"], label="sink"),
            pa(color=COLORS["intermediates_color"], label="intermediate"),
        ])

        random = get_random_action(self.embedding, rand=np.random)
        self.text_box_val = str(random)
        self.text_box = TextBox(input_text_ax,
                                "Action",
                                initial=self.text_box_val)

        def _update_textbox_val(new_val):
            self.text_box_val = new_val

        self.text_box.on_text_change(_update_textbox_val)

        self.submit_btn = Button(input_btn_ax, "Take action")

        def _on_clicked(_):
            self._take_action(self._parse_textbox())

        self.submit_btn.on_clicked(_on_clicked)

    def _parse_textbox(self):
        action = self.text_box_val
        possibilities = self.embedding.possibilities()
        for possibility in possibilities:
            if str(possibility) == action:
                return possibility
        return None

    def _update_textbox(self):
        next_random = get_random_action(self.embedding, rand=np.random)
        self.text_box.set_val(
            str(next_random) if next_random is not None else "")

    def _take_action(self, action):
        if action is None:
            print("Action could not be parsed")
            return
        print(f"Taking action: {action}")
        success = self.embedding.take_action(*action)
        if not success:
            print("Action is not valid. The possibilities are:")
        self.update_embedding()
        self._update_textbox()
        plt.draw()

    def update_infra(self):
        """Redraws the infrastructure"""
        plt.sca(self.infra_ax)
        plt.cla()
        self.infra_ax.set_title("Infrastructure")
        draw_infra(self.embedding.infra, **COLORS)

    def update_overlay(self):
        """Redraws the overlay"""
        plt.sca(self.overlay_ax)
        plt.cla()
        self.overlay_ax.set_title("Overlay")
        draw_overlay(self.embedding.overlay, **COLORS)

    def update_embedding(self):
        """Redraws the embedding"""
        plt.sca(self.embedding_ax)
        plt.cla()
        self.embedding_ax.set_title("Embedding")
        draw_embedding(self.embedding, **COLORS)
Esempio n. 4
0
    verticalpad = int(val)
    draw()


def updatesourceblur(val):
    global sourceblur
    sourceblur = int(val)
    draw()


def updaterandomfrequency(val):
    global randomfrequency
    randomfrequency = val
    draw()


savefilenamebox.on_text_change(updatesavefilename)
savebutton.on_clicked(save)
thicknessslider.on_changed(updatethickness)
linesslider.on_changed(updatelines)
noiseslider.on_changed(updatenoise)
offsetscaleslider.on_changed(updateoffset)
hsizeslider.on_changed(updatehsize)
vsizeslider.on_changed(updatevsize)
hpadslider.on_changed(updatehpad)
vpadslider.on_changed(updatevpad)
sourceblurslider.on_changed(updatesourceblur)
randomfrequencyslider.on_changed(updaterandomfrequency)

plt.show()
Esempio n. 5
0
def createGUI(fileName, variable, outFile, refFile, applyFile, nogui,
              overwrite):

    if not outFile:
        outFile = join(dirname(fileName), 'edit_' + basename(fileName))
    editsFile = splitext(outFile)[0] + '.txt'

    if fileName == outFile:
        error('Output filename must differ from input filename "{}". Exiting.'.
              format(fileName))

    if not overwrite:
        if os.path.exists(outFile) or os.path.exists(editsFile):
            error(
                '"{}" or "{}" already exists. To overwrite, use the --overwrite option.'
                .format(outFile, editsFile))

    # Open NetCDF files
    try:
        rg = Dataset(fileName, 'r')
    except:
        error('There was a problem opening input NetCDF file "' + fileName +
              '".')

    rgVar = rg.variables[variable]  # handle to the variable
    dims = rgVar.dimensions  # tuple of dimensions
    depth = rgVar[:]  # Read the data
    #depth = depth[0:600,0:600]
    (nj, ni) = depth.shape
    print('Range of input depths: min=', np.amin(depth), 'max=',
          np.amax(depth))

    ref = None
    if refFile:
        try:
            ref = Dataset(refFile, 'r').variables[variable][:]
        except:
            error('There was a problem opening reference NetCDF file "' +
                  refFile + '".')

    try:
        sg = Dataset('supergrid.nc', 'r')
        lon = sg.variables['x'][:]
        lon = lon[0:2 * nj + 1:2, 0:2 * ni + 1:2]
        lat = sg.variables['y'][:]
        lat = lat[0:2 * nj + 1:2, 0:2 * ni + 1:2]
    except:
        lon, lat = np.meshgrid(np.arange(ni + 1), np.arange(nj + 1))
    fullData = Topography(lon, lat, depth, ref)

    class Container:
        def __init__(self):
            self.view = None
            self.edits = None
            self.data = None
            self.quadMesh = None
            self.cbar = None
            self.ax = None
            self.syms = None
            self.useref = False
            self.textbox = None
            cdict = {
                'red': ((0.0, 0.0, 0.0), (0.5, 0.7, 0.0), (1.0, 0.9, 0.0)),
                'green': ((0.0, 0.0, 0.0), (0.5, 0.7, 0.2), (1.0, 1.0, 0.0)),
                'blue': ((0.0, 0.0, 0.2), (0.5, 1.0, 0.0), (1.0, 0.9, 0.0))
            }
            cdict_r = {
                'red': ((0.0, 0.0, 0.0), (0.497, 0.7, 0.0), (1.0, 0.9, 0.0)),
                'green': ((0.0, 0.0, 0.0), (0.497, 0.7, 0.2), (1.0, 1.0, 0.0)),
                'blue': ((0.0, 0.0, 0.2), (0.497, 1.0, 0.0), (1.0, 0.9, 0.0))
            }
            self.cmap1 = LinearSegmentedColormap('my_colormap', cdict, 256)
            self.cmap2 = LinearSegmentedColormap('my_colormap', cdict_r,
                                                 256).reversed()
            self.cmap3 = plt.get_cmap('seismic')
            self.cmap = self.cmap1
            self.prevcmap = self.cmap
            self.clim = 6000
            self.plotdiff = False
            self.fieldname = None

    All = Container()
    All.view = View(ni, nj)
    All.edits = Edits()

    # Read edit data, if it exists
    if 'iEdit' in rg.variables:
        jEdit = rg.variables['iEdit'][:]
        iEdit = rg.variables['jEdit'][:]
        zEdit = rg.variables['zEdit'][:]  # Original value of edited data
        for l, i in enumerate(iEdit):
            All.edits.setVal(fullData.height[iEdit[l], jEdit[l]])
            fullData.height[iEdit[l], jEdit[l]] = zEdit[l]  # Restore data
            All.edits.add(iEdit[l], jEdit[l])
    if applyFile:
        try:  # first try opening as a NetCDF
            apply = Dataset(applyFile, 'r')
            if 'iEdit' in apply.variables:
                jEdit = apply.variables['iEdit'][:]
                iEdit = apply.variables['jEdit'][:]
                zNew = apply.variables[variable]
                for l, i in enumerate(iEdit):
                    All.edits.add(iEdit[l], jEdit[l], zNew[iEdit[l], jEdit[l]])
            apply.close()
        except:
            try:  # if that fails, try opening as a text file
                with open(applyFile, 'rt') as edFile:
                    edCount = 0
                    line = edFile.readline()
                    version = None
                    if line.startswith('editTopo.py'):
                        version = line.strip().split()[-1]
                    if version is None:
                        # assume this is in comma-delimited format ii, jj, zNew # comment
                        # where ii, jj may be integers or start:end inclusive integer ranges,
                        # indexed counting from 1
                        while line:
                            linedata = line.strip().split('#')[0].strip()
                            if linedata:
                                jEdits, iEdits, zNew = linedata.split(
                                    ',')  # swaps meaning of i & j
                                iEdits = [
                                    int(x) for x in iEdits.strip().split(':')
                                ]
                                jEdits = [
                                    int(x) for x in jEdits.strip().split(':')
                                ]
                                zNew = float(zNew.strip())
                                for ed in [iEdits, jEdits]:
                                    if len(ed) == 1:
                                        ed.append(ed[0] + 1)
                                    elif len(ed) == 2:
                                        ed[1] += 1
                                    else:
                                        raise ValueError
                                for i in range(*iEdits):
                                    for j in range(*jEdits):
                                        All.edits.add(
                                            i - 1, j - 1, zNew
                                        )  # -1 because ii, jj count from 1
                                        edCount += 1
                            line = edFile.readline()
                    elif version == '1':
                        # whitespace-delimited format jEdit iEdit zOld zNew # comment
                        # where ii, jj are integer indices counting from 0
                        while line:
                            line = edFile.readline()
                            linedata = line.strip().split('#')[0].strip()
                            if linedata:
                                jEdit, iEdit, _, zNew = linedata.split(
                                )  # swap meaning of i & j; ignore zOld
                                iEdit = int(iEdit)
                                jEdit = int(jEdit)
                                zNew = float(zNew)
                                All.edits.add(iEdit, jEdit, zNew)
                                edCount += 1
                    else:
                        error('Unsupported version "{}" in "{}".'.format(
                            version, applyFile))
                    print('Applied {} cell edits from "{}".'.format(
                        edCount, applyFile))
            except:
                error('There was a problem applying edits from "' + applyFile +
                      '".')

    All.data = fullData.cloneWindow((All.view.i0, All.view.j0),
                                    (All.view.iw, All.view.jw))
    All.fieldname = All.data.fieldnames[0]
    if All.edits.ijz:
        All.data.applyEdits(fullData, All.edits.ijz)

    # A mask based solely on value of depth
    # notLand = np.where( depth<0, 1, 0)
    # wet = ice9it(600,270,depth)

    # plt.rcParams['toolbar'] = 'None'  # don't use - also disables statusbar

    def replot(All):
        if All.cbar is not None:
            All.cbar.remove()
        h = plt.pcolormesh(All.data.longitude,
                           All.data.latitude,
                           All.data.plotfield,
                           cmap=All.cmap,
                           vmin=-All.clim,
                           vmax=All.clim)
        hc = plt.colorbar()
        return (h, hc)

    All.quadMesh, All.cbar = replot(All)
    All.syms = All.edits.plot(fullData)
    dir(All.syms)
    All.ax = plt.gca()
    All.ax.set_xlim(All.data.xlim)
    All.ax.set_ylim(All.data.ylim)

    if fullData.haveref:

        def setsource(label):
            All.fieldname = label
            All.data.plotfield = All.data.fields[All.fieldname]
            All.plotdiff = All.fieldname == All.data.fieldnames[2]
            if All.plotdiff and All.cmap != All.cmap3:
                All.prevcmap = All.cmap
                All.cmap = All.cmap3
            else:
                All.cmap = All.prevcmap
            All.quadMesh.set_cmap(All.cmap)
            All.cbar.mappable.set_cmap(All.cmap)
            All.quadMesh.set_array(All.data.plotfield.ravel())
            plt.draw()

        sourcebuttons = RadioButtons(plt.axes([.88, .4, 0.12, 0.15]),
                                     All.data.fieldnames)
        sourcebuttons.on_clicked(setsource)

    def setDepth(str):
        try:
            All.edits.setVal(float(str))
        except:
            pass

    tbax = plt.axes([0.12, 0.01, 0.3, 0.05])
    textbox = TextBox(tbax, 'set depth', '0')
    textbox.on_submit(setDepth)
    textbox.on_text_change(setDepth)

    def nothing(x, y):
        return ''

    tbax.format_coord = nothing  # stop status bar displaying coords in textbox
    All.textbox = textbox
    if fullData.haveref:
        All.useref = True
        userefcheck = CheckButtons(plt.axes([0.42, 0.01, 0.11, 0.05]),
                                   ['use ref'], [All.useref])

        def setuseref(_):
            All.useref = userefcheck.get_status()[0]
            if not All.useref:
                All.edits.setVal(0.0)
                All.textbox.set_val(repr(All.edits.newDepth))

        userefcheck.on_clicked(setuseref)
    else:
        All.useref = False

    lowerButtons = Buttons(left=.9)

    def undoLast(event):
        All.edits.pop()
        All.data = fullData.cloneWindow((All.view.i0, All.view.j0),
                                        (All.view.iw, All.view.jw),
                                        fieldname=All.fieldname)
        All.data.applyEdits(fullData, All.edits.ijz)
        All.quadMesh.set_array(All.data.plotfield.ravel())
        All.edits.updatePlot(fullData, All.syms)
        plt.draw()

    lowerButtons.add('Undo', undoLast)

    upperButtons = Buttons(bottom=1 - .0615)

    def colorScale(event):
        Levs = [50, 100, 200, 500, 1000, 2000, 3000, 4000, 5000, 6000]
        i = Levs.index(All.clim)
        if event == ' + ':
            i = min(i + 1, len(Levs) - 1)
        elif event == ' - ':
            i = max(i - 1, 0)
        elif event == 'Flip' and not All.plotdiff:
            if All.cmap == All.cmap1:
                All.cmap = All.cmap2
            else:
                All.cmap = All.cmap1
        All.clim = Levs[i]
        All.quadMesh.set_clim(vmin=-All.clim, vmax=All.clim)
        All.quadMesh.set_cmap(All.cmap)
        All.cbar.mappable.set_clim(vmin=-All.clim, vmax=All.clim)
        All.cbar.mappable.set_cmap(All.cmap)
        plt.draw()

    def moveVisData(di, dj):
        All.view.move(di, dj)
        All.data = fullData.cloneWindow((All.view.i0, All.view.j0),
                                        (All.view.iw, All.view.jw),
                                        fieldname=All.fieldname)
        All.data.applyEdits(fullData, All.edits.ijz)
        plt.sca(All.ax)
        plt.cla()
        All.quadMesh, All.cbar = replot(All)
        All.ax.set_xlim(All.data.xlim)
        All.ax.set_ylim(All.data.ylim)
        All.syms = All.edits.plot(fullData)
        plt.draw()

    def moveWindowLeft(event):
        moveVisData(-1, 0)

    upperButtons.add('West', moveWindowLeft)

    def moveWindowRight(event):
        moveVisData(1, 0)

    upperButtons.add('East', moveWindowRight)

    def moveWindowDown(event):
        moveVisData(0, -1)

    upperButtons.add('South', moveWindowDown)

    def moveWindowUp(event):
        moveVisData(0, 1)

    upperButtons.add('North', moveWindowUp)
    climButtons = Buttons(bottom=1 - .0615, left=0.75)

    def incrCScale(event):
        colorScale(' + ')

    climButtons.add(' + ', incrCScale)

    def decrCScale(event):
        colorScale(' - ')

    climButtons.add(' - ', decrCScale)

    def revcmap(event):
        colorScale('Flip')

    climButtons.add('Flip', revcmap)
    plt.sca(All.ax)

    def onClick(event):  # Mouse button click
        if event.inaxes == All.ax and event.button == 1 and event.xdata:
            # left click: edit point
            (i, j) = findPointInMesh(fullData.longitude, fullData.latitude,
                                     event.xdata, event.ydata)
            if i is not None:
                (I, J) = findPointInMesh(All.data.longitude, All.data.latitude,
                                         event.xdata, event.ydata)
                if event.dblclick:
                    nVal = -99999
                    if All.data.height[I + 1, J] < 0:
                        nVal = max(nVal, All.data.height[I + 1, J])
                    if All.data.height[I - 1, J] < 0:
                        nVal = max(nVal, All.data.height[I - 1, J])
                    if All.data.height[I, J + 1] < 0:
                        nVal = max(nVal, All.data.height[I, J + 1])
                    if All.data.height[I, J - 1] < 0:
                        nVal = max(nVal, All.data.height[I, J - 1])
                    if nVal == -99999:
                        return
                    All.edits.add(i, j, nVal)
                    All.data.height[I, J] = nVal
                else:
                    All.edits.add(i, j)
                    All.data.height[I, J] = All.edits.get()
                if All.data.haveref:
                    All.data.diff[I,
                                  J] = All.data.height[I, J] - All.data.ref[I,
                                                                            J]
                All.quadMesh.set_array(All.data.plotfield.ravel())
                All.edits.updatePlot(fullData, All.syms)
                plt.draw()
        elif event.inaxes == All.ax and event.button == 3 and event.xdata:
            # right click: undo edit
            (i, j) = findPointInMesh(fullData.longitude, fullData.latitude,
                                     event.xdata, event.ydata)
            if i is not None:
                All.edits.delete(i, j)
                All.data = fullData.cloneWindow((All.view.i0, All.view.j0),
                                                (All.view.iw, All.view.jw),
                                                fieldname=All.fieldname)
                All.data.applyEdits(fullData, All.edits.ijz)
                All.quadMesh.set_array(All.data.plotfield.ravel())
                All.edits.updatePlot(fullData, All.syms)
                plt.draw()
        elif event.inaxes == All.ax and event.button == 2 and event.xdata:
            zoom(event)  # Re-center

    plt.gcf().canvas.mpl_connect('button_press_event', onClick)

    def zoom(event):  # Scroll wheel up/down
        if event.button == 'up':
            scale_factor = 1 / 1.5  # deal with zoom in
        elif event.button == 'down':
            scale_factor = 1.5  # deal with zoom out
        else:
            scale_factor = 1.0
        new_xlim, new_ylim = newLims(All.ax.get_xlim(), All.ax.get_ylim(),
                                     (event.xdata, event.ydata), All.data.xlim,
                                     All.data.ylim, All.view.ni, All.view.nj,
                                     scale_factor)
        if new_xlim is None:
            return  # No change in limits
        All.view.seti(new_xlim)
        All.view.setj(new_ylim)
        All.data = fullData.cloneWindow((All.view.i0, All.view.j0),
                                        (All.view.iw, All.view.jw),
                                        fieldname=All.fieldname)
        All.data.applyEdits(fullData, All.edits.ijz)
        plt.sca(All.ax)
        plt.cla()
        All.quadMesh, All.cbar = replot(All)
        # All.ax.set_xlim(All.data.xlim)
        # All.ax.set_ylim(All.data.ylim)
        All.syms = All.edits.plot(fullData)
        All.ax.set_xlim(new_xlim)
        All.ax.set_ylim(new_ylim)
        # All.cbar.mappable.set_clim(vmin=-All.clim, vmax=All.clim)
        # All.cbar.mappable.set_cmap(All.cmap)
        plt.draw()  # force re-draw

    plt.gcf().canvas.mpl_connect('scroll_event', zoom)

    def statusMesg(x, y):
        j, i = findPointInMesh(fullData.longitude, fullData.latitude, x, y)
        if All.useref:
            All.textbox.set_val(repr(
                fullData.ref[j, i]))  # callback calls All.edits.setVal
        if i is not None:
            height = fullData.height[j, i]
            newval = All.edits.getEdit(j, i)
            if newval is not None:
                return 'depth(%i,%i) = %g (was %g)      depth - set depth = %g' % \
                        (i, j, newval, height, newval - All.edits.newDepth)
            else:
                return 'depth(%i,%i) = %g      depth - set depth = %g' % \
                        (i, j, height, height - All.edits.newDepth)
        else:
            return 'new depth = %g' % \
                    (All.edits.newDepth)

    All.ax.format_coord = statusMesg

    if not nogui:
        print("""
Ignore all the controls in the toolbar at the top of the window.
Zoom in and out with the scroll wheel.
Pan the view with the North, South, East and West buttons.
Use +, -, Flip buttons to modify the colormap.
Set the prescribed depth with the textbox at the bottom.
Left click on a cell to apply the prescribed depth value.
Right click on a cell to reset to the original value.
Double left click to assign the highest of the 4 nearest points with depth<0.
Close the window to write the edits to the output NetCDF file,
and also to a .txt file.
""")
        plt.show()


# The following is executed after GUI window is closed
# All.edits.list()
    if not outFile == ' ':
        print('Made %i edits.' % (len(All.edits.ijz)))
        print('Writing edited topography to "' + outFile + '".')
        # Create new netcdf file
        if not fileName == outFile:
            sh.copyfile(fileName, outFile)
        try:
            rg = Dataset(outFile, 'r+')
        except:
            error('There was a problem opening "' + outFile + '".')
        rgVar = rg.variables[variable]  # handle to the variable
        dims = rgVar.dimensions  # tuple of dimensions
        rgVar[:] = fullData.height[:, :]  # Write the data
        if All.edits.ijz:
            # print('Applying %i edits' % (len(All.edits.ijz)))
            if 'nEdits' in rg.dimensions:
                numEdits = rg.dimensions['nEdits']
            else:
                numEdits = rg.createDimension('nEdits',
                                              0)  # len(All.edits.ijz))
            if 'iEdit' in rg.variables:
                iEd = rg.variables['iEdit']
            else:
                iEd = rg.createVariable('iEdit', 'i4', ('nEdits', ))
                iEd.long_name = 'i-index of edited data'
            if 'jEdit' in rg.variables:
                jEd = rg.variables['jEdit']
            else:
                jEd = rg.createVariable('jEdit', 'i4', ('nEdits', ))
                jEd.long_name = 'j-index of edited data'
            if 'zEdit' in rg.variables:
                zEd = rg.variables['zEdit']
            else:
                zEd = rg.createVariable('zEdit', 'f4', ('nEdits', ))
                zEd.long_name = 'Original value of edited data'
                try:
                    zEd.units = rgVar.units
                except AttributeError:
                    zEd.units = 'm'
            hist_str = 'made %i changes (i, j, old, new): ' % len(
                All.edits.ijz)
            for l, (i, j, z) in enumerate(All.edits.ijz):
                if l > 0:
                    hist_str += ', '
                iEd[l] = j
                jEd[l] = i
                zEd[l] = rgVar[i, j]
                rgVar[i, j] = z
                hist_str += repr((j, i, zEd[l].item(), rgVar[i, j].item()))
            print(hist_str.replace(': ', ':\n').replace('), ', ')\n'))
            hist_str = time.ctime(time.time()) + ' ' \
                + ' '.join(sys.argv) \
                + ' ' + hist_str
            if 'history' not in rg.ncattrs():
                rg.history = hist_str
            else:
                rg.history = rg.history + ' | ' + hist_str
        # write editsFile even if no edits, so editsFile will match outFile
        print('Writing list of edits to text file "' + editsFile +
              '" (this can be used with --apply).')
        try:
            with open(editsFile, 'wt') as edfile:
                edfile.write('editTopo.py edits file version 1\n')
                edfile.write(
                    '#\n# This file can be used as an argument for editTopo.py --apply\n#\n'
                )
                edfile.write('# created: ' + time.ctime(time.time()) + '\n')
                edfile.write('# by: ' + pwd.getpwuid(os.getuid()).pw_name +
                             '\n')
                edfile.write('# via: ' + ' '.join(sys.argv) + '\n#\n')
                if All.edits.ijz:
                    ii, jj, _ = zip(*All.edits.ijz)
                    news = [rgVar[i, j].item() for (i, j, _) in All.edits.ijz]
                    olds = [
                        fullData.height[i, j].item()
                        for (i, j, _) in All.edits.ijz
                    ]
                    iiwidth = max([len(repr(x)) for x in ii], default=0) + 2
                    jjwidth = max([len(repr(x)) for x in jj], default=0) + 2
                    oldwidth = max([len(repr(x)) for x in olds], default=0) + 2
                    edfile.write('# ' + \
                                 'i'.rjust(jjwidth-2) +  # swaps meaning of i & j
                                 'j'.rjust(iiwidth) +    # ditto
                                 '  ' +
                                 'old'.ljust(oldwidth) +
                                 'new' + '\n')
                    for (i, j, old, new) in zip(ii, jj, olds, news):
                        edfile.write(
                            repr(j).rjust(jjwidth) +  # swaps meaning of i & j
                            repr(i).rjust(iiwidth) +  # ditto
                            '  ' + repr(old).ljust(oldwidth) + repr(new) +
                            '\n')
                else:
                    edfile.write('#    i    j    old    new\n')
        except:
            error('There was a problem creating "' + editsFile + '".')
        rg.close()
Esempio n. 6
0
def main():
    global gc_args, g_curve_x, g_curve_y, gc_fig, gc_ax

    gc_args = retrieve_args()

    g_curve_x = np.linspace(g_left_x, g_right_x, gc_sl_n)
    g_curve_y = np.array([g_sl_valinit] * gc_sl_n)

    gc_fig = plt.figure(figsize=(12, 12))
    gc_ax = gc_fig.add_subplot(111)

    axis_color, hover_color = 'lightgoldenrodyellow', '0.975'
    sl_hstep, sl_x, sl_y, sl_w, sl_h = 0.163, 0.23, 0.05, 0.02, 0.4
    btn_h, btn_top, btn_vstep = 0.04, 0.62, 0.06
    lp_x, lp_w = 0.025, 0.15
    txt_top, txt_w, txt_h, txt_hstep = 0.91, 0.279, 0.029, 0.298
    txt_label_color = 'brown'

    # * * Create widgets * *

    # Adjust the subplots region to leave some space for the sliders and buttons
    gc_fig.subplots_adjust(left=sl_x, bottom=0.50)

    for i, slider in enumerate(gc_sl):
        sl_ax = gc_fig.add_axes([sl_x + i * sl_hstep, sl_y, sl_w, sl_h],
                                facecolor=axis_color)
        gc_sl[i] = i, VertSlider(sl_ax,
                                 f"S{i}:",
                                 g_bottom_y,
                                 g_top_y,
                                 valinit=g_sl_valinit)

    # Draw the initial plot
    plot_curve()

    for slider in gc_sl:
        i, s = slider
        s.on_changed(sliders_on_changed)

    tp = TextBox(gc_fig.add_axes([lp_x, txt_top + 0.035, 0.875, txt_h]),
                 '',
                 initial=Path.cwd())  # Current directory.
    tp.label.set_color(txt_label_color)

    ti = TextBox(gc_fig.add_axes([lp_x, txt_top, txt_w, txt_h]),
                 'I',
                 initial=g_inc)  # Include target directory.
    ti.on_text_change(inc_textbox_on_text_change)
    ti.label.set_color(txt_label_color)

    ts = TextBox(gc_fig.add_axes([lp_x + txt_hstep, txt_top, txt_w, txt_h]),
                 'S',
                 initial=g_src)  # Source target directory.
    ts.on_text_change(src_textbox_on_text_change)
    ts.label.set_color(txt_label_color)

    tn = TextBox(gc_fig.add_axes([lp_x + txt_hstep * 2, txt_top, txt_w,
                                  txt_h]),
                 'N',
                 initial=g_name)  # File and function name stem.
    tn.on_text_change(name_textbox_on_text_change)
    tn.label.set_color(txt_label_color)
    tn.text_disp.set_color('black')

    if gc_args.function_name:
        tn.set_val(Path(gc_args.function_name).resolve().stem)

    tr = TextBox(gc_fig.add_axes([lp_x, 0.852, lp_w, txt_h]),
                 'R',
                 initial=str(g_range))  # Function range.
    tr.on_text_change(range_textbox_on_text_change)
    tr.label.set_color(txt_label_color)

    br = Button(gc_fig.add_axes([lp_x, btn_top - btn_vstep * 2, lp_w, btn_h]),
                'Reset',
                color=axis_color,
                hovercolor=hover_color)
    br.on_clicked(reset_button_on_clicked)

    be = Button(gc_fig.add_axes([lp_x, btn_top, lp_w, btn_h]),
                'Export',
                color=axis_color,
                hovercolor=hover_color)
    be.on_clicked(export_button_on_clicked)

    bi = Button(gc_fig.add_axes([lp_x, btn_top - btn_vstep, lp_w, btn_h]),
                'Import',
                color=axis_color,
                hovercolor=hover_color)
    bi.on_clicked(import_button_on_clicked)

    rbd = RadioButtons(gc_fig.add_axes([lp_x, 0.68, lp_w, 0.15],
                                       facecolor=axis_color),
                       ('128', '256', '512', '1024'),
                       active=0)
    rbd.on_clicked(domain_radios_on_clicked)

    # * * Run the show * *
    plt.show()
Esempio n. 7
0
class SPPEditor:

    epsilon = 7  # max absolute pixel distance to count as a hit

    def __init__(self, x, y, info=None, x_pos=None, y_pos=None):
        plt.ion()
        self.fig, self.ax = plt.subplots()
        if info is not None:
            if not isinstance(info, str):
                info = str(info)
            self.fig.suptitle(info, fontsize=16)
        self.fig.set_figheight(6)
        self.fig.set_figwidth(10)
        plt.grid()
        plt.subplots_adjust(bottom=0.2)
        self.x = x
        self.y = y
        self._ind = None  # the active point index
        (self.basedata, ) = self.ax.plot(self.x, self.y)
        if (x_pos is None) or (y_pos is None):
            x_pos = []
            y_pos = []
        self._setup_positions(x_pos, y_pos)
        self.fig.canvas.mpl_connect("key_press_event", self.key_press_callback)
        self.fig.canvas.mpl_connect("button_release_event",
                                    self.button_release_callback)
        self.axbox = plt.axes([0.1, 0.05, 0.8, 0.1])
        self.text_box = TextBox(self.axbox,
                                "Delay [fs]",
                                initial="",
                                color="silver",
                                hovercolor="whitesmoke")
        self.text_box.on_submit(self.submit)
        self.text_box.on_text_change(self.text_change)

    def _setup_positions(self, x_pos, y_pos):
        self.x_pos, self.y_pos = x_pos, y_pos
        (self.points, ) = self.ax.plot(self.x_pos, self.y_pos, "ko")

    def _show(self):
        plt.show(block=True)

    def _get_textbox(self):
        return self.text_box

    def _get_ax(self):
        return self.ax

    def submit(self, delay):
        try:
            delay = re.sub(r"[^0-9\.,\-]", "", delay)
            self.delay = float(delay)
        except ValueError:
            pass

    def text_change(self, delay):
        try:
            delay = re.sub(r"[^0-9\.,\-]", "", delay)
            self.delay = float(delay)
        except ValueError:
            pass

    # TODO: Config should decide how to treat missing values.
    def get_data(self):
        positions, _ = self.points.get_data()
        if not hasattr(self, "delay"):
            return np.array([]), np.array([])
        if positions.size != 0:
            self.delay = np.ones_like(positions) * self.delay
        return self.delay, positions

    def button_release_callback(self, event):
        """whenever a mouse button is released"""
        if event.button != 1:
            return
        self._ind = None

    def get_ind_under_point(self, event):
        """
        Get the index of the selected point within the given epsilon tolerance
        """

        # We use the pixel coordinates, because the axes are usually really
        # differently scaled.
        if event.inaxes is None:
            return
        if event.inaxes in [self.ax]:
            try:
                x, y = self.points.get_data()
                xy_pixels = self.ax.transData.transform(np.vstack([x, y]).T)
                xpix, ypix = xy_pixels.T

                # return the index of the point iff within epsilon distance.
                d = np.hypot(xpix - event.x, ypix - event.y)
                (indseq, ) = np.nonzero(d == d.min())
                ind = indseq[0]

                if d[ind] >= self.epsilon:
                    ind = None
            except ValueError:
                return
            return ind

    def key_press_callback(self, event):
        """whenever a key is pressed"""
        if not event.inaxes:
            return
        if event.key == "d":
            if event.inaxes in [self.ax]:
                ind = self.get_ind_under_point(event)
            else:
                ind = None
            if ind is not None:
                self.x_pos = np.delete(self.x_pos, ind)
                self.y_pos = np.delete(self.y_pos, ind)
                self.points.set_data(self.x_pos, self.y_pos)

        elif event.key == "i":
            if event.inaxes in [self.ax]:
                self.x_pos = np.append(self.x_pos, event.xdata)
                self.y_pos = np.append(self.y_pos, event.ydata)
                self.points.set_data(self.x_pos, self.y_pos)

        if self.points.stale:
            self.fig.canvas.draw_idle()
Esempio n. 8
0
class Table(object):
    """infinite table where every rows is being RL of the previous one and   """
    cfill = LL('./fill.so').fill
    cfill.argtypes = [
        ctypes.c_int,
        ndpointer(ctypes.c_int, ndim=1),
        ctypes.c_int,
        ctypes.c_int,
        ndpointer(ctypes.c_int, ndim=2, flags='CONTIGUOUS'),
    ]
    cfill.restypes = None

    VD = 5000
    HD = 2000

    def __init__(self, code=[1, 2], stype=0.):
        super(Table, self).__init__()
        self.code = code
        self.stype = stype
        self.field = np.zeros((
            self.VD,
            self.HD,
        ), dtype='int32')
        self.generate()
        self.cfill(self.genome.size, self.genome, self.VD, self.HD, self.field)
        self.field = self.genome[self.field]

    def generate(self):
        self.genome = [-1]
        l = len(self.code)
        dn = 0
        i = -1
        if self.stype == 'a':
            while dn < self.VD + self.HD:
                i += 1
                j = i + 1
                while (j % (l + 1) == 0):
                    j //= (l + 1)
                self.genome.append(self.code[j % (l + 1) - 1])
                if self.genome[-1] != self.genome[-2]:
                    dn += 1
        elif self.stype == 'd':
            while dn < self.VD + self.HD:
                i += 1
                j = i + 1
                while (j % 2 == 0):
                    j //= 2
                self.genome.append(self.code[((j - 1) // 2) % l])
                if self.genome[-1] != self.genome[-2]:
                    dn += 1
        else:
            while dn < self.VD + self.HD:
                i += 1
                self.genome.append(self.code[int(p * i) % l])
                if self.genome[-1] != self.genome[-2]:
                    dn += 1
        self.genome = np.array(self.genome[1:], dtype='int32')

    def __str__(self):
        s = ""
        for i in range(10):
            for j in range(50):
                s = s + str(self.field[i, j]) + ' '
            s = s + '\n'
        return s

    def plot(self, x=0, y=0):
        self.hoffset = x
        self.voffset = y
        self.step = 1

        axscale = plt.axes([0.25, 0.1, .55, 0.03], facecolor="green")
        self.slscale = Slider(axscale,
                              'Scale',
                              1,
                              2,
                              valinit=1.2,
                              valstep=0.01)
        self.slscale.on_changed(self.update)
        d = int(10**self.slscale.val)
        self.imsize = d

        self.ax = plt.axes([0.25, 0.2, .65, 0.7])
        self.im = self.ax.imshow([[.0]],
                                 cmap=LinearSegmentedColormap.from_list(
                                     "thiscmap",
                                     colorlist[:max(self.code)],
                                     N=25))
        self.cbar = self.ax.figure.colorbar(self.im, ax=self.ax)
        self.cbar.set_ticks(self.code)

        axright = plt.axes([0.14, 0.35, 0.04, 0.04])
        axleft = plt.axes([0.06, 0.35, 0.04, 0.04])
        axup = plt.axes([0.1, 0.39, 0.04, 0.04])
        axdown = plt.axes([0.1, 0.31, 0.04, 0.04])
        axstep = plt.axes([0.1, 0.35, 0.04, 0.04])
        self.tbstep = TextBox(axstep, ' ', initial=' 1')
        self.tbstep.on_text_change(self.upstep)
        self.bright = Button(axright,
                             r'$\rightarrow$',
                             color="red",
                             hovercolor='green')
        self.bleft = Button(axleft,
                            r'$\leftarrow$',
                            color="red",
                            hovercolor='green')
        self.bup = Button(axup, r'$\uparrow$', color="red", hovercolor='green')
        self.bdown = Button(axdown,
                            r'$\downarrow$',
                            color="red",
                            hovercolor='green')
        self.bright.on_clicked(self.fright)
        self.bleft.on_clicked(self.fleft)
        self.bup.on_clicked(self.fup)
        self.bdown.on_clicked(self.fdown)

        self.update(0)

    def upstep(self, txt):
        self.step = int(eval(self.tbstep.text))

    def fleft(self, val):
        self.hoffset = max(self.hoffset - self.step, 0)
        self.update(val)

    def fright(self, val):
        self.hoffset = min(self.HD - self.imsize, self.hoffset + self.step)
        self.update(val)

    def fup(self, val):
        self.voffset = max(self.voffset - self.step, 0)
        self.update(val)

    def fdown(self, val):
        self.voffset = min(self.VD - self.imsize, self.voffset + self.step)
        self.update(val)

    def update(self, val):
        d = int(10**self.slscale.val)
        self.imsize = d
        x = self.hoffset
        y = self.voffset
        sideticks = [.5 / d - .5 + i / d for i in range(d)]

        self.ax.set_xticks(sideticks)
        self.ax.set_yticks(sideticks)
        self.ax.set_xticklabels([str(x)] + [' '] * (d - 2) + [str(x + d)])
        self.ax.set_yticklabels([str(y)] + [' '] * (d - 2) + [str(y + d)])
        self.im.set_data(self.field[y:y + d:1, x:x + d])
        self.im.autoscale()

        plt.draw()
    def calibrate(self):

        plt.clf()
        self.img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)
        plt.imshow(self.img)

        self.anounce(
            'You will define a rectangle in the image \n by selecting its corners, click to begin'
        )
        plt.waitforbuttonpress()

        while True:
            pts = []
            while len(pts) < 3:
                self.anounce(
                    'Select 4 corners of the rectangle in the following order:\n top left, top right, bottom left, bottom right'
                )
                pts = np.asarray(plt.ginput(4, timeout=-1))
                if len(pts) < 4:
                    self.anounce('Too few points, starting over')

            self.anounce('Happy? Key click for yes, mouse click for no')

            fill_pts = np.copy(pts)
            fill_pts[[2, 3]] = fill_pts[[3, 2]]
            print(pts)
            print(fill_pts)
            ph = plt.fill(fill_pts[:, 0], fill_pts[:, 1], 'r', lw=2)

            if plt.waitforbuttonpress():
                break

            for p in ph:
                p.remove()

        pts = pts.tolist()
        for pt in pts:
            pt[0] = round(pt[0])
            pt[1] = round(pt[1])
        self.corners = pts

        def submitWidth(text):
            self.width = eval(text)

        def submitHeight(text):
            self.height = eval(text)

        def submit(event):
            print("Corners: {}".format(self.corners))
            print("Width is {}m".format(self.width))
            print("Height is {}m".format(self.height))
            print("Calibration is complete")
            plt.close('all')

        self.anounce(
            "Enter the width and height (in metres) in the fields below \n Width: top left corner to top right corner \n Height: top left corner to bottom left corner"
        )
        width_box = TextBox(plt.axes([0.2, 0.02, 0.2, 0.05]),
                            'Width:',
                            initial='')
        width_box.on_text_change(submitWidth)
        height_box = TextBox(plt.axes([0.6, 0.02, 0.2, 0.05]),
                             'Height:',
                             initial='')
        height_box.on_text_change(submitHeight)

        submit_button = Button(plt.axes([0.85, 0.02, 0.1, 0.05]), 'Submit')
        submit_button.on_clicked(submit)

        plt.show()
Esempio n. 10
0
# Method to start the animation from the start button
def play_ani(id):
    animation.event_source.start()


# Method to stop the animation from the stop button
def pause_ani(id):
    animation.event_source.stop()


# calls updating functions for sliders and text box
s_E.on_changed(update_e)
s_L.on_changed(update_l)
s_rp.on_changed(update_rp)
s_e.on_changed(update_ecc)
text_bot_E.on_text_change(submit_E)
text_bot_L.on_text_change(submit_L)
pause.on_clicked(pause_ani)
play.on_clicked(play_ani)


# Method defining clear plots used for blit mapping in animation
def init_func():
    a2.set_data([], [])
    orbit_dot.set_data([], [])
    orbit.set_data([], [])
    h_xx.set_data([], [])
    h_xx_dot.set_data([], [])
    h_xy_dot.set_data([], [])
    h_xy.set_data([], [])
    return a2, orbit, orbit_dot, \
Esempio n. 11
0
class DelaunayTris:
    def __init__(self, points=[]):
        self.points = points
        self.cells = []  #Faces for halfedge data structure

        self.target_point = -1

        self.fig = plt.figure()
        self.ax = self.fig.gca(projection="3d")
        plt.subplots_adjust(bottom=0.2)
        # self.cid_press = self.ax.figure.canvas.mpl_connect('button_press_event', self.on_press)
        # self.cid_release = self.ax.figure.canvas.mpl_connect('button_release_event', self.on_release)
        # self.cid_motion = self.ax.figure.canvas.mpl_connect('motion_notify_event', self.on_motion)

        axprepare = plt.axes([0.7, 0.05, 0.15, 0.075])
        axsave = plt.axes([0.5, 0.05, 0.15, 0.075])
        axopen = plt.axes([0.3, 0.05, 0.15, 0.075])
        axfile_name = plt.axes([0.05, 0.05, 0.15, 0.075])

        self.bprepare = Button(axprepare, 'Prepare stl')
        self.bprepare.on_clicked(self.prepare)

        self.bsave = Button(axsave, 'Save Points')
        self.bsave.on_clicked(self.save_points)

        self.bopen = Button(axopen, 'Open Points')
        self.bopen.on_clicked(self.open_points)

        self.points_file = "points.p"
        self.textbox_file_name = TextBox(axfile_name, "", initial="points.p")
        self.textbox_file_name.on_text_change(self.update_file_name)

        self.triangulate_vis()

    def load_points(self, points):
        """Function to load and display an array of points"""
        if len(points) == 0:
            print("No points provided!")
            return
        print("Loading " + str(len(points)) + " " + str(len(points[0])) + " dimensional points")
        self.points = points
        self.triangulate_vis()

    def open_points(self, event):
        file_points = pickle.load(open(self.points_file, "rb"))
        self.load_points(file_points)

    def save_points(self, event):
        pickle.dump(self.points, open(self.points_file, "wb+"))


    def update_file_name(self, text):
        self.points_file = text

    def prepare(self, event, filePath="out.stl"):
        """
        Prepares a 3D printable model of the given Voronoi diagram.

        :param voronoi: computed Voronoi attributes
        :return: unused
        """
        self.subdivideFace(0)
        print("Preparing")
        # points = np.array([[6, 4, 2], [9, 5, 8], [9, 1, 9], [8, 9, 1], [3, 8, 8], [2, 6, 2], [8, 2, 10], [3, 6, 1], [9, 8, 9],
        #       [7, 7, 4],
        #       [2, 10, 5], [4, 3, 10], [5, 3, 9], [4, 7, 4], [3, 6, 7], [7, 4, 3], [6, 4, 9], [5, 8, 4], [2, 9, 10],
        #       [7, 8, 6], [9, 2, 7], [6, 10, 7], [9, 9, 3], [2, 9, 4], [5, 9, 6], [4, 8, 9], [9, 1, 2], [6, 9, 1],
        #       [10, 6, 5], [1, 9, 9], [2, 1, 3], [10, 1, 5], [4, 10, 2]])


        output = open(filePath, "w")
        output.write("solid Voronoi\n")
        faces = []
        for indexList in self.voronoi.ridge_vertices:
            if -1 not in indexList:
                face = []
                for index in indexList:
                    face.append(self.voronoi.vertices[index])
                faces.append(np.asarray(face))
        # I'm thinking order could be important for the triangle vertices and is being lost?
        for face in faces:
            triangles = self.triangulate(face)
            # compute a normal vector for this face
            normal = np.cross(face[1] - face[0], face[2] - face[1])
            # process points in batches of 3 (points of a triangle)
            for i in range(0, len(triangles), 3):
                # begin a new STL triangle
                output.write("facet normal {} {} {}\n".format(normal[0], normal[1], normal[2]))
                output.write("outer loop\n")
                trianglePoints = triangles[i:i + 3]
                print(trianglePoints)
                for j in range(0, 3):
                    output.write("vertex {} {} {}\n".format(trianglePoints[j][0], trianglePoints[j][1], trianglePoints[j][2]))
            output.write("endloop\nendfacet\n")

        output.write("endsolid Voronoi\n")

    def triangulate(self, points):
        """
        Splits a 3D planar facet into triangles.
        :param points: vertex coordinates for a planar face in 3D
        :return: vertices of the divided plane
        """
        # move all points by this much so the shape has to be touching the origin
        average_point = np.zeros((1,3))
        for point in points:
            average_point += point
        average_point /= len(points)
        return np.append(points, average_point, axis=0)

    def subdivideOnce(self, points):
        """
        Given the vertices of a 2D shape located in the XY coordinate plane, subdivides the inner area into triangular
        shapes (necessary for 3D printing) using the Delaunay triangulation.

        :param points: a numpy array of input points; this array is modified in-place
        :return: unused
        """

        from scipy.spatial import Delaunay

        triangulation = Delaunay(points)
        trianglePoints = []
        for indexList in triangulation.simplices:
            for index in indexList:
                trianglePoints.append(points[index])
        return trianglePoints

    def chopOffThirdDimension(self, npArrayOf3DPoints):
        return np.delete(npArrayOf3DPoints, 2, 1)


    def addEmptyThirdDimension(self, npArrayOf2DPoints):
        return np.insert(npArrayOf2DPoints, 2, values=0, axis=1)


    def rotateToPlane(self, points, normalVectorOriginal, normalVectorNew, isAtOrigin=True, offset=np.array([0, 0, 0])):
        """
        Rotates a shape defined by its vertices about a defined axis. Useful for putting a planar shape located in 3D into
        a coordinate plane or restoring it to its original location in 3D space.

        :param points:                  list of points to rotate about an axis
        :param normalVectorOriginal:    vector (as numpy array) which is normal to the original plane
        :param normalVectorNew:         vector (as numpy array) which is normal to the desired plane
        :param isAtOrigin:              True if the shape defined by the given points is located at the origin
        :param offset:                  a vector (as numpy array) offset which is either subtracted from the given points or
                                            added to the resulting points if isAtOrigin is False or True
        :return: new numpy array of points rotated about the defined axis
        """
        from math import sqrt
        if not isAtOrigin:
            # translate points by the offset, typically moving the shape to the origin
            points = points - offset
        M = normalVectorOriginal
        N = normalVectorNew
        # compute costheta using the geometric dot product
        costheta = np.dot(M, N) / (np.linalg.norm(M) * np.linalg.norm(N))
        # cross the two axis vectors, make the result a unit vector
        mncross = np.cross(M, N)
        axis = mncross / np.linalg.norm(mncross)
        # shorten variable names (s = sintheta)
        c = costheta
        s = sqrt(1 - c * c)
        C = 1 - c
        [x, y, z] = axis

        # rotation matrix via https://en.wikipedia.org/wiki/Rotation_matrix#Axis_and_angle
        rmat = np.array([[x * x * C + c, x * y * C - z * s, x * z * C + y * s],
                         [y * x * C + z * s, y * y * C + c, y * z * C - x * s],
                         [z * x * C - y * s, z * y * C + x * s, z * z * C + c]])

        if isAtOrigin:
            # rotate all of the points and then move the shape back to its original location
            return list(map(lambda point: np.dot(rmat, point) + offset, points))
        else:
            # rotate all of the points; will only work correctly if the shape is at the origin
            return list(map(lambda point: np.dot(rmat, point), points))


    def triangulate_vis(self):
        #self.subdivideFace(self.points)
        self.plotVoronoi(self.points)


    def subdivideFace(self, face_index):
        """
        Given the index of a 3D face located in self.faces, subdivides the inner area into triangular
        shapes (necessary for 3D printing)
        :param points: a numpy array of input points; this array is modified in-place
        :return: array of tris
        """

        face = self.faces[face_index]
        verts = face.getVertices()

        print(verts)
        center = barycenter(verts)

        tris = []
        for i in range(len(face.halfedges)):
            cur_point = face.halfedges[i].vertex
            next_point = face.halfedges[i].next.vertex
            tris.append([cur_point.location, next_point.location, center])

        print(tris)
        return
        triangulation = Delaunay(points)

        trianglePoints = []
        for indexList in triangulation.simplices:
            for index in indexList:
                trianglePoints.append(points[index])

        points = trianglePoints
        return trianglePoints



    def plotVoronoi(self, points):
        """
        Display the subdivided face in 2D with matplotlib.

        Adapted from: https://stackoverflow.com/a/24952758
        :param points: points to plot, connecting them by simultaneously visualizing the Delaunary triangulation
        :return: unused
        """

        self.cells = []
        self.voronoi = Voronoi(points)

        vertices = []
        for i in range(len(self.voronoi.vertices)):
            location = [self.voronoi.vertices[i, 0],self.voronoi.vertices[i, 1],self.voronoi.vertices[i, 2]]
            vertices.append(halfedge.vertex(location=location))

        for r in range(len(self.voronoi.regions)):
            #self.regions.append(halfedge.face())
            cell = halfedge.cell()
            faces = []
            region = self.voronoi.regions[r]
            region_points = []
            region_point_indices = []

            for index in region:
                if index == -1 or index >= len(vertices):
                    break
                region_points.append(vertices[index].location)
                region_point_indices.append(vertices[index])

            if len(region_points) != len(region) or len(region) < 3:
                continue

            hull = ConvexHull(region_points)
            for simplex in hull.simplices:
                face = halfedge.face()

                edges = []
                for i in range(len(simplex)):
                    edges.append(halfedge.halfedge(vertex=vertices[simplex[i]], face=face))
                    if i > 0:
                        edges[i].previous = edges[i-1]  #Previous edge is edge before this one in the list
                        edges[i-1].next = edges[i]      #This edge is the next edge for the one before
                        edges[i-1].vertex.halfedge = edges[i]   #This edge is the outgoing edge for the last vertex
                    if i == len(simplex)-1:
                        edges[0].previous = edges[len(simplex)-1]
                        edges[len(simplex)-1].next = edges[0]
                        edges[len(simplex)-1].vertex.halfedge = edges[0]
                face.halfedges = edges
                faces.append(face)
            cell.faces = faces
            self.cells.append(cell)
            print(len(self.cells[0].faces))

        #     face = halfedge.face()
        #
        #     edges = []
        #     for i in range(len(region)):
        #         vertex_index = region[i]
        #         if vertex_index == -1:
        #             break
        #
        #         edges.append(halfedge.halfedge(vertex=vertices[vertex_index], face=face))
        #         if i > 0:
        #             edges[i].previous = edges[i-1]  #Previous edge is edge before this one in the list
        #             edges[i-1].next = edges[i]      #This edge is the next edge for the one before
        #             edges[i-1].vertex.halfedge = edges[i]   #This edge is the outgoing edge for the last vertex
        #         if i == len(region)-1:
        #             edges[0].previous = edges[len(region)-1]
        #             edges[len(region)-1].next = edges[0]
        #             edges[len(region)-1].vertex.halfedge = edges[0]
        #
        #     if len(edges) < len(region):
        #         continue
        #     else:
        #         face.halfedges = edges
        #     self.cells.append(face)
        #
        #
        # #Algorithm to fill in the opposite field for halfedges
        # for f in range(len(self.cells)):        #Loop through every face
        #     halfedges = self.faces[f].halfedges         #Get the halfedges that make up the face
        #     for edge in halfedges:              #Do this for every halfedge
        #         if edge.opposite == None:       #only if the edge doesn't have an opposite
        #             next = edge.next            #Get the next edge
        #             vertex_next_edges = edge.vertex.halfedges       #List of outgoing edges from next vertex
        #             for vertex_next_edge in vertex_next_edges:      #loop through these outgoing edges
        #                 if not(vertex_next_edge == next) and vertex_next_edge.vertex == edge.previous.vertex:       #if the outgoing edge isn't the next halfedge and it goes
        #                                                                                                             #into the same vertex that the current edge originates from
        #                     edge.opposite = vertex_next_edge        #Set the opposite of the current edge
        #                     vertex_next_edge.opposite = edge        #to the outgoing edge




        self.ax.clear()

        #PLOTTING FROM SCIPY VORONOI DATA STRUCTURE
        self.voronoi_points = self.ax.scatter(self.voronoi.points[:,0],self.voronoi.points[:,1],self.voronoi.points[:,2])
        self.ax.scatter(self.voronoi.vertices[:,0],self.voronoi.vertices[:,1],self.voronoi.vertices[:,2], 'r')

        for i in range(len(self.voronoi.ridge_vertices)):
            points = np.zeros((0,3))
            for j in range(len(self.voronoi.ridge_vertices[i])):
                index = self.voronoi.ridge_vertices[i][j]
                if index >= 0:
                    points = np.append(points, np.array([[self.voronoi.vertices[index, 0],self.voronoi.vertices[index, 1],self.voronoi.vertices[index, 2]]]), axis=0)
            if len(points) > 1:
                self.ax.plot(points[:, 0], points[:, 1], points[:, 2], 'b')

        #PLOTTING FROM HALFEDGE DATA STRUCTURE
        for cell in self.cells:
            print(len(cell.faces))
            for face in cell.faces:
                if len(face.halfedges)<1:
                    continue
                start_edge = face.halfedges[-1]
                cur_edge = start_edge.next
                while True:
                    locations = np.array([cur_edge.previous.vertex.location, cur_edge.vertex.location])
                    self.ax.plot(locations[:,0], locations[:,1], locations[:,2], 'go-')

                    if cur_edge == start_edge:
                        break
                    cur_edge = cur_edge.next

        plt.show()

    def on_press(self, event):
        print(event.inaxes)
        print(event.xdata)
        print(event.ydata)
        print(self.voronoi_points.axes)
        if event.inaxes != self.voronoi_points.axes:
            print("Not in axis!")
            return
        print(self.target_point)
        min_dist_squared = 1e10
        close_point = -1;
        for i in range(len(self.points)):
            dist=(event.xdata - self.points[i][0])*(event.xdata - self.points[i][0]) + (event.ydata - self.points[i][1])*(event.ydata - self.points[i][1])
            if dist < min_dist_squared:
                min_dist_squared = dist
                close_point = i
        if sqrt(min_dist_squared)<0.1:
            self.target_point = close_point
            print(self.target_point)
        else:
            self.target_point = -1
            print(self.target_point)


    def on_release(self, event):
        if event.inaxes != self.voronoi_points.axes:
            print("Not in axis!")
            return
        if self.target_point == -1:
            self.add_point(event)
        else:
            self.points[self.target_point] = [event.xdata, event.ydata, 0]
            self.target_point = -1

            self.triangulate_vis()

    def on_motion(self, event):
        print(self.target_point)
        if event.inaxes != self.voronoi_points.axes:
            return
        if self.target_point >= 0:
            self.points[self.target_point] = [event.xdata, event.ydata, 0]
            self.triangulate_vis()


    def add_point(self, event):
        """
        Function to add a point on a mouse click
        :param event:
        :return:
        """

        self.points = np.append(points, [[event.xdata, event.ydata, 0]], axis=0)
        print(self.points)

        self.triangulate_vis()
Esempio n. 12
0
    plt.subplots_adjust(bottom=0.2)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    
    axStep = plt.axes([0.15, 0.025, 0.5, 0.05])
    sStep = Slider(axStep,"intervals",1,999,valinit=2,valfmt="%.3d")
    sStep.on_changed(update)
    WIDTH = 0.2
    axboxN = plt.axes([0.15, 0.075, WIDTH, 0.035])
    text_boxN = TextBox(axboxN, 'Point #')
    axboxX = plt.axes([0.15+0.1+WIDTH, 0.075, WIDTH, 0.035])
    text_boxX = TextBox(axboxX, 'Point X')
    axboxY = plt.axes([0.15+(0.1+WIDTH)*2, 0.075, WIDTH, 0.035])
    text_boxY = TextBox(axboxY, 'Point Y')

    text_boxN.on_text_change(textNVerify)
    text_boxX.on_text_change(textXVerify)
    text_boxY.on_text_change(textYVerify)


    addax = plt.axes([0.85, 0.025, 0.1, 0.04])
    button = Button(addax, 'Add')
    button.on_clicked(newPoint)
    moveax = plt.axes([0.75, 0.025, 0.1, 0.04])
    buttonMove = Button(moveax, 'Set')
    buttonMove.on_clicked(setPoint)

    patches = []
    selectedControl = None
    controlPoints = [(1,1),(2,3),(4,3),(3,1)]
    minX,minY = maxX,maxY = controlPoints[0]