예제 #1
0
def save_figure(
        fig_object: matplotlib.figure.Figure,
        file_path: Text) -> None:
    """Fully saves the figure in pdf and pkl format for later modification.

    This function saves the figure in a pkl and pdf such that later can
        be loaded and easily be modified.
        To have the figure object, one can add the following line of the code
        to the beginning of their code:
            fig_object = plt.figure()

    Args:
        fig_object: Figure object (computed by "plt.figure()")

        file_path: Texting file path without file extension.

    Returns:
        None.

    Raises:
        None.
    """
    # Saves as pdf.
    fig_object.savefig(file_path + '.pdf', dpi=fig_object.dpi)
    # Also saves as pickle.
    with open(file_path + '.pkl', 'wb') as handle:
        pk.dump(fig_object, handle, protocol=pk.HIGHEST_PROTOCOL)
예제 #2
0
def set_dim(
    fig: matplotlib.figure.Figure,
    width: float = 398.3386,
    fraction_of_line_width: float = 1,
    ratio: float = (5 ** 0.5 - 1) / 2,
) -> None:
    """Set aesthetic figure dimensions to avoid scaling in latex.

    Default width is `src.constants.REPORT_WIDTH`.
    Default ratio is golden ratio, with figure occupying full page width.

    Args:
        fig (matplotlib.figure.Figure): Figure object to resize.
        width (float): Textwidth of the report to make fontsizes match.
            Defaults to `src.constants.REPORT_WIDTH`.
        fraction_of_line_width (float, optional): Fraction of the document width
            which you wish the figure to occupy.  Defaults to 1.
        ratio (float, optional): Fraction of figure width that the figure height
            should be. Defaults to (5 ** 0.5 - 1)/2.

    Returns:
        void; alters current figure to have the desired dimensions

    Example:
        Here is an example of using this function::
            >>> set_dim(fig, fraction_of_line_width=1, ratio=(5 ** 0.5 - 1) / 2)
    """
    fig.set_size_inches(
        get_dim(width=width, fraction_of_line_width=fraction_of_line_width, ratio=ratio)
    )
예제 #3
0
def update_figure(fig: mpl.figure.Figure,
                  params: Dict[str, str] = False) -> mpl.figure.Figure:
    """
    Updates some mpl Figure parameters. Only limited support for now.
    In a future version the optuna plots will be moved here 
    and expanded customization will be enabled.
    
    Returns a matplotlib Figure
    
    Inputs: 
        fig: a matplotlib Figure
        params: a dictionary containing mpl fields
    """

    if params is False:
        fig.set_yscale("log")
        mpl.rcParams.update({"figure.dpi": 300})
    else:
        if "rcparams" in params:
            mpl.rcParams.update(**params["rcparams"])
        if "set_xlim" in params:
            fig.set_xlim(params["set_xlim"])
        if "set_ylim" in params:
            fig.set_ylim(params["set_ylim"])
        if "set_xscale" in params:
            fig.set_xscale(params["set_xscale"])
        if "set_yscale" in params:
            fig.set_yscale(params["set_yscale"])

    plt.tight_layout()
    return fig
예제 #4
0
def hfss_plot_convergences_report(convergence_t: pd.core.frame.DataFrame,
                                  convergence_f: pd.core.frame.DataFrame,
                                  fig: mpl.figure.Figure = None,
                                  _display=True):
    """Plot convergence frequency vs. pass number if fig is None.
    Plot delta frequency and solved elements vs. pass number.
    Plot delta frequency vs. solved elements.

    Args:
        convergence_t (pandas.core.frame.DataFrame): Convergence vs pass number of the eigenemode freqs.
        convergence_f (pandas.core.frame.DataFrame): Convergence vs pass number of the eigenemode freqs.
        fig (matplotlib.figure.Figure, optional): A mpl figure. Defaults to None.
        _display (bool, optional): Display the plot? Defaults to True.
    """

    if fig is None:
        fig = plt.figure(figsize=(11, 3.))

        # Grid spec and axes;    height_ratios=[4, 1], wspace=0.5
        gs = mpl.gridspec.GridSpec(1, 3, width_ratios=[1.2, 1.5, 1])
        axs = [fig.add_subplot(gs[i]) for i in range(3)]

        ax0t = axs[1].twinx()
        plot_convergence_f_vspass(axs[0], convergence_f)
        plot_convergence_max_df(axs[1], convergence_t.iloc[:, 1])
        plot_convergence_solved_elem(ax0t, convergence_t.iloc[:, 0])
        plot_convergence_maxdf_vs_sol(axs[2], convergence_t.iloc[:, 1],
                                      convergence_t.iloc[:, 0])

        fig.tight_layout(w_pad=0.1)  # pad=0.0, w_pad=0.1, h_pad=1.0)
예제 #5
0
def _save_fig(dst_path: str, figure: matplotlib.figure.Figure) -> None:
    """Save the generated figures for the dataset in dst_dir."""
    # `savefig` do not support GCS, so first save the image locally.
    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_path = os.path.join(tmp_dir, 'tmp.png')
        figure.savefig(tmp_path)
        tf.io.gfile.copy(tmp_path, dst_path, overwrite=FLAGS.overwrite)
    plt.close(figure)
예제 #6
0
def fig_to_image(fig: mpl.figure.Figure, **kwargs: dict) -> IO:
    bbox_inches = kwargs.pop("bbox_inches", "tight")
    format_ = "svg"

    img = BytesIO()
    fig.savefig(img, bbox_inches=bbox_inches, format=format_, **kwargs)
    img.seek(0)

    return img
예제 #7
0
def svg_from_mpl_axes(fig: mpl.figure.Figure) -> QByteArray:
    """ Convert a matplotlib figure to SVG and store it in a Qt byte array.
    """

    data = io.BytesIO()
    fig.savefig(data, format="svg")
    plt.close(fig)

    return QByteArray(data.getvalue())
예제 #8
0
def fig_square(fig: mpl.figure.Figure) -> None:
    """Adjust figure width so that it is square, and tight layout.

    Parameters
    ----------
    fig : mpl.figure.Figure
        `Figure` instance.
    """
    fig.set_size_inches(mpl.figure.figaspect(1))
    fig.tight_layout()
예제 #9
0
def save_figure(filename: Union[str, IO[bytes]],
                fig: matplotlib.figure.Figure,
                quality: Optional[int] = 90,
                optimize: Optional[bool] = True) -> Tuple[int, int]:
    """Use Agg to save a figure to an image file and return dimensions."""
    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    _, _, width, height = canvas.figure.bbox.bounds
    fig.savefig(filename, quality=quality, optimize=optimize)
    return int(width), int(height)
예제 #10
0
def figure_to_base64str(fig: matplotlib.figure.Figure) -> str:
    """Converts a Matplotlib figure to a base64 string encoding.

  Args:
    fig: A matplotlib Figure.

  Returns:
    A base64 encoding of the figure.
  """
    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight', format='png')
    return base64.b64encode(buf.getbuffer().tobytes()).decode('ascii')
예제 #11
0
def save_fig(fig: mpl.figure.Figure,
             filename: str,
             directory: Optional[str] = None,
             add_date: bool = False,
             **savefig_kwargs):
    """Saves a Matplotlib figure to disk."""
    if add_date:
        filename = "{}_{}".format(time.strftime("%Y%m%d"), filename)
    savefig_kwargs.setdefault("bbox_inches", "tight")
    savefig_kwargs.setdefault("pad_inches", 0.02)
    save_path = os.path.join(directory, filename)
    fig.savefig(save_path, **savefig_kwargs)
예제 #12
0
def multi_nyquist_title(fig: matplotlib.figure.Figure, params: paramsTup,
                        settings) -> matplotlib.figure.Figure:

    # plt.figure(fig)  # make sure fig is focus
    (filename_strip, solv, elec, ref_elec, work_elec, voltage) = params
    elec = "TBFTFB"
    fig.suptitle(
        f"Nyquist plots of {work_elec} in {elec} ({solv}) vs {ref_elec} reference",
        y=0.98)
    # xmin, xmax = plt.xlim()     # pylint: disable=W0612
    # expt_vars = f"WE = {work_elec} \nRE = {ref_elec} \nElectrolyte = {elec} \nSolvent = {solv}"
    # fig.text(xmax * 1.1, 0, expt_vars, fontdict=settings, withdash=False)
    return fig
예제 #13
0
def setsize(fig: matplotlib.figure.Figure) -> None:
    """
    Set the size of a figure.
    
    Currently the size is hardcoded, but functionality may be extended in the
    future.
    
    Arguments
    ---------
    fig : matplotlib.figure.Figure
        Figure to a be saved.
    """
    fig.set_size_inches(11.75, 8.25)  # a4
예제 #14
0
    def subplot_data(fig: matplotlib.figure.Figure,
                     subplot_params: tuple,
                     timestamps: list,
                     data: np.array,
                     x_label: str,
                     y_label: str,
                     event_timestamps: list = None,
                     data_label: str = None,
                     fontsize: int = 30,
                     event_annotation_color='g') -> None:
        '''Plots a given time series data on a subplot.

        Keyword arguments:
        fig -- matplotlib figure handle
        subplot_params -- a tuple with three subplot parameters: number of rows,
                          number of columns, and the number of the current subplot
        timestamps -- a list of timestamps (plotted over the x-axis)
        data -- a one-dimensional numpy array of data items (plotted over the y-axis)
        x_label -- x axis label
        y_label -- y axis label
        event_timestamps -- optional list of timestamps for data
                            annotation; the annotations are plotted as
                            lines at the given timestamps
                            (default None, in which case there are no annotations)
        data_label -- optional label for the plotted data; if a label is given,
                      a legend is added to the plot (default None)
        fontsize -- fontsize for the x and y axes labels (default 30)
        event_annotation_color -- color of event annotation lines (default 'g')

        '''
        fig.add_subplot(*subplot_params)

        # we plot the data
        if data_label:
            plt.plot(timestamps, data, label=data_label)
        else:
            plt.plot(timestamps, data)

        # we add event annotations to the plot if there are any events
        if event_timestamps is not None and len(event_timestamps) > 0:
            plt.plot([event_timestamps, event_timestamps],
                     [np.nanmin(data), np.nanmax(data)],
                     event_annotation_color)

        plt.xlabel(x_label, fontsize=fontsize)
        plt.ylabel(y_label, fontsize=fontsize)

        # we add a legend only if a data label is assigned
        if data_label:
            plt.legend()
예제 #15
0
def savefig(fig: matplotlib.figure.Figure, filename: str) -> None:
    """
    Save a single figure to file.
    
    Arguments
    ---------
    fig : matplotlib.figure.Figure
        Figure to a be saved.
    filename : str
        Name of the file to be written.
    """
    print("saving figure {file}".format(file=filename))
    matplotlib.pyplot.show(block=False)
    fig.savefig(filename, dpi=300)
예제 #16
0
def style_figure(fig: matplotlib.figure.Figure, title: str, bottom: float = 0.2) -> None:
    """Stylize the supplied matplotlib Figure instance."""
    fig.tight_layout()
    if bottom is not None:
        fig.subplots_adjust(bottom=bottom)
    fig.suptitle(title)
    fig.legend(ncol=10, handlelength=0.75, handletextpad=0.25, columnspacing=0.5, loc='lower left')
예제 #17
0
def _save_figure(fig: mpl.figure.Figure,
                 path: pathlib.Path,
                 close_figure: bool = False) -> None:
    """Saves a figure at a given location and optionally closes it.

    Args:
        fig (matplotlib.figure.Figure): figure to save
        figure_path (pathlib.Path): path where figure should be saved
        close_figure (bool, optional): whether to close figure after saving, defaults to
            False
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    log.debug(f"saving figure as {path}")
    fig.savefig(path)
    if close_figure:
        plt.close(fig)
예제 #18
0
def add_parameter_details(f: matplotlib.figure.Figure, details: str, y: float):
    if not details:
        return
    string = ""
    elements = ["{} = {}".format(k, v) for k, v in json.loads(details).items()]
    acc = 0
    for el in elements[:-1]:
        string = string + el + ','
        acc += len(el) + 2
        if acc > 60:
            string = string + '\n'
            acc = 0
        else:
            string = string + ' '
    string += elements[-1]
    f.text(0.05, y, string)
예제 #19
0
파일: rdm_plot.py 프로젝트: caiw/pyrsa
def _rdm_colorbar(
    mappable: matplotlib.cm.ScalarMappable = None,
    fig: matplotlib.figure.Figure = None,
    ax: Axes = None,
    title: str = None,
) -> matplotlib.colorbar.Colorbar:
    """_rdm_colorbar. Add vertically-oriented, small colorbar to rdm figure. Used
    internally by show_rdm.

    Args:
        mappable (matplotlib.cm.ScalarMappable): Typically plt.imshow instance.
        fig (matplotlib.figure.Figure): Matplotlib figure handle.
        ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default.
        title (str): Title string for the colorbar (positioned top, left aligned).

    Returns:
        matplotlib.colorbar.Colorbar: Matplotlib handle.
    """
    cb = fig.colorbar(
        mappable=mappable,
        ax=ax,
        shrink=0.25,
        aspect=5,
        ticks=matplotlib.ticker.LinearLocator(numticks=3),
    )
    cb.ax.set_title(title, loc="left", fontdict=dict(fontweight="normal"))
    return cb
예제 #20
0
def setup_ax(fig: matplotlib.figure.Figure,
             pos: tuple[int, int, int] = (1, 1, 1),
             ) -> matplotlib.axes.Axes:
    '''
    Args
    - fig: matplotlib.figure.Figure
    - pos: 3-tuple of int, axes position (nrows, ncols, index)

    Returns: matplotlib.axes.Axes
    '''
    ax = fig.add_subplot(*pos, projection=ccrs.PlateCarree())

    # draw land (better version of cfeature.LAND)
    # land = cfeature.NaturalEarthFeature(
    #     category='physical', name='land', scale='10m',
    #     edgecolor='face', facecolor=cfeature.COLORS['land'], zorder=-1)
    ax.add_feature(cfeature.LAND)

    # draw borders of countries (better version of cfeature.BORDERS)
    countries = cfeature.NaturalEarthFeature(
        category='cultural', name='admin_0_boundary_lines_land', scale='10m',
        edgecolor='black', facecolor='none')
    ax.add_feature(countries)

    # draw coastline (better version of cfeature.COASTLINE)
    coastline = cfeature.NaturalEarthFeature(
        category='physical', name='coastline', scale='10m',
        edgecolor='black', facecolor='none')
    ax.add_feature(coastline)

    # draw lakes (better version of cfeature.LAKES)
    lakes = cfeature.NaturalEarthFeature(
        category='physical', name='lakes', scale='10m',
        edgecolor='face', facecolor=cfeature.COLORS['water'])
    ax.add_feature(lakes)

    # draw ocean (better version of cfeature.OCEAN)
    ocean = cfeature.NaturalEarthFeature(
        category='physical', name='ocean', scale='50m',
        edgecolor='face', facecolor=cfeature.COLORS['water'], zorder=-1)
    ax.add_feature(ocean)

    # draw rivers (better version of cfeature.RIVERS)
    rivers = cfeature.NaturalEarthFeature(
        category='physical', name='rivers_lake_centerlines', scale='10m',
        edgecolor=cfeature.COLORS['water'], facecolor='none')
    ax.add_feature(rivers)

    # draw borders of states/provinces internal to a country
    states_provinces = cfeature.NaturalEarthFeature(
        category='cultural', name='admin_1_states_provinces_lines', scale='50m',
        edgecolor='gray', facecolor='none')
    ax.add_feature(states_provinces)

    ax.set_aspect('equal')
    gridliner = ax.gridlines(draw_labels=True)
    gridliner.top_labels = False
    gridliner.right_labels = False
    return ax
예제 #21
0
def add_metric_description_title(df_plot: pd.DataFrame,
                                 fig: mpl.figure.Figure,
                                 y: float = 1.0):
    """Adds a suptitle to the figure, describing the metric used."""
    assert df_plot.Metric.nunique() == 1, "More than one metric in DataFrame."
    binning_scheme = utils.assert_and_get_constant(df_plot.binning_scheme)
    num_bins = utils.assert_and_get_constant(df_plot.num_bins)
    norm = utils.assert_and_get_constant(df_plot.norm)
    title = (f"ECE variant: {binning_scheme} binning, "
             f"{num_bins:.0f} bins, "
             f"{norm} norm")
    display_names = {
        "adaptive": "equal-mass",
        "even": "equal-width",
        "l1": "L1",
        "l2": "L2",
    }
    for old, new in display_names.items():
        title = title.replace(old, new)

    fig.suptitle(title, y=y, verticalalignment="bottom")
예제 #22
0
def save_plot_impl(fig: matplotlib.figure.Figure,
                   output_prefix: str, output_name: str,
                   printing_extensions: Sequence[str]) -> List[str]:
    """ Implementation of generic save plot function.

    It loops over all requested file extensions and save the matplotlib fig.

    Args:
        fig: Figure on which the plot was drawn.
        output_prefix: File path to where files should be saved.
        output_name: Filename under which the plot should be saved, but without the file extension.
        printing_extensions: List of file extensions under which plots should be saved. They should
            not contain the dot!
    Returns:
        Filenames under which the plot was saved.
    """
    filenames = []
    for extension in printing_extensions:
        filename = os.path.join(output_prefix, output_name + "." + extension)
        logger.debug(f"Saving matplotlib figure to \"{filename}\"")
        fig.savefig(filename)
        filenames.append(filename)
    return filenames
예제 #23
0
    def plot_convergences(self,
                          convergence_t: pd.DataFrame = None,
                          convergence_f: pd.DataFrame = None,
                          fig: mpl.figure.Figure = None,
                          _display: bool = True):
        """Creates 3 plots, useful to determin the convergence achieved by the renderer:
        * convergence frequency vs. pass number if fig is None.
        * delta frequency and solved elements vs. pass number.
        * delta frequency vs. solved elements.

        Args:
            convergence_t (pd.DataFrame): Convergence vs pass number of the eigenemode freqs.
            convergence_f (pd.DataFrame): Convergence vs pass number of the eigenemode freqs.
            fig (matplotlib.figure.Figure, optional): A mpl figure. Defaults to None.
            _display (bool, optional): Display the plot? Defaults to True.
        """

        if convergence_t is None:
            convergence_t = self.convergence_t
        if convergence_f is None:
            convergence_f = self.convergence_f

        if fig is None:
            fig = plt.figure(figsize=(11, 3.))

            # Grid spec and axes;    height_ratios=[4, 1], wspace=0.5
            gspec = mpl.gridspec.GridSpec(1, 3, width_ratios=[1.2, 1.5, 1])
            axs = [fig.add_subplot(gspec[i]) for i in range(3)]

            ax0t = axs[1].twinx()
            plot_convergence_f_vspass(axs[0], convergence_f)
            plot_convergence_max_df(axs[1], convergence_t.iloc[:, 1])
            plot_convergence_solved_elem(ax0t, convergence_t.iloc[:, 0])
            plot_convergence_maxdf_vs_sol(axs[2], convergence_t.iloc[:, 1],
                                          convergence_t.iloc[:, 0])

            fig.tight_layout(w_pad=0.1)  # pad=0.0, w_pad=0.1, h_pad=1.0)
예제 #24
0
 def add_stimulus(self, fig: matplotlib.figure.Figure, stim_path: Path):
     """
     Add stimulus to a figure
     Arguments:
         fig {matplotlib.figure.Figure} -- [description]
     Returns:
         [type] -- [description]
     """
     stim_path = Path(stim_path)
     if stim_path.suffix == ".mp4":
         vidcap = cv2.VideoCapture(str(stim_path))
         success, image = vidcap.read()
     else:
         image = plt.imread(get_sample_data(stim_path))
     newax = fig.add_axes([0.38, 0, 0.2, 0.4], zorder=-1)
     newax.imshow(image)
     newax.axis("off")
     return fig
예제 #25
0
    def apply(self, fig: matplotlib.figure.Figure) -> None:
        # It shouldn't hurt to align the labels if there's only one.
        fig.align_ylabels()

        # Adjust the layout.
        fig.tight_layout()
        adjust_default_args = dict(
            # Reduce spacing between subplots
            hspace=0,
            wspace=0,
            # Reduce external spacing
            left=0.10,
            bottom=0.105,
            right=0.98,
            top=0.98,
        )
        adjust_default_args.update(self.edge_padding)
        fig.subplots_adjust(**adjust_default_args)
예제 #26
0
def savefig(f: matplotlib.figure.Figure,
            fpath: str,
            tight: bool = True,
            details: str = None,
            space: float = 0.0,
            **kwargs):
    if tight:
        if details:
            add_parameter_details(f, details, -0.4 + space)
        f.savefig(fpath, bbox_inches='tight', **kwargs)
    else:
        if details:
            f.subplots_adjust(bottom=0.2)
            add_parameter_details(f, details, 0.1)
        f.savefig(fpath, **kwargs)
    if fpath.endswith('png'):
        add_tags_to_png_file(fpath)
    if fpath.endswith('svg'):
        add_tags_to_svg_file(fpath)
예제 #27
0
def _trends_helper(
    adata: AnnData,
    models: Dict[str, Dict[str, Any]],
    gene: str,
    ln_key: str,
    lineage_names: Optional[Sequence[str]] = None,
    same_plot: bool = False,
    sharey: Union[str, bool] = False,
    show_ylabel: bool = True,
    show_lineage: bool = True,
    show_xticks_and_label: bool = True,
    lineage_cmap=None,
    abs_prob_cmap=cm.viridis,
    gene_as_title: bool = False,
    legend_loc: Optional[str] = "best",
    fig: mpl.figure.Figure = None,
    axes: Union[mpl.axes.Axes, Sequence[mpl.axes.Axes]] = None,
    **kwargs,
) -> None:
    """
    Plot an expression gene for some lineages.

    Parameters
    ----------
    %(adata)s
    %(model)s
    gene
        Name of the gene in `adata.var_names``.
    ln_key
        Key in ``adata.obsm`` where to find the lineages.
    lineage_names
        Names of lineages to plot.
    same_plot
        Whether to plot all lineages in the same plot or separately.
    sharey
        Whether the y-axis is being shared.
    show_ylabel
        Whether to show y-label on the y-axis. Usually, only the first column will contain the y-label.
    show_lineage
        Whether to show the lineage as the title. Usually, only first row will contain the lineage names.
    show_xticks_and_label
        Whether to show x-ticks and x-label. Usually, only the last row will show this.
    lineage_cmap
        Colormap to use when coloring the the lineage. If `None`, use colors from ``adata.uns``.
    abs_prob_cmap:
        Colormap to use when coloring in the absorption probabilities, if they are being plotted.
    gene_as_title
        Whether to use the gene names as titles (with lineage names as well) or on the y-axis.
    legend_loc
        Location of the legend. If `None`, don't show any legend.
    fig
        Figure to use.
    ax
        Ax to use.
    **kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`.

    Returns
    -------
    %(just_plots)s
    """

    n_lineages = len(lineage_names)
    if same_plot:
        axes = [axes] * len(lineage_names)

    fig.tight_layout()
    axes = np.ravel(axes)

    percs = kwargs.pop("perc", None)
    if percs is None or not isinstance(percs[0], (tuple, list)):
        percs = [percs]

    same_perc = False  # we need to show colorbar always if percs differ
    if len(percs) != n_lineages or n_lineages == 1:
        if len(percs) != 1:
            raise ValueError(
                f"Percentile must be a collection of size `1` or `{n_lineages}`, got `{len(percs)}`."
            )
        same_perc = True
        percs = percs * n_lineages

    hide_cells = kwargs.pop("hide_cells", False)
    show_cbar = kwargs.pop("show_cbar", True)

    if same_plot:
        lineage_colors = (
            lineage_cmap.colors
            if lineage_cmap is not None and hasattr(lineage_cmap, "colors")
            else adata.uns.get(f"{_colors(ln_key)}", cm.Set1.colors)
        )
    else:
        lineage_colors = (
            "black" if not mcolors.is_color_like(lineage_cmap) else lineage_cmap
        )

    for i, (name, ax, perc) in enumerate(zip(lineage_names, axes, percs)):
        if same_plot:
            if gene_as_title:
                title = gene
                ylabel = "expression" if show_ylabel else None
            else:
                title = ""
                ylabel = gene
        else:
            if gene_as_title:
                title = None
                ylabel = "expression" if i == 0 else None
            else:
                title = (
                    (name if name is not None else "no lineage") if show_lineage else ""
                )
                ylabel = gene if i == 0 else None

        models[gene][name].plot(
            ax=ax,
            fig=fig,
            perc=perc,
            show_cbar=False,
            title=title,
            hide_cells=hide_cells or (same_plot and i == n_lineages - 1),
            same_plot=same_plot,
            lineage_color=lineage_colors[i]
            if same_plot and name is not None
            else lineage_colors,
            abs_prob_cmap=abs_prob_cmap,
            ylabel=ylabel,
            **kwargs,
        )
        if sharey in ("row", "all", True) and i == 0:
            plt.setp(ax.get_yticklabels(), visible=True)

        if show_xticks_and_label:
            plt.setp(ax.get_xticklabels(), visible=True)
        else:
            ax.set_xlabel(None)

    if not same_plot and same_perc and show_cbar and not hide_cells:
        vmin = np.min([models[gene][ln].w_all for ln in lineage_names])
        vmax = np.max([models[gene][ln].w_all for ln in lineage_names])
        norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

        for ax in axes:
            [
                c
                for c in ax.get_children()
                if isinstance(c, mpl.collections.PathCollection)
            ][0].set_norm(norm)

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="2.5%", pad=0.1)
        _ = mpl.colorbar.ColorbarBase(
            cax, norm=norm, cmap=abs_prob_cmap, label="absorption probability"
        )

    if same_plot and lineage_names != [None] and legend_loc is not None:
        ax.legend(lineage_names, loc=legend_loc)
예제 #28
0
def plot_correlations(
    fig: matplotlib.figure.Figure,
    ax: matplotlib.axes.Axes,
    r2: float,
    slope: float,
    y_inter: float,
    corr_vals: np.ndarray,
    vis_vals: np.ndarray,
    scale_factor: Union[float, int],
    corr_bname: str,
    vis_bname: str,
    odir: Union[Path, str],
):
    """
    Plot the correlations between NIR band and the visible bands for
    the Hedley et al. (2005) sunglint correction method

    Parameters
    ----------
    fig : matplotlib.figure object
        Reusing a matplotlib.figure object to avoid the creation many
        fig instantances

    ax : matplotlib.axes._subplots object
        Reusing the axes object

    r2 : float
        The correlation coefficient squared of the linear regression
        between NIR and a VIS band

    slope : float
        The slope/gradient of the linear regression between NIR and
        a VIS band

    y_inter : float
        The intercept of the linear regression between NIR and a
        VIS band

    corr_vals : numpy.ndarray
        1D array containing the NIR values from the ROI

    vis_vals : numpy.ndarray
        1D array containing the VIS values from the ROI

    scale_factor : int or None
        The scale factor used to convert integers to reflectances
        that range [0...1]

    corr_bname : str
        The NIR band number

    vis_bname : str
        The VIS band number

    odir : str
        Directory where the correlation plots are saved

    """
    # clear previous plot
    ax.clear()

    # ----------------------------------- #
    #   Create a unique cmap for hist2d   #
    # ----------------------------------- #
    ncolours = 256

    # get the jet colormap
    colour_array = plt.get_cmap("jet")(range(ncolours))  # 256 x 4

    # change alpha values
    # e.g. low values have alpha = 1, high values have alpha = 0
    # color_array[:,-1] = np.linspace(1.0,0.0,ncolors)
    # e.g. low values have alpha = 0, high values have alpha = 1
    # color_array[:,-1] = np.linspace(0.0,1.0,ncolors)

    # We want only the first few colours to have low alpha
    # as they would represent low density [meshgrid] bins
    # which we are not interested in, and hence would want
    # them to appear as a white colour (alpha ~ 0)
    num_alpha = 25
    colour_array[0:num_alpha, -1] = np.linspace(0.0, 1.0, num_alpha)
    colour_array[num_alpha:, -1] = 1

    # create a colormap object
    cmap = LinearSegmentedColormap.from_list(name="jet_alpha",
                                             colors=colour_array)

    # ----------------------------------- #
    #  Plot density using np.histogram2d  #
    # ----------------------------------- #
    xbin_low, xbin_high = np.percentile(corr_vals, (1, 99),
                                        interpolation="linear")
    ybin_low, ybin_high = np.percentile(vis_vals, (1, 99),
                                        interpolation="linear")

    nbins = [int(xbin_high - xbin_low), int(ybin_high - ybin_low)]

    bin_range = [[int(xbin_low), int(xbin_high)],
                 [int(ybin_low), int(ybin_high)]]

    hist2d, xedges, yedges = np.histogram2d(x=corr_vals,
                                            y=vis_vals,
                                            bins=nbins,
                                            range=bin_range)

    # normalised hist to range [0...1] then rotate and flip
    hist2d = np.flipud(np.rot90(hist2d / hist2d.max()))

    # Mask zeros
    hist_masked = np.ma.masked_where(hist2d == 0, hist2d)

    # use pcolormesh to plot the hist2D
    qm = ax.pcolormesh(xedges, yedges, hist_masked, cmap=cmap)

    # create a colour bar axes within ax
    cbaxes = inset_axes(
        ax,
        width="3%",
        height="30%",
        bbox_to_anchor=(0.37, 0.03, 1, 1),
        loc="lower center",
        bbox_transform=ax.transAxes,
    )

    # Add a colour bar inside the axes
    fig.colorbar(
        cm.ScalarMappable(cmap=cmap),
        cax=cbaxes,
        ticks=[0.0, 1],
        orientation="vertical",
        label="Point Density",
    )

    # ----------------------------------- #
    #     Plot linear regression line     #
    # ----------------------------------- #
    x_range = np.array([xbin_low, xbin_high])
    (ln, ) = ax.plot(
        x_range,
        slope * (x_range) + y_inter,
        color="k",
        linestyle="-",
        label="linear regr.",
    )

    # ----------------------------------- #
    #          Format the figure          #
    # ----------------------------------- #
    # add legend (top left)
    lgnd = ax.legend(loc=2, fontsize=10)

    # add annotation
    ann_str = (r"$r^{2}$" + " = {0:0.2f}\n"
               "slope = {1:0.2f}\n"
               "y-inter = {2:0.2f}".format(r2, slope, y_inter))
    ann = ax.annotate(ann_str,
                      xy=(0.02, 0.76),
                      xycoords="axes fraction",
                      fontsize=10)

    # Add labels to figure
    xlabel = f"Reflectance ({corr_bname})"
    ylabel = f"Reflectance ({vis_bname})"

    if scale_factor is not None:
        if scale_factor > 1:
            xlabel += " " + r"$\times$" + " {0}".format(int(scale_factor))
            ylabel += " " + r"$\times$" + " {0}".format(int(scale_factor))

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    # plt.show(); sys.exit()

    # Save figure
    png_file = os.path.join(
        odir, "Correlation_{0}_vs_{1}.png".format(corr_bname, vis_bname))

    fig.savefig(png_file,
                format="png",
                bbox_inches="tight",
                pad_inches=0.1,
                dpi=300)

    # delete all lines and annotations from figure,
    # so it can be reused in the next iteration
    qm.remove()
    ln.remove()
    ann.remove()
    lgnd.remove()
예제 #29
0
파일: Plot.py 프로젝트: tttapa/EAGLE
def save(fig: matplotlib.figure.Figure, filename: str):
    fig.savefig(filename)
예제 #30
0
def save_figure(fig: matplotlib.figure.Figure, name: str) -> None:
    t_start = time()
    print(f'Saving figure {name}...', end='')
    fig.savefig(name)
    print(f' ({time() - t_start:.2f} s)')