Exemple #1
0
    def add_collection3d(self, col, zs=0, zdir='z'):
        '''
        Add a 3d collection object to the plot.

        2D collection types are converted to a 3D version by
        modifying the object and adding z coordinate information.

        Supported are:
            - PolyCollection
            - LineColleciton
            - PatchCollection
        '''
        zvals = np.atleast_1d(zs)
        if len(zvals) > 0 :
            zsortval = min(zvals)
        else :
            zsortval = 0   # FIXME: Fairly arbitrary. Is there a better value?

        if type(col) is collections.PolyCollection:
            art3d.poly_collection_2d_to_3d(col, zs=zs, zdir=zdir)
            col.set_sort_zpos(zsortval)
        elif type(col) is collections.LineCollection:
            art3d.line_collection_2d_to_3d(col, zs=zs, zdir=zdir)
            col.set_sort_zpos(zsortval)
        elif type(col) is collections.PatchCollection:
            art3d.patch_collection_2d_to_3d(col, zs=zs, zdir=zdir)
            col.set_sort_zpos(zsortval)

        Axes.add_collection(self, col)
Exemple #2
0
    def draw(self, renderer):

        if self._parent is not None:
            self.axes.viewLim.set(self._parent.viewLim)
            self.set_position(self._parent.get_position())

        Axes.draw(self, renderer)
Exemple #3
0
class TelescopeEventView(tk.Frame, object):
    """ A frame showing the camera view of a single telescope """

    def __init__(self, root, telescope, data=None, *args, **kwargs):
        self.telescope = telescope
        super(TelescopeEventView, self).__init__(root)
        self.figure = Figure(figsize=(5, 5), facecolor='none')
        self.ax = Axes(self.figure, [0, 0, 1, 1], aspect=1)
        self.ax.set_axis_off()
        self.figure.add_axes(self.ax)
        self.camera_plot = CameraPlot(telescope, self.ax, data, *args, **kwargs)
        self.canvas = FigureCanvasTkAgg(self.figure, master=self)
        self.canvas.show()
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        self.canvas._tkcanvas.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True)
        self.canvas._tkcanvas.config(highlightthickness=0)

    @property
    def data(self):
        return self.camera_plot.data

    @data.setter
    def data(self, value):
        self.camera_plot.data = value
        self.canvas.draw()
Exemple #4
0
    def _set_lim_and_transforms(self):
        """
        This is called once when the plot is created to set up all the
        transforms for the data, text and grids.
        """
        rot = 30

        # Get the standard transform setup from the Axes base class
        Axes._set_lim_and_transforms(self)

        # Need to put the skew in the middle, after the scale and limits,
        # but before the transAxes. This way, the skew is done in Axes
        # coordinates thus performing the transform around the proper origin
        # We keep the pre-transAxes transform around for other users, like the
        # spines for finding bounds
        self.transDataToAxes = self.transScale + \
            self.transLimits + transforms.Affine2D().skew_deg(rot, 0)

        # Create the full transform from Data to Pixels
        self.transData = self.transDataToAxes + self.transAxes

        # Blended transforms like this need to have the skewing applied using
        # both axes, in axes coords like before.
        self._xaxis_transform = (transforms.blended_transform_factory(
            self.transScale + self.transLimits,
            transforms.IdentityTransform()) +
            transforms.Affine2D().skew_deg(rot, 0)) + self.transAxes
    def plot_metrics_evolution(cls, ax: axes.Axes, y, x=None, **plot_config):

        # default config
        config = defaultdict(lambda: None)
        config['xlabel'] = 'time'
        for key, value in plot_config.items():
            config[key] = value


        # major ticks every 50
        # major_ticks = np.arange(0, 900, 50)
        # ax.set_xticks(major_ticks)

        # plotting the graph
        if x is None:
            x = np.arange(len(y))

        # check if element of x is tuple or just number
        lines = []
        if isinstance(y[0], (list, tuple)):
            for y_i in zip(*y):
                line, = ax.plot(x, y_i)
                lines.append(line)
        else:
            line, = ax.plot(x, y)
            # colour
            if config['colour'] is not None:
                line.set_color(config['colour'])
            lines.append(line)


        cls.apply_plot_config(ax, config)

        return lines
    def plot_evolution(cls, ax: axes.Axes, y, x=None, **plot_config):

        # default config
        config = defaultdict(lambda: None)
        config['xlabel'] = 'time'
        for key, value in plot_config.items():
            config[key] = value


        # plotting the graph
        if x is None:
            x = np.arange(len(y))

        # check if element of x is tuple or just number
        lines = []
        if isinstance(y[0], (list, tuple)):
            for y_i in zip(*y):
                line, = ax.plot(x, y_i)
                lines.append(line)
        else:
            line, = ax.plot(x, y)
            lines.append(line)

        cls.apply_plot_config(ax, config)
        plt.tight_layout()

        return lines
Exemple #7
0
 def set_top_view(self):
     xdwl = (0.95/self.dist)
     xdw = (0.9/self.dist)
     ydwl = (0.95/self.dist)
     ydw = (0.9/self.dist)
     Axes.set_xlim(self, -xdwl, xdw)
     Axes.set_ylim(self, -ydwl, ydw)
Exemple #8
0
def center_histogram_2d(par1, par2):
    fig22 = figure(22)
    ax22 = Axes(fig22, [.1, .1, .8, .8])
    fig22.add_axes(ax22)
    div = 10.
    ll1 = l_dict[par1]
    ll2 = l_dict[par2]
    ss1 = ss_dict[par1]
    ss2 = ss_dict[par2]
    H, xedges, yedges = histogram2d(ss1[0], ss2[0], bins = [linspace(-ll1 / 2., ll1 / 2., div + 1) , linspace(-ll2 / 2., ll2 / 2., div * ll1 / ll2 + 1)], normed = 0)

    #extent = [-ll1 / 2., ll1 / 2., -ll2 / 2., ll2 / 2.]#[xedges[0], xedges[-1], yedges[0], yedges[-1]]
    #im = imshow( H, extent=extent )#'binary', cmap='jet' 
    #im.set_interpolation( 'bilinear' )
    #colorbar()
    ax22 = Axes3D(fig22, [.1, .1, .8, .8])
    x = (xedges[range(0, len(xedges) - 1)] + xedges[range(1, len(xedges))]) / 2.
    y = (yedges[range(0, len(yedges) - 1)] + yedges[range(1, len(yedges))]) / 2.
    #ax10.scatter3D( xedges.ravel(), yedges.ravel(), H.ravel() )
    xx = outer(x, ones(len(y)))
    yy = outer(ones(len(x)), y)
    ax22.plot_wireframe(xx, yy, H)#, rstride=1, cstride=1 

    #H, xedges, yedges = histogram2d( ss1[0], ss2[0], bins=[div, div], normed=0 )
    #ax22.set_xticks( arange( -ll1 / 2., ll1 / 2., ll1 / 10. ) )
    #ax22.set_title( 'histogram of fibers centroid in 2D -- sim' )
    #ax22 = Axes3D( fig22 )
    #X, Y, Z = xedges, yedges, H
    #ax22.plot3D( X.ravel(), Y.ravel(), Z.ravel(), 'ro' )  
    #ax22.contour3D( X[:-1], Y[:-1], Z )
    #ax22.plot_wireframe( X, Y, Z, rstride=6, cstride=6, color='blue', linewidth=0.5 )
    #ax22.plot_surface( X[:-1], Y[:-1], Z )
    draw()
    return 0
Exemple #9
0
    def draw(self, renderer):
        # draw the background patch
        self.axesPatch.draw(renderer)
        self._frameon = False

        # add the projection matrix to the renderer
        self.M = self.get_proj()
        renderer.M = self.M
        renderer.vvec = self.vvec
        renderer.eye = self.eye
        renderer.get_axis_position = self.get_axis_position

        # Calculate projection of collections and zorder them
        zlist = [(col.do_3d_projection(renderer), col) for col in self.collections]
        zlist.sort()
        zlist.reverse()
        for i, (z, col) in enumerate(zlist):
            col.zorder = i

        # Calculate projection of patches and zorder them
        zlist = [(patch.do_3d_projection(renderer), patch) for patch in self.patches]
        zlist.sort()
        zlist.reverse()
        for i, (z, patch) in enumerate(zlist):
            patch.zorder = i

        axes = (self.w_xaxis, self.w_yaxis, self.w_zaxis)
        for ax in axes:
            ax.draw_pane(renderer)
        for ax in axes:
            ax.draw(renderer)

        Axes.draw(self, renderer)
Exemple #10
0
    def cla(self):
        """Provide reasonable defaults for the axes.
        """
        # Call the base class.
        Axes.cla(self)
        self.grid(True)

        # Only the x-axis is shown, but there are 3 axes once all of the
        # projections are included.
        self.yaxis.set_visible(False)

        # Adjust the number of ticks shown.
        #self.set_xticks(np.linspace(0, self.viewLim.x1, 5))
        self.set_xticks(np.linspace(0, self.total, 5))

        # Turn off minor ticking altogether.
        self.xaxis.set_minor_locator(NullLocator())

        # Place the title a little higher than normal.
        self.title.set_y(1.02)

        # Modify the padding between the tick labels and the axis labels.
        self.xaxis.labelpad = 10 # In display units

        # Spacing from the vertices (tips) to the tip labels (in data coords, as
        # a fraction of self.total)
        self.tipoffset = 0.14
Exemple #11
0
    def __init__(self, fig, rect=None, *args, **kwargs):
        if rect is None:
            rect = [0.0, 0.0, 1.0, 1.0]
        self.fig = fig
        self.cids = []

        azim = kwargs.pop("azim", -60)
        elev = kwargs.pop("elev", 30)

        self.xy_viewLim = unit_bbox()
        self.zz_viewLim = unit_bbox()
        self.xy_dataLim = unit_bbox()
        self.zz_dataLim = unit_bbox()
        # inihibit autoscale_view until the axises are defined
        # they can't be defined until Axes.__init__ has been called
        self.view_init(elev, azim)
        self._ready = 0
        Axes.__init__(self, self.fig, rect, frameon=True, xticks=[], yticks=[], *args, **kwargs)

        self.M = None

        self._ready = 1
        self.mouse_init()
        self.create_axes()
        self.set_top_view()

        self.axesPatch.set_linewidth(0)
        self.fig.add_axes(self)
Exemple #12
0
 def cla(self):
     # Disconnect the various events we set.
     for cid in self.cids:
         self.figure.canvas.mpl_disconnect(cid)
     self.cids = []
     Axes.cla(self)
     self.grid(rcParams['axes3d.grid'])
Exemple #13
0
    def draw(self, renderer):
        # draw the background patch
        self.axesPatch.draw(renderer)
        self._frameon = False

        # add the projection matrix to the renderer
        self.M = self.get_proj()
        renderer.M = self.M
        renderer.vvec = self.vvec
        renderer.eye = self.eye
        renderer.get_axis_position = self.get_axis_position

        # Calculate projection of collections and zorder them
        zlist = [(col.do_3d_projection(renderer), col) \
                for col in self.collections]
        zlist.sort(reverse=True)
        for i, (z, col) in enumerate(zlist):
            col.zorder = getattr(col, '_force_zorder', i)


        # Calculate projection of patches and zorder them
        zlist = [(patch.do_3d_projection(renderer), patch) \
                for patch in self.patches]
        zlist.sort(reverse=True)
        for i, (z, patch) in enumerate(zlist):
            patch.zorder = i

        self.w_xaxis.draw(renderer)
        self.w_yaxis.draw(renderer)
        self.w_zaxis.draw(renderer)
        Axes.draw(self, renderer)
    def cla(self):
        """
        Override to set up some reasonable defaults.
        """
        # Don't forget to call the base class
        Axes.cla(self)

        # Turn off minor ticking altogether
        self.xaxis.set_minor_locator(NullLocator())
        self.yaxis.set_minor_locator(NullLocator())

        self.xaxis.set_major_locator(MaxNLocator(5, prune='both'))
        self.yaxis.set_major_locator(MaxNLocator(5, prune='both'))

        # Do not display ticks -- we only want gridlines and text
        self.xaxis.set_ticks_position('none')
        self.yaxis.set_ticks_position('none')

        self.set_center(None, None)

        # FIXME: probabaly want to override autoscale_view
        # to properly handle wrapping introduced by margin
        # and properlty wrap data. 
        # It doesn't make sense to have xwidth > 360. 
        self._tight = True
def plot_mods(ax: Axes, stats: Stats, most_popular_at_bottom=False, percentage=False):
    # FIXME different colors when using percentage with 10 mods
    mods_by_popularity = sorted(stats.players_by_mod.keys(),
                                key=lambda mod: stats.all_time_players_by_mod[mod], reverse=most_popular_at_bottom)
    labels = ["{} - avg. {:.2f} players".format(mod, stats.all_time_players_by_mod[mod] / len(stats.dates))
              for mod in mods_by_popularity]

    colors = ['red', 'green', 'blue', 'yellow', 'purple', 'lime', 'gray', 'cyan', 'orange', 'deeppink', 'black']
    colors = colors[0:len(mods_by_popularity)]
    colors = colors if most_popular_at_bottom else reversed(colors)
    ax.set_prop_cycle(color=colors)

    all_mods_values = np.row_stack([stats.players_by_mod[mod] for mod in mods_by_popularity])
    if percentage:
        with np.errstate(invalid='ignore'):
            all_mods_values = all_mods_values / all_mods_values.sum(axis=0) * 100
    ax.stackplot(stats.dates, all_mods_values, labels=labels, linewidth=0.1)

    if percentage:
        ax.set_ylim(bottom=0, top=100)
    else:
        ax.set_ylim(bottom=0)
    decorate_axes(ax)
    handles, labels = ax.get_legend_handles_labels()
    handles = handles if most_popular_at_bottom else reversed(handles)
    labels = labels if most_popular_at_bottom else reversed(labels)
    leg = ax.legend(handles, labels, loc='upper left', prop={'size': 10})
    leg.get_frame().set_alpha(0.5)
Exemple #16
0
    def __init__(self, fig, rect=None, **kwargs):

        if rect == None:
            rect = [0.1,0.1,0.8,0.8]
        self.season = kwargs.pop('season',1)
        self.ra_direction = kwargs.pop('ra_direction',1)
        # center coordinate in (RA, Dec)
        self.center = get_kepler_center(season=self.season)
        # self.center must be set before Axes.__init__, because Axes.__init__
        # calls axes.cla(), which in turns calls set_ylim(), which is
        # overwritten in this class.

        self.ref_dec1, self.ref_dec2 = 20., 50.

        Axes.__init__(self, fig, rect, **kwargs)

        # prepare for the draw the ticks on the top & right axis
        self.twiny = self.twiny()
        self.twinx = self.twinx()

        #self.set_aspect(1.0)

        self.set_ra_ticks(5)
        self.set_dec_ticks(5)

        self.set_lim(self.center[0]-11, self.center[0]+11,
                     self.center[1]-9,  self.center[1]+9)
Exemple #17
0
 def __init__(self, fig, rect, *args, **kwargs):
     self.aln = kwargs.pop("aln")
     nrows = len(self.aln)
     ncols = self.aln.get_alignment_length()
     self.alnidx = numpy.arange(ncols)
     self.app = kwargs.pop("app", None)
     self.showy = kwargs.pop('showy', True)
     Axes.__init__(self, fig, rect, *args, **kwargs)
     rgb = mpl_colors.colorConverter.to_rgb
     gray = rgb('gray')
     d = defaultdict(lambda:gray)
     d["A"] = rgb("red")
     d["a"] = rgb("red")
     d["C"] = rgb("blue")
     d["c"] = rgb("blue")
     d["G"] = rgb("green")
     d["g"] = rgb("green")
     d["T"] = rgb("yellow")
     d["t"] = rgb("yellow")
     self.cmap = d
     self.selector = RectangleSelector(
         self, self.select_rectangle, useblit=True
         )
     def f(e):
         if e.button != 1: return True
         else: return RectangleSelector.ignore(self.selector, e)
     self.selector.ignore = f
     self.selected_rectangle = Rectangle(
         [0,0],0,0, facecolor='white', edgecolor='cyan', alpha=0.3
         )
     self.add_patch(self.selected_rectangle)
     self.highlight_find_collection = None
Exemple #18
0
 def cla(self):
     """
     Override to set up some reasonable defaults.
     """
     Axes.cla(self)
     self.xaxis.set_minor_locator(NullLocator())
     self.yaxis.set_minor_locator(NullLocator())
     self.xaxis.set_major_locator(NullLocator())
     self.yaxis.set_major_locator(NullLocator())
Exemple #19
0
    def grid(self, b=None, which='major', axis='both', **kwargs):
        Axes.grid(self, b, which, axis, **kwargs)

        plot_handler = GlobalFigureManager.get_active_figure().plot_handler
        if plot_handler is not None and not is_gui():
            if axis == 'x':
                plot_handler.manager._xgrid = b
            elif axis == 'y':
                plot_handler.manager._ygrid = b
    def __init__(self, *args, **kwargs):
        self.ra_0 = None
        self.dec_0 = None
        self.dec_1 = None
        self.dec_2 = None

        Axes.__init__(self, *args, **kwargs)

        self.cla()
def display_top_down( arrX, arrY ):

    fig = plt.figure(1)
    ax = Axes(fig, [.1,.1,.8,.8]) 
    fig.add_axes(ax)                                           
    l = Line2D( arrX, arrY)                    
    ax.add_line(l)        

    plt.plot( arrX, arrY, 'ro' )
Exemple #22
0
 def draw(self, *args):
     '''
     draw() is overridden here to allow the data transform to be updated
     before calling the Axes.draw() method.  This allows resizes to be
     properly handled without registering callbacks.  The amount of
     work done here is kept to a minimum.
     '''
     self._update_data_transform()
     Axes.draw(self, *args)
Exemple #23
0
    def __init__(self, *args, **kwargs):
        """
        Create a new Polar Axes for a polar plot.
        """

        self._rpad = 0.05
        Axes.__init__(self, *args, **kwargs)
        self.set_aspect('equal', adjustable='box', anchor='C')
        self.cla()
Exemple #24
0
 def cla(self):
     Axes.cla(self)
     self.title.set_y(1.05)
     self.xaxis.set_major_formatter(self.ThetaFormatter())
     angles = npy.arange(0.0, 360.0, 45.0)
     self.set_thetagrids(angles)
     self.yaxis.set_major_locator(self.RadialLocator(self.yaxis.get_major_locator()))
     self.grid(rcParams['polaraxes.grid'])
     self.xaxis.set_ticks_position('none')
     self.yaxis.set_ticks_position('none')
Exemple #25
0
    def set_top_view(self):
        # this happens to be the right view for the viewing coordinates
        # moved up and to the left slightly to fit labels and axes
        xdwl = (0.95/self.dist)
        xdw = (0.9/self.dist)
        ydwl = (0.95/self.dist)
        ydw = (0.9/self.dist)

        Axes.set_xlim(self, -xdwl, xdw, auto=None)
        Axes.set_ylim(self, -ydwl, ydw, auto=None)
Exemple #26
0
    def __init__(self, *args, **kwargs):
        """
        Create a new Polar Axes for a polar plot.
        """

        self._rpad = 0.05
        self.resolution = kwargs.pop('resolution', self.RESOLUTION)
        Axes.__init__(self, *args, **kwargs)
        self.set_aspect('equal', adjustable='box', anchor='C')
        self.cla()
Exemple #27
0
    def plot(self, *args, **kwargs):
        if "projection" in kwargs:
            projection = kwargs.pop("projection")
        else:
            projection = self.name
        vars = args[:2]
        args = args[2:]

        if len(vars) == 2 and isinstance(vars[1], (str, unicode)):
            args = (vars[1],) + args
            vars = vars[:1]

        if ((len(vars) == 1 and
             isinstance(vars[0], hfarray) and
             len(vars[0].dims) >= 1)):
            y = vars[0]
            x = hfarray(y.dims[0])
            vars = (x, y)
            if self.HFTOOLS_default_x_name is None:
                self.HFTOOLS_default_x_name = y.dims[0].name
                fmt = self.axes.xaxis.get_major_formatter()
                if hasattr(fmt, "update_template"):
                    fmt.default_label = self.HFTOOLS_default_x_name
                    fmt.update_template()

        if len(vars) == 1:
            y = vars[0]
            if projection in _projfun:
                x, y = _projfun[projection](None, y)
                return Axes.plot(self, y, *args, **kwargs)
            elif np.iscomplexobj(y):
                return Axes.plot(self, y.real, y.imag, *args, **kwargs)
            else:
                return Axes.plot(self, y, *args, **kwargs)
        elif len(vars) == 2:
            x = vars[0]
            y = vars[1]
            xunit = getattr(x, "unit", None)
            yunit = getattr(y, "unit", None)

            if projection in _projfun:
                x, y = _projfun[projection](x, y)
                lines = self._plot_helper(x, y, *args, **kwargs)
            elif np.iscomplexobj(y):
                xunit = yunit
                lines = self._plot_helper(y.real, y.imag, *args, **kwargs)
            else:
                lines = self._plot_helper(x, y, *args, **kwargs)
            if xunit:
                self.set_xlabel_unit(xunit)
            if yunit:
                self.set_ylabel_unit(yunit)
            return lines
        else:
            raise Exception("Missing plot data")
Exemple #28
0
    def __init__(self, *args, **kwargs):
        """
        Create a new Polar Axes for a polar plot.
        """
        self._default_theta_offset = kwargs.pop('theta_offset', 0)
        self._default_theta_direction = kwargs.pop('theta_direction', 1)
        self._default_rlabel_position = kwargs.pop('rlabel_position', 22.5)

        Axes.__init__(self, *args, **kwargs)
        self.set_aspect('equal', adjustable='box', anchor='C')
        self.cla()
Exemple #29
0
 def __init__(self, *args, **kwargs):
     self._memmap = kwargs.pop('memmap')
     self._sample_rate = kwargs.pop('sample_rate')
     self._num_samples = kwargs.pop('num_samples')
     limits = kwargs.pop('limits')
     self._limits = self._get_buffer_bounds(*limits)
     self._scale = kwargs.pop('scale', self.default_scale)
     self._pixel_density = kwargs.pop('pixel_density',
                                      self.default_pixel_density)
     Axes.__init__(self, *args, **kwargs)
     self._reload_buffer()
Exemple #30
0
 def draw(self, renderer, *args, **kwargs):
     if self._capture is None or not self._enabled:
         Axes.draw(self.axes, renderer, *args, **kwargs)
         self._capture = RenderCapture(self.axes, renderer)
     else:
         self.axes.axesPatch.draw(renderer, *args, **kwargs)
         self._capture.draw(renderer, *args, **kwargs)
         self.axes.xaxis.draw(renderer, *args, **kwargs)
         self.axes.yaxis.draw(renderer, *args, **kwargs)
         for s in self.axes.spines.values():
             s.draw(renderer, *args, **kwargs)
Exemple #31
0
def plot_observation(ax: Axes, field: Field, Z: dict, x: int, y: int):
    def Z_prob(i, j):
        if (i, j) in Z:
            return Z[(i, j)]
        return 0.5

    Z_img = np.zeros((field.M, field.N))
    for i in range(0, field.M):
        for j in range(0, field.N):
            if (i, j) in Z:
                Z_img[i, j] = Z[(i, j)]
            else:
                Z_img[i, j] = 0.5

    # plot the lighting
    ax.imshow(Z_img.T,
              extent=(0, field.M, 0, field.N),
              origin="lower",
              alpha=0.5,
              cmap="Reds")

    # for making the entire grid show up each time:
    ax.set_xlim(0, field.M)
    ax.set_xlim(0, field.N)
    ax.grid(True)
    ax.set_xticks(range(0, field.M + 1))
    ax.set_yticks(range(0, field.N + 1))
    ax.set_title("Observation")
    # plot the limits
    limits_x = [0, 0, field.M, field.M]
    limits_y = [0, field.N, 0, field.N]
    ax.scatter(limits_x, limits_y, s=10, c="k", marker="+")
    # plot the robots location with black 'X'
    ax.scatter([x + 0.5], [y + 0.5], s=20, c="k", marker="x")
Exemple #32
0
 def set_rticks(self, *args, **kwargs):
     return Axes.set_yticks(self, *args, **kwargs)
Exemple #33
0
def plot_field(ax: Axes, field: Field, light: bool = True):
    # plot the lighting
    if light:
        ax.imshow(field.light_grid.T,
                  extent=(0, field.M, 0, field.N),
                  origin="lower",
                  alpha=0.5,
                  cmap="Greys_r")

    # for making the entire grid show up each time:
    ax.set_xlim(0, field.M)
    ax.set_xlim(0, field.N)
    ax.grid(True)
    ax.set_xticks(range(0, field.M + 1))
    ax.set_yticks(range(0, field.N + 1))

    # plot the limits
    limits_x = [0, 0, field.M, field.M]
    limits_y = [0, field.N, 0, field.N]
    ax.scatter(limits_x, limits_y, s=10, c="k", marker="+")
    # plot the objects locations with black 'X'
    objects_x = [(obj[0] + 0.5) for obj in field.objects]
    objects_y = [(obj[1] + 0.5) for obj in field.objects]
    ax.scatter(objects_x, objects_y, s=20, c="k", marker="x")
    # plot the light sources
    if light:
        if field.lights is not None:
            lights_x = [(light[0] + 0.5) for light in field.lights]
            lights_y = [(light[1] + 0.5) for light in field.lights]
            ax.scatter(lights_x, lights_y, s=20, c="y", marker="d")
    # plot the blocking objects locations
    if field.obstacles is not None:
        block_x = [(block[0] + 0.5) for block in field.obstacles]
        block_y = [(block[1] + 0.5) for block in field.obstacles]
        ax.scatter(block_x, block_y, s=30, c="k", marker="s")
    # plot the aliasing and the aliasing directions
    alias_poses = field.alias.keys()
    alias_x = [k[0] + 0.5 for k in alias_poses]
    alias_y = [k[1] + 0.5 for k in alias_poses]
    directions = np.array(list(field.alias.values()))
    ax.scatter(alias_x, alias_y, s=20, c="r", marker="*")
    q = ax.quiver(alias_x,
                  alias_y,
                  directions[:, 0],
                  directions[:, 1],
                  scale=0.5,
                  scale_units='xy',
                  angles='xy',
                  color='r',
                  alpha=0.5)
Exemple #34
0
 def set_xlim(self, *args, **kwargs):
     Axes.set_xlim(self, -np.pi, np.pi)
     Axes.set_ylim(self, -np.pi / 2.0, np.pi / 2.0)
Exemple #35
0
 def setup_axes(self, ax_PR: Axes, ax_delay_P: Axes, ax_delay_R: Axes):
     lims = (self.zoom_from - self.lim_offset, 1 + self.lim_offset)
     ax_PR.set_xlim(lims)
     ax_PR.set_ylim(lims)
     # ax_PR.set_aspect("equal")
     # ^This unsynchs the axes widths.
     # Manually make sure aspect ratio is approximately equal using figsize.
     ax_PR.xaxis.set_major_formatter(fraction)
     ax_PR.yaxis.set_major_formatter(fraction)
     ax_PR.xaxis.tick_top()
     ax_PR.yaxis.tick_right()
     ax_delay_P.yaxis.set_major_formatter(fraction)
     ax_delay_P.xaxis.set_major_formatter(fraction)
     ax_delay_P.xaxis.set_label_position("top")
     ax_delay_P.xaxis.tick_top()
     ax_delay_P.set_ylim(lims)
     ax_delay_P.set_ylabel("Precision")
     ax_delay_P.set_xlabel("Detection latency")
     ax_delay_R.xaxis.set_major_formatter(fraction)
     ax_delay_R.yaxis.set_major_formatter(fraction)
     ax_delay_R.yaxis.set_label_position("right")
     ax_delay_R.yaxis.tick_right()
     ax_delay_R.set_xlim(lims)
     ax_delay_R.set_xlabel("Recall")
     ax_delay_R.set_ylabel("Detection latency")
Exemple #36
0
    def plot_group(keys, values, ax: Axes, **kwds):
        # GH 45465: xlabel/ylabel need to be popped out before plotting happens
        xlabel, ylabel = kwds.pop("xlabel", None), kwds.pop("ylabel", None)
        if xlabel:
            ax.set_xlabel(pprint_thing(xlabel))
        if ylabel:
            ax.set_ylabel(pprint_thing(ylabel))

        keys = [pprint_thing(x) for x in keys]
        values = [
            np.asarray(remove_na_arraylike(v), dtype=object) for v in values
        ]
        bp = ax.boxplot(values, **kwds)
        if fontsize is not None:
            ax.tick_params(axis="both", labelsize=fontsize)

        # GH 45465: x/y are flipped when "vert" changes
        is_vertical = kwds.get("vert", True)
        ticks = ax.get_xticks() if is_vertical else ax.get_yticks()
        if len(ticks) != len(keys):
            i, remainder = divmod(len(ticks), len(keys))
            assert remainder == 0, remainder
            keys *= i
        if is_vertical:
            ax.set_xticklabels(keys, rotation=rot)
        else:
            ax.set_yticklabels(keys, rotation=rot)
        maybe_color_bp(bp, **kwds)

        # Return axes in multiplot case, maybe revisit later # 985
        if return_type == "dict":
            return bp
        elif return_type == "both":
            return BoxPlot.BP(ax=ax, lines=bp)
        else:
            return ax
Exemple #37
0
def plot_fit(vis_data: ndarray,
             ax: Axes,
             xy_kwargs: dict = None,
             xycalc_kwargs: dict = None,
             xydiff_kwargs: dict = None,
             xyzero_kwargs: dict = None,
             fill_kwargs: dict = None,
             yzero: ndarray = None):
    """Visualize the fit.

    Parameters
    ----------
    vis_data : ndarray
        The data to visualize. The first three rows are independent variable, dependent variable, and fitting.

    ax : Axes
        The axes to plot on.

    xy_kwargs : dict
        The kwargs for plotting the y v.s. x curve.

    xycalc_kwargs : dict
        The kwargs for plotting the ycalc v.s. x curve.

    xydiff_kwargs : dict
        The kwargs for plotting the ydiff v.s. x curve.

    xyzero_kwargs : dict
        The kwargs for plotting the yzero v.s. x curve.

    fill_kwargs : dict
        The kwargs for filling in the area between ydiff and yzero.

    yzero : ndarray
        The base line corresponding to the zero value in ydiff.
    """
    # use the default value if None
    if xyzero_kwargs is None:
        xyzero_kwargs = {}
    if xydiff_kwargs is None:
        xydiff_kwargs = {}
    if xycalc_kwargs is None:
        xycalc_kwargs = {}
    if xy_kwargs is None:
        xy_kwargs = {}
    if fill_kwargs is None:
        fill_kwargs = {}
    # split data
    if len(vis_data.shape) != 2:
        raise ValueError('Invalid data shape: {}. Need 2D data array.'.format(
            vis_data.shape))
    if vis_data.shape[0] < 3:
        raise ValueError('Invalid data dimension: {}. Need dim >= 3'.format(
            vis_data.shape[0]))
    x, y, ycalc = vis_data[:3]
    ydiff = y - ycalc
    # shift ydiff
    if yzero is None:
        yzero = get_yzero(y, ycalc, ydiff)
    ydiff += yzero
    # circle data curve
    _xy_kwargs = {'fillstyle': 'none', 'label': 'data'}
    _xy_kwargs.update(xy_kwargs)
    data_line, = ax.plot(x, y, 'o', **_xy_kwargs)
    # solid calculation curve
    _xycalc_kwargs = {
        'label': 'fit',
        'color': complimentary(data_line.get_color())
    }
    _xycalc_kwargs.update(xycalc_kwargs)
    ax.plot(x, ycalc, '-', **_xycalc_kwargs)
    # dash zero difference curve
    _xyzero_kwargs = {'color': 'grey'}
    _xyzero_kwargs.update(xyzero_kwargs)
    ax.plot(x, yzero, '--', **_xyzero_kwargs)
    # solid shifted difference curve
    _xydiff_kwargs = {'label': 'residuals', 'color': data_line.get_color()}
    _xydiff_kwargs.update(xydiff_kwargs)
    diff_line, = ax.plot(x, ydiff, '-', **_xydiff_kwargs)
    # fill in area between curves
    if fill_kwargs.pop('fill', True):
        _fill_kwargs = {'alpha': 0.4, 'color': diff_line.get_color()}
        _fill_kwargs.update(fill_kwargs)
        ax.fill_between(x, ydiff, yzero, **_fill_kwargs)
    return ax
Exemple #38
0
    def _plot_totals(self, total_barplot_ax: Axes,
                     orientation: Literal['top', 'right']):
        """
        Makes the bar plot for totals
        """
        params = self.plot_group_extra
        counts_df = params['counts_df']
        if self.categories_order is not None:
            counts_df = counts_df.loc[self.categories_order]
        if params['color'] is None:
            if f'{self.groupby}_colors' in self.adata.uns:
                color = self.adata.uns[f'{self.groupby}_colors']
            else:
                color = 'salmon'
        else:
            color = params['color']

        if orientation == 'top':
            counts_df.plot(
                kind="bar",
                color=color,
                position=0.5,
                ax=total_barplot_ax,
                edgecolor="black",
                width=0.65,
            )
            # add numbers to the top of the bars
            max_y = max([p.get_height() for p in total_barplot_ax.patches])

            for p in total_barplot_ax.patches:
                p.set_x(p.get_x() + 0.5)
                if p.get_height() >= 1000:
                    display_number = f'{np.round(p.get_height()/1000, decimals=1)}k'
                else:
                    display_number = np.round(p.get_height(), decimals=1)
                total_barplot_ax.annotate(
                    display_number,
                    (p.get_x() + p.get_width() / 2.0,
                     (p.get_height() + max_y * 0.05)),
                    ha="center",
                    va="top",
                    xytext=(0, 10),
                    fontsize="x-small",
                    textcoords="offset points",
                )
            # for k in total_barplot_ax.spines.keys():
            #     total_barplot_ax.spines[k].set_visible(False)
            total_barplot_ax.set_ylim(0, max_y * 1.4)

        elif orientation == 'right':
            counts_df.plot(
                kind="barh",
                color=color,
                position=-0.3,
                ax=total_barplot_ax,
                edgecolor="black",
                width=0.65,
            )

            # add numbers to the right of the bars
            max_x = max([p.get_width() for p in total_barplot_ax.patches])
            for p in total_barplot_ax.patches:
                if p.get_width() >= 1000:
                    display_number = f'{np.round(p.get_width()/1000, decimals=1)}k'
                else:
                    display_number = np.round(p.get_width(), decimals=1)
                total_barplot_ax.annotate(
                    display_number,
                    ((p.get_width()), p.get_y() + p.get_height()),
                    ha="center",
                    va="top",
                    xytext=(10, 10),
                    fontsize="x-small",
                    textcoords="offset points",
                )
            total_barplot_ax.set_xlim(0, max_x * 1.4)

        total_barplot_ax.grid(False)
        total_barplot_ax.axis("off")
Exemple #39
0
 def __optionCallback(self, key, value, options):
     if key in [
             "minx", "maxx", "miny", "maxy", "minz", "maxz", "realwidth",
             "realheight", "azimuth", "elevation", "left_margin",
             "right_margin", "top_margin", "bottom_margin"
     ]:
         self.redrawlabels = 1
         if key[:3] in ["min", "max"]:
             minc = self.cget("min" + key[3])
             maxc = self.cget("max" + key[3])
             if minc < maxc:
                 func = None
                 if self.ax is self.ax3d:
                     func = getattr(self.ax, "set_" + key[3] + "lim3d")
                     self._cur_lims = (Axes.get_xlim(self.ax),
                                       Axes.get_ylim(self.ax))
                 elif key[3] != 'z':
                     func = getattr(self.ax, "set_" + key[3] + "lim")
                 if func is not None:
                     func(minc, maxc)
                     tickskey = key[3] + "ticks"
                     ticksval = self.cget(tickskey)
                     if ticksval is not None:
                         self.__optionCallback(tickskey, ticksval, options)
         elif key == "realwidth":
             lm = float(self.cget("left_margin"))
             rm = float(self.cget("right_margin"))
             self.ax.get_figure().subplots_adjust(left=lm / value,
                                                  right=1 - rm / value)
         elif key == "realheight":
             tm = float(self.cget("top_margin"))
             bm = float(self.cget("bottom_margin"))
             self.ax.get_figure().subplots_adjust(top=1 - tm / value,
                                                  bottom=bm / value)
         elif key == "left_margin":
             fig = self.ax.get_figure()
             width = fig.get_figwidth() * fig.get_dpi()
             fig.subplots_adjust(left=value / width)
         elif key == "right_margin":
             fig = self.ax.get_figure()
             width = fig.get_figwidth() * fig.get_dpi()
             fig.subplots_adjust(right=1 - value / width)
         elif key == "top_margin":
             fig = self.ax.get_figure()
             height = fig.get_figheight() * fig.get_dpi()
             fig.subplots_adjust(top=1 - value / height)
         elif key == "bottom_margin":
             fig = self.ax.get_figure()
             height = fig.get_figheight() * fig.get_dpi()
             fig.subplots_adjust(bottom=value / height)
         elif self.ax is self.ax3d:
             elev = self.cget("elevation")
             azim = self.cget("azimuth")
             if elev is not None or azim is not None:
                 self.ax.view_init(elev, azim)
     elif key == "grid":
         if value in ["yes", True]:
             self.ax.grid(color=self.cget("foreground"))
         else:
             self.ax.grid(False)
     elif key in ["width", "height"]:
         if isinstance(self.canvas, FigureCanvasTkAggRedraw):
             self.canvas.get_tk_widget()[key] = value
         else:
             fig = self.ax.get_figure()
             if key == "width":
                 fig.set_figwidth(float(value) / fig.get_dpi())
                 self._configNoDraw(realwidth=value)
             else:
                 fig.set_figheight(float(value) / fig.get_dpi())
                 self._configNoDraw(realheight=value)
     elif key == "top_title":
         fontsize = self.cget("top_title_fontsize")
         if fontsize is None:
             self.ax.set_title(value)
         else:
             self.ax.set_title(value, fontsize=fontsize)
     elif key == "top_title_fontsize":
         title = self.cget("top_title")
         if title is not None:
             self.ax.set_title(title, fontsize=value)
     elif key in ["background", "bg"]:
         self.ax.set_axis_bgcolor(value)
     elif key in ["foreground", "fg"]:
         matplotlib.rcParams["axes.edgecolor"] = self.cget("foreground")
         self.redrawlabels = 1
         if self.cget("grid") in ["yes", True]:
             self.ax.grid(color=value)
     elif key == "color_list":
         color_list = value.split()
         i = 0
         for d in self.data:
             if d["newsect"] is None or d["newsect"]:
                 i = i + 1
             if d["color"] is None:
                 color = i
             else:
                 color = d["color"]
             d["mpline"].set_color(color_list[color % len(color_list)])
     elif key == "decorations":
         if value:
             self.ax.set_axis_on()
         else:
             self.ax.set_axis_off()
     elif key == "use_symbols":
         self.plotsymbols()
     elif key == "use_labels":
         self.plotlabels()
     elif key in ["xlabel", "ylabel", "zlabel"]:
         if value is None:
             value = ""
         fontsize = self.cget(key + "_fontsize")
         if hasattr(self.ax, "set_" + key):
             func = getattr(self.ax, "set_" + key)
             if fontsize is None:
                 func(value)
             else:
                 func(value, fontsize=fontsize)
     elif key in ["xlabel_fontsize", "ylabel_fontsize", "zlabel_fontsize"]:
         label = self.cget(key[:6])
         if hasattr(self.ax, "set_" + key[:6]):
             func = getattr(self.ax, "set_" + key[:6])
             if value is None:
                 func(label)
             else:
                 func(label, fontsize=value)
     elif key in ["xticks", "yticks", "zticks"]:
         if value is None:
             if self.ax is self.ax3d:
                 axis = getattr(self.ax, "w_" + key[0] + "axis")
                 axis.set_major_locator(AutoLocator())
             elif key == "xticks":
                 self.ax.set_xscale('linear')
             else:
                 self.ax.set_yscale('linear')
         else:
             min = self.cget("min" + key[0])
             max = self.cget("max" + key[0])
             ticks = [
                 min + ((max - min) * i) / float(value - 1)
                 for i in range(value)
             ]
             if self.ax is self.ax3d:
                 axis = getattr(self.ax, "w_" + key[0] + "axis")
                 axis.set_major_locator(FixedLocator(ticks))
             elif key == "xticks":
                 self.ax.set_xticks(ticks)
             elif key == "yticks":
                 self.ax.set_yticks(ticks)
Exemple #40
0
    def plot(self, ax: axes.Axes) -> None:

        length = len(self._quotes)

        bodies = np.ndarray(shape=length, dtype=object)
        shadows = np.ndarray(shape=length, dtype=object)

        for index, df in enumerate(self._quotes.itertuples()):

            p_open = df.open
            p_high = df.high
            p_low = df.low
            p_close = df.close

            p_shadow_top = p_high
            p_shadow_bottom = p_low

            p_body_top: float
            p_body_bottom: float

            if p_open > p_close:
                p_body_top = p_open
                p_body_bottom = p_close
            else:
                p_body_top = p_close
                p_body_bottom = p_open

            assert p_body_top is not None
            assert p_body_bottom is not None

            if abs(p_open - p_close) < self._minimum_height:
                mid = (p_open + p_close) / 2.0
                mid_height = self._minimum_height / 2.0
                p_body_top = mid + mid_height
                p_body_bottom = mid - mid_height

            if abs(p_shadow_top - p_shadow_bottom) < self._minimum_height:
                mid = (p_shadow_top + p_shadow_bottom) / 2.0
                mid_height = self._minimum_height / 2.0
                p_shadow_top = mid + mid_height
                p_shadow_bottom = mid - mid_height

            color = self._color_unchanged

            if p_close > p_open:
                color = self._color_up
            elif p_close < p_open:
                color = self._color_down

            shadow = patches.Rectangle(
                xy=(index - (self._shadow_width / 2.0), p_shadow_bottom),
                width=self._shadow_width,
                height=p_shadow_top - p_shadow_bottom,
                facecolor=color,
                edgecolor=color,
            )

            body = patches.Rectangle(
                xy=(index - (self._body_width / 2.0), p_body_bottom),
                width=self._body_width,
                height=p_body_top - p_body_bottom,
                facecolor=color,
                edgecolor=color,
            )

            bodies[index] = body
            shadows[index] = shadow

        ax.add_collection(
            PatchCollection(bodies, match_original=True, zorder=self._zorder))
        ax.add_collection(
            PatchCollection(shadows, match_original=True, zorder=self._zorder))
Exemple #41
0
 def _post_plot_logic(self, ax: Axes, data):
     if self.orientation == "horizontal":
         ax.set_xlabel("Frequency")
     else:
         ax.set_ylabel("Frequency")
Exemple #42
0
 def _in_axes(self, mouseevent):
     if hasattr(self._pan_trans):
         return True
     else:
         return Axes._in_axes(self, mouseevent)
Exemple #43
0
 def set_ylim(self, *args, **kwargs):
     Axes.set_ylim(self, *args, **kwargs)
     self._update_affine()
Exemple #44
0
 def set_yscale(self, *args, **kwargs):
     if args[0] != 'linear':
         raise NotImplementedError
     Axes.set_yscale(self, *args, **kwargs)
Exemple #45
0
def plot_origin(ax: Axes, origin: np.array, **kwargs):

    props = origin_properties.copy()
    props.update(kwargs)

    return ax.plot(origin[0, None], origin[1, None], **props)
Exemple #46
0
 def _plot(self, time: Sequence[datetime.datetime], axis: Axes):
     axis.plot(time[self.initialization:], self.expect, label="expectation")
     lower = [_e - _d for _e, _d in zip(self.expect, self.dev_dn)]
     axis.plot(time[self.initialization:], lower, label="lower bound")
     upper = [_e + _d for _e, _d in zip(self.expect, self.dev_up)]
     axis.plot(time[self.initialization:], upper, label="upper bound")
Exemple #47
0
def plot_axis(ax: Axes, points: np.array, **kwargs):

    props = axis_properties.copy()
    props.update(kwargs)

    return ax.plot(points[:, :1], points[:, 1:], **props)
Exemple #48
0
    def plotlabels(self):
        if self.cget("realwidth") == 1 or self.cget("realheight") == 1:
            return
        self.redrawlabels = 0

        for labels in self.labels:
            for label in labels:
                if "mpline" in label:
                    self.ax.lines.remove(label["mpline"])
                    del label["mpline"]
                if "mptext" in label:
                    self.ax.texts.remove(label["mptext"])
                    del label["mptext"]

        if not self.cget("use_labels"):
            return

        if hasattr(self.ax.transData, 'transform'):
            trans = self.ax.transData.transform
            if self.ax is self.ax2d:

                def transseq(x, y):
                    return trans(transpose([x, y]))
            else:
                # this matrix needs to be initialized or matplotlib only does
                # that when it draws
                self.ax.M = self.ax.get_proj()

                def transseq(x, y, z):
                    return trans(
                        transpose(proj_transform(x, y, z, self.ax.M)[:2]))

            inv_trans = self.ax.transData.inverted().transform
        else:
            trans = self.ax.transData.xy_tup
            if self.ax is self.ax2d:

                def transseq(x, y):
                    return transpose(self.ax.transData.numerix_x_y(x, y))
            else:
                self.ax.M = self.ax.get_proj()

                def transseq(x, y, z):
                    return transpose(
                        self.ax.transData.numerix_x_y(
                            *(proj_transform(x, y, z, self.ax.M)[:2])))

            inv_trans = self.ax.transData.inverse_xy_tup
        if self.cget("smart_label"):
            mp = self.inarrs()
            sp1 = self.cget("left_margin")
            sp2 = 5  #fontsize
            sp3 = self.cget("bottom_margin")
            sp4 = 5
            for d in self.data:
                if self.ax is self.ax2d:
                    seq = transseq(d["x"], d["y"])
                else:
                    seq = transseq(d["x"], d["y"], d["z"])
                seq[:, 0] = (seq[:, 0] - sp1) / sp2
                seq[:, 1] = (seq[:, 1] - sp3) / sp4
                self.map_curve(mp, seq)

        for i, labels in enumerate(self.labels):
            for label in labels:
                if len(label["text"]) == 0:
                    continue
                [x, y] = label["xy"][:2]
                if self.ax is self.ax3d:
                    # transform 3d to 2d coordinates and compare
                    # against the 2D limits
                    z = label["xy"][2]
                    [x, y, z] = proj_transform(x, y, z, self.ax.M)
                lim = Axes.get_xlim(self.ax)
                if x < lim[0] or x > lim[1]:
                    continue
                lim = Axes.get_ylim(self.ax)
                if y < lim[0] or y > lim[1]:
                    continue
                data = trans((x, y))
                if data is not None:
                    [x, y] = data
                    if self.cget("smart_label"):
                        [xoffd1, yoffd1, xoffd2, yoffd2, xofft, yofft,
                         pos] = self.findsp(x, y, mp)
                    else:
                        [xoffd1, yoffd1, xoffd2, yoffd2, xofft, yofft,
                         pos] = self.dumblabel(i, label["j"], x, y)
                    [ha, va] = self.getpos(pos)
                    [xd1, yd1] = inv_trans((x + xoffd1, y + yoffd1))
                    [xd2, yd2] = inv_trans((x + xoffd2, y + yoffd2))
                    [xt, yt] = inv_trans((x + xofft, y + yofft))
                    # using the low-level line routines to avoid
                    # rescaling in 3D (even with auto scale off,
                    # the view is reset after each plot() by
                    # set_top_view in mplot3d)
                    line = Line2D([xd1, xd2], [yd1, yd2],
                                  linewidth=0.5,
                                  color=self.cget("foreground"))
                    self.ax.add_line(line)
                    self.ax.annotate(label["text"], (xt, yt),
                                     ha=ha,
                                     va=va,
                                     color=self.cget("foreground"),
                                     clip_on=True)
                    label["mpline"] = self.ax.lines[-1]
                    label["mptext"] = self.ax.texts[-1]
Exemple #49
0
 def _plot(self, time: Sequence[datetime.datetime], axis: Axes):
     axis.plot(time[self.initialization:], self.lower, label="lower bound")
     axis.plot(time[self.initialization:], self.upper, label="upper bound")
Exemple #50
0
 def __init__(self, *args, **kwargs):
     Axes.__init__(self, *args, **kwargs)
     self.filename = kwargs.pop('filename', '')
     self.dataset = kwargs.pop('dataset', '/img')
     self.frame_number = kwargs.pop('frame_number', 0)
     self.scaled = kwargs.pop('scaled', False)
Exemple #51
0
def draw_rigraph(graph: NxGraph, plot: Axes, **kwargs):
    """
  Draws the Region intersection graph, to visualize how Regions the
  overlapping or intersecting Regions, and how the Regions form
  a network of intersecting Regions.

  Args:
    nxgraph:  The Region intersecting graph to draw.
    plot:     The matplotlib plot to draw on.
    kwargs:   Additional arguments and options.

  Keyword Args:
    forced:
      Boolean flag whether or not to force apart
      the unconnected clusters (nodes and edges)
      within the graph.
    colored:
      Boolean flag for colored output or greyscale.

      If True:
        Color codes the nodes based on the Region's
        stored 'color' data property. And color codes the
        edges based on the 2 node Region's stored 'color'
        data property. If the colors in the node Region's
        differ, the edge is black.
      If False:
        All nodes and edges are colored black.
  """
    G = graph.G
    black = (0, 0, 0)

    forced = kwargs.get('forced', False)
    colored = kwargs.get('colored', False)

    def force(G: nx.Graph):
        df = DataFrame(index=G.nodes(), columns=G.nodes())
        for row, data in nx.shortest_path_length(G):
            for col, dist in data.items():
                df.loc[row, col] = dist
        df = df.fillna(df.max().max())
        return df.to_dict()

    def get_pos(G: nx.Graph):
        return nx.kamada_kawai_layout(G,
                                      **({
                                          'dist': force(G)
                                      } if forced else {}))

    def get_edge_color(r: Region):
        a, b = r['intersect']
        r_color = r.getdata('color')
        a_color = a.getdata('color')
        b_color = b.getdata('color')

        if r_color:
            return r_color
        if a_color and a_color == b_color:
            return a_color
        else:
            return black

    if colored:
        node_color = [
            region.getdata('color', black) for r, region, data in graph.regions
        ]
        edge_color = [
            get_edge_color(region) for u, v, region, data in graph.overlaps
        ]
    else:
        node_color = [black] * len(G)
        edge_color = [black] * nx.number_of_edges(G)

    plot.set_axis_off()
    nx.draw_networkx(G,
                     ax=plot,
                     **{
                         'pos': get_pos(G),
                         'with_labels': False,
                         'node_color': node_color,
                         'node_size': 1,
                         'edge_color': edge_color,
                         'width': 1
                     })
Exemple #52
0
 def set_yscale(self, *args, **kwargs):
     Axes.set_yscale(self, *args, **kwargs)
     self.yaxis.set_major_locator(
         self.RadialLocator(self.yaxis.get_major_locator()))
Exemple #53
0
 def __init__(self, *args, **kwargs):
     Axes.__init__(self, *args, **kwargs)
     self.set_aspect(0.5, adjustable='box', anchor='C')
     self.cla()
Exemple #54
0
 def set_rscale(self, *args, **kwargs):
     return Axes.set_yscale(self, *args, **kwargs)
Exemple #55
0
    def _plot_var_groups_brackets(
        gene_groups_ax: Axes,
        group_positions: Iterable[Tuple[int, int]],
        group_labels: Sequence[str],
        left_adjustment: float = -0.3,
        right_adjustment: float = 0.3,
        rotation: Optional[float] = None,
        orientation: Literal['top', 'right'] = 'top',
    ):
        """\
        Draws brackets that represent groups of genes on the give axis.
        For best results, this axis is located on top of an image whose
        x axis contains gene names.

        The gene_groups_ax should share the x axis with the main ax.

        Eg: gene_groups_ax = fig.add_subplot(axs[0, 0], sharex=dot_ax)

        Parameters
        ----------
        gene_groups_ax
            In this axis the gene marks are drawn
        group_positions
            Each item in the list, should contain the start and end position that the
            bracket should cover.
            Eg. [(0, 4), (5, 8)] means that there are two brackets, one for the var_names (eg genes)
            in positions 0-4 and other for positions 5-8
        group_labels
            List of group labels
        left_adjustment
            adjustment to plot the bracket start slightly before or after the first gene position.
            If the value is negative the start is moved before.
        right_adjustment
            adjustment to plot the bracket end slightly before or after the last gene position
            If the value is negative the start is moved before.
        rotation
            rotation degrees for the labels. If not given, small labels (<4 characters) are not
            rotated, otherwise, they are rotated 90 degrees
        orientation
            location of the brackets. Either `top` or `right`
        Returns
        -------
        None
        """
        import matplotlib.patches as patches
        from matplotlib.path import Path

        # get the 'brackets' coordinates as lists of start and end positions

        left = [x[0] + left_adjustment for x in group_positions]
        right = [x[1] + right_adjustment for x in group_positions]

        # verts and codes are used by PathPatch to make the brackets
        verts = []
        codes = []
        if orientation == 'top':
            # rotate labels if any of them is longer than 4 characters
            if rotation is None and group_labels:
                if max([len(x) for x in group_labels]) > 4:
                    rotation = 90
                else:
                    rotation = 0
            for idx, (left_coor, right_coor) in enumerate(zip(left, right)):
                verts.append((left_coor, 0))  # lower-left
                verts.append((left_coor, 0.6))  # upper-left
                verts.append((right_coor, 0.6))  # upper-right
                verts.append((right_coor, 0))  # lower-right

                codes.append(Path.MOVETO)
                codes.append(Path.LINETO)
                codes.append(Path.LINETO)
                codes.append(Path.LINETO)

                group_x_center = left[idx] + float(right[idx] - left[idx]) / 2
                gene_groups_ax.text(
                    group_x_center,
                    1.1,
                    group_labels[idx],
                    ha='center',
                    va='bottom',
                    rotation=rotation,
                )
        else:
            top = left
            bottom = right
            for idx, (top_coor, bottom_coor) in enumerate(zip(top, bottom)):
                verts.append((0, top_coor))  # upper-left
                verts.append((0.4, top_coor))  # upper-right
                verts.append((0.4, bottom_coor))  # lower-right
                verts.append((0, bottom_coor))  # lower-left

                codes.append(Path.MOVETO)
                codes.append(Path.LINETO)
                codes.append(Path.LINETO)
                codes.append(Path.LINETO)

                diff = bottom[idx] - top[idx]
                group_y_center = top[idx] + float(diff) / 2
                if diff * 2 < len(group_labels[idx]):
                    # cut label to fit available space
                    group_labels[idx] = group_labels[idx][:int(diff * 2)] + "."
                gene_groups_ax.text(
                    1.1,
                    group_y_center,
                    group_labels[idx],
                    ha='right',
                    va='center',
                    rotation=270,
                    fontsize='small',
                )

        path = Path(verts, codes)

        patch = patches.PathPatch(path, facecolor='none', lw=1.5)

        gene_groups_ax.add_patch(patch)
        gene_groups_ax.grid(False)
        gene_groups_ax.axis('off')
        # remove y ticks
        gene_groups_ax.tick_params(axis='y', left=False, labelleft=False)
        # remove x ticks and labels
        gene_groups_ax.tick_params(axis='x',
                                   bottom=False,
                                   labelbottom=False,
                                   labeltop=False)
Exemple #56
0
 def plot_PR_curves(self, ax: Axes):
     for sweep, color in zip(self.threshold_sweeps, self.colors):
         ax.plot(sweep.recall, sweep.precision, c=color, **self.line_kwargs)
 def __init__(self, *args, myclass=None, **kwargs):
     return Axes.__init__(self, *args, **kwargs)
Exemple #58
0
 def __init__(self, *args, **kwargs):
     kwargs.pop('myclass', None)
     return Axes.__init__(self, *args, **kwargs)
Exemple #59
0
 def _set_ticklabels(self, ax: Axes, labels):
     if self.orientation == "vertical":
         ax.set_xticklabels(labels)
     else:
         ax.set_yticklabels(labels)
Exemple #60
0
    def _plot_size_legend(self, size_legend_ax: Axes):
        # for the dot size legend, use step between dot_max and dot_min
        # based on how different they are.
        diff = self.dot_max - self.dot_min
        if 0.3 < diff <= 0.6:
            step = 0.1
        elif diff <= 0.3:
            step = 0.05
        else:
            step = 0.2
        # a descending range that is afterwards inverted is used
        # to guarantee that dot_max is in the legend.
        size_range = np.arange(self.dot_max, self.dot_min, step * -1)[::-1]
        if self.dot_min != 0 or self.dot_max != 1:
            dot_range = self.dot_max - self.dot_min
            size_values = (size_range - self.dot_min) / dot_range
        else:
            size_values = size_range

        size = size_values**self.size_exponent
        size = size * (self.largest_dot -
                       self.smallest_dot) + self.smallest_dot

        # plot size bar
        size_legend_ax.scatter(
            np.arange(len(size)) + 0.5,
            np.repeat(0, len(size)),
            s=size,
            color='gray',
            edgecolor='black',
            linewidth=self.dot_edge_lw,
            zorder=100,
        )
        size_legend_ax.set_xticks(np.arange(len(size)) + 0.5)
        labels = [
            "{}".format(np.round((x * 100), decimals=0).astype(int))
            for x in size_range
        ]
        size_legend_ax.set_xticklabels(labels, fontsize='small')

        # remove y ticks and labels
        size_legend_ax.tick_params(axis='y',
                                   left=False,
                                   labelleft=False,
                                   labelright=False)

        # remove surrounding lines
        size_legend_ax.spines['right'].set_visible(False)
        size_legend_ax.spines['top'].set_visible(False)
        size_legend_ax.spines['left'].set_visible(False)
        size_legend_ax.spines['bottom'].set_visible(False)
        size_legend_ax.grid(False)

        ymax = size_legend_ax.get_ylim()[1]
        size_legend_ax.set_ylim(-1.05 - self.largest_dot * 0.003, 4)
        size_legend_ax.set_title(self.size_title, y=ymax + 0.45, size='small')

        xmin, xmax = size_legend_ax.get_xlim()
        size_legend_ax.set_xlim(xmin - 0.15, xmax + 0.5)