コード例 #1
0
ファイル: plots.py プロジェクト: cyang-2014/plastid
def stacked_bar(data,
                axes=None,
                labels=None,
                lighten_by=0.1,
                cmap=None,
                **kwargs):
    """Create a stacked bar graph
    
    Parameters
    ----------
    data : :class:`numpy.ndarray`
        Array of data, in which each row is a stack, each column a value in that stack.

    axes : :class:`matplotlib.axes.Axes` or `None`, optional
        Axes in which to place plot. If `None`, a new figure is generated.
        (Default: `None`)
        
    labels : list, optional
        Labels for each stack. If `None`, stacks are labeled sequentially by number.
        (Default: `None`)
    
    lighten_by : float, optional
        Amount by which to lighten sequential blocks in each stack. (Default: 0.10)
    
    cmap : :class:`matplotlib.colors.Colormap`, optional
        Colormap from which to generate bar colors. If supplied, will override
        any `color` attribute in `**kwargs`. (Default: `None`)
        
    **kwargs : keyword arguments
        Other keyword arguments to pass to :func:`matplotlib.pyplot.bar`

    Returns
    -------
    :class:`matplotlib.figure.Figure`
        Parent figure of axes
    
    :class:`matplotlib.axes.Axes`
        Axes containing plot
    """
    fig, ax = get_fig_axes(axes)
    rows, cols = data.shape
    labels = labels if labels is not None else range(rows)
    defaults = [("align", "center"), ("width", 0.8)]

    if cmap is not None:
        kwargs["color"] = cmap(numpy.linspace(0, 1.0, num=10))
    elif kwargs.get("color", None) is None:
        kwargs["color"] = [next(get_color_cycle(ax)) for _ in range(rows)]

    x = numpy.arange(rows) + 0.5
    xaxis = ax.xaxis
    xaxis.set_ticks(x)
    xaxis.set_ticklabels(labels)
    bottoms = numpy.zeros(rows)

    for k, v in defaults:
        if k not in kwargs:
            kwargs[k] = v

    for i in range(cols):
        color = kwargs["color"]
        if i > 0:
            kwargs["color"] = lighten(color, amt=lighten_by)

        heights = data[:, i]
        plt.bar(x, heights, bottom=bottoms, **kwargs)
        heights.shape
        bottoms += heights

    ax.set_xlim(-0.5, rows + 0.5)

    return fig, ax
コード例 #2
0
ファイル: plots.py プロジェクト: cyang-2014/plastid
def triangle_plot(data,
                  axes=None,
                  fn="scatter",
                  vertex_labels=None,
                  grid=None,
                  clip=True,
                  do_setup=True,
                  **kwargs):
    """Plot data lying in a plane x + y + z = k in a homogenous triangular space.

    Parameters
    ----------
    data : :class:`numpy.ndarray`
        Mx2 or Mx3 list or array of points in triangular space, where
        the first column is the first coordinate, the second column
        the second, and the third, the third.

    axes : :class:`matplotlib.axes.Axes` or `None`, optional
        Axes in which to place plot. If `None`, a new figure is generated.
        (Default: `None`)
       
    fn : str, optional
        Name of plotting function. Must correspond to an attribute of a
        :class:`matplotlib.axes.Axes` (e.g. 'scatter', 'plot', 'hexbin'et c.),
        that is be able to take an Nx2 :class:`numpy.ndarray` in Cartesian space
        as input (e.g. 'plot', 'scatter', 'hexbin'; Default: 'scatter').

    vertex_labels : list or None, optional
        Labels for vertex. If `None`, vertices aren't labeled. (Default: `None`)

    grid : :class:`numpy.ndarray` or None, optional
        If not `None`, draw gridlines at intervals specified in `grid`,
        as long as the grid coordinate is > 0.33333 (center of triangle)
        and <= 1.0 (border).

    clip : bool, optional
        If `True` clipping masks corresponding to the triangle boundaries
        will be applied to all plot elements (Default: `True`)
        
    do_setup : bool, optional
        If `True`, the plot area will be prepared. A triangle will be drawn,
        gridlines drawn, et c. Specify `False` if plotting a second dataset
        ontop of an already-prepared axes (Default: `True`)

    **kwargs : keyword arguments
        Other keyword arguments to pass to function specified by `fn`.

    Returns
    -------
    :class:`matplotlib.figure.Figure`
        Parent figure of axes
    
    :class:`matplotlib.axes.Axes`
        Axes containing plot

    """
    fig, ax = get_fig_axes(axes)
    mplrc = matplotlib.rcParams

    if do_setup == True:
        triverts = trianglize(_triverts)
        tripatch = matplotlib.patches.Polygon(
            triverts,
            closed=True,
            facecolor=mplrc["axes.facecolor"],
            edgecolor=mplrc["axes.edgecolor"],
            linewidth=mplrc["axes.linewidth"],
            zorder=-10)
        ax.add_patch(tripatch)

        # format axes
        ax.set_xlim((0, 1))
        ax.set_ylim((0, _triverts[:, 1].max()))
        ax.set_frame_on(False)
        ax.set_xticks([])
        ax.set_yticks([])

        # label vertices
        if vertex_labels is not None:
            l1, l2, l3 = vertex_labels

            tkwargs = {"fig": fig, "units": "points"}
            p1trans = matplotlib.transforms.offset_copy(ax.transData,
                                                        x=0,
                                                        y=8,
                                                        **tkwargs)
            p2trans = matplotlib.transforms.offset_copy(ax.transData,
                                                        x=-10,
                                                        y=-12,
                                                        **tkwargs)
            p3trans = matplotlib.transforms.offset_copy(ax.transData,
                                                        x=10,
                                                        y=-12,
                                                        **tkwargs)
            ax.text(triverts[0, 0], triverts[0, 1], l1, transform=p1trans)
            ax.text(triverts[1, 0], triverts[1, 1], l2, transform=p2trans)
            ax.text(triverts[2, 0], triverts[2, 1], l3, transform=p3trans)

        # add gridlines
        grid_kwargs = {
            K.replace("grid.", ""): V
            for (K, V) in mplrc.items() if K.startswith("grid")
        }
        if grid is not None:
            grid = numpy.array(grid)
            remainders = (1.0 - grid) / 2
            for i, r in zip(grid, remainders):
                if i >= 1.0 / 3:
                    points = [numpy.array([i, r, r])]
                    for _ in range(3):
                        points.append(rotate.dot(points[-1]))

                    points = numpy.array(points)
                    points = trianglize(points[:, [0, 2]])

                    myline = matplotlib.lines.Line2D(points[:, 0], points[:,
                                                                          1],
                                                     **grid_kwargs)
                    ax.add_line(myline)

    # scale data
    data = trianglize(data)

    # plot data
    artists = []
    fn = getattr(ax, fn)
    res = fn(*zip(*data), **kwargs)
    if isinstance(res, Artist):
        artists.append(res)
    elif isinstance(res, list):
        artists.extend([X for X in res if isinstance(X, Artist)])

    # clip
    if clip == True:
        for artist in artists:
            artist.set_clip_path(tripatch)
            artist.set_clip_on(True)

    return fig, ax
コード例 #3
0
ファイル: plots.py プロジェクト: cyang-2014/plastid
def profile_heatmap(data,
                    profile=None,
                    x=None,
                    axes=None,
                    sort_fn=sort_max_position,
                    cmap=None,
                    nancolor="#666666",
                    im_args={},
                    plot_args={}):
    """Create a dual-paned plot in which `profile` is displayed in a top
    panel, above a heatmap showing the intensities of each row of `data`,
    optionally sorted top-to-bottom by `sort_fn`.

    Parameters
    ----------
    data : :class:`numpy.ndarray`
        Array of data, in which each row is an individual aligned vector of data
        for region of interest, and each column a position in that vector

    profile : :class:`numpy.ndarray` or None
        Reduced profile of data, often a column-wise median. If not
        supplied, it will be calculated.

    x : :class:`numpy.ndarray`
        Array of values for X-axis

    axes : :class:`matplotlib.axes.Axes` or `None`, optional
        Axes in which to place plot. If `None`, a new figure is generated.
        (Default: `None`)
       
    sort_fn : function, optional
        Sort rows in `data` by this function before plotting. Function must
        return a :class:`numpy.ndarray` of indices corresponding to rows in `data`
        (Default: sort by ascending argmax of each row)

    cmap : :class:`~matplotlib.colors.Colormap`, optional
        Colormap to use in heatmap. It not `None`, overrides any value
        in `im_args`. (Default: `None`) 

    nancolor : str or matplotlib colorspec
        Color used for plotting `nan` and other illegal or masked values
        
    im_args : dict
        Keyword arguments to pass to :func:`matplotlib.pyplot.imshow`

    plot_args : dict
        Keyword arguments to pass to :func:`matplotlib.pyplot.plot`
        for plotting metagene average


    Returns
    -------
    :class:`matplotlib.figure.Figure`
        Parent figure of axes
    
    dict
        Dictionary of :class:`matplotlib.axes.Axes`. "top" refers to the 
        panel containing the summary profile. "main" refers to the heatmap
        of individual values
    """
    fig, ax = get_fig_axes(axes)
    axes = split_axes(ax, top_height=0.2)

    if sort_fn is None:
        sort_indices = numpy.arange(data.shape[0])
    else:
        sort_indices = sort_fn(data)

    if x is None:
        x = numpy.arange(0, data.shape[1])

    if profile is None:
        profile = numpy.nanmedian(data, axis=0)

    im_args = copy.deepcopy(im_args)

    # populate with defaults
    for k, v in _heatmap_defaults.items():
        if k not in im_args:
            im_args[k] = v

    if "extent" not in im_args:
        im_args["extent"] = [x.min(), x.max(), 0, data.shape[0]]  #,0]
    if "vmin" not in im_args:
        im_args["vmin"] = numpy.nanmin(data)
    if "vmax" not in im_args:
        im_args["vmax"] = numpy.nanmax(data)

    if cmap is not None:
        im_args["cmap"] = cmap
    elif "cmap" in im_args:
        cmap = matplotlib.cm.get_cmap(im_args["cmap"])
    else:
        cmap = matplotlib.cm.get_cmap()

    cmap.set_bad(nancolor, 1.0)

    axes["top"].plot(x, profile, **plot_args)
    axes["top"].set_ylim(0, profile.max())
    axes["top"].set_xlim(x.min(), x.max())
    #axes["top"].set_yticks([])
    axes["top"].set_yticks([0, profile.max()])
    axes["top"].xaxis.tick_bottom()
    axes["top"].grid(True, which="both")

    axes["main"].xaxis.tick_bottom()
    axes["main"].imshow(data[sort_indices, :], **im_args)

    return fig, axes
コード例 #4
0
ファイル: plots.py プロジェクト: cyang-2014/plastid
def kde_plot(data,
             axes=None,
             color=None,
             label=None,
             alpha=0.7,
             vert=False,
             log=False,
             base=10,
             points=500,
             bw_method="scott",
             rescale=False,
             zorder=None):
    """Plot a kernel density estimate of `data` on `axes`.

    Parameters
    ----------
    data : :class:`numpy.ndarray`
        Array of data

    axes : :class:`matplotlib.axes.Axes` or `None`, optional
        Axes in which to place plot. If `None`, a new figure is generated.
        (Default: `None`)
        
    color : matplotlib colorspec, optional
        Color to use for plotting (Default: use next in matplotlibrc)

    label : str, optional
        Name of data series (used for legend; default: `None`)

    alpha : float, optional
        Amount of alpha transparency to use (Default: 0.7)

    vert : bool, optional
        If true, plot kde vertically

    log : bool, optional
        If `True`, `data` is log-transformed before the kde is estimated.
        Data are converted back to non-log space afterwards.

    base : 2, 10, or :obj:`numpy.e`, optional
        If `log` is `True`, this serves as the base of the log space.
        If `log` is `False`, this is ignored. (Default: 2)

    points : int
        Number of points over which to evaluate kde. (Default: 100)

    bw_method : str
        Bandwith estimation method. See documentation for
        :obj:`scipy.stats.gaussian_kde`. (Default: "scott")    


    Returns
    -------
    :class:`matplotlib.figure.Figure`
        Parent figure of axes
    
    :class:`matplotlib.axes.Axes`
        Axes containing plot
    """
    fig, axes = get_fig_axes(axes)

    if color is None:
        color = next(get_color_cycle(axes))

    a, b = get_kde(data,
                   log=log,
                   base=base,
                   points=points,
                   bw_method=bw_method)

    if rescale == True:
        b /= b.max()

    fbargs = {"alpha": alpha, "facecolor": lighten(color), "edgecolor": color}
    if label is not None:
        fbargs["label"] = label

    if vert == True:
        axes.fill_betweenx(a, b, 0, **fbargs)
        axes.plot(
            b, a, color=color, alpha=alpha, label=label
        )  # this is a bit of a hack to get labels to print; fill_between doesn't work with legends
        if log == True:
            axes.semilogy()
    else:
        axes.fill_between(a, b, 0, **fbargs)
        axes.plot(a, b, color=color, alpha=alpha, label=label)
        if log == True:
            axes.semilogx()

    return fig, axes
コード例 #5
0
ファイル: plots.py プロジェクト: joshuagryphon/plastid
def stacked_bar(data,axes=None,labels=None,lighten_by=0.1,cmap=None,**kwargs):
    """Create a stacked bar graph
    
    Parameters
    ----------
    data : :class:`numpy.ndarray`
        Array of data, in which each row is a stack, each column a value in that stack.

    axes : :class:`matplotlib.axes.Axes` or `None`, optional
        Axes in which to place plot. If `None`, a new figure is generated.
        (Default: `None`)
        
    labels : list, optional
        Labels for each stack. If `None`, stacks are labeled sequentially by number.
        (Default: `None`)
    
    lighten_by : float, optional
        Amount by which to lighten sequential blocks in each stack. (Default: 0.10)
    
    cmap : :class:`matplotlib.colors.Colormap`, optional
        Colormap from which to generate bar colors. If supplied, will override
        any `color` attribute in `**kwargs`. (Default: `None`)
        
    **kwargs : keyword arguments
        Other keyword arguments to pass to :func:`matplotlib.pyplot.bar`

    Returns
    -------
    :class:`matplotlib.figure.Figure`
        Parent figure of axes
    
    :class:`matplotlib.axes.Axes`
        Axes containing plot
    """
    fig, ax = get_fig_axes(axes)
    rows, cols = data.shape
    labels = labels if labels is not None else range(rows)
    defaults = [("align","center"),
                ("width",0.8)]

    if cmap is not None:
        kwargs["color"] = cmap(numpy.linspace(0,1.0,num=10))
    elif kwargs.get("color",None) is None:
        kwargs["color"] = [next(get_color_cycle(ax)) for _ in range(rows)]
        
    x = numpy.arange(rows) + 0.5
    xaxis = ax.xaxis
    xaxis.set_ticks(x)
    xaxis.set_ticklabels(labels)
    bottoms = numpy.zeros(rows)
   
    for k,v in defaults:
        if k not in kwargs:
            kwargs[k] = v
    
    for i in range(cols):
        color = kwargs["color"]
        if i > 0:
            kwargs["color"] = lighten(color,amt=lighten_by)

        heights = data[:,i]
        plt.bar(x,heights,bottom=bottoms,**kwargs)
        heights.shape
        bottoms += heights
    
    ax.set_xlim(-0.5,rows+0.5)

    return fig, ax
コード例 #6
0
ファイル: plots.py プロジェクト: joshuagryphon/plastid
def profile_heatmap(data,profile=None,x=None,axes=None,sort_fn=sort_max_position,
                    cmap=None,nancolor="#666666",im_args={},plot_args={}):
    """Create a dual-paned plot in which `profile` is displayed in a top
    panel, above a heatmap showing the intensities of each row of `data`,
    optionally sorted top-to-bottom by `sort_fn`.

    Parameters
    ----------
    data : :class:`numpy.ndarray`
        Array of data, in which each row is an individual aligned vector of data
        for region of interest, and each column a position in that vector

    profile : :class:`numpy.ndarray` or None
        Reduced profile of data, often a column-wise median. If not
        supplied, it will be calculated.

    x : :class:`numpy.ndarray`
        Array of values for X-axis

    axes : :class:`matplotlib.axes.Axes` or `None`, optional
        Axes in which to place plot. If `None`, a new figure is generated.
        (Default: `None`)
       
    sort_fn : function, optional
        Sort rows in `data` by this function before plotting. Function must
        return a :class:`numpy.ndarray` of indices corresponding to rows in `data`
        (Default: sort by ascending argmax of each row)

    cmap : :class:`~matplotlib.colors.Colormap`, optional
        Colormap to use in heatmap. It not `None`, overrides any value
        in `im_args`. (Default: `None`) 

    nancolor : str or matplotlib colorspec
        Color used for plotting `nan` and other illegal or masked values
        
    im_args : dict
        Keyword arguments to pass to :func:`matplotlib.pyplot.imshow`

    plot_args : dict
        Keyword arguments to pass to :func:`matplotlib.pyplot.plot`
        for plotting metagene average


    Returns
    -------
    :class:`matplotlib.figure.Figure`
        Parent figure of axes
    
    dict
        Dictionary of :class:`matplotlib.axes.Axes`. "top" refers to the 
        panel containing the summary profile. "main" refers to the heatmap
        of individual values
    """
    fig, ax = get_fig_axes(axes)
    axes = split_axes(ax,top_height=0.2)

    if sort_fn is None:
        sort_indices = numpy.arange(data.shape[0])
    else:
        sort_indices = sort_fn(data)

    if x is None:
        x = numpy.arange(0,data.shape[1])

    if profile is None:
        profile = numpy.nanmedian(data,axis=0)
    
    im_args     = copy.deepcopy(im_args)
    
    # populate with defaults
    for k,v in _heatmap_defaults.items():
        if k not in im_args:
            im_args[k] = v

    if "extent" not in im_args:            
        im_args["extent"] = [x.min(),x.max(),0,data.shape[0]]#,0]
    if "vmin" not in im_args:
        im_args["vmin"] = numpy.nanmin(data)
    if "vmax" not in im_args:
        im_args["vmax"] = numpy.nanmax(data)

    if cmap is not None:
        im_args["cmap"] = cmap
    elif "cmap" in im_args:
        cmap = matplotlib.cm.get_cmap(im_args["cmap"])
    else:
        cmap = matplotlib.cm.get_cmap()
    
    cmap.set_bad(nancolor,1.0)
    
    axes["top"].plot(x,profile,**plot_args)
    axes["top"].set_ylim(0,profile.max())
    axes["top"].set_xlim(x.min(),x.max())
    #axes["top"].set_yticks([])
    axes["top"].set_yticks([0,profile.max()])
    axes["top"].xaxis.tick_bottom()
    axes["top"].grid(True,which="both")

    axes["main"].xaxis.tick_bottom()
    axes["main"].imshow(data[sort_indices,:],**im_args)

    return fig, axes
コード例 #7
0
ファイル: plots.py プロジェクト: joshuagryphon/plastid
def triangle_plot(data,axes=None,fn="scatter",vertex_labels=None,grid=None,clip=True,do_setup=True,**kwargs):
    """Plot data lying in a plane x + y + z = k in a homogenous triangular space.

    Parameters
    ----------
    data : :class:`numpy.ndarray`
        Mx2 or Mx3 list or array of points in triangular space, where
        the first column is the first coordinate, the second column
        the second, and the third, the third.

    axes : :class:`matplotlib.axes.Axes` or `None`, optional
        Axes in which to place plot. If `None`, a new figure is generated.
        (Default: `None`)
       
    fn : str, optional
        Name of plotting function. Must correspond to an attribute of a
        :class:`matplotlib.axes.Axes` (e.g. 'scatter', 'plot', 'hexbin'et c.),
        that is be able to take an Nx2 :class:`numpy.ndarray` in Cartesian space
        as input (e.g. 'plot', 'scatter', 'hexbin'; Default: 'scatter').

    vertex_labels : list or None, optional
        Labels for vertex. If `None`, vertices aren't labeled. (Default: `None`)

    grid : :class:`numpy.ndarray` or None, optional
        If not `None`, draw gridlines at intervals specified in `grid`,
        as long as the grid coordinate is > 0.33333 (center of triangle)
        and <= 1.0 (border).

    clip : bool, optional
        If `True` clipping masks corresponding to the triangle boundaries
        will be applied to all plot elements (Default: `True`)
        
    do_setup : bool, optional
        If `True`, the plot area will be prepared. A triangle will be drawn,
        gridlines drawn, et c. Specify `False` if plotting a second dataset
        ontop of an already-prepared axes (Default: `True`)

    **kwargs : keyword arguments
        Other keyword arguments to pass to function specified by `fn`.

    Returns
    -------
    :class:`matplotlib.figure.Figure`
        Parent figure of axes
    
    :class:`matplotlib.axes.Axes`
        Axes containing plot

    """
    fig, ax = get_fig_axes(axes)
    mplrc = matplotlib.rcParams

    if do_setup == True:
        triverts = trianglize(_triverts)
        tripatch =  matplotlib.patches.Polygon(triverts,
                                               closed=True,
                                               facecolor=mplrc["axes.facecolor"],
                                               edgecolor=mplrc["axes.edgecolor"],
                                               linewidth=mplrc["axes.linewidth"],
                                               zorder=-10
                                               )
        ax.add_patch(tripatch)

        # format axes
        ax.set_xlim((0,1))
        ax.set_ylim((0,_triverts[:,1].max()))
        ax.set_frame_on(False)
        ax.set_xticks([])
        ax.set_yticks([])

        # label vertices
        if vertex_labels is not None:
            l1,l2,l3 = vertex_labels

            tkwargs = { "fig"   : fig,
                        "units" : "points"
                    }
            p1trans = matplotlib.transforms.offset_copy(ax.transData,x=0,  y=8,  **tkwargs)
            p2trans = matplotlib.transforms.offset_copy(ax.transData,x=-10,y=-12,**tkwargs)
            p3trans = matplotlib.transforms.offset_copy(ax.transData,x=10, y=-12,**tkwargs)
            ax.text(triverts[0,0],triverts[0,1],l1,transform=p1trans)
            ax.text(triverts[1,0],triverts[1,1],l2,transform=p2trans)
            ax.text(triverts[2,0],triverts[2,1],l3,transform=p3trans)

        # add gridlines
        grid_kwargs = { K.replace("grid.","") : V for (K,V) in mplrc.items() if K.startswith("grid") }
        if grid is not None:
            grid = numpy.array(grid)
            remainders = (1.0 - grid)/2
            for i, r in zip(grid,remainders):
                if i >= 1.0/3:
                    points = [numpy.array([i,r,r])]
                    for _ in range(3):
                        points.append(rotate.dot(points[-1]))

                    points = numpy.array(points)
                    points = trianglize(points[:,[0,2]])

                    myline = matplotlib.lines.Line2D(points[:,0],
                                                     points[:,1],
                                                     **grid_kwargs)
                    ax.add_line(myline)
    

    # scale data
    data = trianglize(data)

    # plot data
    artists = []
    fn = getattr(ax,fn)
    res = fn(*zip(*data),**kwargs)
    if isinstance(res,Artist):
        artists.append(res)
    elif isinstance(res,list):
        artists.extend([X for X in res if isinstance(X,Artist)])

    # clip
    if clip == True:
        for artist in artists:
            artist.set_clip_path(tripatch)
            artist.set_clip_on(True)

    return fig, ax
コード例 #8
0
ファイル: plots.py プロジェクト: joshuagryphon/plastid
def kde_plot(data,axes=None,color=None,label=None,alpha=0.7,vert=False,
            log=False,base=10,points=500,bw_method="scott",rescale=False,
            zorder=None):
    """Plot a kernel density estimate of `data` on `axes`.

    Parameters
    ----------
    data : :class:`numpy.ndarray`
        Array of data

    axes : :class:`matplotlib.axes.Axes` or `None`, optional
        Axes in which to place plot. If `None`, a new figure is generated.
        (Default: `None`)
        
    color : matplotlib colorspec, optional
        Color to use for plotting (Default: use next in matplotlibrc)

    label : str, optional
        Name of data series (used for legend; default: `None`)

    alpha : float, optional
        Amount of alpha transparency to use (Default: 0.7)

    vert : bool, optional
        If true, plot kde vertically

    log : bool, optional
        If `True`, `data` is log-transformed before the kde is estimated.
        Data are converted back to non-log space afterwards.

    base : 2, 10, or :obj:`numpy.e`, optional
        If `log` is `True`, this serves as the base of the log space.
        If `log` is `False`, this is ignored. (Default: 2)

    points : int
        Number of points over which to evaluate kde. (Default: 100)

    bw_method : str
        Bandwith estimation method. See documentation for
        :obj:`scipy.stats.gaussian_kde`. (Default: "scott")    


    Returns
    -------
    :class:`matplotlib.figure.Figure`
        Parent figure of axes
    
    :class:`matplotlib.axes.Axes`
        Axes containing plot
    """
    fig, axes = get_fig_axes(axes)

    if color is None:
        color = next(get_color_cycle(axes))

    a, b = get_kde(data,log=log,base=base,points=points,bw_method=bw_method)
    
    if rescale == True:
      b /= b.max()

    fbargs = { "alpha" : alpha,
               "facecolor" : lighten(color),
               "edgecolor" : color
             }
    if label is not None:
        fbargs["label"] = label

    if vert == True:
        axes.fill_betweenx(a,b,0,**fbargs)
        axes.plot(b,a,color=color,alpha=alpha,label=label) # this is a bit of a hack to get labels to print; fill_between doesn't work with legends
        if log == True:
            axes.semilogy()
    else:
        axes.fill_between(a,b,0,**fbargs)
        axes.plot(a,b,color=color,alpha=alpha,label=label)
        if log == True:
            axes.semilogx()

    return fig, axes