class Tracks:
    def __init__(self, ax, stormcells, tails=None):
        self.tracks = None
        self.tails = tails
        self.update_trackmap(ax, stormcells)

    def update_trackmap(self, ax, stormcells):
        if self.tracks is not None:
            self.tracks.remove()
            self.tracks = None
        if self.tracks is None:
            self.tracks = LineCollection([])
            ax.add_collection(self.tracks)

        self.trackmap = []
        for trackid in range(np.max(stormcells['track_id']) + 1):
            indexes = np.where(stormcells['track_id'] == trackid)[0]
            # Makes sure the track segments are in chronological order
            indexes = indexes[np.argsort(stormcells['frame_index'][indexes])]
            self.trackmap.append(indexes)

    def update_frame(self, frame_index, stormcells):
        segments = []
        for trackid, indexes in enumerate(self.trackmap):
            trackdata = stormcells[indexes]
            trackdata = trackdata[trackdata['frame_index'] <= frame_index]
            if self.tails:
                mask = trackdata['frame_index'] > (frame_index - self.tails)
                trackdata = trackdata[mask]
            segments.append(
                zip(trackdata['xcent'], trackdata['ycent'])
                or [(np.nan, np.nan)])
        self.tracks.set_segments(segments)
class Tracks:
    def __init__(self, ax, stormcells):
        self.tracks = None
        self.update_trackmap(ax, stormcells)

    def update_trackmap(self, ax, stormcells):
        if self.tracks is not None:
            self.tracks.remove()
            self.tracks = None
        if self.tracks is None:
            self.tracks = LineCollection([])
            ax.add_collection(self.tracks)

        self.trackmap = []
        for trackid in range(np.max(stormcells['track_id']) + 1):
            indexes = np.where(stormcells['track_id'] == trackid)[0]
            # Makes sure the track segments are in chronological order
            indexes = indexes[np.argsort(stormcells['frame_index'][indexes])]
            self.trackmap.append(indexes)

    def update_frame(self, frame_index, stormcells):
        segments = []
        for trackid, indexes in enumerate(self.trackmap):
            trackdata = stormcells[indexes]
            trackdata = trackdata[trackdata['frame_index'] <= frame_index]
            segments.append(zip(trackdata['xcent'], trackdata['ycent'])
                            or [(np.nan, np.nan)])
        self.tracks.set_segments(segments)
Beispiel #3
0
class TrailingLine:
    '''
    This class plots a line that starts out with zero thickness and linearly
    increases to thickness max_width along its path.
    '''
    def __init__(self, x, y, ax, max_width, **kwargs):
        self.ax = ax
        self.max_width = max_width
        lw = self._compute_linewidths(len(x))

        self.lc = LineCollection(self._compute_segments(x, y),
                                 linewidths=lw,
                                 **kwargs)

        ax.add_collection(self.lc)

    def set_data(self, x, y):
        self.lc.set_segments(self._compute_segments(x, y))
        lw = self._compute_linewidths(len(x))
        self.lc.set_linewidth(lw)

    @classmethod
    def _compute_segments(cls, x, y):
        points = np.array([x, y]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        return segments

    def _compute_linewidths(self, npts):
        return np.linspace(0, self.max_width, npts)[:-1]
Beispiel #4
0
class BatchLineCollection(object):
    def __init__(self, ax):
        self._ax = ax
        self._lc = None

    @property
    def artists(self):
        return [self._lc]

    def draw(self, x, y, **kwargs):
        segments = []
        for x_i, y_i in zip(x, y):
            xy_i = np.stack([x_i, y_i], axis=1)
            xy_i = xy_i.reshape(-1, 1, 2)
            segments_i = np.hstack([xy_i[:-1], xy_i[1:]])
            segments.append(segments_i)
        segments = np.concatenate(segments, axis=0)

        if self._lc is None:
            self._lc = PltLineCollection(segments)
            self._ax.add_collection(self._lc)
        else:
            self._lc.set_segments(segments)
        if 'color' in kwargs:
            self._lc.set_color(np.reshape(kwargs['color'],
                                          [len(segments), -1]))
        if 'linewidth' in kwargs:
            self._lc.set_linewidth(kwargs['linewidth'])
        self._lc.set_joinstyle('round')
        self._lc.set_capstyle('round')
Beispiel #5
0
class ArrowedLineCollection(LineCollection):
 
    def __init__(self, *kl, **kw):
        '''The same arguments as in matplotlib.collections.LineCollection.
        To initiate arrows, additionally use 'add_arrows'.
        '''
        
        LineCollection.__init__(self, *kl, **kw)
        self.kw = kw                                # to be reused in add_arrows
        self.draw_arrows = True
        
        return

    def add_arrows(self, arrow_offsets=10, arrow_length=10, arrow_width=5):
        '''Draw arrows at the end of every edge.

        Parameters
        ----------
        arrow_offsets:  a list of length equal to the number of edges. It stores offstets (in pixels) by which arrows
                        should be moved back to make space for a node. A single value (the same for all) is also accepted.
        arrow_length:   the length of the arrow head, in pixels?
        arrow_width:    the width of the arrow head, in pixels?


        '''
        self.draw_arrows=True
        self.arrows = LineCollection([], **self.kw)
        self.arrow_length=arrow_length
        self.arrow_width=arrow_width
        
        self.arrow_offsets = []
        if type(arrow_offsets)==list:
            if len(arrow_offsets) != len(self.get_paths()):  raise ValueError('arrow_offsets does not match the number of edges.')
            self.arrow_offsets = arrow_offsets
        else:
            self.arrow_offsets = [arrow_offsets]*len(self.get_paths())
            
        return
        

    def draw(self, renderer):
        ''' Overrides the matplotlib.collections.LineCollection.draw(). Adds arrows.
        '''
        LineCollection.draw(self, renderer)
        if not self.draw_arrows: return
        
        segments=[]
        for i,path in enumerate(self.get_paths()):
            L= [self.get_transform().transform(numpy.transpose(j[0])) for j in path.iter_segments()]
            V=L[1]-L[0]
            scale = min(1.0, vector_length(V)/(4*self.arrow_offsets[i]))    #make the arrows smaller if their size is comparable with edge length
            node_shift = scale * self.arrow_offsets[i] * norm_vector(V)
            head_length_vector = scale * self.arrow_length * norm_vector(V)
            head_width_vector = scale * self.arrow_width * perpendicular_vector(norm_vector(V))
            segments.append((L[1]-head_length_vector+head_width_vector-node_shift,  L[1]-node_shift  ,L[1]-head_length_vector-head_width_vector-node_shift))

        self.arrows.set_segments(segments)
        self.arrows.draw(renderer)
        return
Beispiel #6
0
 def do_3d_projection(self, renderer):
     '''
     Project the points according to renderer matrix.
     '''
     xyslist = [
         proj3d.proj_trans_points(points, renderer.M) for points in
         self._segments3d]
     segments_2d = [zip(xs, ys) for (xs, ys, zs) in xyslist]
     LineCollection.set_segments(self, segments_2d)
     minz = 1e9
     for (xs, ys, zs) in xyslist:
         minz = min(minz, min(zs))
     return minz
Beispiel #7
0
    def do_3d_projection(self, renderer):
        '''
        Project the points according to renderer matrix.
        '''
        xyslist = [
            proj3d.proj_trans_points(points, renderer.M) for points in
            self._segments3d]
        segments_2d = [zip(xs, ys) for (xs, ys, zs) in xyslist]
        LineCollection.set_segments(self, segments_2d)

        minz = 1e9
        for (xs, ys, zs) in xyslist:
            minz = min(minz, min(zs))
        return minz
    def do_3d_projection(self, renderer=None):
        """
        Project the points according to renderer matrix.
        """
        xyslist = [proj3d.proj_trans_points(points, self.axes.M)
                   for points in self._segments3d]
        segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
        LineCollection.set_segments(self, segments_2d)

        # FIXME
        minz = 1e9
        for xs, ys, zs in xyslist:
            minz = min(minz, min(zs))
        return minz
Beispiel #9
0
    def do_3d_projection(self, renderer):
        """
        Project the points according to renderer matrix.
        """
        xyslist = [
            proj3d.proj_trans_points(points, renderer.M) for points in
            self._segments3d]
        segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
        LineCollection.set_segments(self, segments_2d)

        # FIXME
        minz = 1e9
        for xs, ys, zs in xyslist:
            minz = min(minz, min(zs))
        return minz
Beispiel #10
0
class MyCell(matplotlib.table.CustomCell):
    """ Extending matplotlib tables.
        
        Adapted from https://stackoverflow.com/a/53573651/505698
    """
    def __init__(self, *args, visible_edges, **kwargs):
        super().__init__(*args, visible_edges=visible_edges, **kwargs)
        seg = np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0],
                        [0.0, 0.0]]).reshape(-1, 1, 2)
        segments = np.concatenate([seg[:-1], seg[1:]], axis=1)
        self.edgelines = LineCollection(segments,
                                        edgecolor=kwargs.get("edgecolor"))
        self._text.set_zorder(2)
        self.set_zorder(1)

    def set_transform(self, trans):
        self.edgelines.set_transform(trans)
        super().set_transform(trans)

    def draw(self, renderer):
        c = self.get_edgecolor()
        self.set_edgecolor((1, 1, 1, 0))
        super().draw(renderer)
        self.update_segments(c)
        self.edgelines.draw(renderer)
        self.set_edgecolor(c)

    def update_segments(self, color):
        x, y = self.get_xy()
        w, h = self.get_width(), self.get_height()
        seg = np.array([[x, y], [x + w, y], [x + w, y + h], [x, y + h],
                        [x, y]]).reshape(-1, 1, 2)
        segments = np.concatenate([seg[:-1], seg[1:]], axis=1)
        self.edgelines.set_segments(segments)
        self.edgelines.set_linewidth(self.get_linewidth())
        colors = [
            color if edge in self._visible_edges else (1, 1, 1, 0)
            for edge in self._edges
        ]
        self.edgelines.set_edgecolor(colors)

    def get_path(self):
        codes = [Path.MOVETO] + [Path.LINETO] * 3 + [Path.CLOSEPOLY]
        return Path(
            [[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0], [0.0, 0.0]],
            codes,
            readonly=True,
        )
Beispiel #11
0
class Tracks(object):
    def __init__(self, ax, tails=None):
        self.tracks = None
        self.tails = tails
        self.initialize_lines(ax)

    @staticmethod
    def create_trackmap(stormdata):
        trackmap = []
        for trackid in range(np.max(stormdata['track_id']) + 1):
            indexes = np.where(stormdata['track_id'] == trackid)[0]
            # Makes sure the track segments are in chronological order
            indexes = indexes[np.argsort(stormdata['frame_index'][indexes])]
            trackmap.append(indexes)
        return trackmap

    def remove_lines(self):
        if self.tracks is not None:
            self.tracks.remove()
            self.tracks = None

    def initialize_lines(self, ax):
        self.remove_lines()
        self.tracks = LineCollection([])
        ax.add_collection(self.tracks)

    def update_lines(self, frame_index, stormdata):
        segments = []
        for indexes in self.create_trackmap(stormdata):
            trackdata = stormdata[indexes]
            trackdata = trackdata[trackdata['frame_index'] <= frame_index]
            if self.tails:
                mask = trackdata['frame_index'] >= (frame_index - self.tails)
                trackdata = trackdata[mask]
            # There must always be something in a track, even it it is NaNs.
            segments.append(
                zip(trackdata['xcent'], trackdata['ycent'])
                or [(np.nan, np.nan)])
        self.tracks.set_segments(segments)

    def lolite_line(self, indx):
        self.hilite_line(indx, 1)

    def hilite_line(self, indx, lw=4):
        if indx is not None:
            lws = self.tracks.get_linewidths()
            lws[indx] = lw
            self.tracks.set_linewidths(lws)
Beispiel #12
0
    def do_3d_projection(self, renderer=None):
        """
        Project the points according to renderer matrix.
        """
        # see _update_scalarmappable docstring for why this must be here
        _update_scalarmappable(self)
        xyslist = [proj3d.proj_trans_points(points, self.axes.M)
                   for points in self._segments3d]
        segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
        LineCollection.set_segments(self, segments_2d)

        # FIXME
        minz = 1e9
        for xs, ys, zs in xyslist:
            minz = min(minz, min(zs))
        return minz
class Tracks(object):
    def __init__(self, ax, tails=None):
        self.tracks = None
        self.tails = tails
        self.initialize_lines(ax)

    @staticmethod
    def create_trackmap(stormdata):
        trackmap = []
        for trackid in range(np.max(stormdata['track_id']) + 1):
            indexes = np.where(stormdata['track_id'] == trackid)[0]
            # Makes sure the track segments are in chronological order
            indexes = indexes[np.argsort(stormdata['frame_index'][indexes])]
            trackmap.append(indexes)
        return trackmap

    def remove_lines(self):
        if self.tracks is not None:
            self.tracks.remove()
            self.tracks = None

    def initialize_lines(self, ax):
        self.remove_lines()
        self.tracks = LineCollection([])
        ax.add_collection(self.tracks)

    def update_lines(self, frame_index, stormdata):
        segments = []
        for indexes in self.create_trackmap(stormdata):
            trackdata = stormdata[indexes]
            trackdata = trackdata[trackdata['frame_index'] <= frame_index]
            if self.tails:
                mask = trackdata['frame_index'] >= (frame_index - self.tails)
                trackdata = trackdata[mask]
            # There must always be something in a track, even it it is NaNs.
            segments.append(zip(trackdata['xcent'], trackdata['ycent'])
                            or [(np.nan, np.nan)])
        self.tracks.set_segments(segments)

    def lolite_line(self, indx):
        self.hilite_line(indx, 1)

    def hilite_line(self, indx, lw=4):
        if indx is not None:
            lws = self.tracks.get_linewidths()
            lws[indx] = lw
            self.tracks.set_linewidths(lws)
Beispiel #14
0
class LineCollection(object):
    def __init__(self, ax):
        self._ax = ax
        self._lc = None

    @property
    def artists(self):
        return [self._lc]

    def draw(self, x, y, **kwargs):
        xy = np.stack([x, y], axis=1)
        xy = xy.reshape(-1, 1, 2)
        segments = np.hstack([xy[:-1], xy[1:]])

        if self._lc is None:
            self._lc = PltLineCollection(segments)
            self._ax.add_collection(self._lc)
        else:
            self._lc.set_segments(segments)
        if 'color' in kwargs:
            self._lc.set_color(kwargs['color'])
Beispiel #15
0
class ScatterPlot(object):
    def __init__(self, feeder, marker = 'o'):
        self.marker = marker
        self.feeder = feeder
        self.stream = iter(self.feeder)

    def get_frames_len(self):
        return len(self.feeder.time_intervals) - 1

    def get_frame(self):
        return self.feeder.frame

    def get_limits(self):
        return self.feeder.get_limits()

    def setup_plot(self):
        x, y, c = next(self.stream)
        ax = plt.gca()
        self.scat = ax.scatter(x, y, c = c, marker = self.marker, s = 25)
        return self.scat

    def setup_plot_edges(self):
        lines = self.feeder.get_current_edges()
        ax = plt.gca()
        self.lines = LineCollection(lines, linewidths=0.1)
        ax.add_collection(self.lines)

    def update_edges(self):
        lines = self.feeder.get_current_edges()
        self.lines.set_segments(lines)

    def prev_state(self):
        self.feeder.prev()
        self.stream = iter(self.feeder)

    def update_plot(self):
        x, y, _ = next(self.stream)
        new_data = np.array(zip(x, y))
        self.scat.set_offsets(new_data)
        return self.scat
Beispiel #16
0
class SURFDemo(ImageProcessDemo):
    TITLE = "SURF Demo"
    DEFAULT_IMAGE = "lena.jpg"
    SETTINGS = ["m_perspective", "hessian_threshold", "n_octaves"]
    m_perspective = Array(np.float, (3, 3))
    m_perspective2 = Array(np.float, (3, 3))

    hessian_threshold = Int(2000)
    n_octaves = Int(2)

    poly = Instance(PolygonWidget)

    def control_panel(self):
        return VGroup(
            Item("m_perspective",
                 label="变换矩阵",
                 editor=ArrayEditor(format_str="%g")),
            Item("m_perspective2",
                 label="变换矩阵",
                 editor=ArrayEditor(format_str="%g")),
            Item("hessian_threshold", label="hessianThreshold"),
            Item("n_octaves", label="nOctaves"))

    def __init__(self, **kwargs):
        super(SURFDemo, self).__init__(**kwargs)
        self.poly = None
        self.init_points = None
        self.lines = LineCollection([], linewidths=1, alpha=0.6, color="red")
        self.axe.add_collection(self.lines)
        self.connect_dirty("poly.changed,hessian_threshold,n_octaves")

    def init_poly(self):
        if self.poly is None:
            return
        h, w, _ = self.img_color.shape
        self.init_points = np.array([(w, 0), (2 * w, 0), (2 * w, h), (w, h)],
                                    np.float32)
        self.poly.set_points(self.init_points)
        self.poly.update()

    def init_draw(self):
        style = {"marker": "o"}
        self.poly = PolygonWidget(axe=self.axe,
                                  points=np.zeros((3, 2)),
                                  style=style)
        self.init_poly()

    @on_trait_change("hessian_threshold, n_octaves")
    def calc_surf1(self):
        self.surf = cv2.SURF(self.hessian_threshold, self.n_octaves)
        self.key_points1, self.features1 = self.surf.detectAndCompute(
            self.img_gray, None)
        self.key_positions1 = np.array([kp.pt for kp in self.key_points1])

    def _img_changed(self):
        self.img_gray = cv2.cvtColor(self.img, cv2.COLOR_BGR2GRAY)
        self.img_color = cv2.cvtColor(self.img_gray, cv2.COLOR_GRAY2RGB)
        self.img_show = np.concatenate([self.img_color, self.img_color],
                                       axis=1)
        self.size = self.img_color.shape[1], self.img_color.shape[0]
        self.calc_surf1()

        FLANN_INDEX_KDTREE = 1
        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
        search_params = dict(checks=100)

        self.matcher = cv2.FlannBasedMatcher(index_params, search_params)

        self.init_poly()

    def settings_loaded(self):
        src = self.init_points.copy()
        w, h = self.size
        src[:, 0] -= w
        dst = cv2.perspectiveTransform(src[None, :, :], self.m_perspective)
        dst = dst.squeeze()
        dst[:, 0] += w
        self.poly.set_points(dst)
        self.poly.update()

    def draw(self):
        if self.poly is None:
            return
        w, h = self.size
        src = self.init_points.copy()
        dst = self.poly.points.copy().astype(np.float32)
        src[:, 0] -= w
        dst[:, 0] -= w
        m = cv2.getPerspectiveTransform(src, dst)
        self.m_perspective = m
        img2 = cv2.warpPerspective(self.img_gray,
                                   m,
                                   self.size,
                                   borderValue=[255] * 4)
        self.img_show[:, w:, :] = img2[:, :, None]
        key_points2, features2 = self.surf.detectAndCompute(img2, None)

        key_positions2 = np.array([kp.pt for kp in key_points2])

        match_list = self.matcher.knnMatch(self.features1, features2, k=1)
        index1 = np.array([m[0].queryIdx for m in match_list])
        index2 = np.array([m[0].trainIdx for m in match_list])

        distances = np.array([m[0].distance for m in match_list])

        n = min(50, len(distances))
        best_index = np.argsort(distances)[:n]
        matched_positions1 = self.key_positions1[index1[best_index]]
        matched_positions2 = key_positions2[index2[best_index]]

        self.m_perspective2, mask = cv2.findHomography(matched_positions1,
                                                       matched_positions2,
                                                       cv2.RANSAC)

        lines = np.concatenate([matched_positions1, matched_positions2],
                               axis=1)
        lines[:, 2] += w
        line_colors = COLORS[mask.ravel()]
        self.lines.set_segments(lines.reshape(-1, 2, 2))
        self.lines.set_color(line_colors)
        self.draw_image(self.img_show)
Beispiel #17
0
class Visualisation(FigureCanvasWxAgg):
    def __init__(self, gui):

        self.fig = pl.figure()  #(9,8), 90)
        FigureCanvasWxAgg.__init__(self, gui, -1, self.fig)
        self.xlim, self.ylim, self.dataon = (), (), False

        # c'est la GUI e tle modele
        self.gui, self.core = gui, gui.core
        # polygone d'interaction sur une zone (pour pouvoir la modifier)
        self.polyInteract = None

        # liste de variables, sert pour la GUI et le combobox sur les variables
        self.listeNomVar = []
        for mod in self.core.modelList:
            for k in gui.linesDic[mod].keys():
                self.listeNomVar.extend(gui.linesDic[mod][k])
        self.curVar, self.curContour = None, 'Charge'  # variable courante selectionne
        self.curMedia, self.curOri, self.curPlan, self.curGroupe = 0, 'Z', 0, None
        # variable pour savoir si on est en cours de tracage d'une zone
        self.typeZone = -1

        # coordonnees et zone de la zone que l'on est en train de creer
        self.curZone = None  # objet graphique (ligne, point, rect..)
        self.x1, self.y1 = [], []
        self.tempZoneVal = []  # liste de values pour polyV
        self.calcE = 0
        self.calcT = 0
        self.calcR = 0
        # dit si calcule effectue ou non

        # dictionnaire qui est compose des variables de l'Aquifere
        # a chaque variable est associe une liste de zones
        self.listeZone, self.listeZoneText, self.listeZmedia = {}, {}, {}
        for i in range(len(self.listeNomVar)):
            #print self.listeNomVar[i]
            self.listeZone[self.listeNomVar[i]] = []
            self.listeZoneText[self.listeNomVar[i]] = []
            self.listeZmedia[self.listeNomVar[i]] = []

        # toolbar de la visu, de type NavigationToolbar2Wx
        self.toolbar = NavigationToolbar2Wx(self)
        self.toolbar.Realize()
        # ajout du subplot a la figure
        self.cnv = self.fig.add_axes([.05, .05, .9,
                                      .88])  #left,bottom, wide,height
        self.toolbar.update()
        self.pos = self.mpl_connect('motion_notify_event', self.onPosition)

        # create teh major objects:
        self.Contour, self.ContourF, self.ContourLabel, self.Vector = None, None, None, None
        self.Grid, self.Particles, self.Image, self.Map = None, None, None, None

    #####################################################################
    #                     Divers accesseur/mutateurs
    #####################################################################

    def GetToolBar(self):
        return self.toolbar

    def getcurVisu(self):
        return [self.curGroupe, self.curNom, self.curObj]

    def onPosition(self, evt):
        self.gui.onPosition(' x: ' + str(evt.xdata)[:6] + ' y: ' +
                            str(evt.ydata)[:6])

    def delAllObjects(self):
        for v in self.listeZone:
            self.listeZone[v] = []
            self.listeZoneText[v] = []
        self.cnv.lines = []
        self.cnv.collections = []
        self.cnv.artists = []
        self.cnv.images = []
        self.cnv.cla()
        self.draw()

    def setVisu(self, core):
        """creer les objets graphiques a partir des caracteristiques d'un modele
        importe.
        creer les zones avec setAllzones, puis le contour et les vectors ecoult,
        les lignes etle contour pour tracer, contour pour reaction
        depend de l'etat du systeme de la liste graphique
        comme ca tout ca pourra etre visualise sans faire de nouveau calcul
        """
        self.delAllObjects()
        for mod in self.core.modelList:
            self.setAllZones(core.diczone[mod].dic)
        self.initDomain()
        self.draw()

    def setDataOn(self, bool):
        """definit l'affichage ou non des donnees qaund contour"""
        self.dataon = bool

    def redraw(self):
        #self.cnv.set_xlim(self.xlim)
        #self.cnv.set_ylim(self.ylim)
        self.draw()

#    def changeTitre(self,titre):
#        s='';ori=self.curOri
#        if ori in ['X','Y','Z']:
#            plan=self.curPlan;
#            x1,y1 = self.model.Aquifere.getXYticks()
#            zl = self.model.Aquifere.getParm('zList')
#        if ori=='Z': s=' Z = '+ str(zl[plan])
#        if ori=='X': s=' X = '+ str(x1[plan])
#        if ori=='Y': s=' Y = '+ str(y1[plan])
#        pl.title(self.traduit(str(titre))+s[:9],fontsize=20)

    def createAndShowObject(self, dataM, dataV, opt, value=None, color=None):
        """create the Contour, Vector, opt is contour or vector
        """
        if dataM == None: self.drawContour(False)
        else: self.createContour(dataM, value, color)
        if dataV == None: self.drawVector(False)
        else: self.createVector(dataV)

    def drawObject(self, typObj, bool):
        if typObj == 'Map' and self.Map == None and bool == False: return
        exec('self.draw' + typObj + '(' + str(bool) + ')')

    def changeObject(self, groupe, name, value, color):
        if name == 'Grid': self.changeGrid(color)
        elif name == 'Particles':
            self.changeParticles(value=value, color=color)
        elif name == 'Veloc-vect':
            self.changeVector(value, color)
            #elif name=='Visible': self.changeData(value,color)
        else:
            self.changeContour(value, color)

    #####################################################################
    #             Gestion de l'affichage de la grid/map
    #####################################################################
    # methode qui change la taille du domaine d'etude (les values de l'axe
    # de la figure matplotlib en fait) et la taille des cellules d'etude
    def initDomain(self):
        # change value of the axes
        grd = self.core.addin.getFullGrid()
        self.xlim = (grd['x0'], grd['x1'])
        self.ylim = (grd['y0'], grd['y1'])
        p, = pl.plot([0, 1], 'b')
        p.set_visible(False)
        self.transform = p.get_transform()
        self.cnv.set_xlim(self.xlim)
        self.cnv.set_ylim(self.ylim)
        self.createGrid()
        # add basic vector as a linecollection
        dep = rand(2, 2) * 0.
        arr = dep * 1.
        self.Vector = LineCollection(zip(dep, arr))
        self.Vector.set_transform(self.transform)
        self.Vector.set_visible(False)
        #pl.setp(lc,linewidth=.5);
        self.cnv.collections.append(self.Vector)
        self.Vector.data = [0, 0, None, None]


#    def changeDomain(self):
#        self.changeAxesOri('Z',0)
#

    def changeAxesOri(self, ori):
        # change orientation de la visu
        zb = self.core.Zblock
        zlim = (amin(zb), amax(zb))
        if ori == 'Z':
            self.cnv.set_xlim(self.xlim)
            self.cnv.set_ylim(self.ylim)
        elif ori == 'X':
            self.cnv.set_xlim(self.ylim)
            self.cnv.set_ylim(zlim)
        elif ori == 'Y':
            self.cnv.set_xlim(self.xlim)
            self.cnv.set_ylim(zlim)
        self.draw()

    def createGrid(self, col=None):
        if self.Grid == None:
            col = (.6, .6, .6)
            self.Grid = [0, 0, col]
            #self.cnv.collections=[0,0];
        else:
            for i in range(2):
                self.Grid[i].set_visible(False)
        if col == None: col = self.Grid[2]
        else: self.Grid[2] = col
        #print 'create grid',self.Grid,col
        nx, ny, xt, yt = getXYvects(self.core)
        #print 'visu,grid',nx,ny,xt,yt
        if len(self.cnv.collections) < 2: self.cnv.collections = [0, 0]
        l = len(ravel(xt))
        dep = concatenate([xt.reshape((l, 1)), ones((l, 1)) * min(yt)], axis=1)
        arr = concatenate([xt.reshape((l, 1)), ones((l, 1)) * max(yt)], axis=1)
        self.Grid[0] = LineCollection(zip(dep, arr))
        self.cnv.collections[0] = self.Grid[0]
        l = len(ravel(yt))
        dep = concatenate([ones((l, 1)) * min(xt), yt.reshape((l, 1))], axis=1)
        arr = concatenate([ones((l, 1)) * max(xt), yt.reshape((l, 1))], axis=1)
        self.Grid[1] = LineCollection(zip(dep, arr))
        self.cnv.collections[1] = self.Grid[1]
        for i in [0, 1]:
            self.Grid[i].set_transform(self.transform)
            self.Grid[i].set_color(col)
        self.redraw()

    def drawGrid(self, bool):  # works only to remove not to recreate
        col = self.Grid[2]
        for i in [0, 1]:
            self.Grid[i].set_visible(bool)
            self.Grid[i].set_color(col)
        self.redraw()

    def changeGrid(self, color):
        a = color.Get()
        col = (a[0] / 255, a[1] / 255, a[2] / 255)
        for i in [0, 1]:
            self.Grid[i].set_color(col)
        self.Grid[2] = col
        self.redraw()

    #####################################################################
    #             Affichage d'une variable sous forme d'image
    #####################################################################
    # l'image se met en position 1 dans la liste des images
    def createMap(self):
        file = self.gui.map
        mat = Im.imread(file)
        org = 'upper'
        ext = (self.xlim[0], self.xlim[1], self.ylim[0], self.ylim[1])
        self.Map = pl.imshow(mat,
                             origin=org,
                             extent=ext,
                             aspect='auto',
                             interpolation='nearest')
        self.cnv.images = [self.Map]  #
        self.cnv.images[0].set_visible(True)
        self.redraw()

    def drawMap(self, bool):
        if self.Map == None: self.createMap()
        #        self.Map.set_visible(bool)
        self.cnv.images = [self.Map]  #
        self.cnv.images[0].set_visible(bool)
        self.redraw()

    def createImage(self, data):
        #print 'vis img',len(xt),len(yt),shape(mat)
        X, Y, Z = data
        image = pl.pcolormesh(X, Y, Z)  #,norm='Normalize')
        self.cnv.images = [image]
        self.redraw()

    def drawImage(self, bool):
        if len(self.cnv.images) > 0:
            self.cnv.images[0].set_visible(bool)
            self.redraw()

    #####################################################################
    #             Gestion de l'affichage des contours
    #####################################################################

    def createContour(self, data, value=None, col=None):
        """ calcul des contour sa partir de value : value[0] : min
        [1] : max, [2] nb contours, [3] decimales, [4] : 'lin' log' ou 'fix',
        si [4]:fix, alors [5] est la serie des values de contours"""
        X, Y, Z = data
        #print 'visu controu',value,col
        self.cnv.collections = self.cnv.collections[:3]
        self.cnv.artists = []
        V = 11
        Zmin = amin(amin(Z))
        Zmax = amax(amax(Z * (Z < 1e5)))
        if Zmax == Zmin:  # test min=max -> pas de contour
            self.gui.onMessage(' values all equal to ' + str(Zmin))
            return
        if value == None:
            value = [Zmin, Zmax, (Zmax - Zmin) / 10., 2, 'auto', []]
        # adapter le namebre et la value des contours
        val2 = [float(a) for a in value[:3]]
        if value[4] == 'log':  # cas echelle log
            n = int((log10(val2[1]) - log10(max(val2[0], 1e-4))) / val2[2]) + 1
            V = logspace(log10(max(val2[0], 1e-4)), log10(val2[1]), n)
        elif (value[4]
              == 'fix') and (value[5] != None):  # fixes par l'utilisateur
            V = value[5] * 1
            V.append(V[-1] * 2.)
            n = len(V)
        elif value[4] == 'lin':  # cas echelle lineaire
            n = int((val2[1] - val2[0]) / val2[2]) + 1
            V = linspace(val2[0], val2[1], n)
        else:  # cas automatique
            n = 11
            V = linspace(Zmin, Zmax, n)
        # ONE DIMENSIONAL
        r, c = shape(X)
        if r == 1:
            X = concatenate([X, X])
            Y = concatenate([Y - Y * .45, Y + Y * .45])
            Z = concatenate([Z, Z])
        Z2 = ma.masked_where(Z.copy() > 1e5, Z.copy())
        #print value,n,V
        # definir les couleurs des contours
        if col == None:  # or (col==[(0,0,0),(0,0,0),(0,0,0),10]):
            cf = pl.contourf(pl.array(X), pl.array(Y), Z2, V)
            c = pl.contour(pl.array(X), pl.array(Y), Z2, V)
            col = [(0, 0, 255), (0, 255, 0), (255, 0, 0), 10]
        else:
            r, g, b = [], [], []
            lim=((0.,1.,0.,0.),(.1,1.,0.,0.),(.25,.8,0.,0.),(.35,0.,.8,0.),(.45,0.,1.,0.),\
                 (.55,0.,1.,0.),(.65,0.,.8,0.),(.75,0.,0.,.8),(.9,0.,0.,1.),(1.,0.,0.,1.))
            for i in range(len(lim)):
                c1 = lim[i][1] * col[0][0] / 255. + lim[i][2] * col[1][
                    0] / 255. + lim[i][3] * col[2][0] / 255.
                r.append((lim[i][0], c1, c1))
                c2 = lim[i][1] * col[0][1] / 255. + lim[i][2] * col[1][
                    1] / 255. + lim[i][3] * col[2][1] / 255.
                g.append((lim[i][0], c2, c2))
                c3 = lim[i][1] * col[0][2] / 255. + lim[i][2] * col[1][
                    2] / 255. + lim[i][3] * col[2][2] / 255.
                b.append((lim[i][0], c3, c3))
            cdict = {'red': r, 'green': g, 'blue': b}
            my_cmap = mpl.colors.LinearSegmentedColormap(
                'my_colormap', cdict, 256)
            cf = pl.contourf(pl.array(X), pl.array(Y), Z2, V, cmap=my_cmap)
            c = pl.contour(pl.array(X), pl.array(Y), Z2, V, cmap=my_cmap)
        #print col[3]
        for c0 in cf.collections:
            c0.set_alpha(int(col[3]) / 100.)
            #print cl
        if value == None: fmt = '%1.3f'
        else: fmt = '%1.' + str(value[3]) + 'f'
        cl = pl.clabel(c, color='black', fontsize=9, fmt=fmt)
        self.Contour = c
        self.ContourF = cf
        self.ContourLabel = cl
        self.Contour.data = data
        self.redraw()

    def changeContour(self, value, col):
        """ modifie les values d'un contour existant"""
        self.drawContour(False)
        self.createContour(self.Contour.data, value, col)

    def drawContour(self, bool):
        self.cnv.collections = self.cnv.collections[:3]
        self.cnv.artists = []
        self.draw()
        #~ for c in self.Contour.collections :c.set_visible(False)
        #~ for c in self.ContourF.collections :c.set_visible(False)
        #~ for a in self.ContourLabel: a.set_visible(False)
        #~ #self.cnv.collections = self.cnv.collections[:3]
        #~ self.redraw()

    #####################################################################
    #             Gestion de l'affichage de vectors
    #####################################################################
    """vector has been created as the first item of lincollection list
    during domain intialization"""

    def createVector(self, data):
        X, Y, U, V = data
        """ modifie les values de vectors existants"""
        if self.Vector.data[3] == None:  #first vector no color
            ech = 1.
            col = (0, 0, 1)
        else:
            a, b, ech, col = self.Vector.data
            self.drawVector(False)
        l = len(ravel(X))
        dep = concatenate([X.reshape((l, 1)), Y.reshape((l, 1))], axis=1)
        b = X + U * ech
        c = Y + V * ech
        arr = concatenate([b.reshape((l, 1)), c.reshape((l, 1))], axis=1)
        self.Vector = LineCollection(zip(dep, arr))
        self.Vector.set_transform(self.transform)
        self.Vector.set_color(col)
        if len(self.cnv.collections) > 2: self.cnv.collections[2] = self.Vector
        else: self.cnv.collections.append(self.Vector)
        self.Vector.set_visible(True)
        self.Vector.data = [dep, arr, ech, col]
        #print self.Vector.data
        self.redraw()

    def drawVector(self, bool):
        """ dessine les vectors vitesse a partir de x,y,u,v et du
        booleen qui dit s'il faut dessiner ou non """
        self.Vector.set_visible(bool)
        self.redraw()

    def changeVector(self, ech, col=wx.Color(0, 0, 255)):
        """ modifie les values de vectors existants"""
        #self.drawVector(False)
        ech = float(ech)
        #change coordinates
        dep, arr_old, ech_old, col_old = self.Vector.data
        #print shape(dep),shape(arr_old),ech,ech_old
        arr = dep + (arr_old - dep) * ech / ech_old
        # new object
        #self.Vector = LineCollection(zip(dep,arr))
        self.Vector.set_segments(zip(dep, arr))
        #self.Vector.set_transform(self.transform)
        a = col.Get()
        col = (a[0] / 255, a[1] / 255, a[2] / 255)
        self.Vector.set_color(col)
        self.Vector.set_visible(True)
        #self.cnv.collections[2]=self.Vector
        self.Vector.data = [dep, arr, ech, col]
        self.redraw()

    #####################################################################
    #             Gestion de l'affichage de particules
    #####################################################################
    def startParticles(self):
        if self.Particles != None:
            self.partVisible(False)
        self.Particles = {
            'line': [],
            'txt': [],
            'data': [],
            'color': wx.Color(255, 0, 0)
        }
        self.mpl_disconnect(self.toolbar._idPress)
        self.mpl_disconnect(self.toolbar._idRelease)
        self.mpl_disconnect(self.toolbar._idDrag)
        # on capte le clic gauche de la souris
        self.m3 = self.mpl_connect('button_press_event', self.mouseParticles)
        self.stop = False
        #self.createParticles()
        #wx.EVT_LEAVE_WINDOW(self,self.finParticules)  # arrete particules qd on sort de visu

    def mouseParticles(self, evt):
        #test pour savoir si le curseur est bien dans les axes de la figure
        if self.stop: return
        if evt.inaxes is None: return
        if evt.button == 1:
            [xp, yp, tp] = self.core.addin.calcParticle(evt.xdata, evt.ydata)
            #print xp,yp,tp
            self.updateParticles(xp, yp, tp)
        elif evt.button == 3:
            self.mpl_disconnect(self.m3)
            self.stop = True
            self.gui.actions('zoneEnd')

    def updateParticles(self, X, Y, T, freq=10):
        """ rajouter une ligne dans le groupe de particules"""
        ligne, = pl.plot(pl.array(X), pl.array(Y), 'r')
        if freq > 0:
            tx, ty, tt = X[0::freq], Y[0::freq], T[0::freq]
            txt = []
            for i in range(len(tx)):
                a = str(tt[i])
                b = a.split('.')
                ln = max(4, len(b[0]))
                txt.append(pl.text(tx[i], ty[i], a[:ln], fontsize='8'))
        self.Particles['line'].append(ligne)
        self.Particles['txt'].append(txt)
        self.Particles['data'].append((X, Y, T))
        self.gui_repaint()  # bug matplotlib v2.6 for direct draw!!!
        self.draw()

    def drawParticles(self, bool, value=None):
        if self.Particles == None: return
        self.partVisible(bool)
        self.gui_repaint()
        self.draw()

    def changeParticles(self, value=None, color=wx.Color(255, 0, 0)):
        self.partVisible(False)
        self.Particles['color'], self.Particles['txt'] = color, []
        for i, data in enumerate(self.Particles['data']):
            X, Y, T = data
            tx, ty, tt = self.ptsPartic(X, Y, T, float(value))
            txt = []
            for i in range(len(tx)):
                a = str(tt[i])
                b = a.split('.')
                ln = max(4, len(b[0]))
                txt.append(pl.text(tx[i], ty[i], a[:ln], fontsize='8'))
            self.Particles['txt'].append(txt)
        self.partVisible(True)
        self.gui_repaint()
        self.draw()

    def partVisible(self, bool):
        a = self.Particles['color'].Get()
        color = (a[0] / 255, a[1] / 255, a[2] / 255)
        for line in self.Particles['line']:
            line.set_visible(bool)
            line.set_color(color)
        for points in self.Particles['txt']:
            for txt in points:
                txt.set_visible(bool)

    def ptsPartic(self, X, Y, T, dt):
        #tx,ty,tt,i1=iphtC1.ptsLigne(X,Y,T,dt);
        tmin = amin(T)
        tmax = amax(T)
        t1 = linspace(tmin, tmax, int((tmax - tmin) / dt))
        f = interp1d(T, X)
        xn = f(t1)
        f = interp1d(T, Y)
        yn = f(t1)
        return xn, yn, t1

    #####################################################################
    #                   Gestion des zones de la visu
    #####################################################################
    # affichage de toutes les zones d'une variable
    def showVar(self, var, media):
        self.setUnvisibleZones()
        self.curVar, self.curMedia = var, media
        for i in range(len(self.listeZone[var])):
            #print 'vis showvar',self.listeZmedia[var][i]
            if (media in self.listeZmedia[var][i]) or (media == -1):
                self.listeZone[var][i].set_visible(True)
                self.visibleText(self.listeZoneText[var][i], True)
        #self.changeTitre(var)
        self.redraw()

    def showData(self, liForage, liData):
        self.setUnvisibleZones()
        self.curVar = 'data'
        self.listeZoneText['data'] = []
        for zone in self.listeZone['Forages']:
            zone.set_visible(True)
        lZone = self.model.Aquifere.getZoneList('Forages')
        txt = []
        for z in lZone:
            x, y = zip(*z.getXy())
            name = z.getNom()
            if name in liForage:
                ind = liForage.index(name)
                txt.append(
                    pl.text(mean(x), mean(y), name + '\n' + str(liData[ind])))
        obj = GraphicObject('zoneText', txt, True, None)
        self.addGraphicObject(obj)
        self.redraw()

    def changeData(self, taille, col):
        obj = self.listeZoneText['data'][0].getObject()
        for txt in obj:
            txt.set_size(taille)
            txt.set_color(col)

    def getcurZone(self):
        return self.curZone

    def setcurZone(self, zone):
        self.curZone = zone

    # methode qui efface toutes les zones de toutes les variables
    def setUnvisibleZones(self):
        for v in self.listeZone:
            for zone in self.listeZone[v]:
                zone.set_visible(False)
            for txt in self.listeZoneText[v]:
                if type(txt) == type([5, 6]):
                    for t in txt:
                        t.set_visible(False)
                else:
                    txt.set_visible(False)

    # methode appelee par la GUI lorsqu'on veut creer une nouvelle zone
    def setZoneReady(self, typeZone, curVar):
        self.typeZone = typeZone
        self.curVar = curVar
        self.tempZoneVal = []
        # on deconnecte la toolbar pour activer la formaiton de zones
        self.mpl_disconnect(self.toolbar._idPress)
        self.mpl_disconnect(self.toolbar._idRelease)
        self.mpl_disconnect(self.toolbar._idDrag)
        # on capte le clic gauche de la souris
        self.m1 = self.mpl_connect('button_press_event', self.mouse_clic)

    def setZoneEnd(self, evt):
        # on informe la GUI qui informera le model
        xv, yv = self.getcurZone().get_xdata(), self.getcurZone().get_ydata()
        if len(self.tempZoneVal) > 1: xy = zip(xv, yv, self.tempZoneVal)
        else: xy = zip(xv, yv)
        # effacer zone pour si cancel, remettre de l'ordre
        self.curZone.set_visible(False)
        self.curZone = None
        self.x1, self.y1 = [], []
        self.gui.addBox.onZoneCreate(self.typeZone, xy)

    def addZone(self, media, name, val, coords, visible=True):
        """ ajout de la zone et du texte (name+value) sur visu 
        """
        #print 'visu',coords
        a = zip(*coords)
        txt = []
        #print name,a
        if len(a) == 0: return
        if len(a) == 2: x, y = a
        elif len(a) == 3: x, y, z = a
        if len(x) == 1:
            zone = Line2D(x, y, marker='+', markersize=10, markeredgecolor='r')
        else:
            zone = Line2D(x, y)
        zone.verts = coords
        zone.set_visible(visible)
        if type(media) != type([2]): media = [int(media)]
        self.curMedia = media
        self.cnv.add_line(zone)
        if self.typeZone == "POLYV" or len(coords[0]) == 3:
            txt = [
                pl.text(
                    mean(x) * .1 + x[0] * .9,
                    mean(y) * .1 + y[0] * .9, name + '\n' + str(val)[:16])
            ]
            for i in range(len(x)):
                t = pl.text(x[i], y[i], str(z[i]))
                t.set_visible(visible)
                txt.append(t)
        else:
            txt = pl.text(
                mean(x) * .1 + x[0] * .9,
                mean(y) * .1 + y[0] * .9, name + '\n' + str(val)[:16])
        self.listeZone[self.curVar].append(zone)
        self.listeZmedia[self.curVar].append(media)
        self.listeZoneText[self.curVar].append(txt)
        if visible: self.redraw()

    def delZone(self, Variable, ind):
        """methode de suppression de la zone d'indice ind de Variable
        """
        if self.listeZone.has_key(Variable) == False: return
        if len(self.listeZone[Variable]) > ind:
            self.listeZone[Variable][ind].set_visible(False)
            self.visibleText(self.listeZoneText[Variable][ind], False)
            self.listeZone[Variable][ind:ind + 1] = []
            self.listeZoneText[Variable][ind:ind + 1] = []
            self.listeZmedia[Variable].pop(ind)
            self.redraw()

    def visibleText(self, text, bool):
        if type(text) == type([5, 6]):
            for t in text:
                t.set_visible(bool)
        else:
            text.set_visible(bool)

    def delAllZones(self, Variable):
        lz = self.listeZone[Variable]
        for i in range(len(lz)):
            lz[i].setVisible(False)
            self.listeZoneText[Variable][i].set_visible(False)
        self.listeZone[Variable] = []
        self.listeZmedia[Variable] = []
        self.listeZoneText[Variable] = []
        self.redraw()

    def modifValZone(self, nameVar, ind, val, xy):
        """modify the value (or list of value) for the zone nameVar 
        the text contains name et value"""

    def modifZoneAttr(self, nameVar, ind, val, media, xy):
        # modify xy
        zone = self.listeZone[nameVar][ind]
        if len(xy[0]) == 3: x, y, z = zip(*xy)
        else: x, y = zip(*xy)
        zone.set_data(x, y)
        # modify media
        if type(media) != type([2]): media = [int(media)]
        self.listeZmedia[nameVar][ind] = media
        # modify text
        textObj = self.listeZoneText[nameVar][ind]
        if type(textObj) == type([2, 3]):
            name = pl.getp(textObj[0], 'text').split('\n')[0]
            pl.setp(textObj[0], text=name + '\n' + str(val)[:16])
            for i in range(len(z)):
                pl.setp(textObj[i + 1], text=str(z[i]))
        else:
            name = pl.getp(textObj, 'text').split('\n')[0]
            pl.setp(textObj, text=name + '\n' + str(val)[:16])
        self.redraw()

    def modifZone(self, nameVar, ind):
        """ modification interactive des points de la zone d'indice ind de name nameVar
        """
        zone = self.listeZone[nameVar][ind]
        self.polyInteract = PolygonInteractor(self, zone, nameVar, ind)
        zone.set_visible(False)
        self.cnv.add_line(self.polyInteract.line)
        self.draw()

    def finModifZone(self):
        """fonction qui met fin a la modification de la zone courante"""
        if self.polyInteract != None:
            self.polyInteract.set_visible(False)
            self.polyInteract.disable()
            # on informe la GUI des nouvelles coordonnees
            var, ind = self.polyInteract.typeVariable, self.polyInteract.ind
            x, y = self.polyInteract.lx, self.polyInteract.ly
            #print x,y
            xy = zip(x, y)
            self.gui.modifBox.onModifZoneCoord(var, ind, xy)
            zone = self.listeZone[var][ind]
            zone.set_data(x, y)
            zone.set_visible(True)
            # on modifie la position du texte
            txt = self.listeZoneText[var][ind]
            if type(txt) == type([5, 6]):
                txt[0].set_position((x[0], y[0]))
                for i in range(1, len(txt)):
                    txt[i].set_position((x[i - 1], y[i - 1]))
            else:
                txt.set_position(
                    (mean(x) * .1 + x[0] * .9, mean(y) * .1 + y[0] * .9))
            self.draw()

    def setAllZones(self, dicZone):
        """updates all zones when a file is imported
        """
        for var in dicZone.keys():
            self.listeZone[var] = []
            self.curVar = var
            lz = dicZone[var]
            nbz = len(lz['name'])
            for i in range(nbz):
                if lz['name'][i] == '': continue
                coords = lz['coords'][i]
                self.addZone(lz['media'][i], lz['name'][i], lz['value'][i],
                             coords)
        self.setUnvisibleZones()
        #self.redraw()

    #####################################################################
    #             Gestion de l'interaction de la souris
    #             pour la creation des zones
    #####################################################################

    #methode executee lors d'un clic de souris dans le canvas
    def mouse_clic(self, evt):
        if evt.inaxes is None:
            return
        if self.curZone == None:  # au depart
            self.x1 = [float(str(evt.xdata)[:6])
                       ]  # pour aovir chiffre pas trop long
            self.y1 = [float(str(evt.ydata)[:6])]
            self.setcurZone(Line2D(self.x1, self.y1))
            self.cnv.add_line(self.curZone)
            self.m2 = self.mpl_connect('motion_notify_event',
                                       self.mouse_motion)
            if self.typeZone == "POLYV":
                self.polyVdialog()
            if self.typeZone == "POINT":
                self.deconnecte()
                self.setZoneEnd(evt)

        else:  # points suivants
            if self.typeZone == "POLYV":  # and evt.button ==1:
                if evt.button == 3: self.deconnecte()
                rep = self.polyVdialog()  # dialog for the current value of z
                if rep == False: return
            self.x1.append(float(str(evt.xdata)[:6]))
            self.y1.append(float(str(evt.ydata)[:6]))
            if self.typeZone == "LINE" or self.typeZone == "RECT":
                self.deconnecte()  #fin des le 2eme point
                self.setZoneEnd(evt)
            if self.typeZone in ["POLY", "POLYV"
                                 ] and evt.button == 3:  # fin du polygone
                self.deconnecte()
                self.setZoneEnd(evt)

    #methode executee lors du deplacement de la souris dans le canvas suite a un mouse_clic
    def mouse_motion(self, evt):
        time.sleep(0.1)
        if evt.inaxes is None: return
        lx, ly = self.x1 * 1, self.y1 * 1
        if self.typeZone == "RECT":
            xr, yr = self.creeRectangle(self.x1[0], self.y1[0], evt.xdata,
                                        evt.ydata)
            self.curZone.set_data(xr, yr)
        else:  # autres cas
            lx.append(evt.xdata)
            ly.append(evt.ydata)
            self.curZone.set_data(lx, ly)
        self.draw()

    def polyVdialog(self):
        lst0 = [('Value', 'Text', 0)]
        dialg = config.dialogs.genericDialog(self.gui, 'value', lst0)
        values = dialg.getValues()
        if values != None:
            val = float(values[0])
            #print val*2
            self.tempZoneVal.append(val)
            return True
        else:
            return False

    def creeRectangle(self, x1, y1, x2, y2):
        xr = [x1, x2, x2, x1, x1]
        yr = [y1, y1, y2, y2, y1]
        return [xr, yr]

    def deconnecte(self):
        # deconnecter la souris
        self.mpl_disconnect(self.m1)
        self.mpl_disconnect(self.m2)

    ###################################################################
    #   deplacer une zone ##############################

    def startMoveZone(self, nameVar, ind):
        """ methode qui demarre les interactions avec la souris"""
        # reperer la zone et rajouter un point de couleur
        self.nameVar, self.ind = nameVar, ind
        zone = self.listeZone[nameVar][ind]
        self.curZone = zone
        self.lx, self.ly = zone.get_xdata(), zone.get_ydata()
        self.xstart = self.lx[0] * 1.
        self.ystart = self.ly[0] * 1.
        self.ptstart = Line2D([self.xstart], [self.ystart],
                              marker='o',
                              markersize=7,
                              markerfacecolor='r')
        self.cnv.add_line(self.ptstart)
        self.m1 = self.mpl_connect('button_press_event', self.zoneM_clic)
        self.draw()

    def zoneM_clic(self, evt):
        """ action au premier clic"""
        if evt.inaxes is None: return
        #if evt.button==3: self.finMoveZone(evt) # removed OA 6/2/13
        d = sqrt((evt.xdata - self.xstart)**2 + (evt.ydata - self.ystart)**2)
        xmn, xmx = self.xlim
        ymn, ymx = self.ylim
        dmax = sqrt((xmx - xmn)**2 + (ymx - ymn)**2) / 100
        if d > dmax: return
        self.m2 = self.mpl_connect('motion_notify_event', self.zone_motion)
        self.m3 = self.mpl_connect('button_release_event', self.finMoveZone)
        self.mpl_disconnect(self.m1)

    def zone_motion(self, evt):
        """ methode pour deplacer la zone quand on deplace la souris"""
        # reperer le curseur proche du point de couleur
        time.sleep(0.1)
        if evt.inaxes is None: return
        # changer els coord du polygone lorsque l'on deplace la souris
        lx = [a + evt.xdata - self.xstart for a in self.lx]
        ly = [a + evt.ydata - self.ystart for a in self.ly]
        self.ptstart.set_data(lx[0], ly[0])
        self.curZone.set_data(lx, ly)
        self.draw()

    def finMoveZone(self, evt):
        """ methode pour arret de deplacement de la zone"""
        # lorsque l'on relache la souris arreter les mpl connect
        self.mpl_disconnect(self.m2)
        self.mpl_disconnect(self.m3)
        # renvoyer les nouvelels coordonnes au modele
        lx, ly = self.curZone.get_xdata(), self.curZone.get_ydata()
        self.listeZone[self.nameVar][self.ind].set_data(lx, ly)
        xy = zip(lx, ly)
        self.gui.modifBox.onModifZoneCoord(self.nameVar, self.ind, xy)
        # on modifie la position du texte
        txt = self.listeZoneText[self.nameVar][self.ind]
        if type(txt) == type([5, 6]):
            txt[0].set_position((lx[0], ly[0]))
            for i in range(1, len(txt)):
                txt[i].set_position(
                    (lx[i - 1],
                     ly[i - 1]))  #-1 because 1st position zone names
        else:
            txt.set_position(
                (mean(lx) * .1 + lx[0] * .9, mean(ly) * .1 + ly[0] * .9))
        self.ptstart.set_visible(False)
        self.ptstart = None
        self.curZone = None
        self.draw()
Beispiel #18
0
class SpikeBrowserUI(object):
    def __init__(self, window):
        self.window = window
        self.sp_win = [-0.8, 1]
        self.spike_collection = None
        
        self.fig = Figure((5, 4), 75)
        
        self.canvas = window.get_canvas(self.fig)
        
        self._mpl_init()
        self.canvas.mpl_connect('key_press_event', self._on_key)
        self.window.set_scroll_handler(self.OnScrollEvt)
        
        

        
    def _mpl_init(self):
        self.fig.clf()
        self.axes = self.fig.add_axes([0.05, 0.1, 0.95,0.9])
        self.ax_prev = self.fig.add_axes([0.8, 0.0, 0.1,0.05])
        self.ax_next = self.fig.add_axes([0.9, 0.0, 0.1,0.05])
        
        self.b_next = Button(self.ax_next, 'Next')
        self.b_prev = Button(self.ax_prev, "Prev")

        self.b_next.on_clicked(self._next_spike)
        self.b_prev.on_clicked(self._prev_spike)
        self.i_spike = 0
        self.i_start = 0
        self.line_collection = None
        
    def _next_spike(self, event):
        try:
            if self.i_spike<len(self.spt)-1:
                self.i_spike+=1
            t_spk = self.spt[self.i_spike]
            i_start =  int(np.ceil(t_spk/1000.*self.FS-self.i_window/2.))
            i_start = np.maximum(self.i_min, i_start)
            i_start = np.minimum(self.i_max, i_start)
            self.i_start = i_start
            self.i_end = self.i_start + self.i_window
            self.window.set_scroll_pos(self.i_start)
            self.draw_plot()
        except IndexError:
            pass
        
    def _prev_spike(self, event):
        try:
            if self.i_spike>0:
                self.i_spike-=1
            t_spk = self.spt[self.i_spike]
            i_start =  int(np.ceil(t_spk/1000.*self.FS-self.i_window/2.))
            i_start = np.maximum(self.i_min, i_start)
            i_start = np.minimum(self.i_max, i_start)
            self.i_start = i_start
            self.i_end = self.i_start + self.i_window
            self.window.set_scroll_pos(self.i_start)
            self.draw_plot()
        except IndexError:
            pass

    def _on_key(self, event):
        if event.key=='+' or event.key=='=':
            self.ylims/=2.
        elif event.key == '-':
            self.ylims*=2.
        else:
            return
        offset = self.ylims[1]-self.ylims[0]
        self.offsets = np.arange(self.n_chans)*offset
        self.draw_plot()

    def set_spiketimes(self, spk_idx, labels=None, all_labels=None):
        if spk_idx:
            self.spt = spk_idx['data']
            if labels is not None:
                self.labels = labels
                if all_labels is None:
                    self.color_func = label_color(np.unique(labels))
                else:
                    self.color_func = label_color(all_labels)
            else:
                self.labels = None
            
            self.ax_next.set_visible(True)
            self.ax_prev.set_visible(True)
                
        else:
            self.spt = None
            self.ax_next.set_visible(False)
            self.ax_prev.set_visible(False)
            
    def set_data(self, data):

        self.x = data['data']
        self.FS = data['FS']
        n_chans, n_pts = self.x.shape
        
        #reset spike times data/hide buttons
        self.set_spiketimes(None)
        
        self.i_window = int(self.winsz/1000.*self.FS)
        # Extents of data sequence: 
        self.i_min = 0
        self.i_max = n_pts - self.i_window
        self.n_chans = n_chans


        self.window.set_scroll_max(self.i_max, self.i_window)
    

        # Indices of data interval to be plotted:
        
        self.i_end = self.i_start + self.i_window
        
        
        self.time = np.arange(self.i_start,self.i_end)*1./self.FS
        
        self.segs = np.empty((n_chans, self.i_window, 2))
        self.segs[:,:,0] = self.time[np.newaxis,:]
        self.segs[:,:,1] = self.x[:,self.i_start:self.i_end]
         
        ylims = (self.segs[:,:,1].min(), self.segs[:,:,1].max())
        offset = ylims[1]-ylims[0]
        self.offsets = np.arange(n_chans)*offset
        self.segs[:,:,1] += self.offsets[:,np.newaxis]
        
        self.ylims = np.array(ylims)
        
        if self.line_collection:
            self.line_collection.remove()

        self.line_collection = LineCollection(self.segs,
                                              offsets=None,
                                              transform=self.axes.transData,
                                              color='k')

        self.axes.add_collection(self.line_collection)
        self.axes.set_xlim((self.time[0], self.time[-1]))
        self.axes.set_ylim((self.ylims[0]+self.offsets.min(), 
                            self.ylims[1]+self.offsets.max()))

        self.canvas.draw()

    def draw_plot(self):

        self.time = np.arange(self.i_start,self.i_end)*1./self.FS
        self.segs[:,:,0] = self.time[np.newaxis,:]
        self.segs[:,:,1] = self.x[:,self.i_start:self.i_end]+self.offsets[:,np.newaxis]
        self.line_collection.set_segments(self.segs)

        # Adjust plot limits:
        self.axes.set_xlim((self.time[0], self.time[-1]))
        self.axes.set_ylim((self.ylims[0]+self.offsets.min(), 
                            self.ylims[1]+self.offsets.max()))
        
        if self.spt is not None:
            self.draw_spikes()
        # Redraw:                  
        self.canvas.draw()
        
    def draw_spikes(self):
        if self.spike_collection is not None:
            self.spike_collection.remove()
            self.spike_collection = None
        sp_win = self.sp_win 
        time = self.segs[0,:,0]*1000.
        t_min, t_max = time[0]-sp_win[0], time[-1]-sp_win[1]
        spt = self.spt[(self.spt>t_min) & (self.spt<t_max)]
        if len(spt)>0:
            n_pts = int((sp_win[1]-sp_win[0])/1000.*self.FS)
            sp_segs = np.empty((len(spt), self.n_chans, n_pts, 2))
            for i in range(len(spt)):
                start, = np.nonzero(time>=(spt[i]+sp_win[0]))
                start = start[0]
                stop  = start+n_pts
                sp_segs[i,:,:,0] = (time[np.newaxis,start:stop]/1000.)
                sp_segs[i,:,:,1] = self.segs[:, start:stop, 1]
            sp_segs = sp_segs.reshape(-1, n_pts, 2)
            if self.labels is not None:
                labs = self.labels[(self.spt>t_min) & (self.spt<t_max)]
                colors = np.repeat(self.color_func(labs), self.n_chans, 0)
            else:
                colors = 'r'
            self.spike_collection = LineCollection(sp_segs,
                                                  offsets=None,
                                                  color=colors,
                                                  transform=self.axes.transData)
            self.axes.add_collection(self.spike_collection)
            
        
    def OnScrollEvt(self, pos):

        # Update the indices of the plot:
        self.i_start = self.i_min + pos
        self.i_end = self.i_min + self.i_window + pos
        t_center = (self.i_start+self.i_window/2.)*1000./self.FS
        idx, = np.where(self.spt<t_center)
        if len(idx)>0:
            self.i_spike = idx[-1]
        else:
            self.i_spike = 0
        self.draw_plot()
Beispiel #19
0
 def set_segments(self, segments):
     """
     Set 3D segments.
     """
     self._segments3d = segments
     LineCollection.set_segments(self, [])
Beispiel #20
0
class SkeletonBuilder:
    def __init__(self, config_path):
        self.config_path = config_path
        self.cfg = read_config(config_path)
        # Find uncropped labeled data
        self.df = None
        found = False
        root = os.path.join(self.cfg["project_path"], "labeled-data")
        for dir_ in os.listdir(root):
            folder = os.path.join(root, dir_)
            if os.path.isdir(folder) and not any(
                    folder.endswith(s) for s in ("cropped", "labeled")):
                self.df = pd.read_hdf(
                    os.path.join(folder,
                                 f'CollectedData_{self.cfg["scorer"]}.h5'))
                row, col = self.pick_labeled_frame()
                if "individuals" in self.df.columns.names:
                    self.df = self.df.xs(col, axis=1, level="individuals")
                self.xy = self.df.loc[row].values.reshape((-1, 2))
                missing = np.flatnonzero(np.isnan(self.xy).all(axis=1))
                if not missing.size:
                    found = True
                    break
        if self.df is None:
            raise IOError("No labeled data were found.")

        self.bpts = self.df.columns.get_level_values("bodyparts").unique()
        if not found:
            warnings.warn(
                f"A fully labeled animal could not be found. "
                f"{', '.join(self.bpts[missing])} will need to be manually connected in the config.yaml."
            )
        self.tree = KDTree(self.xy)
        # Handle image previously annotated on a different platform
        sep = "/" if "/" in row else "\\"
        if sep != os.path.sep:
            row = row.replace(sep, os.path.sep)
        self.image = io.imread(os.path.join(self.cfg["project_path"], row))
        self.inds = set()
        self.segs = set()
        # Draw the skeleton if already existent
        if self.cfg["skeleton"]:
            for bone in self.cfg["skeleton"]:
                pair = np.flatnonzero(self.bpts.isin(bone))
                if len(pair) != 2:
                    continue
                pair_sorted = tuple(sorted(pair))
                self.inds.add(pair_sorted)
                self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :])))
        self.lines = LineCollection(self.segs,
                                    colors=mcolors.to_rgba(
                                        self.cfg["skeleton_color"]))
        self.lines.set_picker(True)
        self.show()

    def pick_labeled_frame(self):
        # Find the most 'complete' animal
        try:
            count = self.df.groupby(level="individuals", axis=1).count()
            if "single" in count:
                count.drop("single", axis=1, inplace=True)
        except KeyError:
            count = self.df.count(axis=1).to_frame()
        mask = count.where(count == count.values.max())
        kept = mask.stack().index.to_list()
        np.random.shuffle(kept)
        row, col = kept.pop()
        return row, col

    def show(self):
        self.fig = plt.figure()
        ax = self.fig.add_subplot(111)
        ax.axis("off")
        lo = np.nanmin(self.xy, axis=0)
        hi = np.nanmax(self.xy, axis=0)
        center = (hi + lo) / 2
        w, h = hi - lo
        ampl = 1.3
        w *= ampl
        h *= ampl
        ax.set_xlim(center[0] - w / 2, center[0] + w / 2)
        ax.set_ylim(center[1] - h / 2, center[1] + h / 2)
        ax.imshow(self.image)
        ax.scatter(*self.xy.T, s=self.cfg["dotsize"]**2)
        ax.add_collection(self.lines)
        ax.invert_yaxis()

        self.lasso = LassoSelector(ax, onselect=self.on_select)
        ax_clear = self.fig.add_axes([0.85, 0.55, 0.1, 0.1])
        ax_export = self.fig.add_axes([0.85, 0.45, 0.1, 0.1])
        self.clear_button = Button(ax_clear, "Clear")
        self.clear_button.on_clicked(self.clear)
        self.export_button = Button(ax_export, "Export")
        self.export_button.on_clicked(self.export)
        self.fig.canvas.mpl_connect("pick_event", self.on_pick)
        plt.show()

    def clear(self, *args):
        self.inds.clear()
        self.segs.clear()
        self.lines.set_segments(self.segs)

    def export(self, *args):
        inds_flat = set(ind for pair in self.inds for ind in pair)
        unconnected = [i for i in range(len(self.xy)) if i not in inds_flat]
        if len(unconnected):
            warnings.warn(
                f'Unconnected {", ".join(self.bpts[unconnected])}. '
                f"It is desirable that all bodyparts be connected for multi-animal projects."
            )
        self.cfg["skeleton"] = [
            tuple(self.bpts[list(pair)]) for pair in self.inds
        ]
        write_config(self.config_path, self.cfg)

    def on_pick(self, event):
        if event.mouseevent.button == 3:
            removed = event.artist.get_segments().pop(event.ind[0])
            self.segs.remove(tuple(map(tuple, removed)))
            self.inds.remove(tuple(self.tree.query(removed)[1]))

    def on_select(self, verts):
        self.path = Path(verts)
        self.verts = verts
        inds = self.tree.query_ball_point(verts, 5)
        inds_unique = []
        for lst in inds:
            if len(lst) and lst[0] not in inds_unique:
                inds_unique.append(lst[0])
        for pair in zip(inds_unique, inds_unique[1:]):
            pair_sorted = tuple(sorted(pair))
            self.inds.add(pair_sorted)
            self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :])))
        self.lines.set_segments(self.segs)
        self.fig.canvas.draw_idle()
def draw_prediction_on_image(image,
                             keypoints_with_scores,
                             crop_region=None,
                             close_figure=False,
                             output_image_height=None):
    """Draws the keypoint predictions on image.

  Args:
    image: A numpy array with shape [height, width, channel] representing the
      pixel values of the input image.
    keypoints_with_scores: A numpy array with shape [1, 1, 17, 3] representing
      the keypoint coordinates and scores returned from the MoveNet model.
    crop_region: A dictionary that defines the coordinates of the bounding box
      of the crop region in normalized coordinates (see the init_crop_region
      function below for more detail). If provided, this function will also
      draw the bounding box on the image.
    output_image_height: An integer indicating the height of the output image.
      Note that the image aspect ratio will be the same as the input image.

  Returns:
    A numpy array with shape [out_height, out_width, channel] representing the
    image overlaid with keypoint predictions.
  """
    height, width, channel = image.shape
    aspect_ratio = float(width) / height
    fig, ax = plt.subplots(figsize=(12 * aspect_ratio, 12))
    # To remove the huge white borders
    fig.tight_layout(pad=0)
    ax.margins(0)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    plt.axis('off')

    im = ax.imshow(image)
    line_segments = LineCollection([], linewidths=(4), linestyle='solid')
    ax.add_collection(line_segments)
    # Turn off tick labels
    scat = ax.scatter([], [], s=60, color='#FF1493', zorder=3)

    (keypoint_locs, keypoint_edges,
     edge_colors) = _keypoints_and_edges_for_display(keypoints_with_scores,
                                                     height, width)

    line_segments.set_segments(keypoint_edges)
    line_segments.set_color(edge_colors)
    if keypoint_edges.shape[0]:
        line_segments.set_segments(keypoint_edges)
        line_segments.set_color(edge_colors)
    if keypoint_locs.shape[0]:
        scat.set_offsets(keypoint_locs)

    if crop_region is not None:
        xmin = max(crop_region['x_min'] * width, 0.0)
        ymin = max(crop_region['y_min'] * height, 0.0)
        rec_width = min(crop_region['x_max'], 0.99) * width - xmin
        rec_height = min(crop_region['y_max'], 0.99) * height - ymin
        rect = patches.Rectangle((xmin, ymin),
                                 rec_width,
                                 rec_height,
                                 linewidth=1,
                                 edgecolor='b',
                                 facecolor='none')
        ax.add_patch(rect)

    fig.canvas.draw()
    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    image_from_plot = image_from_plot.reshape(
        fig.canvas.get_width_height()[::-1] + (3, ))
    plt.close(fig)
    if output_image_height is not None:
        output_image_width = int(output_image_height / height * width)
        image_from_plot = cv2.resize(image_from_plot,
                                     dsize=(output_image_width,
                                            output_image_height),
                                     interpolation=cv2.INTER_CUBIC)
    return image_from_plot
def make_labeled_images_from_dataframe(
    df,
    cfg,
    destfolder="",
    scale=1.0,
    dpi=100,
    keypoint="+",
    draw_skeleton=True,
    color_by="bodypart",
):
    """
    Write labeled frames to disk from a DataFrame.
    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing the labeled data. Typically, the DataFrame is obtained
        through pandas.read_csv() or pandas.read_hdf().
    cfg : dict
        Project configuration.
    destfolder : string, optional
        Destination folder into which images will be stored. By default, same location as the labeled data.
        Note that the folder will be created if it does not exist.
    scale : float, optional
        Up/downscale the output dimensions.
        By default, outputs are of the same dimensions as the original images.
    dpi : int, optional
        Output resolution. 100 dpi by default.
    keypoint : str, optional
        Keypoint appearance. By default, keypoints are marked by a + sign.
        Refer to https://matplotlib.org/3.2.1/api/markers_api.html for a list of all possible options.
    draw_skeleton : bool, optional
        Whether to draw the animal skeleton as defined in *cfg*. True by default.
    color_by : str, optional
        Color scheme of the keypoints. Must be either 'bodypart' or 'individual'.
        By default, keypoints are colored relative to the bodypart they represent.
    """

    bodyparts = df.columns.get_level_values("bodyparts")
    bodypart_names = bodyparts.unique()
    nbodyparts = len(bodypart_names)
    bodyparts = bodyparts[::2]
    draw_skeleton = (draw_skeleton
                     and cfg["skeleton"])  # Only draw if a skeleton is defined

    if color_by == "bodypart":
        map_ = bodyparts.map(dict(zip(bodypart_names, range(nbodyparts))))
        cmap = get_cmap(nbodyparts, cfg["colormap"])
        colors = cmap(map_)
    elif color_by == "individual":
        try:
            individuals = df.columns.get_level_values("individuals")
            individual_names = individuals.unique().to_list()
            nindividuals = len(individual_names)
            individuals = individuals[::2]
            map_ = individuals.map(
                dict(zip(individual_names, range(nindividuals))))
            cmap = get_cmap(nindividuals, cfg["colormap"])
            colors = cmap(map_)
        except KeyError as e:
            raise Exception(
                "Coloring by individuals is only valid for multi-animal data"
            ) from e
    else:
        raise ValueError(
            "`color_by` must be either `bodypart` or `individual`.")

    bones = []
    if draw_skeleton:
        for bp1, bp2 in cfg["skeleton"]:
            match1, match2 = [], []
            for j, bp in enumerate(bodyparts):
                if bp == bp1:
                    match1.append(j)
                elif bp == bp2:
                    match2.append(j)
            bones.extend(zip(match1, match2))
    ind_bones = tuple(zip(*bones))

    images_list = [
        os.path.join(cfg["project_path"], *tuple_)
        for tuple_ in df.index.tolist()
    ]
    if not destfolder:
        destfolder = os.path.dirname(images_list[0])
    tmpfolder = destfolder + "_labeled"
    attempttomakefolder(tmpfolder)
    ic = io.imread_collection(images_list)

    h, w = ic[0].shape[:2]
    all_same_shape = True
    for array in ic[1:]:
        if array.shape[:2] != (h, w):
            all_same_shape = False
            break

    xy = df.values.reshape((df.shape[0], -1, 2))
    segs = xy[:, ind_bones].swapaxes(1, 2)

    s = cfg["dotsize"]
    alpha = cfg["alphavalue"]
    if all_same_shape:  # Very efficient, avoid re-drawing the whole plot
        fig, ax = prepare_figure_axes(w, h, scale, dpi)
        im = ax.imshow(np.zeros((h, w)), "gray")
        pts = [
            ax.plot([], [], keypoint, ms=s, alpha=alpha, color=c)[0]
            for c in colors
        ]
        coll = LineCollection([], colors=cfg["skeleton_color"], alpha=alpha)
        ax.add_collection(coll)
        for i in trange(len(ic)):
            filename = ic.files[i]
            ind = images_list.index(filename)
            coords = xy[ind]
            img = ic[i]
            if img.ndim == 2 or img.shape[-1] == 1:
                img = color.gray2rgb(ic[i])
            im.set_data(img)
            for pt, coord in zip(pts, coords):
                pt.set_data(*coord)
            if ind_bones:
                coll.set_segments(segs[ind])
            imagename = os.path.basename(filename)
            fig.subplots_adjust(left=0,
                                bottom=0,
                                right=1,
                                top=1,
                                wspace=0,
                                hspace=0)
            fig.savefig(
                os.path.join(tmpfolder,
                             imagename.replace(".png", f"_{color_by}.png")),
                dpi=dpi,
            )
        plt.close(fig)

    else:  # Good old inelegant way
        for i in trange(len(ic)):
            filename = ic.files[i]
            ind = images_list.index(filename)
            coords = xy[ind]
            image = ic[i]
            h, w = image.shape[:2]
            fig, ax = prepare_figure_axes(w, h, scale, dpi)
            ax.imshow(image)
            for coord, c in zip(coords, colors):
                ax.plot(*coord, keypoint, ms=s, alpha=alpha, color=c)
            if ind_bones:
                coll = LineCollection(segs[ind],
                                      colors=cfg["skeleton_color"],
                                      alpha=alpha)
                ax.add_collection(coll)
            imagename = os.path.basename(filename)
            fig.subplots_adjust(left=0,
                                bottom=0,
                                right=1,
                                top=1,
                                wspace=0,
                                hspace=0)
            fig.savefig(
                os.path.join(tmpfolder,
                             imagename.replace(".png", f"_{color_by}.png")),
                dpi=dpi,
            )
            plt.close(fig)
Beispiel #23
0
 def set_segments(self, segments):
     """
     Set 3D segments.
     """
     self._segments3d = np.asanyarray(segments)
     LineCollection.set_segments(self, [])
Beispiel #24
0
 def set_segments(self, segments):
     """
     Set 3D segments.
     """
     self._segments3d = np.asanyarray(segments)
     LineCollection.set_segments(self, [])
Beispiel #25
0
 def set_segments(self, segments):
     '''
     Set 3D segments
     '''
     self._segments3d = segments
     LineCollection.set_segments(self, [])
Beispiel #26
0
class ScatterLayerArtist(MatplotlibLayerArtist):

    _layer_state_cls = ScatterLayerState

    def __init__(self, axes, viewer_state, layer_state=None, layer=None):

        super(ScatterLayerArtist, self).__init__(axes, viewer_state,
                                                 layer_state=layer_state, layer=layer)

        # Watch for changes in the viewer state which would require the
        # layers to be redrawn
        self._viewer_state.add_global_callback(self._update_scatter)
        self.state.add_global_callback(self._update_scatter)

        # Scatter
        self.scatter_artist = self.axes.scatter([], [])
        self.plot_artist = self.axes.plot([], [], 'o', mec='none')[0]
        self.errorbar_artist = self.axes.errorbar([], [], fmt='none')
        self.vector_artist = None
        self.line_collection = LineCollection(np.zeros((0, 2, 2)))
        self.axes.add_collection(self.line_collection)

        # Scatter density
        self.density_auto_limits = DensityMapLimits()
        self.density_artist = ScatterDensityArtist(self.axes, [], [], color='white',
                                                   vmin=self.density_auto_limits.min,
                                                   vmax=self.density_auto_limits.max)
        self.axes.add_artist(self.density_artist)

        self.mpl_artists = [self.scatter_artist, self.plot_artist,
                            self.errorbar_artist, self.vector_artist,
                            self.line_collection, self.density_artist]
        self.errorbar_index = 2
        self.vector_index = 3

        self.reset_cache()

    def reset_cache(self):
        self._last_viewer_state = {}
        self._last_layer_state = {}

    @defer_draw
    def _update_data(self, changed):

        # Layer artist has been cleared already
        if len(self.mpl_artists) == 0:
            return

        try:
            x = self.layer[self._viewer_state.x_att].ravel()
        except (IncompatibleAttribute, IndexError):
            # The following includes a call to self.clear()
            self.disable_invalid_attributes(self._viewer_state.x_att)
            return
        else:
            self.enable()

        try:
            y = self.layer[self._viewer_state.y_att].ravel()
        except (IncompatibleAttribute, IndexError):
            # The following includes a call to self.clear()
            self.disable_invalid_attributes(self._viewer_state.y_att)
            return
        else:
            self.enable()

        if self.state.markers_visible:
            if self.state.density_map:
                self.density_artist.set_xy(x, y)
                self.plot_artist.set_data([], [])
                self.scatter_artist.set_offsets(np.zeros((0, 2)))
            else:
                if self.state.cmap_mode == 'Fixed' and self.state.size_mode == 'Fixed':
                    # In this case we use Matplotlib's plot function because it has much
                    # better performance than scatter.
                    self.plot_artist.set_data(x, y)
                    self.scatter_artist.set_offsets(np.zeros((0, 2)))
                    self.density_artist.set_xy([], [])
                else:
                    self.plot_artist.set_data([], [])
                    offsets = np.vstack((x, y)).transpose()
                    self.scatter_artist.set_offsets(offsets)
                    self.density_artist.set_xy([], [])
        else:
            self.plot_artist.set_data([], [])
            self.scatter_artist.set_offsets(np.zeros((0, 2)))
            self.density_artist.set_xy([], [])

        if self.state.line_visible:
            if self.state.cmap_mode == 'Fixed':
                points = np.array([x, y]).transpose()
                self.line_collection.set_segments([points])
            else:
                # In the case where we want to color the line, we need to over
                # sample the line by a factor of two so that we can assign the
                # correct colors to segments - if we didn't do this, then
                # segments on one side of a point would be a different color
                # from the other side. With oversampling, we can have half a
                # segment on either side of a point be the same color as a
                # point
                x_fine = np.zeros(len(x) * 2 - 1, dtype=float)
                y_fine = np.zeros(len(y) * 2 - 1, dtype=float)
                x_fine[::2] = x
                x_fine[1::2] = 0.5 * (x[1:] + x[:-1])
                y_fine[::2] = y
                y_fine[1::2] = 0.5 * (y[1:] + y[:-1])
                points = np.array([x_fine, y_fine]).transpose().reshape(-1, 1, 2)
                segments = np.concatenate([points[:-1], points[1:]], axis=1)
                self.line_collection.set_segments(segments)
        else:
            self.line_collection.set_segments(np.zeros((0, 2, 2)))

        for eartist in list(self.errorbar_artist[2]):
            if eartist is not None:
                try:
                    eartist.remove()
                except ValueError:
                    pass
                except AttributeError:  # Matplotlib < 1.5
                    pass

        if self.vector_artist is not None:
            self.vector_artist.remove()
            self.vector_artist = None

        if self.state.vector_visible:

            if self.state.vx_att is not None and self.state.vy_att is not None:

                vx = self.layer[self.state.vx_att].ravel()
                vy = self.layer[self.state.vy_att].ravel()

                if self.state.vector_mode == 'Polar':
                    ang = vx
                    length = vy
                    # assume ang is anti clockwise from the x axis
                    vx = length * np.cos(np.radians(ang))
                    vy = length * np.sin(np.radians(ang))

            else:
                vx = None
                vy = None

            if self.state.vector_arrowhead:
                hw = 3
                hl = 5
            else:
                hw = 1
                hl = 0

            v = np.hypot(vx, vy)
            vmax = np.nanmax(v)
            vx = vx / vmax
            vy = vy / vmax

            self.vector_artist = self.axes.quiver(x, y, vx, vy, units='width',
                                                  pivot=self.state.vector_origin,
                                                  headwidth=hw, headlength=hl,
                                                  scale_units='width',
                                                  scale=10 / self.state.vector_scaling)
            self.mpl_artists[self.vector_index] = self.vector_artist

        if self.state.xerr_visible or self.state.yerr_visible:

            if self.state.xerr_visible and self.state.xerr_att is not None:
                xerr = self.layer[self.state.xerr_att].ravel()
            else:
                xerr = None

            if self.state.yerr_visible and self.state.yerr_att is not None:
                yerr = self.layer[self.state.yerr_att].ravel()
            else:
                yerr = None

            self.errorbar_artist = self.axes.errorbar(x, y, fmt='none',
                                                      xerr=xerr, yerr=yerr)
            self.mpl_artists[self.errorbar_index] = self.errorbar_artist

    @defer_draw
    def _update_visual_attributes(self, changed, force=False):

        if not self.enabled:
            return

        if self.state.markers_visible:

            if self.state.density_map:

                if self.state.cmap_mode == 'Fixed':
                    if force or 'color' in changed or 'cmap_mode' in changed:
                        self.density_artist.set_color(self.state.color)
                        self.density_artist.set_c(None)
                        self.density_artist.set_clim(self.density_auto_limits.min,
                                                     self.density_auto_limits.max)
                elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                    c = self.layer[self.state.cmap_att].ravel()
                    set_mpl_artist_cmap(self.density_artist, c, self.state)

                if force or 'stretch' in changed:
                    self.density_artist.set_norm(ImageNormalize(stretch=STRETCHES[self.state.stretch]()))

                if force or 'dpi' in changed:
                    self.density_artist.set_dpi(self._viewer_state.dpi)

                if force or 'density_contrast' in changed:
                    self.density_auto_limits.contrast = self.state.density_contrast
                    self.density_artist.stale = True

            else:

                if self.state.cmap_mode == 'Fixed' and self.state.size_mode == 'Fixed':

                    if force or 'color' in changed:
                        self.plot_artist.set_color(self.state.color)

                    if force or 'size' in changed or 'size_scaling' in changed:
                        self.plot_artist.set_markersize(self.state.size *
                                                        self.state.size_scaling)

                else:

                    # TEMPORARY: Matplotlib has a bug that causes set_alpha to
                    # change the colors back: https://github.com/matplotlib/matplotlib/issues/8953
                    if 'alpha' in changed:
                        force = True

                    if self.state.cmap_mode == 'Fixed':
                        if force or 'color' in changed or 'cmap_mode' in changed:
                            self.scatter_artist.set_facecolors(self.state.color)
                            self.scatter_artist.set_edgecolor('none')
                    elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                        c = self.layer[self.state.cmap_att].ravel()
                        set_mpl_artist_cmap(self.scatter_artist, c, self.state)
                        self.scatter_artist.set_edgecolor('none')

                    if force or any(prop in changed for prop in MARKER_PROPERTIES):

                        if self.state.size_mode == 'Fixed':
                            s = self.state.size * self.state.size_scaling
                            s = broadcast_to(s, self.scatter_artist.get_sizes().shape)
                        else:
                            s = self.layer[self.state.size_att].ravel()
                            s = ((s - self.state.size_vmin) /
                                 (self.state.size_vmax - self.state.size_vmin)) * 30
                            s *= self.state.size_scaling

                        # Note, we need to square here because for scatter, s is actually
                        # proportional to the marker area, not radius.
                        self.scatter_artist.set_sizes(s ** 2)

        if self.state.line_visible:

            if self.state.cmap_mode == 'Fixed':
                if force or 'color' in changed or 'cmap_mode' in changed:
                    self.line_collection.set_array(None)
                    self.line_collection.set_color(self.state.color)
            elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                # Higher up we oversampled the points in the line so that
                # half a segment on either side of each point has the right
                # color, so we need to also oversample the color here.
                c = self.layer[self.state.cmap_att].ravel()
                cnew = np.zeros((len(c) - 1) * 2)
                cnew[::2] = c[:-1]
                cnew[1::2] = c[1:]
                set_mpl_artist_cmap(self.line_collection, cnew, self.state)

            if force or 'linewidth' in changed:
                self.line_collection.set_linewidth(self.state.linewidth)

            if force or 'linestyle' in changed:
                self.line_collection.set_linestyle(self.state.linestyle)

        if self.state.vector_visible and self.vector_artist is not None:

            if self.state.cmap_mode == 'Fixed':
                if force or 'color' in changed or 'cmap_mode' in changed:
                    self.vector_artist.set_array(None)
                    self.vector_artist.set_color(self.state.color)
            elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                c = self.layer[self.state.cmap_att].ravel()
                set_mpl_artist_cmap(self.vector_artist, c, self.state)

        if self.state.xerr_visible or self.state.yerr_visible:

            for eartist in list(self.errorbar_artist[2]):

                if eartist is None:
                    continue

                if self.state.cmap_mode == 'Fixed':
                    if force or 'color' in changed or 'cmap_mode' in changed:
                        eartist.set_color(self.state.color)
                elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                    c = self.layer[self.state.cmap_att].ravel()
                    set_mpl_artist_cmap(eartist, c, self.state)

                if force or 'alpha' in changed:
                    eartist.set_alpha(self.state.alpha)

                if force or 'visible' in changed:
                    eartist.set_visible(self.state.visible)

                if force or 'zorder' in changed:
                    eartist.set_zorder(self.state.zorder)

        for artist in [self.scatter_artist, self.plot_artist,
                       self.vector_artist, self.line_collection,
                       self.density_artist]:

            if artist is None:
                continue

            if force or 'alpha' in changed:
                artist.set_alpha(self.state.alpha)

            if force or 'zorder' in changed:
                artist.set_zorder(self.state.zorder)

            if force or 'visible' in changed:
                artist.set_visible(self.state.visible)

        self.redraw()

    @defer_draw
    def _update_scatter(self, force=False, **kwargs):

        if (self._viewer_state.x_att is None or
            self._viewer_state.y_att is None or
                self.state.layer is None):
            return

        # Figure out which attributes are different from before. Ideally we shouldn't
        # need this but currently this method is called multiple times if an
        # attribute is changed due to x_att changing then hist_x_min, hist_x_max, etc.
        # If we can solve this so that _update_histogram is really only called once
        # then we could consider simplifying this. Until then, we manually keep track
        # of which properties have changed.

        changed = set()

        if not force:

            for key, value in self._viewer_state.as_dict().items():
                if value != self._last_viewer_state.get(key, None):
                    changed.add(key)

            for key, value in self.state.as_dict().items():
                if value != self._last_layer_state.get(key, None):
                    changed.add(key)

        self._last_viewer_state.update(self._viewer_state.as_dict())
        self._last_layer_state.update(self.state.as_dict())

        if force or len(changed & DATA_PROPERTIES) > 0:
            self._update_data(changed)
            force = True

        if force or len(changed & VISUAL_PROPERTIES) > 0:
            self._update_visual_attributes(changed, force=force)

    def get_layer_color(self):
        if self.state.cmap_mode == 'Fixed':
            return self.state.color
        else:
            return self.state.cmap

    @defer_draw
    def update(self):
        self._update_scatter(force=True)
        self.redraw()
Beispiel #27
0
def make_labeled_images_from_dataframe(
    df,
    cfg,
    destfolder="",
    scale=1.0,
    dpi=100,
    keypoint="+",
    draw_skeleton=True,
    color_by="bodypart",
):
    """
    Write labeled frames to disk from a DataFrame.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing the labeled data. Typically, the DataFrame is obtained
        through pandas.read_csv() or pandas.read_hdf().

    cfg : dict
        Project configuration.

    destfolder : string, optional
        Destination folder into which images will be stored. By default, same location as the labeled data.
        Note that the folder will be created if it does not exist.

    scale : float, optional
        Up/downscale the output dimensions.
        By default, outputs are of the same dimensions as the original images.

    dpi : int, optional
        Output resolution. 100 dpi by default.

    keypoint : str, optional
        Keypoint appearance. By default, keypoints are marked by a + sign.
        Refer to https://matplotlib.org/3.2.1/api/markers_api.html for a list of all possible options.

    draw_skeleton : bool, optional
        Whether to draw the animal skeleton as defined in *cfg*. True by default.

    color_by : str, optional
        Color scheme of the keypoints. Must be either 'bodypart' or 'individual'.
        By default, keypoints are colored relative to the bodypart they represent.
    """

    bodyparts = df.columns.get_level_values("bodyparts")
    bodypart_names = bodyparts.unique()
    nbodyparts = len(bodypart_names)
    bodyparts = bodyparts[::2]

    if color_by == "bodypart":
        map_ = bodyparts.map(dict(zip(bodypart_names, range(nbodyparts))))
        cmap = get_cmap(nbodyparts, cfg["colormap"])
        colors = cmap(map_)
    elif color_by == "individual":
        try:
            individuals = df.columns.get_level_values("individuals")
            individual_names = individuals.unique().to_list()
            nindividuals = len(individual_names)
            individuals = individuals[::2]
            map_ = individuals.map(
                dict(zip(individual_names, range(nindividuals))))
            cmap = get_cmap(nindividuals, cfg["colormap"])
            colors = cmap(map_)
        except KeyError as e:
            raise Exception(
                "Coloring by individuals is only valid for multi-animal data"
            ) from e
    else:
        raise ValueError(
            "`color_by` must be either `bodypart` or `individual`.")

    bones = []
    if draw_skeleton:
        for bp1, bp2 in cfg["skeleton"]:
            match1, match2 = [], []
            for j, bp in enumerate(bodyparts):
                if bp == bp1:
                    match1.append(j)
                elif bp == bp2:
                    match2.append(j)
            bones.extend(zip(match1, match2))
    ind_bones = tuple(zip(*bones))

    sep = "/" if "/" in df.index[0] else "\\"
    images = cfg["project_path"] + sep + df.index
    if sep != os.path.sep:
        images = images.str.replace(sep, os.path.sep)
    if not destfolder:
        destfolder = os.path.dirname(images[0])
    tmpfolder = destfolder + "_labeled"
    attempttomakefolder(tmpfolder)
    ic = io.imread_collection(images.to_list())

    h, w = ic[0].shape[:2]
    fig, ax = prepare_figure_axes(w, h, scale, dpi)
    im = ax.imshow(np.zeros((h, w)), "gray")
    scat = ax.scatter([], [],
                      s=cfg["dotsize"],
                      alpha=cfg["alphavalue"],
                      marker=keypoint)
    scat.set_color(colors)
    xy = df.values.reshape((df.shape[0], -1, 2))
    segs = xy[:, ind_bones].swapaxes(1, 2)
    coll = LineCollection([], colors=cfg["skeleton_color"])
    ax.add_collection(coll)
    for i in trange(len(ic)):
        coords = xy[i]
        im.set_array(ic[i])
        scat.set_offsets(coords)
        if ind_bones:
            coll.set_segments(segs[i])
        imagename = os.path.basename(ic.files[i])
        fig.savefig(
            os.path.join(tmpfolder,
                         imagename.replace(".png", f"_{color_by}.png")))
    plt.close(fig)
Beispiel #28
0
 def set_segments(self, segments):
     '''
     Set 3D segments
     '''
     self._segments3d = np.asanyarray(segments)
     LineCollection.set_segments(self, [])
Beispiel #29
0
class SURFDemo(ImageProcessDemo):
    TITLE = "SURF Demo"
    DEFAULT_IMAGE = "lena.jpg"
    SETTINGS = ["m_perspective", "hessian_threshold", "n_octaves"]
    m_perspective = Array(np.float, (3, 3))
    m_perspective2 = Array(np.float, (3, 3))

    hessian_threshold = Int(2000)
    n_octaves = Int(2)

    poly = Instance(PolygonWidget)

    def control_panel(self):
        return VGroup(
            Item("m_perspective", label=u"变换矩阵", editor=ArrayEditor(format_str="%g")),
            Item("m_perspective2", label=u"变换矩阵", editor=ArrayEditor(format_str="%g")),
            Item("hessian_threshold", label=u"hessianThreshold"),
            Item("n_octaves", label=u"nOctaves")
        )

    def __init__(self, **kwargs):
        super(SURFDemo, self).__init__(**kwargs)
        self.poly = None
        self.init_points = None
        self.lines = LineCollection([], linewidths=1, alpha=0.6, color="red")
        self.axe.add_collection(self.lines)
        self.connect_dirty("poly.changed,hessian_threshold,n_octaves")

    def init_poly(self):
        if self.poly is None:
            return
        h, w, _ = self.img_color.shape
        self.init_points = np.array([(w, 0), (2*w, 0), (2*w, h), (w, h)], np.float32)
        self.poly.set_points(self.init_points)
        self.poly.update()

    def init_draw(self):
        style = {"marker": "o"}
        self.poly = PolygonWidget(axe=self.axe, points=np.zeros((3, 2)), style=style)
        self.init_poly()

    @on_trait_change("hessian_threshold, n_octaves")
    def calc_surf1(self):
        self.surf = cv2.SURF(self.hessian_threshold, self.n_octaves)
        self.key_points1, self.features1 = self.surf.detectAndCompute(self.img_gray, None)
        self.key_positions1 = np.array([kp.pt for kp in self.key_points1])

    def _img_changed(self):
        self.img_gray = cv2.cvtColor(self.img, cv2.COLOR_BGR2GRAY)
        self.img_color = cv2.cvtColor(self.img_gray, cv2.COLOR_GRAY2RGB)
        self.img_show = np.concatenate([self.img_color, self.img_color], axis=1)
        self.size = self.img_color.shape[1], self.img_color.shape[0]
        self.calc_surf1()

        FLANN_INDEX_KDTREE = 1
        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
        search_params = dict(checks=100)

        self.matcher = cv2.FlannBasedMatcher(index_params, search_params)

        self.init_poly()

    def settings_loaded(self):
        src = self.init_points.copy()
        w, h = self.size
        src[:, 0] -= w
        dst = cv2.perspectiveTransform(src[None, :, :], self.m_perspective)
        dst = dst.squeeze()
        dst[:, 0] += w
        self.poly.set_points(dst)
        self.poly.update()

    def draw(self):
        if self.poly is None:
            return
        w, h = self.size
        src = self.init_points.copy()
        dst = self.poly.points.copy().astype(np.float32)
        src[:, 0] -= w
        dst[:, 0] -= w
        m = cv2.getPerspectiveTransform(src, dst)
        self.m_perspective = m
        img2 = cv2.warpPerspective(self.img_gray, m, self.size, borderValue=[255]*4)
        self.img_show[:, w:, :] = img2[:, :, None]
        key_points2, features2 = self.surf.detectAndCompute(img2, None)

        key_positions2 = np.array([kp.pt for kp in key_points2])

        match_list = self.matcher.knnMatch(self.features1, features2, k=1)
        index1 = np.array([m[0].queryIdx for m in match_list])
        index2 = np.array([m[0].trainIdx for m in match_list])

        distances = np.array([m[0].distance for m in match_list])

        n = min(50, len(distances))
        best_index = np.argsort(distances)[:n]
        matched_positions1 = self.key_positions1[index1[best_index]]
        matched_positions2 = key_positions2[index2[best_index]]

        self.m_perspective2, mask = cv2.findHomography(matched_positions1, matched_positions2, cv2.RANSAC)

        lines = np.concatenate([matched_positions1, matched_positions2], axis=1)
        lines[:, 2] += w
        line_colors = COLORS[mask.ravel()]
        self.lines.set_segments(lines.reshape(-1, 2, 2))
        self.lines.set_color(line_colors)
        self.draw_image(self.img_show)
class AnimationPlot:
    def __init__(self, x_min, y_min, x_max, y_max):
        # Create the figure
        self.figure = plt.figure(figsize=((x_max - x_min)/(y_max - y_min), 4), facecolor='black')

    def create_texture(self, texture, texture_length):
        # Create texture x values
        self.texture_x = np.linspace(0, texture_length, 100*texture_length)
        
        # Plot the texture
        self.texture_line, = self.animation_axis.plot([], [], color='white', lw=1)

    def create_whiskers(self, n_whiskers):
        # Initialize the list of whisker lines & colors
        # Note: colors are initialized to span the range (0, 1) in order to set the correct max & min for the subsequent animation
        self.whisker_lines = [[(0, 0), (0, 0)]]*n_whiskers
        self.whisker_colors = np.linspace(0, 1, n_whiskers)

        # Create a line collection from the whisker lines & colors
        self.whisker_line_collection = LineCollection(self.whisker_lines, array=self.whisker_colors, cmap=cm.rainbow, lw=0.5)
        self.animation_axis.add_collection(self.whisker_line_collection)

    def create_animation_plot(self, x_min, y_min, x_max, y_max):
        self.animation_axis = plt.Axes(self.figure, [0, 0.66, 1, 0.33])
        self.figure.add_axes(self.animation_axis)
        self.animation_axis.set_axis_off()
        self.animation_axis.set_xlim(x_min, x_max)
        self.animation_axis.set_ylim(x_min - 1, y_max + 1)

    def create_whisker_deflection_plot(self, n_whiskers):
        self.whisker_deflection_axis = plt.Axes(self.figure, [0, 0.33, 1, 0.33])
        self.figure.add_axes(self.whisker_deflection_axis)

        self.whisker_deflection_axis.set_axis_off()
        self.whisker_deflection_axis.set_xlim(-0.5, n_whiskers - 0.5)
        self.whisker_deflection_axis.set_ylim(0, 1)

        self.whisker_deflection_lines = []

        # Update whisker deflection lines
        for n in range(n_whiskers):
            # Set starting x & y coordinates of the whisker deflection line
            x_start = n_whiskers - n
            y_start = 0.1

            # Set ending x & y coordinates of the whisker deflection line
            x_end = n_whiskers - n
            y_end = 0.9

            self.whisker_deflection_lines.append([(x_start, y_start), (x_end, y_end)])

        # Initialize the list of sensory whisker deflection colors
        self.whisker_deflection_colors = np.linspace(0, 1, n_whiskers)

        # Create a line collection from the whisker deflection lines & colors
        self.whisker_deflection_line_collection = LineCollection(self.whisker_deflection_lines, array=self.whisker_deflection_colors, cmap=cm.rainbow, lw=5)
        self.whisker_deflection_axis.add_collection(self.whisker_deflection_line_collection)

    def create_sensory_cell_activity_plot(self, n_sensory_cells):
        self.sensory_cell_activity_axis = plt.Axes(self.figure, [0, 0, 1, 0.33])
        self.figure.add_axes(self.sensory_cell_activity_axis)

        self.sensory_cell_activity_axis.set_axis_off()
        self.sensory_cell_activity_axis.set_xlim(-0.5, n_sensory_cells - 0.5)
        self.sensory_cell_activity_axis.set_ylim(0, 1)

        self.sensory_cell_activity_lines = []

        # Update sensory cell activity lines
        for n in range(n_sensory_cells):
            # Set starting x & y coordinates of the cell activity line
            x_start = n_sensory_cells - n
            y_start = 0.1

            # Set ending x & y coordinates of the cell activity line
            x_end = n_sensory_cells - n
            y_end = 0.9

            self.sensory_cell_activity_lines.append([(x_start, y_start), (x_end, y_end)])

        # Initialize the list of sensory cell activity colors
        self.sensory_cell_activity_colors = np.linspace(0, 1, n_sensory_cells)

        # Create a line collection from the sensory cell activity lines & colors
        self.sensory_cell_activity_line_collection = LineCollection(self.sensory_cell_activity_lines, array=self.sensory_cell_activity_colors, cmap=cm.rainbow, lw=5)
        self.sensory_cell_activity_axis.add_collection(self.sensory_cell_activity_line_collection)

    def update_plot(self, i):
        # Call the update function to update the agent and/or texture
        texture, agent = self.update_func(i)

        # Update texture y values
        texture_y = texture.value(self.texture_x)

        # Update whisker lines
        for n in range(agent.n_whiskers):
            # Get the whisker angle & deflection amount
            whisker_angle = agent.whiskers.whisker_angles[n]
            deflection    = agent.whiskers.deflections[n]

            # Set starting x & y coordinates of the whisker
            y_start = 0.5*agent.whiskers.whisker_lengths[0]*np.sin(whisker_angle)
            x_start = agent.x + 0.5*agent.whiskers.whisker_lengths[0]*np.cos(whisker_angle)

            # Set ending x & y coordinates of the whisker
            x_end = agent.x + (agent.whiskers.whisker_lengths[n] - deflection)*np.cos(whisker_angle)
            y_end = (agent.whiskers.whisker_lengths[n] - deflection)*np.sin(whisker_angle)

            # Update the whisker color
            self.whisker_colors[n] = deflection/np.amax(agent.whiskers.whisker_lengths)

            # Update the whisker line
            self.whisker_lines[n] = [(x_start, y_start), (x_end, y_end)]

        # Update sensory cell activity lines
        for n in range(agent.n_sensory_cells):
            # Get cell activities
            cell_activity = agent.sensory_cells.activity[n]
            max_activity  = np.amax(agent.sensory_cells.activity)

            if max_activity > 0:
                # Update the cell activity line color
                self.sensory_cell_activity_colors[n] = cell_activity/np.amax(agent.sensory_cells.activity)
            else:
                self.sensory_cell_activity_colors[n] = 0

        # Update whisker deflection lines
        for n in range(agent.n_whiskers):
            # Get whisker deflection
            deflection = agent.whiskers.deflections[n]

            # Update the whisker deflection line color
            self.whisker_deflection_colors[n] = deflection/agent.whiskers.whisker_lengths[n]

        # Update the whisker line collection
        self.whisker_line_collection.set_array(self.whisker_colors)
        self.whisker_line_collection.set_segments(self.whisker_lines)

        # Update the whisker deflection line collection
        self.whisker_deflection_line_collection.set_array(self.whisker_deflection_colors)

        # Update the cell activation line collection
        self.sensory_cell_activity_line_collection.set_array(self.sensory_cell_activity_colors)

        # Update the texture line data
        self.texture_line.set_data(self.texture_x, texture_y)

        return self.texture_line, self.whisker_line_collection, self.whisker_deflection_line_collection, self.sensory_cell_activity_line_collection

    def animate(self, update_func):
        self.update_func = update_func

        # Create the animation
        anim = animation.FuncAnimation(self.figure, self.update_plot, init_func=lambda:[self.texture_line, self.whisker_line_collection, self.whisker_deflection_line_collection, self.sensory_cell_activity_line_collection], interval=1, blit=True)

        # Show the plot
        plt.show()
Beispiel #31
0
ax.set_aspect('equal')
ax.grid()

ax.set_xlim(0, 10)
ax.set_ylim(0, 10)

#=======================================================
segmentsCollection = LineCollection([], linestyle='solid', color='r')
ax.add_collection(segmentsCollection)

peaksCollection = LineCollection([], linestyle='solid', color='b')
ax.add_collection(peaksCollection)

segs2 = [[[2, 5], [6, 8]], [[1, 2], [6, 4]], [[2, 2], [8, 2]]]
segmentsCollection.set_segments(segs2)

# I would like to save segment number data in the LineCollection instance (peaksCollection)
# Don't know how to store this extra information in the segment of the LineCollection object.
peaksInSegment = []


#=======================================================
def draw_tick(segment, x, y, length=0.25):
    line = LineString(segment)
    left = line.parallel_offset(length, 'left')
    right0 = line.parallel_offset(length, 'right')
    right = LineString([right0.boundary.geoms[1], right0.boundary.geoms[0]
                        ])  # flip because 'right' orientation
    point = Point(x, y)
    a = left.interpolate(line.project(point))
Beispiel #32
0
class ScatterLayerArtist(MatplotlibLayerArtist):

    _layer_state_cls = ScatterLayerState

    def __init__(self, axes, viewer_state, layer_state=None, layer=None):

        super(ScatterLayerArtist, self).__init__(axes,
                                                 viewer_state,
                                                 layer_state=layer_state,
                                                 layer=layer)

        # Watch for changes in the viewer state which would require the
        # layers to be redrawn
        self._viewer_state.add_global_callback(self._update_scatter)
        self.state.add_global_callback(self._update_scatter)

        # Scatter
        self.scatter_artist = self.axes.scatter([], [])
        self.plot_artist = self.axes.plot([], [], 'o', mec='none')[0]
        self.errorbar_artist = self.axes.errorbar([], [], fmt='none')
        self.vector_artist = None
        self.line_collection = LineCollection(np.zeros((0, 2, 2)))
        self.axes.add_collection(self.line_collection)

        # Scatter density
        self.density_auto_limits = DensityMapLimits()
        self.density_artist = ScatterDensityArtist(
            self.axes, [], [],
            color='white',
            vmin=self.density_auto_limits.min,
            vmax=self.density_auto_limits.max)
        self.axes.add_artist(self.density_artist)

        self.mpl_artists = [
            self.scatter_artist, self.plot_artist, self.errorbar_artist,
            self.vector_artist, self.line_collection, self.density_artist
        ]
        self.errorbar_index = 2
        self.vector_index = 3

        self.reset_cache()

    def reset_cache(self):
        self._last_viewer_state = {}
        self._last_layer_state = {}

    @defer_draw
    def _update_data(self, changed):

        # Layer artist has been cleared already
        if len(self.mpl_artists) == 0:
            return

        try:
            x = self.layer[self._viewer_state.x_att].ravel()
        except (IncompatibleAttribute, IndexError):
            # The following includes a call to self.clear()
            self.disable_invalid_attributes(self._viewer_state.x_att)
            return
        else:
            self.enable()

        try:
            y = self.layer[self._viewer_state.y_att].ravel()
        except (IncompatibleAttribute, IndexError):
            # The following includes a call to self.clear()
            self.disable_invalid_attributes(self._viewer_state.y_att)
            return
        else:
            self.enable()

        if self.state.markers_visible:
            if self.state.density_map:
                self.density_artist.set_xy(x, y)
                self.plot_artist.set_data([], [])
                self.scatter_artist.set_offsets(np.zeros((0, 2)))
            else:
                if self.state.cmap_mode == 'Fixed' and self.state.size_mode == 'Fixed':
                    # In this case we use Matplotlib's plot function because it has much
                    # better performance than scatter.
                    self.plot_artist.set_data(x, y)
                    self.scatter_artist.set_offsets(np.zeros((0, 2)))
                    self.density_artist.set_xy([], [])
                else:
                    self.plot_artist.set_data([], [])
                    offsets = np.vstack((x, y)).transpose()
                    self.scatter_artist.set_offsets(offsets)
                    self.density_artist.set_xy([], [])
        else:
            self.plot_artist.set_data([], [])
            self.scatter_artist.set_offsets(np.zeros((0, 2)))
            self.density_artist.set_xy([], [])

        if self.state.line_visible:
            if self.state.cmap_mode == 'Fixed':
                points = np.array([x, y]).transpose()
                self.line_collection.set_segments([points])
            else:
                # In the case where we want to color the line, we need to over
                # sample the line by a factor of two so that we can assign the
                # correct colors to segments - if we didn't do this, then
                # segments on one side of a point would be a different color
                # from the other side. With oversampling, we can have half a
                # segment on either side of a point be the same color as a
                # point
                x_fine = np.zeros(len(x) * 2 - 1, dtype=float)
                y_fine = np.zeros(len(y) * 2 - 1, dtype=float)
                x_fine[::2] = x
                x_fine[1::2] = 0.5 * (x[1:] + x[:-1])
                y_fine[::2] = y
                y_fine[1::2] = 0.5 * (y[1:] + y[:-1])
                points = np.array([x_fine,
                                   y_fine]).transpose().reshape(-1, 1, 2)
                segments = np.concatenate([points[:-1], points[1:]], axis=1)
                self.line_collection.set_segments(segments)
        else:
            self.line_collection.set_segments(np.zeros((0, 2, 2)))

        for eartist in list(self.errorbar_artist[2]):
            if eartist is not None:
                try:
                    eartist.remove()
                except ValueError:
                    pass
                except AttributeError:  # Matplotlib < 1.5
                    pass

        if self.vector_artist is not None:
            self.vector_artist.remove()
            self.vector_artist = None

        if self.state.vector_visible:

            if self.state.vx_att is not None and self.state.vy_att is not None:

                vx = self.layer[self.state.vx_att].ravel()
                vy = self.layer[self.state.vy_att].ravel()

                if self.state.vector_mode == 'Polar':
                    ang = vx
                    length = vy
                    # assume ang is anti clockwise from the x axis
                    vx = length * np.cos(np.radians(ang))
                    vy = length * np.sin(np.radians(ang))

            else:
                vx = None
                vy = None

            if self.state.vector_arrowhead:
                hw = 3
                hl = 5
            else:
                hw = 1
                hl = 0

            v = np.hypot(vx, vy)
            vmax = np.nanmax(v)
            vx = vx / vmax
            vy = vy / vmax

            self.vector_artist = self.axes.quiver(
                x,
                y,
                vx,
                vy,
                units='width',
                pivot=self.state.vector_origin,
                headwidth=hw,
                headlength=hl,
                scale_units='width',
                scale=10 / self.state.vector_scaling)
            self.mpl_artists[self.vector_index] = self.vector_artist

        if self.state.xerr_visible or self.state.yerr_visible:

            if self.state.xerr_visible and self.state.xerr_att is not None:
                xerr = self.layer[self.state.xerr_att].ravel()
            else:
                xerr = None

            if self.state.yerr_visible and self.state.yerr_att is not None:
                yerr = self.layer[self.state.yerr_att].ravel()
            else:
                yerr = None

            self.errorbar_artist = self.axes.errorbar(x,
                                                      y,
                                                      fmt='none',
                                                      xerr=xerr,
                                                      yerr=yerr)
            self.mpl_artists[self.errorbar_index] = self.errorbar_artist

    @defer_draw
    def _update_visual_attributes(self, changed, force=False):

        if not self.enabled:
            return

        if self.state.markers_visible:

            if self.state.density_map:

                if self.state.cmap_mode == 'Fixed':
                    if force or 'color' in changed or 'cmap_mode' in changed:
                        self.density_artist.set_color(self.state.color)
                        self.density_artist.set_c(None)
                        self.density_artist.set_clim(
                            self.density_auto_limits.min,
                            self.density_auto_limits.max)
                elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                    c = self.layer[self.state.cmap_att].ravel()
                    set_mpl_artist_cmap(self.density_artist, c, self.state)

                if force or 'stretch' in changed:
                    self.density_artist.set_norm(
                        ImageNormalize(
                            stretch=STRETCHES[self.state.stretch]()))

                if force or 'dpi' in changed:
                    self.density_artist.set_dpi(self._viewer_state.dpi)

                if force or 'density_contrast' in changed:
                    self.density_auto_limits.contrast = self.state.density_contrast
                    self.density_artist.stale = True

            else:

                if self.state.cmap_mode == 'Fixed' and self.state.size_mode == 'Fixed':

                    if force or 'color' in changed:
                        self.plot_artist.set_color(self.state.color)

                    if force or 'size' in changed or 'size_scaling' in changed:
                        self.plot_artist.set_markersize(
                            self.state.size * self.state.size_scaling)

                else:

                    # TEMPORARY: Matplotlib has a bug that causes set_alpha to
                    # change the colors back: https://github.com/matplotlib/matplotlib/issues/8953
                    if 'alpha' in changed:
                        force = True

                    if self.state.cmap_mode == 'Fixed':
                        if force or 'color' in changed or 'cmap_mode' in changed:
                            self.scatter_artist.set_facecolors(
                                self.state.color)
                            self.scatter_artist.set_edgecolor('none')
                    elif force or any(prop in changed
                                      for prop in CMAP_PROPERTIES):
                        c = self.layer[self.state.cmap_att].ravel()
                        set_mpl_artist_cmap(self.scatter_artist, c, self.state)
                        self.scatter_artist.set_edgecolor('none')

                    if force or any(prop in changed
                                    for prop in MARKER_PROPERTIES):

                        if self.state.size_mode == 'Fixed':
                            s = self.state.size * self.state.size_scaling
                            s = broadcast_to(
                                s,
                                self.scatter_artist.get_sizes().shape)
                        else:
                            s = self.layer[self.state.size_att].ravel()
                            s = ((s - self.state.size_vmin) /
                                 (self.state.size_vmax -
                                  self.state.size_vmin)) * 30
                            s *= self.state.size_scaling

                        # Note, we need to square here because for scatter, s is actually
                        # proportional to the marker area, not radius.
                        self.scatter_artist.set_sizes(s**2)

        if self.state.line_visible:

            if self.state.cmap_mode == 'Fixed':
                if force or 'color' in changed or 'cmap_mode' in changed:
                    self.line_collection.set_array(None)
                    self.line_collection.set_color(self.state.color)
            elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                # Higher up we oversampled the points in the line so that
                # half a segment on either side of each point has the right
                # color, so we need to also oversample the color here.
                c = self.layer[self.state.cmap_att].ravel()
                cnew = np.zeros((len(c) - 1) * 2)
                cnew[::2] = c[:-1]
                cnew[1::2] = c[1:]
                set_mpl_artist_cmap(self.line_collection, cnew, self.state)

            if force or 'linewidth' in changed:
                self.line_collection.set_linewidth(self.state.linewidth)

            if force or 'linestyle' in changed:
                self.line_collection.set_linestyle(self.state.linestyle)

        if self.state.vector_visible and self.vector_artist is not None:

            if self.state.cmap_mode == 'Fixed':
                if force or 'color' in changed or 'cmap_mode' in changed:
                    self.vector_artist.set_array(None)
                    self.vector_artist.set_color(self.state.color)
            elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                c = self.layer[self.state.cmap_att].ravel()
                set_mpl_artist_cmap(self.vector_artist, c, self.state)

        if self.state.xerr_visible or self.state.yerr_visible:

            for eartist in list(self.errorbar_artist[2]):

                if eartist is None:
                    continue

                if self.state.cmap_mode == 'Fixed':
                    if force or 'color' in changed or 'cmap_mode' in changed:
                        eartist.set_color(self.state.color)
                elif force or any(prop in changed for prop in CMAP_PROPERTIES):
                    c = self.layer[self.state.cmap_att].ravel()
                    set_mpl_artist_cmap(eartist, c, self.state)

                if force or 'alpha' in changed:
                    eartist.set_alpha(self.state.alpha)

                if force or 'visible' in changed:
                    eartist.set_visible(self.state.visible)

                if force or 'zorder' in changed:
                    eartist.set_zorder(self.state.zorder)

        for artist in [
                self.scatter_artist, self.plot_artist, self.vector_artist,
                self.line_collection, self.density_artist
        ]:

            if artist is None:
                continue

            if force or 'alpha' in changed:
                artist.set_alpha(self.state.alpha)

            if force or 'zorder' in changed:
                artist.set_zorder(self.state.zorder)

            if force or 'visible' in changed:
                artist.set_visible(self.state.visible)

        self.redraw()

    @defer_draw
    def _update_scatter(self, force=False, **kwargs):

        if (self._viewer_state.x_att is None
                or self._viewer_state.y_att is None
                or self.state.layer is None):
            return

        # Figure out which attributes are different from before. Ideally we shouldn't
        # need this but currently this method is called multiple times if an
        # attribute is changed due to x_att changing then hist_x_min, hist_x_max, etc.
        # If we can solve this so that _update_histogram is really only called once
        # then we could consider simplifying this. Until then, we manually keep track
        # of which properties have changed.

        changed = set()

        if not force:

            for key, value in self._viewer_state.as_dict().items():
                if value != self._last_viewer_state.get(key, None):
                    changed.add(key)

            for key, value in self.state.as_dict().items():
                if value != self._last_layer_state.get(key, None):
                    changed.add(key)

        self._last_viewer_state.update(self._viewer_state.as_dict())
        self._last_layer_state.update(self.state.as_dict())

        if force or len(changed & DATA_PROPERTIES) > 0:
            self._update_data(changed)
            force = True

        if force or len(changed & VISUAL_PROPERTIES) > 0:
            self._update_visual_attributes(changed, force=force)

    def get_layer_color(self):
        if self.state.cmap_mode == 'Fixed':
            return self.state.color
        else:
            return self.state.cmap

    @defer_draw
    def update(self):
        self._update_scatter(force=True)
        self.redraw()
Beispiel #33
0
class HoughDemo(ImageProcessDemo):
    TITLE = u"Hough Demo"
    DEFAULT_IMAGE = "stuff.jpg"
    SETTINGS = ["th2", "show_canny", "rho", "theta", "hough_th",
                "minlen", "maxgap", "dp", "mindist", "param2",
                "min_radius", "max_radius", "blur_sigma",
                "linewidth", "alpha", "check_line", "check_circle"]

    check_line = Bool(True)
    check_circle = Bool(True)

    #Gaussian blur parameters
    blur_sigma = Range(0.1, 5.0, 2.0)
    show_blur = Bool(False)

    # Canny parameters
    th2 = Range(0.0, 255.0, 200.0)
    show_canny = Bool(False)

    # HoughLine parameters
    rho = Range(1.0, 10.0, 1.0)
    theta = Range(0.1, 5.0, 1.0)
    hough_th = Range(1, 100, 40)
    minlen = Range(0, 100, 10)
    maxgap = Range(0, 20, 10)

    # HoughtCircle parameters

    dp = Range(1.0, 5.0, 1.9)
    mindist = Range(1.0, 100.0, 50.0)
    param2 = Range(5, 100, 50)
    min_radius = Range(5, 100, 20)
    max_radius = Range(10, 100, 70)

    # draw parameters
    linewidth = Range(1.0, 3.0, 1.0)
    alpha = Range(0.0, 1.0, 0.6)

    def control_panel(self):
        return VGroup(
            Group(
                Item("blur_sigma", label=u"标准方差"),
                Item("show_blur", label=u"显示结果"),
                label=u"高斯模糊参数"
            ),
            Group(
                Item("th2", label=u"阈值2"),
                Item("show_canny", label=u"显示结果"),
                label=u"边缘检测参数"
            ),
            Group(
                Item("rho", label=u"偏移分辨率(像素)"),
                Item("theta", label=u"角度分辨率(角度)"),
                Item("hough_th", label=u"阈值"),
                Item("minlen", label=u"最小长度"),
                Item("maxgap", label=u"最大空隙"),
                label=u"直线检测"
            ),
            Group(
                Item("dp", label=u"分辨率(像素)"),
                Item("mindist", label=u"圆心最小距离(像素)"),
                Item("param2", label=u"圆心检查阈值"),
                Item("min_radius", label=u"最小半径"),
                Item("max_radius", label=u"最大半径"),
                label=u"圆检测"
            ),
            Group(
                Item("linewidth", label=u"线宽"),
                Item("alpha", label=u"alpha"),
                HGroup(
                    Item("check_line", label=u"直线"),
                    Item("check_circle", label=u"圆"),
                ),
                label=u"绘图参数"
            )
        )

    def __init__(self, **kwargs):
        super(HoughDemo, self).__init__(**kwargs)
        self.connect_dirty("th2, show_canny, show_blur, rho, theta, hough_th,"
                            "min_radius, max_radius, blur_sigma,"
                           "minlen, maxgap, dp, mindist, param2, "
                           "linewidth, alpha, check_line, check_circle")
        self.lines = LineCollection([], linewidths=2, alpha=0.6)
        self.axe.add_collection(self.lines)

        self.circles = EllipseCollection(
            [], [], [],
            units="xy",
            facecolors="none",
            edgecolors="red",
            linewidths=2,
            alpha=0.6,
            transOffset=self.axe.transData)

        self.axe.add_collection(self.circles)

    def _img_changed(self):
        self.img_gray = cv2.cvtColor(self.img, cv2.COLOR_BGR2GRAY)

    def draw(self):
        img_smooth = cv2.GaussianBlur(self.img_gray, (0, 0), self.blur_sigma, self.blur_sigma)
        img_edge = cv2.Canny(img_smooth, self.th2 * 0.5, self.th2)

        if self.show_blur and self.show_canny:
            show_img = cv2.cvtColor(np.maximum(img_smooth, img_edge), cv2.COLOR_BAYER_BG2BGR)
        elif self.show_blur:
            show_img = cv2.cvtColor(img_smooth, cv2.COLOR_BAYER_BG2BGR)
        elif self.show_canny:
            show_img = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2BGR)
        else:
            show_img = self.img

        if self.check_line:
            theta = self.theta / 180.0 * np.pi
            lines = cv2.HoughLinesP(img_edge,
                                    self.rho, theta, self.hough_th,
                                    minLineLength=self.minlen,
                                    maxLineGap=self.maxgap)

            if lines is not None:
                lines = lines[0]
                lines.shape = -1, 2, 2
                self.lines.set_segments(lines)
                self.lines.set_visible(True)
            else:
                self.lines.set_visible(False)
        else:
            self.lines.set_visible(False)

        if self.check_circle:
            circles = cv2.HoughCircles(img_smooth, 3,
                                       self.dp, self.mindist,
                                       param1=self.th2,
                                       param2=self.param2,
                                       minRadius=self.min_radius,
                                       maxRadius=self.max_radius)

            if circles is not None:
                circles = circles[0]
                self.circles._heights = self.circles._widths = circles[:, 2]
                self.circles.set_offsets(circles[:, :2])
                self.circles._angles = np.zeros(len(circles))
                self.circles._transOffset = self.axe.transData
                self.circles.set_visible(True)
            else:
                self.circles.set_visible(False)
        else:
            self.circles.set_visible(False)

        self.lines.set_linewidths(self.linewidth)
        self.circles.set_linewidths(self.linewidth)
        self.lines.set_alpha(self.alpha)
        self.circles.set_alpha(self.alpha)

        self.draw_image(show_img)
Beispiel #34
0
class SpikeBrowserUI(object):
    def __init__(self, window):
        self.window = window
        self.sp_win = [-0.8, 1]
        self.spike_collection = None

        self.fig = Figure((5, 4), 75)

        self.canvas = window.get_canvas(self.fig)

        self._mpl_init()
        self.canvas.mpl_connect('key_press_event', self._on_key)
        self.window.set_scroll_handler(self.OnScrollEvt)

    def _mpl_init(self):
        self.fig.clf()
        self.axes = self.fig.add_axes([0.05, 0.1, 0.95, 0.9])
        self.ax_prev = self.fig.add_axes([0.8, 0.0, 0.1, 0.05])
        self.ax_next = self.fig.add_axes([0.9, 0.0, 0.1, 0.05])

        self.b_next = Button(self.ax_next, 'Next')
        self.b_prev = Button(self.ax_prev, "Prev")

        self.b_next.on_clicked(self._next_spike)
        self.b_prev.on_clicked(self._prev_spike)
        self.i_spike = 0
        self.i_start = 0
        self.line_collection = None

    def _next_spike(self, event):
        try:
            if self.i_spike < len(self.spt) - 1:
                self.i_spike += 1
            t_spk = self.spt[self.i_spike]
            i_start = int(np.ceil(t_spk / 1000. * self.FS -
                                  self.i_window / 2.))
            i_start = np.maximum(self.i_min, i_start)
            i_start = np.minimum(self.i_max, i_start)
            self.i_start = i_start
            self.i_end = self.i_start + self.i_window
            self.window.set_scroll_pos(self.i_start)
            self.draw_plot()
        except IndexError:
            pass

    def _prev_spike(self, event):
        try:
            if self.i_spike > 0:
                self.i_spike -= 1
            t_spk = self.spt[self.i_spike]
            i_start = int(np.ceil(t_spk / 1000. * self.FS -
                                  self.i_window / 2.))
            i_start = np.maximum(self.i_min, i_start)
            i_start = np.minimum(self.i_max, i_start)
            self.i_start = i_start
            self.i_end = self.i_start + self.i_window
            self.window.set_scroll_pos(self.i_start)
            self.draw_plot()
        except IndexError:
            pass

    def _on_key(self, event):
        if event.key == '+' or event.key == '=':
            self.ylims /= 2.
        elif event.key == '-':
            self.ylims *= 2.
        else:
            return
        offset = self.ylims[1] - self.ylims[0]
        self.offsets = np.arange(self.n_chans) * offset
        self.draw_plot()

    def set_spiketimes(self, spk_idx, labels=None, all_labels=None):
        if spk_idx:
            self.spt = spk_idx['data']
            if labels is not None:
                self.labels = labels
                if all_labels is None:
                    self.color_func = label_color(np.unique(labels))
                else:
                    self.color_func = label_color(all_labels)
            else:
                self.labels = None

            self.ax_next.set_visible(True)
            self.ax_prev.set_visible(True)

        else:
            self.spt = None
            self.ax_next.set_visible(False)
            self.ax_prev.set_visible(False)

    def set_data(self, data):

        self.x = data['data']
        self.FS = data['FS']
        n_chans, n_pts = self.x.shape

        #reset spike times data/hide buttons
        self.set_spiketimes(None)

        self.i_window = int(self.winsz / 1000. * self.FS)
        # Extents of data sequence:
        self.i_min = 0
        self.i_max = n_pts - self.i_window
        self.n_chans = n_chans

        self.window.set_scroll_max(self.i_max, self.i_window)

        # Indices of data interval to be plotted:

        self.i_end = self.i_start + self.i_window

        self.time = np.arange(self.i_start, self.i_end) * 1. / self.FS

        self.segs = np.empty((n_chans, self.i_window, 2))
        self.segs[:, :, 0] = self.time[np.newaxis, :]
        self.segs[:, :, 1] = self.x[:, self.i_start:self.i_end]

        ylims = (self.segs[:, :, 1].min(), self.segs[:, :, 1].max())
        offset = ylims[1] - ylims[0]
        self.offsets = np.arange(n_chans) * offset
        self.segs[:, :, 1] += self.offsets[:, np.newaxis]

        self.ylims = np.array(ylims)

        if self.line_collection:
            self.line_collection.remove()

        self.line_collection = LineCollection(self.segs,
                                              offsets=None,
                                              transform=self.axes.transData,
                                              color='k')

        self.axes.add_collection(self.line_collection)
        self.axes.set_xlim((self.time[0], self.time[-1]))
        self.axes.set_ylim((self.ylims[0] + self.offsets.min(),
                            self.ylims[1] + self.offsets.max()))

        self.canvas.draw()

    def draw_plot(self):

        self.time = np.arange(self.i_start, self.i_end) * 1. / self.FS
        self.segs[:, :, 0] = self.time[np.newaxis, :]
        self.segs[:, :, 1] = self.x[:, self.i_start:self.
                                    i_end] + self.offsets[:, np.newaxis]
        self.line_collection.set_segments(self.segs)

        # Adjust plot limits:
        self.axes.set_xlim((self.time[0], self.time[-1]))
        self.axes.set_ylim((self.ylims[0] + self.offsets.min(),
                            self.ylims[1] + self.offsets.max()))

        if self.spt is not None:
            self.draw_spikes()
        # Redraw:
        self.canvas.draw()

    def draw_spikes(self):
        if self.spike_collection is not None:
            self.spike_collection.remove()
            self.spike_collection = None
        sp_win = self.sp_win
        time = self.segs[0, :, 0] * 1000.
        t_min, t_max = time[0] - sp_win[0], time[-1] - sp_win[1]
        spt = self.spt[(self.spt > t_min) & (self.spt < t_max)]
        if len(spt) > 0:
            n_pts = int((sp_win[1] - sp_win[0]) / 1000. * self.FS)
            sp_segs = np.empty((len(spt), self.n_chans, n_pts, 2))
            for i in range(len(spt)):
                start, = np.nonzero(time >= (spt[i] + sp_win[0]))
                start = start[0]
                stop = start + n_pts
                sp_segs[i, :, :, 0] = (time[np.newaxis, start:stop] / 1000.)
                sp_segs[i, :, :, 1] = self.segs[:, start:stop, 1]
            sp_segs = sp_segs.reshape(-1, n_pts, 2)
            if self.labels is not None:
                labs = self.labels[(self.spt > t_min) & (self.spt < t_max)]
                colors = np.repeat(self.color_func(labs), self.n_chans, 0)
            else:
                colors = 'r'
            self.spike_collection = LineCollection(
                sp_segs,
                offsets=None,
                color=colors,
                transform=self.axes.transData)
            self.axes.add_collection(self.spike_collection)

    def OnScrollEvt(self, pos):

        # Update the indices of the plot:
        self.i_start = self.i_min + pos
        self.i_end = self.i_min + self.i_window + pos
        t_center = (self.i_start + self.i_window / 2.) * 1000. / self.FS
        idx, = np.where(self.spt < t_center)
        if len(idx) > 0:
            self.i_spike = idx[-1]
        else:
            self.i_spike = 0
        self.draw_plot()
Beispiel #35
0
class MemmapViewer(object):
    def __init__(self, df, label=None):
        self.df = df
        self.label = label
        if self.label is None:
            self.label = 'id = {index} size = {Size} timestamp = {timestamp}'

        self.fig = plt.gcf()
        self.fig.set_size_inches([25, 15], forward=True)
        gs = gridspec.GridSpec(2, 1, height_ratios=[20, 1])

        self._ax = self.fig.add_subplot(gs[0, 0])
        self._ax.set_ylim([-0.2, 0.2])

        # create artists
        self._lc = draw_memmap('0x0, 17179869184, 0;', ax=self._ax)
        self._text = self._ax.text(0.5,
                                   0.8,
                                   "",
                                   ha='center',
                                   va='baseline',
                                   fontsize=25,
                                   transform=self._ax.transAxes)
        # lc and texts for allocators
        self._lcalloc = LineCollection([[(0, 0.1), (17179869184, 0.1)]],
                                       linewidths=10)
        self._ax.add_collection(self._lcalloc)
        self._textalloc = []

        pu.cleanup_axis_bytes(self._ax.xaxis)

        init_frame = 0
        self._step(init_frame)

        # add a slider
        axstep = self.fig.add_subplot(gs[1, 0],
                                      facecolor='lightgoldenrodyellow')
        self._stepSlider = Slider(axstep,
                                  'ID',
                                  0,
                                  len(df),
                                  closedmax=False,
                                  valinit=init_frame,
                                  valfmt='%d',
                                  valstep=1,
                                  dragging=True)

        def update(val):
            self._step(int(val))
            self.fig.canvas.draw_idle()

        self._stepSlider.on_changed(update)

        # listen key press
        self.fig.canvas.mpl_connect('key_press_event',
                                    lambda evt: self._keydown(evt))

    def _step(self, i):
        """draw step"""
        row = self.df.iloc[i]
        mapdf, _ = draw_on_lc(row.MemMap, self._lc)
        self._text.set_text(self.label.format(index=i, **dict(row.items())))

        # draw some lines showing each allocator
        nAlloc = len(mapdf.Name.unique())
        while len(self._textalloc) < nAlloc:
            self._textalloc.append(
                self._ax.text(0, 0, "", ha='center', va='baseline'))

        allocSeg = []
        for (name, grp), t in zip(mapdf.groupby('Name'), self._textalloc):
            rgmin, rgmax = grp.NOrigin.min(), grp.NEnd.max()
            yval = -.1
            allocSeg.append([[rgmin, yval], [rgmax, yval]])

            t.set_text(name)
            t.set_position([(rgmin + rgmax) / 2, yval * 1.2])
        self._lcalloc.set_segments(allocSeg)

        self._ax.set_xlim([mapdf.NOrigin.min(), mapdf.NEnd.max()])

    def _keydown(self, evt):
        if evt.key == 'left':
            newval = max(self._stepSlider.val - 1, self._stepSlider.valmin)
            self._stepSlider.set_val(newval)
        elif evt.key == 'right':
            newval = min(self._stepSlider.val + 1, self._stepSlider.valmax - 1)
            self._stepSlider.set_val(newval)
        self.fig.canvas.draw_idle()
Beispiel #36
0
class MemmapViewer(object):
    def __init__(self,
                 df,
                 colormap,
                 ptr2sess,
                 label=None,
                 doPreprocess=False,
                 hdf=None):
        self.label = label
        if self.label is None:
            self.label = 'id = {index} timestamp = {timestamp}'

        self.colormap = colormap
        self.ptr2sess = ptr2sess
        self._do_preprocess = doPreprocess
        self._hdf = hdf
        if doPreprocess:
            self.data = preprocess_memmap2(df, self.colormap, self.ptr2sess)
        else:
            self.data = df

        self.fig = plt.gcf()
        self.fig.set_size_inches([25, 15], forward=True)
        gs = gridspec.GridSpec(2, 1, height_ratios=[20, 1])

        self._ax = self.fig.add_subplot(gs[0, 0])
        self._ax.set_ylim([-0.2, 0.2])

        # create artists
        self._lc = draw_memmap('0x0, 17179869184, 0, 0;', ax=self._ax)
        self._text = self._ax.text(0.5,
                                   0.8,
                                   "",
                                   ha='center',
                                   va='baseline',
                                   fontsize=25,
                                   transform=self._ax.transAxes)
        # lc and texts for bins
        self._lcalloc = LineCollection([[(0, 0.1), (17179869184, 0.1)]],
                                       linewidths=10)
        self._ax.add_collection(self._lcalloc)
        self._textalloc = []

        pu.cleanup_axis_bytes(self._ax.xaxis)

        init_frame = 0
        self._step(init_frame)

        # session legend
        self._ax.legend(handles=[
            mpatches.Patch(color=c, label=sess)
            for sess, c in self.colormap.items()
        ])

        # add a slider
        axstep = self.fig.add_subplot(gs[1, 0],
                                      facecolor='lightgoldenrodyellow')
        self._stepSlider = Slider(axstep,
                                  'ID',
                                  0,
                                  len(self.data) - 1,
                                  valinit=init_frame,
                                  valfmt='%d',
                                  valstep=1,
                                  dragging=True)

        def update(val):
            self._step(int(val))
            self.fig.canvas.draw_idle()

        self._stepSlider.on_changed(update)

        # timer
        self._timer = None

        # listen key press
        self.fig.canvas.mpl_connect('key_press_event',
                                    lambda evt: self._keydown(evt))

    def _data(self, i):
        if self._do_preprocess:
            return self.data[i]
        else:
            if self._hdf is None:
                try:
                    row = self.data.iloc[i]
                except AttributeError:
                    row = self.data[i]
                return _preprocess_memmap_row((i, row), self.ptr2sess,
                                              self.colormap)
            else:
                ts = self.data[i]
                mapdf = pd.from_hdf(self._hdf, str(ts))
                return ts, mapdf

    def _step(self, i):
        """draw step"""
        ts, mapdf = self._data(i)
        try:
            mapdf, _ = draw_on_lc(mapdf, self._lc)
        except ValueError:
            print('Error when parsing memmap for index at {}'.format(i))
            return

        self._text.set_text(self.label.format(index=i, timestamp=ts))

        # draw some lines showing each allocator...
        nAlloc = len(mapdf.Bin.unique())
        while len(self._textalloc) < nAlloc:
            self._textalloc.append(
                self._ax.text(0,
                              0,
                              "",
                              rotation='vertical',
                              ha='center',
                              va='baseline'))
        # ... and clear any extra ones
        for t in self._textalloc[nAlloc:]:
            t.set_text('')

        allocSeg = []
        for (bsize, grp), t in zip(mapdf.groupby('Bin'), self._textalloc):
            rgmin, rgmax = grp.NOrigin.min(), grp.NEnd.max()
            yval = -.1
            allocSeg.append([[rgmin, yval], [rgmax, yval]])

            t.set_text('Bin({})'.format(pu.bytes2human(bsize)))
            t.set_position([(rgmin + rgmax) / 2, yval * 1.2])
        self._lcalloc.set_segments(allocSeg)

        self._ax.set_xlim([mapdf.NOrigin.min(), mapdf.NEnd.max()])

    def _keydown(self, evt):
        if evt.key == 'left':
            self.prev()
        elif evt.key == 'right':
            self.next()
        elif evt.key == ' ':
            if self._timer is None:
                self._timer = self.fig.canvas.new_timer(interval=100)

                def autonext():
                    self.next()
                    if self._stepSlider.val == self._stepSlider.valmax:
                        # stop timer
                        self._timer.stop()
                        self._timer = None

                self._timer.add_callback(autonext)
                self._timer.start()
            else:
                self._timer.stop()
                self._timer = None
        else:
            print(evt.key)

    def next(self):
        newval = min(self._stepSlider.val + 1, self._stepSlider.valmax)
        self._stepSlider.set_val(newval)
        self.fig.canvas.draw_idle()

    def prev(self):
        newval = max(self._stepSlider.val - 1, self._stepSlider.valmin)
        self._stepSlider.set_val(newval)
        self.fig.canvas.draw_idle()

    def goto(self, val):
        self._stepSlider.set_val(val)
        self.fig.canvas.draw_idle()
Beispiel #37
0
class HoughDemo(ImageProcessDemo):
    TITLE = u"Hough Demo"
    DEFAULT_IMAGE = "stuff.jpg"
    SETTINGS = ["th2", "show_canny", "rho", "theta", "hough_th",
                "minlen", "maxgap", "dp", "mindist", "param2",
                "min_radius", "max_radius", "blur_sigma",
                "linewidth", "alpha", "check_line", "check_circle"]

    check_line = Bool(True)
    check_circle = Bool(True)

    #Gaussian blur parameters
    blur_sigma = Range(0.1, 5.0, 2.0)
    show_blur = Bool(False)

    # Canny parameters
    th2 = Range(0.0, 255.0, 200.0)
    show_canny = Bool(False)

    # HoughLine parameters
    rho = Range(1.0, 10.0, 1.0)
    theta = Range(0.1, 5.0, 1.0)
    hough_th = Range(1, 100, 40)
    minlen = Range(0, 100, 10)
    maxgap = Range(0, 20, 10)

    # HoughtCircle parameters

    dp = Range(1.0, 5.0, 1.9)
    mindist = Range(1.0, 100.0, 50.0)
    param2 = Range(5, 100, 50)
    min_radius = Range(5, 100, 20)
    max_radius = Range(10, 100, 70)

    # draw parameters
    linewidth = Range(1.0, 3.0, 1.0)
    alpha = Range(0.0, 1.0, 0.6)

    def control_panel(self):
        return VGroup(
            Group(
                Item("blur_sigma", label=u"标准方差"),
                Item("show_blur", label=u"显示结果"),
                label=u"高斯模糊参数"
            ),
            Group(
                Item("th2", label=u"阈值2"),
                Item("show_canny", label=u"显示结果"),
                label=u"边缘检测参数"
            ),
            Group(
                Item("rho", label=u"偏移分辨率(像素)"),
                Item("theta", label=u"角度分辨率(角度)"),
                Item("hough_th", label=u"阈值"),
                Item("minlen", label=u"最小长度"),
                Item("maxgap", label=u"最大空隙"),
                label=u"直线检测"
            ),
            Group(
                Item("dp", label=u"分辨率(像素)"),
                Item("mindist", label=u"圆心最小距离(像素)"),
                Item("param2", label=u"圆心检查阈值"),
                Item("min_radius", label=u"最小半径"),
                Item("max_radius", label=u"最大半径"),
                label=u"圆检测"
            ),
            Group(
                Item("linewidth", label=u"线宽"),
                Item("alpha", label=u"alpha"),
                HGroup(
                    Item("check_line", label=u"直线"),
                    Item("check_circle", label=u"圆"),
                ),
                label=u"绘图参数"
            )
        )

    def __init__(self, **kwargs):
        super(HoughDemo, self).__init__(**kwargs)
        self.connect_dirty("th2, show_canny, show_blur, rho, theta, hough_th,"
                            "min_radius, max_radius, blur_sigma,"
                           "minlen, maxgap, dp, mindist, param2, "
                           "linewidth, alpha, check_line, check_circle")
        self.lines = LineCollection([], linewidths=2, alpha=0.6)
        self.axe.add_collection(self.lines)

        self.circles = EllipseCollection(
            [], [], [],
            units="xy",
            facecolors="none",
            edgecolors="red",
            linewidths=2,
            alpha=0.6,
            transOffset=self.axe.transData)

        self.axe.add_collection(self.circles)

    def _img_changed(self):
        self.img_gray = cv2.cvtColor(self.img, cv2.COLOR_BGR2GRAY)

    def draw(self):
        img_smooth = cv2.GaussianBlur(self.img_gray, (0, 0), self.blur_sigma, self.blur_sigma)
        img_edge = cv2.Canny(img_smooth, self.th2 * 0.5, self.th2)

        if self.show_blur and self.show_canny:
            show_img = cv2.cvtColor(np.maximum(img_smooth, img_edge), cv2.COLOR_BAYER_BG2BGR)
        elif self.show_blur:
            show_img = cv2.cvtColor(img_smooth, cv2.COLOR_BAYER_BG2BGR)
        elif self.show_canny:
            show_img = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2BGR)
        else:
            show_img = self.img

        if self.check_line:
            theta = self.theta / 180.0 * np.pi
            lines = cv2.HoughLinesP(img_edge,
                                    self.rho, theta, self.hough_th,
                                    minLineLength=self.minlen,
                                    maxLineGap=self.maxgap)

            if lines is not None:
                lines = lines[0]
                lines.shape = -1, 2, 2
                self.lines.set_segments(lines)
                self.lines.set_visible(True)
            else:
                self.lines.set_visible(False)
        else:
            self.lines.set_visible(False)

        if self.check_circle:
            circles = cv2.HoughCircles(img_smooth, 3,
                                       self.dp, self.mindist,
                                       param1=self.th2,
                                       param2=self.param2,
                                       minRadius=self.min_radius,
                                       maxRadius=self.max_radius)

            if circles is not None:
                circles = circles[0]
                self.circles._heights = self.circles._widths = circles[:, 2]
                self.circles.set_offsets(circles[:, :2])
                self.circles._angles = np.zeros(len(circles))
                self.circles._transOffset = self.axe.transData
                self.circles.set_visible(True)
            else:
                self.circles.set_visible(False)
        else:
            self.circles.set_visible(False)

        self.lines.set_linewidths(self.linewidth)
        self.circles.set_linewidths(self.linewidth)
        self.lines.set_alpha(self.alpha)
        self.circles.set_alpha(self.alpha)

        self.draw_image(show_img)
Beispiel #38
0
 def set_segments(self, segments):
     '''
     Set 3D segments
     '''
     self._segments3d = segments
     LineCollection.set_segments(self, [])
Beispiel #39
0
class Visualization(object):
    # Visualization tools
    def __init__(self):

        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
        self.ax.set_aspect('equal')

        plt.ion()
        plt.axis('off')
        plt.tight_layout()

        self.particle_handle = None
        self.true_pose_handle = None
        self.mean_pose_handle = None
        self.lidar_handle = None

    # Function to draw the gridmap
    #    gridmap:            An instance of the Gridmap class that specifies
    #                        an occupancy grid representation of the map
    #                        where 1: occupied and 0: free
    def drawGridmap(self, gridmap):

        (m, n) = gridmap.getShape()

        # Set the axis size
        self.ax.set_xlim(0, (n) * gridmap.xres, True, False)
        self.ax.set_ylim(0, (m) * gridmap.yres, True, False)

        for j in range(n):
            for i in range(m):
                # x0 = (j-1)*gridmap.xres
                # y0 = (i-1)*gridmap.yres
                x0 = (j) * gridmap.xres
                y0 = (i) * gridmap.yres
                #y0 = ((m-1) - i)*gridmap.yres

                if gridmap.inCollision(j, i, True):
                    self.ax.add_patch(
                        patches.Rectangle((x0, y0), gridmap.xres,
                                          gridmap.yres))

        self.fig.canvas.draw()

    # Function to draw particles
    #   particles:   An N x 3 array where each column is a particle
    #   weights:     An N x 1 array of particle weights
    def drawParticles(self, particles, weights=None):

        if self.particle_handle == None:
            self.particle_handle, = self.ax.plot(particles[0, :],
                                                 particles[1, :], 'k.')
        else:
            self.particle_handle.set_xdata(particles[0, :])
            self.particle_handle.set_ydata(particles[1, :])

        self.fig.canvas.draw()

    # Function to draw ground-truth pose
    #   x, y, theta:   Position and orientation
    def drawGroundTruthPose(self, x, y, theta):

        if self.true_pose_handle == None:
            self.true_pose_handle, = self.ax.plot(x, y, 'r.')
        else:
            self.true_pose_handle.set_xdata(x)
            self.true_pose_handle.set_ydata(y)

        self.fig.canvas.draw()

    # Function to draw mean pose
    #   x, y, theta:   Position and orientation
    def drawMeanPose(self, x, y, theta):

        if self.mean_pose_handle == None:
            self.mean_pose_handle, = self.ax.plot(x, y, 'go')
        else:
            self.mean_pose_handle.set_xdata(x)
            self.mean_pose_handle.set_ydata(y)

        self.fig.canvas.draw()

    # Function to draw a LIDAR scan
    #   range:       Array of ranges
    #   bearing:     Array of bearings
    #   (x,y,theta)  Pose from which scan was acquired
    def drawLidar(self, range, bearing, x, y, theta):

        # Get the XY points corresponding to range and bearing in the LIDAR frame
        CosSin = np.vstack((np.cos(bearing[:]), np.sin(bearing[:])))
        XY_lidar = np.tile(range.transpose(), (2, 1)) * CosSin

        # Define the rotation matrix
        R = np.array([[np.cos(theta), -np.sin(theta)],
                      [np.sin(theta), np.cos(theta)]])

        XY_robot = np.tile(np.array([[x], [y]]), (1, bearing.shape[0]))

        XY_world = R.dot(
            XY_lidar
        ) + XY_robot  #np.tile(np.array([[x],[y]]),(1,bearing.shape[0]))

        # Restructure the data to make it suitable for LineCollection
        # (a bit ugly, but it works)
        XY_worldT = XY_world.transpose()
        temp3 = XY_worldT.reshape(-1, 1, 2)

        XY_robotT = XY_robot.transpose()
        temp4 = XY_robotT.reshape(-1, 1, 2)
        lines = np.hstack((temp3, temp4))

        if self.lidar_handle == None:
            # for i in range(XY_world.shape[1]):
            #     self.line_handle,
            self.lidar_handle = LineCollection(lines,
                                               cmap=plt.cm.gist_ncar,
                                               linewidths=0.5,
                                               color='red')
            self.ax.add_collection(self.lidar_handle)
        else:
            self.lidar_handle.set_segments(lines)

        self.fig.canvas.draw()
def create_video_with_keypoints_only(
    df,
    output_name,
    ind_links=None,
    pcutoff=0.6,
    dotsize=8,
    alpha=0.7,
    background_color="k",
    skeleton_color="navy",
    color_by="bodypart",
    colormap="viridis",
    fps=25,
    dpi=200,
    codec="h264",
):
    bodyparts = df.columns.get_level_values("bodyparts")[::3]
    bodypart_names = bodyparts.unique()
    n_bodyparts = len(bodypart_names)
    nx = int(np.nanmax(df.xs("x", axis=1, level="coords")))
    ny = int(np.nanmax(df.xs("y", axis=1, level="coords")))

    n_frames = df.shape[0]
    xyp = df.values.reshape((n_frames, -1, 3))

    if color_by == "bodypart":
        map_ = bodyparts.map(dict(zip(bodypart_names, range(n_bodyparts))))
        cmap = plt.get_cmap(colormap, n_bodyparts)
    elif color_by == "individual":
        try:
            individuals = df.columns.get_level_values("individuals")[::3]
            individual_names = individuals.unique().to_list()
            n_individuals = len(individual_names)
            map_ = individuals.map(
                dict(zip(individual_names, range(n_individuals))))
            cmap = plt.get_cmap(colormap, n_individuals)
        except KeyError as e:
            raise Exception(
                "Coloring by individuals is only valid for multi-animal data"
            ) from e
    else:
        raise ValueError(f"Invalid color_by={color_by}")

    prev_backend = plt.get_backend()
    plt.switch_backend("agg")
    fig = plt.figure(frameon=False, figsize=(nx / dpi, ny / dpi))
    ax = fig.add_subplot(111)
    scat = ax.scatter([], [], s=dotsize**2, alpha=alpha)
    coords = xyp[0, :, :2]
    coords[xyp[0, :, 2] < pcutoff] = np.nan
    scat.set_offsets(coords)
    colors = cmap(map_)
    scat.set_color(colors)
    segs = coords[tuple(zip(
        *tuple([ind_links])))].swapaxes(0, 1) if ind_links else []
    coll = LineCollection(segs, colors=skeleton_color, alpha=alpha)
    ax.add_collection(coll)
    ax.set_xlim(0, nx)
    ax.set_ylim(0, ny)
    ax.axis("off")
    ax.add_patch(
        plt.Rectangle((0, 0),
                      1,
                      1,
                      facecolor=background_color,
                      transform=ax.transAxes,
                      zorder=-1))
    ax.invert_yaxis()
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    writer = FFMpegWriter(fps=fps, codec=codec)
    with writer.saving(fig, output_name, dpi=dpi):
        writer.grab_frame()
        for index, _ in enumerate(trange(n_frames - 1), start=1):
            coords = xyp[index, :, :2]
            coords[xyp[index, :, 2] < pcutoff] = np.nan
            scat.set_offsets(coords)
            if ind_links:
                segs = coords[tuple(zip(*tuple([ind_links])))].swapaxes(0, 1)
            coll.set_segments(segs)
            writer.grab_frame()
    plt.close(fig)
    plt.switch_backend(prev_backend)
Beispiel #41
0
class TutorialLayerArtist(MatplotlibLayerArtist):

    _layer_state_cls = TutorialLayerState

    def __init__(self, axes, *args, **kwargs):

        super(TutorialLayerArtist, self).__init__(axes, *args, **kwargs)

        #self.artist = self.axes.plot([], [], 'o', mec='none')[0]
        self.lc = LineCollection([], color='k', linestyle='solid')
        self.artist = self.axes.add_collection(self.lc)
        self.mpl_artists.append(self.artist)

        self.state.add_callback('visible', self._on_visual_change)
        self.state.add_callback('zorder', self._on_visual_change)
        self.state.add_callback('color', self._on_visual_change)
        self.state.add_callback('alpha', self._on_visual_change)
        self.state.add_callback('linewidth', self._on_visual_change)

        self._viewer_state.add_callback('x_att', self._on_attribute_change)
        self._viewer_state.add_callback('y_att', self._on_attribute_change)
        self._viewer_state.add_callback('orientation',
                                        self._on_attribute_change)

    def _on_visual_change(self, value=None):

        self.artist.set_visible(self.state.visible)
        self.artist.set_zorder(self.state.zorder)
        self.lc.set_color(self.state.color)
        self.lc.set_linewidth(self.state.linewidth)
        #self.artist.set_markeredgecolor(self.state.color)
        # if self.state.fill:
        #     self.artist.set_markerfacecolor(self.state.color)
        # else:
        #     self.artist.set_markerfacecolor('white')
        self.artist.set_alpha(self.state.alpha)

        self.redraw()

    def _on_attribute_change(self, value=None):

        if self._viewer_state.x_att is None or self._viewer_state.y_att is None:
            return

        #parent
        x = self.state.layer[self._viewer_state.x_att]
        #height
        y = self.state.layer[self._viewer_state.y_att]

        orientation = self._viewer_state.orientation

        ### sorty by
        ##### sorby_array = None for using the orignal order
        sortby_array = y  ### sort by height
        x, y = sort1Darrays(x, y, sortby_array)

        verts, verts_horiz = dendro_layout(x, y, orientation=orientation)
        nleaf = calculate_nleaf(x)

        ###  Fix the input!
        color_code = 'linear'
        color_code_by = y  ## height
        color_code_cmap = cm.Reds

        ####
        if color_code == 'fixed':

            verts_final = np.concatenate([verts, verts_horiz])
            colors_final = list(np.ones(len(verts_final)))

        elif color_code == 'linear':

            cmap = color_code_cmap
            normalize = mplcolors.Normalize(np.nanmin(y), np.nanmax(y))

            colors = [cmap(normalize(yi)) for yi in y]
            colors_horiz = []
            for i in range(len(verts_horiz)):
                colors_horiz.append((0., 0., 0., 1.))

            verts_final = np.concatenate([verts, verts_horiz])

            colors_final = np.concatenate([colors, colors_horiz])

        #self.artist.set_data(x, y)
        self.lc.set_segments(verts_final)
        self.lc.set_color(colors_final)

        # parent
        xmin = (-.5)
        xmax = nleaf + 1.5
        # height
        ymin = np.nanmin(y) - .05 * (np.nanmax(y) - np.nanmin(y))
        ymax = np.nanmax(y) + .05 * (np.nanmax(y) - np.nanmin(y))

        if orientation == 'bottom-up':
            self.axes.set_xlim(xmin, xmax)
            self.axes.set_ylim(ymin, ymax)
        elif orientation == 'top-down':
            self.axes.set_xlim(xmin, xmax)
            self.axes.set_ylim(ymax, ymin)
        elif orientation == 'left-right':
            self.axes.set_ylim(xmin, xmax)
            self.axes.set_xlim(ymin, ymax)
        elif orientation == 'right-left':
            self.axes.set_ylim(xmin, xmax)
            self.axes.set_xlim(ymax, ymin)

        self.redraw()

    def update(self):
        self._on_attribute_change()
        self._on_visual_change()
Beispiel #42
0
 def set_segments(self, segments):
     '''
     Set 3D segments
     '''
     self._segments3d = np.asanyarray(segments)
     LineCollection.set_segments(self, [])
Beispiel #43
0
class SpikeBrowserUI(object):
    def __init__(self, window):
        self.window = window
        self.sp_win = [-0.8, 1]
        self.spike_collection = None

        self.fig = Figure((9, 5), 75)

        self.canvas = window.get_canvas(self.fig)

        self._mpl_init()
        self.canvas.mpl_connect('key_press_event', self._zoom_key_handler)
        self.canvas.mpl_connect(
            'key_press_event', self._browse_spikes_key_handler)
        self.window.set_scroll_handler(self.OnScrollEvt)

    def _mpl_init(self):
        self.fig.clf()
        self.axes = self.fig.add_axes([0.02, 0.1, 0.96, 0.85])
        self.fancyyaxis = FancyYAxis(self, 0.05)
        self.ax_prev = self.fig.add_axes([0.8, 0.0, 0.1, 0.05])
        self.ax_next = self.fig.add_axes([0.9, 0.0, 0.1, 0.05])

        self.b_next = Button(self.ax_next, 'Next')
        self.b_prev = Button(self.ax_prev, "Prev")

        self.b_next.on_clicked(self._next_spike)
        self.b_prev.on_clicked(self._prev_spike)
        self.i_spike = 0
        self.i_start = 0
        self.line_collection = None

    def _next_spike(self, event):
        try:
            if self.i_spike < len(self.spt) - 1:
                self.i_spike += 1
            t_spk = self.spt[self.i_spike]
            i_start = int(
                np.ceil(t_spk / 1000. * self.FS - self.i_window / 2.))
            i_start = np.maximum(self.i_min, i_start)
            i_start = np.minimum(self.i_max, i_start)
            self.i_start = i_start
            self.i_end = self.i_start + self.i_window
            self.window.set_scroll_pos(self.i_start)
            self.draw_plot()
        except IndexError:
            pass

    def _prev_spike(self, event):
        try:
            if self.i_spike > 0:
                self.i_spike -= 1
            t_spk = self.spt[self.i_spike]
            i_start = int(
                np.ceil(t_spk / 1000. * self.FS - self.i_window / 2.))
            i_start = np.maximum(self.i_min, i_start)
            i_start = np.minimum(self.i_max, i_start)
            self.i_start = i_start
            self.i_end = self.i_start + self.i_window
            self.window.set_scroll_pos(self.i_start)
            self.draw_plot()
        except IndexError:
            pass

    def _zoom_key_handler(self, event):
        if event.key == '+' or event.key == '=':
            self.scale_y(0.5)
        elif event.key == '-':
            self.scale_y(2)
        elif event.key == 'ctrl++' or event.key == 'ctrl+=':
            self.scale_x(0.5)
        elif event.key == 'ctrl+-':
            self.scale_x(2)

        self.draw_plot()

    def _browse_spikes_key_handler(self, event):
        if event.key == 'right':
            self._next_spike(event)
        elif event.key == 'left':
            self._prev_spike(event)
        else:
            return

    def set_spiketimes(self, spk_idx, labels=None, all_labels=None):
        if spk_idx:
            self.spt = spk_idx['data']
            if labels is not None:
                self.labels = labels
                if all_labels is None:
                    self.color_func = label_color(np.unique(labels))
                else:
                    self.color_func = label_color(all_labels)
            else:
                self.labels = None

            self.ax_next.set_visible(True)
            self.ax_prev.set_visible(True)

            self.update_i_sipke()

        else:
            self.spt = None
            self.ax_next.set_visible(False)
            self.ax_prev.set_visible(False)

    def set_data(self, data):
        self.x = data['data']
        self.FS = data['FS']
        n_chans, n_pts = self.x.shape

        # reset spike times data/hide buttons
        self.set_spiketimes(None)

        self.i_window = int(self.winsz / 1000. * self.FS)
        # Extents of data sequence:
        self.i_min = 0
        self.i_max = n_pts - self.i_window
        self.n_chans = n_chans

        self.window.set_scroll_max(self.i_max, self.i_window)

        # Indices of data interval to be plotted:
        self.i_end = self.i_start + self.i_window
        curr_slice = self.x[:, self.i_start:self.i_end]
        ylims = (curr_slice.min(), curr_slice.max())
        offset = ylims[1] - ylims[0]

        self.ylims = np.array(ylims)
        self.offsets = np.arange(n_chans) * offset
        
        # will be filled in draw_plot
        self.segs = np.empty((n_chans, self.i_window, 2))

        if self.line_collection:
            self.line_collection.remove()

        self.line_collection = LineCollection(self.segs,
                                              offsets=None,
                                              transform=self.axes.transData,
                                              color='k')

        self.axes.add_collection(self.line_collection)
        self.fancyyaxis.reset()

        self.draw_plot()

    def draw_plot(self):
        self.time = np.arange(self.i_start, self.i_end) * 1. / self.FS
        self.segs[:, :, 0] = self.time[np.newaxis, :]
        y_signal = self.x[:, self.i_start:self.i_end]
        y_signal = y_signal - np.mean(y_signal, 1)[:, None]
        self.segs[:, :, 1] = y_signal + self.offsets[:, np.newaxis]
        self.line_collection.set_segments(self.segs)

        # Adjust plot limits:
        self.axes.set_xlim((self.time[0], self.time[-1]))

        ygap = np.max(np.abs(self.ylims))
        self.axes.set_ylim((- ygap + self.offsets.min(),
                            ygap + self.offsets.max()))
        self.fancyyaxis.update()

        if self.spt is not None:
            self.draw_spikes()
        # Redraw:
        self.canvas.draw()


    def draw_spikes(self):
        if self.spike_collection is not None:
            self.spike_collection.remove()
            self.spike_collection = None
        sp_win = self.sp_win
        time = self.segs[0, :, 0] * 1000.
        t_min, t_max = time[0] - sp_win[0], time[-1] - sp_win[1]
        spt = self.spt[(self.spt > t_min) & (self.spt < t_max)]
        if len(spt) > 0:
            n_pts = int((sp_win[1] - sp_win[0]) / 1000. * self.FS)
            sp_segs = np.empty((len(spt), self.n_chans, n_pts, 2))
            for i in range(len(spt)):
                start, = np.nonzero(time >= (spt[i] + sp_win[0]))
                start = start[0]
                stop = start + n_pts
                sp_segs[i, :, :, 0] = (time[np.newaxis, start:stop] / 1000.)
                sp_segs[i, :, :, 1] = self.segs[:, start:stop, 1]
            sp_segs = sp_segs.reshape(-1, n_pts, 2)
            if self.labels is not None:
                labs = self.labels[(self.spt > t_min) & (self.spt < t_max)]
                colors = np.repeat(self.color_func(labs), self.n_chans, 0)
            else:
                colors = 'r'
            self.spike_collection = LineCollection(sp_segs,
                                                   offsets=None,
                                                   color=colors,
                                                   transform=self.axes.transData)
            self.axes.add_collection(self.spike_collection)

    def scale_y(self, factor):
        self.ylims *= factor
        offset = self.ylims[1] - self.ylims[0]
        self.offsets = np.arange(self.n_chans) * offset

    def scale_x(self, factor):
        i_center = self.i_start + self.i_window / 2
        self.i_window = int(self.i_window * factor)
        self.i_start = i_center - self.i_window / 2
        self.i_start = self.i_start >= 0 and self.i_start or 0
        self.i_end = self.i_start + self.i_window

        self.segs = np.empty((self.n_chans, self.i_window, 2))


    def OnScrollEvt(self, pos):

        # Update the indices of the plot:
        self.i_start = self.i_min + pos
        self.i_end = self.i_min + self.i_window + pos

        self.update_i_sipke()
        self.draw_plot()

    def update_i_sipke(self):
        '''
        Finds a spike index which is near or inside the current data
        window (between i_start and i_end). The i_spike variable is then
        updated with this index.
        '''

        t_center = (self.i_start + self.i_window / 2.) * 1000. / self.FS
        idx, = np.where(self.spt < t_center)
        if len(idx) > 0:
            self.i_spike = idx[-1]
        else:
            self.i_spike = 0