Esempio n. 1
0
	def plotPattern(self, type, stats):
		"""Common graph function for read and writes, type is either 'read' or 'write'. Stats is a bool that indicates whether to print statistics or not."""
		names = self.filenames(type) #unique names of used files				
		if not self.fileName in names:
			print self.fileName, "is not in our data set"			 
			return					
									
		if self.data == None:
			self.__prepareData()
			
		self.axes.clear()
		
		graphdata = np.column_stack((self.data['off'], self.data['start'], self.data['off']+ self.data['size'], self.data['start'] + self.data['dur']))

		lineSegments = LineCollection(graphdata.reshape(-1,2,2), linewidths=(4));
		lineSegments.set_picker(True)		
		self.lineCol = self.axes.add_collection(lineSegments)
	
		maxEnd = max(graphdata[:,2])
		maxTime = max(graphdata[:,3])
													
		if stats:
			  self.__printStats()

		self.axes.xaxis.set_major_formatter(FuncFormatter(self.__xFormater))
		self.axes.grid(color='grey', linewidth=0.5)
		self.axes.set_xlabel("file offset (kiB)", fontsize=16);
		self.axes.set_ylabel("time (ms)", fontsize=16);
		self.axes.set_xlim(0, maxEnd);						
		self.axes.set_ylim(self.startTime, maxTime);
		
		self.fig.suptitle('%s' % self.__elideText(self.fileName), fontsize=9)
#			ticks = self.__getTicks(0, maxEnd)
#			plt.xticks(ticks);
		self.fig.autofmt_xdate()			
Esempio n. 2
0
def _add_plot(fig, ax, plot_data, color_data, pkeys, lw=1, cmap='_auto',
              alpha=1.0):
  colors = color_data.color
  if cmap == '_auto':
    cmap = None
    if color_data.is_categorical:
      colors = np.take(COLOR_CYCLE, colors, mode='wrap')

  if plot_data.scatter:
    data, = plot_data.trajs
    artist = ax.scatter(*data.T, marker='o', c=colors, edgecolor='none',
                        s=lw*20, cmap=cmap, alpha=alpha, picker=5)
  else:
    # trajectory plot
    if hasattr(colors, 'dtype') and np.issubdtype(colors.dtype, np.number):
      # delete lines with NaN colors
      mask = np.isfinite(colors)
      if mask.all():
        trajs = plot_data.trajs
      else:
        trajs = [t for i,t in enumerate(plot_data.trajs) if mask[i]]
        colors = colors[mask]
      artist = LineCollection(trajs, linewidths=lw, cmap=cmap)
      artist.set_array(colors)
    else:
      artist = LineCollection(plot_data.trajs, linewidths=lw, cmap=cmap)
      artist.set_color(colors)
    artist.set_alpha(alpha)
    artist.set_picker(True)
    artist.set_pickradius(5)
    ax.add_collection(artist, autolim=True)
    ax.autoscale_view()
    # Force ymin -> 0
    ax.set_ylim((0, ax.get_ylim()[1]))

  def on_pick(event):
    if event.artist is not artist:
      return
    label = pkeys[event.ind[0]]
    ax.set_title(label)
    fig.canvas.draw_idle()

  # XXX: hack, make this more official
  if hasattr(fig, '_superman_cb'):
    fig.canvas.mpl_disconnect(fig._superman_cb[0])
  cb_id = fig.canvas.mpl_connect('pick_event', on_pick)
  fig._superman_cb = (cb_id, on_pick)
  return artist
Esempio n. 3
0
class SkeletonBuilder:
    def __init__(self, config_path):
        self.config_path = config_path
        self.cfg = read_config(config_path)
        # Find uncropped labeled data
        self.df = None
        found = False
        root = os.path.join(self.cfg["project_path"], "labeled-data")
        for dir_ in os.listdir(root):
            folder = os.path.join(root, dir_)
            if os.path.isdir(folder) and not any(
                    folder.endswith(s) for s in ("cropped", "labeled")):
                self.df = pd.read_hdf(
                    os.path.join(folder,
                                 f'CollectedData_{self.cfg["scorer"]}.h5'))
                row, col = self.pick_labeled_frame()
                if "individuals" in self.df.columns.names:
                    self.df = self.df.xs(col, axis=1, level="individuals")
                self.xy = self.df.loc[row].values.reshape((-1, 2))
                missing = np.flatnonzero(np.isnan(self.xy).all(axis=1))
                if not missing.size:
                    found = True
                    break
        if self.df is None:
            raise IOError("No labeled data were found.")

        self.bpts = self.df.columns.get_level_values("bodyparts").unique()
        if not found:
            warnings.warn(
                f"A fully labeled animal could not be found. "
                f"{', '.join(self.bpts[missing])} will need to be manually connected in the config.yaml."
            )
        self.tree = KDTree(self.xy)
        # Handle image previously annotated on a different platform
        sep = "/" if "/" in row else "\\"
        if sep != os.path.sep:
            row = row.replace(sep, os.path.sep)
        self.image = io.imread(os.path.join(self.cfg["project_path"], row))
        self.inds = set()
        self.segs = set()
        # Draw the skeleton if already existent
        if self.cfg["skeleton"]:
            for bone in self.cfg["skeleton"]:
                pair = np.flatnonzero(self.bpts.isin(bone))
                if len(pair) != 2:
                    continue
                pair_sorted = tuple(sorted(pair))
                self.inds.add(pair_sorted)
                self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :])))
        self.lines = LineCollection(self.segs,
                                    colors=mcolors.to_rgba(
                                        self.cfg["skeleton_color"]))
        self.lines.set_picker(True)
        self.show()

    def pick_labeled_frame(self):
        # Find the most 'complete' animal
        try:
            count = self.df.groupby(level="individuals", axis=1).count()
            if "single" in count:
                count.drop("single", axis=1, inplace=True)
        except KeyError:
            count = self.df.count(axis=1).to_frame()
        mask = count.where(count == count.values.max())
        kept = mask.stack().index.to_list()
        np.random.shuffle(kept)
        row, col = kept.pop()
        return row, col

    def show(self):
        self.fig = plt.figure()
        ax = self.fig.add_subplot(111)
        ax.axis("off")
        lo = np.nanmin(self.xy, axis=0)
        hi = np.nanmax(self.xy, axis=0)
        center = (hi + lo) / 2
        w, h = hi - lo
        ampl = 1.3
        w *= ampl
        h *= ampl
        ax.set_xlim(center[0] - w / 2, center[0] + w / 2)
        ax.set_ylim(center[1] - h / 2, center[1] + h / 2)
        ax.imshow(self.image)
        ax.scatter(*self.xy.T, s=self.cfg["dotsize"]**2)
        ax.add_collection(self.lines)
        ax.invert_yaxis()

        self.lasso = LassoSelector(ax, onselect=self.on_select)
        ax_clear = self.fig.add_axes([0.85, 0.55, 0.1, 0.1])
        ax_export = self.fig.add_axes([0.85, 0.45, 0.1, 0.1])
        self.clear_button = Button(ax_clear, "Clear")
        self.clear_button.on_clicked(self.clear)
        self.export_button = Button(ax_export, "Export")
        self.export_button.on_clicked(self.export)
        self.fig.canvas.mpl_connect("pick_event", self.on_pick)
        plt.show()

    def clear(self, *args):
        self.inds.clear()
        self.segs.clear()
        self.lines.set_segments(self.segs)

    def export(self, *args):
        inds_flat = set(ind for pair in self.inds for ind in pair)
        unconnected = [i for i in range(len(self.xy)) if i not in inds_flat]
        if len(unconnected):
            warnings.warn(
                f'Unconnected {", ".join(self.bpts[unconnected])}. '
                f"It is desirable that all bodyparts be connected for multi-animal projects."
            )
        self.cfg["skeleton"] = [
            tuple(self.bpts[list(pair)]) for pair in self.inds
        ]
        write_config(self.config_path, self.cfg)

    def on_pick(self, event):
        if event.mouseevent.button == 3:
            removed = event.artist.get_segments().pop(event.ind[0])
            self.segs.remove(tuple(map(tuple, removed)))
            self.inds.remove(tuple(self.tree.query(removed)[1]))

    def on_select(self, verts):
        self.path = Path(verts)
        self.verts = verts
        inds = self.tree.query_ball_point(verts, 5)
        inds_unique = []
        for lst in inds:
            if len(lst) and lst[0] not in inds_unique:
                inds_unique.append(lst[0])
        for pair in zip(inds_unique, inds_unique[1:]):
            pair_sorted = tuple(sorted(pair))
            self.inds.add(pair_sorted)
            self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :])))
        self.lines.set_segments(self.segs)
        self.fig.canvas.draw_idle()
Esempio n. 4
0
class PlotCtrl(PlotView):
    def __init__(self, panel):
        PlotView.__init__(self, panel)

        # stores the plot objects
        self.plots = []
        self.__plot_count = 0
        # self.highlighted_vertices = set()  # Rename to selected_vertices
        self.highlighted_vertices = []  # Keeps track of the highlighted vertices index
        self.marker = None  # Must be a matplotlib line2D object
        self.x_scatter_data, self.y_scatter_data = None, None  # Holds the scatter data for highlighting
        self.poly_list = None  # Holds the data for the plotted polygon
        self.highlight_color = "y"  # Yellow is used when highlighting
        self.color = "#0DACFF"  # The standard color for objects that are not highlighted
        self._color_converter = ColorConverter()
        self.line_collection = None  # Holds the line collection data
        self.highlighted_lines = []
        self.selected_lines = []

        # stores the axis objects
        self.__axis = []

        # matplotlib color cycle used to ensure primary and secondary axis are not displayed with the same color
        self.__color_cycle = color_cycle()

        # Used to be able to deactivate the canvas events later
        self._cid_press = None
        self._cid_release = None
        self._cid_scroll = None

    def activate_panning(self):
        self._cid_press = self.canvas.mpl_connect('button_press_event', self.on_canvas_clicked)
        self._cid_release = self.canvas.mpl_connect('button_release_event', self.on_canvas_released)

    def activate_zooming(self):
        self._cid_scroll = self.canvas.mpl_connect('scroll_event', self.on_canvas_mouse_scroll)

    def deactivate_panning(self):
        self.canvas.mpl_disconnect(self._cid_press)
        self.canvas.mpl_disconnect(self._cid_release)

    def deactivate_zooming(self):
        self.canvas.mpl_disconnect(self._cid_scroll)

    def clear_plot(self):

        # clear axis
        self.axes.clear()
        for ax in self.__axis:
            ax.cla()

        # reset the axis container
        self.__axis = []

        self.axes.grid()
        self.axes.margins(0)

        # clear the plot objects
        self.__plot_count = 0
        self.plots = []

        self.redraw()

    def getNextColor(self):
        return next(self.__color_cycle)

    def get_highlighted_vertices(self):
        """
        Returns the index and the coordinates of the highlighted vertices in a dictionary.
        Key is index, value are coordinates
        :return:
        """
        data = {}
        for vertex in self.highlighted_vertices:
            data[vertex] = [self.x_scatter_data[vertex], self.y_scatter_data[vertex]]
        return data

    def get_highlighted_polygons(self):
        """
        Returns all the highlighted polygons.
        Key is index, value are object
        :return: type(dict)
        """
        if not self.poly_list:  # Check if polygon has been plotted
            return {}  # No polygons have been plotted.

        data = {}
        for i in range(len(self.axes.collections)):
            if self.axes.collections[i].get_facecolor().all() == self.axes.collections[i].get_edgecolor().all():
                data[i] = self.axes.collections[i]
        return data

    def get_highlighted_lines(self):
        """
        Returns the highlighted lines index
        Key is index, value are indexes of segments
        :return:
        """
        if not self.line_collection:
            return {}  # No lines have been plotted

        lines = {}
        for line in self.highlighted_lines:
            lines[line] = self.line_segments[0]
        return lines

    def highlight_line(self, event):
        """
        Highlighted lines have value 1 in self.highlighted_lines
        Recolor the collection
        :param event:
        :return:
        """
        ind = event.ind[0]
        line_idx = self.segment_line[ind]

        if line_idx not in self.highlighted_lines:
            self.highlighted_lines.append(line_idx)
        else:
            self.highlighted_lines.remove(line_idx)

        highlight_idx = self.line_segments[line_idx]
        self.selected_lines[highlight_idx] = 1 - self.selected_lines[highlight_idx]
        lines, = event.artist.axes.collections
        lines.set_color(self.__line_colors[self.selected_lines])
        event.canvas.draw_idle()

    def highlight_polygon(self, pick_event):
        if pick_event.artist.get_facecolor()[0].all() == pick_event.artist.get_edgecolor()[0].all():
            pick_event.artist.set_facecolor(self.color)
            pick_event.artist.set_edgecolor(None)
        else:
            pick_event.artist.set_color(self.highlight_color)
        pick_event.artist.axes.figure.canvas.draw()

    def highlight_vertex(self, pick_event):
        """
        Only one marker can be used to highlight. self.marker is that one marker
        :param pick_event: matplotlib mouse pick event
        :return:
        """
        if not self.marker:
            self.marker, = pick_event.artist.axes.plot([], [], "o")  # Create a plot
            self.marker.set_color(self.highlight_color)  # Set color to yellow

        # Check if vertex has been highlighted
        if pick_event.ind[0] in self.highlighted_vertices:
            self.highlighted_vertices.remove(pick_event.ind[0])  # Remove highlight
        else:
            self.highlighted_vertices.append(pick_event.ind[0])  # Add highlight

        self.highlighted_vertices.sort()
        x = self._get_vertices_data_points(self.x_scatter_data, pick_event.ind[0])
        y = self._get_vertices_data_points(self.y_scatter_data, pick_event.ind[0])

        # Highlight only those in self.highlighted_vertices
        self.marker.set_data(x, y)
        pick_event.artist.axes.figure.canvas.draw()

    def _get_vertices_data_points(self, data, index):
        """
        :param data: x_data or y_data, type(tuple or list)
        :param index: index of the selected vertex, type(int)
        :return: type(list) contains the x & y values that are highlight and should be plotted
        """
        a = []
        for i in self.highlighted_vertices:
            a.append(data[i])

        return a

    def plot_dates(self, data, name, noDataValue, ylabel=""):
        """
        :param data: type([datetime, floats])
        :param name:
        :param noDataValue:
        :param ylabel:
        :return:
        """

        if len(data) == 0:
            return

        # unpack the dates, values and replace nodata with None
        dates, values = zip(*data)
        nvals = numpy.array(values, dtype=numpy.float)
        nvals[nvals == noDataValue] = None
        nvals[numpy.isnan(nvals)] = None

        p = self.axes.plot_date(dates, nvals, label=name, linestyle="None", marker=".")
        self.axes.set_ylabel(ylabel)

        # save each of the plots
        self.plots.extend(p)

        self.redraw()

    def plot_polygon(self, data, color):
        poly_list = []
        for item in data:
            reference = item.GetGeometryRef(0)
            points = numpy.array(reference.GetPoints())
            a = tuple(map(tuple, points[:, 0:2]))
            poly_list.append(a)

        self.poly_list = poly_list

        # Plot multiple polygons and add them to collection as individual polygons
        for poly in self.poly_list:
            p_coll = PolyCollection([poly], closed=True, facecolor=color, alpha=0.5, edgecolor=None, linewidths=(2,))
            p_coll.set_picker(True)  # Enable pick event
            self.axes.add_collection(p_coll, autolim=True)

    def plot_point(self, data, color):  # Rename to plot scatter
        # get x,y points
        x, y = zip(*[(g.GetX(), g.GetY()) for g in data])
        self.x_scatter_data, self.y_scatter_data = x, y
        collection = self.axes.scatter(x, y, marker="o", color=color, picker=True)
        return collection

    def plot_linestring(self, data, color):
        """
        A segment is from point A to point b. It is created from grabbing the previous point to the next point
        :param data: geometry object
        :param color:  # Hexadecimal
        :return:
        """
        segments = []
        self.line_segments = {}
        self.segment_line = {}
        index = 0  # Keeps track of how many lines to plot. Should match len(data)
        last_segment = 0
        points = []
        for geo_object in data:
            for point in geo_object.GetPoints():  # Remove the z coordinate
                points.append(point[:-1])
            self.line_segments[index] = range(last_segment, last_segment + len(points) - 1)
            for i in range(len(points) - 1):  # Create the segments
                segments.append([points[i], points[i + 1]])
                self.segment_line[last_segment + i] = index
            last_segment += len(points) - 1
            index += 1

        self.__line_colors = np.array(
            [self._color_converter.to_rgba(color), self._color_converter.to_rgba(self.highlight_color)])
        self.selected_lines = np.zeros(len(segments), dtype=int)  # Must be a np.zero array
        colors = self.__line_colors[self.selected_lines]
        self.line_collection = LineCollection(segments, pickradius=10, linewidths=2, colors=colors)
        self.line_collection.set_picker(True)
        self.axes.add_collection(self.line_collection)

    def plot_geometry(self, geometry_object, title, color=None):
        """
        A general plot method that will plot the respective type
        Must call redraw afterwards to have an effect
        :param geometry_object:
        :param title: title for the plot
        :param color: # Hexadecimal
        :return:
        """
        if not color:
            color = self.color

        if geometry_object[0].GetGeometryName().upper() == "POLYGON":
            self.plot_polygon(geometry_object, color)
        elif geometry_object[0].GetGeometryName().upper() == "POINT":
            self.plot_point(geometry_object, color)
        elif geometry_object[0].GetGeometryName().upper() == "LINESTRING":
            self.plot_linestring(geometry_object, color)
        else:
            raise Exception("plot_geometry() failed. Geometries must be POLYGON OR POINT")

        self.set_title(title)
        self.axes.grid(True)

        # If margin is 0 the graph will fill the plot
        self.axes.margins(0.1)

    def reset_highlighter(self):
        """
        Resets the variables needed to highlight
        :return:
        """
        self.marker = None
        self.highlighted_vertices = []
        self.x_scatter_data, self.y_scatter_data = None, None
        self.poly_list = None
        self.line_collection = None

    def set_line_width(self, width):
        """
        Sets the width of the lines plotted
        Does not work with scatter plot
        :param width: real number
        :return:
        """
        if not len(self.plots):
            return  # Nothing has been plotted

        for line in self.plots:
            line.set_linewidth(width)
        self.redraw()

    def on_canvas_clicked(self, event):
        self.toolbar.press_pan(event)

    def on_canvas_mouse_scroll(self, event):
        base_scale = 2.0
        cur_xlim = self.axes.get_xlim()
        cur_ylim = self.axes.get_ylim()
        cur_xrange = (cur_xlim[1] - cur_xlim[0]) * .5
        cur_yrange = (cur_ylim[1] - cur_ylim[0]) * .5
        xdata = event.xdata  # get event x location
        ydata = event.ydata  # get event y location
        if event.button == 'up':
            # deal with zoom in
            scale_factor = 1 / base_scale
        elif event.button == 'down':
            # deal with zoom out
            scale_factor = base_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print event.button

        # set new limits
        self.axes.set_xlim([xdata - cur_xrange * scale_factor,
                     xdata + cur_xrange * scale_factor])
        self.axes.set_ylim([ydata - cur_yrange * scale_factor,
                     ydata + cur_yrange * scale_factor])
        self.redraw()

    def on_canvas_released(self, event):
        self.toolbar.release_pan(event)
Esempio n. 5
0
class PlotCtrl(PlotView):
    def __init__(self, panel):
        PlotView.__init__(self, panel)

        # stores the plot objects
        self.plots = []
        self.__plot_count = 0
        # self.highlighted_vertices = set()  # Rename to selected_vertices
        self.highlighted_vertices = [
        ]  # Keeps track of the highlighted vertices index
        self.marker = None  # Must be a matplotlib line2D object
        self.x_scatter_data, self.y_scatter_data = None, None  # Holds the scatter data for highlighting
        self.poly_list = None  # Holds the data for the plotted polygon
        self.highlight_color = "y"  # Yellow is used when highlighting
        self.color = "#0DACFF"  # The standard color for objects that are not highlighted
        self._color_converter = ColorConverter()
        self.line_collection = None  # Holds the line collection data
        self.highlighted_lines = []
        self.selected_lines = []

        # stores the axis objects
        self.__axis = []

        # matplotlib color cycle used to ensure primary and secondary axis are not displayed with the same color
        self.__color_cycle = color_cycle()

        # Used to be able to deactivate the canvas events later
        self._cid_press = None
        self._cid_release = None
        self._cid_scroll = None

    def activate_panning(self):
        self._cid_press = self.canvas.mpl_connect('button_press_event',
                                                  self.on_canvas_clicked)
        self._cid_release = self.canvas.mpl_connect('button_release_event',
                                                    self.on_canvas_released)

    def activate_zooming(self):
        self._cid_scroll = self.canvas.mpl_connect('scroll_event',
                                                   self.on_canvas_mouse_scroll)

    def deactivate_panning(self):
        self.canvas.mpl_disconnect(self._cid_press)
        self.canvas.mpl_disconnect(self._cid_release)

    def deactivate_zooming(self):
        self.canvas.mpl_disconnect(self._cid_scroll)

    def clear_plot(self):

        # clear axis
        self.axes.clear()
        for ax in self.__axis:
            ax.cla()

        # reset the axis container
        self.__axis = []

        self.axes.grid()
        self.axes.margins(0)

        # clear the plot objects
        self.__plot_count = 0
        self.plots = []

        self.redraw()

    def getNextColor(self):
        return next(self.__color_cycle)

    def get_highlighted_vertices(self):
        """
        Returns the index and the coordinates of the highlighted vertices in a dictionary.
        Key is index, value are coordinates
        :return:
        """
        data = {}
        for vertex in self.highlighted_vertices:
            data[vertex] = [
                self.x_scatter_data[vertex], self.y_scatter_data[vertex]
            ]
        return data

    def get_highlighted_polygons(self):
        """
        Returns all the highlighted polygons.
        Key is index, value are object
        :return: type(dict)
        """
        if not self.poly_list:  # Check if polygon has been plotted
            return {}  # No polygons have been plotted.

        data = {}
        for i in range(len(self.axes.collections)):
            if self.axes.collections[i].get_facecolor().all(
            ) == self.axes.collections[i].get_edgecolor().all():
                data[i] = self.axes.collections[i]
        return data

    def get_highlighted_lines(self):
        """
        Returns the highlighted lines index
        Key is index, value are indexes of segments
        :return:
        """
        if not self.line_collection:
            return {}  # No lines have been plotted

        lines = {}
        for line in self.highlighted_lines:
            lines[line] = self.line_segments[0]
        return lines

    def highlight_line(self, event):
        """
        Highlighted lines have value 1 in self.highlighted_lines
        Recolor the collection
        :param event:
        :return:
        """
        ind = event.ind[0]
        line_idx = self.segment_line[ind]

        if line_idx not in self.highlighted_lines:
            self.highlighted_lines.append(line_idx)
        else:
            self.highlighted_lines.remove(line_idx)

        highlight_idx = self.line_segments[line_idx]
        self.selected_lines[
            highlight_idx] = 1 - self.selected_lines[highlight_idx]
        lines, = event.artist.axes.collections
        lines.set_color(self.__line_colors[self.selected_lines])
        event.canvas.draw_idle()

    def highlight_polygon(self, pick_event):
        if pick_event.artist.get_facecolor()[0].all(
        ) == pick_event.artist.get_edgecolor()[0].all():
            pick_event.artist.set_facecolor(self.color)
            pick_event.artist.set_edgecolor(None)
        else:
            pick_event.artist.set_color(self.highlight_color)
        pick_event.artist.axes.figure.canvas.draw()

    def highlight_vertex(self, pick_event):
        """
        Only one marker can be used to highlight. self.marker is that one marker
        :param pick_event: matplotlib mouse pick event
        :return:
        """
        if not self.marker:
            self.marker, = pick_event.artist.axes.plot([], [],
                                                       "o")  # Create a plot
            self.marker.set_color(self.highlight_color)  # Set color to yellow

        # Check if vertex has been highlighted
        if pick_event.ind[0] in self.highlighted_vertices:
            self.highlighted_vertices.remove(
                pick_event.ind[0])  # Remove highlight
        else:
            self.highlighted_vertices.append(
                pick_event.ind[0])  # Add highlight

        self.highlighted_vertices.sort()
        x = self._get_vertices_data_points(self.x_scatter_data,
                                           pick_event.ind[0])
        y = self._get_vertices_data_points(self.y_scatter_data,
                                           pick_event.ind[0])

        # Highlight only those in self.highlighted_vertices
        self.marker.set_data(x, y)
        pick_event.artist.axes.figure.canvas.draw()

    def _get_vertices_data_points(self, data, index):
        """
        :param data: x_data or y_data, type(tuple or list)
        :param index: index of the selected vertex, type(int)
        :return: type(list) contains the x & y values that are highlight and should be plotted
        """
        a = []
        for i in self.highlighted_vertices:
            a.append(data[i])

        return a

    def plot_dates(self, data, name, noDataValue, ylabel=""):
        """
        :param data: type([datetime, floats])
        :param name:
        :param noDataValue:
        :param ylabel:
        :return:
        """

        if len(data) == 0:
            return

        # unpack the dates, values and replace nodata with None
        dates, values = zip(*data)
        nvals = numpy.array(values, dtype=numpy.float)
        nvals[nvals == noDataValue] = None
        nvals[numpy.isnan(nvals)] = None

        p = self.axes.plot_date(dates,
                                nvals,
                                label=name,
                                linestyle="None",
                                marker=".")
        self.axes.set_ylabel(ylabel)

        # save each of the plots
        self.plots.extend(p)

        self.redraw()

    def plot_polygon(self, data, color):
        poly_list = []
        for item in data:
            reference = item.GetGeometryRef(0)
            points = numpy.array(reference.GetPoints())
            a = tuple(map(tuple, points[:, 0:2]))
            poly_list.append(a)

        self.poly_list = poly_list

        # Plot multiple polygons and add them to collection as individual polygons
        for poly in self.poly_list:
            p_coll = PolyCollection([poly],
                                    closed=True,
                                    facecolor=color,
                                    alpha=0.5,
                                    edgecolor=None,
                                    linewidths=(2, ))
            p_coll.set_picker(True)  # Enable pick event
            self.axes.add_collection(p_coll, autolim=True)

    def plot_point(self, data, color):  # Rename to plot scatter
        # get x,y points
        x, y = zip(*[(g.GetX(), g.GetY()) for g in data])
        self.x_scatter_data, self.y_scatter_data = x, y
        collection = self.axes.scatter(x,
                                       y,
                                       marker="o",
                                       color=color,
                                       picker=True)
        return collection

    def plot_linestring(self, data, color):
        """
        A segment is from point A to point b. It is created from grabbing the previous point to the next point
        :param data: geometry object
        :param color:  # Hexadecimal
        :return:
        """
        segments = []
        self.line_segments = {}
        self.segment_line = {}
        index = 0  # Keeps track of how many lines to plot. Should match len(data)
        last_segment = 0
        points = []
        for geo_object in data:
            for point in geo_object.GetPoints():  # Remove the z coordinate
                points.append(point[:-1])
            self.line_segments[index] = range(last_segment,
                                              last_segment + len(points) - 1)
            for i in range(len(points) - 1):  # Create the segments
                segments.append([points[i], points[i + 1]])
                self.segment_line[last_segment + i] = index
            last_segment += len(points) - 1
            index += 1

        self.__line_colors = np.array([
            self._color_converter.to_rgba(color),
            self._color_converter.to_rgba(self.highlight_color)
        ])
        self.selected_lines = np.zeros(len(segments),
                                       dtype=int)  # Must be a np.zero array
        colors = self.__line_colors[self.selected_lines]
        self.line_collection = LineCollection(segments,
                                              pickradius=10,
                                              linewidths=2,
                                              colors=colors)
        self.line_collection.set_picker(True)
        self.axes.add_collection(self.line_collection)

    def plot_geometry(self, geometry_object, title, color=None):
        """
        A general plot method that will plot the respective type
        Must call redraw afterwards to have an effect
        :param geometry_object:
        :param title: title for the plot
        :param color: # Hexadecimal
        :return:
        """
        if not color:
            color = self.color

        if geometry_object[0].GetGeometryName().upper() == "POLYGON":
            self.plot_polygon(geometry_object, color)
        elif geometry_object[0].GetGeometryName().upper() == "POINT":
            self.plot_point(geometry_object, color)
        elif geometry_object[0].GetGeometryName().upper() == "LINESTRING":
            self.plot_linestring(geometry_object, color)
        else:
            raise Exception(
                "plot_geometry() failed. Geometries must be POLYGON OR POINT")

        self.set_title(title)
        self.axes.grid(True)

        # If margin is 0 the graph will fill the plot
        self.axes.margins(0.1)

    def reset_highlighter(self):
        """
        Resets the variables needed to highlight
        :return:
        """
        self.marker = None
        self.highlighted_vertices = []
        self.x_scatter_data, self.y_scatter_data = None, None
        self.poly_list = None
        self.line_collection = None

    def set_line_width(self, width):
        """
        Sets the width of the lines plotted
        Does not work with scatter plot
        :param width: real number
        :return:
        """
        if not len(self.plots):
            return  # Nothing has been plotted

        for line in self.plots:
            line.set_linewidth(width)
        self.redraw()

    def on_canvas_clicked(self, event):
        self.toolbar.press_pan(event)

    def on_canvas_mouse_scroll(self, event):
        base_scale = 2.0
        cur_xlim = self.axes.get_xlim()
        cur_ylim = self.axes.get_ylim()
        cur_xrange = (cur_xlim[1] - cur_xlim[0]) * .5
        cur_yrange = (cur_ylim[1] - cur_ylim[0]) * .5
        xdata = event.xdata  # get event x location
        ydata = event.ydata  # get event y location
        if event.button == 'up':
            # deal with zoom in
            scale_factor = 1 / base_scale
        elif event.button == 'down':
            # deal with zoom out
            scale_factor = base_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print event.button

        # set new limits
        self.axes.set_xlim([
            xdata - cur_xrange * scale_factor,
            xdata + cur_xrange * scale_factor
        ])
        self.axes.set_ylim([
            ydata - cur_yrange * scale_factor,
            ydata + cur_yrange * scale_factor
        ])
        self.redraw()

    def on_canvas_released(self, event):
        self.toolbar.release_pan(event)
Esempio n. 6
0
def plotConnectionData(cells,
                       p,
                       fig=None,
                       ax=None,
                       zdata=None,
                       clrmap=None,
                       colorbar=None,
                       pickable=None):
    """
        Assigns color-data to connections between a cell and its nearest neighbours and returns plot instance

        Parameters
        ----------

        zdata_t                  A data array with each scalar entry corresponding to a polygon entry in
                               vor_verts. If not specified the default is z=1. If 'random'
                               is specified the method creates random vales from 0 to 1..

        clrmap                 The colormap to use for plotting. Must be specified as cm.mapname. A list of
                               available mapnames is supplied at
                               http://matplotlib.org/examples/color/colormaps_reference.html
                               Default is cm.rainbow. Good options are cm.coolwarm, cm.Blues, cm.jet


        Returns
        -------
        fig, ax                Matplotlib figure and axes instances for the plot.

        Notes
        -------
        Uses matplotlib.collections LineCollection, matplotlib.cm, matplotlib.pyplot and numpy arrays

        """
    if fig is None:
        fig = plt.figure()  # define the figure and axes instances
    if ax is None:
        ax = plt.subplot(111)
        #ax = plt.axes()

    if zdata is None:
        z = np.ones(len(cells.gap_jun_i))
    #FIXME: This is a bit cumbersome. Ideally, a new "is_zdata_random"
    #boolean parameter defaulting to "False" should be tested, instead.
    #Whack-a-mole with a big-fat-pole!

    # If random data is requested, do so. To avoid erroneous and expensive
    # elementwise comparisons when "zdata" is neither None nor a string,
    # "zdata" must be guaranteed to be a string *BEFORE* testing this
    # parameter as a string. Numpy prints scary warnings otherwise: e.g.,
    #
    #     FutureWarning: elementwise comparison failed; returning scalar
    #     instead, but in the future will perform elementwise comparison
    elif isinstance(zdata, str) and zdata == 'random':
        z = np.random.random(len(cells.gap_jun_i))

    else:
        z = zdata

    if clrmap is None:
        clrmap = cm.bone_r  # default colormap

    # Make a line collection and add it to the plot.

    con_segs = cells.cell_centres[cells.gap_jun_i]

    connects = p.um * np.asarray(con_segs)

    coll = LineCollection(connects,
                          array=z,
                          cmap=clrmap,
                          linewidths=4.0,
                          zorder=0)
    coll.set_clim(vmin=0.0, vmax=1.0)
    coll.set_picker(pickable)
    ax.add_collection(coll)

    ax.axis('equal')

    # Add a colorbar for the Line Collection
    if zdata is not None and colorbar == 1:
        ax_cb = fig.colorbar(coll, ax=ax)
    else:
        ax_cb = None

    xmin = cells.xmin * p.um
    xmax = cells.xmax * p.um
    ymin = cells.ymin * p.um
    ymax = cells.ymax * p.um

    ax.axis([xmin, xmax, ymin, ymax])

    return fig, ax, ax_cb
Esempio n. 7
0
	def plot(self, form, width='uniform', sort=True, gap=0.05, color_nodes=None):
		"""
		Create a level set tree plot in Matplotlib.
		
		Parameters
		----------
		form : {'lambda', 'alpha', 'kappa', 'old'}
			Determines main form of the plot. 'lambda' is the traditional plot
			where the vertical scale is density levels, but plot improvements
			such as mass sorting of the nodes and colored nodes are allowed and
			the secondary 'alpha' scale is visible (but not controlling). The
			'old' form uses density levels for vertical scale but does not allow
			plot tweaks and does not show the secondary 'alpha' scale. The
			'alpha' setting makes the uppper level set mass the primary vertical
			scale, leaving the 'lambda' scale in place for reference. 'kappa'
			makes node mass the vertical scale, so that each node's vertical
			height is proportional to its mass excluding the mass of the node's
			children.
			
		width : {'uniform', 'mass'}, optional
			Determines how much horzontal space each level set tree node is
			given. The default of "uniform" gives each child node an equal
			fraction of the parent node's horizontal space. If set to 'mass',
			then horizontal space is allocated proportional to the mass (i.e.
			fraction of points) of a node relative to its siblings.
		
		sort : bool, optional
			If True, sort sibling nodes from most to least points and draw left
			to right. Also sorts root nodes in the same way.
			
		gap : float, optional
			Fraction of vertical space to leave at the bottom. Default is 5%,
			and 0% also works well. Higher values are used for interactive tools
			to make room for buttons and messages.
			
		color_nodes : list, optional
			Each entry should be a valid index in the level set tree that will
			be colored uniquely.
			
		Returns
		-------
		fig : matplotlib figure
			Use fig.show() to view, fig.savefig() to save, etc.
			
		segments : dict
			A dictionary with values that contain the coordinates of vertical
			line segment endpoints. This is only useful to the interactive
			analysis tools.
		
		segmap : list
			Indicates the order of the vertical line segments as returned by the
			recursive coordinate mapping function, so they can be picked by the
			user in the interactive tools.
		
		splits : dict
			Dictionary values contain the coordinates of horizontal line
			segments (i.e. node splits).
			
		splitmap : list
			Indicates the order of horizontal line segments returned by
			recursive coordinate mapping function, for use with interactive
			tools.
		"""
		
		## Validate input
		if form == 'old':
			sort = False
			color_nodes = None
			width = 'uniform'
			

		## Initialize the plot containers
		segments = {}
		splits = {}
		segmap = []
		splitmap = []
		

		## Find the root connected components and corresponding plot intervals
		ix_root = np.array([k for k, v in self.nodes.iteritems()
			if v.parent is None])
		n_root = len(ix_root)
		census = np.array([len(self.nodes[x].members) for x in ix_root],
			dtype=np.float)
		n = sum(census)
		
		if sort is True:
			seniority = np.argsort(census)[::-1]
			ix_root = ix_root[seniority]
			census = census[seniority]
			
		if width == 'mass':
			weights = census / n
			intervals = np.cumsum(weights)
			intervals = np.insert(intervals, 0, 0.0)
		else:
			intervals = np.linspace(0.0, 1.0, n_root+1)
		
		
		## Do a depth-first search on each root to get segments for each branch
		for i, ix in enumerate(ix_root):
			if form == 'kappa':
				branch = self.constructMassMap(ix, 0.0, (intervals[i],
					intervals[i+1]), width)
			elif form == 'old':
				branch = self.constructBranchMap(ix, (intervals[i],
					intervals[i+1]), 'lambda', width, sort)			
			else:
				branch = self.constructBranchMap(ix, (intervals[i],
					intervals[i+1]), form, width, sort)
		
			branch_segs, branch_splits, branch_segmap, branch_splitmap = branch
			segments = dict(segments.items() + branch_segs.items())
			splits = dict(splits.items() + branch_splits.items())
			segmap += branch_segmap
			splitmap += branch_splitmap

			
		## get the the vertical line segments in order of the segment map (segmap)
		verts = [segments[k] for k in segmap]
		lats = [splits[k] for k in splitmap]


		## Find the fraction of nodes in each segment (to use as linewidths)
		thickness = [max(1.0, 12.0 * len(self.nodes[x].members)/n)
			for x in segmap]

		
		## Get the relevant vertical ticks
		primary_ticks = [(x[0][1], x[1][1]) for x in segments.values()]
		primary_ticks = np.unique(np.array(primary_ticks).flatten())
		primary_labels = [str(round(tick, 2)) for tick in primary_ticks]
		
						
		## Set up the plot framework
		fig, ax = plt.subplots()
		ax.set_position([0.11, 0.05, 0.78, 0.93])
		ax.set_xlim((-0.04, 1.04))
		ax.set_xticks([])
		ax.set_xticklabels([])
		ax.yaxis.grid(color='gray')
		ax.set_yticks(primary_ticks)
		ax.set_yticklabels(primary_labels)

				
		## Form-specific details
		if form == 'kappa':
			kappa_max = max(primary_ticks)
			ax.set_ylim((-1.0 * gap * kappa_max, 1.04*kappa_max))
			ax.set_ylabel("mass")
			
		elif form == 'old':
			ax.set_ylabel("lambda")
			ymin = min([v.start_level for v in self.nodes.itervalues()])
			ymax = max([v.end_level for v in self.nodes.itervalues()])
			rng = ymax - ymin
			ax.set_ylim(ymin - gap*rng, ymax + 0.05*rng)	

		elif form == 'lambda':
			ax.set_ylabel("lambda")
			ymin = min([v.start_level for v in self.nodes.itervalues()])
			ymax = max([v.end_level for v in self.nodes.itervalues()])
			rng = ymax - ymin
			ax.set_ylim(ymin - gap*rng, ymax + 0.05*rng)
			
			ax2 = ax.twinx()
			ax2.set_position([0.11, 0.05, 0.78, 0.93])
			ax2.set_ylabel("alpha", rotation=270)

			alpha_ticks = np.sort(list(set(
				[v.start_mass for v in self.nodes.itervalues()] + \
				[v.end_mass for v in self.nodes.itervalues()])))
			alpha_labels = [str(round(m, 2)) for m in alpha_ticks]

			ax2.set_yticks(primary_ticks)
			ax2.set_yticklabels(alpha_labels)		
			ax2.set_ylim(ax.get_ylim())
		
		elif form == 'alpha':
			ax.set_ylabel("alpha")
			ymin = min([v.start_mass for v in self.nodes.itervalues()])
			ymax = max([v.end_mass for v in self.nodes.itervalues()])
			rng = ymax - ymin
			ax.set_ylim(ymin - gap*rng, ymax + 0.05*ymax)
			
			ax2 = ax.twinx()
			ax2.set_position([0.11, 0.05, 0.78, 0.93])
			ax2.set_ylabel("lambda", rotation=270)

			lambda_ticks = np.sort(list(set(
				[v.start_level for v in self.nodes.itervalues()] + \
				[v.end_level for v in self.nodes.itervalues()])))
			lambda_labels = [str(round(lvl, 2)) for lvl in lambda_ticks]

			ax2.set_ylim(ax.get_ylim())
			ax2.set_yticks(primary_ticks)
			ax2.set_yticklabels(lambda_labels)
						
		else:
			raise ValueError('Plot form not understood')				

				
		## Add the line segments
		segclr = np.array([[0.0, 0.0, 0.0]] * len(segmap))
		splitclr = np.array([[0.0, 0.0, 0.0]] * len(splitmap))

		palette = utl.Palette()
		if color_nodes is not None:
			for i, ix in enumerate(color_nodes):
				n_clr = np.alen(palette.colorset)
				c = palette.colorset[i % n_clr, :]
				subtree = self.makeSubtree(ix)

				## set verical colors
				ix_replace = np.in1d(segmap, subtree.nodes.keys())
				segclr[ix_replace] = c

				## set horizontal colors
				if splitmap:
					ix_replace = np.in1d(splitmap, subtree.nodes.keys())
					splitclr[ix_replace] = c
						
		linecol = LineCollection(verts, linewidths=thickness, colors=segclr)
		ax.add_collection(linecol)
		linecol.set_picker(20)
	
		splitcol = LineCollection(lats, colors=splitclr)
		ax.add_collection(splitcol)
				
		return fig, segments, segmap, splits, splitmap