Esempio n. 1
0
    def get_spine_transform(self):
        """get the spine transform"""
        self._ensure_position_is_set()
        what, how = self._spine_transform

        if what == 'data':
            # special case data based spine locations
            data_xform = self.axes.transScale + \
                         (how+self.axes.transLimits + self.axes.transAxes)
            if self.spine_type in ['left','right']:
                result = mtransforms.blended_transform_factory(
                    data_xform,self.axes.transData)
            elif self.spine_type in ['top','bottom']:
                result = mtransforms.blended_transform_factory(
                    self.axes.transData,data_xform)
            else:
                raise ValueError('unknown spine spine_type: %s'%self.spine_type)
            return result

        if self.spine_type in ['left','right']:
            base_transform = self.axes.get_yaxis_transform(which='grid')
        elif self.spine_type in ['top','bottom']:
            base_transform = self.axes.get_xaxis_transform(which='grid')
        else:
            raise ValueError('unknown spine spine_type: %s'%self.spine_type)

        if what=='identity':
            return base_transform
        elif what=='post':
            return base_transform+how
        elif what=='pre':
            return how+base_transform
        else:
            raise ValueError("unknown spine_transform type: %s"%what)
Esempio n. 2
0
def zoom_effect01(ax1, ax2, xmin, xmax, **kwargs):
    u"""
    ax1 : the main axes
    ax1 : the zoomed axes
    (xmin,xmax) : the limits of the colored area in both plot axes.
    connect ax1 & ax2. The x-range of (xmin, xmax) in both axes will
    be marked.  The keywords parameters will be used ti create
    patches.
    """
    trans1 = blended_transform_factory(ax1.transData, ax1.transAxes)
    trans2 = blended_transform_factory(ax2.transData, ax2.transAxes)
    bbox = Bbox.from_extents(xmin, 0, xmax, 1)
    mybbox1 = TransformedBbox(bbox, trans1)
    mybbox2 = TransformedBbox(bbox, trans2)
    prop_patches=kwargs.copy()
    prop_patches["ec"]="none"
    prop_patches["alpha"]=0.2
    c1, c2, bbox_patch1, bbox_patch2, p = \
        connect_bbox(mybbox1, mybbox2,
                     loc1a=3, loc2a=2, loc1b=4, loc2b=1,
                     prop_lines=kwargs, prop_patches=prop_patches)
    ax1.add_patch(bbox_patch1)
    ax2.add_patch(bbox_patch2)
    ax2.add_patch(c1)
    ax2.add_patch(c2)
    ax2.add_patch(p)
    return c1, c2, bbox_patch1, bbox_patch2, p
Esempio n. 3
0
    def make_range_frame(self):
        """ Constructs the component lines of the range frame """
        xtrans = transforms.blended_transform_factory(
            self.axes.transData, self.axes.transAxes
        )
        intervalx = interval_as_array(self.axes.dataLim.intervalx)

        ytrans = transforms.blended_transform_factory(
            self.axes.transAxes, self.axes.transData
        )
        intervaly = interval_as_array(self.axes.dataLim.intervaly)

        xline = LineCollection(
            segments=[[(intervalx[0], 0), (intervalx[1], 0)]],
            linewidths=[self.linewidth],
            colors=[self.color],
            transform=xtrans,
            zorder=10
        )
        yline = LineCollection(
            segments=[[(0, intervaly[0]), (0, intervaly[1])]],
            linewidths=[self.linewidth],
            colors=[self.color],
            transform=ytrans,
            zorder=10
        )

        return [xline, yline]
Esempio n. 4
0
    def get_spine_transform(self):
        """get the spine transform"""
        self._ensure_position_is_set()
        what, how = self._spine_transform

        if what == "data":
            # special case data based spine locations
            data_xform = self.axes.transScale + (how + self.axes.transLimits + self.axes.transAxes)
            if self.spine_type in ["left", "right"]:
                result = mtransforms.blended_transform_factory(data_xform, self.axes.transData)
            elif self.spine_type in ["top", "bottom"]:
                result = mtransforms.blended_transform_factory(self.axes.transData, data_xform)
            else:
                raise ValueError("unknown spine spine_type: %s" % self.spine_type)
            return result

        if self.spine_type in ["left", "right"]:
            base_transform = self.axes.get_yaxis_transform(which="grid")
        elif self.spine_type in ["top", "bottom"]:
            base_transform = self.axes.get_xaxis_transform(which="grid")
        else:
            raise ValueError("unknown spine spine_type: %s" % self.spine_type)

        if what == "identity":
            return base_transform
        elif what == "post":
            return base_transform + how
        elif what == "pre":
            return how + base_transform
        else:
            raise ValueError("unknown spine_transform type: %s" % what)
Esempio n. 5
0
 def set_position(self,position):
     """set the position of the spine
     Spine position is specified by a 2 tuple of (position type,
     amount). The position types are:
     * 'outward' : place the spine out from the data area by the
       specified number of points. (Negative values specify placing the
       spine inward.)
     * 'axes' : place the spine at the specified Axes coordinate (from
       0.0-1.0).
     * 'data' : place the spine at the specified data coordinate.
     Additionally, shorthand notations define a special positions:
     * 'center' -> ('axes',0.5)
     * 'zero' -> ('data', 0.0)
     """
     if position in ('center','zero'):
         pass
     else:
         assert len(position)==2, "position should be 'center' or 2-tuple"
         assert position[0] in ['outward','axes','data']
     self._position = position
     self._calc_offset_transform()
     t = self.get_spine_transform()
     if self.spine_type in ['left','right']:
         t2 = mtransforms.blended_transform_factory(t,
                                                    self.axes.transAxes)
     elif self.spine_type in ['bottom','top']:
         t2 = mtransforms.blended_transform_factory(self.axes.transAxes,
                                                    t)
     self.set_transform(t2)
     if self.axis is not None:
         self.axis.cla()
Esempio n. 6
0
def zoom_effect(ax1, ax2, xlim, **kwargs):

	trans1 = blended_transform_factory(ax1.transData, ax1.transAxes)
	trans2 = blended_transform_factory(ax2.transData, ax2.transAxes)

	bbox = Bbox.from_extents(xlim[0], 0, xlim[1], 1)

	tbbox1 = TransformedBbox(bbox, trans1)
	tbbox2 = TransformedBbox(bbox, trans2)

	
	prop_patches = kwargs.copy()
	prop_patches['ec'] = 'none'
	prop_patches['alpha'] = 0.1

	c1, c2, bbox_patch1, bbox_patch2, p = \
			connect_bboxes(tbbox1, tbbox2, loc1a=3, loc2a=2, loc1b=4, loc2b=1, prop_lines=kwargs, prop_patches=prop_patches)
	
	ax1.add_patch(bbox_patch1)
	ax2.add_patch(bbox_patch2)
	ax2.add_patch(c1)
	ax2.add_patch(c2)
	ax2.add_patch(p)

	return c1, c2, bbox_patch1, bbox_patch2, p
Esempio n. 7
0
def _fancy_barh(ax, values, data, val_fmt='', is_legend=False):
    """fancy-ish horizontal bar plot
       values must be same len as data, use np.nan if no values for sample
       :param is_legend: if True, will include legend like line segments
       :param val_fmt: is format string for value;
                       None don't show vlaue, '' just str()
    """
    assert( len(values) == len(data) )
    names = []
    for d in data:
        names.append(d['name'])
    width = 0.90
    bar_pos = np.arange(len(names))
    rects = ax.barh(bar_pos, values, width)
    ax.set_yticks([])
    ax.set_ylim([min(bar_pos)-(1-width)/2., max(bar_pos)+width+(1-width)/2.])
    ax.invert_yaxis()
    # set bar colors and annotate with sample name : size
    xmax = ax.get_xlim()[1]
    for ii,rect in enumerate(rects):
        rect.set_facecolor(data[ii]['plot_color'])
        x = 0.01
        y = rect.get_y()+rect.get_height()/2.
        if( val_fmt is None ):
            label = str(names[ii])
        else:
            if( np.isnan(values[ii]) ):
                label = '%s : No Info'%(names[ii])
            elif( val_fmt == '' ):
                label = '%s : %s'%(names[ii], str(values[ii]))
            else:
                label = ('%s : '+val_fmt)%(names[ii], values[ii])
        ax.text(x, y, label,
                va='center', ha='left',
                transform=transforms.blended_transform_factory(
                ax.transAxes, ax.transData),
                bbox=dict(boxstyle="round,pad=0.2", alpha=0.65, fc='w', lw=0) )
        if( is_legend ):
            # legend like line segment
            linex = [-0.12, -0.04]
            box_h= 0.5
            box_w_pad = 0.025
            ax.add_patch(MPL.patches.FancyBboxPatch((linex[0]-box_w_pad,y-box_h/2.),
                        linex[1]-linex[0]+2*box_w_pad, box_h,
                        ec='w',
                        fc='w',
                        boxstyle="square,pad=0",
                        transform=transforms.blended_transform_factory(
                                ax.transAxes, ax.transData),
                        clip_on=False) )
            line, = ax.plot([-0.12,-0.04], [y]*2,
                    '-', color=data[ii]['plot_color'],
                    marker=data[ii]['plot_marker'],
                    transform=transforms.blended_transform_factory(
                    ax.transAxes, ax.transData),
                    clip_on=False)
    ax.set_xlim((0,xmax))
    ax.set_ylabel(' \n \n ') # @TCC hack - fake ylabel so tight_layout adds spacing
    return True
Esempio n. 8
0
    def _set_lim_and_transforms(self):

        self.transAxes = self._parent_axes.transAxes

        self.transData = self.transAux + self._parent_axes.transData

        self._xaxis_transform = mtransforms.blended_transform_factory(self.transData, self.transAxes)
        self._yaxis_transform = mtransforms.blended_transform_factory(self.transAxes, self.transData)
Esempio n. 9
0
    def __init__(self, transform, fig_transform,
                 sizex=0, sizey=0, labelx=None, labely=None, loc=4,
                 xbar_width = 2, ybar_width = 2,
                 pad=3, borderpad=0.1, xsep=3, ysep = 3, prop=None, textprops={'size':10}, **kwargs):
        """
        Draw a horizontal and/or vertical  bar with the size in data coordinate
        of the give axes. A label will be drawn underneath (center-aligned).
 
        - transform : the coordinate frame (typically axes.transData)
        - sizex,sizey : width of x,y bar, in data units. 0 to omit
        - labelx,labely : labels for x,y bars; None to omit
        - loc : position in containing axes
        - pad, borderpad : padding, in fraction of the legend font size (or prop)
        - sep : separation between labels and bars in points.
        - **kwargs : additional arguments passed to base class constructor
        """
        from matplotlib.patches import Rectangle
        from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
        # new shit
        # try splitting the transform into X and Y so that
        import matplotlib.transforms as transforms
        xtransform = transforms.blended_transform_factory(transform, fig_transform)
        ytransform = transforms.blended_transform_factory(fig_transform, transform)
        # end new shit

        # bars = AuxTransformBox(xtransform)
        # if sizey:
        #     bars.add_artist(Rectangle((0,0), ybar_width, sizey,
        #                               fc="Black"))
        # if sizex:
        #     bars.add_artist(Rectangle((0,0), sizex, xbar_width,
        #                               fc="Black"))
 
        ybar_width /= 72.
        xbar_width /= 72.
        
        if sizey:
            ybar = AuxTransformBox(ytransform)
            ybar.add_artist(Rectangle((0,0), ybar_width, sizey, fc="Black"))
            bars = ybar
        if sizex:
            xbar = AuxTransformBox(xtransform)
            xbar.add_artist(Rectangle((0,0), sizex, xbar_width, fc="Black"))
            bars = xbar
        if sizex and sizey:
            bars = VPacker(children=[ybar, xbar], pad = 10, sep=ysep)
        if sizex and labelx:
            bars = VPacker(children=[bars, TextArea(labelx,
                                                    minimumdescent=False,
                                                    textprops = textprops)],
                           align="center", pad=0, sep=-3)
        if sizey and labely:
            bars = HPacker(children=[TextArea(labely,
                                              textprops = textprops), bars],
                            align="center", pad=0, sep=xsep)

        AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
                                   child=bars, prop=prop, frameon=False, **kwargs)
Esempio n. 10
0
File: plot.py Progetto: nhmc/H2
def dhist(xvals, yvals, xbins=20, ybins=20, ax=None, c='b', fmt='.', ms=1,
          label=None, loc='right,bottom', xhistmax=None, yhistmax=None,
          histlw=1, xtop=0.2, ytop=0.2, chist=None, **kwargs):
    """ Given two set of values, plot two histograms and the
    distribution.

    xvals,yvals are the two properties to plot.  xbins, ybins give the
    number of bins or the bin edges. c is the color.
    """

    if chist is None:
        chist = c
    if ax is None:
        pl.figure()
        ax = pl.gca()

    loc = [l.strip().lower() for l in loc.split(',')]

    if ms is None:
        ms = default_marker_size(fmt)

    ax.plot(xvals, yvals, fmt, color=c, ms=ms, label=label, **kwargs)
    x0,x1,y0,y1 = ax.axis()

    if np.__version__ < '1.5':
        x,xbins = np.histogram(xvals, bins=xbins, new=True)
        y,ybins = np.histogram(yvals, bins=ybins, new=True)
    else:
        x,xbins = np.histogram(xvals, bins=xbins)
        y,ybins = np.histogram(yvals, bins=ybins)

    b = np.repeat(xbins, 2)
    X = np.concatenate([[0], np.repeat(x,2), [0]])
    Xmax = xhistmax or X.max()
    X = xtop * X / Xmax
    if 'top' in loc:
        X = 1 - X
    trans = mtransforms.blended_transform_factory(ax.transData, ax.transAxes)
    ax.plot(b, X, color=chist, transform=trans, lw=histlw)

    b = np.repeat(ybins, 2)
    Y = np.concatenate([[0], np.repeat(y,2), [0]])
    Ymax = yhistmax or Y.max()
    Y = ytop * Y / Ymax
    if 'right' in loc:
        Y = 1 - Y
    trans = mtransforms.blended_transform_factory(ax.transAxes, ax.transData)
    ax.plot(Y, b, color=chist, transform=trans, lw=histlw)

    ax.set_xlim(xbins[0], xbins[-1])
    ax.set_ylim(ybins[0], ybins[-1])
    if pl.isinteractive():
        pl.show()

    return ax, dict(x=x, y=y, xbinedges=xbins, ybinedges=ybins)
Esempio n. 11
0
def puttext(x,y,text,ax, xcoord='ax', ycoord='ax', **kwargs):
    """ Print text on an axis using axes coordinates."""
    if xcoord == 'data' and ycoord == 'ax':
        trans = mtransforms.blended_transform_factory(ax.transData, ax.transAxes)
    elif xcoord == 'ax' and ycoord == 'data':
        trans = mtransforms.blended_transform_factory(ax.transAxes, ax.transData)
    elif xcoord == 'ax' and ycoord == 'ax':
        trans = ax.transAxes
    else:
        raise ValueError("Bad keyword combination: %s, %s "%(xcoord,ycoord))
    return ax.text(x, y, str(text), transform=trans, **kwargs)
Esempio n. 12
0
    def plotinit(self):
        """ Set up the figure and do initial plots.

        Updates the following attributes:

          self.artists
        """
        wa,fl,er = self.wa, self.fl, self.er
        if self.continuum is not None:
            co = self.continuum

        # axis for spectrum & continuum
        a0 = self.fig.add_axes((0.05,0.1,0.9,0.6))
        self.ax = a0
        a0.set_autoscale_on(0)
        # axis for residuals
        a1 = self.fig.add_axes((0.05,0.75,0.9,0.2),sharex=a0)
        a1.set_autoscale_on(0)
        a1.axhline(0,color='k',alpha=0.7, zorder=99)
        a1.axhline(1,color='k',alpha=0.7, zorder=99)
        a1.axhline(-1,color='k',alpha=0.7, zorder=99)
        a1.axhline(2,color='k',linestyle='dashed',zorder=99)
        a1.axhline(-2,color='k',linestyle='dashed',zorder=99)
        m0, = a1.plot([0],[0],'.r',marker='.', mec='none', lw=0, mew=0, ms=6, alpha=0.5)
        a1.set_ylim(-4, 4)
        a0.axhline(0, color='0.7')
        if self.continuum is not None:
            a0.plot(wa, co, color='0.7', lw=1, ls='dashed')
        self.artists['fl'], = a0.plot(wa, fl, 'b', lw=0.5,
                                      linestyle='steps-mid')
        a0.plot(wa, er, lw=0.5, color='orange')
        m1, = a0.plot([0], [0], 'r', alpha=0.7)
        m2, = a0.plot([0], [0], 'o', mfc='None',mew=1, ms=8, mec='r', picker=5,
                      alpha=0.7)
        a0.set_xlim(min(wa), max(wa))
        good = (er > 0) & ~np.isnan(fl) & ~np.isinf(fl)
        ymax = 2 * np.abs(np.percentile(fl[good], 95))
        a0.set_ylim(-0.1*ymax, ymax)
        a0.text(0.9,0.9, 'z=%.2f' % self.redshift, transform=a0.transAxes)

        # for histogram
        trans = mtran.blended_transform_factory(a1.transAxes, a1.transData)
        hist, = a1.plot([], [], color='k', transform=trans)
        x = np.linspace(-3,3)
        a1.plot(Gaussian(x,0,1,0.05), x, color='k', transform=trans, lw=0.5)

        if self.template is not None:
            trans = mtran.blended_transform_factory(a0.transData, a0.transAxes)
            a0.plot(self.wa, self.template/self.template.max(), '-c', lw=2,
                    alpha=0.5, transform=trans)

        self.fig.canvas.draw()
        self.artists.update(contpoints=m2, cont=m1, resid=m0, hist_left=hist)
Esempio n. 13
0
def zoom_effect(ax_zoomed, ax_origin, xlims = None, orientation='below', **kwargs):
    """
    ax_zoomed : zoomed axes
    ax_origin:  the main axes
    (xmin,xmax) : the limits of the colored area in both plot axes.

    connect ax1 & ax2. The x-range of (xmin, xmax) in both axes will
    be marked.  The keywords parameters will be used ti create
    patches.

    """
    if xlims is None:
        tt = ax_zoomed.transScale + (ax_zoomed.transLimits + ax_origin.transAxes)
        transform = blended_transform_factory(ax_origin.transData, tt)

        bbox_zoomed=ax_zoomed.bbox
        bbox_origin=TransformedBbox(ax_zoomed.viewLim, transform)
    else:
        transform_zoomed=blended_transform_factory(ax_zoomed.transData, ax_zoomed.transAxes)
        transform_origin=blended_transform_factory(ax_origin.transData, ax_origin.transAxes)
    
        bbox_zoomed=TransformedBbox(Bbox.from_extents(xlims[0], 0, xlims[1], 1), transform_zoomed)
        bbox_origin=TransformedBbox(Bbox.from_extents(xlims[0], 0, xlims[1], 1), transform_origin)

    prop_patches = kwargs.copy()
    prop_patches["ec"] = "none"
    prop_patches["alpha"] = 0.2

    if orientation=='below':
        loc1a=2
        loc2a=3
        loc1b=1
        loc2b=4
    elif orientation=='above':
        loc1a=3
        loc2a=2
        loc1b=4
        loc2b=1
    else:
        raise Exception("orientation '%s' not recognized" % orientation)

    c1, c2, bbox_zoomed_patch, bbox_origin_patch, p = \
        connect_bbox(bbox_zoomed, bbox_origin,
                     loc1a=loc1a, loc2a=loc2a, loc1b=loc1b, loc2b=loc2b,
                     prop_lines=kwargs, prop_patches=prop_patches)

    ax_zoomed.add_patch(bbox_zoomed_patch)
    ax_origin.add_patch(bbox_origin_patch)
    ax_origin.add_patch(c1)
    ax_origin.add_patch(c2)
    ax_origin.add_patch(p)

    return c1, c2, bbox_zoomed_patch, bbox_origin_patch, p
Esempio n. 14
0
 def setLabels(self):
     """ Set plot attributes """
     self.ppm.axpp.set_title("Seismograms")
     if self.opts.filemode == "pkl":
         axstk = self.axstk
         trans = transforms.blended_transform_factory(axstk.transAxes, axstk.transAxes)
         axstk.text(1, 1.01, self.opts.pklfile, transform=trans, va="bottom", ha="right", color="k")
     axpp = self.ppm.axpp
     trans = transforms.blended_transform_factory(axpp.transAxes, axpp.transData)
     font = FontProperties()
     font.set_family("monospace")
     axpp.text(1.025, 0, " " * 8 + "qual= CCC/SNR/COH", transform=trans, va="center", color="k", fontproperties=font)
Esempio n. 15
0
	def setLabels(self):
		""" Set plot attributes """
		self.ppm.axpp.set_title('Seismograms')
		if self.opts.filemode == 'pkl':
			axstk = self.axstk
			trans = transforms.blended_transform_factory(axstk.transAxes, axstk.transAxes)
			axstk.text(1,1.01,self.opts.pklfile,transform=trans, va='bottom', ha='right',color='k')
		axpp = self.ppm.axpp
		trans = transforms.blended_transform_factory(axpp.transAxes, axpp.transData)
		font = FontProperties()
		font.set_family('monospace')
		axpp.text(1.025, 0, ' '*8+'qual= CCC/SNR/COH', transform=trans, va='center', 
			color='k', fontproperties=font)
Esempio n. 16
0
    def _draw_labels(self):
        """
        Draw x and y labels onto the figure
        """
        # This is very laboured. Should be changed when MPL
        # finally has a constraint based layout manager.
        figure = self.figure
        get_property = self.theme.themeables.property

        try:
            margin = get_property('axis_title_x', 'margin')
        except KeyError:
            pad_x = 5
        else:
            pad_x = margin.get_as('t', 'pt')

        try:
            margin = get_property('axis_title_y', 'margin')
        except KeyError:
            pad_y = 5
        else:
            pad_y = margin.get_as('r', 'pt')

        # Get the axis labels (default or specified by user)
        # and let the coordinate modify them e.g. flip
        labels = self.coordinates.labels({
            'x': self.layout.xlabel(self.labels),
            'y': self.layout.ylabel(self.labels)
        })

        # The first axes object is on left, and the last axes object
        # is at the bottom. We change the transform so that the relevant
        # coordinate is in figure coordinates. This way we take
        # advantage of how MPL adjusts the label position so that they
        # do not overlap with the tick text. This works well for
        # facetting with scales='fixed' and also when not facetting.
        # first_ax = self.axs[0]
        # last_ax = self.axs[-1]

        xlabel = self.facet.last_ax.set_xlabel(
            labels['x'], labelpad=pad_x)
        ylabel = self.facet.first_ax.set_ylabel(
            labels['y'], labelpad=pad_y)

        xlabel.set_transform(mtransforms.blended_transform_factory(
            figure.transFigure, mtransforms.IdentityTransform()))
        ylabel.set_transform(mtransforms.blended_transform_factory(
            mtransforms.IdentityTransform(), figure.transFigure))

        figure._themeable['axis_title_x'] = xlabel
        figure._themeable['axis_title_y'] = ylabel
Esempio n. 17
0
    def _set_lim_and_transforms(self):
        self.transAxes = BboxTransformTo(self.bbox)

        # Transforms the x and y axis separately by a scale factor
        # It is assumed that this part will have non-linear components
        self.transScale = TransformWrapper(IdentityTransform())

        # A (possibly non-linear) projection on the (already scaled)
        # data.  This one is aware of rmin
        self.transProjection = self.PolarTransform(self)

        # This one is not aware of rmin
        self.transPureProjection = self.PolarTransform(self, use_rmin=False)

        # An affine transformation on the data, generally to limit the
        # range of the axes
        self.transProjectionAffine = self.PolarAffine(self.transScale, self.viewLim)

        # The complete data transformation stack -- from data all the
        # way to display coordinates
        self.transData = self.transScale + self.transProjection + (self.transProjectionAffine + self.transAxes)

        # This is the transform for theta-axis ticks.  It is
        # equivalent to transData, except it always puts r == 1.0 at
        # the edge of the axis circle.
        self._xaxis_transform = (
            self.transPureProjection + self.PolarAffine(IdentityTransform(), Bbox.unit()) + self.transAxes
        )
        # The theta labels are moved from radius == 0.0 to radius == 1.1
        self._theta_label1_position = Affine2D().translate(0.0, 1.1)
        self._xaxis_text1_transform = self._theta_label1_position + self._xaxis_transform
        self._theta_label2_position = Affine2D().translate(0.0, 1.0 / 1.1)
        self._xaxis_text2_transform = self._theta_label2_position + self._xaxis_transform

        # This is the transform for r-axis ticks.  It scales the theta
        # axis so the gridlines from 0.0 to 1.0, now go from 0.0 to
        # 2pi.
        self._yaxis_transform = Affine2D().scale(np.pi * 2.0, 1.0) + self.transData
        # The r-axis labels are put at an angle and padded in the r-direction
        self._r_label1_position = ScaledTranslation(
            22.5, self._rpad, blended_transform_factory(Affine2D(), BboxTransformToMaxOnly(self.viewLim))
        )
        self._yaxis_text1_transform = (
            self._r_label1_position + Affine2D().scale(1.0 / 360.0, 1.0) + self._yaxis_transform
        )
        self._r_label2_position = ScaledTranslation(
            22.5, -self._rpad, blended_transform_factory(Affine2D(), BboxTransformToMaxOnly(self.viewLim))
        )
        self._yaxis_text2_transform = (
            self._r_label2_position + Affine2D().scale(1.0 / 360.0, 1.0) + self._yaxis_transform
        )
Esempio n. 18
0
def _buildTransform(current_axes):
    global _stlp_data_transform, _stlp_xlabel_transform, _stlp_ylabel_transform

    current_figure = current_axes.figure

    current_axes.axes.get_xaxis().set_visible(False)
    current_axes.axes.get_yaxis().set_visible(False)
#   pylab.box(False)

    data_figure_trans = current_axes.transData + current_figure.transFigure.inverted()

    pylab.xlim((_T_min, _T_max))
    pylab.ylim((_p_min, _p_max))

    identity_matrix = np.zeros((3, 3))
    for idx in range(3): identity_matrix[idx, idx] = 1

    # Create the affine matrix for the skew transform.  This only works in data coordinates.  We'll fix that later ...
    skew_matrix = np.copy(identity_matrix)
    skew_matrix[0, 1] = np.tan(45 * np.pi / 180)
    skew_transform = transforms.Affine2D(skew_matrix)

    # Create the logarithmic transform in the y.
    log_p_transform = transforms.blended_transform_factory(transforms.Affine2D(), LogScale(current_axes.yaxis, basey=10).get_transform())

    # The log transform shrinks everything to log(p) space, so define a scale factor to blow it back up to a reasonable size.
    p_bnd_trans = log_p_transform.transform(np.array([[0, _p_min], [0, _p_max]]))[:, 1]
    scale_factor = (_p_max - _p_min) / (p_bnd_trans[1] - p_bnd_trans[0])

    # Define the affine transform for the flip and another for the scale back to reasonable coordinates after the log transform.
    flip_transform = transforms.Affine2D.identity().scale(1, -1)
    preskew_scale_transform = transforms.Affine2D().translate(0, p_bnd_trans[1]).scale(1, scale_factor).translate(0, _p_min)
    postskew_move_transform = transforms.Affine2D().translate(0, _p_min)

    # Define a transform that basically does everything but the skew so we can figure out where the 1000 mb level is and skew around that line.
    prelim_data_transform = log_p_transform + flip_transform + preskew_scale_transform + data_figure_trans
    marker = prelim_data_transform.transform(np.array([[_T_min, 1000]]))[0, 1]

    # Add a translation to that marker point into the data-figure transform matrix.
    data_figure_trans += transforms.Affine2D().translate(0, -marker)

    # Define our skew transform in figure coordinates.
    figure_skew_transform = data_figure_trans + skew_transform + data_figure_trans.inverted()

    # Create our skew-T log-p transform matrix.  It does the log-p transform first, then the flip, then the scale, then the skew.
    _stlp_data_transform = log_p_transform + flip_transform + preskew_scale_transform + figure_skew_transform + current_axes.transData

    # Create a blended transform where the y axis is the log-p, but the x axis is the axes transform (for adding pressure labels and wind barbs).
    _stlp_xlabel_transform = transforms.blended_transform_factory(_stlp_data_transform, current_axes.transAxes)
    _stlp_ylabel_transform = transforms.blended_transform_factory(current_axes.transAxes, _stlp_data_transform)
    return
Esempio n. 19
0
File: fitcont.py Progetto: nhmc/H2
    def plotinit(self):
        """ Set up the figure and do initial plots.

        Updates the following attributes:

          self.markers
        """
        wa,fl,er = [self.spec[k][0:-1:self.nbin] for k in 'wa fl er'.split()]
        if self.spec['co'] is not None:
            co = self.spec['co'][0:-1:self.nbin]
        # axis for spectrum & continuum
        a0 = self.fig.add_axes((0.05,0.1,0.9,0.6))
        a0.set_autoscale_on(0)
        # axis for residuals
        a1 = self.fig.add_axes((0.05,0.75,0.9,0.2),sharex=a0)
        a1.set_autoscale_on(0)
        a1.axhline(0,color='k',alpha=0.7, zorder=99)
        a1.axhline(1,color='k',alpha=0.7, zorder=99)
        a1.axhline(-1,color='k',alpha=0.7, zorder=99)
        a1.axhline(2,color='k',linestyle='dashed',zorder=99)
        a1.axhline(-2,color='k',linestyle='dashed',zorder=99)
        m0, = a1.plot([0],[0],'.r', ms=6, alpha=0.5)
        a1.set_ylim(-4, 4)
        a0.axhline(0, color='0.7')
        if self.spec['co'] is not None:
            a0.plot(wa,co, color='0.7', lw=1, ls='dashed')
        self.art_fl, = a0.plot(wa, fl, 'b', lw=0.5, linestyle='steps-mid')
        a0.plot(wa, er, lw=0.5, color='orange')
        m1, = a0.plot([0], [0], 'r', alpha=0.7)
        m2, = a0.plot([0], [0], 'o', mfc='None',mew=1, ms=8, mec='r', picker=5,
                      alpha=0.7)
        a0.set_xlim(min(wa), max(wa))
        good = (er > 0) & ~np.isnan(fl)
        ymin = -5 * np.median(er[good])
        ymax = 2 * sorted(fl[good])[int(good.sum()*0.95)]
        a0.set_ylim(ymin, ymax)
        a0.text(0.9,0.9, 'z=%.2f' % self.redshift, transform=a0.transAxes)

        # for histogram
        trans = mtran.blended_transform_factory(a1.transAxes, a1.transData)
        hist, = a1.plot([], [], color='k', transform=trans)
        x = np.linspace(-3,3)
        a1.plot(Gaussian(x,0,1,0.05), x, color='k', transform=trans, lw=0.5)

        if self.template is not None:
            trans = mtran.blended_transform_factory(a0.transData, a0.transAxes)                
            a0.plot(self.spec['wa'], self.template/self.template.max(), '-c', lw=2,
                    alpha=0.5, transform=trans)

        self.fig.canvas.draw()
        self.markers.update(contpoints=m2, cont=m1, resid=m0, hist_left=hist)
Esempio n. 20
0
def setAxes1(axs, opts, provs, weststyle=False):
	axps, axhp, axhs, axll = axs
	axhp.set_title(opts.figtt)
	axhp.yaxis.set_major_formatter(nullfmt)
	axhs.xaxis.set_major_formatter(nullfmt)
	axhp.set_ylabel('P Histogram')
	axhs.set_xlabel('S Histogram')
	#axhp.grid()
	#axhs.grid()
	axps.grid()
	#print axhp.get_ylim()
	#print axhs.get_xlim()
	if opts.absdt:
		axps.set_xlabel('Absolute P Delay Time [s]')
		axps.set_ylabel('Absolute S Delay Time [s]')
	else:
		axps.set_xlabel('Relative P Delay Time [s]')
		axps.set_ylabel('Relative S Delay Time [s]')
	axps.set_xticks(opts.xticks)
	axps.set_yticks(opts.yticks)
	axhs.xaxis.set_label_position('top')
	if weststyle:
		axps.yaxis.set_label_position('right')
		axps.yaxis.set_ticks_position('right')
		axhs.yaxis.set_ticks_position('right')
		axhs.set_xlim(axhs.get_xlim()[::-1])
	# make legend
	if opts.physio:
		ax = axll
		n = len(provs)
		for i in range(n):
			ax.plot(0.04, -i, color='None', marker=opts.pmarker, mec=pidict[provs[i]][2], 
				ms=opts.pms, mew=opts.pmew, ls='None', alpha=1)
				#ms=opts.pms, mew=opts.pmew, ls='None', alpha=opts.alpha)
			ax.text(0.08, -i, pidict[provs[i]][1], va='center', ha='left', size=14)
			ax.axis([0,1,0.5-n,0.5])
		ax.yaxis.set_major_formatter(nullfmt)
		ax.xaxis.set_major_formatter(nullfmt)
		ax.set_xticks([])
		ax.set_yticks([])
	# label fig number
	if opts.lab is not None:
		lab = '(' + opts.lab + ')'
		if weststyle:
			trans = transforms.blended_transform_factory(ax.transAxes, ax.transAxes)
			ax.text(-0.06, 1.03, lab, transform=trans, va='bottom', ha='left', size=20, fontweight='bold')
		else: 
			trans = transforms.blended_transform_factory(axhp.transAxes, axhp.transAxes)
			axhp.text(-0.13, 1.03, lab, transform=trans, va='bottom', ha='left', size=20, fontweight='bold')
Esempio n. 21
0
def line_to_axis(ax,x,y,xlabel=None,ylabel=None,fontsize=12):
    """
    Draw a line from a point x,y on to x and y axes ax.
    """
    v=ax.axis()

    if not (xlabel is None):
        ax.arrow(x,y,0,-(y-v[2])) # To x axis
        trans_x = transforms.blended_transform_factory(ax.transData, ax.transAxes)
        ax.text(x, -0.03, xlabel, transform=trans_x, fontsize=fontsize, va='center',ha='center')

    if not (ylabel is None):
        ax.arrow(x,y,-(x-v[0]),0.) # To y axis
        trans_y = transforms.blended_transform_factory(ax.transAxes, ax.transData)
        ax.text(-0.01, y, ylabel, transform=trans_y, fontsize=fontsize, va='center',ha='right')
Esempio n. 22
0
def zoom_effect(ax1, ax2, **kwargs):
    u"""
    ax1 : the zoomed axes
    ax2 : the main axes

    The xmin & xmax will be taken from the ax1.viewLim.
    """

    tt = ax1.transScale + (ax1.transLimits + ax2.transAxes)
    trans = blended_transform_factory(ax2.transData, tt)

    mybbox1 = ax1.bbox
    mybbox2 = TransformedBbox(ax1.viewLim, trans)

    prop_patches=kwargs.copy()
    prop_patches["ec"]="none"
    prop_patches["alpha"]=0.2

    c1, c2, bbox_patch1, bbox_patch2, p = \
        connect_bbox(mybbox1, mybbox2,
#                     loc1a=3, loc2a=2, loc1b=4, loc2b=1,
                     loc1a=4, loc2a=1, loc1b=3, loc2b=2,
                     prop_lines=kwargs, prop_patches=prop_patches)

    ax1.add_patch(bbox_patch1)
    ax2.add_patch(bbox_patch2)
    ax2.add_patch(c1)
    ax2.add_patch(c2)
    ax2.add_patch(p)

    return c1, c2, bbox_patch1, bbox_patch2, p
Esempio n. 23
0
def axvfill(xvals, ax=None, color='k', alpha=0.1, edgecolor='none', **kwargs):
    """ Fill vertical regions defined by a sequence of (left, right)
    positions.

    Parameters
    ----------
    xvals: list
      Sequence of pairs specifying the left and right extent of each
      region. e.g. (3,4) or [(0,1), (3,4)]
    ax : matplotlib axes instance (default is the current axes)
      The axes to plot regions on.
    color : mpl colour (default 'g')
      Color of the regions.
    alpha : float (default 0.3)
      Opacity of the regions (1=opaque).

    Other keywords arguments are passed to PolyCollection.
    """
    if ax is None:
        ax = pl.gca()
    xvals = np.asanyarray(xvals)
    if xvals.ndim == 1:
        xvals = xvals[None, :]
    if xvals.shape[-1] != 2:
        raise ValueError('Invalid input')

    coords = [[(x0,0), (x0,1), (x1,1), (x1,0)] for x0,x1 in xvals]
    trans = mtransforms.blended_transform_factory(ax.transData, ax.transAxes)
    kwargs.update(facecolor=color, edgecolor=edgecolor, transform=trans, alpha=alpha)

    p = PolyCollection(coords, **kwargs)
    ax.add_collection(p)
    ax.autoscale_view()
    return p
Esempio n. 24
0
    def _set_lim_and_transforms(self):
        """
        This is called once when the plot is created to set up all the
        transforms for the data, text and grids.
        """
        rot = 30

        # Get the standard transform setup from the Axes base class
        Axes._set_lim_and_transforms(self)

        # Need to put the skew in the middle, after the scale and limits,
        # but before the transAxes. This way, the skew is done in Axes
        # coordinates thus performing the transform around the proper origin
        # We keep the pre-transAxes transform around for other users, like the
        # spines for finding bounds
        self.transDataToAxes = self.transScale + \
            self.transLimits + transforms.Affine2D().skew_deg(rot, 0)

        # Create the full transform from Data to Pixels
        self.transData = self.transDataToAxes + self.transAxes

        # Blended transforms like this need to have the skewing applied using
        # both axes, in axes coords like before.
        self._xaxis_transform = (transforms.blended_transform_factory(
            self.transScale + self.transLimits,
            transforms.IdentityTransform()) +
            transforms.Affine2D().skew_deg(rot, 0)) + self.transAxes
Esempio n. 25
0
 def __init__(self):
     super().__init__()
     class Ssm:
         pass
     self.ssm = Ssm()
     self.ssm.btn_add = 3
     self.ssm.btn_del = 3
     self.ssm.key_mod = 'control'
     self.ssm.minspan = 0
     self.ssm.rect = None
     self.ssm.rangespans = []
     self.ssm.rectprops = dict(facecolor='0.5', alpha=0.2)
     self.ssm.ranges = self.ranges
     for rng in self.ssm.ranges:
         self.ssm.rangespans.append(self.makespan(rng[1], rng[0]-rng[1]))
     self.redraw()
     trans = blended_transform_factory(
         self.ax.transData,
         self.ax.transAxes)
     w, h = 0, 1
     self.ssm.rect = Rectangle([0, 0], w, h,
                           transform=trans,
                           visible=False,
                           animated=True,
                           **self.ssm.rectprops
                           )
     self.ax.add_patch(self.ssm.rect)
Esempio n. 26
0
 def test_line_extent_compound_coords2(self):
     # a simple line in (offset + data) coordinates in the y component, and in axes coordinates in the x
     ax = plt.axes()
     trans = mtrans.blended_transform_factory(ax.transAxes, mtrans.Affine2D().scale(10) + ax.transData)
     ax.plot([0.1, 1.2, 0.8], [35, -5, 18], transform=trans)
     np.testing.assert_array_equal(ax.dataLim.get_points(), np.array([[  np.inf,  -50.], [  -np.inf,  350.]]))
     plt.close()
Esempio n. 27
0
 def new_axes(self, ax, nrect):
     self.ax = ax
     if self.canvas is not ax.figure.canvas:
         if self.canvas is not None:
             self.disconnect_events()
         self.canvas = ax.figure.canvas
         self.connect_default_events()
     # span
     trans = blended_transform_factory(self.ax.transData, self.ax.transAxes)
     w, h = 0, 1
     self.rect = Rectangle((0, 0), w, h, transform=trans, visible=False,
                           animated=True, **self.rectprops)
     self.ax.add_patch(self.rect)
     self.artists = [self.rect]
     # stay rect
     self.stay_rects = []
     for set in range(0, len(nrect)):
         self.stay_rects.append([])
         for n in range(0, nrect[set]):
             stay_rect = Rectangle((0, 0), w, h, transform=trans, visible=False,
                                   animated=True, **self.stay_rectprops[set])
             self.ax.add_patch(stay_rect)
             self.stay_rects[set].append(stay_rect)
         self.artists.extend(self.stay_rects[set])
     # bar
     self.bar = ax.axvline(0, w, h, visible=False, **self.lineprops)
     self.artists.append(self.bar)
Esempio n. 28
0
def zoom_effect02(ax1, ax2, **kwargs):
    """
    ax1 : the main axes
    ax1 : the zoomed axes

    Similar to zoom_effect01.  The xmin & xmax will be taken from the
    ax1.viewLim.
    """

    tt = ax1.transScale + (ax1.transLimits + ax2.transAxes)
    trans = blended_transform_factory(ax2.transData, tt)

    mybbox1 = ax1.bbox
    mybbox2 = TransformedBbox(ax1.viewLim, trans)

    prop_patches = {**kwargs, "ec": "none", "alpha": 0.2}

    c1, c2, bbox_patch1, bbox_patch2, p = connect_bbox(
        mybbox1, mybbox2,
        loc1a=3, loc2a=2, loc1b=4, loc2b=1,
        prop_lines=kwargs, prop_patches=prop_patches)

    ax1.add_patch(bbox_patch1)
    ax2.add_patch(bbox_patch2)
    ax2.add_patch(c1)
    ax2.add_patch(c2)
    ax2.add_patch(p)

    return c1, c2, bbox_patch1, bbox_patch2, p
Esempio n. 29
0
def delayDistPlotRun(dtdict, sdfile, opts):	
	'Run delayDistPlot '
	tpfile = sdfile + '-topo'
	vals = loadtxt(tpfile, usecols=(1,2))
	dists = vals[:,0]
	topos = vals[:,1]

	dddict, dtmean, sdlist, ftag = delayDistGet(dtdict, sdfile, opts)

	fig = figure(figsize=(7, 8))
	ax0 = fig.add_subplot(2, 1, 1)
	ax1 = fig.add_subplot(2, 1, 2, sharex=ax0)
	subplots_adjust(left=.1, right=.95, bottom=.08, top=0.86, hspace=.08)
	rcParams['legend.fontsize'] = 9
	# plot topo
	ax = fig.add_axes([0.1, 0.92, 0.85, 0.065], sharex=ax0)
	axs = [ax0, ax1, ax]
	ax.plot(dists, topos, 'k-')
	zz = zeros(len(dists))
	ax.fill_between(dists, zz, topos, color='k', alpha=.3)
	tmin, tmax = min(topos), max(topos)
	dt = (tmax - tmin)*.2
	ax.set_ylim(tmin-dt, tmax+dt)
	ax.set_yticks([])
	ax.set_ylabel('Topo')
	# plot dt dist
	if opts.blackwhite:
		opts.randcol = False
	else:
		opts.randcol = True
	opts.pcol = 'k'
	opts.pms = 4
	opts.mcol = 'k'
	opts.msym = '+'
	opts.mms = 11
	opts.msym = '*'
	opts.mew = 2
	opts.alpha = .4
	if opts.blackwhite:
		opts.alpha = .2
		opts.alpha = .4

	opts.xlim = None
	opts.xlim = -100, sdlist[-1][1] + 100
	delayDistPlot(dddict, dtmean, sdlist, axs, opts)

	# indexing
	if opts.indexing is not None:
		#ax = ax0
		tt = '(' + opts.indexing + ')'
		trans = transforms.blended_transform_factory(ax.transAxes, ax.transAxes)
		#ax.text(-.05, 1.1, tt, transform=trans, va='center', ha='right', size=16)	
		ax.text(-.06, 1.1, tt, transform=trans, va='top', ha='right', size=16)	

	key = '-'.join(sdfile.split('/')[-1].split('.'))
	if opts.randcol:
		fignm = odir + 'ddline-rc-' + ftag + key + '.png'
	else:
		fignm = odir + 'ddline-kw-' + ftag + key + '.png'
	saveFigure(fignm, opts)
Esempio n. 30
0
 def test_line_extent_compound_coords1(self):
     # a simple line in data coordinates in the y component, and in axes coordinates in the x
     ax = plt.axes()
     trans = mtrans.blended_transform_factory(ax.transAxes, ax.transData)
     ax.plot([0.1, 1.2, 0.8], [35, -5, 18], transform=trans)
     np.testing.assert_array_equal(ax.dataLim.get_points(), np.array([[  0.,  -5.], [  1.,  35.]]))
     plt.close()
Esempio n. 31
0
def plot_feature_well(tc, gs):
    """
    Plotting function for the feature well.

    Args:
        tc (TransectContainer): The container for the main plot.
        log (axis): A matplotlib axis.
        gs (GridSpec): A matplotlib gridspec.
    """
    fname = tc.settings['curve_display']

    logs = tc.log.get(tc.feature_well)

    if not logs:
        # There was no data for this well, so there won't be a feature plot.
        Notice.fail("There's no well data for feature well " + tc.feature_well)
        return gs

    Z = logs.data['DEPT']

    curves = [
        'GR', 'DT', 'DPHI_SAN', 'NPHI_SAN', 'DTS', 'RT_HRLT', 'RHOB', 'DRHO'
    ]

    window = tc.settings.get('curve_smooth_window') or 51
    ntracks = 5
    lw = 1.0
    smooth = True
    naxes = 0
    ncurv_per_track = np.zeros(ntracks)

    if getattr(tc.log, 'striplog', None):
        ncurv_per_track[0] = 1

    for curve in curves:
        naxes += 1
        params = get_curve_params(curve, fname)
        ncurv_per_track[params['track']] += 1

    axss = plt.subplot(gs[2:, -5])
    axs0 = [axss, axss.twiny()]
    axs1 = [plt.subplot(gs[2:, -4])]
    axs2 = [plt.subplot(gs[2:, -3])]
    axs3 = [plt.subplot(gs[2:, -2])]
    axs4 = [plt.subplot(gs[2:, -1])]

    axs = [axs0, axs1, axs2, axs3, axs4]

    if getattr(tc.log, 'striplog', None):
        legend = Legend.default()
        try:
            logs.striplog[tc.log.striplog].plot_axis(axs0[0], legend=legend)
        except KeyError:
            # In fact, this striplog doesn't exist.
            Notice.fail("There is no such striplog" + tc.log.striplog)
            # And move on...

    axs0[0].set_ylim([Z[-1], 0])
    label_shift = np.zeros(len(axs))

    for curve in curves:
        try:
            values = logs.data[curve]
        except ValueError:
            Notice.warning("Curve not present: " + curve)
            values = np.empty_like(Z)
            values[:] = np.nan

        params = get_curve_params(curve, fname)
        i = params['track']

        j = 0

        label_shift[i] += 1

        linOrlog = params['logarithmic']

        sxticks = np.array(params['xticks'])
        xticks = np.array(sxticks, dtype=float)
        whichticks = 'major'

        if linOrlog == 'log':
            midline = np.log(np.mean(xticks))
            xpos = midline
            whichticks = 'minor'
        else:
            midline = np.mean(xticks)
            xpos = midline

        if smooth:
            values = utils.rolling_median(values, window)

        if curve == 'GR':
            j = 1  # second axis in first track
            label_shift[i] = 1
            if params['fill_left_cond']:
                # do the fill for the lithology track
                axs[i][j].fill_betweenx(Z,
                                        params['xleft'],
                                        values,
                                        facecolor=params['fill_left'],
                                        alpha=1.0,
                                        zorder=11)

        if (curve == 'DPHI_SAN') and params['fill_left_cond']:
            # do the fill for the neutron porosity track
            try:
                nphi = utils.rolling_median(logs.data['NPHI_SAN'], window)
            except ValueError:
                Notice.warning("No NPHI in this well")
                nphi = np.empty_like(Z)
                nphi[:] = np.nan
            axs[i][j].fill_betweenx(Z,
                                    nphi,
                                    values,
                                    where=nphi >= values,
                                    facecolor=params['fill_left'],
                                    alpha=1.0,
                                    zorder=11)

            axs[i][j].fill_betweenx(Z,
                                    nphi,
                                    values,
                                    where=nphi <= values,
                                    facecolor='#8C1717',
                                    alpha=0.5,
                                    zorder=12)

        if curve == 'DRHO':
            blk_drho = 3.2
            values += blk_drho  # this is a hack to get DRHO on RHOB scale
            axs[i][j].fill_betweenx(Z,
                                    blk_drho,
                                    values,
                                    where=nphi <= values,
                                    facecolor='#CCCCCC',
                                    alpha=0.5,
                                    zorder=12)

        # fill right
        if params['fill_right_cond']:

            axs[i][j].fill_betweenx(Z,
                                    values,
                                    params['xright'],
                                    facecolor=params['fill_right'],
                                    alpha=1.0,
                                    zorder=12)

        # plot curve
        axs[i][j].plot(values, Z, color=params['hexcolor'], lw=lw, zorder=13)

        # set scale of curve
        axs[i][j].set_xlim([params['xleft'], params['xright']])

        # ------------------------------------------------- #
        # curve labels
        # ------------------------------------------------- #

        trans = transforms.blended_transform_factory(axs[i][j].transData,
                                                     axs[i][j].transData)

        magic = -Z[-1] / 12.
        axs[i][j].text(xpos,
                       magic - (magic / 4) * (label_shift[i] - 1),
                       curve,
                       horizontalalignment='center',
                       verticalalignment='bottom',
                       fontsize=12,
                       color=params['hexcolor'],
                       transform=trans)
        # curve units
        units = '${}$'.format(params['units'])
        if label_shift[i] <= 1:
            axs[i][j].text(xpos,
                           magic * 0.5,
                           units,
                           horizontalalignment='center',
                           verticalalignment='top',
                           fontsize=12,
                           color='k',
                           transform=trans)

        # ------------------------------------------------- #
        # scales and tickmarks
        # ------------------------------------------------- #

        axs[i][j].set_xscale(linOrlog)
        axs[i][j].set_ylim([Z[-1], 0])
        axs[i][j].axes.xaxis.set_ticks(xticks)
        axs[i][j].axes.xaxis.set_ticklabels(sxticks, fontsize=8)
        for label in axs[i][j].axes.xaxis.get_ticklabels():
            label.set_rotation(90)
        axs[i][j].tick_params(axis='x', direction='out')
        axs[i][j].xaxis.tick_top()
        axs[i][j].xaxis.set_label_position('top')
        axs[i][j].xaxis.grid(True,
                             which=whichticks,
                             linewidth=0.25,
                             linestyle='-',
                             color='0.75',
                             zorder=100)

        axs[i][j].yaxis.grid(True,
                             which=whichticks,
                             linewidth=0.25,
                             linestyle='-',
                             color='0.75',
                             zorder=100)
        axs[i][j].yaxis.set_ticks(np.arange(0, max(Z), 100))
        if i != 0:
            axs[i][j].set_yticklabels("")

    # ------------------------------------------------- #
    # End of curve loop
    # ------------------------------------------------- #

    # Add Depth label
    axs[0][0].text(0,
                   1.05,
                   'MD\n$m$',
                   fontsize='10',
                   horizontalalignment='center',
                   verticalalignment='center',
                   transform=axs[0][0].transAxes)

    axs[0][0].axes.yaxis.get_ticklabels()
    axs[0][0].axes.xaxis.set_ticklabels('')

    for label in axs[0][0].axes.yaxis.get_ticklabels():
        label.set_rotation(90)
        label.set_fontsize(10)

    for label in axs[1][0].axes.xaxis.get_ticklabels():
        label.set_rotation(90)
        label.set_fontsize(10)

    # Add Tops
    try:
        if os.path.exists(tc.tops_file):
            tops = utils.get_tops(tc.tops_file)
            topx = get_curve_params('DT', fname)
            topmidpt = np.amax((topx)['xright'])

            # plot tops
            for i in range(ntracks):

                for mkr, depth in tops.iteritems():

                    # draw horizontal bars at the top position
                    axs[i][-1].axhline(y=depth,
                                       xmin=0.01,
                                       xmax=.99,
                                       color='b',
                                       lw=2,
                                       alpha=0.5,
                                       zorder=100)

                    # draw text box at the right edge of the last track
                    axs[-1][-1].text(x=topmidpt,
                                     y=depth,
                                     s=mkr,
                                     alpha=0.5,
                                     color='k',
                                     fontsize='8',
                                     horizontalalignment='center',
                                     verticalalignment='center',
                                     zorder=10000,
                                     bbox=dict(facecolor='white',
                                               edgecolor='k',
                                               alpha=0.25,
                                               lw=0.25),
                                     weight='light')

    except AttributeError:
        Notice.warning("No tops for this well")
    except TypeError:
        # We didn't get a tops file so move along.
        print "No tops for this well"

    return gs
Esempio n. 32
0
    opts.ccpara = ccpara
    return gsac, opts


if __name__ == '__main__':
    gsac, opts = load()
    saclist = gsac.saclist
    xxlim = -20, 20
    reltimes = [0, 3]
    npick = len(reltimes)
    axs = axes1(npick)
    twa = -10, 10
    twb = getwin(gsac, opts, 't2')
    twins = [twa, twb]
    tts = ['Predicted', 'Measured']
    for i in range(npick):
        opts.reltime = reltimes[i]
        ax = axs[i]
        sacp1(saclist, opts, ax)
        ax.set_xlim(xxlim)
        plotwin(ax, twins[i], opts.pppara)
        ax.set_title(tts[i])
    labs = 'ab'
    for ax, lab in zip(axs, labs):
        tt = '(' + lab + ')'
        trans = transforms.blended_transform_factory(ax.transAxes,
                                                     ax.transAxes)
        ax.text(-.05, 1, tt, transform=trans, va='center', ha='right', size=16)
    plt.savefig('egalignp1.pdf', format='pdf')
    plt.show()
Esempio n. 33
0
def _fancy_barh(ax, values, data, val_fmt='', is_legend=False):
    """fancy-ish horizontal bar plot
       values must be same len as data, use np.nan if no values for sample
       :param is_legend: if True, will include legend like line segments
       :param val_fmt: is format string for value;
                       None don't show vlaue, '' just str()
    """
    assert (len(values) == len(data))
    names = []
    for d in data:
        names.append(d['name'])
    width = 0.90
    bar_pos = np.arange(len(names))
    rects = ax.barh(bar_pos, values, width)
    ax.set_yticks([])
    ax.set_ylim([
        min(bar_pos) - (1 - width) / 2.,
        max(bar_pos) + width + (1 - width) / 2.
    ])
    ax.invert_yaxis()
    # set bar colors and annotate with sample name : size
    xmax = ax.get_xlim()[1]
    for ii, rect in enumerate(rects):
        rect.set_facecolor(data[ii]['plot_color'])
        x = 0.01
        y = rect.get_y() + rect.get_height() / 2.
        if (val_fmt is None):
            label = str(names[ii])
        else:
            if (np.isnan(values[ii])):
                label = '%s : No Info' % (names[ii])
            elif (val_fmt == ''):
                label = '%s : %s' % (names[ii], str(values[ii]))
            else:
                label = ('%s : ' + val_fmt) % (names[ii], values[ii])
        ax.text(x,
                y,
                label,
                va='center',
                ha='left',
                transform=transforms.blended_transform_factory(
                    ax.transAxes, ax.transData),
                bbox=dict(boxstyle="round,pad=0.2", alpha=0.65, fc='w', lw=0))
        if (is_legend):
            # legend like line segment
            linex = [-0.12, -0.04]
            box_h = 0.5
            box_w_pad = 0.025
            ax.add_patch(
                MPL.patches.FancyBboxPatch(
                    (linex[0] - box_w_pad, y - box_h / 2.),
                    linex[1] - linex[0] + 2 * box_w_pad,
                    box_h,
                    ec='w',
                    fc='w',
                    boxstyle="square,pad=0",
                    transform=transforms.blended_transform_factory(
                        ax.transAxes, ax.transData),
                    clip_on=False))
            line, = ax.plot([-0.12, -0.04], [y] * 2,
                            '-',
                            color=data[ii]['plot_color'],
                            marker=data[ii]['plot_marker'],
                            transform=transforms.blended_transform_factory(
                                ax.transAxes, ax.transData),
                            clip_on=False,
                            lw=LINE_WIDTH)
    ax.set_xlim((0, xmax))
    ax.set_ylabel(
        ' \n \n ')  # @TCC hack - fake ylabel so tight_layout adds spacing
    return True
Esempio n. 34
0
def plot_model_comparison(result,
                          sort=False,
                          colors=None,
                          alpha=0.01,
                          test_pair_comparisons=True,
                          multiple_pair_testing='fdr',
                          test_above_0=True,
                          test_below_noise_ceil=True,
                          error_bars='sem'):
    """ Plots the results of RSA inference on a set of models as a bar graph
    with one bar for each model indicating its predictive performance. The
    function also shows the noise ceiling whose upper edge is an upper bound
    on the performance the true model could achieve (given noise and inter-
    subject variability) and whose lower edge is an estimate of a lower bound
    on the performance of the true model. In addition, all pairwise inferential
    model comparisons are shown in the upper part of the figure.
    The only mandatory input is a "result" object containing  model evaluations
    for bootstrap samples and crossvalidation folds. These are used here to
    construct confidence intervals and perform the significance tests.

    Args (All strings case insensitive):
        result (pyrsa.inference.result.Result):
            model evaluation result
        sort (Boolean or string):
            False (default): plot bars in the order passed
            'descend[ing]': plot bars in descending order of model performance
            'ascend[ing]': plot bars in ascending order of model performance
        colors (list of lists, numpy array, matplotlib colormap):
            None (default): default blue for all bars
            single color: list or numpy array of 3 or 4 values (RGB, RGBA)
                    specifying the color for all bars
            multiple colors: list of lists or numpy array (number of colors by
                    3 or 4 channels -- RGB, RGBA). If the number of colors
                    matches the number of models, each color is used for the
                    bar corresponding to one model (in the order of the models
                    as passed). If the number of colors does not match the
                    number of models, the list is linearly interpolated to
                    assign a color to each model (in the order of the models as
                    passed). For example, two colors will become a gradation,
                    unless there are exactly two model. Instead of a list of
                    lists or numpy array, a matplotlib colormap object may also
                    be passed (e.g. colors = cm.coolwarm).
        alpha (float):
            significance threshold (p threshold or FDR q threshold)
        test_pair_comparisons (Boolean or string):
            False or None: do not plot pairwise model comparison results
            True (default): plot pairwise model comparison results using
                default settings
            'arrows': plot results in arrows style, indicating pairs of sets
                between which all differences are significant
            'nili': plot results as Nili bars (Nili et al. 2014), indicating
                each significant difference by a horizontal line (or each
                nonsignificant difference if the string contains a '2', e.g.
                'nili2')
            'golan': plot results as Golan wings (Golan et al. 2020), with one
                wing (graphical element) indicating all dominance relationships
                for one model.
            'cliques': plot results as cliques of insignificant differences
        multiple_pair_testing (Boolean or string):
            False or 'none': do not adjust for multiple testing for the
                pairwise model comparisons
            'FDR' or 'fdr' (default): control the false-discorvery rate at
                q = alpha
            'FWER',' fwer', or 'Bonferroni': control the familywise error rate
            using the Bonferroni method
        test_above_0 (Boolean or string):
            False or None: do not plot results of statistical comparison of
                each model performance against 0
            True (default): plot results of statistical comparison of each
                model performance against 0 using default settings ('dewdrops')
            'dewdrops': place circular "dewdrops" at the baseline to indicate
                models whose performance is significantly greater than 0
            'icicles': place triangular "icicles" at the baseline to indicate
                models whose performance is significantly greater than 0
            Tests are one-sided, use the global alpha threshold and are
            automatically Bonferroni-corrected for the number of models tested.
        test_below_noise_ceil (Boolean or string):
            False or None: do not plot results of statistical comparison of
                each model performance against the lower-bound estimate of the
                noise ceiling
            True (default): plot results of statistical comparison of each
                model performance against the lower-bound estimate of the noise
                ceiling using default settings ('dewdrops')
            'dewdrops': use circular "dewdrops" at the lower bound of the
                noise ceiling to indicate models whose performance is
                significantly below the lower-bound estimate of the noise
                ceiling
            'icicles': use triangular "icicles" at the lower bound of the noise
                ceiling to indicate models whose performance is significantly
                below the lower-bound estimate of the noise ceiling
            Tests are one-sided, use the global alpha threshold and are
            automatically Bonferroni-corrected for the number of models tested.
        error_bars (Boolean or string):
            False or None: do not plot error bars
            True (default) or 'SEM': plot the standard error of the mean
            'CI': plot 95%-confidence intervals (exluding 2.5% on each side)
            'CI[x]': plot x%-confidence intervals (exluding 2.5% on each side)
            Confidence intervals are based on the bootstrap procedure,
            reflecting variability of the estimate across subjects and/or
            experimental conditions.

    Returns:
        ---

    """

    # Prepare and sort data
    evaluations = result.evaluations
    models = result.models
    noise_ceiling = result.noise_ceiling
    method = result.method

    while len(evaluations.shape) > 2:
        evaluations = np.nanmean(evaluations, axis=-1)
    if noise_ceiling.ndim > 1:
        noise_ceiling = noise_ceiling[:, ~np.isnan(evaluations[:, 0])]
    evaluations = evaluations[~np.isnan(evaluations[:, 0])]
    perf = np.mean(evaluations, axis=0)
    n_bootstraps, n_models = evaluations.shape
    if sort is True:
        sort = 'descending'  # descending by default if sort is True
    elif sort is False:
        sort = 'unsorted'
    if sort != 'unsorted':  # 'descending' or 'ascending'
        idx = np.argsort(perf)
        if 'descend' in sort.lower():
            idx = np.flip(idx)
        perf = perf[idx]
        evaluations = evaluations[:, idx]
        models = [models[i] for i in idx]
        if not ('descend' in sort.lower() or 'ascend' in sort.lower()):
            raise Exception('plot_model_comparison: Argument ' +
                            'sort is incorrectly defined as ' + sort + '.')

    # Prepare axes for bars and pairwise comparisons
    fs, fs2 = 18, 14  # axis label font sizes
    l, b, w, h = 0.15, 0.15, 0.8, 0.8
    fig = plt.figure(figsize=(12.5, 10))
    if test_pair_comparisons is True:
        test_pair_comparisons = 'arrows'
    if test_pair_comparisons:
        if test_pair_comparisons.lower() in ['arrows', 'cliques']:
            h_pair_tests = 0.25
        elif 'golan' in test_pair_comparisons.lower():
            h_pair_tests = 0.3
        elif 'nili' in test_pair_comparisons.lower():
            h_pair_tests = 0.4
        else:
            raise Exception(
                'plot_model_comparison: Argument ' +
                'test_pair_comparisons is incorrectly defined as ' +
                test_pair_comparisons + '.')
        ax = plt.axes((l, b, w, h * (1 - h_pair_tests)))
        axbar = plt.axes(
            (l, b + h * (1 - h_pair_tests), w, h * h_pair_tests * 0.7))
    else:
        ax = plt.axes((l, b, w, h))

    # Define the model colors
    if colors is None:  # no color passed...
        colors = [0, 0.4, 0.9, 1]  # use default blue
    elif isinstance(colors, cm.colors.LinearSegmentedColormap):
        cmap = cm.get_cmap(colors)
        colors = cmap(np.linspace(0, 1, 100))[np.newaxis, :, :3].squeeze()
    colors = np.array([np.array(col) for col in colors])
    if len(colors.shape) == 1:  # one color passed...
        n_col, n_chan = 1, colors.shape[0]
        colors.shape = (n_col, n_chan)
    else:  # multiple colors passed...
        n_col, n_chan = colors.shape
        if n_col == n_models:  # one color passed for each model...
            cols2 = colors
        else:  # number of colors passed does not match number of models...
            # interpolate colors to define a color for each model
            cols2 = np.empty((n_models, n_chan))
            for c in range(n_chan):
                cols2[:,
                      c] = np.interp(np.array(range(n_models)),
                                     np.array(range(n_col)) / n_col * n_models,
                                     colors[:, c])
        if sort != 'unsorted':
            colors = cols2[idx, :]
        else:
            colors = cols2
    if colors.shape[1] == 3:
        colors = np.concatenate((colors, np.ones((colors.shape[0], 1))),
                                axis=1)

    # Plot bars and error bars
    ax.bar(np.arange(evaluations.shape[1]), perf, color=colors)
    if error_bars is True:
        error_bars = 'sem'
    if error_bars.lower() == 'sem':
        errorbar_low = np.std(evaluations, axis=0)
        errorbar_high = np.std(evaluations, axis=0)
    elif error_bars[0:2].lower() == 'ci':
        if len(error_bars) == 2:
            CI_percent = 95
        else:
            CI_percent = int(error_bars[2:])
        prop_cut = (1 - CI_percent / 100) / 2
        framed_evals = np.concatenate(
            (np.tile(np.array((-np.inf, np.inf)).reshape(2, 1),
                     (1, n_models)), evaluations),
            axis=0)
        errorbar_low = -(np.quantile(framed_evals, prop_cut, axis=0) - perf)
        errorbar_high = (np.quantile(framed_evals, 1 - prop_cut, axis=0) -
                         perf)
        limits = np.concatenate((errorbar_low, errorbar_high))
        if np.isnan(limits).any() or (abs(limits) == np.inf).any():
            raise Exception(
                'plot_model_comparison: Too few bootstrap samples for the ' +
                'requested confidence interval: ' + error_bars + '.')
    elif error_bars:
        raise Exception('plot_model_comparison: Argument ' +
                        'error_bars is incorrectly defined as ' + error_bars +
                        '.')
    if error_bars:
        ax.errorbar(np.arange(evaluations.shape[1]),
                    perf,
                    yerr=[errorbar_low, errorbar_high],
                    fmt='none',
                    ecolor='k',
                    capsize=0,
                    linewidth=3)

    # Test whether model performance exceeds 0 (one sided)
    if test_above_0 is True:
        test_above_0 = 'dewdrops'
    if test_above_0:
        p = ((evaluations < 0).sum(axis=0) + 1) / n_bootstraps
        model_significant = p < alpha / n_models
        half_sym_size = 9
        if test_above_0.lower() == 'dewdrops':
            halfmoonup = Path.wedge(0, 180)
            ax.plot(model_significant.nonzero()[0],
                    np.tile(0, model_significant.sum()),
                    'w',
                    marker=halfmoonup,
                    markersize=half_sym_size,
                    linewidth=0)
        elif test_above_0.lower() == 'icicles':
            ax.plot(model_significant.nonzero()[0],
                    np.tile(0, model_significant.sum()),
                    'w',
                    marker=10,
                    markersize=half_sym_size,
                    linewidth=0)
        else:
            raise Exception('plot_model_comparison: Argument test_above_0' +
                            ' is incorrectly defined as ' + test_above_0 + '.')

    # Plot noise ceiling
    noise_ceil_col = [0.5, 0.5, 0.5, 0.2]
    if noise_ceiling is not None:
        noise_lower = np.nanmean(noise_ceiling[0])
        noise_upper = np.nanmean(noise_ceiling[1])
        noiserect = patches.Rectangle((-0.5, noise_lower),
                                      len(perf),
                                      noise_upper - noise_lower,
                                      linewidth=0,
                                      facecolor=noise_ceil_col,
                                      zorder=1e6)
        ax.add_patch(noiserect)

    # Test whether model performance is below the noise ceiling's lower bound
    # (one sided)
    if test_below_noise_ceil is True:
        test_below_noise_ceil = 'dewdrops'
    if test_below_noise_ceil:
        if len(noise_ceiling.shape) > 1:
            noise_lower_bs = noise_ceiling[0]
            noise_lower_bs.shape = (noise_lower_bs.shape[0], 1)
        else:
            noise_lower_bs = noise_ceiling[0].reshape(1, 1)
        diffs = noise_lower_bs - evaluations  # positive if below lower bound
        p = ((diffs < 0).sum(axis=0) + 1) / n_bootstraps
        model_below_lower_bound = p < alpha / n_models

        if test_below_noise_ceil.lower() == 'dewdrops':
            halfmoondown = Path.wedge(180, 360)
            ax.plot(model_below_lower_bound.nonzero()[0],
                    np.tile(noise_lower + 0.0000,
                            model_below_lower_bound.sum()),
                    color='none',
                    marker=halfmoondown,
                    markersize=half_sym_size,
                    markerfacecolor=noise_ceil_col,
                    markeredgecolor='none',
                    linewidth=0)
        elif test_below_noise_ceil.lower() == 'icicles':
            ax.plot(model_below_lower_bound.nonzero()[0],
                    np.tile(noise_lower + 0.0007,
                            model_below_lower_bound.sum()),
                    color='none',
                    marker=11,
                    markersize=half_sym_size,
                    markerfacecolor=noise_ceil_col,
                    markeredgecolor='none',
                    linewidth=0)
        else:
            raise Exception(
                'plot_model_comparison: Argument ' +
                'test_below_noise_ceil is incorrectly defined as ' +
                test_below_noise_ceil + '.')

    # Pairwise model comparisons
    if test_pair_comparisons:
        model_comp_descr = 'Model comparisons: two-tailed, '
        p_values = pair_tests(evaluations)
        n_tests = int((n_models**2 - n_models) / 2)
        if multiple_pair_testing is None:
            multiple_pair_testing = 'uncorrected'
        if multiple_pair_testing.lower() == 'bonferroni' or \
           multiple_pair_testing.lower() == 'fwer':
            significant = p_values < (alpha / n_tests)
            model_comp_descr = (model_comp_descr +
                                'p < {:<.5g}'.format(alpha) +
                                ', Bonferroni-corrected for ' + str(n_tests) +
                                ' model-pair comparisons')
        elif multiple_pair_testing.lower() == 'fdr':
            ps = batch_to_vectors(np.array([p_values]))[0][0]
            ps = np.sort(ps)
            criterion = alpha * (np.arange(ps.shape[0]) + 1) / ps.shape[0]
            k_ok = ps < criterion
            if np.any(k_ok):
                k_max = np.max(np.where(ps < criterion)[0])
                crit = criterion[k_max]
            else:
                crit = 0
            significant = p_values < crit
            model_comp_descr = (model_comp_descr +
                                'FDR q < {:<.5g}'.format(alpha) + ' (' +
                                str(n_tests) + ' model-pair comparisons)')
        else:
            if 'uncorrected' not in multiple_pair_testing.lower():
                raise Exception(
                    'plot_model_comparison: Argument ' +
                    'multiple_pair_testing is incorrectly defined as ' +
                    multiple_pair_testing + '.')
            significant = p_values < alpha
            model_comp_descr = (model_comp_descr +
                                'p < {:<.5g}'.format(alpha) +
                                ', uncorrected (' + str(n_tests) +
                                ' model-pair comparisons)')
        if result.cv_method in [
                'bootstrap_rdm', 'bootstrap_pattern', 'bootstrap_crossval'
        ]:
            model_comp_descr = model_comp_descr + \
                '\nInference by bootstrap resampling ' + \
                '({:<,.0f}'.format(n_bootstraps) + ' bootstrap samples) of '
        if result.cv_method == 'bootstrap_rdm':
            model_comp_descr = model_comp_descr + 'subjects. '
        elif result.cv_method == 'bootstrap_pattern':
            model_comp_descr = model_comp_descr + 'experimental conditions. '
        elif result.cv_method in ['bootstrap', 'bootstrap_crossval']:
            model_comp_descr = model_comp_descr + \
                'subjects and experimental conditions. '
        model_comp_descr = model_comp_descr + 'Error bars indicate the'
        if error_bars[0:2].lower() == 'ci':
            model_comp_descr = (model_comp_descr + ' ' + str(CI_percent) +
                                '% confidence interval.')
        elif error_bars.lower() == 'sem':
            model_comp_descr = (model_comp_descr +
                                ' standard error of the mean.')
        if test_above_0 or test_below_noise_ceil:
            model_comp_descr = (
                model_comp_descr +
                '\nOne-sided comparisons of each model performance ')
        if test_above_0:
            model_comp_descr = model_comp_descr + 'against 0 '
        if test_above_0 and test_below_noise_ceil:
            model_comp_descr = model_comp_descr + 'and '
        if test_below_noise_ceil:
            model_comp_descr = (
                model_comp_descr +
                'against the lower-bound estimate of the noise ceiling ')
        if test_above_0 or test_below_noise_ceil:
            model_comp_descr = (model_comp_descr +
                                'are Bonferroni-corrected for ' +
                                str(n_models) + ' models.')

        fig.suptitle(model_comp_descr, fontsize=fs2 / 2)
        axbar.set_xlim(ax.get_xlim())
        digits = [d for d in list(test_pair_comparisons) if d.isdigit()]
        if len(digits) > 0:
            v = int(digits[0])
        else:
            v = None
        if 'nili' in test_pair_comparisons.lower():
            if v:
                plot_nili_bars(axbar, significant, version=v)
            else:
                plot_nili_bars(axbar, significant)
        elif 'golan' in test_pair_comparisons.lower():
            if v:
                plot_golan_wings(axbar,
                                 significant,
                                 perf,
                                 sort,
                                 colors,
                                 version=v)
            else:
                plot_golan_wings(axbar, significant, perf, sort, colors)
        elif 'arrows' in test_pair_comparisons.lower():
            plot_arrows(axbar, significant)
        elif 'cliques' in test_pair_comparisons.lower():
            plot_cliques(axbar, significant)

    # Floating axes
    ytoptick = np.floor(min(1, noise_upper) * 10) / 10
    ax.set_yticks(np.arange(0, ytoptick + 1e-6, step=0.1))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xticks(np.arange(n_models))
    ax.spines['left'].set_bounds(0, ytoptick)
    ax.spines['bottom'].set_bounds(0, n_models - 1)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    plt.rc('ytick', labelsize=fs2)

    # Axis labels
    ylabel_fig_x, ysublabel_fig_x = 0.07, 0.095
    trans = transforms.blended_transform_factory(fig.transFigure,
                                                 ax.get_yaxis_transform())
    ax.text(ylabel_fig_x,
            ytoptick / 2,
            'RDM prediction accuracy',
            horizontalalignment='center',
            verticalalignment='center',
            rotation='vertical',
            fontsize=fs,
            fontweight='bold',
            transform=trans)
    if method.lower() == 'cosine':
        ax.set_ylabel('[across-subject mean of cosine similarity]',
                      fontsize=fs2)
    if method.lower() in ['cosine_cov', 'whitened cosine']:
        ax.set_ylabel('[across-subject mean of whitened-RDM cosine]',
                      fontsize=fs2)
    elif method.lower() == 'spearman':
        ax.set_ylabel('[across-subject mean of Spearman r rank correlation]',
                      fontsize=fs2)
    elif method.lower() in ['corr', 'pearson']:
        ax.text(ysublabel_fig_x,
                ytoptick / 2,
                '[across-subject mean of Pearson r correlation]',
                horizontalalignment='center',
                verticalalignment='center',
                rotation='vertical',
                fontsize=fs2,
                fontweight='normal',
                transform=trans)
        # ax.set_ylabel('[across-subject mean of Pearson r correlation]',
        #               fontsize=fs2)
    elif method.lower() in ['whitened pearson', 'corr_cov']:
        ax.set_ylabel('[across-subject mean of whitened-RDM Pearson r ' +
                      'correlation]',
                      fontsize=fs2)
    elif method.lower() in ['kendall', 'tau-b']:
        ax.set_ylabel('[across-subject mean of Kendall tau-b rank ' +
                      'correlation]',
                      fontsize=fs2)
    elif method.lower() == 'tau-a':
        ax.set_ylabel('[across-subject mean of ' +
                      'Kendall tau-a rank correlation]',
                      fontsize=fs2)
    if models is not None:
        ax.set_xticklabels([m.name for m in models], fontsize=fs2, rotation=45)
Esempio n. 35
0
def analyze_droplets(droplets):
    voffs = []
    vons = []
    voff_errs = []
    von_errs = []
    Vs = []
    for i in range(len(droplets)):

        droplet = droplets[i]

        fig, axes = plt.subplots(1, 2, sharex=True)

        xpos, ypos, xvel, yvel, cov, birth_time, lifetime = unpack_droplet(
            droplet)
        ypos = (ymax - ypos) * px2mm
        ysmooth = savgol_filter(ypos, 31, 2)

        yvel_smooth = savgol_filter(ypos, 21, 1)
        yerr = np.sqrt(cov[:, 2, 2]) * px2mm
        ##        yvel_err = np.sqrt(cov[:,3,3])

        ##        yvel_err = np.sqrt(cov[:,3,3])*px2mm#[np.sqrt(yerr[i]**2+yerr[i+1]**2)/dt for i in range(len(yerr)-1)]
        ##        yvel_err.append(yvel_err[-1])
        ##
        ##        yvel_err = np.array(yvel_err)

        yvel_ = savgol_filter(ypos, 31, 4, deriv=1,
                              delta=dt)  #np.diff(ysmooth)/dt
        yacc = savgol_filter(ypos, 31, 2, deriv=2,
                             delta=dt)  #np.diff(ysmooth)/dt

        yvel_err = savgol_error(ypos, yerr, 31, 4, deriv=1, delta=dt)

        time_array = np.linspace(birth_time, birth_time + lifetime, len(ypos))

        axes[0].plot(time_array, yvel_)
        axes[0].fill_between(time_array,
                             yvel_ - yvel_err,
                             yvel_ + yvel_err,
                             alpha=0.7)
        axes[1].plot(time_array, yacc)

        transform = blended_transform_factory(axes[0].transData,
                                              axes[0].transAxes)

        axes[0].vlines(split_times,
                       0,
                       1,
                       colors=line_colors,
                       linestyles='dashed',
                       transform=transform)

        transform = blended_transform_factory(axes[1].transData,
                                              axes[1].transAxes)

        axes[1].vlines(split_times,
                       0,
                       1,
                       colors=line_colors,
                       linestyles='dashed',
                       transform=transform)

        plt.xlim(np.min(time_array), np.max(time_array))
        axes[0].set_ylim(np.min(yvel_), np.max(yvel_))
        axes[1].set_ylim(np.min(yacc), np.max(yacc))
        print('')
        print(
            "Please pick out an approximate start and end time around ONE voltage change, then close the plot. Be prepared to enter these numbers."
        )
        print(
            "You may veto any droplet by entering a negative start or end time."
        )

        axes[0].set_xlabel('time [s]')
        axes[0].set_ylabel('velocity [mm/s]')
        axes[1].set_ylabel('acceleration [mm/s^2]')
        fig.suptitle(f"droplet id: {droplet.dropletid}")
        plt.tight_layout()
        plt.show()

        tmin = float(input("time start: "))
        tmax = float(input("time end: "))
        print('')

        if tmin < 0 or tmax < 0:
            print(f'droplet {droplet.dropletid} vetoed!')
            continue

        switch_index = np.argwhere((tmin < split_times)
                                   & (tmax > split_times))[0][0]

        time_switch = split_times[switch_index]
        line_color = line_colors[switch_index]
        switch_voltage = split_voltages[switch_index] if split_voltages[
            switch_index] != 0 else split_voltages[switch_index - 1]

        ##        v_left_median = np.median(yvel_[(tmin<=time_array)&(time_switch>=time_array)])
        ##        v_right_median = np.median(yvel_[(tmax>=time_array)&(time_switch<=time_array)])

        left = (tmin <= time_array) & (time_switch >= time_array)
        right = (tmax >= time_array) & (time_switch <= time_array)

        yacc_left = yacc[left]
        left_weights = np.exp(-np.abs(yacc_left) /
                              np.median(np.abs(yacc_left)))
        left_weights = left_weights / np.sum(left_weights)
        yacc_right = yacc[right]
        right_weights = np.exp(-np.abs(yacc_right) /
                               np.median(np.abs(yacc_right)))
        right_weights = right_weights / np.sum(right_weights)

        v_left_mean = np.sum(left_weights * yvel_[left] / yvel_err[left]**
                             2) / np.sum(left_weights / yvel_err[left]**2)
        v_right_mean = np.sum(right_weights * yvel_[right] / yvel_err[right]**
                              2) / np.sum(right_weights / yvel_err[right]**2)
        ##        v_left_mean = np.average(yvel_[left], weights = left_weights)
        ##        v_right_mean = np.average(yvel_[right], weights = right_weights)

        v_left_err = np.sqrt(
            np.sum(left_weights**2 / yvel_err[left]**2) /
            np.sum(left_weights / yvel_err[left]**2)**2)
        v_right_err = np.sqrt(
            np.sum(right_weights**2 / yvel_err[right]**2) /
            np.sum(right_weights / yvel_err[right]**2)**2)

        ##        v_left_err = np.sqrt(1/np.sum(1/yvel_err[left]**2))
        ##        v_right_err = np.sqrt(1/np.sum(1/yvel_err[right]**2))

        ##        print(v_left_mean,'+/-',v_left_err)
        ##        print(v_right_mean,'+/-',v_right_err)

        ##        print(v_left_median, v_left_mean, np.mean(yvel_[(tmin<=time_array)&(time_switch>=time_array)]))
        ##        print(v_right_median, v_right_mean, np.mean(yvel_[(tmax>=time_array)&(time_switch<=time_array)]))
        ####
        plt.plot(time_array, yvel_)
        plt.xlabel('time [s]')
        plt.ylabel('velocity [mm/s]')
        plt.title(f"droplet id: {droplet.dropletid}")
        plt.hlines([v_left_mean],
                   tmin,
                   time_switch,
                   colors='m',
                   linestyles=['dashed'])
        plt.fill_between([tmin, time_switch], [v_left_mean - v_left_err] * 2,
                         [v_left_mean + v_left_err] * 2,
                         color='m',
                         alpha=0.7)
        plt.hlines([v_right_mean],
                   time_switch,
                   tmax,
                   colors='g',
                   linestyles=['dashed'])
        plt.fill_between([time_switch, tmax], [v_right_mean - v_right_err] * 2,
                         [v_right_mean + v_right_err] * 2,
                         color='g',
                         alpha=0.7)

        plt.show()

        if line_color != 'k':
            voffs.append(v_left_mean)
            voff_errs.append(v_left_err)
            vons.append(v_right_mean)
            von_errs.append(v_right_err)
        else:
            voffs.append(v_right_mean)
            voff_errs.append(v_right_err)
            vons.append(v_left_mean)
            von_errs.append(v_left_err)

        Vs.append(switch_voltage)  #V/m

    voffs = np.array(voffs)
    vons = np.array(vons)
    voff_errs = np.array(voff_errs)
    von_errs = np.array(von_errs)
    Vs = np.array(Vs)
    return voffs / 1000, vons / 1000, voff_errs / 1000, von_errs / 1000, Vs  #converting from mm/s to m/s
Esempio n. 36
0
    def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
        """Plot partial dependence plots.

        Parameters
        ----------
        ax : Matplotlib axes or array-like of Matplotlib axes, default=None
            - If a single axis is passed in, it is treated as a bounding axes
                and a grid of partial dependence plots will be drawn within
                these bounds. The `n_cols` parameter controls the number of
                columns in the grid.
            - If an array-like of axes are passed in, the partial dependence
                plots will be drawn directly into these axes.
            - If `None`, a figure and a bounding axes is created and treated
                as the single axes case.

        n_cols : int, default=3
            The maximum number of columns in the grid plot. Only active when
            `ax` is a single axes or `None`.

        line_kw : dict, default=None
            Dict with keywords passed to the `matplotlib.pyplot.plot` call.
            For one-way partial dependence plots.

        contour_kw : dict, default=None
            Dict with keywords passed to the `matplotlib.pyplot.contourf`
            call for two-way partial dependence plots.

        Returns
        -------
        display: :class:`~sklearn.inspection.PartialDependenceDisplay`
        """

        check_matplotlib_support("plot_partial_dependence")
        import matplotlib.pyplot as plt  # noqa
        from matplotlib import transforms  # noqa
        from matplotlib.ticker import MaxNLocator  # noqa
        from matplotlib.ticker import ScalarFormatter  # noqa
        from matplotlib.gridspec import GridSpecFromSubplotSpec  # noqa

        if line_kw is None:
            line_kw = {}
        if contour_kw is None:
            contour_kw = {}

        if ax is None:
            _, ax = plt.subplots()

        default_contour_kws = {"alpha": 0.75}
        contour_kw = {**default_contour_kws, **contour_kw}

        n_features = len(self.features)

        if isinstance(ax, plt.Axes):
            # If ax was set off, it has most likely been set to off
            # by a previous call to plot.
            if not ax.axison:
                raise ValueError("The ax was already used in another plot "
                                 "function, please set ax=display.axes_ "
                                 "instead")

            ax.set_axis_off()
            self.bounding_ax_ = ax
            self.figure_ = ax.figure

            n_cols = min(n_cols, n_features)
            n_rows = int(np.ceil(n_features / float(n_cols)))

            self.axes_ = np.empty((n_rows, n_cols), dtype=np.object)
            self.lines_ = np.empty((n_rows, n_cols), dtype=np.object)
            self.contours_ = np.empty((n_rows, n_cols), dtype=np.object)

            axes_ravel = self.axes_.ravel()

            gs = GridSpecFromSubplotSpec(n_rows,
                                         n_cols,
                                         subplot_spec=ax.get_subplotspec())
            for i, spec in zip(range(n_features), gs):
                axes_ravel[i] = self.figure_.add_subplot(spec)

        else:  # array-like
            ax = np.asarray(ax, dtype=object)
            if ax.size != n_features:
                raise ValueError("Expected ax to have {} axes, got {}".format(
                    n_features, ax.size))

            if ax.ndim == 2:
                n_cols = ax.shape[1]
            else:
                n_cols = None

            self.bounding_ax_ = None
            self.figure_ = ax.ravel()[0].figure
            self.axes_ = ax
            self.lines_ = np.empty_like(ax, dtype=np.object)
            self.contours_ = np.empty_like(ax, dtype=np.object)

        # create contour levels for two-way plots
        if 2 in self.pdp_lim:
            Z_level = np.linspace(*self.pdp_lim[2], num=8)
        lines_ravel = self.lines_.ravel(order='C')
        contours_ravel = self.contours_.ravel(order='C')

        for i, axi, fx, (avg_preds, values) in zip(count(), self.axes_.ravel(),
                                                   self.features,
                                                   self.pd_results):
            if len(values) == 1:
                lines_ravel[i] = axi.plot(values[0],
                                          avg_preds[self.target_idx].ravel(),
                                          **line_kw)[0]
            else:
                # contour plot
                XX, YY = np.meshgrid(values[0], values[1])
                Z = avg_preds[self.target_idx].T
                CS = axi.contour(XX,
                                 YY,
                                 Z,
                                 levels=Z_level,
                                 linewidths=0.5,
                                 colors='k')
                contours_ravel[i] = axi.contourf(XX,
                                                 YY,
                                                 Z,
                                                 levels=Z_level,
                                                 vmax=Z_level[-1],
                                                 vmin=Z_level[0],
                                                 **contour_kw)
                axi.clabel(CS,
                           fmt='%2.2f',
                           colors='k',
                           fontsize=10,
                           inline=True)

            trans = transforms.blended_transform_factory(
                axi.transData, axi.transAxes)
            ylim = axi.get_ylim()
            axi.vlines(self.deciles[fx[0]],
                       0,
                       0.05,
                       transform=trans,
                       color='k')
            axi.set_ylim(ylim)

            # Set xlabel if it is not already set
            if not axi.get_xlabel():
                axi.set_xlabel(self.feature_names[fx[0]])

            if len(values) == 1:
                if n_cols is None or i % n_cols == 0:
                    axi.set_ylabel('Partial dependence')
                else:
                    axi.set_yticklabels([])
                axi.set_ylim(self.pdp_lim[1])
            else:
                # contour plot
                trans = transforms.blended_transform_factory(
                    axi.transAxes, axi.transData)
                xlim = axi.get_xlim()
                axi.hlines(self.deciles[fx[1]],
                           0,
                           0.05,
                           transform=trans,
                           color='k')
                # hline erases xlim
                axi.set_ylabel(self.feature_names[fx[1]])
                axi.set_xlim(xlim)
        return self
Esempio n. 37
0
    for (ra, dec, dist), color in zip(opts.radecdist, colors[1:]):
        theta = 0.5 * np.pi - np.deg2rad(dec)
        phi = np.deg2rad(ra)
        ipix = hp.ang2pix(nside, theta, phi)
        ax.fill_between(d,
                        scipy.stats.norm(mu[ipix], sigma[ipix]).pdf(d) *
                        norm[ipix] * np.square(d),
                        alpha=0.5,
                        color=color)
        ax.axvline(dist, color='black', linewidth=0.5)
        ax.plot([dist], [-0.15],
                marker=truth_marker,
                markeredgecolor=color,
                markerfacecolor='none',
                markeredgewidth=1,
                clip_on=False,
                transform=transforms.blended_transform_factory(
                    ax.transData, ax.transAxes))
        ax.axvline(dist, color='black', linewidth=0.5)

    # Scale axes
    ax.set_xticks([0, max_distance])
    ax.set_xticklabels(['0', "{0:d}\nMpc".format(int(np.round(max_distance)))],
                       fontsize=9)
    ax.set_yticks([])
    ax.set_xlim(0, max_distance)
    ax.set_ylim(0, ax.get_ylim()[1])

progress.update(-1, 'Saving')
opts.output()
Esempio n. 38
0
def plot_twopoints(univariate,
                   condition_label=None,
                   trefs=[],
                   ntrefs=4,
                   trange=(-100., 100.)):
    """Plot two-point functions.

    Parameters
    ----------
    univariate : :class:`Univariate` instance
    trefs : flist of floats
        indicate the times that you would like to have as references
        if left empty, reference times will be computed automatically
    ntrefs : int
        if trefs is empty, number of times of reference to display
    """
    obs = univariate.obs
    fig, axs = plt.subplots(3, 1, figsize=(6, 9))

    # define time label
    if univariate.obs.mode != 'dynamics' and univariate.obs.timing == 'g':
        timelabel = 'Generations'
        if univariate.obs.tref is not None:
            timelabel += ' (since tref {})'.format(univariate.obs.tref)
    else:
        timelabel = 'Time (minutes)'

    # choice of index/indices for time of reference
    times = univariate['master'].time
    npoints = len(times)
    if not trefs:
        di = npoints // ntrefs + 1
        indices = np.arange(0, npoints, di, dtype=int)
        trefs = times[indices]

    for c_label in ['master', condition_label]:
        if c_label is None:
            continue
        if c_label == 'master':
            lt = '-'
        else:
            lt = '--'

        times = univariate[c_label].time
        counts = univariate[c_label].count_two
        corr = univariate[c_label].autocorr
        var = np.diagonal(corr)

        valid = counts != 0

        for tref in trefs:
            # this tref may not be in conditioned data (who knows)
            if np.amin(np.abs(times - tref)) > 1.:
                continue
            index = np.argmin(np.abs(times - tref))
            lab = '{:.0f} mins'.format(tref)

            ax = axs[0]
            ok = np.where(counts[index, :] > 0)
            dat, = ax.plot(times[ok],
                           counts[index, :][ok],
                           ls=lt,
                           label=r'$t_{{\mathrm{{ref}}}}=${}'.format(lab))
            color = dat.get_color()
            ax.plot((tref, tref), (0, counts[index, index]),
                    ls=':',
                    color=color)
            # add text to point to tref
            if index < len(times) - 1:
                ax.text(times[index + 1],
                        0,
                        lab,
                        color=color,
                        transform=ax.transData)
            else:
                ax.text(times[index - 2],
                        0,
                        lab,
                        color=color,
                        transform=ax.transData)
            ax.set_xlabel(timelabel)

            ax = axs[1]
            dat, = ax.plot(times[valid[index, :]],
                           corr[index, :][valid[index, :]] / var[index],
                           ls=lt)
            color = dat.get_color()
            # define heterogeneous transform
            # the x coords of this transformation are data, and the
            # y coord are axes
            trans = transforms.blended_transform_factory(
                ax.transData, ax.transAxes)
            ax.axvline(tref, ymin=0.1, ymax=0.9, ls=':', color=color)
            # add text to point to tref
            if index < len(times) - 1:
                ax.text(times[index + 1],
                        0.1,
                        lab,
                        color=color,
                        transform=trans)
            else:
                ax.text(times[index - 2],
                        0.1,
                        lab,
                        color=color,
                        transform=trans)
            xmin, xmax = ax.xaxis.get_data_interval()
            ax.axhline(0, ls='--', color='k')
            ax.set_xlabel(timelabel)

            ax = axs[2]
            ax.plot(times[valid[index, :]] - tref,
                    corr[index, :][valid[index, :]] / var[index],
                    ls=lt)
            # ax.set_yscale('log')
            ax.axhline(0, ls='--', color='k')
            trangeleft, trangeright = xmin - xmax, xmax - xmin
            ax.set_xlim(left=trangeleft, right=trangeright)
            ax.set_xlabel('Delta' + timelabel)

    # legends
    axs[0].legend(loc=2)

    master_line = mlines.Line2D([], [],
                                color='C7',
                                ls='-',
                                label='all samples')
    handles = [
        master_line,
    ]
    if condition_label is not None:
        c_line = mlines.Line2D([], [],
                               color='C7',
                               ls='--',
                               label=condition_label)
        handles.append(c_line)
    axs[1].legend(handles=handles, loc=2)

    # ticks and labels
    # first axes locators are integers
    axs[0].yaxis.set_major_locator(MaxNLocator(integer=True))
    axs[0].tick_params(axis='x', direction='out', top='on', labeltop='on')
    axs[0].tick_params(axis='x',
                       direction='in',
                       bottom='on',
                       labelbottom='off')
    axs[1].tick_params(axis='x', direction='in')

    axs[2].set_xlabel(timelabel,
                      x=.95,
                      horizontalalignment='right',
                      fontsize='large')
    axs[0].xaxis.set_label_position('top')
    axs[0].set_xlabel(timelabel,
                      x=.95,
                      horizontalalignment='right',
                      fontsize='large')

    # ylabels
    axs[0].set_ylabel(r'Samples $\langle t_{\mathrm{ref}} | t \rangle$',
                      fontsize='large')
    axs[1].set_ylabel(r'Autocorr. $g(t_{\mathrm{ref}}, t)$', fontsize='large')
    axs[2].set_ylabel(r'Shifted $g(t_{\mathrm{ref}}, t- t_{\mathrm{ref}})$',
                      fontsize='large')

    latex_obs = obs.as_latex_string()
    axs[0].text(0.5,
                1.3,
                r' Autocorrelation fcts for {}'.format(latex_obs),
                size='large',
                horizontalalignment='center',
                verticalalignment='bottom',
                transform=axs[0].transAxes)
    fig.subplots_adjust(hspace=.1)
    return fig
Esempio n. 39
0
                       top=0.975,
                       wspace=0.1,
                       hspace=0.1,
                       width_ratios=None,
                       height_ratios=[2, 1])

ax1 = plt.subplot(gs[0, 0])  # top left
ax2 = plt.subplot(gs[0, 1], sharex=ax1)  # top right
ax3 = plt.subplot(gs[1, 0], sharex=ax1)  # bottom left
ax4 = plt.subplot(gs[1, 1], sharex=ax1)  # bottom right

fontdict_titles = {'fontsize': 9, 'fontweight': 'bold'}
fontdict_axis = {'fontsize': 10, 'fontweight': 'bold'}

# the x coords of the event labels are data, and the y coord are axes
event_label_transform = transforms.blended_transform_factory(
    ax1.transData, ax1.transAxes)

trialtype_colors = {
    'SGHF': 'MediumBlue',
    'SGLF': 'Turquoise',
    'PGHF': 'DarkGreen',
    'PGLF': 'YellowGreen',
    'LFSG': 'Orange',
    'LFPG': 'Yellow',
    'HFSG': 'DarkRed',
    'HFPG': 'OrangeRed',
    'SGSG': 'SteelBlue',
    'PGPG': 'LimeGreen',
    None: 'black'
}
Esempio n. 40
0
def gsea_plot(rank_metric, enrich_term, hit_ind, nes, pval, fdr, RES, phenoPos,
              phenoNeg, figsize, format, outdir, module):
    """This is the main function for reproducing the gsea plot.

    :param rank_metric: pd.Series for rankings, rank_metric.values.
    :param enrich_term: gene_set name
    :param hit_ind: hits indices of rank_metric.index presented in gene set S.
    :param nes: Normalized enrichment scores.
    :param pval: nominal p-value.
    :param fdr: false discovery rate.
    :param RES: running enrichment scores.
    :param phenoPos: phenotype label, positive correlated.
    :param phenoNeg: phenotype label, negative correlated.
    :param figsize: matplotlib figsize.
    :return:
    """
    # plt.style.use('classic')
    # center color map at midpoint = 0
    norm = _MidpointNormalize(midpoint=0)

    #dataFrame of ranked matrix scores
    x = np.arange(len(rank_metric))
    rankings = rank_metric.values
    # figsize = (6,6)
    phenoP_label = phenoPos + ' (Positively Correlated)'
    phenoN_label = phenoNeg + ' (Negatively Correlated)'
    zero_score_ind = np.abs(rankings).argmin()
    z_score_label = 'Zero score at ' + str(zero_score_ind)
    nes_label = 'NES: ' + "{:.3f}".format(float(nes))
    pval_label = 'Pval: ' + "{:.3f}".format(float(pval))
    fdr_label = 'FDR: ' + "{:.3f}".format(float(fdr))
    im_matrix = np.tile(rankings, (2, 1))

    # output truetype
    plt.rcParams.update({'pdf.fonttype': 42, 'ps.fonttype': 42})
    # in most case, we will have many plots, so do not display plots
    # It's also usefull to run this script on command line.

    # GSEA Plots
    gs = plt.GridSpec(16, 1)
    # fig = plt.figure(figsize=figsize)
    fig = Figure(figsize=figsize)
    canvas = FigureCanvas(fig)
    # Ranked Metric Scores Plot
    ax1 = fig.add_subplot(gs[11:])
    if module == 'ssgsea':
        nes_label = 'ES: ' + "{:.3f}".format(float(nes))
        pval_label = 'Pval: '
        fdr_label = 'FDR: '
        ax1.fill_between(x, y1=np.log(rankings), y2=0, color='#C9D3DB')
        ax1.set_ylabel("log ranked metric", fontsize=14)
    else:
        ax1.fill_between(x, y1=rankings, y2=0, color='#C9D3DB')
        ax1.set_ylabel("Ranked list metric", fontsize=14)
    ax1.text(.05,
             .9,
             phenoP_label,
             color='red',
             horizontalalignment='left',
             verticalalignment='top',
             transform=ax1.transAxes)
    ax1.text(.95,
             .05,
             phenoN_label,
             color='Blue',
             horizontalalignment='right',
             verticalalignment='bottom',
             transform=ax1.transAxes)
    # the x coords of this transformation are data, and the y coord are axes
    trans1 = transforms.blended_transform_factory(ax1.transData, ax1.transAxes)
    ax1.vlines(zero_score_ind,
               0,
               1,
               linewidth=.5,
               transform=trans1,
               linestyles='--',
               color='grey')
    ax1.text(zero_score_ind,
             0.5,
             z_score_label,
             horizontalalignment='right' if module == 'ssgsea' else 'center',
             verticalalignment='center',
             transform=trans1)
    ax1.set_xlabel("Rank in Ordered Dataset", fontsize=14)
    ax1.spines['top'].set_visible(False)
    ax1.tick_params(axis='both',
                    which='both',
                    top='off',
                    right='off',
                    left='off')
    ax1.locator_params(axis='y', nbins=5)
    ax1.yaxis.set_major_formatter(
        plt.FuncFormatter(
            lambda tick_loc, tick_num: '{:.1f}'.format(tick_loc)))

    # use round method to control float number
    # ax1.yaxis.set_major_formatter(plt.FuncFormatter(lambda tick_loc,tick_num :  round(tick_loc, 1) ))

    # gene hits
    ax2 = fig.add_subplot(gs[8:10], sharex=ax1)

    # the x coords of this transformation are data, and the y coord are axes
    trans2 = transforms.blended_transform_factory(ax2.transData, ax2.transAxes)
    ax2.vlines(hit_ind, 0, 1, linewidth=.5, transform=trans2)
    ax2.spines['bottom'].set_visible(False)
    ax2.tick_params(axis='both',
                    which='both',
                    bottom='off',
                    top='off',
                    labelbottom='off',
                    right='off',
                    left='off',
                    labelleft='off')
    # colormap
    ax3 = fig.add_subplot(gs[10], sharex=ax1)
    ax3.imshow(im_matrix,
               aspect='auto',
               norm=norm,
               cmap=plt.cm.seismic,
               interpolation='none')  # cm.coolwarm
    ax3.spines['bottom'].set_visible(False)
    ax3.tick_params(axis='both',
                    which='both',
                    bottom='off',
                    top='off',
                    labelbottom='off',
                    right='off',
                    left='off',
                    labelleft='off')

    # Enrichment score plot
    ax4 = fig.add_subplot(gs[:8], sharex=ax1)
    ax4.plot(x, RES, linewidth=4, color='#88C544')
    ax4.text(.1, .1, fdr_label, transform=ax4.transAxes)
    ax4.text(.1, .2, pval_label, transform=ax4.transAxes)
    ax4.text(.1, .3, nes_label, transform=ax4.transAxes)

    # the y coords of this transformation are data, and the x coord are axes
    trans4 = transforms.blended_transform_factory(ax4.transAxes, ax4.transData)
    ax4.hlines(0, 0, 1, linewidth=.5, transform=trans4, color='grey')
    ax4.set_ylabel("Enrichment score (ES)", fontsize=14)
    ax4.set_xlim(min(x), max(x))
    ax4.tick_params(axis='both',
                    which='both',
                    bottom='off',
                    top='off',
                    labelbottom='off',
                    right='off')
    ax4.locator_params(axis='y', nbins=5)
    # FuncFormatter need two argument, I don't know why. this lambda function used to format yaxis tick labels.
    ax4.yaxis.set_major_formatter(
        plt.FuncFormatter(
            lambda tick_loc, tick_num: '{:.1f}'.format(tick_loc)))

    # fig adjustment
    fig.suptitle(enrich_term, fontsize=16)
    fig.subplots_adjust(hspace=0)
    # fig.tight_layout()
    # plt.close(fig)
    enrich_term = enrich_term.replace('/', '_').replace(":", "_")
    canvas.print_figure(
        '{0}/{1}.{2}.{3}'.format(outdir, enrich_term, module, format),
        bbox_inches='tight',
        dpi=300,
    )
    return
Esempio n. 41
0
# Add tick labels for the x-axis ticks on the 1st and 6th plots.

for ax in axs[[0, 5]]:
    ax.tick_params(labelbottom=True)
    ax.set_xticklabels(["3 days\nbefore", "3 days\nafter"])

    # need to make these tick labels centered at tick,
    # instead of the default of right aligned
    for label in ax.xaxis.get_ticklabels():
        label.set_horizontalalignment("center")

# Transforms and Lines

# Create a transform that ...
trans = transforms.blended_transform_factory(
    fig.transFigure,  # goes across whole in x direction
    axs[0].transData,  # goes up with the y data in the first axis
)

for y in yticks:
    l = plt.Line2D(
        # x values found by trial and error
        [0.04, 0.985],
        [y, y],
        transform=trans,
        color="black",
        alpha=0.4,
        linewidth=0.5,
        zorder=0.1,
    )

    if y == 100:
Esempio n. 42
0
dates_f = seconds_from_epoch(dates)
bins = get_month_bins(dates)

fig, ax = plt.subplots(figsize=(7, 5))

n, bins, _ = ax.hist(dates_f, bins=bins, color='blue', alpha=0.6)

ax.xaxis.set_major_formatter(FuncFormatter(date_formatter))
ax.set_xticks(bins[2:-1:3])  # Date label every 3 months.

labels = ax.get_xticklabels()
for l in labels:
    l.set_rotation(40)
    l.set_size(10)

mixed_transform = blended_transform_factory(ax.transData, ax.transAxes)

for version, date in releases.items():
    date = seconds_from_epoch([date])[0]
    ax.axvline(date, color='black', linestyle=':', label=version)
    ax.text(date,
            1,
            version,
            color='r',
            va='bottom',
            ha='center',
            transform=mixed_transform)

ax.set_title('Pull request activity').set_y(1.05)
ax.set_xlabel('Date')
ax.set_ylabel('PRs per month', color='blue')
Esempio n. 43
0
def plot_result(result, reference=None, names=None, filename=None, window_title=None, events=False, markers=False):
    """ Plot a collection of time series.

    Parameters:
        result:       structured NumPy Array that contains the time series to plot where 'time' is the independent variable
        reference:    optional reference signals with the same structure as `result`
        names:        variables to plot
        filename:     when provided the plot is saved as `filename` instead of showing the figure
        window_title: title for the figure window
        events:       draw vertical lines at events
        markers:      show markers
    """

    from . import plot_library

    if plot_library == 'plotly':
        figure = create_plotly_figure(result, names=names, events=events, markers=markers)
        if filename is None:
            figure.show()
        else:
            figure.write_image(filename)
        return

    import matplotlib.pylab as pylab
    import matplotlib.pyplot as plt
    import matplotlib.transforms as mtransforms
    from matplotlib.ticker import MaxNLocator
    from collections.abc import Iterable

    params = {
        'legend.fontsize': 8,
        'axes.labelsize': 8,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'axes.linewidth': 0.5,
    }

    pylab.rcParams.update(params)

    time = result['time']

    if names is None:

        names = []

        # plot at most 20 one-dimensional signals
        for d in result.dtype.descr[1:]:
            if len(d) < 3 and len(names) < 20:
                names.append(d[0])

    if len(names) > 0:

        # indent label 0.015 inch / character
        label_x = -0.015 * np.max(list(map(len, names)) + [8])

        fig, axes = plt.subplots(len(names), sharex=True)

        fig.set_facecolor('white')

        if not isinstance(axes, Iterable):
            axes = [axes]

        if events:
            t_event = time[np.argwhere(np.diff(time) == 0)]

        for ax, name in zip(axes, names):

            y = result[name]

            ax.grid(b=True, which='both', color='0.8', linestyle='-', zorder=0)

            ax.tick_params(direction='in')

            if events:
                for t in t_event:
                    ax.axvline(x=t, color='y', linewidth=1)

            if reference is not None and name in reference.dtype.names:
                t_ref = reference[reference.dtype.names[0]]
                y_ref = reference[name]

                t_band, y_min, y_max, i_out = validate_signal(t=time, y=y, t_ref=t_ref, y_ref=y_ref)

                ax.fill_between(t_band, y_min, y_max, facecolor=(0, 0.5, 0), alpha=0.1)
                ax.plot(t_band, y_min, color=(0, 0.5, 0), linewidth=1, label='lower bound', zorder=101, alpha=0.5)
                ax.plot(t_band, y_max, color=(0, 0.5, 0), linewidth=1, label='upper bound', zorder=101, alpha=0.5)

                # mark the outliers
                # use the data coordinates for the x-axis and the axes coordinates for the y-axis
                trans = mtransforms.blended_transform_factory(ax.transData, ax.transAxes)
                ax.fill_between(time, 0, 1, where=i_out, facecolor='red', alpha=0.5, transform=trans)

            if y.dtype == np.float64:
                # find left indices of discontinuities
                i_disc = np.flatnonzero(np.diff(time) == 0)
                i_disc = np.append(i_disc + 1, len(time))
                i0 = 0
                for i1 in i_disc:
                    ax.plot(time[i0:i1], y[i0:i1], color='b', linewidth=0.9, label='result', zorder=101)
                    i0 = i1
            else:
                ax.hlines(y[:-1], time[:-1], time[1:], colors='b', linewidth=1, label='result', zorder=101)
                ax.yaxis.set_major_locator(MaxNLocator(integer=True))

            if y.dtype == bool:
                # use fixed range and labels and fill area
                ax.set_ylim(-0.25, 1.25)
                ax.yaxis.set_ticks([0, 1])
                ax.yaxis.set_ticklabels(['false', 'true'])
                if y.ndim == 1:
                    ax.fill_between(time, y, 0, step='post', facecolor='b', alpha=0.1)
            else:
                ax.margins(x=0, y=0.05)

            if time.size < 200:
                ax.scatter(time, y, color='b', s=5, zorder=101)

            ax.set_ylabel(name, horizontalalignment='left', rotation=0)

            # align the y-labels
            ax.get_yaxis().set_label_coords(label_x, 0.5)

        # set the window title
        if window_title is not None:
            fig.canvas.set_window_title(window_title)

        def onresize(event):
            fig = plt.gcf()

            w = fig.get_figwidth()

            # tight_layout() crashes on very small figures
            if w < 3:
                return

            x = label_x * (8.0 / w)

            # update label coordinates
            for ax in fig.get_axes():
                ax.get_yaxis().set_label_coords(x, 0.5)

            # update layout
            plt.tight_layout()

        # update layout when the plot is re-sized
        fig.canvas.mpl_connect('resize_event', onresize)

        fig.set_size_inches(w=8, h=1.5 * len(names), forward=True)

        plt.tight_layout()

        if filename is None:
            plt.show()
        else:
            dir, _ = os.path.split(filename)
            if not os.path.isdir(dir):
                os.makedirs(dir)
            fig.savefig(filename)
            plt.close(fig)
Esempio n. 44
0
File: plot.py Progetto: jiawu/GSEApy
    def axes_rank(self, rect):
        """
        rect : sequence of float
               The dimensions [left, bottom, width, height] of the new axes. All
               quantities are in fractions of figure width and height.
        """
        # Ranked Metric Scores Plot
        ax1 = self.fig.add_axes(rect, sharex=self.ax)
        if self.module == 'ssgsea':
            ax1.fill_between(self._x,
                             y1=np.log(self.rankings),
                             y2=0,
                             color='#C9D3DB')
            ax1.set_ylabel("log ranked metric", fontsize=14)
        else:
            ax1.fill_between(self._x, y1=self.rankings, y2=0, color='#C9D3DB')
            ax1.set_ylabel("Ranked list metric", fontsize=14)

        ax1.text(.05,
                 .9,
                 self._pos_label,
                 color='red',
                 horizontalalignment='left',
                 verticalalignment='top',
                 transform=ax1.transAxes)
        ax1.text(.95,
                 .05,
                 self._neg_label,
                 color='Blue',
                 horizontalalignment='right',
                 verticalalignment='bottom',
                 transform=ax1.transAxes)
        # the x coords of this transformation are data, and the y coord are axes
        trans1 = transforms.blended_transform_factory(ax1.transData,
                                                      ax1.transAxes)
        ax1.vlines(self._zero_score_ind,
                   0,
                   1,
                   linewidth=.5,
                   transform=trans1,
                   linestyles='--',
                   color='grey')

        hap = self._zero_score_ind / max(self._x)
        if hap < 0.25:
            ha = 'left'
        elif hap > 0.75:
            ha = 'right'
        else:
            ha = 'center'
        ax1.text(hap,
                 0.5,
                 self._z_score_label,
                 horizontalalignment=ha,
                 verticalalignment='center',
                 transform=ax1.transAxes)
        ax1.set_xlabel("Rank in Ordered Dataset", fontsize=14)
        ax1.spines['top'].set_visible(False)
        ax1.tick_params(axis='both',
                        which='both',
                        top=False,
                        right=False,
                        left=False)
        ax1.locator_params(axis='y', nbins=5)
        ax1.yaxis.set_major_formatter(
            plt.FuncFormatter(
                lambda tick_loc, tick_num: '{:.1f}'.format(tick_loc)))
Esempio n. 45
0
def plot_partial_dependence(gbrt, X, features, feature_names=None,
                            label=None, n_cols=3, grid_resolution=100,
                            percentiles=(0.05, 0.95), n_jobs=1,
                            verbose=0, ax=None, line_kw=None,
                            contour_kw=None, **fig_kw):
    """Partial dependence plots for ``features``.

    The ``len(features)`` plots are arranged in a grid with ``n_cols``
    columns. Two-way partial dependence plots are plotted as contour
    plots.

    Parameters
    ----------
    gbrt : BaseGradientBoosting
        A fitted gradient boosting model.
    X : array-like, shape=(n_samples, n_features)
        The data on which ``gbrt`` was trained.
    features : seq of tuples or ints
        If seq[i] is an int or a tuple with one int value, a one-way
        PDP is created; if seq[i] is a tuple of two ints, a two-way
        PDP is created.
    feature_names : seq of str
        Name of each feature; feature_names[i] holds
        the name of the feature with index i.
    label : object
        The class label for which the PDPs should be computed.
        Only if gbrt is a multi-class model. Must be in ``gbrt.classes_``.
    n_cols : int
        The number of columns in the grid plot (default: 3).
    percentiles : (low, high), default=(0.05, 0.95)
        The lower and upper percentile used to create the extreme values
        for the PDP axes.
    grid_resolution : int, default=100
        The number of equally spaced points on the axes.
    n_jobs : int
        The number of CPUs to use to compute the PDs. -1 means 'all CPUs'.
        Defaults to 1.
    verbose : int
        Verbose output during PD computations. Defaults to 0.
    ax : Matplotlib axis object, default None
        An axis object onto which the plots will be drawn.
    line_kw : dict
        Dict with keywords passed to the ``pylab.plot`` call.
        For one-way partial dependence plots.
    contour_kw : dict
        Dict with keywords passed to the ``pylab.plot`` call.
        For two-way partial dependence plots.
    fig_kw : dict
        Dict with keywords passed to the figure() call.
        Note that all keywords not recognized above will be automatically
        included here.

    Returns
    -------
    fig : figure
        The Matplotlib Figure object.
    axs : seq of Axis objects
        A seq of Axis objects, one for each subplot.

    Examples
    --------
    >>> from sklearn.datasets import make_friedman1
    >>> from sklearn.ensemble import GradientBoostingRegressor
    >>> X, y = make_friedman1()
    >>> clf = GradientBoostingRegressor(n_estimators=10).fit(X, y)
    >>> fig, axs = plot_partial_dependence(clf, X, [0, (0, 1)]) #doctest: +SKIP
    ...
    """
    import matplotlib.pyplot as plt
    from matplotlib import transforms
    from matplotlib.ticker import MaxNLocator
    from matplotlib.ticker import ScalarFormatter

    if not isinstance(gbrt, BaseGradientBoosting):
        raise ValueError('gbrt has to be an instance of BaseGradientBoosting')
    if gbrt.estimators_.shape[0] == 0:
        raise ValueError('Call %s.fit before partial_dependence' %
                         gbrt.__class__.__name__)

    # set label_idx for multi-class GBRT
    if hasattr(gbrt, 'classes_') and np.size(gbrt.classes_) > 2:
        if label is None:
            raise ValueError('label is not given for multi-class PDP')
        label_idx = np.searchsorted(gbrt.classes_, label)
        if gbrt.classes_[label_idx] != label:
            raise ValueError('label %s not in ``gbrt.classes_``' % str(label))
    else:
        # regression and binary classification
        label_idx = 0

    X = check_array(X, dtype=DTYPE, order='C')
    if gbrt.n_features != X.shape[1]:
        raise ValueError('X.shape[1] does not match gbrt.n_features')

    if line_kw is None:
        line_kw = {'color': 'green'}
    if contour_kw is None:
        contour_kw = {}

    # convert feature_names to list
    if feature_names is None:
        # if not feature_names use fx indices as name
        feature_names = [str(i) for i in range(gbrt.n_features)]
    elif isinstance(feature_names, np.ndarray):
        feature_names = feature_names.tolist()

    def convert_feature(fx):
        if isinstance(fx, six.string_types):
            try:
                fx = feature_names.index(fx)
            except ValueError:
                raise ValueError('Feature %s not in feature_names' % fx)
        return fx

    # convert features into a seq of int tuples
    tmp_features = []
    for fxs in features:
        if isinstance(fxs, (numbers.Integral,) + six.string_types):
            fxs = (fxs,)
        try:
            fxs = np.array([convert_feature(fx) for fx in fxs], dtype=np.int32)
        except TypeError:
            raise ValueError('features must be either int, str, or tuple '
                             'of int/str')
        if not (1 <= np.size(fxs) <= 2):
            raise ValueError('target features must be either one or two')

        tmp_features.append(fxs)

    features = tmp_features

    names = []
    try:
        for fxs in features:
            l = []
            # explicit loop so "i" is bound for exception below
            for i in fxs:
                l.append(feature_names[i])
            names.append(l)
    except IndexError:
        raise ValueError('features[i] must be in [0, n_features) '
                         'but was %d' % i)

    # compute PD functions
    pd_result = Parallel(n_jobs=n_jobs, verbose=verbose)(
        delayed(partial_dependence)(gbrt, fxs, X=X,
                                    grid_resolution=grid_resolution,
                                    percentiles=percentiles)
        for fxs in features)

    # get global min and max values of PD grouped by plot type
    pdp_lim = {}
    for pdp, axes in pd_result:
        min_pd, max_pd = pdp[label_idx].min(), pdp[label_idx].max()
        n_fx = len(axes)
        old_min_pd, old_max_pd = pdp_lim.get(n_fx, (min_pd, max_pd))
        min_pd = min(min_pd, old_min_pd)
        max_pd = max(max_pd, old_max_pd)
        pdp_lim[n_fx] = (min_pd, max_pd)

    # create contour levels for two-way plots
    if 2 in pdp_lim:
        Z_level = np.linspace(*pdp_lim[2], num=8)

    if ax is None:
        fig = plt.figure(**fig_kw)
    else:
        fig = ax.get_figure()
        fig.clear()

    n_cols = min(n_cols, len(features))
    n_rows = int(np.ceil(len(features) / float(n_cols)))
    axs = []
    for i, fx, name, (pdp, axes) in zip(count(), features, names,
                                        pd_result):
        ax = fig.add_subplot(n_rows, n_cols, i + 1)

        if len(axes) == 1:
            ax.plot(axes[0], pdp[label_idx].ravel(), **line_kw)
        else:
            # make contour plot
            assert len(axes) == 2
            XX, YY = np.meshgrid(axes[0], axes[1])
            Z = pdp[label_idx].reshape(list(map(np.size, axes))).T
            CS = ax.contour(XX, YY, Z, levels=Z_level, linewidths=0.5,
                            colors='k')
            ax.contourf(XX, YY, Z, levels=Z_level, vmax=Z_level[-1],
                        vmin=Z_level[0], alpha=0.75, **contour_kw)
            ax.clabel(CS, fmt='%2.2f', colors='k', fontsize=10, inline=True)

        # plot data deciles + axes labels
        deciles = mquantiles(X[:, fx[0]], prob=np.arange(0.1, 1.0, 0.1))
        trans = transforms.blended_transform_factory(ax.transData,
                                                     ax.transAxes)
        ylim = ax.get_ylim()
        ax.vlines(deciles, [0], 0.05, transform=trans, color='k')
        ax.set_xlabel(name[0])
        ax.set_ylim(ylim)

        # prevent x-axis ticks from overlapping
        ax.xaxis.set_major_locator(MaxNLocator(nbins=6, prune='lower'))
        tick_formatter = ScalarFormatter()
        tick_formatter.set_powerlimits((-3, 4))
        ax.xaxis.set_major_formatter(tick_formatter)

        if len(axes) > 1:
            # two-way PDP - y-axis deciles + labels
            deciles = mquantiles(X[:, fx[1]], prob=np.arange(0.1, 1.0, 0.1))
            trans = transforms.blended_transform_factory(ax.transAxes,
                                                         ax.transData)
            xlim = ax.get_xlim()
            ax.hlines(deciles, [0], 0.05, transform=trans, color='k')
            ax.set_ylabel(name[1])
            # hline erases xlim
            ax.set_xlim(xlim)
        else:
            ax.set_ylabel('Partial dependence')

        if len(axes) == 1:
            ax.set_ylim(pdp_lim[1])
        axs.append(ax)

    fig.subplots_adjust(bottom=0.15, top=0.7, left=0.1, right=0.95, wspace=0.4,
                        hspace=0.3)
    return fig, axs
Esempio n. 46
0
def paga_path(
    adata: AnnData,
    nodes: Sequence[Union[str, int]],
    keys: Sequence[str],
    use_raw: bool = True,
    annotations: Sequence[str] = ('dpt_pseudotime', ),
    color_map: Union[str, Colormap, None] = None,
    color_maps_annotations: Optional[Mapping[str, Union[str, Colormap]]] = {
        'dpt_pseudotime': 'Greys'
    },
    palette_groups: Optional[Sequence[str]] = None,
    n_avg: int = 1,
    groups_key: Optional[str] = None,
    xlim: Tuple[Optional[int], Optional[int]] = (None, None),
    title: Optional[str] = None,
    left_margin=None,
    ytick_fontsize: Optional[int] = None,
    title_fontsize: Optional[int] = None,
    show_node_names: bool = True,
    show_yticks: bool = True,
    show_colorbar: bool = True,
    legend_fontsize: Optional[int] = None,
    legend_fontweight: Optional[str] = None,
    normalize_to_zero_one: bool = False,
    as_heatmap: bool = True,
    return_data: bool = False,
    show: Optional[bool] = None,
    save: Union[bool, str, None] = None,
    ax: Optional[Axes] = None,
) -> Optional[Axes]:
    """Gene expression and annotation changes along paths in the abstracted graph.

    Parameters
    ----------
    adata
        An annotated data matrix.
    nodes
        A path through nodes of the abstracted graph, that is, names or indices
        (within `.categories`) of groups that have been used to run PAGA.
    keys
        Either variables in `adata.var_names` or annotations in
        `adata.obs`. They are plotted using `color_map`.
    use_raw
        Use `adata.raw` for retrieving gene expressions if it has been set.
    annotations
        Plot these keys with `color_maps_annotations`. Need to be keys for
        `adata.obs`.
    color_map
        Matplotlib colormap.
    color_maps_annotations
        Color maps for plotting the annotations. Keys of the dictionary must
        appear in `annotations`.
    palette_groups
        Ususally, use the same `sc.pl.palettes...` as used for coloring the
        abstracted graph.
    n_avg
        Number of data points to include in computation of running average.
    groups_key
        Key of the grouping used to run PAGA. If `None`, defaults to
        `adata.uns['paga']['groups']`.
    as_heatmap
        Plot the timeseries as heatmap. If not plotting as heatmap,
        `annotations` have no effect.
    show_node_names
        Plot the node names on the nodes bar.
    show_colorbar
        Show the colorbar.
    show_yticks
        Show the y ticks.
    normalize_to_zero_one
        Shift and scale the running average to [0, 1] per gene.
    return_data
        Return the timeseries data in addition to the axes if `True`.
    show
         Show the plot, do not return axis.
    save
        If `True` or a `str`, save the figure.
        A string is appended to the default filename.
        Infer the filetype if ending on \\{`'.pdf'`, `'.png'`, `'.svg'`\\}.
    ax
         A matplotlib axes object.

    Returns
    -------
    A :class:`~matplotlib.axes.Axes` object, if `ax` is `None`, else `None`.
    If `return_data`, return the timeseries data in addition to an axes.
    """
    ax_was_none = ax is None

    if groups_key is None:
        if 'groups' not in adata.uns['paga']:
            raise KeyError(
                'Pass the key of the grouping with which you ran PAGA, '
                'using the parameter `groups_key`.')
        groups_key = adata.uns['paga']['groups']
    groups_names = adata.obs[groups_key].cat.categories

    if 'dpt_pseudotime' not in adata.obs.keys():
        raise ValueError(
            '`pl.paga_path` requires computation of a pseudotime `tl.dpt` '
            'for ordering at single-cell resolution')

    if palette_groups is None:
        utils.add_colors_for_categorical_sample_annotation(adata, groups_key)
        palette_groups = adata.uns[groups_key + '_colors']

    def moving_average(a):
        return sc_utils.moving_average(a, n_avg)

    ax = pl.gca() if ax is None else ax
    from matplotlib import transforms
    trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
    X = []
    x_tick_locs = [0]
    x_tick_labels = []
    groups = []
    anno_dict = {anno: [] for anno in annotations}
    if isinstance(nodes[0], str):
        nodes_ints = []
        groups_names_set = set(groups_names)
        for node in nodes:
            if node not in groups_names_set:
                raise ValueError(
                    'Each node/group needs to be one of {} (`groups_key`=\'{}\') not \'{}\'.'
                    .format(groups_names.tolist(), groups_key, node))
            nodes_ints.append(groups_names.get_loc(node))
        nodes_strs = nodes
    else:
        nodes_ints = nodes
        nodes_strs = [groups_names[node] for node in nodes]

    adata_X = adata
    if use_raw and adata.raw is not None:
        adata_X = adata.raw

    for ikey, key in enumerate(keys):
        x = []
        for igroup, group in enumerate(nodes_ints):
            idcs = np.arange(adata.n_obs)[adata.obs[groups_key].values ==
                                          nodes_strs[igroup]]
            if len(idcs) == 0:
                raise ValueError(
                    'Did not find data points that match '
                    '`adata.obs[{}].values == str({})`.'
                    'Check whether adata.obs[{}] actually contains what you expect.'
                    .format(groups_key, group, groups_key))
            idcs_group = np.argsort(adata.obs['dpt_pseudotime'].values[
                adata.obs[groups_key].values == nodes_strs[igroup]])
            idcs = idcs[idcs_group]
            if key in adata.obs_keys(): x += list(adata.obs[key].values[idcs])
            else: x += list(adata_X[:, key].X[idcs])
            if ikey == 0:
                groups += [group for i in range(len(idcs))]
                x_tick_locs.append(len(x))
                for anno in annotations:
                    series = adata.obs[anno]
                    if is_categorical_dtype(series): series = series.cat.codes
                    anno_dict[anno] += list(series.values[idcs])
        if n_avg > 1:
            old_len_x = len(x)
            x = moving_average(x)
            if ikey == 0:
                for key in annotations:
                    if not isinstance(anno_dict[key][0], str):
                        anno_dict[key] = moving_average(anno_dict[key])
        if normalize_to_zero_one:
            x -= np.min(x)
            x /= np.max(x)
        X.append(x)
        if not as_heatmap:
            ax.plot(x[xlim[0]:xlim[1]], label=key)
        if ikey == 0:
            for igroup, group in enumerate(nodes):
                if len(groups_names) > 0 and group not in groups_names:
                    label = groups_names[group]
                else:
                    label = group
                x_tick_labels.append(label)
    X = np.array(X)
    if as_heatmap:
        img = ax.imshow(X,
                        aspect='auto',
                        interpolation='nearest',
                        cmap=color_map)
        if show_yticks:
            ax.set_yticks(range(len(X)))
            ax.set_yticklabels(keys, fontsize=ytick_fontsize)
        else:
            ax.set_yticks([])
        ax.set_frame_on(False)
        ax.set_xticks([])
        ax.tick_params(axis='both', which='both', length=0)
        ax.grid(False)
        if show_colorbar:
            pl.colorbar(img, ax=ax)
        left_margin = 0.2 if left_margin is None else left_margin
        pl.subplots_adjust(left=left_margin)
    else:
        left_margin = 0.4 if left_margin is None else left_margin
        if len(keys) > 1:
            pl.legend(frameon=False,
                      loc='center left',
                      bbox_to_anchor=(-left_margin, 0.5),
                      fontsize=legend_fontsize)
    xlabel = groups_key
    if not as_heatmap:
        ax.set_xlabel(xlabel)
        pl.yticks([])
        if len(keys) == 1: pl.ylabel(keys[0] + ' (a.u.)')
    else:
        import matplotlib.colors
        # groups bar
        ax_bounds = ax.get_position().bounds
        groups_axis = pl.axes([
            ax_bounds[0], ax_bounds[1] - ax_bounds[3] / len(keys),
            ax_bounds[2], ax_bounds[3] / len(keys)
        ])
        groups = np.array(groups)[None, :]
        groups_axis.imshow(
            groups,
            aspect='auto',
            interpolation="nearest",
            cmap=matplotlib.colors.ListedColormap(
                # the following line doesn't work because of normalization
                # adata.uns['paga_groups_colors'])
                palette_groups[np.min(groups).astype(int):],
                N=int(np.max(groups) + 1 - np.min(groups))))
        if show_yticks:
            groups_axis.set_yticklabels(['', xlabel, ''],
                                        fontsize=ytick_fontsize)
        else:
            groups_axis.set_yticks([])
        groups_axis.set_frame_on(False)
        if show_node_names:
            ypos = (groups_axis.get_ylim()[1] + groups_axis.get_ylim()[0]) / 2
            x_tick_locs = sc_utils.moving_average(x_tick_locs, n=2)
            for ilabel, label in enumerate(x_tick_labels):
                groups_axis.text(x_tick_locs[ilabel],
                                 ypos,
                                 x_tick_labels[ilabel],
                                 fontdict={
                                     'horizontalalignment': 'center',
                                     'verticalalignment': 'center'
                                 })
        groups_axis.set_xticks([])
        groups_axis.grid(False)
        groups_axis.tick_params(axis='both', which='both', length=0)
        # further annotations
        y_shift = ax_bounds[3] / len(keys)
        for ianno, anno in enumerate(annotations):
            if ianno > 0: y_shift = ax_bounds[3] / len(keys) / 2
            anno_axis = pl.axes([
                ax_bounds[0], ax_bounds[1] - (ianno + 2) * y_shift,
                ax_bounds[2], y_shift
            ])
            arr = np.array(anno_dict[anno])[None, :]
            if anno not in color_maps_annotations:
                color_map_anno = ('Vega10' if is_categorical_dtype(
                    adata.obs[anno]) else 'Greys')
            else:
                color_map_anno = color_maps_annotations[anno]
            img = anno_axis.imshow(arr,
                                   aspect='auto',
                                   interpolation='nearest',
                                   cmap=color_map_anno)
            if show_yticks:
                anno_axis.set_yticklabels(['', anno, ''],
                                          fontsize=ytick_fontsize)
                anno_axis.tick_params(axis='both', which='both', length=0)
            else:
                anno_axis.set_yticks([])
            anno_axis.set_frame_on(False)
            anno_axis.set_xticks([])
            anno_axis.grid(False)
    if title is not None: ax.set_title(title, fontsize=title_fontsize)
    if show is None and not ax_was_none: show = False
    else: show = settings.autoshow if show is None else show
    utils.savefig_or_show('paga_path', show=show, save=save)
    if return_data:
        df = pd.DataFrame(data=X.T, columns=keys)
        df['groups'] = moving_average(
            groups)  # groups is without moving average, yet
        if 'dpt_pseudotime' in anno_dict:
            df['distance'] = anno_dict['dpt_pseudotime'].T
        return ax, df if ax_was_none and show == False else df
    else:
        return ax if ax_was_none and show == False else None
Esempio n. 47
0
def make_plot(data):
    converters = {'Concentration': no_units,
                  'Weight': no_units,
                  'Date/Time': convert_dates,
                 }
    weight = pd.read_html(data, match='Weight', 
                                header=0, 
                                converters=converters)[0]
    glucose = pd.read_html(data, match='Concentration', 
                                 header=0,
                                 converters=converters)[0]
    glucose.Event = glucose.Event.apply(add_space)
    glucose.set_index('Date/Time', inplace=True)


    # set up plots
    plt.rc('axes', prop_cycle=(cycler('color', [ 'b','#ee7600', 'magenta','g']))+
                               (cycler('marker', ['o', 'o', 'o', 'o'])))
    fig = plt.figure()
    fig.set_size_inches(11, 8.5)
    ax1 = plt.subplot(211)
    plt.suptitle('Blood Glucose and Weight')
    plt.ylabel('Concentration (mg/dL)')
    ax1.xaxis.set_major_formatter(DateFormatter('%#m/%d'))
    ax1.xaxis.set_major_locator(AutoDateLocator())

    events = glucose.groupby('Event')
    for name, event in events:
        ax1.plot(event.index, event.Concentration, linestyle='-', ms=5, label=name, linewidth=0.5)
    ax1.legend(numpoints=1, loc='upper left', ncol=3)
    #~ plt.axhline(y=120, color='red', linestyle=':', linewidth=1)
    #~ plt.axhspan(0.5, 1, color='red', transform=ax1.transAxes)
    plt.margins(x=0.1, y=0.1)
    trans = transforms.blended_transform_factory( ax1.transAxes, ax1.transData)
    ax1.add_patch(patches.Rectangle((0, 120), width=1, height=100, transform=trans, color='r', alpha=0.1))
    ax1.add_patch(patches.Rectangle((0, 80), width=1, height=40, transform=trans, color='g', alpha=0.1))

    ax2 = plt.subplot(212, sharex=ax1)
    plt.ylabel('Weight (lbs)')
    ax2.plot(weight['Date/Time'], weight['Weight'], marker='o', color='black', ms=5, linewidth=0.5)

    # add trendline
    dates = weight['Date/Time'].values.astype(float).reshape(-1, 1)
    weights = weight['Weight'].values.reshape(-1, 1)
    regression = linear_model.LinearRegression()
    regression.fit(dates, weights)
    trend = regression.predict(dates)
    line = ax2.plot(weight['Date/Time'], trend, ':', color='#808080')
    
    # get number of days for all weight readings
    run = (weight.iloc[0]-weight.iloc[-1])['Date/Time'].total_seconds() / (24*60*60)
    slope = (trend[0][0]-trend[-1][0])/run
    index = len(trend)//2
    x     = weight['Date/Time'][index]
    y     = trend[index][0]
    ax2.annotate('%0.1f lbs/week' % (slope*7), xy=(x,y), xytext=(x,y), color='#808080',
                 rotation=round(degrees(atan(slope))))

    plt.margins(x=0.1, y=0.1)
    plt.savefig('output/readings.png', orientation='landscape')
    plt.savefig('output/readings.pdf', orientation='landscape')
    plt.show()
Esempio n. 48
0
    def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
        '''plots ale to display'''
        if line_kw is None:
            line_kw = {}
        if contour_kw is None:
            contour_kw = {}

        if ax is None:
            _, ax = plt.subplots()

        default_contour_kws = {"alpha": 0.75}
        contour_kw = {**default_contour_kws, **contour_kw}

        default_line_kws = {'color': 'C0'}
        line_kw = {**default_line_kws, **line_kw}
        individual_line_kw = line_kw.copy()

        if self.kind == 'individual' or self.kind == 'both':
            individual_line_kw['alpha'] = 0.3
            individual_line_kw['linewidth'] = 0.5

        n_features = len(self.features)
        n_sampled = 1
        if self.kind == 'individual':
            n_instances = len(self.ale_results[0].individual[0])
            n_sampled = self._get_sample_count(n_instances)
        elif self.kind == 'both':
            n_instances = len(self.ale_results[0].individual[0])
            n_sampled = self._get_sample_count(n_instances) + 1

        if isinstance(ax, plt.Axes):
            # If ax was set off, it has most likely been set to off
            # by a previous call to plot.
            if not ax.axison:
                raise ValueError("The ax was already used in another plot "
                                 "function, please set ax=display.axes_ "
                                 "instead")

            ax.set_axis_off()
            self.bounding_ax_ = ax
            self.figure_ = ax.figure

            n_cols = min(n_cols, n_features)
            n_rows = int(np.ceil(n_features / float(n_cols)))

            self.axes_ = np.empty((n_rows, n_cols), dtype=object)
            if self.kind == 'average':
                self.lines_ = np.empty((n_rows, n_cols), dtype=object)
            else:
                self.lines_ = np.empty((n_rows, n_cols, n_sampled),
                                       dtype=object)
            self.contours_ = np.empty((n_rows, n_cols), dtype=object)

            axes_ravel = self.axes_.ravel()

            gs = GridSpecFromSubplotSpec(n_rows,
                                         n_cols,
                                         subplot_spec=ax.get_subplotspec())
            for i, spec in zip(range(n_features), gs):
                axes_ravel[i] = self.figure_.add_subplot(spec)

        else:  # array-like
            ax = np.asarray(ax, dtype=object)
            if ax.size != n_features:
                raise ValueError("Expected ax to have {} axes, got {}".format(
                    n_features, ax.size))

            if ax.ndim == 2:
                n_cols = ax.shape[1]
            else:
                n_cols = None

            self.bounding_ax_ = None
            self.figure_ = ax.ravel()[0].figure
            self.axes_ = ax
            if self.kind == 'average':
                self.lines_ = np.empty_like(ax, dtype=object)
            else:
                self.lines_ = np.empty(ax.shape + (n_sampled, ), dtype=object)
            self.contours_ = np.empty_like(ax, dtype=object)

        # create contour levels for two-way plots
        if 2 in self.ale_lim:
            Z_level = np.linspace(*self.ale_lim[2], num=8)

        self.deciles_vlines_ = np.empty_like(self.axes_, dtype=object)
        self.deciles_hlines_ = np.empty_like(self.axes_, dtype=object)

        # Create 1d views of these 2d arrays for easy indexing
        lines_ravel = self.lines_.ravel(order='C')
        contours_ravel = self.contours_.ravel(order='C')
        vlines_ravel = self.deciles_vlines_.ravel(order='C')
        hlines_ravel = self.deciles_hlines_.ravel(order='C')

        for i, axi, fx, ale_result in zip(count(), self.axes_.ravel(),
                                          self.features, self.ale_results):

            avg_preds = None
            preds = None
            values = ale_result
            if self.kind == 'individual':
                preds = ale_result
            elif self.kind == 'average':
                avg_preds = ale_result.average
            else:  # kind='both'
                avg_preds = ale_result
                preds = ale_result
            if len(values) == 1:
                if self.kind == 'individual' or self.kind == 'both':
                    n_samples = self._get_sample_count(
                        len(preds[self.target_idx]))
                    ice_lines = preds[self.target_idx]
                    sampled = ice_lines[np.random.choice(
                        ice_lines.shape[0], n_samples, replace=False), :]
                    for j, ins in enumerate(sampled):
                        lines_ravel[i * j + j] = axi.plot(
                            values[0], ins.ravel(), **individual_line_kw)[0]
                if self.kind == 'average':
                    lines_ravel[i] = axi.plot(
                        values[0], avg_preds[self.target_idx].ravel(),
                        **line_kw)[0]
                elif self.kind == 'both':
                    lines_ravel[i] = axi.plot(
                        values[0],
                        avg_preds[self.target_idx].ravel(),
                        label='average',
                        **line_kw)[0]
                    axi.legend()
            else:
                # contour plot
                XX, YY = np.meshgrid(values[0], values[1])
                Z = avg_preds[self.target_idx].T
                CS = axi.contour(XX,
                                 YY,
                                 Z,
                                 levels=Z_level,
                                 linewidths=0.5,
                                 colors='k')
                contours_ravel[i] = axi.contourf(XX,
                                                 YY,
                                                 Z,
                                                 levels=Z_level,
                                                 vmax=Z_level[-1],
                                                 vmin=Z_level[0],
                                                 **contour_kw)
                axi.clabel(CS,
                           fmt='%2.2f',
                           colors='k',
                           fontsize=10,
                           inline=True)

            trans = transforms.blended_transform_factory(
                axi.transData, axi.transAxes)
            ylim = axi.get_ylim()
            vlines_ravel[i] = axi.vlines(self.deciles[fx[0]],
                                         0,
                                         0.05,
                                         transform=trans,
                                         color='k')
            axi.set_ylim(ylim)

            # Set xlabel if it is not already set
            if not axi.get_xlabel():
                axi.set_xlabel(self.feature_names[fx[0]])

            if len(values) == 1:
                if n_cols is None or i % n_cols == 0:
                    if not axi.get_ylabel():
                        axi.set_ylabel('ALE')
                else:
                    axi.set_yticklabels([])
                axi.set_ylim(self.ale_lim[1])
            else:
                # contour plot
                trans = transforms.blended_transform_factory(
                    axi.transAxes, axi.transData)
                xlim = axi.get_xlim()
                hlines_ravel[i] = axi.hlines(self.deciles[fx[0]],
                                             0,
                                             0.05,
                                             transform=trans,
                                             color='k')
                # hline erases xlim
                axi.set_ylabel(self.feature_names[fx[0]])
                axi.set_xlim(xlim)
        return self
Esempio n. 49
0
def pick_droplets(saved_droplets):

    fig = plt.figure()
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)

    ax_button_save = fig.add_axes([0.1, 0.01, 0.2, 0.04])
    ax_button_toss = fig.add_axes([0.4, 0.01, 0.2, 0.04])

    button_save = Button(ax_button_save,
                         'Save Droplet',
                         color='lightcyan',
                         hovercolor='0.975')
    button_toss = Button(ax_button_toss,
                         'Toss Droplet',
                         color='red',
                         hovercolor='0.975')

    interesting_droplets = []

    pos_plot, = ax1.plot([], [], label='tracked position')
    pos_plot_smooth, = ax1.plot([], [], label='smoothed track')
    ax1.legend()
    ax1.set_xlabel('time [s]')
    ax1.set_ylabel('position [pixels]')

    vel_plot, = ax2.plot([], [], label='computed velocity')
    vel_plot_smooth, = ax2.plot([], [])

    ax2.set_xlabel('time [s]')
    ax2.set_ylabel('velocity [pixels/sec]')

    transform = blended_transform_factory(ax1.transData, ax1.transAxes)

    ax1.vlines(split_times,
               0,
               1,
               colors=line_colors,
               linestyles='dashed',
               transform=transform)

    transform = blended_transform_factory(ax2.transData, ax2.transAxes)

    ax2.vlines(split_times,
               0,
               1,
               colors=line_colors,
               linestyles='dashed',
               transform=transform)

    def plot_next():
        global droplet_container
        if len(saved_droplets) > 0:
            droplet = saved_droplets.pop(0)
            droplet_container = droplet
            xpos, ypos, xvel, yvel, cov, birth_time, lifetime = unpack_droplet(
                droplet)
            ypos = ymax - ypos
            yvel = -yvel
            ysmooth = savgol_filter(ypos, 31, 2)

            yerr = np.sqrt(cov[:, 2, 2])
            yvel_err = np.sqrt(cov[:, 3, 3])

            xpos = xpos - xmin

            yvel_ = savgol_filter(ypos, 31, 4, deriv=1,
                                  delta=dt)  #np.diff(ysmooth)/dt
            ##            yvel_smooth = savgol_filter(yvel, 21, 1)

            time_array = np.linspace(birth_time, birth_time + lifetime,
                                     len(ypos))
            if np.any((split_times > birth_time)
                      & (split_times < birth_time + lifetime)):
                pos_plot.set_data(time_array, ypos)

                pos_plot_smooth.set_data(time_array, ysmooth)

                vel_plot.set_data(time_array, yvel_)
                ##                vel_plot_smooth.set_data(time_array, yvel_smooth)

                ax1.set_xlim([birth_time, birth_time + lifetime])
                ax2.set_xlim([birth_time, birth_time + lifetime])

                ax1.set_ylim([np.min(ypos), np.max(ypos)])
                ax2.set_ylim([np.min(yvel_), np.max(yvel_)])

                fig.suptitle(
                    f"droplet id: {droplet.dropletid}, time: {datetime.timedelta(seconds=droplet.birth_time)}"
                )

                fig.canvas.draw_idle()

                return droplet
            return plot_next()
        else:
            plt.close()
            print('')
            print("all done")
            return None

    button_save.on_clicked(lambda e: [
        print('saving droplet', droplet_container.dropletid),
        interesting_droplets.append(droplet_container),
        plot_next()
    ])

    button_toss.on_clicked(lambda e: plot_next())
    plt.tight_layout()
    plt.subplots_adjust(top=0.94, bottom=0.165)
    plot_next()
    plt.show()

    return interesting_droplets
Esempio n. 50
0
    def _plot_one_way_partial_dependence(
        self,
        preds,
        avg_preds,
        feature_values,
        feature_idx,
        n_ice_lines,
        ax,
        n_cols,
        pd_plot_idx,
        n_lines,
        individual_line_kw,
        line_kw,
    ):
        """Plot 1-way partial dependence: ICE and PDP.

        Parameters
        ----------
        preds : ndarray of shape \
                (n_instances, n_grid_points) or None
            The predictions computed for all points of `feature_values` for a
            given feature for all samples in `X`.
        avg_preds : ndarray of shape (n_grid_points,)
            The average predictions for all points of `feature_values` for a
            given feature for all samples in `X`.
        feature_values : ndarray of shape (n_grid_points,)
            The feature values for which the predictions have been computed.
        feature_idx : int
            The index corresponding to the target feature.
        n_ice_lines : int
            The number of ICE lines to plot.
        ax : Matplotlib axes
            The axis on which to plot the ICE and PDP lines.
        n_cols : int or None
            The number of column in the axis.
        pd_plot_idx : int
            The sequential index of the plot. It will be unraveled to find the
            matching 2D position in the grid layout.
        n_lines : int
            The total number of lines expected to be plot on the axis.
        individual_line_kw : dict
            Dict with keywords passed when plotting the ICE lines.
        line_kw : dict
            Dict with keywords passed when plotting the PD plot.
        """
        from matplotlib import transforms  # noqa

        if self.kind in ("individual", "both"):
            self._plot_ice_lines(
                preds[self.target_idx],
                feature_values,
                n_ice_lines,
                ax,
                pd_plot_idx,
                n_lines,
                individual_line_kw,
            )

        if self.kind in ("average", "both"):
            # the average is stored as the last line
            if self.kind == "average":
                pd_line_idx = pd_plot_idx
            else:
                pd_line_idx = pd_plot_idx * n_lines + n_ice_lines
            self._plot_average_dependence(
                avg_preds[self.target_idx].ravel(),
                feature_values,
                ax,
                pd_line_idx,
                line_kw,
            )

        trans = transforms.blended_transform_factory(ax.transData,
                                                     ax.transAxes)
        # create the decile line for the vertical axis
        vlines_idx = np.unravel_index(pd_plot_idx, self.deciles_vlines_.shape)
        self.deciles_vlines_[vlines_idx] = ax.vlines(
            self.deciles[feature_idx[0]],
            0,
            0.05,
            transform=trans,
            color="k",
        )
        # reset ylim which was overwritten by vlines
        ax.set_ylim(self.pdp_lim[1])

        # Set xlabel if it is not already set
        if not ax.get_xlabel():
            ax.set_xlabel(self.feature_names[feature_idx[0]])

        if n_cols is None or pd_plot_idx % n_cols == 0:
            if not ax.get_ylabel():
                ax.set_ylabel('Partial dependence')
        else:
            ax.set_yticklabels([])

        if line_kw.get("label", None) and self.kind != 'individual':
            ax.legend()
Esempio n. 51
0
def plot_record_section(st,
                        origin_time,
                        source_location,
                        plot_celerity=None,
                        label_waveforms=True):
    """
    Plot a record section based upon user-provided source location and origin
    time. Optionally plot celerity for reference, with two plotting options.

    Args:
        st (:class:`~obspy.core.stream.Stream`): Any Stream object with
            `tr.stats.latitude`, `tr.stats.longitude` attached
        origin_time (:class:`~obspy.core.utcdatetime.UTCDateTime`): Origin time
            for record section
        source_location (tuple): Tuple of (`lat`, `lon`) specifying source
            location
        plot_celerity: Can be either `'range'` or a single celerity or a list
            of celerities. If `'range'`, plots a continuous swatch of
            celerities from 260-380 m/s. Otherwise, plots specific celerities.
            If `None`, does not plot any celerities (default: `None`)
        label_waveforms (bool): Toggle labeling waveforms with network and
            station codes (default: `True`)

    Returns:
        :class:`~matplotlib.figure.Figure`: Output figure
    """

    st_edit = st.copy()

    for tr in st_edit:
        tr.stats.distance, _, _ = gps2dist_azimuth(*source_location,
                                                   tr.stats.latitude,
                                                   tr.stats.longitude)

    st_edit.trim(origin_time)

    fig = plt.figure(figsize=(12, 8))

    st_edit.plot(fig=fig,
                 type='section',
                 orientation='horizontal',
                 fillcolors=('black', 'black'))

    ax = fig.axes[0]

    trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)

    if label_waveforms:
        for tr in st_edit:
            ax.text(1.01,
                    tr.stats.distance / 1000,
                    f'{tr.stats.network}.{tr.stats.station}',
                    verticalalignment='center',
                    transform=trans,
                    fontsize=10)
        pad = 0.1  # Move colorbar to the right to make room for labels
    else:
        pad = 0.05  # Matplotlib default for vertical colorbars

    if plot_celerity:

        # Check if user requested a continuous range of celerities
        if plot_celerity == 'range':
            inc = 0.5  # [m/s]
            celerity_list = np.arange(220, 350 + inc, inc)  # [m/s] Includes
            # all reasonable
            # celerities
            zorder = -1

        # Otherwise, they provided specific celerities
        else:
            # Type conversion
            if type(plot_celerity) is not list:
                plot_celerity = [plot_celerity]

            celerity_list = plot_celerity
            celerity_list.sort()
            zorder = None

        # Create colormap of appropriate length
        cmap = plt.cm.get_cmap('rainbow', len(celerity_list))
        colors = [cmap(i) for i in range(cmap.N)]

        xlim = np.array(ax.get_xlim())
        y_max = ax.get_ylim()[1]  # Save this for re-scaling axis

        for celerity, color in zip(celerity_list, colors):
            ax.plot(xlim,
                    xlim * celerity / 1000,
                    label=f'{celerity:g}',
                    color=color,
                    zorder=zorder)

        ax.set_ylim(top=y_max)  # Scale y-axis to pre-plotting extent

        # If plotting a continuous range, add a colorbar
        if plot_celerity == 'range':
            mapper = plt.cm.ScalarMappable(cmap=cmap)
            mapper.set_array(celerity_list)
            cbar = fig.colorbar(mapper,
                                label='Celerity (m/s)',
                                pad=pad,
                                aspect=30)
            cbar.ax.minorticks_on()

        # If plotting discrete celerities, just add a legend
        else:
            ax.legend(title='Celerity (m/s)',
                      loc='lower right',
                      framealpha=1,
                      edgecolor='inherit')

    ax.set_ylim(bottom=0)  # Show all the way to zero offset

    time_round = np.datetime64(origin_time + 0.5,
                               's').astype(datetime)  # Nearest second
    ax.set_xlabel('Time (s) from {}'.format(time_round))
    ax.set_ylabel('Distance (km) from '
                  '({:.4f}, {:.4f})'.format(*source_location))

    fig.tight_layout()
    fig.show()

    return fig
def plot_station_data(date,site,sitetitle,df):
    '''Given site station ID, the title of that site, and a dataframe of ASOS observation data from the last 3 days,
    returns a plot of the last 3-days of weather at that site. 
    
    Parameters: 
    site (str): string of ASOS station ID
    sitetitle (str): string of ASOS station full name 
    df (dataframe): dataframe containing last 72 hours (3 days) of ASOS station data 
    
    Returns: 
    None
    
    *saves plots to plot_dir listed near top of script*
    '''
    if isinstance(df, int): #Returns if the station is not reporting
        return
    
    lower_site = site.lower()
    timestamp_end=str(df.index[-1].strftime('%Y%m%d%H%M'))
    dt = df.index[:]
    dt_array = np.array(dt.values)
    graphtimestamp_start=dt[0].strftime("%m/%d/%y") 
    graphtimestamp=dt[-1].strftime("%m/%d/%y")      
    #now = datetime.datetime.utcnow()
    now = datetime.strptime(date,'%Y%m%d')
    today_date = dt[-1].strftime('%Y%m%d')
    markersize = 1.5
    linewidth = 1.0

    #make figure and axes
    fig = plt.figure()
    fig.set_size_inches(18,10)
    if 'snow_depth_set_1' in df.keys():          #six axes if snow depth 
        ax1 = fig.add_subplot(6,1,1)
        ax2 = fig.add_subplot(6,1,2,sharex=ax1)
        ax3 = fig.add_subplot(6,1,3,sharex=ax1)
        ax4 = fig.add_subplot(6,1,4,sharex=ax1)
        ax5 = fig.add_subplot(6,1,5,sharex=ax1)
        ax6 = fig.add_subplot(6,1,6,sharex=ax1)
        ax6.set_xlabel('Time (UTC)')
    else:
        ax1 = fig.add_subplot(5,1,1)             #five axes if no snow depth
        ax2 = fig.add_subplot(5,1,2,sharex=ax1)
        ax3 = fig.add_subplot(5,1,3,sharex=ax1)
        ax4 = fig.add_subplot(5,1,4,sharex=ax1)
        ax5 = fig.add_subplot(5,1,5,sharex=ax1)
        ax5.set_xlabel('Time (UTC)')
    
    #ax1.set_title(site+' '+sitetitle+' '+graphtimestamp_start+' - '+graphtimestamp+' '+now.strftime("%H:%MZ"))
    ax1.set_title(site+' '+sitetitle+' '+graphtimestamp_start+' - '+graphtimestamp)

    #------------------
    #plot airT and dewT
    #------------------
    if 'tmpc' in df.keys():
        airT = df['tmpc']
        airT_new = airT.dropna()
        airT_list = list(airT_new.values)
        airT_dt_list = []
        for i in range(0,len(airT)):
            if pd.isnull(airT[i]) == False:
                airT_dt_list.append(dt[i])
        ax1.plot_date(airT_dt_list,airT_list,'o-',label="Temp",color="blue",linewidth=linewidth,markersize=markersize)  
        #ax1.plot_date(dt,airT,'-',label="Temp",color="blue",linewidth=linewidth)  
    if 'dwpc' in df.keys():
        dewT = df['dwpc']
        dewT_new = dewT.dropna()
        dewT_list = list(dewT_new.values)
        dewT_dt_list = []
        for i in range(0,len(dewT)):
            if pd.isnull(dewT[i]) == False:
                dewT_dt_list.append(dt[i])
        ax1.plot_date(dewT_dt_list,dewT_list,'o-',label="Dew Point",color="black",linewidth=linewidth,markersize=markersize)
    if ax1.get_ylim()[0] < 0 < ax1.get_ylim()[1]:
        ax1.axhline(0, linestyle='-', linewidth = 1.0, color='deepskyblue')
        trans = transforms.blended_transform_factory(ax1.get_yticklabels()[0].get_transform(), ax1.transData)
        ax1.text(0,0,'0C', color="deepskyblue", transform=trans, ha="right", va="center") #light blue line at 0 degrees C
    ax1.set_ylabel('Temp ($^\circ$C)')
    ax1.legend(loc='best',ncol=2)
    axes = [ax1]                             #begin axes

    #----------------------------
    #plotting wind speed and gust
    #----------------------------
    if 'sknt' in df.keys():
        wnd_spd = df['sknt']
        ax2.plot_date(dt,wnd_spd,'o-',label='Speed',color="forestgreen",linewidth=linewidth,markersize=markersize)
    if 'gust' in df.keys():
        wnd_gst = df['gust']
        max_wnd_gst = wnd_gst.max(skipna=True)
        ax2.plot_date(dt,wnd_gst,'o-',label='Gust (Max ' + str(round(max_wnd_gst,1)) + 'kt)',color="red",linewidth=0.0,markersize=markersize) 
    ax2.set_ylabel('Wind (kt)')
    ax2.legend(loc='best',ncol=2)
    axes.append(ax2)
    
    #-----------------------
    #plotting wind direction
    #-----------------------
    if 'drct' in df.keys():
        wnd_dir = df['drct']
        ax3.plot_date(dt,wnd_dir,'o-',label='Direction',color="purple",linewidth=0.2, markersize=markersize)
    ax3.set_ylim(-10,370)
    ax3.set_ylabel('Wind Direction')
    ax3.set_yticks([0,90,180,270,360])
    axes.append(ax3)
    
    #-------------
    #plotting MSLP
    #-------------
    if 'mslp' in df.keys():
        mslp = df['mslp']
        mslp_new = mslp.dropna()
        mslp_list = list(mslp_new.values)
        mslp_dt_list = []
        for i in range(0,len(mslp)):
            if pd.isnull(mslp[i]) == False:
                mslp_dt_list.append(dt[i])
        max_mslp = mslp.max(skipna=True)
        min_mslp = mslp.min(skipna=True)
        labelname = 'Min ' + str(round(min_mslp,1)) + 'hPa, Max ' + str(round(max_mslp,2)) + 'hPa'
        ax4.plot_date(mslp_dt_list,mslp_list,'o-',label=labelname,color='darkorange',linewidth=linewidth,markersize=markersize)
    ax4.legend(loc='best')
    ax4.set_ylabel('MSLP (hPa)')
    ax4.set_xlabel('Time (UTC)')
    axes.append(ax4)
    
    #-------------------------------------------
    #plotting precip accumulation & precip types
    #-------------------------------------------        

    # Move date_time from index to column
    df = df.reset_index()

    # MODIFY THIS SO THAT WHEN READINGS COME IN WE KEEP THEM INSTEAD OF JUST ACCEPTING HOURLY ACCUMS
    # WILL MAKE PLOT LESS CHOPPY
    # Plot precip time series (use only values at minute 53)
    if 'p01m' in df.keys():
        df['p01m'] = df['p01m'].fillna(0)
        precip_inc = list(df['p01m'].values)
        precip_inc_dt = list(df['date_time'].values)

        precip_accum = 0.0
        precip_accum_list = []
        precip_dt_list = []
        precip_accum_indices = []
        
        for i in range(0,len(precip_inc)):
            time_obj = pd.to_datetime(precip_inc_dt[i])
            minutes = time_obj.strftime('%M')
            if minutes == '53':
                precip_accum = precip_accum + precip_inc[i]
                precip_accum_list.append(precip_accum)
                precip_dt_list.append(precip_inc_dt[i])
                precip_accum_indices.append(i)
        times = precip_dt_list
        precip_accums = precip_accum_list

        """
        precip_accum = 0.0
        precip_accum_list = []
        last_increment = df.loc[0].p01m
        precip_accum_indices = []

        for index in range(1,len(df)):
            if df.loc[index].p01m < last_increment:
                precip_accum = precip_accum + last_increment
                precip_accum_list.append(precip_accum)
                precip_accum_indices.append(index)
            else:
                precip_accum_list.append(precip_accum)
            last_increment = df.loc[index].p01m
        # append last element in dataframe in case there's some precip there
        precip_accum = precip_accum + last_increment
        precip_accum_list.append(precip_accum)
        # values to use for the plot
        times = [precip_dt_list[i] for i in precip_accum_indices]
        precip_accums = [precip_accum_list[i] for i in precip_accum_indices]
        """
                       
        max_precip = max(precip_accum_list)
        labelname = 'Precip (' + str(round(max_precip,2)) + 'mm)'
        ax5.plot_date(precip_dt_list,precip_accum_list,'o-',label=labelname,color='navy',linewidth=linewidth,markersize=markersize)
        if max_precip > 0:
            ax5.set_ylim(-0.1*max_precip,max_precip+max_precip*0.2)
        else:
            ax5.set_ylim(-0.5,5)
            
    # Add weather_code info to plot
    if 'wxcodes' in df.keys():
        df['wxcodes'] = df['wxcodes'].fillna('')
        wxcodes_list = list(df['wxcodes'].values)
        #wxcodes_dt_list = list(df['date_time'].values)
        wxcodes_num_list = []
        for i in range(0,len(wxcodes_list)):
            wxcodes = wxcodes_list[i]
            if len(wxcodes) > 0:
                wxcode_wto = wx_code_map[wxcodes.split()[0]]
            else:
                wxcode_wto = 0
            wxcodes_num_list.append(wxcode_wto)
        wxcodes = [wxcodes_num_list[i] for i in precip_accum_indices]
        # Set y values for weather symbols on plot
        dummy_y_vals = np.ones(len(wxcodes)) * (0.10*max_precip)  

        sp = StationPlot(ax5, times, dummy_y_vals)
        #ax.plot(dates, temps)
        #sp.plot_symbol('C', wxcodes, current_weather, fontsize=16, color='red')
        sp.plot_symbol('C', wxcodes, current_weather, fontsize=14, color='red')
        
        """
        wxcodes_num_list_hrly = []
        wxcodes_dt_list_hrly = []
        for i in range(0,len(wxcodes_num_list)):
            time_obj = pd.to_datetime(wxcodes_dt_list[i])
            minutes = time_obj.strftime('%M')
            if minutes == '00':
                wxcodes_num_list_hrly.append(wxcodes_num_list[i])
                wxcodes_dt_list_hrly.append(wxcodes_dt_list[i])
        dummy_y_vals = np.zeros(len(wxcodes_num_list_hrly))  # only to place on plot
        sp = StationPlot(ax5, wxcodes_dt_list_hrly, dummy_y_vals)
        #ax5.plot(precip_dt_list,precip_accum_list)
        sp.plot_symbol('C', wxcodes_num_list_hrly, current_weather, fontsize=16, color='red')
        """

    ax5.legend(loc='best')
    ax5.set_ylabel('Precip (mm)')
    axes.append(ax5)

    """
    #-------------------
    #plotting snow depth
    #-------------------
    if 'snow_depth_set_1' in df.keys():
        snow_depth = df['snow_depth_set_1']
        snow_depth_new = snow_depth.dropna()
        snow_depth_dt_list = []
        for i in range(0,len(snow_depth)):
            if pd.isnull(snow_depth[i]) == False:
                snow_depth_dt_list.append(dt[i])  
        max_snow_depth = snow_depth.max(skipna=True)
        min_snow_depth = snow_depth.min(skipna=True)
        labelname = 'Min Depth ' + str(round(min_snow_depth,2)) + 'mm, Max Depth ' + str(round(max_snow_depth,2)) + 'mm'
        ax6.plot_date(snow_depth_dt_list,snow_depth_new,'o-',label=labelname,color='deepskyblue',linewidth=linewidth,markersize=markersize)
        if max_snow_depth > 0:
            ax6.set_ylim(-0.1*max_snow_depth,max_snow_depth+max_snow_depth*0.2)
        else:
            ax6.set_ylim(-0.5,5)
        ax6.legend(loc='best')
        ax6.set_ylabel('Snow Depth (mm)')
        axes.append(ax6)
    """

    # Axes formatting
    for ax in axes: 
        ax.spines["top"].set_visible(False)  #darker borders on the grids of each subplot
        ax.spines["right"].set_visible(False)  
        ax.spines["left"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.tick_params(axis='x',which='both',bottom='on',top='off')     #add ticks at labeled times
        ax.tick_params(axis='y',which='both',left='on',right='off') 

        ax.xaxis.set_major_locator( DayLocator() )
        ax.xaxis.set_major_formatter( DateFormatter('%b-%d') )
        
        ax.xaxis.set_minor_locator( HourLocator(np.linspace(6,18,3)) )
        ax.xaxis.set_minor_formatter( DateFormatter('%H') )
        ax.fmt_xdata = DateFormatter('Y%m%d%H%M%S')
        ax.yaxis.grid(linestyle = '--')
        ax.get_yaxis().set_label_coords(-0.06,0.5)

    # Write plot to file
    plot_path = plot_dir+'/'+today_date
    if not os.path.exists(plot_path):
            os.makedirs(plot_path)
    try:
        plt.savefig(plot_path+'/ops.asos.'+timestamp_end+'.'+lower_site+'.png',bbox_inches='tight')
    except:
        print("Problem saving figure for %s. Usually a maxticks problem" %site)
    plt.close()
Esempio n. 53
0
    def _plot_two_way_partial_dependence(
        self,
        avg_preds,
        feature_values,
        feature_idx,
        ax,
        pd_plot_idx,
        Z_level,
        contour_kw,
    ):
        """Plot 2-way partial dependence.

        Parameters
        ----------
        avg_preds : ndarray of shape \
                (n_instances, n_grid_points, n_grid_points)
            The average predictions for all points of `feature_values[0]` and
            `feature_values[1]` for some given features for all samples in `X`.
        feature_values : seq of 1d array
            A sequence of array of the feature values for which the predictions
            have been computed.
        feature_idx : tuple of int
            The indices of the target features
        ax : Matplotlib axes
            The axis on which to plot the ICE and PDP lines.
        pd_plot_idx : int
            The sequential index of the plot. It will be unraveled to find the
            matching 2D position in the grid layout.
        Z_level : ndarray of shape (8, 8)
            The Z-level used to encode the average predictions.
        contour_kw : dict
            Dict with keywords passed when plotting the contours.
        """
        from matplotlib import transforms  # noqa

        XX, YY = np.meshgrid(feature_values[0], feature_values[1])
        Z = avg_preds[self.target_idx].T
        CS = ax.contour(XX, YY, Z, levels=Z_level, linewidths=0.5, colors="k")
        contour_idx = np.unravel_index(pd_plot_idx, self.contours_.shape)
        self.contours_[contour_idx] = ax.contourf(
            XX,
            YY,
            Z,
            levels=Z_level,
            vmax=Z_level[-1],
            vmin=Z_level[0],
            **contour_kw,
        )
        ax.clabel(CS, fmt="%2.2f", colors="k", fontsize=10, inline=True)

        trans = transforms.blended_transform_factory(ax.transData,
                                                     ax.transAxes)
        # create the decile line for the vertical axis
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
        vlines_idx = np.unravel_index(pd_plot_idx, self.deciles_vlines_.shape)
        self.deciles_vlines_[vlines_idx] = ax.vlines(
            self.deciles[feature_idx[0]],
            0,
            0.05,
            transform=trans,
            color="k",
        )
        # create the decile line for the horizontal axis
        hlines_idx = np.unravel_index(pd_plot_idx, self.deciles_hlines_.shape)
        self.deciles_hlines_[hlines_idx] = ax.hlines(
            self.deciles[feature_idx[1]],
            0,
            0.05,
            transform=trans,
            color="k",
        )
        # reset xlim and ylim since they are overwritten by hlines and vlines
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        # set xlabel if it is not already set
        if not ax.get_xlabel():
            ax.set_xlabel(self.feature_names[feature_idx[0]])
        ax.set_ylabel(self.feature_names[feature_idx[1]])
Esempio n. 54
0
    def get_spine_transform(self):
        """Return the spine transform."""
        self._ensure_position_is_set()

        position = self._position
        if isinstance(position, str):
            if position == 'center':
                position = ('axes', 0.5)
            elif position == 'zero':
                position = ('data', 0)
        assert len(position) == 2, 'position should be 2-tuple'
        position_type, amount = position
        _api.check_in_list(['axes', 'outward', 'data'],
                           position_type=position_type)
        if self.spine_type in ['left', 'right']:
            base_transform = self.axes.get_yaxis_transform(which='grid')
        elif self.spine_type in ['top', 'bottom']:
            base_transform = self.axes.get_xaxis_transform(which='grid')
        else:
            raise ValueError(f'unknown spine spine_type: {self.spine_type!r}')

        if position_type == 'outward':
            if amount == 0:  # short circuit commonest case
                return base_transform
            else:
                offset_vec = {
                    'left': (-1, 0),
                    'right': (1, 0),
                    'bottom': (0, -1),
                    'top': (0, 1),
                }[self.spine_type]
                # calculate x and y offset in dots
                offset_dots = amount * np.array(offset_vec) / 72
                return (base_transform + mtransforms.ScaledTranslation(
                    *offset_dots, self.figure.dpi_scale_trans))
        elif position_type == 'axes':
            if self.spine_type in ['left', 'right']:
                # keep y unchanged, fix x at amount
                return (
                    mtransforms.Affine2D.from_values(0, 0, 0, 1, amount, 0) +
                    base_transform)
            elif self.spine_type in ['bottom', 'top']:
                # keep x unchanged, fix y at amount
                return (
                    mtransforms.Affine2D.from_values(1, 0, 0, 0, 0, amount) +
                    base_transform)
        elif position_type == 'data':
            if self.spine_type in ('right', 'top'):
                # The right and top spines have a default position of 1 in
                # axes coordinates.  When specifying the position in data
                # coordinates, we need to calculate the position relative to 0.
                amount -= 1
            if self.spine_type in ('left', 'right'):
                return mtransforms.blended_transform_factory(
                    mtransforms.Affine2D().translate(amount, 0) +
                    self.axes.transData, self.axes.transData)
            elif self.spine_type in ('bottom', 'top'):
                return mtransforms.blended_transform_factory(
                    self.axes.transData,
                    mtransforms.Affine2D().translate(0, amount) +
                    self.axes.transData)
plt.legend(loc='best', prop={'size': 15})
plt.grid(True)
fig.savefig('Loss.png')

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)

plt.xlabel("Number of Steps", size=15)
plt.ylabel("Accuracy", size=15)
plt.title("Plot of Accuracy vs Number of Steps", size=15)

ax.plot(tmp, train_accuracy, 'b', label="Training Accuracy", linewidth=2.5)
ax.plot(tmp, val_accuracy, 'r', label="Validation Accuracy", linewidth=2.5)
ax.axhline(y=acc_test, color='g', linewidth=2.5, label='Test accuracy')

trans = transforms.blended_transform_factory(
    ax.get_yticklabels()[0].get_transform(), ax.transData)
ax.text(0,
        acc_test,
        "{:.4f}".format(acc_test),
        color="g",
        transform=trans,
        ha="right",
        va="center",
        size=12)

plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
ax.xaxis.labelpad = 10
ax.yaxis.labelpad = 10
plt.legend(loc='best', prop={'size': 15})
plt.grid(True)
Esempio n. 56
0
    def _add_gridline_label(self, value, axis, upper_end):
        """
        Create a Text artist on our axes for a gridline label.

        Parameters
        ----------
        value
            Coordinate value of this gridline.  The text contains this
            value, and is positioned centred at that point.
        axis
            Which axis the label is on: 'x' or 'y'.
        upper_end: bool
            If True, place at the maximum of the "other" coordinate (Axes
            coordinate == 1.0).  Else 'lower' end (Axes coord = 0.0).

        """
        transform = self._crs_transform()
        if upper_end:
            shift_scale = 1
        else:
            shift_scale = -1
        if axis == 'x':
            x = value
            y = 1.0 if upper_end else 0.0
            h_align = 'center'
            v_align = 'bottom' if upper_end else 'top'
            tr_x = transform
            tr_y = self.axes.transAxes + \
                mtrans.ScaledTranslation(
                    0.0,
                    shift_scale * self.xpadding * (1.0 / 72),
                    self.axes.figure.dpi_scale_trans)
            str_value = self.xformatter(value)
            user_label_style = self.xlabel_style
        elif axis == 'y':
            y = value
            x = 1.0 if upper_end else 0.0
            if matplotlib.__version__ > '2.0':
                v_align = 'center_baseline'
            else:
                v_align = 'center'
            h_align = 'left' if upper_end else 'right'
            tr_y = transform
            tr_x = self.axes.transAxes + \
                mtrans.ScaledTranslation(
                    shift_scale * self.ypadding * (1.0 / 72),
                    0.0,
                    self.axes.figure.dpi_scale_trans)
            str_value = self.yformatter(value)
            user_label_style = self.ylabel_style
        else:
            raise ValueError(
                "Unknown axis, {!r}, must be either 'x' or 'y'".format(axis))

        # Make a 'blended' transform for label text positioning.
        # One coord is geographic, and the other a plain Axes
        # coordinate with an appropriate offset.
        label_transform = mtrans.blended_transform_factory(x_transform=tr_x,
                                                           y_transform=tr_y)

        label_style = {
            'verticalalignment': v_align,
            'horizontalalignment': h_align,
        }
        label_style.update(user_label_style)

        # Create and add a Text artist with these properties
        text_artist = mtext.Text(x,
                                 y,
                                 str_value,
                                 clip_on=False,
                                 transform=label_transform,
                                 **label_style)
        if axis == 'x':
            self.xlabel_artists.append(text_artist)
        elif axis == 'y':
            self.ylabel_artists.append(text_artist)
        self.axes.add_artist(text_artist)
Esempio n. 57
0
def plot_partial_dependence(estimator,
                            X,
                            features,
                            feature_names=None,
                            target=None,
                            response_method='auto',
                            n_cols=3,
                            grid_resolution=100,
                            percentiles=(0.05, 0.95),
                            method='auto',
                            n_jobs=None,
                            verbose=0,
                            fig=None,
                            line_kw=None,
                            contour_kw=None):
    """Partial dependence plots.

    The ``len(features)`` plots are arranged in a grid with ``n_cols``
    columns. Two-way partial dependence plots are plotted as contour plots.

    Read more in the :ref:`User Guide <partial_dependence>`.

    Parameters
    ----------
    estimator : BaseEstimator
        A fitted estimator object implementing `predict`, `predict_proba`,
        or `decision_function`. Multioutput-multiclass classifiers are not
        supported.
    X : array-like, shape (n_samples, n_features)
        The data to use to build the grid of values on which the dependence
        will be evaluated. This is usually the training data.
    features : list of {int, str, pair of int, pair of str}
        The target features for which to create the PDPs.
        If features[i] is an int or a string, a one-way PDP is created; if
        features[i] is a tuple, a two-way PDP is created. Each tuple must be
        of size 2.
        if any entry is a string, then it must be in ``feature_names``.
    feature_names : seq of str, shape (n_features,), optional
        Name of each feature; feature_names[i] holds the name of the feature
        with index i. By default, the name of the feature corresponds to
        their numerical index.
    target : int, optional (default=None)
        - In a multiclass setting, specifies the class for which the PDPs
          should be computed. Note that for binary classification, the
          positive class (index 1) is always used.
        - In a multioutput setting, specifies the task for which the PDPs
          should be computed
        Ignored in binary classification or classical regression settings.
    response_method : 'auto', 'predict_proba' or 'decision_function', \
            optional (default='auto') :
        Specifies whether to use :term:`predict_proba` or
        :term:`decision_function` as the target response. For regressors
        this parameter is ignored and the response is always the output of
        :term:`predict`. By default, :term:`predict_proba` is tried first
        and we revert to :term:`decision_function` if it doesn't exist. If
        ``method`` is 'recursion', the response is always the output of
        :term:`decision_function`.
    n_cols : int, optional (default=3)
        The maximum number of columns in the grid plot.
    grid_resolution : int, optional (default=100)
        The number of equally spaced points on the axes of the plots, for each
        target feature.
    percentiles : tuple of float, optional (default=(0.05, 0.95))
        The lower and upper percentile used to create the extreme values
        for the PDP axes. Must be in [0, 1].
    method : str, optional (default='auto')
        The method to use to calculate the partial dependence predictions:

        - 'recursion' is only supported for gradient boosting estimator (namely
          :class:`GradientBoostingClassifier<mrex.ensemble.GradientBoostingClassifier>`,
          :class:`GradientBoostingRegressor<mrex.ensemble.GradientBoostingRegressor>`,
          :class:`HistGradientBoostingClassifier<mrex.ensemble.HistGradientBoostingClassifier>`,
          :class:`HistGradientBoostingRegressor<mrex.ensemble.HistGradientBoostingRegressor>`)
          but is more efficient in terms of speed.
          With this method, ``X`` is optional and is only used to build the
          grid and the partial dependences are computed using the training
          data. This method does not account for the ``init`` predicor of
          the boosting process, which may lead to incorrect values (see
          warning below. With this method, the target response of a
          classifier is always the decision function, not the predicted
          probabilities.

        - 'brute' is supported for any estimator, but is more
          computationally intensive.

        - 'auto':
          - 'recursion' is used for estimators that supports it.
          - 'brute' is used for all other estimators.
    n_jobs : int, optional (default=None)
        The number of CPUs to use to compute the partial dependences.
        ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
        ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
        for more details.
    verbose : int, optional (default=0)
        Verbose output during PD computations.
    fig : Matplotlib figure object, optional (default=None)
        A figure object onto which the plots will be drawn, after the figure
        has been cleared. By default, a new one is created.
    line_kw : dict, optional
        Dict with keywords passed to the ``matplotlib.pyplot.plot`` call.
        For one-way partial dependence plots.
    contour_kw : dict, optional
        Dict with keywords passed to the ``matplotlib.pyplot.plot`` call.
        For two-way partial dependence plots.

    Examples
    --------
    >>> from mrex.datasets import make_friedman1
    >>> from mrex.ensemble import GradientBoostingRegressor
    >>> X, y = make_friedman1()
    >>> clf = GradientBoostingRegressor(n_estimators=10).fit(X, y)
    >>> plot_partial_dependence(clf, X, [0, (0, 1)]) #doctest: +SKIP

    See also
    --------
    mrex.inspection.partial_dependence: Return raw partial
      dependence values

    Warnings
    --------
    The 'recursion' method only works for gradient boosting estimators, and
    unlike the 'brute' method, it does not account for the ``init``
    predictor of the boosting process. In practice this will produce the
    same values as 'brute' up to a constant offset in the target response,
    provided that ``init`` is a consant estimator (which is the default).
    However, as soon as ``init`` is not a constant estimator, the partial
    dependence values are incorrect for 'recursion'. This is not relevant for
    :class:`HistGradientBoostingClassifier
    <mrex.ensemble.HistGradientBoostingClassifier>` and
    :class:`HistGradientBoostingRegressor
    <mrex.ensemble.HistGradientBoostingRegressor>`, which do not have an
    ``init`` parameter.
    """
    check_matplotlib_support('plot_partial_dependence')  # noqa
    import matplotlib.pyplot as plt  # noqa
    from matplotlib import transforms  # noqa
    from matplotlib.ticker import MaxNLocator  # noqa
    from matplotlib.ticker import ScalarFormatter  # noqa

    # set target_idx for multi-class estimators
    if hasattr(estimator, 'classes_') and np.size(estimator.classes_) > 2:
        if target is None:
            raise ValueError('target must be specified for multi-class')
        target_idx = np.searchsorted(estimator.classes_, target)
        if (not (0 <= target_idx < len(estimator.classes_))
                or estimator.classes_[target_idx] != target):
            raise ValueError(
                'target not in est.classes_, got {}'.format(target))
    else:
        # regression and binary classification
        target_idx = 0

    X = check_array(X)
    n_features = X.shape[1]

    # convert feature_names to list
    if feature_names is None:
        # if feature_names is None, use feature indices as name
        feature_names = [str(i) for i in range(n_features)]
    elif isinstance(feature_names, np.ndarray):
        feature_names = feature_names.tolist()
    if len(set(feature_names)) != len(feature_names):
        raise ValueError('feature_names should not contain duplicates.')

    def convert_feature(fx):
        if isinstance(fx, str):
            try:
                fx = feature_names.index(fx)
            except ValueError:
                raise ValueError('Feature %s not in feature_names' % fx)
        return int(fx)

    # convert features into a seq of int tuples
    tmp_features = []
    for fxs in features:
        if isinstance(fxs, (numbers.Integral, str)):
            fxs = (fxs, )
        try:
            fxs = [convert_feature(fx) for fx in fxs]
        except TypeError:
            raise ValueError('Each entry in features must be either an int, '
                             'a string, or an iterable of size at most 2.')
        if not (1 <= np.size(fxs) <= 2):
            raise ValueError('Each entry in features must be either an int, '
                             'a string, or an iterable of size at most 2.')

        tmp_features.append(fxs)

    features = tmp_features

    names = []
    try:
        for fxs in features:
            names_ = []
            # explicit loop so "i" is bound for exception below
            for i in fxs:
                names_.append(feature_names[i])
            names.append(names_)
    except IndexError:
        raise ValueError('All entries of features must be less than '
                         'len(feature_names) = {0}, got {1}.'.format(
                             len(feature_names), i))

    # compute averaged predictions
    pd_result = Parallel(n_jobs=n_jobs, verbose=verbose)(
        delayed(partial_dependence)(estimator,
                                    X,
                                    fxs,
                                    response_method=response_method,
                                    method=method,
                                    grid_resolution=grid_resolution,
                                    percentiles=percentiles)
        for fxs in features)

    # For multioutput regression, we can only check the validity of target
    # now that we have the predictions.
    # Also note: as multiclass-multioutput classifiers are not supported,
    # multiclass and multioutput scenario are mutually exclusive. So there is
    # no risk of overwriting target_idx here.
    avg_preds, _ = pd_result[0]  # checking the first result is enough
    if is_regressor(estimator) and avg_preds.shape[0] > 1:
        if target is None:
            raise ValueError(
                'target must be specified for multi-output regressors')
        if not 0 <= target <= avg_preds.shape[0]:
            raise ValueError(
                'target must be in [0, n_tasks], got {}.'.format(target))
        target_idx = target

    # get global min and max values of PD grouped by plot type
    pdp_lim = {}
    for avg_preds, values in pd_result:
        min_pd = avg_preds[target_idx].min()
        max_pd = avg_preds[target_idx].max()
        n_fx = len(values)
        old_min_pd, old_max_pd = pdp_lim.get(n_fx, (min_pd, max_pd))
        min_pd = min(min_pd, old_min_pd)
        max_pd = max(max_pd, old_max_pd)
        pdp_lim[n_fx] = (min_pd, max_pd)

    # create contour levels for two-way plots
    if 2 in pdp_lim:
        Z_level = np.linspace(*pdp_lim[2], num=8)

    if fig is None:
        fig = plt.figure()
    else:
        fig.clear()

    if line_kw is None:
        line_kw = {'color': 'green'}
    if contour_kw is None:
        contour_kw = {}

    n_cols = min(n_cols, len(features))
    n_rows = int(np.ceil(len(features) / float(n_cols)))
    axs = []
    for i, fx, name, (avg_preds, values) in zip(count(), features, names,
                                                pd_result):
        ax = fig.add_subplot(n_rows, n_cols, i + 1)

        if len(values) == 1:
            ax.plot(values[0], avg_preds[target_idx].ravel(), **line_kw)
        else:
            # make contour plot
            assert len(values) == 2
            XX, YY = np.meshgrid(values[0], values[1])
            Z = avg_preds[target_idx].T
            CS = ax.contour(XX,
                            YY,
                            Z,
                            levels=Z_level,
                            linewidths=0.5,
                            colors='k')
            ax.contourf(XX,
                        YY,
                        Z,
                        levels=Z_level,
                        vmax=Z_level[-1],
                        vmin=Z_level[0],
                        alpha=0.75,
                        **contour_kw)
            ax.clabel(CS, fmt='%2.2f', colors='k', fontsize=10, inline=True)

        # plot data deciles + axes labels
        deciles = mquantiles(X[:, fx[0]], prob=np.arange(0.1, 1.0, 0.1))
        trans = transforms.blended_transform_factory(ax.transData,
                                                     ax.transAxes)
        ylim = ax.get_ylim()
        ax.vlines(deciles, [0], 0.05, transform=trans, color='k')
        ax.set_xlabel(name[0])
        ax.set_ylim(ylim)

        # prevent x-axis ticks from overlapping
        ax.xaxis.set_major_locator(MaxNLocator(nbins=6, prune='lower'))
        tick_formatter = ScalarFormatter()
        tick_formatter.set_powerlimits((-3, 4))
        ax.xaxis.set_major_formatter(tick_formatter)

        if len(values) > 1:
            # two-way PDP - y-axis deciles + labels
            deciles = mquantiles(X[:, fx[1]], prob=np.arange(0.1, 1.0, 0.1))
            trans = transforms.blended_transform_factory(
                ax.transAxes, ax.transData)
            xlim = ax.get_xlim()
            ax.hlines(deciles, [0], 0.05, transform=trans, color='k')
            ax.set_ylabel(name[1])
            # hline erases xlim
            ax.set_xlim(xlim)
        else:
            ax.set_ylabel('Partial dependence')

        if len(values) == 1:
            ax.set_ylim(pdp_lim[1])
        axs.append(ax)

    fig.subplots_adjust(bottom=0.15,
                        top=0.7,
                        left=0.1,
                        right=0.95,
                        wspace=0.4,
                        hspace=0.3)
Esempio n. 58
0
def dot_plot(points, intervals=None, lines=None, sections=None,
             styles=None, marker_props=None, line_props=None,
             split_names=None, section_order=None, line_order=None,
             stacked=False, styles_order=None, striped=False,
             horizontal=True, show_names="both",
             fmt_left_name=None, fmt_right_name=None,
             show_section_titles=None, ax=None):
    """
    Produce a dotplot similar in style to those in Cleveland's
    "Visualizing Data" book.  These are also known as "forest plots".

    Parameters
    ----------
    points : array_like
        The quantitative values to be plotted as markers.
    intervals : array_like
        The intervals to be plotted around the points.  The elements
        of `intervals` are either scalars or sequences of length 2.  A
        scalar indicates the half width of a symmetric interval.  A
        sequence of length 2 contains the left and right half-widths
        (respectively) of a nonsymmetric interval.  If None, no
        intervals are drawn.
    lines : array_like
        A grouping variable indicating which points/intervals are
        drawn on a common line.  If None, each point/interval appears
        on its own line.
    sections : array_like
        A grouping variable indicating which lines are grouped into
        sections.  If None, everything is drawn in a single section.
    styles : array_like
        A grouping label defining the plotting style of the markers
        and intervals.
    marker_props : dict
        A dictionary mapping style codes (the values in `styles`) to
        dictionaries defining key/value pairs to be passed as keyword
        arguments to `plot` when plotting markers.  Useful keyword
        arguments are "color", "marker", and "ms" (marker size).
    line_props : dict
        A dictionary mapping style codes (the values in `styles`) to
        dictionaries defining key/value pairs to be passed as keyword
        arguments to `plot` when plotting interval lines.  Useful
        keyword arguments are "color", "linestyle", "solid_capstyle",
        and "linewidth".
    split_names : string
        If not None, this is used to split the values of `lines` into
        substrings that are drawn in the left and right margins,
        respectively.  If None, the values of `lines` are drawn in the
        left margin.
    section_order : array_like
        The section labels in the order in which they appear in the
        dotplot.
    line_order : array_like
        The line labels in the order in which they appear in the
        dotplot.
    stacked : boolean
        If True, when multiple points or intervals are drawn on the
        same line, they are offset from each other.
    styles_order : array_like
        If stacked=True, this is the order in which the point styles
        on a given line are drawn from top to bottom (if horizontal
        is True) or from left to right (if horiontal is False).  If
        None (default), the order is lexical.
    striped : boolean
        If True, every other line is enclosed in a shaded box.
    horizontal : boolean
        If True (default), the lines are drawn horizontally, otherwise
        they are drawn vertically.
    show_names : string
        Determines whether labels (names) are shown in the left and/or
        right margins (top/bottom margins if `horizontal` is True).
        If `both`, labels are drawn in both margins, if 'left', labels
        are drawn in the left or top margin.  If `right`, labels are
        drawn in the right or bottom margin.
    fmt_left_name : function
        The left/top margin names are passed through this function
        before drawing on the plot.
    fmt_right_name : function
        The right/bottom marginnames are passed through this function
        before drawing on the plot.
    show_section_titles : bool or None
        If None, section titles are drawn only if there is more than
        one section.  If False/True, section titles are never/always
        drawn, respectively.
    ax : matplotlib.axes
        The axes on which the dotplot is drawn.  If None, a new axes
        is created.

    Returns
    -------
    fig : Figure
        The figure given by `ax.figure` or a new instance.

    Notes
    -----
    `points`, `intervals`, `lines`, `sections`, `styles` must all have
    the same length whenever present.

    Examples
    --------
    This is a simple dotplot with one point per line:
    >>> dot_plot(points=point_values)

    This dotplot has labels on the lines (if elements in
    `label_values` are repeated, the corresponding points appear on
    the same line):
    >>> dot_plot(points=point_values, lines=label_values)

    References
    ----------
      * Cleveland, William S. (1993). "Visualizing Data". Hobart
        Press.
      * Jacoby, William G. (2006) "The Dot Plot: A Graphical Display
        for Labeled Quantitative Values." The Political Methodologist
        14(1): 6-14.
    """

    import matplotlib.transforms as transforms

    fig, ax = utils.create_mpl_ax(ax)

    # Convert to numpy arrays if that is not what we are given.
    points = np.asarray(points)
    asarray_or_none = lambda x : None if x is None else np.asarray(x)
    intervals = asarray_or_none(intervals)
    lines = asarray_or_none(lines)
    sections = asarray_or_none(sections)
    styles = asarray_or_none(styles)

    # Total number of points
    npoint = len(points)

    # Set default line values if needed
    if lines is None:
        lines = np.arange(npoint)

    # Set default section values if needed
    if sections is None:
        sections = np.zeros(npoint)

    # Set default style values if needed
    if styles is None:
        styles = np.zeros(npoint)

    # The vertical space (in inches) for a section title
    section_title_space = 0.5

    # The number of sections
    nsect = len(set(sections))
    if section_order is not None:
        nsect = len(set(section_order))

    # The number of section titles
    if show_section_titles == False:
        draw_section_titles = False
        nsect_title = 0
    elif show_section_titles == True:
        draw_section_titles = True
        nsect_title = nsect
    else:
        draw_section_titles = nsect > 1
        nsect_title = nsect if nsect > 1 else 0

    # The total vertical space devoted to section titles.
    section_space_total = section_title_space * nsect_title

    # Add a bit of room so that points that fall at the axis limits
    # are not cut in half.
    ax.set_xmargin(0.02)
    ax.set_ymargin(0.02)

    if section_order is None:
        lines0 = list(set(sections))
        lines0.sort()
    else:
        lines0 = section_order

    if line_order is None:
        lines1 = list(set(lines))
        lines1.sort()
    else:
        lines1 = line_order

    # A map from (section,line) codes to index positions.
    lines_map = {}
    for i in range(npoint):
        if section_order is not None and sections[i] not in section_order:
            continue
        if line_order is not None and lines[i] not in line_order:
            continue
        ky = (sections[i], lines[i])
        if ky not in lines_map:
            lines_map[ky] = []
        lines_map[ky].append(i)

    # Get the size of the axes on the parent figure in inches
    bbox = ax.get_window_extent().transformed(
        fig.dpi_scale_trans.inverted())
    awidth, aheight = bbox.width, bbox.height

    # The number of lines in the plot.
    nrows = len(lines_map)

    # The positions of the lowest and highest guideline in axes
    # coordinates (for horizontal dotplots), or the leftmost and
    # rightmost guidelines (for vertical dotplots).
    bottom, top = 0, 1

    if horizontal:
        # x coordinate is data, y coordinate is axes
        trans = transforms.blended_transform_factory(ax.transData,
                                                     ax.transAxes)
    else:
        # x coordinate is axes, y coordinate is data
        trans = transforms.blended_transform_factory(ax.transAxes,
                                                     ax.transData)

    # Space used for a section title, in axes coordinates
    title_space_axes = section_title_space / aheight

    # Space between lines
    if horizontal:
        dpos = (top - bottom - nsect_title*title_space_axes) /\
            float(nrows)
    else:
        dpos = (top - bottom) / float(nrows)

    # Determine the spacing for stacked points
    if styles_order is not None:
        style_codes = styles_order
    else:
        style_codes = list(set(styles))
        style_codes.sort()
    # Order is top to bottom for horizontal plots, so need to
    # flip.
    if horizontal:
        style_codes = style_codes[::-1]
    # nval is the maximum number of points on one line.
    nval = len(style_codes)
    if nval > 1:
        stackd = dpos / (2.5*(float(nval)-1))
    else:
        stackd = 0.

    # Map from style code to its integer position
    style_codes_map = {x: style_codes.index(x) for x in style_codes}

    # Setup default marker styles
    colors = ["r", "g", "b", "y", "k", "purple", "orange"]
    if marker_props is None:
        marker_props = {x: {} for x in style_codes}
    for j in range(nval):
        sc = style_codes[j]
        if "color" not in marker_props[sc]:
            marker_props[sc]["color"] = colors[j % len(colors)]
        if "marker" not in marker_props[sc]:
            marker_props[sc]["marker"] = "o"
        if "ms" not in marker_props[sc]:
            marker_props[sc]["ms"] = 10 if stackd == 0 else 6

    # Setup default line styles
    if line_props is None:
        line_props = {x: {} for x in style_codes}
    for j in range(nval):
        sc = style_codes[j]
        if "color" not in line_props[sc]:
            line_props[sc]["color"] = "grey"
        if "linewidth" not in line_props[sc]:
            line_props[sc]["linewidth"] = 2 if stackd > 0 else 8

    if horizontal:
        # The vertical position of the first line.
        pos = top - dpos/2 if nsect == 1 else top
    else:
        # The horizontal position of the first line.
        pos = bottom + dpos/2

    # Points that have already been labeled
    labeled = set()

    # Positions of the y axis grid lines
    ticks = []

    # Loop through the sections
    for k0 in lines0:

        # Draw a section title
        if draw_section_titles:

            if horizontal:

                y0 = pos + dpos/2 if k0 == lines0[0] else pos

                ax.fill_between((0, 1), (y0,y0),
                                (pos-0.7*title_space_axes,
                                 pos-0.7*title_space_axes),
                                color='darkgrey',
                                transform=ax.transAxes,
                                zorder=1)

                txt = ax.text(0.5, pos - 0.35*title_space_axes, k0,
                              horizontalalignment='center',
                              verticalalignment='center',
                              transform=ax.transAxes)
                txt.set_fontweight("bold")
                pos -= title_space_axes

            else:

                m = len([k for k in lines_map if k[0] == k0])

                ax.fill_between((pos-dpos/2+0.01,
                                 pos+(m-1)*dpos+dpos/2-0.01),
                                (1.01,1.01), (1.06,1.06),
                                color='darkgrey',
                                transform=ax.transAxes,
                                zorder=1, clip_on=False)

                txt = ax.text(pos + (m-1)*dpos/2, 1.02, k0,
                              horizontalalignment='center',
                              verticalalignment='bottom',
                              transform=ax.transAxes)
                txt.set_fontweight("bold")

        jrow = 0
        for k1 in lines1:

            # No data to plot
            if (k0, k1) not in lines_map:
                continue

            # Draw the guideline
            if horizontal:
                ax.axhline(pos, color='grey')
            else:
                ax.axvline(pos, color='grey')

            # Set up the labels
            if split_names is not None:
                us = k1.split(split_names)
                if len(us) >= 2:
                    left_label, right_label = us[0], us[1]
                else:
                    left_label, right_label = k1, None
            else:
                left_label, right_label = k1, None

            if fmt_left_name is not None:
                left_label = fmt_left_name(left_label)

            if fmt_right_name is not None:
                right_label = fmt_right_name(right_label)

            # Draw the stripe
            if striped and jrow % 2 == 0:
                if horizontal:
                    ax.fill_between((0, 1), (pos-dpos/2, pos-dpos/2),
                                    (pos+dpos/2, pos+dpos/2),
                                    color='lightgrey',
                                    transform=ax.transAxes,
                                    zorder=0)
                else:
                    ax.fill_between((pos-dpos/2, pos+dpos/2),
                                    (0, 0), (1, 1),
                                    color='lightgrey',
                                    transform=ax.transAxes,
                                    zorder=0)

            jrow += 1

            # Draw the left margin label
            if show_names.lower() in ("left", "both"):
                if horizontal:
                    ax.text(-0.1/awidth, pos, left_label,
                            horizontalalignment="right",
                            verticalalignment='center',
                            transform=ax.transAxes,
                            family='monospace')
                else:
                    ax.text(pos, -0.1/aheight, left_label,
                            horizontalalignment="center",
                            verticalalignment='top',
                            transform=ax.transAxes,
                            family='monospace')

            # Draw the right margin label
            if show_names.lower() in ("right", "both"):
                if right_label is not None:
                    if horizontal:
                        ax.text(1 + 0.1/awidth, pos, right_label,
                                horizontalalignment="left",
                                verticalalignment='center',
                                transform=ax.transAxes,
                                family='monospace')
                    else:
                        ax.text(pos, 1 + 0.1/aheight, right_label,
                                horizontalalignment="center",
                                verticalalignment='bottom',
                                transform=ax.transAxes,
                                family='monospace')

            # Save the vertical position so that we can place the
            # tick marks
            ticks.append(pos)

            # Loop over the points in one line
            for ji,jp in enumerate(lines_map[(k0,k1)]):

                # Calculate the vertical offset
                yo = 0
                if stacked:
                    yo = -dpos/5 + style_codes_map[styles[jp]]*stackd

                pt = points[jp]

                # Plot the interval
                if intervals is not None:

                    # Symmetric interval
                    if np.isscalar(intervals[jp]):
                        lcb, ucb = pt - intervals[jp],\
                            pt + intervals[jp]

                    # Nonsymmetric interval
                    else:
                        lcb, ucb = pt - intervals[jp][0],\
                            pt + intervals[jp][1]

                    # Draw the interval
                    if horizontal:
                        ax.plot([lcb, ucb], [pos+yo, pos+yo], '-',
                                transform=trans,
                                **line_props[styles[jp]])
                    else:
                        ax.plot([pos+yo, pos+yo], [lcb, ucb], '-',
                                transform=trans,
                                **line_props[styles[jp]])


                # Plot the point
                sl = styles[jp]
                sll = sl if sl not in labeled else None
                labeled.add(sl)
                if horizontal:
                    ax.plot([pt,], [pos+yo,], ls='None',
                            transform=trans, label=sll,
                            **marker_props[sl])
                else:
                    ax.plot([pos+yo,], [pt,], ls='None',
                            transform=trans, label=sll,
                            **marker_props[sl])

            if horizontal:
                pos -= dpos
            else:
                pos += dpos

    # Set up the axis
    if horizontal:
        ax.xaxis.set_ticks_position("bottom")
        ax.yaxis.set_ticks_position("none")
        ax.set_yticklabels([])
        ax.spines['left'].set_color('none')
        ax.spines['right'].set_color('none')
        ax.spines['top'].set_color('none')
        ax.spines['bottom'].set_position(('axes', -0.1/aheight))
        ax.set_ylim(0, 1)
        ax.yaxis.set_ticks(ticks)
        ax.autoscale_view(scaley=False, tight=True)
    else:
        ax.yaxis.set_ticks_position("left")
        ax.xaxis.set_ticks_position("none")
        ax.set_xticklabels([])
        ax.spines['bottom'].set_color('none')
        ax.spines['right'].set_color('none')
        ax.spines['top'].set_color('none')
        ax.spines['left'].set_position(('axes', -0.1/awidth))
        ax.set_xlim(0, 1)
        ax.xaxis.set_ticks(ticks)
        ax.autoscale_view(scalex=False, tight=True)

    return fig
Esempio n. 59
0
    tmin=time.min()//1
    tmax=time.max()//1+1

    #Create a 12x8 inch figure
    fig = plt.figure(figsize=(12,8))
    # Create a subplot for the left eye, 2,1,1 means the subplot
    # grid will be 2 rows and 1 column, and we are about to create
    # the subplot for row 1 of 2
    #"Left Eye Position"
    left_axis = fig.add_subplot(2,1,1)
    left_axis.plot(time,left_gaze_x,label="X Gaze")
    left_axis.plot(time,left_gaze_y,label="Y Gaze")
    plt.xticks(np.arange(tmin,tmax,0.5),rotation='vertical')
    # Fill in missing eye data areas of the plot with a vertical bar the full
    # height of the sub plot.
    trans = mtransforms.blended_transform_factory(left_axis.transData, left_axis.transAxes)
    left_axis.fill_between(time, 0, 1, where=left_pupil_size==0,
            facecolor='DarkRed',
            alpha=0.5, transform=trans)
    #text(0.5, 0.95, 'test', transform=fig.transFigure, horizontalalignment='center')
    left_axis.set_ylabel('Position (pixels)')
    # Left Eye Sample Sub Plot
    left_axis.set_title("Left Eye Position", fontsize=12)

    # Resize the plot x axis by 85% so that the legend , which is outside
    # the plot, will still fit in the matplotlib window.
    #
    box = left_axis.get_position()
    left_axis.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    fontP = FontProperties()
    fontP.set_size('small')
Esempio n. 60
0
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.transforms import blended_transform_factory, ScaledTranslation

fig = plt.figure(figsize=(6, 4))

ax = fig.add_subplot(1, 1, 1, aspect=1)
ax.set_xlim(0, 10)
ax.set_xticks(range(11))
ax.set_ylim(0, 5)
ax.set_xticks(range(11))

point = 1 / 72
fontsize = 12
dx, dy = 0, -1.5 * fontsize * point
offset = ScaledTranslation(dx, dy, fig.dpi_scale_trans)
transform = blended_transform_factory(ax.transData, ax.transAxes + offset)

for x in range(11):
    plt.text(x,
             0,
             "↑",
             transform=transform,
             ha="center",
             va="top",
             fontsize=fontsize)

plt.tight_layout()
plt.savefig("../../figures/coordinates/transforms-blend.pdf")
plt.show()