Esempio n. 1
0
    def do_tight_layout(fig, axes, suptitle, **kwargs):

        # tight_layout
        renderer = get_renderer(fig)
        axeslist = list(axes.values())
        subplots_list = list(get_subplotspec_list(axeslist))
        kw = get_tight_layout_figure(fig,
                                     axeslist,
                                     subplots_list,
                                     renderer,
                                     pad=1.08,
                                     h_pad=0,
                                     w_pad=0,
                                     rect=None)

        left = kwargs.get('left', kw['left'])
        bottom = kwargs.get('bottom', kw['bottom'])
        right = kwargs.get('right', kw['right'])
        top = kw['top']
        if suptitle:
            top = top * .95
        top = kwargs.get('top', top)
        ws = kwargs.get('wspace', kw.get('wspace', 0) * 1.1)
        hs = kwargs.get('hspace', kw.get('hspace', 0) * 1.1)

        plt.subplots_adjust(left=left,
                            bottom=bottom,
                            right=right,
                            top=top,
                            wspace=ws,
                            hspace=hs)
Esempio n. 2
0
    def do_tight_layout(fig, axes, suptitle, **kwargs):

        # tight_layout
        renderer = get_renderer(fig)
        axeslist = list(axes.values())
        subplots_list = list(get_subplotspec_list(axeslist))
        kw = get_tight_layout_figure(
            fig,
            axeslist,
            subplots_list,
            renderer,
            # pad=1.1,
            h_pad=0,
            w_pad=0,
            rect=None,
        )

        left = kwargs.get("left", kw["left"])
        bottom = kwargs.get("bottom", kw["bottom"])
        right = kwargs.get("right", kw["right"])
        top = kw["top"]
        if suptitle:
            top = top * 0.95
        top = kwargs.get("top", top)
        ws = kwargs.get("wspace", kw.get("wspace", 0) * 1.1)
        hs = kwargs.get("hspace", kw.get("hspace", 0) * 1.1)

        plt.subplots_adjust(
            left=left, bottom=bottom, right=right, top=top, wspace=ws, hspace=hs
        )
Esempio n. 3
0
    def tight_layout(self, figure, renderer=None,
                     pad=1.08, h_pad=None, w_pad=None, rect=None):
        """
        Adjust subplot parameters to give specified padding.

        Parameters
        ----------

        pad : float
            Padding between the figure edge and the edges of subplots, as a
            fraction of the font-size.
        h_pad, w_pad : float, optional
            Padding (height/width) between edges of adjacent subplots.
            Defaults to ``pad_inches``.
        rect : tuple of 4 floats, optional
            (left, bottom, right, top) rectangle in normalized figure
            coordinates that the whole subplots area (including labels) will
            fit into.  Default is (0, 0, 1, 1).
        """

        subplotspec_list = tight_layout.get_subplotspec_list(
            figure.axes, grid_spec=self)
        if None in subplotspec_list:
            warnings.warn("This figure includes Axes that are not compatible "
                          "with tight_layout, so results might be incorrect.")

        if renderer is None:
            renderer = tight_layout.get_renderer(figure)

        kwargs = tight_layout.get_tight_layout_figure(
            figure, figure.axes, subplotspec_list, renderer,
            pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect)
        self.update(**kwargs)
Esempio n. 4
0
    def tight_layout(self, figure, renderer=None,
                     pad=1.08, h_pad=None, w_pad=None, rect=None):
        """
        Adjust subplot parameters to give specified padding.

        Parameters
        ----------

        pad : float
            Padding between the figure edge and the edges of subplots, as a
            fraction of the font-size.
        h_pad, w_pad : float, optional
            Padding (height/width) between edges of adjacent subplots.
            Defaults to ``pad_inches``.
        rect : tuple of 4 floats, optional
            (left, bottom, right, top) rectangle in normalized figure
            coordinates that the whole subplots area (including labels) will
            fit into.  Default is (0, 0, 1, 1).
        """

        subplotspec_list = tight_layout.get_subplotspec_list(
            figure.axes, grid_spec=self)
        if None in subplotspec_list:
            warnings.warn("This figure includes Axes that are not compatible "
                          "with tight_layout, so results might be incorrect.")

        if renderer is None:
            renderer = tight_layout.get_renderer(figure)

        kwargs = tight_layout.get_tight_layout_figure(
            figure, figure.axes, subplotspec_list, renderer,
            pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect)
        self.update(**kwargs)
Esempio n. 5
0
def save_axis_in_file(fig, ax, dirname, filename):

    renderer = tight_layout.get_renderer(fig)
    inset_tight_bbox = ax.get_tightbbox(renderer)
    extent = inset_tight_bbox.transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(dirname, filename + '.png'),
                bbox_inches=extent,
                dpi=1000)

    renderer = tight_layout.get_renderer(fig)
    inset_tight_bbox = ax.get_tightbbox(renderer)
    extent = inset_tight_bbox.transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(dirname, filename + '.svg'), bbox_inches=extent)

    renderer = tight_layout.get_renderer(fig)
    inset_tight_bbox = ax.get_tightbbox(renderer)
    extent = inset_tight_bbox.transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(dirname, filename + '.pdf'), bbox_inches=extent)
Esempio n. 6
0
    def make_grid(self, fig=None):
        """Get a SubplotSpec for each Axes, accounting for label text width
        """
        n_cats = len(self.totals)
        n_inters = len(self.intersections)

        if fig is None:
            fig = plt.gcf()

        # Determine text size to determine figure size / spacing
        r = get_renderer(fig)
        t = fig.text(0, 0, '\n'.join(self.totals.index.values))
        textw = t.get_window_extent(renderer=r).width
        t.remove()

        MAGIC_MARGIN = 10  # FIXME
        figw = self._reorient(fig.get_window_extent(renderer=r)).width
        if self._element_size is None:
            colw = (figw - textw - MAGIC_MARGIN) / (len(self.intersections) +
                                                    self._totals_plot_elements)
        else:
            fig = self._reorient(fig)
            render_ratio = figw / fig.get_figwidth()
            colw = self._element_size / 72 * render_ratio
            figw = (colw *
                    (len(self.intersections) + self._totals_plot_elements) +
                    MAGIC_MARGIN + textw)
            fig.set_figwidth(figw / render_ratio)
            fig.set_figheight(
                (colw *
                 (n_cats + self._intersection_plot_elements)) / render_ratio)

        text_nelems = int(
            np.ceil(figw / colw -
                    (len(self.intersections) + self._totals_plot_elements)))

        GS = self._reorient(matplotlib.gridspec.GridSpec)
        gridspec = GS(*self._swapaxes(
            n_cats + self._intersection_plot_elements,
            n_inters + text_nelems + self._totals_plot_elements),
                      hspace=1)
        if self._horizontal:
            return {
                'intersections': gridspec[:-n_cats, -n_inters:],
                'matrix': gridspec[-n_cats:, -n_inters:],
                'shading': gridspec[-n_cats:, :],
                'totals': gridspec[-n_cats:, :self._totals_plot_elements],
                'gs': gridspec
            }
        else:
            return {
                'intersections': gridspec[-n_inters:, n_cats:],
                'matrix': gridspec[-n_inters:, :n_cats],
                'shading': gridspec[:, :n_cats],
                'totals': gridspec[:self._totals_plot_elements, :n_cats],
                'gs': gridspec
            }
Esempio n. 7
0
def save_axis_in_file(fig, ax, dirname, filename):
    ax.set_title("")
    if 'sst' not in dirname:
        ax.set_xlabel(" ")
        ax.set_ylabel(" ")

    renderer = tight_layout.get_renderer(fig)
    inset_tight_bbox = ax.get_tightbbox(renderer)
    extent = inset_tight_bbox.transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(dirname + '/' + filename + '.pdf', bbox_inches=extent)
Esempio n. 8
0
def save_axis_in_file(fig, ax, dirname, filename):
    ax.set_title("")

    renderer = tight_layout.get_renderer(fig)
    inset_tight_bbox = ax.get_tightbbox(renderer)
    extent = inset_tight_bbox.transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(dirname, filename + '.png'), bbox_inches=extent, dpi=1000)

    renderer = tight_layout.get_renderer(fig)
    inset_tight_bbox = ax.get_tightbbox(renderer)
    extent = inset_tight_bbox.transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(dirname, filename + '.svg'), bbox_inches=extent)

    if 'sst' not in dirname and 'readmission' not in dirname:
        ax.set_xlabel(" ")
        ax.set_ylabel(" ")
        
    renderer = tight_layout.get_renderer(fig)
    inset_tight_bbox = ax.get_tightbbox(renderer)
    extent = inset_tight_bbox.transformed(fig.dpi_scale_trans.inverted())
    plt.savefig(os.path.join(dirname, filename + '.pdf'), bbox_inches=extent)
Esempio n. 9
0
def get_yticklabel_width(fig, ax):
    # With some backends, getting the renderer like this may trigger a warning
    # and cause matplotlib to drop down to the Agg backend.
    from matplotlib import tight_layout
    renderer = tight_layout.get_renderer(fig)

    width = max(
        artist.get_window_extent(renderer).width
        for artist in ax.get_yticklabels())
    dpi = ax.get_figure().get_dpi()

    return width / dpi
Esempio n. 10
0
def horizontal_center(fig, pad=1.08):
    """Apply matplotlib's tight_layout to the left margin while keeping the plot
    contents centered.

    This is useful when setting the size of a figure to a document's full
    column width then adjusting so that the plot appears centered rather than
    the [y-axis label, tick labels, plot area] as a whole is centered.

    Parameters
    ----------
    fig : Figure
        The matplotlib figure object, the content of which will be centered
    pad : float
        Padding between the edge of the figure and the axis labels, as a
        multiple of font size
        
    Returns
    -------
    i : int or None
        The number of iterations to converge (the computed margins don't change
        between iterations) or None if it does not converge.
        
    Examples
    --------
    Plot some data and save it as a PNG. The center of the x axis will be
    centered within the figure.
    
    >>> import mplpub
    >>> import matplotlib.pyplot as plt
    >>> fig = plt.figure()
    >>> plt.plot([1, 2, 3], [1, 4, 9])
    >>> plt.ylabel('y axis')
    >>> fig.set_size_inches(4, 1)
    >>> mplpub.horizontal_center(fig)
    >>> fig.savefig('plot.png')
    """
    for i in range(11):
        adjust_kwargs = get_tight_layout_figure(fig, fig.axes,
            get_subplotspec_list(fig.axes), get_renderer(fig), pad=pad)
        
        min_kwarg = max(1-adjust_kwargs['right'], adjust_kwargs['left'])
        if ((min_kwarg - fig.subplotpars.left)==0 and
            (min_kwarg + fig.subplotpars.right)==1):
            return i
        
        fig.subplots_adjust(left=min_kwarg, 
                            right=1-min_kwarg,
                            wspace=adjust_kwargs.get('wspace',None))
    warnings.warn("horizontal_center did not converge")
    
Esempio n. 11
0
File: mpl.py Progetto: yanzewu/line
def update_subfigure(m_state: state.GlobalState):
    """ Update m_state.cur_subfigure()
    """

    # figure is closed -> redraw the figure.
    if m_state.cur_figure().backend is None or not plt.fignum_exists(
            m_state.cur_figure().backend.number):
        update_figure(m_state)
        return

    logger.debug('Updating figure %s, subfigure %d' %
                 (m_state.cur_figurename, m_state.cur_figure().cur_subfigure))
    fig = plt.figure(m_state.cur_figurename)
    _update_subfigure(m_state.cur_subfigure(), tight_layout.get_renderer(fig))
Esempio n. 12
0
    def fix_before_drawing(self, *args):
        """
            Fixes alignment issues due to longer labels or smaller windows
            Is executed after an initial draw event, since we can then retrieve
            the actual label dimensions and shift/resize the plot accordingly.
        """
        renderer = get_renderer(self.figure)
        if not renderer or not self._canvas.get_realized():
            return False

        # Fix left side for wide specimen labels:
        if len(self.labels) > 0:
            bbox = self._get_joint_bbox(self.labels, renderer)
            if bbox is not None:
                self.position_setup.left = self.position_setup.default_left + bbox.width
        # Fix top for high marker labels:
        if len(self.marker_lbls) > 0:
            bbox = self._get_joint_bbox(
                [label for label, flag, _ in self.marker_lbls if flag],
                renderer)
            if bbox is not None:
                self.position_setup.top = self.position_setup.default_top - bbox.height
        # Fix bottom for x-axis label:
        bottom_label = self.plot.axis["bottom"].label
        if bottom_label is not None:
            bbox = self._get_joint_bbox([bottom_label], renderer)
            if bbox is not None:
                self.position_setup.bottom = self.position_setup.default_bottom + (
                    bbox.ymax - bbox.ymin) * 2.0  # somehow we need this?

        # Calculate new plot position & set it:
        plot_pos = self.position_setup.position
        self.plot.set_position(plot_pos)

        # Adjust specimen label position
        for label in self.labels:
            label.set_x(plot_pos[0] - 0.025)

        # Adjust marker label position
        for label, flag, y_offset in self.marker_lbls:
            if flag:
                newy = plot_pos[1] + plot_pos[3] + y_offset - 0.025
                label.set_y(newy)

        _new_pos = self.position_setup.to_string()
        return _new_pos
Esempio n. 13
0
def scale_width(table, shape):
    max_w = {}

    cells = table.get_celld()

    fig = table.get_figure()

    for col in range(shape[1]):
        max_w[col] = 0.
        for row in range(shape[0]):
            c = cells[(row, col)]
            cur_w = c.get_required_width(tight_layout.get_renderer(fig))
            if cur_w > max_w[col]:
                max_w[col] = cur_w

    for col in range(shape[1]):
        for row in range(shape[0]):
            c = cells[(row, col)]
            c.set_width(max_w[col])
Esempio n. 14
0
    def fix_before_drawing(self, *args):
        """
            Fixes alignment issues due to longer labels or smaller windows
            Is executed after an initial draw event, since we can then retrieve
            the actual label dimensions and shift/resize the plot accordingly.
        """
        renderer = get_renderer(self.figure)        
        if not renderer or not self._canvas.get_realized():
            return False
        
        # Fix left side for wide specimen labels:
        if len(self.labels) > 0:
            bbox = self._get_joint_bbox(self.labels, renderer)
            if bbox is not None: 
                self.position_setup.left = self.position_setup.default_left + bbox.width
        # Fix top for high marker labels:
        if len(self.marker_lbls) > 0:
            bbox = self._get_joint_bbox([ label for label, flag, _ in self.marker_lbls if flag ], renderer)
            if bbox is not None: 
                self.position_setup.top = self.position_setup.default_top - bbox.height
        # Fix bottom for x-axis label:
        bottom_label = self.plot.axis["bottom"].label
        if bottom_label is not None:
            bbox = self._get_joint_bbox([bottom_label], renderer)
            if bbox is not None:
                self.position_setup.bottom = self.position_setup.default_bottom + (bbox.ymax - bbox.ymin) * 2.0 # somehow we need this?

        # Calculate new plot position & set it:
        plot_pos = self.position_setup.position
        self.plot.set_position(plot_pos)

        # Adjust specimen label position
        for label in self.labels:
            label.set_x(plot_pos[0] - 0.025)

        # Adjust marker label position
        for label, flag, y_offset in self.marker_lbls:
            if flag:
                newy = plot_pos[1] + plot_pos[3] + y_offset - 0.025
                label.set_y(newy)
        
        _new_pos = self.position_setup.to_string()
        return _new_pos
Esempio n. 15
0
def fit_axes(ax):
    """ Redimension the given axes to have labels fitting.
    """
    fig = ax.get_figure()
    renderer = get_renderer(fig)
    ylabel_width = ax.yaxis.get_tightbbox(renderer).inverse_transformed(
        ax.figure.transFigure).width
    if ax.get_position().xmin < 1.1 * ylabel_width:
        # we need to move it over
        new_position = ax.get_position()
        new_position.x0 = 1.1 * ylabel_width  # pad a little
        ax.set_position(new_position)

    xlabel_height = ax.xaxis.get_tightbbox(renderer).inverse_transformed(
        ax.figure.transFigure).height
    if ax.get_position().ymin < 1.1 * xlabel_height:
        # we need to move it over
        new_position = ax.get_position()
        new_position.y0 = 1.1 * xlabel_height  # pad a little
        ax.set_position(new_position)
Esempio n. 16
0
def fit_axes(ax):
    """ Redimension the given axes to have labels fitting.
    """
    fig = ax.get_figure()
    renderer = get_renderer(fig)
    ylabel_width = ax.yaxis.get_tightbbox(renderer).inverse_transformed(
        ax.figure.transFigure).width
    if ax.get_position().xmin < 1.1 * ylabel_width:
        # we need to move it over
        new_position = ax.get_position()
        new_position.x0 = 1.1 * ylabel_width  # pad a little
        ax.set_position(new_position)

    xlabel_height = ax.xaxis.get_tightbbox(renderer).inverse_transformed(
        ax.figure.transFigure).height
    if ax.get_position().ymin < 1.1 * xlabel_height:
        # we need to move it over
        new_position = ax.get_position()
        new_position.y0 = 1.1 * xlabel_height  # pad a little
        ax.set_position(new_position)
Esempio n. 17
0
def plotMatrix(hm,
               outFileName,
               colorMapDict={
                   'colorMap': ['binary'],
                   'missingDataColor': 'black',
                   'alpha': 1.0
               },
               plotTitle='',
               xAxisLabel='',
               yAxisLabel='',
               regionsLabel='',
               zMin=None,
               zMax=None,
               yMin=None,
               yMax=None,
               averageType='median',
               reference_point_label='TSS',
               startLabel='TSS',
               endLabel="TES",
               heatmapHeight=25,
               heatmapWidth=7.5,
               perGroup=False,
               whatToShow='plot, heatmap and colorbar',
               image_format=None,
               legend_location='upper-left',
               box_around_heatmaps=True,
               label_rotation=0.0,
               dpi=200,
               interpolation_method='auto'):

    hm.reference_point_label = reference_point_label
    hm.startLabel = startLabel
    hm.endLabel = endLabel

    matrix_flatten = None
    if zMin is None:
        matrix_flatten = hm.matrix.flatten()
        # try to avoid outliers by using np.percentile
        zMin = np.percentile(matrix_flatten, 1.0)
        if np.isnan(zMin):
            zMin = [None]
        else:
            zMin = [zMin]  # convert to list to support multiple entries

    if zMax is None:
        if matrix_flatten is None:
            matrix_flatten = hm.matrix.flatten()
        # try to avoid outliers by using np.percentile
        zMax = np.percentile(matrix_flatten, 98.0)
        if np.isnan(zMax) or zMax <= zMin[0]:
            zMax = [None]
        else:
            zMax = [zMax]

    if yMin is None:
        yMin = [None]
    if yMax is None:
        yMax = [None]
    if not isinstance(yMin, list):
        yMin = [yMin]
    if not isinstance(yMax, list):
        yMax = [yMax]

    plt.rcParams['font.size'] = 8.0
    fontP = FontProperties()

    showSummaryPlot = False
    showColorbar = False

    if whatToShow == 'plot and heatmap':
        showSummaryPlot = True
    elif whatToShow == 'heatmap and colorbar':
        showColorbar = True
    elif whatToShow == 'plot, heatmap and colorbar':
        showSummaryPlot = True
        showColorbar = True

    # colormap for the heatmap
    if colorMapDict['colorMap']:
        cmap = []
        for color_map in colorMapDict['colorMap']:
            cmap.append(plt.get_cmap(color_map))
            cmap[-1].set_bad(colorMapDict['missingDataColor']
                             )  # nans are printed using this color

    if colorMapDict['colorList'] and len(colorMapDict['colorList']) > 0:
        # make a cmap for each color list given
        cmap = []
        for color_list in colorMapDict['colorList']:
            cmap.append(
                matplotlib.colors.LinearSegmentedColormap.from_list(
                    'my_cmap',
                    color_list.replace(' ', '').split(","),
                    N=colorMapDict['colorNumber']))
            cmap[-1].set_bad(colorMapDict['missingDataColor']
                             )  # nans are printed using this color

    if len(cmap) > 1 or len(zMin) > 1 or len(zMax) > 1:
        # position color bar below heatmap when more than one
        # heatmap color is given
        colorbar_position = 'below'
    else:
        colorbar_position = 'side'

    grids = prepare_layout(hm.matrix, (heatmapWidth, heatmapHeight),
                           showSummaryPlot, showColorbar, perGroup,
                           colorbar_position)

    # figsize: w,h tuple in inches
    figwidth = heatmapWidth / 2.54
    figheight = heatmapHeight / 2.54
    if showSummaryPlot:
        # the summary plot ocupies a height
        # equal to the fig width
        figheight += figwidth

    numsamples = hm.matrix.get_num_samples()
    if perGroup:
        num_cols = hm.matrix.get_num_groups()
    else:
        num_cols = numsamples
    total_figwidth = figwidth * num_cols
    if showColorbar:
        if colorbar_position == 'below':
            figheight += 1 / 2.54
        else:
            total_figwidth += 1 / 2.54

    fig = plt.figure(figsize=(total_figwidth, figheight))
    fig.suptitle(plotTitle, y=1 - (0.06 / figheight))

    # color map for the summary plot (profile) on top of the heatmap
    cmap_plot = plt.get_cmap('jet')
    numgroups = hm.matrix.get_num_groups()
    if perGroup:
        color_list = cmap_plot(
            np.arange(hm.matrix.get_num_samples()) /
            hm.matrix.get_num_samples())
    else:
        color_list = cmap_plot(np.arange(numgroups) / numgroups)
    alpha = colorMapDict['alpha']

    if image_format == 'plotly':
        return plotlyMatrix(hm,
                            outFileName,
                            yMin=yMin,
                            yMax=yMax,
                            zMin=zMin,
                            zMax=zMax,
                            showSummaryPlot=showSummaryPlot,
                            showColorbar=showColorbar,
                            cmap=cmap,
                            colorList=color_list,
                            colorBarPosition=colorbar_position,
                            perGroup=perGroup,
                            averageType=averageType,
                            plotTitle=plotTitle,
                            xAxisLabel=xAxisLabel,
                            yAxisLabel=yAxisLabel,
                            label_rotation=label_rotation)

    # check if matrix is reference-point based using the upstream >0 value
    # and is sorted by region length. If this is
    # the case, prepare the data to plot a border at the regions end
    regions_length_in_bins = [None] * len(hm.parameters['upstream'])
    if hm.matrix.sort_using == 'region_length' and hm.matrix.sort_method != 'no':
        for idx in range(len(hm.parameters['upstream'])):
            if hm.paramters['upstream'] > 0:
                _regions = hm.matrix.get_regions()
                foo = []
                for _group in _regions:
                    _reg_len = []
                    for ind_reg in _group:
                        if isinstance(ind_reg, dict):
                            _len = ind_reg['end'] - ind_reg['start']
                        else:
                            _len = sum([x[1] - x[0] for x in ind_reg[1]])
                        _reg_len.append(
                            (hm.parameters['upstream'][idx] + _len) /
                            hm.parameters['bin size'][idx])
                    foo.append(_reg_len)
                regions_length_in_bins[idx] = foo

    # plot the profiles on top of the heatmaps
    if showSummaryPlot:
        if perGroup:
            iterNum = numgroups
            iterNum2 = hm.matrix.get_num_samples()
        else:
            iterNum = hm.matrix.get_num_samples()
            iterNum2 = numgroups
        ax_list = addProfilePlot(hm, plt, fig, grids, iterNum, iterNum2,
                                 perGroup, averageType, yAxisLabel, color_list,
                                 yMin, yMax, None, None, colorbar_position,
                                 label_rotation)
        if len(yMin) > 1 or len(yMax) > 1:
            # replot with a tight layout
            import matplotlib.tight_layout as tl
            specList = tl.get_subplotspec_list(fig.axes, grid_spec=grids)
            renderer = tl.get_renderer(fig)
            kwargs = tl.get_tight_layout_figure(fig,
                                                fig.axes,
                                                specList,
                                                renderer,
                                                pad=1.08)

            for ax in ax_list:
                fig.delaxes(ax)

            ax_list = addProfilePlot(hm, plt, fig, grids, iterNum, iterNum2,
                                     perGroup, averageType, yAxisLabel,
                                     color_list, yMin, yMax, kwargs['wspace'],
                                     kwargs['hspace'], colorbar_position,
                                     label_rotation)

        if legend_location != 'none':
            ax_list[-1].legend(loc=legend_location.replace('-', ' '),
                               ncol=1,
                               prop=fontP,
                               frameon=False,
                               markerscale=0.5)

    first_group = 0  # helper variable to place the title per sample/group
    for sample in range(hm.matrix.get_num_samples()):
        sample_idx = sample
        for group in range(numgroups):
            group_idx = group
            # add the respective profile to the
            # summary plot
            sub_matrix = hm.matrix.get_matrix(group, sample)
            if showSummaryPlot:
                if perGroup:
                    sample_idx = sample + 2  # plot + spacer
                else:
                    group += 2  # plot + spacer
                first_group = 1

            if perGroup:
                ax = fig.add_subplot(grids[sample_idx, group])
                # the remainder (%) is used to iterate
                # over the available color maps (cmap).
                # if the user only provided, lets say two
                # and there are 10 groups, colormaps they are reused every
                # two groups.
                cmap_idx = group_idx % len(cmap)
                zmin_idx = group_idx % len(zMin)
                zmax_idx = group_idx % len(zMax)
            else:
                ax = fig.add_subplot(grids[group, sample])
                # see above for the use of '%'
                cmap_idx = sample % len(cmap)
                zmin_idx = sample % len(zMin)
                zmax_idx = sample % len(zMax)

            if group == first_group and not showSummaryPlot and not perGroup:
                title = hm.matrix.sample_labels[sample]
                ax.set_title(title)

            if box_around_heatmaps is False:
                # Turn off the boxes around the individual heatmaps
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
                ax.spines['left'].set_visible(False)
            rows, cols = sub_matrix['matrix'].shape
            # if the number of rows is too large, then the 'nearest' method simply
            # drops rows. A better solution is to relate the threshold to the DPI of the image
            if interpolation_method == 'auto':
                if rows >= 1000:
                    interpolation_method = 'bilinear'
                else:
                    interpolation_method = 'nearest'

            # if np.clip is not used, then values of the matrix that exceed the zmax limit are
            # highlighted. Usually, a significant amount of pixels are equal or above the zmax and
            # the default behaviour produces images full of large highlighted dots.
            # If interpolation='nearest' is used, this has no effect
            sub_matrix['matrix'] = np.clip(sub_matrix['matrix'],
                                           zMin[zmin_idx], zMax[zmax_idx])
            img = ax.imshow(sub_matrix['matrix'],
                            aspect='auto',
                            interpolation=interpolation_method,
                            origin='upper',
                            vmin=zMin[zmin_idx],
                            vmax=zMax[zmax_idx],
                            cmap=cmap[cmap_idx],
                            alpha=alpha,
                            extent=[0, cols, rows, 0])
            img.set_rasterized(True)
            # plot border at the end of the regions
            # if ordered by length
            if regions_length_in_bins[sample] is not None:
                x_lim = ax.get_xlim()
                y_lim = ax.get_ylim()

                ax.plot(regions_length_in_bins[sample][group_idx],
                        np.arange(
                            len(regions_length_in_bins[sample][group_idx])),
                        '--',
                        color='black',
                        linewidth=0.5,
                        dashes=(3, 2))
                ax.set_xlim(x_lim)
                ax.set_ylim(y_lim)

            if perGroup:
                ax.axes.set_xlabel(sub_matrix['group'])
                if sample < hm.matrix.get_num_samples() - 1:
                    ax.axes.get_xaxis().set_visible(False)
            else:
                ax.axes.get_xaxis().set_visible(False)
                ax.axes.set_xlabel(xAxisLabel)
            ax.axes.set_yticks([])
            if perGroup and group == 0:
                ax.axes.set_ylabel(sub_matrix['sample'])
            elif not perGroup and sample == 0:
                ax.axes.set_ylabel(sub_matrix['group'])

            # add labels to last block in a column
            if (perGroup and sample == numsamples - 1) or \
               (not perGroup and group_idx == numgroups - 1):

                # add xticks to the bottom heatmap (last group)
                ax.axes.get_xaxis().set_visible(True)
                xticks_heat, xtickslabel_heat = hm.getTicks(sample)
                if np.ceil(max(xticks_heat)) != float(
                        sub_matrix['matrix'].shape[1]):
                    tickscale = float(
                        sub_matrix['matrix'].shape[1]) / max(xticks_heat)
                    xticks_heat_use = [x * tickscale for x in xticks_heat]
                    ax.axes.set_xticks(xticks_heat_use)
                else:
                    ax.axes.set_xticks(xticks_heat)
                ax.axes.set_xticklabels(xtickslabel_heat, size=8)

                # align the first and last label
                # such that they don't fall off
                # the heatmap sides
                ticks = ax.xaxis.get_major_ticks()
                ticks[0].label1.set_horizontalalignment('left')
                ticks[-1].label1.set_horizontalalignment('right')

                ax.get_xaxis().set_tick_params(which='both',
                                               top='off',
                                               direction='out')

                if showColorbar and colorbar_position == 'below':
                    # draw a colormap per each heatmap below the last block
                    if perGroup:
                        col = group_idx
                    else:
                        col = sample
                    ax = fig.add_subplot(grids[-1, col])
                    tick_locator = ticker.MaxNLocator(nbins=3)
                    cbar = fig.colorbar(img,
                                        cax=ax,
                                        alpha=alpha,
                                        orientation='horizontal',
                                        ticks=tick_locator)
                    labels = cbar.ax.get_xticklabels()
                    ticks = cbar.ax.get_xticks()
                    if ticks[0] == 0:
                        # if the label is at the start of the colobar
                        # move it a bit inside to avoid overlapping
                        # with other labels
                        labels[0].set_horizontalalignment('left')
                    if ticks[-1] == 1:
                        # if the label is at the end of the colobar
                        # move it a bit inside to avoid overlapping
                        # with other labels
                        labels[-1].set_horizontalalignment('right')
                    # cbar.ax.set_xticklabels(labels, rotation=90)

    if showColorbar and colorbar_position != 'below':
        if showSummaryPlot:
            # we don't want to colorbar to extend
            # over the profiles and spacer top rows
            grid_start = 2
        else:
            grid_start = 0

        ax = fig.add_subplot(grids[grid_start:, -1])
        fig.colorbar(img, cax=ax, alpha=alpha)

    if box_around_heatmaps:
        plt.subplots_adjust(wspace=0.10,
                            hspace=0.025,
                            top=0.85,
                            bottom=0,
                            left=0.04,
                            right=0.96)
    else:
        #  When no box is plotted the space between heatmaps is reduced
        plt.subplots_adjust(wspace=0.05,
                            hspace=0.01,
                            top=0.85,
                            bottom=0,
                            left=0.04,
                            right=0.96)

    plt.savefig(outFileName,
                bbox_inches='tight',
                pdd_inches=0,
                dpi=dpi,
                format=image_format)
    plt.close()
Esempio n. 18
0
    def make_grid(self, fig=None):
        """Get a SubplotSpec for each Axes, accounting for label text width
        """
        n_cats = len(self.totals)
        n_inters = len(self.intersections)

        if fig is None:
            fig = plt.gcf()

        # Determine text size to determine figure size / spacing
        r = get_renderer(fig)
        t = fig.text(0, 0, '\n'.join(self.totals.index.values))
        textw = t.get_window_extent(renderer=r).width
        t.remove()

        MAGIC_MARGIN = 10  # FIXME
        figw = self._reorient(fig.get_window_extent(renderer=r)).width

        sizes = np.asarray([p['elements'] for p in self._subset_plots])

        if self._element_size is None:
            colw = (figw - textw - MAGIC_MARGIN) / (len(self.intersections) +
                                                    self._totals_plot_elements)
        else:
            fig = self._reorient(fig)
            render_ratio = figw / fig.get_figwidth()
            colw = self._element_size / 72 * render_ratio
            figw = (colw * (len(self.intersections) +
                            self._totals_plot_elements) +
                    MAGIC_MARGIN + textw)
            fig.set_figwidth(figw / render_ratio)
            fig.set_figheight((colw * (n_cats + sizes.sum())) /
                              render_ratio)

        text_nelems = int(np.ceil(figw / colw - (len(self.intersections) +
                                                 self._totals_plot_elements)))

        GS = self._reorient(matplotlib.gridspec.GridSpec)
        gridspec = GS(*self._swapaxes(n_cats + (sizes.sum() or 0),
                                      n_inters + text_nelems +
                                      self._totals_plot_elements),
                      hspace=1)
        if self._horizontal:
            print(n_cats, n_inters, self._totals_plot_elements)
            out = {'matrix': gridspec[-n_cats:, -n_inters:],
                   'shading': gridspec[-n_cats:, :],
                   'totals': gridspec[-n_cats:, :self._totals_plot_elements],
                   'gs': gridspec}
            cumsizes = np.cumsum(sizes[::-1])
            for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes,
                                         self._subset_plots[::-1]):
                out[plot['id']] = gridspec[start:stop, -n_inters:]
        else:
            out = {'matrix': gridspec[-n_inters:, :n_cats],
                   'shading': gridspec[:, :n_cats],
                   'totals': gridspec[:self._totals_plot_elements, :n_cats],
                   'gs': gridspec}
            cumsizes = np.cumsum(sizes)
            for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes,
                                         self._subset_plots):
                out[plot['id']] = gridspec[-n_inters:,
                                           start + n_cats:stop + n_cats]
        return out
Esempio n. 19
0
def vertical_aspect(fig, aspect, ax_idx=0, pad=1.08,
                    nonoverlapping_extra_artists=[],
                    overlapping_extra_artists=[]):
    """Adjust figure height and vertical spacing so a sub-plot plotting area has
    a specified aspect ratio and the overall figure has top/bottom margins from
    tight_layout.

    Parameters
    ----------
    fig : Figure
        The matplotlib figure object to be updated
    aspect : float
        The aspect ratio (W:H) desired for the subplot of ax_idx
    ax_idx : int
        The index (of fig.axes) for the axes to have the desired aspect ratio
    pad : float
        Padding between the edge of the figure and the axis labels, as a
        multiple of font size
    nonoverlapping_extra_artists : iterable of artists
        Iterable of artists that should not overlap with the subplots but 
        should be accounted for in the layout; e.g., suptitle
    overlapping_extra_artists : iterable of artists
        Iterable of artists that may overlap with the subplots but should be 
        accounted for in the layout; e.g., legends
    
    
    Returns
    -------
    i : int or float
        The number of iterations to converge (be within one pixel by DPI) of the
        desired aspect ratio or if it does not converge, the current aspect 
        ratio of the fig.axes[ax_idx]
    
    Examples
    --------
    Plot some data and save it as a PNG. The height
    
    >>> import mplpub
    >>> import matplotlib.pyplot as plt
    >>> fig = plt.figure()
    >>> plt.plot([1, 2, 3], [1, 4, 9])
    >>> plt.ylabel('y axis')
    >>> fig.set_size_inches(4, 1)
    >>> mplpub.vertical_aspect(fig, mplpub.golden_ratio)
    >>> fig.savefig('plot.png')
    
    Plays well with subplots
    
    >>> import mplpub
    >>> import matplotlib.pyplot as plt
    >>> plt.ion()
    >>> fig = plt.figure()
    >>> for spec in ((1,2,1), (2,2,2), (2,2,4)):
    >>>     plt.subplot(*spec)
    >>>     plt.plot([1, 2, 3], [1, 4, 9])
    >>>     plt.ylabel('y axis')
    >>> fig.set_size_inches(8, 8)
    >>> fig.suptitle("super title")
    >>> print("center iter",mplpub.horizontal_center(fig))
    
    The aspect ratio of any subplot can be set
    
    >>> print("vert iter",mplpub.vertical_aspect(fig, 1, 1))
    >>> print("vert iter",mplpub.vertical_aspect(fig, 0, 0.5))
    
    """
    ax = fig.axes[ax_idx]
    w, h = fig.get_size_inches()

    nrows = get_subplotspec_list(fig.axes)[ax_idx].get_geometry()[0]

    pad_inches = pad * FontProperties(
        size=rcParams["font.size"]).get_size_in_points() / 144

    non_overlapping_inches = {'top': 0, 'bottom': 0}
    hspace = fig.subplotpars.hspace
    for artist in nonoverlapping_extra_artists:
        artist_bbox = TransformedBbox(
            artist.get_window_extent(get_renderer(fig)),
            fig.transFigure.inverted()
        )
        if artist_bbox.ymax < 0.5:
            side = 'bottom'
        else:
            side = 'top'
        non_overlapping_inches[side] += artist_bbox.height*h + pad_inches

    for i in range(11):
        overlapping_maxy = 0
        overlapping_miny = 1
        for artist in overlapping_extra_artists:
            artist_bbox = TransformedBbox(
                artist.get_window_extent(get_renderer(fig)),
                fig.transFigure.inverted()
            )
            if artist_bbox.ymax > overlapping_maxy:
                overlapping_maxy = artist_bbox.ymax
            if artist_bbox.ymin < overlapping_miny:
                overlapping_miny = artist_bbox.ymin
                
        if overlapping_maxy > (1 - pad_inches/h):
            overlapping_top_adjust_inches = overlapping_maxy*h - (h - pad_inches) 
        else:
            overlapping_top_adjust_inches = 0
        if overlapping_miny < pad_inches/h:
            overlapping_bottom_adjust_inches = pad_inches - overlapping_miny*h
        else:
            overlapping_bottom_adjust_inches = 0

        bbox = ax.get_position()        
        current_aspect = ((bbox.x1 - bbox.x0)*w)/((bbox.y1 - bbox.y0)*h)
        aspect_diff = abs(current_aspect - aspect)*w * fig.get_dpi()
        if ((aspect_diff*3.14159<1) and
           (not overlapping_top_adjust_inches) and 
           (not overlapping_bottom_adjust_inches)):
            return i

        old_h = h

        adjust_kwargs = get_tight_layout_figure(fig, fig.axes,
            get_subplotspec_list(fig.axes), get_renderer(fig), pad=pad,
            rect = (0, 
                (non_overlapping_inches['bottom'] 
                 + overlapping_bottom_adjust_inches)/h,
                 1, 
                 1 - (
                    non_overlapping_inches['top'] 
                    + overlapping_top_adjust_inches
                 )/h)
            )
        
        tight_top_inches = (1-adjust_kwargs['top'])*old_h
        tight_bottom_inches = adjust_kwargs['bottom']*old_h

        hspace = adjust_kwargs.get('hspace',0)
        h = ( bbox.width*w*(nrows + hspace*(nrows-1))/aspect +
                (adjust_kwargs['bottom'] + 1 - adjust_kwargs['top'])*old_h +
                overlapping_top_adjust_inches + 
                overlapping_bottom_adjust_inches +
                non_overlapping_inches['top'] +
                non_overlapping_inches['bottom'])

        fig.set_size_inches((w, h))

        fig.subplots_adjust(
            top=1-(tight_top_inches)/h,
            bottom=(tight_bottom_inches)/h,
            hspace=adjust_kwargs.get('hspace',None)
        )

    warnings.warn("vertical_aspect did not converge")
    return current_aspect
Esempio n. 20
0
def subtract_background(df, filename):
    mix = GaussianMixture(n_components=2, tol=1e-8, max_iter=int(1e4))

    def bgsubt(g):
        if g['counts'].size <= 1:
            return g
        mix.fit(np.log10(g['counts'].values.reshape(-1, 1)))
        means = mix.means_.reshape(-1)
        covs = mix.covariances_.reshape(-1)
        ubounds = 10**norm.ppf(0.975, means, covs)
        bg_comp = ubounds.argmin()
        g['counts'] = np.maximum(0, g['counts'].values - 10**means[bg_comp])
        return g

    grouped = df.groupby(['barcode_fw', 'barcode_rev'])
    subt = grouped.apply(
        bgsubt)  #.astype({'barcode_fw':'category', 'barcode_rev': 'category'})

    dfs = (df, subt)
    titles = ('raw read counts', 'background-subtracted read counts')
    with PdfPages(filename) as pdf:
        for cdf, title in zip(dfs, titles):
            fw = cdf['barcode_fw'].cat.categories.sort_values()
            rev = cdf['barcode_rev'].cat.categories.sort_values()
            grouped = cdf.query("counts > 0").groupby(
                ['barcode_fw', 'barcode_rev'], observed=True)
            fig, ax = plt.subplots(nrows=fw.size,
                                   ncols=rev.size,
                                   sharex=True,
                                   sharey=True,
                                   squeeze=False)
            plt.subplots_adjust(left=0.125,
                                right=0.9,
                                bottom=0.1,
                                top=0.9,
                                wspace=0,
                                hspace=0)

            w = 2 * rev.size
            h = 0.5 * fw.size
            w *= 1 / (fig.subplotpars.right - fig.subplotpars.left)
            h *= 1 / (fig.subplotpars.top - fig.subplotpars.bottom)

            fig.set_size_inches(w, h)

            for y in range(fw.size):
                for x in range(rev.size):
                    ax[y, x].set_xscale("log")
                    ax[y, x].xaxis.get_major_locator().set_params(numticks=10)
                    ax[y,
                       x].xaxis.get_minor_locator().set_params(numticks=1000)
                    try:
                        g = grouped.get_group((fw[y], rev[x]))
                    except KeyError:
                        continue
                    c = g['counts']
                    nbins = max(min(100, int(c.size / 10)), 1)
                    ax[y, x].hist(c,
                                  bins=np.logspace(np.log10(c.min()),
                                                   np.log10(c.max()), nbins),
                                  color="#000000",
                                  alpha=0.66,
                                  edgecolor='none')
                    #if x > 0:
                    #plt.setp(ax[y, x].get_yticklines(), visible=False)
                    #if y < fw.size - 1:
                    #xax = ax[y, x].get_xaxis()
                    #plt.setp(xax.get_majorticklines(), visible=False)
                    #plt.setp(xax.get_minorticklines(), visible=False)
            maxx = rev.size - 1
            for y in range(fw.size):
                ax[y, maxx].get_yaxis().set_label_position('right')
                ax[y, maxx].set_ylabel(fw[y],
                                       rotation="horizontal",
                                       ha="left",
                                       va="center")
            for x in range(rev.size):
                ax[0, x].get_xaxis().set_label_position('top')
                ax[0, x].set_xlabel(rev[x])

            box = fig.get_tightbbox(get_renderer(fig))
            fig.text(0.5, (box.ymin - 0.1) / h, "read counts", ha="center")
            fig.text((box.xmin - 0.2) / w,
                     0.5,
                     "frequency",
                     va="center",
                     rotation="vertical")
            fig.suptitle(title, y=(box.ymax + 0.2) / h)

            pdf.savefig(bbox_inches="tight")
            plt.close()
    return subt
Esempio n. 21
0
def wrap_suptitle(fig, suptitle_words=[], hpad=None, **kwargs):
    """Add a wrapped suptitle so the width does not extend into the padding of
    the figure. 
    
    Parameters
    ----------
    fig : Figure
        The matplotlib figure object to be updated
    suptitle_text : list of strings
        A list of word strings that that should not be put on separate lines
    hpad : float or None
        Horizontal padding between the edge of the figure and the axis labels,
        as a multiple of font size. If none, will use the figure's `subplotpars`
        left and right values to keep the text over the plot area
    **kwargs : dict
        Additional keyword args to be passed to suptitle
        
    Returns
    -------
    suptitle : matplotlib.text.Text
        Passes return title from suptitle
    
    >>> import mplpub
    >>> import matplotlib.pyplot as plt
    >>> plt.ion()
    >>> fig = plt.figure()
    >>> for spec in ((1,2,1), (2,2,2), (2,2,4)):
    >>>     plt.subplot(*spec)
    >>>     plt.plot([1, 2, 3], [1, 4, 9])
    >>>     plt.ylabel('y axis')
    >>> fig.set_size_inches(4,4)
    >>> t = "This is a really long string that I'd rather have wrapped so that"\
    >>> " it doesn't go outside of the figure, but if it's long enough it will"\
    >>> " go off the top or bottom!"
    >>> mplpub.wrap_suptitle(fig,t.split(" "))

    """
    w, h = fig.get_size_inches()
    
    if hpad is None:
        max_width = 1 - max(fig.subplotpars.left, 1-fig.subplotpars.right)
    else:
        max_width = 1 - 2 * hpad * FontProperties(
            size=rcParams["font.size"]).get_size_in_points() / (144 * w)
    words_in_lines = [suptitle_words]
    iter_count = 0
    
    iter_max = len(suptitle_words)
    for iter_count in range(iter_max):
        this_line = words_in_lines[-1]
        words_in_this_line = len(this_line)
        for word_idx_iter in range(words_in_this_line):
            split_index = words_in_this_line-word_idx_iter
            line_text = " ".join(this_line[0:split_index])
            suptitle_line = fig.suptitle(line_text, **kwargs)
            suptitle_line_width = TransformedBbox(
                    suptitle_line.get_window_extent(get_renderer(fig)),
                    fig.transFigure.inverted()
                ).width
            if suptitle_line_width <= max_width:
                break
        next_line = this_line[split_index:]
        words_in_lines = words_in_lines[:-1] + [this_line[:split_index]]
        if len(next_line):
            words_in_lines += [next_line]
        else:
            break
    suptitle_text = "\n".join([" ".join(line) for line in words_in_lines])
    return fig.suptitle(suptitle_text, **kwargs)
Esempio n. 22
0
File: mpl.py Progetto: yanzewu/line
def _update_figure(m_fig: state.Figure, name: str, redraw_subfigures=True):

    dpi = m_fig.attr('dpi')
    size = m_fig.attr('size')
    size_inches = (size[0] / dpi, size[1] / dpi)

    if m_fig.backend == None or not plt.fignum_exists(m_fig.backend.number):
        m_fig.backend = plt.figure(name, figsize=size_inches, dpi=dpi)
        logger.debug('Creating new figure: %s' % name)

        for subfig in m_fig.subfigures:
            subfig.backend = None

        #m_fig.backend.set_dpi(dpi) # TODO incorrect behavior in Windows
        m_fig.backend.set_size_inches(*size_inches)
    else:
        #m_fig.backend.set_dpi(dpi)
        m_fig.backend.set_size_inches(*size_inches)

    m_plt_fig = m_fig.backend
    m_plt_fig.clear()
    margin = m_fig.attr('margin')

    def scale(x):
        return x[0]*(1-margin[2]-margin[0])+margin[0], x[1]*(1-margin[3]-margin[1])+margin[1], \
            x[2]*(1-margin[2]-margin[0]), x[3]*(1-margin[3]-margin[1])

    for subfig in m_fig.subfigures:

        pos = subfig.attr('rpos')
        rsize = subfig.attr('rsize')
        padding = subfig.attr('padding')

        ax = subfig.backend
        frame = scale((pos[0] + padding[0], pos[1] + padding[1],
                       rsize[0] - padding[0] - padding[2],
                       rsize[1] - padding[1] - padding[3]))
        if ax is None:
            ax = plt.Axes(m_plt_fig, frame)
            subfig.backend = ax
        else:
            ax.set_position(frame)

        m_plt_fig.add_axes(ax)
        logger.debug('Subfigure found at %s' % str(ax.get_position().bounds))

    renderer = tight_layout.get_renderer(m_fig.backend)
    if redraw_subfigures:
        for subfig in m_fig.subfigures:
            _update_subfigure(subfig, renderer)
            logger.debug('Updated subfigure %s' % subfig.name)

    m_fig.computed_style['frame'] = style.Rect(
        *m_fig.backend.get_window_extent(renderer).bounds)

    if m_fig.title.attr('text') and m_fig.title.attr('visible'):
        st = m_fig.backend.suptitle(
            m_fig.title.attr('text'),
            fontproperties=font_manager.FontProperties(
                family=m_fig.title.attr('fontfamily'),
                **m_fig.title.attr('fontprops').export()))
        m_fig.title.computed_style['frame'] = style.Rect(
            st.get_window_extent(renderer).bounds)
        # TODO in mpl figure title does not care legend position so they may overlap

    if m_fig.legend.attr('source') and m_fig.legend.attr('visible'):

        m_subfig = [
            s for s in m_fig.subfigures
            if s.name == m_fig.legend.attr('source')
        ][0]

        m_style = m_fig.legend.computed_style

        legend_pos = m_fig.legend.attr('pos')
        if legend_pos == style.FloatingPos.AUTO:
            p = 'upper center'  # best does not really make sense here...
            b = None
        else:
            p, b = _translate_loc(*legend_pos)
        # only inner positions are allowed

        # bbox. Usually it's just figure bbox; The only exception is top,center with title,
        # where we have to consider the space of title
        b = None
        if p == 9 and m_fig.title.attr('text') and m_fig.title.attr(
                'visible') and m_fig.computed_style['frame'][3] > 0:
            b = (0, 0, 1.0, m_fig.title.computed_style['frame'][1] /
                 m_fig.computed_style['frame'][3])

        legend = m_fig.backend.legend(
            [lc[0] for lc in m_subfig._legend_candidates],
            [lc[1] for lc in m_subfig._legend_candidates],
            fancybox=False,
            facecolor=m_style['color'],
            edgecolor=m_style['linecolor'],
            prop=font_manager.FontProperties(family=m_style['fontfamily'],
                                             **m_style['fontprops'].export()),
            loc=p,
            bbox_to_anchor=b,
            ncol=m_style['column'],
            frameon=True,
            framealpha=m_style['alpha'],
        )

        lt = m_style['linetype'].to_str()
        frame = legend.get_frame()
        frame.set_linewidth(m_style['linewidth'])
        frame.set_linestyle(lt if lt else 'None')
        frame.set_zorder(m_style['zindex'])

        for t in legend.get_texts():
            t.set_fontfamily(m_style['fontfamily'])

        m_fig.legend.computed_style['frame'] = style.Rect(
            *legend.get_window_extent(renderer).bounds)
Esempio n. 23
0
def plotMatrix(hm, outFileName,
               colorMapDict={'colorMap': ['binary'], 'missingDataColor': 'black', 'alpha': 1.0},
               plotTitle='',
               xAxisLabel='', yAxisLabel='', regionsLabel='',
               zMin=None, zMax=None,
               yMin=None, yMax=None,
               averageType='median',
               reference_point_label='TSS',
               startLabel='TSS', endLabel="TES",
               heatmapHeight=25,
               heatmapWidth=7.5,
               perGroup=False, whatToShow='plot, heatmap and colorbar',
               image_format=None,
               legend_location='upper-left',
               box_around_heatmaps=True,
               dpi=200):

    matrix_flatten = None
    if zMin is None:
        matrix_flatten = hm.matrix.flatten()
        # try to avoid outliers by using np.percentile
        zMin = np.percentile(matrix_flatten, 1.0)
        if np.isnan(zMin):
            zMin = [None]
        else:
            zMin = [zMin]  # convert to list to support multiple entries

    if zMax is None:
        if matrix_flatten is None:
            matrix_flatten = hm.matrix.flatten()
        # try to avoid outliers by using np.percentile
        zMax = np.percentile(matrix_flatten, 98.0)
        if np.isnan(zMax) or zMax <= zMin[0]:
            zMax = [None]
        else:
            zMax = [zMax]

    if yMin is None:
        yMin = [None]
    if yMax is None:
        yMax = [None]
    if not isinstance(yMin, list):
        yMin = [yMin]
    if not isinstance(yMax, list):
        yMax = [yMax]

    plt.rcParams['font.size'] = 8.0
    fontP = FontProperties()

    showSummaryPlot = False
    showColorbar = False

    if whatToShow == 'plot and heatmap':
        showSummaryPlot = True
    elif whatToShow == 'heatmap and colorbar':
        showColorbar = True
    elif whatToShow == 'plot, heatmap and colorbar':
        showSummaryPlot = True
        showColorbar = True

    # colormap for the heatmap
    if colorMapDict['colorMap']:
        cmap = []
        for color_map in colorMapDict['colorMap']:
            cmap.append(plt.get_cmap(color_map))
            cmap[-1].set_bad(colorMapDict['missingDataColor'])  # nans are printed using this color

    if colorMapDict['colorList'] and len(colorMapDict['colorList']) > 0:
        # make a cmap for each color list given
        cmap = []
        for color_list in colorMapDict['colorList']:
            cmap.append(matplotlib.colors.LinearSegmentedColormap.from_list(
                'my_cmap', color_list.replace(' ', '').split(","), N=colorMapDict['colorNumber']))
            cmap[-1].set_bad(colorMapDict['missingDataColor'])  # nans are printed using this color

    if len(cmap) > 1 or len(zMin) > 1 or len(zMax) > 1:
        # position color bar below heatmap when more than one
        # heatmap color is given
        colorbar_position = 'below'
    else:
        colorbar_position = 'side'

    grids = prepare_layout(hm.matrix, (heatmapWidth, heatmapHeight),
                           showSummaryPlot, showColorbar, perGroup, colorbar_position)

    # figsize: w,h tuple in inches
    figwidth = heatmapWidth / 2.54
    figheight = heatmapHeight / 2.54
    if showSummaryPlot:
        # the summary plot ocupies a height
        # equal to the fig width
        figheight += figwidth

    numsamples = hm.matrix.get_num_samples()
    if perGroup:
        num_cols = hm.matrix.get_num_groups()
    else:
        num_cols = numsamples
    total_figwidth = figwidth * num_cols
    if showColorbar:
        if colorbar_position == 'below':
            figheight += 1 / 2.54
        else:
            total_figwidth += 1 / 2.54

    fig = plt.figure(figsize=(total_figwidth, figheight))

    xticks, xtickslabel = getProfileTicks(hm, reference_point_label, startLabel, endLabel)

    xticks_heat, xtickslabel_heat = get_heatmap_ticks(hm, reference_point_label, startLabel, endLabel)
    fig.suptitle(plotTitle, y=1 - (0.06 / figheight))

    # color map for the summary plot (profile) on top of the heatmap
    cmap_plot = plt.get_cmap('jet')
    numgroups = hm.matrix.get_num_groups()
    if perGroup:
        color_list = cmap_plot(np.arange(hm.matrix.get_num_samples()) / hm.matrix.get_num_samples())
    else:
        color_list = cmap_plot(np.arange(numgroups) / numgroups)
    alpha = colorMapDict['alpha']

    # check if matrix is reference-point based using the upstream >0 value
    # and is sorted by region length. If this is
    # the case, prepare the data to plot a border at the regions end
    if hm.parameters['upstream'] > 0 and \
            hm.matrix.sort_using == 'region_length' and \
            hm.matrix.sort_method != 'no':

            _regions = hm.matrix.get_regions()
            regions_length_in_bins = []
            for _group in _regions:
                _reg_len = []
                for ind_reg in _group:
                    if isinstance(ind_reg, dict):
                        _len = ind_reg['end'] - ind_reg['start']
                    else:
                        _len = sum([x[1] - x[0] for x in ind_reg[1]])
                    _reg_len.append((hm.parameters['upstream'] + _len) / hm.parameters['bin size'])
                regions_length_in_bins.append(_reg_len)
    else:
        regions_length_in_bins = None

    # plot the profiles on top of the heatmaps
    if showSummaryPlot:
        if perGroup:
            iterNum = numgroups
            iterNum2 = hm.matrix.get_num_samples()
        else:
            iterNum = hm.matrix.get_num_samples()
            iterNum2 = numgroups
        ax_list = addProfilePlot(hm, plt, fig, grids, iterNum, iterNum2, perGroup, averageType, xticks, xtickslabel, yAxisLabel, color_list, yMin, yMax, None, None, colorbar_position)
        if len(yMin) > 1 or len(yMax) > 1:
            # replot with a tight layout
            import matplotlib.tight_layout as tl
            specList = tl.get_subplotspec_list(fig.axes, grid_spec=grids)
            renderer = tl.get_renderer(fig)
            kwargs = tl.get_tight_layout_figure(fig, fig.axes, specList, renderer, pad=1.08)

            for ax in ax_list:
                fig.delaxes(ax)

            ax_list = addProfilePlot(hm, plt, fig, grids, iterNum, iterNum2, perGroup, averageType, xticks, xtickslabel, yAxisLabel, color_list, yMin, yMax, kwargs['wspace'], kwargs['hspace'], colorbar_position)

        # reduce the number of yticks by half
        num_ticks = len(ax_list[0].get_yticks())
        yticks = [ax_list[0].get_yticks()[i] for i in range(1, num_ticks, 2)]
        ax_list[0].set_yticks(yticks)
        if legend_location != 'none':
            ax_list[-1].legend(loc=legend_location.replace('-', ' '), ncol=1, prop=fontP,
                               frameon=False, markerscale=0.5)

    first_group = 0  # helper variable to place the title per sample/group
    for sample in range(hm.matrix.get_num_samples()):
        sample_idx = sample
        for group in range(numgroups):
            group_idx = group
            # add the respective profile to the
            # summary plot
            sub_matrix = hm.matrix.get_matrix(group, sample)
            if showSummaryPlot:
                if perGroup:
                    sample_idx = sample + 2  # plot + spacer
                else:
                    group += 2  # plot + spacer
                first_group = 1

            if perGroup:
                ax = fig.add_subplot(grids[sample_idx, group])
                # the remainder (%) is used to iterate
                # over the available color maps (cmap).
                # if the user only provided, lets say two
                # and there are 10 groups, colormaps they are reused every
                # two groups.
                cmap_idx = group_idx % len(cmap)
                zmin_idx = group_idx % len(zMin)
                zmax_idx = group_idx % len(zMax)
            else:
                ax = fig.add_subplot(grids[group, sample])
                # see above for the use of '%'
                cmap_idx = sample % len(cmap)
                zmin_idx = sample % len(zMin)
                zmax_idx = sample % len(zMax)

            if group == first_group and not showSummaryPlot and not perGroup:
                title = hm.matrix.sample_labels[sample]
                ax.set_title(title)

            if box_around_heatmaps is False:
                # Turn off the boxes around the individual heatmaps
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
                ax.spines['left'].set_visible(False)
            rows, cols = sub_matrix['matrix'].shape
            interpolation_type = None if rows >= 1000 and cols >= 200 else 'nearest'
            img = ax.imshow(sub_matrix['matrix'],
                            aspect='auto',
                            interpolation=interpolation_type,
                            origin='upper',
                            vmin=zMin[zmin_idx],
                            vmax=zMax[zmax_idx],
                            cmap=cmap[cmap_idx],
                            alpha=alpha,
                            extent=[0, cols, rows, 0])
            img.set_rasterized(True)
            # plot border at the end of the regions
            # if ordered by length
            if regions_length_in_bins is not None:
                x_lim = ax.get_xlim()
                y_lim = ax.get_ylim()

                ax.plot(regions_length_in_bins[group_idx],
                        np.arange(len(regions_length_in_bins[group_idx])),
                        '--', color='black', linewidth=0.5, dashes=(3, 2))
                ax.set_xlim(x_lim)
                ax.set_ylim(y_lim)

            if perGroup:
                ax.axes.set_xlabel(sub_matrix['group'])
                if sample < hm.matrix.get_num_samples() - 1:
                    ax.axes.get_xaxis().set_visible(False)
            else:
                ax.axes.get_xaxis().set_visible(False)
                ax.axes.set_xlabel(xAxisLabel)
            ax.axes.set_yticks([])
            if perGroup and group == 0:
                ax.axes.set_ylabel(sub_matrix['sample'])
            elif not perGroup and sample == 0:
                ax.axes.set_ylabel(sub_matrix['group'])

            # add labels to last block in a column
            if (perGroup and sample == numsamples - 1) or \
               (not perGroup and group_idx == numgroups - 1):

                # add xticks to the bottom heatmap (last group)
                ax.axes.get_xaxis().set_visible(True)
                if np.ceil(max(xticks_heat)) != float(sub_matrix['matrix'].shape[1]):
                    tickscale = float(sub_matrix['matrix'].shape[1]) / max(xticks_heat)
                    xticks_heat_use = [x * tickscale for x in xticks_heat]
                    ax.axes.set_xticks(xticks_heat_use)
                else:
                    ax.axes.set_xticks(xticks_heat)
                ax.axes.set_xticklabels(xtickslabel_heat, size=8)

                # align the first and last label
                # such that they don't fall off
                # the heatmap sides
                ticks = ax.xaxis.get_major_ticks()
                ticks[0].label1.set_horizontalalignment('left')
                ticks[-1].label1.set_horizontalalignment('right')

                ax.get_xaxis().set_tick_params(
                    which='both',
                    top='off',
                    direction='out')

                if showColorbar and colorbar_position == 'below':
                    # draw a colormap per each heatmap below the last block
                    if perGroup:
                        col = group_idx
                    else:
                        col = sample
                    ax = fig.add_subplot(grids[-1, col])
                    tick_locator = ticker.MaxNLocator(nbins=3)
                    cbar = fig.colorbar(img, cax=ax, alpha=alpha, orientation='horizontal', ticks=tick_locator)
                    labels = cbar.ax.get_xticklabels()
                    ticks = cbar.ax.get_xticks()
                    if ticks[0] == 0:
                        # if the label is at the start of the colobar
                        # move it a bit inside to avoid overlapping
                        # with other labels
                        labels[0].set_horizontalalignment('left')
                    if ticks[-1] == 1:
                        # if the label is at the end of the colobar
                        # move it a bit inside to avoid overlapping
                        # with other labels
                        labels[-1].set_horizontalalignment('right')
                    # cbar.ax.set_xticklabels(labels, rotation=90)

    if showColorbar and colorbar_position != 'below':
        if showSummaryPlot:
            # we don't want to colorbar to extend
            # over the profiles and spacer top rows
            grid_start = 2
        else:
            grid_start = 0

        ax = fig.add_subplot(grids[grid_start:, -1])
        fig.colorbar(img, cax=ax, alpha=alpha)

    if box_around_heatmaps:
        plt.subplots_adjust(wspace=0.10, hspace=0.025, top=0.85, bottom=0, left=0.04, right=0.96)
    else:
        #  When no box is plotted the space between heatmaps is reduced
        plt.subplots_adjust(wspace=0.05, hspace=0.01, top=0.85, bottom=0, left=0.04, right=0.96)

    plt.savefig(outFileName, bbox_inches='tight', pdd_inches=0, dpi=dpi, format=image_format)
    plt.close()