예제 #1
0
파일: pdf.py 프로젝트: kkleidal/running
    def add_figure(self, figure: plt.Figure, caption: Optional[str] = None):
        buf = io.BytesIO()
        figure.set_size_inches(6, 4)
        figure.set_dpi(300)
        figure.savefig(buf, format="png")

        b64url = "data:image/png;base64,%s" % base64.b64encode(
            buf.getvalue()).decode("utf8")
        self.__elements.append({"kind": "figure", "url": b64url})
        plt.clf()
예제 #2
0
def _mpl_figure_to_rgb_img(fig: plt.Figure, height, width):
    fig.set_dpi(100)
    fig.set_size_inches(width / 100, height / 100)

    canvas = fig.canvas
    canvas.draw()
    width, height = np.round(fig.get_size_inches() * fig.get_dpi()).astype(int)
    # image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8')

    img = np.fromstring(canvas.tostring_rgb(),
                        dtype='uint8').reshape(height, width, 3)
    plt.close(fig)
    return img
예제 #3
0
 def _default_before_plots(self, fig: plt.Figure, axes: np.ndarray, num_of_log_groups: int) -> None:
     """Set matplotlib window properties
     Args:
         fig: matplotlib Figure
         num_of_log_groups: number of log groups
     """
     clear_output(wait=True)
     figsize_x = self.max_cols * self.cell_size[0]
     figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
     fig.set_size_inches(figsize_x, figsize_y)
     if num_of_log_groups < axes.size:
         for idx, ax in enumerate(axes[-1]):
             if idx >= (num_of_log_groups + len(self.extra_plots)) % self.max_cols:
                 ax.set_visible(False)
예제 #4
0
def histogram2d(
    data: pd.DataFrame,
    column1: str,
    column2: str,
    fig: plt.Figure = None,
    ax: plt.Axes = None,
    fig_width: int = 6,
    fig_height: int = 6,
    trend_line: str = "auto",
    lower_quantile1: float = 0,
    upper_quantile1: float = 1,
    lower_quantile2: float = 0,
    upper_quantile2: float = 1,
    transform1: str = "identity",
    transform2: str = "identity",
    equalize_axes: bool = False,
    reference_line: bool = False,
    plot_density: bool = False,
) -> Tuple[plt.Figure, plt.Axes, p9.ggplot]:
    """
    Creates an EDA plot for two continuous variables.

    Args:
        data: pandas DataFrame containing data to be plotted
        column1: name of column to plot on the x axis
        column2: name of column to plot on the y axis
        fig: matplotlib Figure generated from blank ggplot to plot onto. If specified, must also specify ax
        ax: matplotlib axes generated from blank ggplot to plot onto. If specified, must also specify fig
        fig_width: figure width in inches
        fig_height: figure height in inches
        trend_line: Trend line to plot over data. Default is to plot no trend line. Other options are passed
            to `geom_smooth <https://plotnine.readthedocs.io/en/stable/generated/plotnine.geoms.geom_smooth.html>`_.
        lower_quantile1: Lower quantile of column1 data to remove before plotting for ignoring outliers
        upper_quantile1: Upper quantile of column1 data to remove before plotting for ignoring outliers
        lower_quantile2: Lower quantile of column2 data to remove before plotting for ignoring outliers
        upper_quantile2: Upper quantile of column2 data to remove before plotting for ignoring outliers
        transform1: Transformation to apply to the column1 data for plotting:

         - **'identity'**: no transformation
         - **'log'**: apply a logarithmic transformation with small constant added in case of zero values
         - **'log_exclude0'**: apply a logarithmic transformation with zero values removed
         - **'sqrt'**: apply a square root transformation
        transform2: Transformation to apply to the column2 data for plotting. Same options as for column1.
        equalize_axes: Square the aspect ratio and match the axis limits
        reference_line: Add a y = x reference line
        plot_density: Overlay a 2d density on the given plot

    Returns:
        Tuple containing matplotlib figure and axes along with the plotnine ggplot object

    Examples:
        .. plot::

            import pandas as pd
            import intedact
            data = pd.read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2018/2018-09-11/cats_vs_dogs.csv")
            intedact.histogram2d(data, 'n_dog_households', 'n_cat_households', equalize_axes=True, reference_line=True);
    """
    data = trim_quantiles(data,
                          column1,
                          lower_quantile=lower_quantile1,
                          upper_quantile=upper_quantile1)
    data = trim_quantiles(data,
                          column2,
                          lower_quantile=lower_quantile2,
                          upper_quantile=upper_quantile2)
    data = preprocess_transformations(data, column1, transform=transform1)
    data = preprocess_transformations(data, column2, transform=transform2)

    # draw the scatterplot
    gg = p9.ggplot(data, p9.aes(x=column1, y=column2)) + p9.geom_bin2d()

    # overlay density
    if plot_density:
        gg += p9.geom_density_2d()

    # add reference line
    if reference_line:
        gg += p9.geom_abline(color="black")

    # add trend line
    if trend_line != "none":
        gg += p9.geom_smooth(method=trend_line, color="red")

    gg += p9.labs(fill="")

    # handle axes transforms
    gg, xlabel = transform_axis(gg, column1, transform1, xaxis=True)
    gg, ylabel = transform_axis(gg, column2, transform2, xaxis=False)

    if fig is None and ax is None:
        gg.draw()
        fig = plt.gcf()
        ax = fig.axes[0]
    else:
        _ = gg._draw_using_figure(fig, [ax])

    if equalize_axes:
        fig, ax, gg = match_axes(fig, ax, gg)
        fig.set_size_inches(fig_width, fig_width)
    else:
        fig.set_size_inches(fig_width, fig_height)

    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)

    return fig, ax, gg
예제 #5
0
파일: ui.py 프로젝트: jiejohn/paddle
class MplView(FigureCanvas, QWidget):
    """
    Base class for matplotlib based views. This handles graph canvas setup, toolbar initialisation
    and figure save options.
    """
    is_floatable_view = True
    is_mpl_toolbar_enabled = True

    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self, parent, width=5, height=5, dpi=96, **kwargs):

        self.v = parent

        self.fig = Figure(figsize=(width, height), dpi=dpi)
        self.ax = self.fig.add_subplot(111)
        FigureCanvas.__init__(self, self.fig)

        self.setParent(parent)

        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        
        # Install navigation handler; we need to provide a Qt interface that can handle multiple 
        # plots in a window under separate tabs
        # self.navigation = MplNavigationHandler(self)

        self._current_axis_bounds = None

    def generate(self):
        pass

    def saveAsImage(self, settings): # Size, dots per metre (for print), resample (redraw) image
        filename, _ = QFileDialog.getSaveFileName(self, 'Save current figure', '',  "Tagged Image File Format (*.tif);;\
                                                                                     Portable Document File (*.pdf);;\
                                                                                     Encapsulated Postscript File (*.eps);;\
                                                                                     Scalable Vector Graphics (*.svg);;\
                                                                                     Portable Network Graphics (*.png)")

        if filename:
            size = settings.get_print_size('in')
            dpi = settings.get_dots_per_inch()
            prev_size = self.fig.get_size_inches()
            self.fig.set_size_inches(*size)
            
            self.fig.savefig(filename, dpi=dpi)
            self.fig.set_size_inches(*prev_size)
            self.redraw()

    def get_size_inches(self, dpi):
        s = self.size()
        return s.width()/dpi, s.height()/dpi

    def redraw(self):
        #FIXME: Ugly hack to refresh the canvas
        self.resize( self.size() - QSize(1,1) )
        self.resize( self.size() + QSize(1,1) )
        
    def resizeEvent(self,e):
        FigureCanvas.resizeEvent(self,e)


    def get_text_bbox_screen_coords(self, t):
        bbox = t.get_window_extent(self.get_renderer())        
        return bbox.get_points()

    def get_text_bbox_data_coords(self, t):
        bbox = t.get_window_extent(self.get_renderer())        
        axbox = bbox.transformed(self.ax.transData.inverted())
        return axbox.get_points()
        
    def extend_limits(self, a, b):
        # Extend a to meet b where applicable
        ax, ay = list(a[0]), list(a[1])
        bx, by = b[:, 0], b[:, 1]
   
        ax[0] = bx[0] if bx[0] < ax[0] else ax[0]
        ax[1] = bx[1] if bx[1] > ax[1] else ax[1]

        ay[0] = by[0] if by[0] < ay[0] else ay[0]
        ay[1] = by[1] if by[1] > ay[1] else ay[1]
                
        return [ax,ay]

    def adjust_tight_bbox(self, pad=0.1, extra_artists=None):
        bbox_inches = self.figure.get_tightbbox(self.renderer)
        bbox_artists = self.figure.get_default_bbox_extra_artists()

        if extra_artists is None:
            extra_artists = []
            extra_artists.extend([ax.get_legend() for ax in self.figure.axes if ax.get_legend()])

        bbox_artists.extend(extra_artists)
        bbox_filtered = []
        for a in bbox_artists:
            bbox = a.get_window_extent(self.renderer)
            if a.get_clip_on():
                clip_box = a.get_clip_box()
                if clip_box is not None:
                    bbox = Bbox.intersection(bbox, clip_box)
                clip_path = a.get_clip_path()
                if clip_path is not None and bbox is not None:
                    clip_path = clip_path.get_fully_transformed_path()
                    bbox = Bbox.intersection(bbox,
                                             clip_path.get_extents())
            if bbox is not None and (bbox.width != 0 or
                                     bbox.height != 0):
                bbox_filtered.append(bbox)

        if bbox_filtered:
            _bbox = Bbox.union(bbox_filtered)
            trans = Affine2D().scale(1.0 / self.figure.dpi)
            bbox_extra = TransformedBbox(_bbox, trans)
            bbox_inches = Bbox.union([bbox_inches, bbox_extra])

        if pad:
            bbox_inches = bbox_inches.padded(pad)

        rect = (np.array(bbox_inches.bounds).reshape(-1,2) / self.figure.get_size_inches()).flatten()

        # Adjust the rect; values <0 to +; + to zero
        xpad = -np.min((rect[0], (1-rect[2])))
        xpad = 0 if xpad < 0 else xpad
        ypad = -np.min((rect[1], (1-rect[3])))
        ypad = 0 if ypad < 0 else ypad
        rect = np.array([ xpad, ypad, 1-xpad, 1-ypad ])



        self.figure.tight_layout(rect=np.abs(rect))
예제 #6
0
def save_fig(fig: plt.Figure, name: str, size: Tuple[int, int] = [3.5, 2]):
    fig.tight_layout()
    fig.set_size_inches(w=size[0], h=size[1])
    fig.savefig(os.path.join(figure_dir(), name + ".pgf"))
    fig.savefig(os.path.join(figure_dir(), name + ".pdf"))
예제 #7
0
    def reset_figure(fig: pyplot.Figure, ax: pyplot.Axes) -> None:
        """清空并重置一个画布
        """
        fig.gca().cla()
        fig.gca().set_title('Wireless Sensor Networks')
        fig.gca().set_xlabel('x')
        fig.gca().set_ylabel('y')
        fig.set_size_inches(8, 6)
        ax.set_position((0.1, 0.11, 0.6, 0.8))

        legend_elements = (
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='red',
                          label='source'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='green',
                          label='alive'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='orange',
                          label='received'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='yellow',
                          label='replied'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='blue',
                          label='sending'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='black',
                          label='dead'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='red',
                          label='range of signal\n(source node)'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='green',
                          label='range of signal\n(alive node)'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='orange',
                          label='range of signal\n(received node)'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='yellow',
                          label='range of signal\n(replied node)'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='blue',
                          label='range of signal\n(sending node)'),
        )
        ax.legend(handles=legend_elements,
                  loc='upper left',
                  bbox_to_anchor=(1.02, 1),
                  borderaxespad=0)