Пример #1
0
def std_plot(ax,
             xlabel,
             ylabel,
             zlabel,
             title=None,
             legendtitle=None,
             bbox_to_anchor=None,
             labelspacing=1.2,
             borderpad=1,
             handletextpad=0.5,
             legendsort=False,
             markerscale=None,
             xlim=None,
             ylim=None,
             xbins=None,
             ybins=None,
             cbar=None,
             cbarlabel=None,
             moveyaxis=False,
             sns=False,
             left=True,
             rotation=None,
             xticklabel=None,
             yticklabel=None,
             zticklabel=None,
             fontscale=1):
    # height = 2 font = 6.5
    def autoscale(fig):
        if isinstance(fig, matplotlib.figure.Figure):
            width, height = fig.get_size_inches()
        elif isinstance(fig, matplotlib.axes.Axes):
            width, height = fig.figure.get_size_inches()
        fontscale = height / 2
        if width / fontscale > 8:
            warnings.warn(
                "Please reset fig's width. When scaling the height to 2 in, the scaled width '%.2f' is large than 8"
                % (width / fontscale), UserWarning)
        return fontscale

    class fontprop:
        def init(self,
                 fonttitle=None,
                 fontlabel=None,
                 fontticklabel=None,
                 fontlegend=None,
                 fontcbarlabel=None,
                 fontcbarticklabel=None):
            self.fonttitle = fonttitle
            self.fontlabel = fontlabel
            self.fontticklabel = fontticklabel
            self.fontlegend = fontlegend
            self.fontcbarlabel = fontcbarlabel
            self.fontcbarticklabel = fontcbarticklabel

        def update(self, fontscale):
            self.fonttitle['size'] = self.fonttitle['size'] * fontscale
            self.fontlabel['size'] = self.fontlabel['size'] * fontscale
            self.fontticklabel['size'] = self.fontticklabel['size'] * fontscale
            self.fontlegend['size'] = self.fontlegend['size'] * fontscale
            self.fontcbarlabel['size'] = self.fontcbarlabel['size'] * fontscale
            self.fontcbarticklabel[
                'size'] = self.fontcbarticklabel['size'] * fontscale

        def reset(self, fontscale):
            self.fonttitle['size'] = self.fonttitle['size'] / fontscale
            self.fontlabel['size'] = self.fontlabel['size'] / fontscale
            self.fontticklabel['size'] = self.fontticklabel['size'] / fontscale
            self.fontlegend['size'] = self.fontlegend['size'] / fontscale
            self.fontcbarlabel['size'] = self.fontcbarlabel['size'] / fontscale
            self.fontcbarticklabel[
                'size'] = self.fontcbarticklabel['size'] / fontscale

    if fontscale == 1:
        fontscale = autoscale(ax)
    font = fontprop()
    font.init(fonttitle, fontlabel, fontticklabel, fontlegend, fontcbarlabel,
              fontcbarticklabel)
    font.update(fontscale)

    pyplot.draw()
    #plt.figure(linewidth=30.5)
    if xlim is not None:
        ax.set(xlim=xlim)
    if ylim is not None:
        ax.set(ylim=ylim)
    #pyplot.draw()
    if xbins is not None:
        locator = MaxNLocator(nbins=xbins)
        locator.set_axis(ax.xaxis)
        ax.set_xticks(locator())
    if ybins is not None:
        locator = MaxNLocator(nbins=ybins)
        locator.set_axis(ax.yaxis)
        ax.set_yticks(locator())
    pyplot.draw()
    ax.set_xticks(ax.get_xticks())
    ax.set_yticks(ax.get_yticks())

    ax.set_xlabel(xlabel,
                  fontdict=font.fontlabel,
                  labelpad=(fontsize - 1) * fontscale)
    ax.set_ylabel(ylabel,
                  fontdict=font.fontlabel,
                  labelpad=(fontsize - 1) * fontscale)

    if zlabel is not None:
        ax.set_zticks(ax.get_zticks())
        ax.set_zlabel(zlabel,
                      fontdict=font.fontlabel,
                      labelpad=(fontsize - 1) * fontscale)

    if (rotation is not None) & (xticklabel is not None):
        ax.set_xticklabels(xticklabel, fontticklabel, rotation=rotation)
    elif (xticklabel is not None) & (rotation is None):
        ax.set_xticklabels(xticklabel, fontticklabel)
    elif (xticklabel is None) & (rotation is None):
        ax.set_xticklabels(ax.get_xticklabels(), fontticklabel)
    elif (rotation is not None) & (xticklabel is None):
        ax.set_xticklabels(ax.get_xticklabels(),
                           fontticklabel,
                           rotation=rotation)

    if (rotation is not None) & (yticklabel is not None):
        ax.set_yticklabels(yticklabel, fontticklabel, rotation=rotation)
    elif (yticklabel is not None) & (rotation is None):
        ax.set_yticklabels(yticklabel, fontticklabel)
    elif (yticklabel is None) & (rotation is None):

        ax.set_yticklabels(ax.get_yticklabels(), fontticklabel)
    elif (rotation is not None) & (yticklabel is None):
        ax.set_yticklabels(ax.get_yticklabels(),
                           fontticklabel,
                           rotation=rotation)
    try:
        if (rotation is not None) & (zticklabel is not None):
            ax.set_zticklabels(zticklabel, fontticklabel, rotation=rotation)
        elif (zticklabel is not None) & (rotation is None):
            ax.set_zticklabels(zticklabel, fontticklabel)
        elif (zticklabel is None) & (rotation is None):
            ax.set_zticklabels(ax.get_zticklabels(), size=fontsize * fontscale)
        elif (rotation is not None) & (zticklabel is None):
            ax.set_zticklabels(ax.get_zticklabels(),
                               fontticklabel,
                               rotation=rotation)
    except:
        pass

    if moveyaxis is True:
        #fontticklabel
        ax.spines['left'].set_position(('data', 0))
    ax.spines['left'].set_visible(left)
    ax.spines['right'].set_visible(not left)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_linewidth(0.5 * fontscale)
    ax.spines['bottom'].set_linewidth(0.5 * fontscale)
    ax.spines['left'].set_linewidth(0.5 * fontscale)
    ax.spines['bottom'].set_color('k')
    ax.spines['left'].set_color('k')
    ax.spines['right'].set_color('k')

    ax.tick_params(direction='out', pad=2 * fontscale, width=0.5 * fontscale)
    #ax.spines['bottom']._edgecolor="#000000"
    #ax.spines['left']._edgecolor="#000000"
    if title is not None:
        ax.set_title(title, fontdict=font.fonttitle)
    if legendtitle is not None:
        #if legendloc is None:
        #    legendloc="best"
        legend = ax.legend(title=legendtitle,
                           prop=font.fontlegend,
                           bbox_to_anchor=bbox_to_anchor,
                           labelspacing=labelspacing,
                           borderpad=borderpad,
                           handletextpad=handletextpad,
                           edgecolor="#000000",
                           fancybox=False,
                           markerscale=markerscale)
        ax.legend_.get_frame()._linewidth = 0.5 * fontscale
        legend.get_title().set_fontweight('normal')
        legend.get_title().set_fontsize(fontscale * fontsize)
        if legendsort is True:
            # h: handle l:label
            h, l = ax.get_legend_handles_labels()
            l, h = zip(*sorted(zip(l, h), key=lambda t: int(t[0])))
            legend = ax.legend(h,
                               l,
                               title=legendtitle,
                               prop=font.fontlegend,
                               bbox_to_anchor=bbox_to_anchor,
                               labelspacing=labelspacing,
                               borderpad=borderpad,
                               handletextpad=handletextpad,
                               edgecolor="#000000",
                               fancybox=False,
                               markerscale=markerscale)
            ax.legend_.get_frame()._linewidth = 0.5 * fontscale
            legend.get_title().set_fontweight('normal')
            legend.get_title().set_fontsize(fontscale * fontsize)
        if sns is True:
            h, l = ax.get_legend_handles_labels()
            #l,h = zip(*sorted(zip(l,h), key=lambda t: int(t[0])))
            legend = ax.legend(h[1:],
                               l[1:],
                               title=legendtitle,
                               prop=font.fontlegend,
                               bbox_to_anchor=bbox_to_anchor,
                               labelspacing=labelspacing,
                               borderpad=borderpad,
                               handletextpad=handletextpad,
                               edgecolor="#000000",
                               fancybox=False,
                               markerscale=markerscale)
            ax.legend_.get_frame()._linewidth = 0.5 * fontscale
            legend.get_title().set_fontweight('normal')
            legend.get_title().set_fontsize(fontscale * fontsize)

    if cbar is not None:
        #locator, formatter = cbar._get_ticker_locator_formatter()
        #ticks, ticklabels, offset_string = cbar._ticker(locator, formatter)
        #cbar.ax.spines['top'].set_visible(False)
        #cbar.ax.spines['right'].set_visible(False)
        #cbar.ax.spines['bottom'].set_visible(False)
        #cbar.ax.spines['left'].set_visible(False)
        cbar.ax.tick_params(direction='out',
                            pad=3 * fontscale,
                            width=0 * fontscale,
                            length=0 * fontscale)
        cbar.set_label(cbarlabel,
                       fontdict=font.fontcbarlabel,
                       Rotation=270,
                       labelpad=fontscale * (fontsize + 1))
        cbar.ax.set_yticks(cbar.ax.get_yticks())
        cbar.ax.set_yticklabels(cbar.ax.get_yticklabels(),
                                font.fontcbarticklabel)
    font.reset(fontscale)
    return ax