Ejemplo n.º 1
0
def style_figure(fig: matplotlib.figure.Figure, title: str, bottom: float = 0.2) -> None:
    """Stylize the supplied matplotlib Figure instance."""
    fig.tight_layout()
    if bottom is not None:
        fig.subplots_adjust(bottom=bottom)
    fig.suptitle(title)
    fig.legend(ncol=10, handlelength=0.75, handletextpad=0.25, columnspacing=0.5, loc='lower left')
Ejemplo n.º 2
0
    def apply(self, fig: matplotlib.figure.Figure) -> None:
        # It shouldn't hurt to align the labels if there's only one.
        fig.align_ylabels()

        # Adjust the layout.
        fig.tight_layout()
        adjust_default_args = dict(
            # Reduce spacing between subplots
            hspace=0,
            wspace=0,
            # Reduce external spacing
            left=0.10,
            bottom=0.105,
            right=0.98,
            top=0.98,
        )
        adjust_default_args.update(self.edge_padding)
        fig.subplots_adjust(**adjust_default_args)
Ejemplo n.º 3
0
def savefig(f: matplotlib.figure.Figure,
            fpath: str,
            tight: bool = True,
            details: str = None,
            space: float = 0.0,
            **kwargs):
    if tight:
        if details:
            add_parameter_details(f, details, -0.4 + space)
        f.savefig(fpath, bbox_inches='tight', **kwargs)
    else:
        if details:
            f.subplots_adjust(bottom=0.2)
            add_parameter_details(f, details, 0.1)
        f.savefig(fpath, **kwargs)
    if fpath.endswith('png'):
        add_tags_to_png_file(fpath)
    if fpath.endswith('svg'):
        add_tags_to_svg_file(fpath)
Ejemplo n.º 4
0
def plot_avg_decay_data(t_sol: Union[np.ndarray, List[np.array]],
                        list_sim_data: List[np.array],
                        list_exp_data: List[np.array] = None,
                        state_labels: List[str] = None,
                        concentration: Conc = None,
                        atol: float = A_TOL,
                        colors: Union[str, Tuple[ColorMap, ColorMap]] = 'rk',
                        fig: mpl.figure.Figure = None,
                        title: str = '') -> None:
    ''' Plot the list of simulated and experimental data (optional) against time in t_sol.
        If concentration is given, the legend will show the concentrations.
        colors is a string with two chars. The first is the sim color,
        the second the exp data color.
    '''
    num_plots = len(list_sim_data)
    num_rows = 3
    num_cols = int(np.ceil(num_plots / 3))

    # optional lists default to list of None
    list_exp_data = list_exp_data or [None] * num_plots
    state_labels = state_labels or [''] * num_plots

    list_t_sim = t_sol if len(
        t_sol) == num_plots else [t_sol] * num_plots  # type: List[np.array]

    if concentration:
        conc_str = '_' + str(concentration.S_conc) + 'S_' + str(
            concentration.A_conc) + 'A'


#        state_labels = [label+conc_str for label in state_labels]
    else:
        conc_str = ''

    sim_color = colors[0]
    exp_color = colors[1]
    exp_size = 2  # marker size
    exp_marker = '.'

    if fig is None:
        fig = plt.figure()

    fig.suptitle(title + '. Time in ms.')

    list_axes = fig.get_axes()  # type: List
    if not list_axes:
        for num in range(num_plots):
            fig.add_subplot(num_rows, num_cols, num + 1)
        list_axes = fig.get_axes()

    for sim_data, t_sim, exp_data, state_label, axes\
        in zip(list_sim_data, list_t_sim, list_exp_data, state_labels, list_axes):

        if state_label:
            axes.set_title(
                state_label.replace('_', ' '), {
                    'horizontalalignment': 'center',
                    'verticalalignment': 'center',
                    'fontweight': 'bold',
                    'fontsize': 10
                })

        if sim_data is None or np.isnan(sim_data).any() or not np.any(
                sim_data > 0):
            continue

        # no exp data: either a GS or simply no exp data available
        if exp_data is 0 or exp_data is None:
            # nonposy='clip': clip non positive values to a very small positive number
            axes.semilogy(t_sim * 1000,
                          sim_data,
                          color=sim_color,
                          label=state_label + conc_str)

            axes.axis('tight')
            axes.set_xlim(left=t_sim[0] * 1000.0)
            # add some white space above and below
            margin_factor = np.array([0.7, 1.3])
            axes.set_ylim(*np.array(axes.get_ylim()) * margin_factor)
            if axes.set_ylim()[0] < atol:
                axes.set_ylim(bottom=atol)  # don't show noise below atol
                # detect when the simulation goes above and below atol
                above = sim_data > atol
                change_indices = np.where(np.roll(above, 1) != above)[0]
                # make sure change_indices[-1] happens when the population is going BELOW atol
                if change_indices.size > 1 and sim_data[
                        change_indices[-1]] < atol:  # pragma: no cover
                    # last time it changes
                    max_index = change_indices[-1]
                    # show simData until it falls below atol
                    axes.set_xlim(right=t_sim[max_index] * 1000)
            min_y = min(*axes.get_ylim())
            max_y = max(*axes.get_ylim())
            axes.set_ylim(bottom=min_y, top=max_y)
        else:  # exp data available
            sim_handle, = axes.semilogy(t_sim * 1000,
                                        sim_data,
                                        color=sim_color,
                                        label=state_label + conc_str,
                                        zorder=10)
            # convert exp_data time to ms
            exp_handle, = axes.semilogy(exp_data[:, 0] * 1000,
                                        exp_data[:, 1] * np.max(sim_data),
                                        color=exp_color,
                                        marker=exp_marker,
                                        linewidth=0,
                                        markersize=exp_size,
                                        zorder=1)
            axes.axis('tight')
            axes.set_ylim(top=axes.get_ylim()[1] *
                          1.2)  # add some white space on top
            tmin = min(exp_data[-1, 0], t_sim[0])
            axes.set_xlim(left=tmin * 1000.0, right=exp_data[-1, 0] *
                          1000)  # don't show beyond expData

    if conc_str:
        list_axes[0].legend(loc="best", fontsize='small')
        curr_handles, curr_labels = list_axes[0].get_legend_handles_labels()
        new_labels = [
            label.replace(state_labels[0] + '_', '').replace('_', ', ')
            for label in curr_labels
        ]
        list_axes[0].legend(curr_handles,
                            new_labels,
                            markerscale=5,
                            loc="best",
                            fontsize='small')

    fig.subplots_adjust(top=0.918,
                        bottom=0.041,
                        left=0.034,
                        right=0.99,
                        hspace=0.275,
                        wspace=0.12)
Ejemplo n.º 5
0
def southern_ocean_axes_setup(
    ax: matplotlib.axes.Axes,
    fig: matplotlib.figure.Figure,
    add_gridlines: bool = True,
) -> None:
    """
    This function sets up the subplot so that it is a cartopy map of the southern ocean.

    returns void as the ax and figure objects are pointers not data.

    Args:
        ax (matplotlib.axes.Axes): The axis object to add the map to.
        fig (matplotlib.figure.Figure): The figure object for the figure in general.
        add_gridlines (bool): whether or not to add gridlines to the plot.
    """
    carree = ccrs.PlateCarree()
    ax.set_extent([-180, 180, -90, -30], carree)
    fig.subplots_adjust(bottom=0.05,
                        top=0.95,
                        left=0.04,
                        right=0.95,
                        wspace=0.02)

    def plot_boundary() -> None:
        """
        Makes SO plot boundary into a nice circle
        of the right size.
        """
        theta = np.linspace(0, 2 * np.pi, 100)
        center, radius = [0.5, 0.5], 0.45
        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
        circle = mpath.Path(verts * radius + center)
        ax.set_boundary(circle, transform=ax.transAxes)

    plot_boundary()

    # add coastlines and gridlines
    ax.coastlines(resolution="50m", linewidth=0.3)

    if add_gridlines:
        ax.gridlines(linewidth=0.5)

    # Add 2000m isobath (or whatever the max depth is).

    @jit(cache=True)  # significant performance enhancement.
    def find_isobath(tmp_bathymetry: np.ndarray,
                     crit_depth=cst.MAX_DEPTH) -> List[list]:
        """
        Find isobath.

        Args:
            tmp_bathymetry (np.ndarray): Bathymetry np array.
            crit_depth ([type], optional): Critical depth. Defaults to cst.MAX_DEPTH.

        Returns:
            List[list]: List of index pairs.
        """
        isobath_index_list = []
        shape_bathymetry = np.shape(tmp_bathymetry)
        for i in range(0, shape_bathymetry[0] - 1):
            for j in range(0, shape_bathymetry[1] - 1):
                if tmp_bathymetry[i, j] >= crit_depth:
                    if tmp_bathymetry[i - 1, j] < crit_depth:
                        isobath_index_list.append([i, j])
                    if tmp_bathymetry[i, j - 1] < crit_depth:
                        isobath_index_list.append([i, j])
                    if tmp_bathymetry[i + 1, j] < crit_depth:
                        isobath_index_list.append([i, j])
                    if tmp_bathymetry[i, j + 1] < crit_depth:
                        isobath_index_list.append([i, j])
        return isobath_index_list

    i_list = find_isobath(
        xr.open_dataset(cst.SALT_FILE)[cst.DEPTH_NAME].values,
        crit_depth=cst.MAX_DEPTH,
    )

    index_npa: np.ndarray = np.array(i_list)
    lons = xr.open_dataset(cst.SALT_FILE)[cst.X_COORD].values[index_npa[:, 1]]
    lats = xr.open_dataset(cst.SALT_FILE)[cst.Y_COORD].values[index_npa[:, 0]]

    ax.plot(lons,
            lats,
            ",",
            markersize=0.4,
            color="grey",
            transform=ccrs.Geodetic())