Esempio n. 1
0
def colorbar_legend(ax, values, cmap, vis=True):
    """
    Add a vertical colorbar legend to a plot
    """
    x_range = ax.get_xlim()[1]-ax.get_xlim()[0]
    y_range = ax.get_ylim()[1]-ax.get_ylim()[0]

    x = [ax.get_xlim()[0]+x_range*0.05]
    y = [ax.get_ylim()[1]-(y_range * 0.25), ax.get_ylim()[1]-(y_range*0.05)]

    segs = []
    vals=[]
    p = (x[0], y[0]+((y[1]-y[0])/256.0))
    for i in range(2, 257):
        n = (x[0], y[0]+((y[1]-y[0])/256.0)*i)
        segs.append((p, n))
        p = segs[-1][-1]
        vals.append(min(values)+((max(values)-min(values))/256.0)*(i-1))
    lcbar =  LineCollection(segs, cmap=cmap, lw=15)
    lcbar.set_visible(vis)
    lcbar.set_array(np.array(vals))
    ax.add_collection(lcbar)
    lcbar.set_zorder(1)


    minlab = str(min(values))[:6]
    maxlab = str(max(values))[:6]

    ax.text(x[0]+x_range*.02, y[0], minlab, verticalalignment="bottom", visible=vis)
    ax.text(x[0]+x_range*.02, y[1], maxlab, verticalalignment="top", visible=vis)
Esempio n. 2
0
def colorbar_legend(ax, values, cmap, vis=True):
    """
    Add a vertical colorbar legend to a plot
    """
    x_range = ax.get_xlim()[1]-ax.get_xlim()[0]
    y_range = ax.get_ylim()[1]-ax.get_ylim()[0]

    x = [ax.get_xlim()[0]+x_range*0.05]
    y = [ax.get_ylim()[1]-(y_range * 0.25), ax.get_ylim()[1]-(y_range*0.05)]

    segs = []
    vals=[]
    p = (x[0], y[0]+((y[1]-y[0])/256.0))
    for i in range(2, 257):
        n = (x[0], y[0]+((y[1]-y[0])/256.0)*i)
        segs.append((p, n))
        p = segs[-1][-1]
        vals.append(min(values)+((max(values)-min(values))/256.0)*(i-1))
    lcbar =  LineCollection(segs, cmap=cmap, lw=15)
    lcbar.set_visible(vis)
    lcbar.set_array(np.array(vals))
    ax.add_collection(lcbar)
    lcbar.set_zorder(1)


    minlab = str(min(values))[:6]
    maxlab = str(max(values))[:6]

    ax.text(x[0]+x_range*.02, y[0], minlab, verticalalignment="bottom", visible=vis)
    ax.text(x[0]+x_range*.02, y[1], maxlab, verticalalignment="top", visible=vis)
Esempio n. 3
0
def add_hrm_hiddenstate_recon(treeplot, liks, nregime,vis=True, width=2, colors=None):
    """
    Color branches based on likelihood of being in hidden states.

    Args:
        liks (np.array): The output of anc_recon run as a hidden-rates reconstruction.
    """
    root = treeplot.root
    horz_seg_collections = [None] * (len(root)-1)
    horz_seg_colors = [None]*(len(root)-1)
    vert_seg_collections = [None] * (len(root)-1)
    vert_seg_colors = [None] * (len(root)-1)

    nchar = liks.shape[1]-2 # Liks has rows equal to nchar+2 (the last two rows are indices)
    nobschar = nchar//nregime

    if colors is None:
        c = _tango
        colors = [next(c) for v in range(nregime)]

    for i,n in enumerate(root.descendants()):
        n_lik = liks[i+1] # Add 1 to the index because the loop is skipping the root
        par_lik = liks[n.parent.ni]
        n_r1 = sum(n_lik[:nchar//2])# Likelihood of node being in regime 1
        p_r1 = sum(par_lik[:nchar//2])# Likelihood of node being in regime 1
        n_col = color_map(n_r1, colors[0], colors[1])
        par_col = color_map(p_r1, colors[0], colors[1])

        n_coords = treeplot.n2c[n]
        par_coords = treeplot.n2c[n.parent]

        p1 = (n_coords.x, n_coords.y)
        p2 = (par_coords.x, n_coords.y)

        hsegs,hcols = gradient_segment_horz(p1,p2,n_col,par_col)

        horz_seg_collections[i] = hsegs
        horz_seg_colors[i] = hcols

        vert_seg_collections[i] = ([(par_coords.x,par_coords.y),
                                     (par_coords.x, n_coords.y)])
        vert_seg_colors[i] = (par_col)
    horz_seg_collections = [i for s in horz_seg_collections for i in s]
    horz_seg_colors = [i for s in horz_seg_colors for i in s]
    lc = LineCollection(horz_seg_collections + vert_seg_collections,
                        colors = horz_seg_colors + vert_seg_colors,
                        lw = width)
    lc.set_visible(vis)
    treeplot.add_collection(lc)

    # leg_ax = treeplot.figure.add_axes([0.3, 0.8, 0.1, 0.1])
    # leg_ax.tick_params(which = "both",
    #                    bottom = "off",
    #                    labelbottom="off",
    #                    top = "off",
    #                    left = "off",
    #                    labelleft = "off",
    #                    right = "off")
    treeplot.figure.canvas.draw_idle()
Esempio n. 4
0
def add_ancrecon_hrm(treeplot, liks, vis=True, width=2):
    """
    Color branches on tree based on likelihood of being in a state
    based on ancestral state reconstruction of a two-character, two-regime
    hrm model.
    """
    root = treeplot.root
    horz_seg_collections = [None] * (len(root)-1)
    horz_seg_colors = [None]*(len(root)-1)
    vert_seg_collections = [None] * (len(root)-1)
    vert_seg_colors = [None] * (len(root)-1)

    nchar = liks.shape[1]-2

    for i,n in enumerate(root.descendants()):
        n_lik = liks[i+1]
        par_lik = liks[n.parent.ni]
        n_col = twoS_twoR_colormaker(n_lik[:nchar])
        par_col = twoS_twoR_colormaker(par_lik[:nchar])

        n_coords = treeplot.n2c[n]
        par_coords = treeplot.n2c[n.parent]

        p1 = (n_coords.x, n_coords.y)
        p2 = (par_coords.x, n_coords.y)

        hsegs,hcols = gradient_segment_horz(p1,p2,n_col.rgb,par_col.rgb)

        horz_seg_collections[i] = hsegs
        horz_seg_colors[i] = hcols

        vert_seg_collections[i] = ([(par_coords.x,par_coords.y),
                                     (par_coords.x, n_coords.y)])
        vert_seg_colors[i] = (par_col.rgb)
    horz_seg_collections = [i for s in horz_seg_collections for i in s]
    horz_seg_colors = [i for s in horz_seg_colors for i in s]
    lc = LineCollection(horz_seg_collections + vert_seg_collections,
                        colors = horz_seg_colors + vert_seg_colors,
                        lw = width)
    lc.set_visible(vis)
    treeplot.add_collection(lc)

    leg_ax = treeplot.figure.add_axes([0.3, 0.8, 0.1, 0.1])
    leg_ax.tick_params(which = "both",
                       bottom = "off",
                       labelbottom="off",
                       top = "off",
                       left = "off",
                       labelleft = "off",
                       right = "off")

    c1 = twoS_twoR_colormaker([1,0,0,0])
    c2 = twoS_twoR_colormaker([0,1,0,0])
    c3 = twoS_twoR_colormaker([0,0,1,0])
    c4 = twoS_twoR_colormaker([0,0,0,1])

    grid = np.array([[c1.rgb,c2.rgb],[c3.rgb,c4.rgb]])
    leg_ax.imshow(grid, interpolation="bicubic")
    treeplot.figure.canvas.draw_idle()
Esempio n. 5
0
def add_mkmr_heatmap(treeplot, locations, vis=True, seglen=0.02):
    """
    Heatmap that shows which portions of the tree are most likely
    to contain a switchpoint

    To be used with the output from mk_multi_bayes.

    Args:
        locations (list): List of lists containing node and distance.
          The output from the switchpoint stochastic of mk_multi_bayes.
        seglen (float): The size of segments to break the tree into.
          MUST BE the same as the seglen used in mk_multi_bayes.
    """
    treelen = treeplot.root.max_tippath()
    seglen_px = seglen*treelen
    locations = [tuple([treeplot.root[x[0].ni],round(x[1],7)]) for x in locations]
    segmap = ivy.chars.mk_mr.tree_map(treeplot.root, seglen=seglen)
    segmap = [tuple([x[0],round(x[1],7)]) for x in segmap] # All possible segments to plot

    nrep = len(locations)

    rates = defaultdict(lambda: 0) # Setting heatmap densities
    for l in Counter(locations).items():
        rates[l[0]] = l[1]

    cmap = RdYlBu
    segments = []
    values = []
    # TODO: radial plot type

    for s in segmap:
        node = s[0]
        c = xy(treeplot, node)
        cp = xy(treeplot, node.parent)

        x0 = c[0] - s[1] # Start of segment
        x1 = x0 - seglen_px# End of segment
        if x1 < cp[0]:
            x1 = cp[0]
        y0 = y1 = c[1]

        segments.append(((x0,y0),(x1,y1)))
        values.append(rates[s]/nrep)

        if s[1] == 0.0 and not s[0].isleaf: # Draw vertical segments
            x0 = x1 = c[0]
            y0 = xy(treeplot, node.children[0])[1]
            y1 = xy(treeplot, node.children[-1])[1]
        segments.append(((x0,y0),(x1,y1)))
        values.append(rates[s]/nrep)

    lc = LineCollection(segments, cmap=RdYlBu, lw=2)
    lc.set_array(np.array(values))
    treeplot.add_collection(lc)
    lc.set_zorder(1)
    lc.set_visible(vis)
    treeplot.figure.canvas.draw_idle()
Esempio n. 6
0
def add_ancrecon(treeplot, liks, vis=True, width=2):
    """
    Plot ancestor reconstruction for a binary mk model
    """
    root = treeplot.root
    horz_seg_collections = [None] * (len(root)-1)
    horz_seg_colors = [None]*(len(root)-1)
    vert_seg_collections = [None] * (len(root)-1)
    vert_seg_colors = [None] * (len(root)-1)

    nchar = liks.shape[1]-2
    for i,n in enumerate(root.descendants()):
        n_lik = liks[i+1]
        par_lik = liks[n.parent.ni]
        n_col = twoS_colormaker(n_lik[:nchar])
        par_col = twoS_colormaker(par_lik[:nchar])

        n_coords = treeplot.n2c[n]
        par_coords = treeplot.n2c[n.parent]

        p1 = (n_coords.x, n_coords.y)
        p2 = (par_coords.x, n_coords.y)

        hsegs,hcols = gradient_segment_horz(p1,p2,n_col.rgb,par_col.rgb)

        horz_seg_collections[i] = hsegs
        horz_seg_colors[i] = hcols

        vert_seg_collections[i] = ([(par_coords.x,par_coords.y),
                                     (par_coords.x, n_coords.y)])
        vert_seg_colors[i] = (par_col.rgb)
    horz_seg_collections = [i for s in horz_seg_collections for i in s]
    horz_seg_colors = [i for s in horz_seg_colors for i in s]
    lc = LineCollection(horz_seg_collections + vert_seg_collections,
                        colors = horz_seg_colors + vert_seg_colors,
                        lw = width)
    lc.set_visible(vis)
    treeplot.add_collection(lc)

    treeplot.figure.canvas.draw_idle()
Esempio n. 7
0
class Anim:
    def __init__(
        self, eddy, intern=False, sleep_event=0.1, graphic_information=False, **kwargs
    ):
        self.eddy = eddy
        x_name, y_name = eddy.intern(intern)
        self.t, self.x, self.y = eddy.time, eddy[x_name], eddy[y_name]
        self.x_core, self.y_core, self.track = eddy["lon"], eddy["lat"], eddy["track"]
        self.graphic_informations = graphic_information
        self.pause = False
        self.period = self.eddy.period
        self.sleep_event = sleep_event
        self.mappables = list()
        self.field_color = None
        self.field_txt = None
        self.time_field = False
        self.setup(**kwargs)

    def setup(
        self,
        cmap="jet",
        lut=None,
        field_color="time",
        field_txt="track",
        range_color=(None, None),
        nb_step=25,
        figsize=(8, 6),
        **kwargs,
    ):
        self.field_color = self.eddy[field_color].astype("f4")
        self.field_txt = self.eddy[field_txt]
        rg = range_color
        if rg[0] is None and rg[1] is None and field_color == "time":
            self.time_field = True
        else:
            rg = (
                self.field_color.min() if rg[0] is None else rg[0],
                self.field_color.max() if rg[1] is None else rg[1],
            )
            self.field_color = (self.field_color - rg[0]) / (rg[1] - rg[0])

        self.colors = pyplot.get_cmap(cmap, lut=lut)
        self.nb_step = nb_step

        x_min, x_max = self.x_core.min() - 2, self.x_core.max() + 2
        d_x = x_max - x_min
        y_min, y_max = self.y_core.min() - 2, self.y_core.max() + 2
        d_y = y_max - y_min
        # plot
        self.fig = pyplot.figure(figsize=figsize, **kwargs)
        t0, t1 = self.period
        self.fig.suptitle(f"{t0} -> {t1}")
        self.ax = self.fig.add_axes((0.05, 0.05, 0.9, 0.9), projection="full_axes")
        self.ax.set_xlim(x_min, x_max), self.ax.set_ylim(y_min, y_max)
        self.ax.set_aspect("equal")
        self.ax.grid()
        # init mappable
        self.txt = self.ax.text(x_min + 0.05 * d_x, y_min + 0.05 * d_y, "", zorder=10)
        self.segs = list()
        self.t_segs = list()
        self.c_segs = list()
        self.contour = LineCollection([], zorder=1)
        self.ax.add_collection(self.contour)

        self.fig.canvas.draw()
        self.fig.canvas.mpl_connect("key_press_event", self.keyboard)
        self.fig.canvas.mpl_connect("resize_event", self.reset_bliting)

    def reset_bliting(self, event):
        self.contour.set_visible(False)
        self.txt.set_visible(False)
        for m in self.mappables:
            m.set_visible(False)
        self.fig.canvas.draw()
        self.bg_cache = self.fig.canvas.copy_from_bbox(self.ax.bbox)
        self.contour.set_visible(True)
        self.txt.set_visible(True)
        for m in self.mappables:
            m.set_visible(True)

    def show(self, infinity_loop=False):
        pyplot.show(block=False)
        # save background for future bliting
        self.fig.canvas.draw()
        self.bg_cache = self.fig.canvas.copy_from_bbox(self.ax.bbox)
        loop = True
        t0, t1 = self.period
        while loop:
            self.now = t0
            while True:
                dt = self.sleep_event
                if not self.pause:
                    d0 = datetime.now()
                    self.next()
                    dt_draw = (datetime.now() - d0).total_seconds()
                    dt = self.sleep_event - dt_draw
                    if dt < 0:
                        # self.sleep_event = dt_draw * 1.01
                        dt = 1e-10
                if dt == 0:
                    dt = 1e-10
                self.fig.canvas.start_event_loop(dt)
                if self.now > t1:
                    break
            if infinity_loop:
                self.fig.canvas.start_event_loop(0.5)
            else:
                loop = False

    def next(self):
        self.now += 1
        return self.draw_contour()

    def prev(self):
        self.now -= 1
        return self.draw_contour()

    def func_animation(self, frame):
        while self.mappables:
            self.mappables.pop().remove()
        self.now = frame
        self.update()
        artists = [self.contour, self.txt]
        artists.extend(self.mappables)
        return artists

    def update(self):
        m = self.t == self.now
        if m.sum():
            segs = list()
            t = list()
            c = list()
            for i in where(m)[0]:
                segs.append(create_vertice(self.x[i], self.y[i]))
                c.append(self.field_color[i])
                t.append(self.now)
            self.segs.append(segs)
            self.c_segs.append(c)
            self.t_segs.append(t)
        self.contour.set_paths(chain(*self.segs))
        if self.time_field:
            self.contour.set_color(
                self.colors(
                    [
                        (self.nb_step - self.now + i) / self.nb_step
                        for i in chain(*self.c_segs)
                    ]
                )
            )
        else:
            self.contour.set_color(self.colors(list(chain(*self.c_segs))))
        # linewidth will be link to time delay
        self.contour.set_lw(
            [
                (1 - (self.now - i) / self.nb_step) * 2.5 if i <= self.now else 0
                for i in chain(*self.t_segs)
            ]
        )
        # Update date txt and info
        txt = f"{(timedelta(int(self.now)) + datetime(1950,1,1)).strftime('%Y/%m/%d')}"
        if self.graphic_informations:
            txt += f"- {1/self.sleep_event:.0f} frame/s"
        self.txt.set_text(txt)
        # Update id txt
        for i in where(m)[0]:
            mappable = self.ax.text(
                self.x_core[i],
                self.y_core[i],
                self.field_txt[i],
                fontsize=12,
                fontweight="demibold",
            )
            self.mappables.append(mappable)
            self.ax.draw_artist(mappable)
        self.ax.draw_artist(self.contour)
        self.ax.draw_artist(self.txt)
        # Remove first segment to keep only T contour
        if len(self.segs) > self.nb_step:
            self.segs.pop(0)
            self.t_segs.pop(0)
            self.c_segs.pop(0)

    def draw_contour(self):
        # select contour for this time step
        while self.mappables:
            self.mappables.pop().remove()
        self.ax.figure.canvas.restore_region(self.bg_cache)
        self.update()
        # paint updated artist
        self.ax.figure.canvas.blit(self.ax.bbox)

    def keyboard(self, event):
        if event.key == "escape":
            exit()
        elif event.key == " ":
            self.pause = not self.pause
        elif event.key == "+":
            self.sleep_event *= 0.9
        elif event.key == "-":
            self.sleep_event *= 1.1
        elif event.key == "right" and self.pause:
            self.next()
        elif event.key == "left" and self.pause:
            # we remove 2 step to add 1 so we rewind of only one
            self.segs.pop(-1)
            self.segs.pop(-1)
            self.t_segs.pop(-1)
            self.t_segs.pop(-1)
            self.c_segs.pop(-1)
            self.c_segs.pop(-1)
            self.prev()
Esempio n. 8
0
def add_phylorate(treeplot, rates, nodeidx, vis=True):
    """
    Add phylorate plot generated from data analyzed with BAMM
    (http://bamm-project.org/introduction.html)

    Args:
        rates (array): Array of rates along branches
          created by r_funcs.phylorate
        nodeidx (array): Array of node indices matching rates (also created
          by r_funcs.phylorate)

    WARNING:
        Ladderizing the tree can cause incorrect assignment of Ape node index
        numbers. To prevent this, call this function or root.ape_node_idx()
        before ladderizing the tree to assign correct Ape node index numbers.
    """
    if not treeplot.root.apeidx:
        treeplot.root.ape_node_idx()
    segments = []
    values = []

    if treeplot.plottype == "radial":
        radpatches = [] # For use in drawing arcs for radial plots

        for n in treeplot.root.descendants():
            n.rates = rates[nodeidx==n.apeidx]
            c = treeplot.n2c[n]
            pc = treeplot._path_to_parent(n)[0][1]
            xd = c.x - pc[0]
            yd = c.y - pc[1]
            xseg = xd/len(n.rates)
            yseg = yd/len(n.rates)
            for i, rate in enumerate(n.rates):
                x0 = pc[0] + i*xseg
                y0 = pc[1] + i*yseg
                x1 = x0 + xseg
                y1 = y0 + yseg

                segments.append(((x0, y0), (x1, y1)))
                values.append(rate)

            curverts = treeplot._path_to_parent(n)[0][2:]
            curcodes = treeplot._path_to_parent(n)[1][2:]
            curcol = RdYlBu(n.rates[0])

            radpatches.append(PathPatch(
                       Path(curverts, curcodes), lw=2, edgecolor = curcol,
                            fill=False))
    else:
        for n in treeplot.root.descendants():
            n.rates = rates[nodeidx==n.apeidx]
            c = treeplot.n2c[n]
            pc = treeplot.n2c[n.parent]
            seglen = (c.x-pc.x)/len(n.rates)
            for i, rate in enumerate(n.rates):
                x0 = pc.x + i*seglen
                x1 = x0 + seglen
                segments.append(((x0, c.y), (x1, c.y)))
                values.append(rate)
            segments.append(((pc.x, pc.y), (pc.x, c.y)))
            values.append(n.rates[0])

    lc = LineCollection(segments, cmap=RdYlBu, lw=2)
    lc.set_array(np.array(values))
    treeplot.add_collection(lc)
    lc.set_zorder(1)
    if treeplot.plottype == "radial":
        arccol = matplotlib.collections.PatchCollection(radpatches,
                                                        match_original=True)
        treeplot.add_collection(arccol)
        arccol.set_visible(vis)
        arccol.set_zorder(1)
    lc.set_visible(vis)
    colorbar_legend(treeplot, values, RdYlBu, vis=vis)

    treeplot.figure.canvas.draw_idle()
Esempio n. 9
0
class HoughDemo(ImageProcessDemo):
    TITLE = u"Hough Demo"
    DEFAULT_IMAGE = "stuff.jpg"
    SETTINGS = ["th2", "show_canny", "rho", "theta", "hough_th",
                "minlen", "maxgap", "dp", "mindist", "param2",
                "min_radius", "max_radius", "blur_sigma",
                "linewidth", "alpha", "check_line", "check_circle"]

    check_line = Bool(True)
    check_circle = Bool(True)

    #Gaussian blur parameters
    blur_sigma = Range(0.1, 5.0, 2.0)
    show_blur = Bool(False)

    # Canny parameters
    th2 = Range(0.0, 255.0, 200.0)
    show_canny = Bool(False)

    # HoughLine parameters
    rho = Range(1.0, 10.0, 1.0)
    theta = Range(0.1, 5.0, 1.0)
    hough_th = Range(1, 100, 40)
    minlen = Range(0, 100, 10)
    maxgap = Range(0, 20, 10)

    # HoughtCircle parameters

    dp = Range(1.0, 5.0, 1.9)
    mindist = Range(1.0, 100.0, 50.0)
    param2 = Range(5, 100, 50)
    min_radius = Range(5, 100, 20)
    max_radius = Range(10, 100, 70)

    # draw parameters
    linewidth = Range(1.0, 3.0, 1.0)
    alpha = Range(0.0, 1.0, 0.6)

    def control_panel(self):
        return VGroup(
            Group(
                Item("blur_sigma", label=u"标准方差"),
                Item("show_blur", label=u"显示结果"),
                label=u"高斯模糊参数"
            ),
            Group(
                Item("th2", label=u"阈值2"),
                Item("show_canny", label=u"显示结果"),
                label=u"边缘检测参数"
            ),
            Group(
                Item("rho", label=u"偏移分辨率(像素)"),
                Item("theta", label=u"角度分辨率(角度)"),
                Item("hough_th", label=u"阈值"),
                Item("minlen", label=u"最小长度"),
                Item("maxgap", label=u"最大空隙"),
                label=u"直线检测"
            ),
            Group(
                Item("dp", label=u"分辨率(像素)"),
                Item("mindist", label=u"圆心最小距离(像素)"),
                Item("param2", label=u"圆心检查阈值"),
                Item("min_radius", label=u"最小半径"),
                Item("max_radius", label=u"最大半径"),
                label=u"圆检测"
            ),
            Group(
                Item("linewidth", label=u"线宽"),
                Item("alpha", label=u"alpha"),
                HGroup(
                    Item("check_line", label=u"直线"),
                    Item("check_circle", label=u"圆"),
                ),
                label=u"绘图参数"
            )
        )

    def __init__(self, **kwargs):
        super(HoughDemo, self).__init__(**kwargs)
        self.connect_dirty("th2, show_canny, show_blur, rho, theta, hough_th,"
                            "min_radius, max_radius, blur_sigma,"
                           "minlen, maxgap, dp, mindist, param2, "
                           "linewidth, alpha, check_line, check_circle")
        self.lines = LineCollection([], linewidths=2, alpha=0.6)
        self.axe.add_collection(self.lines)

        self.circles = EllipseCollection(
            [], [], [],
            units="xy",
            facecolors="none",
            edgecolors="red",
            linewidths=2,
            alpha=0.6,
            transOffset=self.axe.transData)

        self.axe.add_collection(self.circles)

    def _img_changed(self):
        self.img_gray = cv2.cvtColor(self.img, cv2.COLOR_BGR2GRAY)

    def draw(self):
        img_smooth = cv2.GaussianBlur(self.img_gray, (0, 0), self.blur_sigma, self.blur_sigma)
        img_edge = cv2.Canny(img_smooth, self.th2 * 0.5, self.th2)

        if self.show_blur and self.show_canny:
            show_img = cv2.cvtColor(np.maximum(img_smooth, img_edge), cv2.COLOR_BAYER_BG2BGR)
        elif self.show_blur:
            show_img = cv2.cvtColor(img_smooth, cv2.COLOR_BAYER_BG2BGR)
        elif self.show_canny:
            show_img = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2BGR)
        else:
            show_img = self.img

        if self.check_line:
            theta = self.theta / 180.0 * np.pi
            lines = cv2.HoughLinesP(img_edge,
                                    self.rho, theta, self.hough_th,
                                    minLineLength=self.minlen,
                                    maxLineGap=self.maxgap)

            if lines is not None:
                lines = lines[0]
                lines.shape = -1, 2, 2
                self.lines.set_segments(lines)
                self.lines.set_visible(True)
            else:
                self.lines.set_visible(False)
        else:
            self.lines.set_visible(False)

        if self.check_circle:
            circles = cv2.HoughCircles(img_smooth, 3,
                                       self.dp, self.mindist,
                                       param1=self.th2,
                                       param2=self.param2,
                                       minRadius=self.min_radius,
                                       maxRadius=self.max_radius)

            if circles is not None:
                circles = circles[0]
                self.circles._heights = self.circles._widths = circles[:, 2]
                self.circles.set_offsets(circles[:, :2])
                self.circles._angles = np.zeros(len(circles))
                self.circles._transOffset = self.axe.transData
                self.circles.set_visible(True)
            else:
                self.circles.set_visible(False)
        else:
            self.circles.set_visible(False)

        self.lines.set_linewidths(self.linewidth)
        self.circles.set_linewidths(self.linewidth)
        self.lines.set_alpha(self.alpha)
        self.circles.set_alpha(self.alpha)

        self.draw_image(show_img)
Esempio n. 10
0
class Anim:
    def __init__(self,
                 eddy,
                 intern=False,
                 sleep_event=0.1,
                 graphic_information=False,
                 **kwargs):
        self.eddy = eddy
        x_name, y_name = eddy.intern(intern)
        self.t, self.x, self.y = eddy.time, eddy[x_name], eddy[y_name]
        self.x_core, self.y_core, self.track = eddy["lon"], eddy["lat"], eddy[
            "track"]
        self.graphic_informations = graphic_information
        self.pause = False
        self.period = self.eddy.period
        self.sleep_event = sleep_event
        self.mappables = list()
        self.setup(**kwargs)

    def setup(self, cmap="jet", nb_step=25, figsize=(8, 6), **kwargs):
        cmap = pyplot.get_cmap(cmap)
        self.colors = cmap(arange(nb_step + 1) / nb_step)
        self.nb_step = nb_step

        x_min, x_max = self.x_core.min() - 2, self.x_core.max() + 2
        d_x = x_max - x_min
        y_min, y_max = self.y_core.min() - 2, self.y_core.max() + 2
        d_y = y_max - y_min
        # plot
        self.fig = pyplot.figure(figsize=figsize, **kwargs)
        t0, t1 = self.period
        self.fig.suptitle(f"{t0} -> {t1}")
        self.ax = self.fig.add_axes((0.05, 0.05, 0.9, 0.9))
        self.ax.set_xlim(x_min, x_max), self.ax.set_ylim(y_min, y_max)
        self.ax.set_aspect("equal")
        self.ax.grid()
        # init mappable
        self.txt = self.ax.text(x_min + 0.05 * d_x,
                                y_min + 0.05 * d_y,
                                "",
                                zorder=10)
        self.segs = list()
        self.contour = LineCollection([], zorder=1)
        self.ax.add_collection(self.contour)

        self.fig.canvas.draw()
        self.fig.canvas.mpl_connect("key_press_event", self.keyboard)
        self.fig.canvas.mpl_connect("resize_event", self.reset_bliting)

    def reset_bliting(self, event):
        self.contour.set_visible(False)
        self.txt.set_visible(False)
        for m in self.mappables:
            m.set_visible(False)
        self.fig.canvas.draw()
        self.bg_cache = self.fig.canvas.copy_from_bbox(self.ax.bbox)
        self.contour.set_visible(True)
        self.txt.set_visible(True)
        for m in self.mappables:
            m.set_visible(True)

    def show(self, infinity_loop=False):
        pyplot.show(block=False)
        # save background for future bliting
        self.fig.canvas.draw()
        self.bg_cache = self.fig.canvas.copy_from_bbox(self.ax.bbox)
        loop = True
        t0, t1 = self.period
        while loop:
            self.now = t0
            while True:
                dt = self.sleep_event
                if not self.pause:
                    d0 = datetime.now()
                    self.next()
                    dt_draw = (datetime.now() - d0).total_seconds()
                    dt = self.sleep_event - dt_draw
                    if dt < 0:
                        # self.sleep_event = dt_draw * 1.01
                        dt = 1e-10
                self.fig.canvas.start_event_loop(dt)

                if self.now > t1:
                    break
            if infinity_loop:
                self.fig.canvas.start_event_loop(0.5)
            else:
                loop = False

    def next(self):
        self.now += 1
        return self.draw_contour()

    def prev(self):
        self.now -= 1
        return self.draw_contour()

    def func_animation(self, frame):
        while self.mappables:
            self.mappables.pop().remove()
        self.now = frame
        self.update()
        artists = [self.contour, self.txt]
        artists.extend(self.mappables)
        return artists

    def update(self):
        m = self.t == self.now
        if m.sum():
            self.segs.append(
                create_vertice(flatten_line_matrix(self.x[m]),
                               flatten_line_matrix(self.y[m])))
        else:
            self.segs.append(empty((0, 2)))
        self.contour.set_paths(self.segs)
        self.contour.set_color(self.colors[-len(self.segs):])
        self.contour.set_lw(arange(len(self.segs)) / len(self.segs) * 2.5)
        txt = f"{self.now}"
        if self.graphic_informations:
            txt += f"- {1/self.sleep_event:.0f} frame/s"
        self.txt.set_text(txt)
        for i in where(m)[0]:
            mappable = self.ax.text(self.x_core[i],
                                    self.y_core[i],
                                    self.track[i],
                                    fontsize=8)
            self.mappables.append(mappable)
            self.ax.draw_artist(mappable)
        self.ax.draw_artist(self.contour)
        self.ax.draw_artist(self.txt)
        # Remove first segment to keep only T contour
        if len(self.segs) > self.nb_step:
            self.segs.pop(0)

    def draw_contour(self):
        # select contour for this time step
        while self.mappables:
            self.mappables.pop().remove()
        self.ax.figure.canvas.restore_region(self.bg_cache)
        self.update()
        # paint updated artist
        self.ax.figure.canvas.blit(self.ax.bbox)

    def keyboard(self, event):
        if event.key == "escape":
            exit()
        elif event.key == " ":
            self.pause = not self.pause
        elif event.key == "+":
            self.sleep_event *= 0.9
        elif event.key == "-":
            self.sleep_event *= 1.1
        elif event.key == "right" and self.pause:
            self.next()
        elif event.key == "left" and self.pause:
            self.segs.pop(-1)
            self.segs.pop(-1)
            self.prev()
Esempio n. 11
0
class HoughDemo(ImageProcessDemo):
    TITLE = u"Hough Demo"
    DEFAULT_IMAGE = "stuff.jpg"
    SETTINGS = ["th2", "show_canny", "rho", "theta", "hough_th",
                "minlen", "maxgap", "dp", "mindist", "param2",
                "min_radius", "max_radius", "blur_sigma",
                "linewidth", "alpha", "check_line", "check_circle"]

    check_line = Bool(True)
    check_circle = Bool(True)

    #Gaussian blur parameters
    blur_sigma = Range(0.1, 5.0, 2.0)
    show_blur = Bool(False)

    # Canny parameters
    th2 = Range(0.0, 255.0, 200.0)
    show_canny = Bool(False)

    # HoughLine parameters
    rho = Range(1.0, 10.0, 1.0)
    theta = Range(0.1, 5.0, 1.0)
    hough_th = Range(1, 100, 40)
    minlen = Range(0, 100, 10)
    maxgap = Range(0, 20, 10)

    # HoughtCircle parameters

    dp = Range(1.0, 5.0, 1.9)
    mindist = Range(1.0, 100.0, 50.0)
    param2 = Range(5, 100, 50)
    min_radius = Range(5, 100, 20)
    max_radius = Range(10, 100, 70)

    # draw parameters
    linewidth = Range(1.0, 3.0, 1.0)
    alpha = Range(0.0, 1.0, 0.6)

    def control_panel(self):
        return VGroup(
            Group(
                Item("blur_sigma", label=u"标准方差"),
                Item("show_blur", label=u"显示结果"),
                label=u"高斯模糊参数"
            ),
            Group(
                Item("th2", label=u"阈值2"),
                Item("show_canny", label=u"显示结果"),
                label=u"边缘检测参数"
            ),
            Group(
                Item("rho", label=u"偏移分辨率(像素)"),
                Item("theta", label=u"角度分辨率(角度)"),
                Item("hough_th", label=u"阈值"),
                Item("minlen", label=u"最小长度"),
                Item("maxgap", label=u"最大空隙"),
                label=u"直线检测"
            ),
            Group(
                Item("dp", label=u"分辨率(像素)"),
                Item("mindist", label=u"圆心最小距离(像素)"),
                Item("param2", label=u"圆心检查阈值"),
                Item("min_radius", label=u"最小半径"),
                Item("max_radius", label=u"最大半径"),
                label=u"圆检测"
            ),
            Group(
                Item("linewidth", label=u"线宽"),
                Item("alpha", label=u"alpha"),
                HGroup(
                    Item("check_line", label=u"直线"),
                    Item("check_circle", label=u"圆"),
                ),
                label=u"绘图参数"
            )
        )

    def __init__(self, **kwargs):
        super(HoughDemo, self).__init__(**kwargs)
        self.connect_dirty("th2, show_canny, show_blur, rho, theta, hough_th,"
                            "min_radius, max_radius, blur_sigma,"
                           "minlen, maxgap, dp, mindist, param2, "
                           "linewidth, alpha, check_line, check_circle")
        self.lines = LineCollection([], linewidths=2, alpha=0.6)
        self.axe.add_collection(self.lines)

        self.circles = EllipseCollection(
            [], [], [],
            units="xy",
            facecolors="none",
            edgecolors="red",
            linewidths=2,
            alpha=0.6,
            transOffset=self.axe.transData)

        self.axe.add_collection(self.circles)

    def _img_changed(self):
        self.img_gray = cv2.cvtColor(self.img, cv2.COLOR_BGR2GRAY)

    def draw(self):
        img_smooth = cv2.GaussianBlur(self.img_gray, (0, 0), self.blur_sigma, self.blur_sigma)
        img_edge = cv2.Canny(img_smooth, self.th2 * 0.5, self.th2)

        if self.show_blur and self.show_canny:
            show_img = cv2.cvtColor(np.maximum(img_smooth, img_edge), cv2.COLOR_BAYER_BG2BGR)
        elif self.show_blur:
            show_img = cv2.cvtColor(img_smooth, cv2.COLOR_BAYER_BG2BGR)
        elif self.show_canny:
            show_img = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2BGR)
        else:
            show_img = self.img

        if self.check_line:
            theta = self.theta / 180.0 * np.pi
            lines = cv2.HoughLinesP(img_edge,
                                    self.rho, theta, self.hough_th,
                                    minLineLength=self.minlen,
                                    maxLineGap=self.maxgap)

            if lines is not None:
                lines = lines[0]
                lines.shape = -1, 2, 2
                self.lines.set_segments(lines)
                self.lines.set_visible(True)
            else:
                self.lines.set_visible(False)
        else:
            self.lines.set_visible(False)

        if self.check_circle:
            circles = cv2.HoughCircles(img_smooth, 3,
                                       self.dp, self.mindist,
                                       param1=self.th2,
                                       param2=self.param2,
                                       minRadius=self.min_radius,
                                       maxRadius=self.max_radius)

            if circles is not None:
                circles = circles[0]
                self.circles._heights = self.circles._widths = circles[:, 2]
                self.circles.set_offsets(circles[:, :2])
                self.circles._angles = np.zeros(len(circles))
                self.circles._transOffset = self.axe.transData
                self.circles.set_visible(True)
            else:
                self.circles.set_visible(False)
        else:
            self.circles.set_visible(False)

        self.lines.set_linewidths(self.linewidth)
        self.circles.set_linewidths(self.linewidth)
        self.lines.set_alpha(self.alpha)
        self.circles.set_alpha(self.alpha)

        self.draw_image(show_img)
Esempio n. 12
0
class Visualisation(FigureCanvasWxAgg):
    def __init__(self, gui):

        self.fig = pl.figure()  #(9,8), 90)
        FigureCanvasWxAgg.__init__(self, gui, -1, self.fig)
        self.xlim, self.ylim, self.dataon = (), (), False

        # c'est la GUI e tle modele
        self.gui, self.core = gui, gui.core
        # polygone d'interaction sur une zone (pour pouvoir la modifier)
        self.polyInteract = None

        # liste de variables, sert pour la GUI et le combobox sur les variables
        self.listeNomVar = []
        for mod in self.core.modelList:
            for k in gui.linesDic[mod].keys():
                self.listeNomVar.extend(gui.linesDic[mod][k])
        self.curVar, self.curContour = None, 'Charge'  # variable courante selectionne
        self.curMedia, self.curOri, self.curPlan, self.curGroupe = 0, 'Z', 0, None
        # variable pour savoir si on est en cours de tracage d'une zone
        self.typeZone = -1

        # coordonnees et zone de la zone que l'on est en train de creer
        self.curZone = None  # objet graphique (ligne, point, rect..)
        self.x1, self.y1 = [], []
        self.tempZoneVal = []  # liste de values pour polyV
        self.calcE = 0
        self.calcT = 0
        self.calcR = 0
        # dit si calcule effectue ou non

        # dictionnaire qui est compose des variables de l'Aquifere
        # a chaque variable est associe une liste de zones
        self.listeZone, self.listeZoneText, self.listeZmedia = {}, {}, {}
        for i in range(len(self.listeNomVar)):
            #print self.listeNomVar[i]
            self.listeZone[self.listeNomVar[i]] = []
            self.listeZoneText[self.listeNomVar[i]] = []
            self.listeZmedia[self.listeNomVar[i]] = []

        # toolbar de la visu, de type NavigationToolbar2Wx
        self.toolbar = NavigationToolbar2Wx(self)
        self.toolbar.Realize()
        # ajout du subplot a la figure
        self.cnv = self.fig.add_axes([.05, .05, .9,
                                      .88])  #left,bottom, wide,height
        self.toolbar.update()
        self.pos = self.mpl_connect('motion_notify_event', self.onPosition)

        # create teh major objects:
        self.Contour, self.ContourF, self.ContourLabel, self.Vector = None, None, None, None
        self.Grid, self.Particles, self.Image, self.Map = None, None, None, None

    #####################################################################
    #                     Divers accesseur/mutateurs
    #####################################################################

    def GetToolBar(self):
        return self.toolbar

    def getcurVisu(self):
        return [self.curGroupe, self.curNom, self.curObj]

    def onPosition(self, evt):
        self.gui.onPosition(' x: ' + str(evt.xdata)[:6] + ' y: ' +
                            str(evt.ydata)[:6])

    def delAllObjects(self):
        for v in self.listeZone:
            self.listeZone[v] = []
            self.listeZoneText[v] = []
        self.cnv.lines = []
        self.cnv.collections = []
        self.cnv.artists = []
        self.cnv.images = []
        self.cnv.cla()
        self.draw()

    def setVisu(self, core):
        """creer les objets graphiques a partir des caracteristiques d'un modele
        importe.
        creer les zones avec setAllzones, puis le contour et les vectors ecoult,
        les lignes etle contour pour tracer, contour pour reaction
        depend de l'etat du systeme de la liste graphique
        comme ca tout ca pourra etre visualise sans faire de nouveau calcul
        """
        self.delAllObjects()
        for mod in self.core.modelList:
            self.setAllZones(core.diczone[mod].dic)
        self.initDomain()
        self.draw()

    def setDataOn(self, bool):
        """definit l'affichage ou non des donnees qaund contour"""
        self.dataon = bool

    def redraw(self):
        #self.cnv.set_xlim(self.xlim)
        #self.cnv.set_ylim(self.ylim)
        self.draw()

#    def changeTitre(self,titre):
#        s='';ori=self.curOri
#        if ori in ['X','Y','Z']:
#            plan=self.curPlan;
#            x1,y1 = self.model.Aquifere.getXYticks()
#            zl = self.model.Aquifere.getParm('zList')
#        if ori=='Z': s=' Z = '+ str(zl[plan])
#        if ori=='X': s=' X = '+ str(x1[plan])
#        if ori=='Y': s=' Y = '+ str(y1[plan])
#        pl.title(self.traduit(str(titre))+s[:9],fontsize=20)

    def createAndShowObject(self, dataM, dataV, opt, value=None, color=None):
        """create the Contour, Vector, opt is contour or vector
        """
        if dataM == None: self.drawContour(False)
        else: self.createContour(dataM, value, color)
        if dataV == None: self.drawVector(False)
        else: self.createVector(dataV)

    def drawObject(self, typObj, bool):
        if typObj == 'Map' and self.Map == None and bool == False: return
        exec('self.draw' + typObj + '(' + str(bool) + ')')

    def changeObject(self, groupe, name, value, color):
        if name == 'Grid': self.changeGrid(color)
        elif name == 'Particles':
            self.changeParticles(value=value, color=color)
        elif name == 'Veloc-vect':
            self.changeVector(value, color)
            #elif name=='Visible': self.changeData(value,color)
        else:
            self.changeContour(value, color)

    #####################################################################
    #             Gestion de l'affichage de la grid/map
    #####################################################################
    # methode qui change la taille du domaine d'etude (les values de l'axe
    # de la figure matplotlib en fait) et la taille des cellules d'etude
    def initDomain(self):
        # change value of the axes
        grd = self.core.addin.getFullGrid()
        self.xlim = (grd['x0'], grd['x1'])
        self.ylim = (grd['y0'], grd['y1'])
        p, = pl.plot([0, 1], 'b')
        p.set_visible(False)
        self.transform = p.get_transform()
        self.cnv.set_xlim(self.xlim)
        self.cnv.set_ylim(self.ylim)
        self.createGrid()
        # add basic vector as a linecollection
        dep = rand(2, 2) * 0.
        arr = dep * 1.
        self.Vector = LineCollection(zip(dep, arr))
        self.Vector.set_transform(self.transform)
        self.Vector.set_visible(False)
        #pl.setp(lc,linewidth=.5);
        self.cnv.collections.append(self.Vector)
        self.Vector.data = [0, 0, None, None]


#    def changeDomain(self):
#        self.changeAxesOri('Z',0)
#

    def changeAxesOri(self, ori):
        # change orientation de la visu
        zb = self.core.Zblock
        zlim = (amin(zb), amax(zb))
        if ori == 'Z':
            self.cnv.set_xlim(self.xlim)
            self.cnv.set_ylim(self.ylim)
        elif ori == 'X':
            self.cnv.set_xlim(self.ylim)
            self.cnv.set_ylim(zlim)
        elif ori == 'Y':
            self.cnv.set_xlim(self.xlim)
            self.cnv.set_ylim(zlim)
        self.draw()

    def createGrid(self, col=None):
        if self.Grid == None:
            col = (.6, .6, .6)
            self.Grid = [0, 0, col]
            #self.cnv.collections=[0,0];
        else:
            for i in range(2):
                self.Grid[i].set_visible(False)
        if col == None: col = self.Grid[2]
        else: self.Grid[2] = col
        #print 'create grid',self.Grid,col
        nx, ny, xt, yt = getXYvects(self.core)
        #print 'visu,grid',nx,ny,xt,yt
        if len(self.cnv.collections) < 2: self.cnv.collections = [0, 0]
        l = len(ravel(xt))
        dep = concatenate([xt.reshape((l, 1)), ones((l, 1)) * min(yt)], axis=1)
        arr = concatenate([xt.reshape((l, 1)), ones((l, 1)) * max(yt)], axis=1)
        self.Grid[0] = LineCollection(zip(dep, arr))
        self.cnv.collections[0] = self.Grid[0]
        l = len(ravel(yt))
        dep = concatenate([ones((l, 1)) * min(xt), yt.reshape((l, 1))], axis=1)
        arr = concatenate([ones((l, 1)) * max(xt), yt.reshape((l, 1))], axis=1)
        self.Grid[1] = LineCollection(zip(dep, arr))
        self.cnv.collections[1] = self.Grid[1]
        for i in [0, 1]:
            self.Grid[i].set_transform(self.transform)
            self.Grid[i].set_color(col)
        self.redraw()

    def drawGrid(self, bool):  # works only to remove not to recreate
        col = self.Grid[2]
        for i in [0, 1]:
            self.Grid[i].set_visible(bool)
            self.Grid[i].set_color(col)
        self.redraw()

    def changeGrid(self, color):
        a = color.Get()
        col = (a[0] / 255, a[1] / 255, a[2] / 255)
        for i in [0, 1]:
            self.Grid[i].set_color(col)
        self.Grid[2] = col
        self.redraw()

    #####################################################################
    #             Affichage d'une variable sous forme d'image
    #####################################################################
    # l'image se met en position 1 dans la liste des images
    def createMap(self):
        file = self.gui.map
        mat = Im.imread(file)
        org = 'upper'
        ext = (self.xlim[0], self.xlim[1], self.ylim[0], self.ylim[1])
        self.Map = pl.imshow(mat,
                             origin=org,
                             extent=ext,
                             aspect='auto',
                             interpolation='nearest')
        self.cnv.images = [self.Map]  #
        self.cnv.images[0].set_visible(True)
        self.redraw()

    def drawMap(self, bool):
        if self.Map == None: self.createMap()
        #        self.Map.set_visible(bool)
        self.cnv.images = [self.Map]  #
        self.cnv.images[0].set_visible(bool)
        self.redraw()

    def createImage(self, data):
        #print 'vis img',len(xt),len(yt),shape(mat)
        X, Y, Z = data
        image = pl.pcolormesh(X, Y, Z)  #,norm='Normalize')
        self.cnv.images = [image]
        self.redraw()

    def drawImage(self, bool):
        if len(self.cnv.images) > 0:
            self.cnv.images[0].set_visible(bool)
            self.redraw()

    #####################################################################
    #             Gestion de l'affichage des contours
    #####################################################################

    def createContour(self, data, value=None, col=None):
        """ calcul des contour sa partir de value : value[0] : min
        [1] : max, [2] nb contours, [3] decimales, [4] : 'lin' log' ou 'fix',
        si [4]:fix, alors [5] est la serie des values de contours"""
        X, Y, Z = data
        #print 'visu controu',value,col
        self.cnv.collections = self.cnv.collections[:3]
        self.cnv.artists = []
        V = 11
        Zmin = amin(amin(Z))
        Zmax = amax(amax(Z * (Z < 1e5)))
        if Zmax == Zmin:  # test min=max -> pas de contour
            self.gui.onMessage(' values all equal to ' + str(Zmin))
            return
        if value == None:
            value = [Zmin, Zmax, (Zmax - Zmin) / 10., 2, 'auto', []]
        # adapter le namebre et la value des contours
        val2 = [float(a) for a in value[:3]]
        if value[4] == 'log':  # cas echelle log
            n = int((log10(val2[1]) - log10(max(val2[0], 1e-4))) / val2[2]) + 1
            V = logspace(log10(max(val2[0], 1e-4)), log10(val2[1]), n)
        elif (value[4]
              == 'fix') and (value[5] != None):  # fixes par l'utilisateur
            V = value[5] * 1
            V.append(V[-1] * 2.)
            n = len(V)
        elif value[4] == 'lin':  # cas echelle lineaire
            n = int((val2[1] - val2[0]) / val2[2]) + 1
            V = linspace(val2[0], val2[1], n)
        else:  # cas automatique
            n = 11
            V = linspace(Zmin, Zmax, n)
        # ONE DIMENSIONAL
        r, c = shape(X)
        if r == 1:
            X = concatenate([X, X])
            Y = concatenate([Y - Y * .45, Y + Y * .45])
            Z = concatenate([Z, Z])
        Z2 = ma.masked_where(Z.copy() > 1e5, Z.copy())
        #print value,n,V
        # definir les couleurs des contours
        if col == None:  # or (col==[(0,0,0),(0,0,0),(0,0,0),10]):
            cf = pl.contourf(pl.array(X), pl.array(Y), Z2, V)
            c = pl.contour(pl.array(X), pl.array(Y), Z2, V)
            col = [(0, 0, 255), (0, 255, 0), (255, 0, 0), 10]
        else:
            r, g, b = [], [], []
            lim=((0.,1.,0.,0.),(.1,1.,0.,0.),(.25,.8,0.,0.),(.35,0.,.8,0.),(.45,0.,1.,0.),\
                 (.55,0.,1.,0.),(.65,0.,.8,0.),(.75,0.,0.,.8),(.9,0.,0.,1.),(1.,0.,0.,1.))
            for i in range(len(lim)):
                c1 = lim[i][1] * col[0][0] / 255. + lim[i][2] * col[1][
                    0] / 255. + lim[i][3] * col[2][0] / 255.
                r.append((lim[i][0], c1, c1))
                c2 = lim[i][1] * col[0][1] / 255. + lim[i][2] * col[1][
                    1] / 255. + lim[i][3] * col[2][1] / 255.
                g.append((lim[i][0], c2, c2))
                c3 = lim[i][1] * col[0][2] / 255. + lim[i][2] * col[1][
                    2] / 255. + lim[i][3] * col[2][2] / 255.
                b.append((lim[i][0], c3, c3))
            cdict = {'red': r, 'green': g, 'blue': b}
            my_cmap = mpl.colors.LinearSegmentedColormap(
                'my_colormap', cdict, 256)
            cf = pl.contourf(pl.array(X), pl.array(Y), Z2, V, cmap=my_cmap)
            c = pl.contour(pl.array(X), pl.array(Y), Z2, V, cmap=my_cmap)
        #print col[3]
        for c0 in cf.collections:
            c0.set_alpha(int(col[3]) / 100.)
            #print cl
        if value == None: fmt = '%1.3f'
        else: fmt = '%1.' + str(value[3]) + 'f'
        cl = pl.clabel(c, color='black', fontsize=9, fmt=fmt)
        self.Contour = c
        self.ContourF = cf
        self.ContourLabel = cl
        self.Contour.data = data
        self.redraw()

    def changeContour(self, value, col):
        """ modifie les values d'un contour existant"""
        self.drawContour(False)
        self.createContour(self.Contour.data, value, col)

    def drawContour(self, bool):
        self.cnv.collections = self.cnv.collections[:3]
        self.cnv.artists = []
        self.draw()
        #~ for c in self.Contour.collections :c.set_visible(False)
        #~ for c in self.ContourF.collections :c.set_visible(False)
        #~ for a in self.ContourLabel: a.set_visible(False)
        #~ #self.cnv.collections = self.cnv.collections[:3]
        #~ self.redraw()

    #####################################################################
    #             Gestion de l'affichage de vectors
    #####################################################################
    """vector has been created as the first item of lincollection list
    during domain intialization"""

    def createVector(self, data):
        X, Y, U, V = data
        """ modifie les values de vectors existants"""
        if self.Vector.data[3] == None:  #first vector no color
            ech = 1.
            col = (0, 0, 1)
        else:
            a, b, ech, col = self.Vector.data
            self.drawVector(False)
        l = len(ravel(X))
        dep = concatenate([X.reshape((l, 1)), Y.reshape((l, 1))], axis=1)
        b = X + U * ech
        c = Y + V * ech
        arr = concatenate([b.reshape((l, 1)), c.reshape((l, 1))], axis=1)
        self.Vector = LineCollection(zip(dep, arr))
        self.Vector.set_transform(self.transform)
        self.Vector.set_color(col)
        if len(self.cnv.collections) > 2: self.cnv.collections[2] = self.Vector
        else: self.cnv.collections.append(self.Vector)
        self.Vector.set_visible(True)
        self.Vector.data = [dep, arr, ech, col]
        #print self.Vector.data
        self.redraw()

    def drawVector(self, bool):
        """ dessine les vectors vitesse a partir de x,y,u,v et du
        booleen qui dit s'il faut dessiner ou non """
        self.Vector.set_visible(bool)
        self.redraw()

    def changeVector(self, ech, col=wx.Color(0, 0, 255)):
        """ modifie les values de vectors existants"""
        #self.drawVector(False)
        ech = float(ech)
        #change coordinates
        dep, arr_old, ech_old, col_old = self.Vector.data
        #print shape(dep),shape(arr_old),ech,ech_old
        arr = dep + (arr_old - dep) * ech / ech_old
        # new object
        #self.Vector = LineCollection(zip(dep,arr))
        self.Vector.set_segments(zip(dep, arr))
        #self.Vector.set_transform(self.transform)
        a = col.Get()
        col = (a[0] / 255, a[1] / 255, a[2] / 255)
        self.Vector.set_color(col)
        self.Vector.set_visible(True)
        #self.cnv.collections[2]=self.Vector
        self.Vector.data = [dep, arr, ech, col]
        self.redraw()

    #####################################################################
    #             Gestion de l'affichage de particules
    #####################################################################
    def startParticles(self):
        if self.Particles != None:
            self.partVisible(False)
        self.Particles = {
            'line': [],
            'txt': [],
            'data': [],
            'color': wx.Color(255, 0, 0)
        }
        self.mpl_disconnect(self.toolbar._idPress)
        self.mpl_disconnect(self.toolbar._idRelease)
        self.mpl_disconnect(self.toolbar._idDrag)
        # on capte le clic gauche de la souris
        self.m3 = self.mpl_connect('button_press_event', self.mouseParticles)
        self.stop = False
        #self.createParticles()
        #wx.EVT_LEAVE_WINDOW(self,self.finParticules)  # arrete particules qd on sort de visu

    def mouseParticles(self, evt):
        #test pour savoir si le curseur est bien dans les axes de la figure
        if self.stop: return
        if evt.inaxes is None: return
        if evt.button == 1:
            [xp, yp, tp] = self.core.addin.calcParticle(evt.xdata, evt.ydata)
            #print xp,yp,tp
            self.updateParticles(xp, yp, tp)
        elif evt.button == 3:
            self.mpl_disconnect(self.m3)
            self.stop = True
            self.gui.actions('zoneEnd')

    def updateParticles(self, X, Y, T, freq=10):
        """ rajouter une ligne dans le groupe de particules"""
        ligne, = pl.plot(pl.array(X), pl.array(Y), 'r')
        if freq > 0:
            tx, ty, tt = X[0::freq], Y[0::freq], T[0::freq]
            txt = []
            for i in range(len(tx)):
                a = str(tt[i])
                b = a.split('.')
                ln = max(4, len(b[0]))
                txt.append(pl.text(tx[i], ty[i], a[:ln], fontsize='8'))
        self.Particles['line'].append(ligne)
        self.Particles['txt'].append(txt)
        self.Particles['data'].append((X, Y, T))
        self.gui_repaint()  # bug matplotlib v2.6 for direct draw!!!
        self.draw()

    def drawParticles(self, bool, value=None):
        if self.Particles == None: return
        self.partVisible(bool)
        self.gui_repaint()
        self.draw()

    def changeParticles(self, value=None, color=wx.Color(255, 0, 0)):
        self.partVisible(False)
        self.Particles['color'], self.Particles['txt'] = color, []
        for i, data in enumerate(self.Particles['data']):
            X, Y, T = data
            tx, ty, tt = self.ptsPartic(X, Y, T, float(value))
            txt = []
            for i in range(len(tx)):
                a = str(tt[i])
                b = a.split('.')
                ln = max(4, len(b[0]))
                txt.append(pl.text(tx[i], ty[i], a[:ln], fontsize='8'))
            self.Particles['txt'].append(txt)
        self.partVisible(True)
        self.gui_repaint()
        self.draw()

    def partVisible(self, bool):
        a = self.Particles['color'].Get()
        color = (a[0] / 255, a[1] / 255, a[2] / 255)
        for line in self.Particles['line']:
            line.set_visible(bool)
            line.set_color(color)
        for points in self.Particles['txt']:
            for txt in points:
                txt.set_visible(bool)

    def ptsPartic(self, X, Y, T, dt):
        #tx,ty,tt,i1=iphtC1.ptsLigne(X,Y,T,dt);
        tmin = amin(T)
        tmax = amax(T)
        t1 = linspace(tmin, tmax, int((tmax - tmin) / dt))
        f = interp1d(T, X)
        xn = f(t1)
        f = interp1d(T, Y)
        yn = f(t1)
        return xn, yn, t1

    #####################################################################
    #                   Gestion des zones de la visu
    #####################################################################
    # affichage de toutes les zones d'une variable
    def showVar(self, var, media):
        self.setUnvisibleZones()
        self.curVar, self.curMedia = var, media
        for i in range(len(self.listeZone[var])):
            #print 'vis showvar',self.listeZmedia[var][i]
            if (media in self.listeZmedia[var][i]) or (media == -1):
                self.listeZone[var][i].set_visible(True)
                self.visibleText(self.listeZoneText[var][i], True)
        #self.changeTitre(var)
        self.redraw()

    def showData(self, liForage, liData):
        self.setUnvisibleZones()
        self.curVar = 'data'
        self.listeZoneText['data'] = []
        for zone in self.listeZone['Forages']:
            zone.set_visible(True)
        lZone = self.model.Aquifere.getZoneList('Forages')
        txt = []
        for z in lZone:
            x, y = zip(*z.getXy())
            name = z.getNom()
            if name in liForage:
                ind = liForage.index(name)
                txt.append(
                    pl.text(mean(x), mean(y), name + '\n' + str(liData[ind])))
        obj = GraphicObject('zoneText', txt, True, None)
        self.addGraphicObject(obj)
        self.redraw()

    def changeData(self, taille, col):
        obj = self.listeZoneText['data'][0].getObject()
        for txt in obj:
            txt.set_size(taille)
            txt.set_color(col)

    def getcurZone(self):
        return self.curZone

    def setcurZone(self, zone):
        self.curZone = zone

    # methode qui efface toutes les zones de toutes les variables
    def setUnvisibleZones(self):
        for v in self.listeZone:
            for zone in self.listeZone[v]:
                zone.set_visible(False)
            for txt in self.listeZoneText[v]:
                if type(txt) == type([5, 6]):
                    for t in txt:
                        t.set_visible(False)
                else:
                    txt.set_visible(False)

    # methode appelee par la GUI lorsqu'on veut creer une nouvelle zone
    def setZoneReady(self, typeZone, curVar):
        self.typeZone = typeZone
        self.curVar = curVar
        self.tempZoneVal = []
        # on deconnecte la toolbar pour activer la formaiton de zones
        self.mpl_disconnect(self.toolbar._idPress)
        self.mpl_disconnect(self.toolbar._idRelease)
        self.mpl_disconnect(self.toolbar._idDrag)
        # on capte le clic gauche de la souris
        self.m1 = self.mpl_connect('button_press_event', self.mouse_clic)

    def setZoneEnd(self, evt):
        # on informe la GUI qui informera le model
        xv, yv = self.getcurZone().get_xdata(), self.getcurZone().get_ydata()
        if len(self.tempZoneVal) > 1: xy = zip(xv, yv, self.tempZoneVal)
        else: xy = zip(xv, yv)
        # effacer zone pour si cancel, remettre de l'ordre
        self.curZone.set_visible(False)
        self.curZone = None
        self.x1, self.y1 = [], []
        self.gui.addBox.onZoneCreate(self.typeZone, xy)

    def addZone(self, media, name, val, coords, visible=True):
        """ ajout de la zone et du texte (name+value) sur visu 
        """
        #print 'visu',coords
        a = zip(*coords)
        txt = []
        #print name,a
        if len(a) == 0: return
        if len(a) == 2: x, y = a
        elif len(a) == 3: x, y, z = a
        if len(x) == 1:
            zone = Line2D(x, y, marker='+', markersize=10, markeredgecolor='r')
        else:
            zone = Line2D(x, y)
        zone.verts = coords
        zone.set_visible(visible)
        if type(media) != type([2]): media = [int(media)]
        self.curMedia = media
        self.cnv.add_line(zone)
        if self.typeZone == "POLYV" or len(coords[0]) == 3:
            txt = [
                pl.text(
                    mean(x) * .1 + x[0] * .9,
                    mean(y) * .1 + y[0] * .9, name + '\n' + str(val)[:16])
            ]
            for i in range(len(x)):
                t = pl.text(x[i], y[i], str(z[i]))
                t.set_visible(visible)
                txt.append(t)
        else:
            txt = pl.text(
                mean(x) * .1 + x[0] * .9,
                mean(y) * .1 + y[0] * .9, name + '\n' + str(val)[:16])
        self.listeZone[self.curVar].append(zone)
        self.listeZmedia[self.curVar].append(media)
        self.listeZoneText[self.curVar].append(txt)
        if visible: self.redraw()

    def delZone(self, Variable, ind):
        """methode de suppression de la zone d'indice ind de Variable
        """
        if self.listeZone.has_key(Variable) == False: return
        if len(self.listeZone[Variable]) > ind:
            self.listeZone[Variable][ind].set_visible(False)
            self.visibleText(self.listeZoneText[Variable][ind], False)
            self.listeZone[Variable][ind:ind + 1] = []
            self.listeZoneText[Variable][ind:ind + 1] = []
            self.listeZmedia[Variable].pop(ind)
            self.redraw()

    def visibleText(self, text, bool):
        if type(text) == type([5, 6]):
            for t in text:
                t.set_visible(bool)
        else:
            text.set_visible(bool)

    def delAllZones(self, Variable):
        lz = self.listeZone[Variable]
        for i in range(len(lz)):
            lz[i].setVisible(False)
            self.listeZoneText[Variable][i].set_visible(False)
        self.listeZone[Variable] = []
        self.listeZmedia[Variable] = []
        self.listeZoneText[Variable] = []
        self.redraw()

    def modifValZone(self, nameVar, ind, val, xy):
        """modify the value (or list of value) for the zone nameVar 
        the text contains name et value"""

    def modifZoneAttr(self, nameVar, ind, val, media, xy):
        # modify xy
        zone = self.listeZone[nameVar][ind]
        if len(xy[0]) == 3: x, y, z = zip(*xy)
        else: x, y = zip(*xy)
        zone.set_data(x, y)
        # modify media
        if type(media) != type([2]): media = [int(media)]
        self.listeZmedia[nameVar][ind] = media
        # modify text
        textObj = self.listeZoneText[nameVar][ind]
        if type(textObj) == type([2, 3]):
            name = pl.getp(textObj[0], 'text').split('\n')[0]
            pl.setp(textObj[0], text=name + '\n' + str(val)[:16])
            for i in range(len(z)):
                pl.setp(textObj[i + 1], text=str(z[i]))
        else:
            name = pl.getp(textObj, 'text').split('\n')[0]
            pl.setp(textObj, text=name + '\n' + str(val)[:16])
        self.redraw()

    def modifZone(self, nameVar, ind):
        """ modification interactive des points de la zone d'indice ind de name nameVar
        """
        zone = self.listeZone[nameVar][ind]
        self.polyInteract = PolygonInteractor(self, zone, nameVar, ind)
        zone.set_visible(False)
        self.cnv.add_line(self.polyInteract.line)
        self.draw()

    def finModifZone(self):
        """fonction qui met fin a la modification de la zone courante"""
        if self.polyInteract != None:
            self.polyInteract.set_visible(False)
            self.polyInteract.disable()
            # on informe la GUI des nouvelles coordonnees
            var, ind = self.polyInteract.typeVariable, self.polyInteract.ind
            x, y = self.polyInteract.lx, self.polyInteract.ly
            #print x,y
            xy = zip(x, y)
            self.gui.modifBox.onModifZoneCoord(var, ind, xy)
            zone = self.listeZone[var][ind]
            zone.set_data(x, y)
            zone.set_visible(True)
            # on modifie la position du texte
            txt = self.listeZoneText[var][ind]
            if type(txt) == type([5, 6]):
                txt[0].set_position((x[0], y[0]))
                for i in range(1, len(txt)):
                    txt[i].set_position((x[i - 1], y[i - 1]))
            else:
                txt.set_position(
                    (mean(x) * .1 + x[0] * .9, mean(y) * .1 + y[0] * .9))
            self.draw()

    def setAllZones(self, dicZone):
        """updates all zones when a file is imported
        """
        for var in dicZone.keys():
            self.listeZone[var] = []
            self.curVar = var
            lz = dicZone[var]
            nbz = len(lz['name'])
            for i in range(nbz):
                if lz['name'][i] == '': continue
                coords = lz['coords'][i]
                self.addZone(lz['media'][i], lz['name'][i], lz['value'][i],
                             coords)
        self.setUnvisibleZones()
        #self.redraw()

    #####################################################################
    #             Gestion de l'interaction de la souris
    #             pour la creation des zones
    #####################################################################

    #methode executee lors d'un clic de souris dans le canvas
    def mouse_clic(self, evt):
        if evt.inaxes is None:
            return
        if self.curZone == None:  # au depart
            self.x1 = [float(str(evt.xdata)[:6])
                       ]  # pour aovir chiffre pas trop long
            self.y1 = [float(str(evt.ydata)[:6])]
            self.setcurZone(Line2D(self.x1, self.y1))
            self.cnv.add_line(self.curZone)
            self.m2 = self.mpl_connect('motion_notify_event',
                                       self.mouse_motion)
            if self.typeZone == "POLYV":
                self.polyVdialog()
            if self.typeZone == "POINT":
                self.deconnecte()
                self.setZoneEnd(evt)

        else:  # points suivants
            if self.typeZone == "POLYV":  # and evt.button ==1:
                if evt.button == 3: self.deconnecte()
                rep = self.polyVdialog()  # dialog for the current value of z
                if rep == False: return
            self.x1.append(float(str(evt.xdata)[:6]))
            self.y1.append(float(str(evt.ydata)[:6]))
            if self.typeZone == "LINE" or self.typeZone == "RECT":
                self.deconnecte()  #fin des le 2eme point
                self.setZoneEnd(evt)
            if self.typeZone in ["POLY", "POLYV"
                                 ] and evt.button == 3:  # fin du polygone
                self.deconnecte()
                self.setZoneEnd(evt)

    #methode executee lors du deplacement de la souris dans le canvas suite a un mouse_clic
    def mouse_motion(self, evt):
        time.sleep(0.1)
        if evt.inaxes is None: return
        lx, ly = self.x1 * 1, self.y1 * 1
        if self.typeZone == "RECT":
            xr, yr = self.creeRectangle(self.x1[0], self.y1[0], evt.xdata,
                                        evt.ydata)
            self.curZone.set_data(xr, yr)
        else:  # autres cas
            lx.append(evt.xdata)
            ly.append(evt.ydata)
            self.curZone.set_data(lx, ly)
        self.draw()

    def polyVdialog(self):
        lst0 = [('Value', 'Text', 0)]
        dialg = config.dialogs.genericDialog(self.gui, 'value', lst0)
        values = dialg.getValues()
        if values != None:
            val = float(values[0])
            #print val*2
            self.tempZoneVal.append(val)
            return True
        else:
            return False

    def creeRectangle(self, x1, y1, x2, y2):
        xr = [x1, x2, x2, x1, x1]
        yr = [y1, y1, y2, y2, y1]
        return [xr, yr]

    def deconnecte(self):
        # deconnecter la souris
        self.mpl_disconnect(self.m1)
        self.mpl_disconnect(self.m2)

    ###################################################################
    #   deplacer une zone ##############################

    def startMoveZone(self, nameVar, ind):
        """ methode qui demarre les interactions avec la souris"""
        # reperer la zone et rajouter un point de couleur
        self.nameVar, self.ind = nameVar, ind
        zone = self.listeZone[nameVar][ind]
        self.curZone = zone
        self.lx, self.ly = zone.get_xdata(), zone.get_ydata()
        self.xstart = self.lx[0] * 1.
        self.ystart = self.ly[0] * 1.
        self.ptstart = Line2D([self.xstart], [self.ystart],
                              marker='o',
                              markersize=7,
                              markerfacecolor='r')
        self.cnv.add_line(self.ptstart)
        self.m1 = self.mpl_connect('button_press_event', self.zoneM_clic)
        self.draw()

    def zoneM_clic(self, evt):
        """ action au premier clic"""
        if evt.inaxes is None: return
        #if evt.button==3: self.finMoveZone(evt) # removed OA 6/2/13
        d = sqrt((evt.xdata - self.xstart)**2 + (evt.ydata - self.ystart)**2)
        xmn, xmx = self.xlim
        ymn, ymx = self.ylim
        dmax = sqrt((xmx - xmn)**2 + (ymx - ymn)**2) / 100
        if d > dmax: return
        self.m2 = self.mpl_connect('motion_notify_event', self.zone_motion)
        self.m3 = self.mpl_connect('button_release_event', self.finMoveZone)
        self.mpl_disconnect(self.m1)

    def zone_motion(self, evt):
        """ methode pour deplacer la zone quand on deplace la souris"""
        # reperer le curseur proche du point de couleur
        time.sleep(0.1)
        if evt.inaxes is None: return
        # changer els coord du polygone lorsque l'on deplace la souris
        lx = [a + evt.xdata - self.xstart for a in self.lx]
        ly = [a + evt.ydata - self.ystart for a in self.ly]
        self.ptstart.set_data(lx[0], ly[0])
        self.curZone.set_data(lx, ly)
        self.draw()

    def finMoveZone(self, evt):
        """ methode pour arret de deplacement de la zone"""
        # lorsque l'on relache la souris arreter les mpl connect
        self.mpl_disconnect(self.m2)
        self.mpl_disconnect(self.m3)
        # renvoyer les nouvelels coordonnes au modele
        lx, ly = self.curZone.get_xdata(), self.curZone.get_ydata()
        self.listeZone[self.nameVar][self.ind].set_data(lx, ly)
        xy = zip(lx, ly)
        self.gui.modifBox.onModifZoneCoord(self.nameVar, self.ind, xy)
        # on modifie la position du texte
        txt = self.listeZoneText[self.nameVar][self.ind]
        if type(txt) == type([5, 6]):
            txt[0].set_position((lx[0], ly[0]))
            for i in range(1, len(txt)):
                txt[i].set_position(
                    (lx[i - 1],
                     ly[i - 1]))  #-1 because 1st position zone names
        else:
            txt.set_position(
                (mean(lx) * .1 + lx[0] * .9, mean(ly) * .1 + ly[0] * .9))
        self.ptstart.set_visible(False)
        self.ptstart = None
        self.curZone = None
        self.draw()
Esempio n. 13
0
def add_phylorate(treeplot, rates, nodeidx, vis=True):
    """
    Add phylorate plot generated from data analyzed with BAMM
    (http://bamm-project.org/introduction.html)

    Args:
        rates (array): Array of rates along branches
          created by r_funcs.phylorate
        nodeidx (array): Array of node indices matching rates (also created
          by r_funcs.phylorate)

    WARNING:
        Ladderizing the tree can cause incorrect assignment of Ape node index
        numbers. To prevent this, call this function or root.ape_node_idx()
        before ladderizing the tree to assign correct Ape node index numbers.
    """
    if not treeplot.root.apeidx:
        treeplot.root.ape_node_idx()
    segments = []
    values = []

    if treeplot.plottype == "radial":
        radpatches = [] # For use in drawing arcs for radial plots

        for n in treeplot.root.descendants():
            n.rates = rates[nodeidx==n.apeidx]
            c = treeplot.n2c[n]
            pc = treeplot._path_to_parent(n)[0][1]
            xd = c.x - pc[0]
            yd = c.y - pc[1]
            xseg = xd/len(n.rates)
            yseg = yd/len(n.rates)
            for i, rate in enumerate(n.rates):
                x0 = pc[0] + i*xseg
                y0 = pc[1] + i*yseg
                x1 = x0 + xseg
                y1 = y0 + yseg

                segments.append(((x0, y0), (x1, y1)))
                values.append(rate)

            curverts = treeplot._path_to_parent(n)[0][2:]
            curcodes = treeplot._path_to_parent(n)[1][2:]
            curcol = RdYlBu(n.rates[0])

            radpatches.append(PathPatch(
                       Path(curverts, curcodes), lw=2, edgecolor = curcol,
                            fill=False))
    else:
        for n in treeplot.root.descendants():
            n.rates = rates[nodeidx==n.apeidx]
            c = treeplot.n2c[n]
            pc = treeplot.n2c[n.parent]
            seglen = (c.x-pc.x)/len(n.rates)
            for i, rate in enumerate(n.rates):
                x0 = pc.x + i*seglen
                x1 = x0 + seglen
                segments.append(((x0, c.y), (x1, c.y)))
                values.append(rate)
            segments.append(((pc.x, pc.y), (pc.x, c.y)))
            values.append(n.rates[0])

    lc = LineCollection(segments, cmap=RdYlBu, lw=2)
    lc.set_array(np.array(values))
    treeplot.add_collection(lc)
    lc.set_zorder(1)
    if treeplot.plottype == "radial":
        arccol = matplotlib.collections.PatchCollection(radpatches,
                                                        match_original=True)
        treeplot.add_collection(arccol)
        arccol.set_visible(vis)
        arccol.set_zorder(1)
    lc.set_visible(vis)
    colorbar_legend(treeplot, values, RdYlBu, vis=vis)

    treeplot.figure.canvas.draw_idle()