예제 #1
0
def _explained_variance_plot(model, ax: mpl.axes.Axes, cutoff: float = 0.9):
    n = len(model.explained_variance_ratio_)
    _x = np.arange(1, n + 1)
    _ycum = np.cumsum(model.explained_variance_ratio_)
    best_index = np.where(_ycum > cutoff)[0]
    # modify in case we dont have one
    best_index = best_index[0] if best_index.shape[0] > 0 else n - 1
    # calculate AUC
    auc = np.trapz(_ycum, _x / n)
    # plot
    ax.plot(_x, _ycum, "x-")
    # plot best point
    ax.scatter(
        [_x[best_index]],
        [_ycum[best_index]],
        facecolors="None",
        edgecolors="red",
        s=100,
        label="n=%d, auc=%.3f" % (_x[best_index], auc),
    )
    # plot 0 to 1 line
    ax.plot([1, n], [0, 1], "k--")
    ax.set_xlabel("N\n(Best proportion: %.3f)" % (_x[best_index] / (n + 1)))
    ax.set_ylabel("Explained variance (ratio)\n(cutoff=%.2f)" % cutoff)
    ax.grid()
    ax.legend()
예제 #2
0
def plot_2pcf(ax: matplotlib.axes.Axes,
              dr: np.ndarray,
              xi0: np.ndarray,
              xi1: np.ndarray) -> None:
    """
    Plot the two-point correlation function for the x- and y-components of
    the astrometric residual field as a function of distance between
    points.

    Parameters
    ----------
    ax:
        Matplotlib axis in which to plot
    dr:
        separations at which the 2-point correlation functions were
        calculated
    xi0:
        2-point correlation function of the x-component of the astrometric
        residual field
    xi1:
        2-point correlation function of the y-component of the astrometric
        residual field

    Returns
    -------
    None
    """
    # Plot the two-point correlation functions as a function of distance
    ax.axhline(y=0, ls='--', lw=1, c='gray')
    ax.plot(dr, xi0, marker='o', ms=5, ls='-', lw=1, label=r'$\xi_{xx}$')
    ax.plot(dr, xi1, marker='o', ms=5, ls='-', lw=1, label=r'$\xi_{yy}$')
    ax.legend()
    ax.set_xlabel(r'$\Delta$ [degrees]', fontsize=12)
    ax.set_ylabel(r'$\xi(\Delta)$ [degrees$^2$]', fontsize=12)
예제 #3
0
def plot_prediction(train_data,
                    test_data,
                    prediction,
                    ax: matplotlib.axes.Axes = None) -> matplotlib.axes.Axes:
    """Plots case counts as step line, with outbreaks and alarms indicated by triangles."""
    whole_data = pd.concat((train_data, test_data), sort=False)
    fontsize = 20
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 8))
    ax.step(
        x=whole_data.index,
        y=whole_data.n_cases,
        where="mid",
        color="blue",
        label="_nolegend_",
    )
    alarms = prediction.query("alarm == 1")
    ax.plot(alarms.index, [0] * len(alarms),
            "g^",
            label="alarm",
            markersize=12)
    outbreaks = test_data.query("outbreak")
    ax.plot(
        outbreaks.index,
        outbreaks.n_outbreak_cases,
        "rv",
        label="outbreak",
        markersize=12,
    )
    ax.set_xlabel("time", fontsize=fontsize)
    ax.set_ylabel("cases", fontsize=fontsize)
    ax.legend(fontsize="xx-large")

    return ax
예제 #4
0
def plot_shots_per_camera(subplot: matplotlib.axes.Axes,
                          data: pd.DataFrame) -> None:
    """
	Barplot of the number of shots per camera, on the provided subplot. Acts in place.

    Args:
        subplot: the subplot matplotlib.axes.Axes on which to plot.
        data: the pandas DataFrame with your exif data.

    Returns:
        Nothing, plots in place.

    ??? warning "There is a danger here"
        Here is an explanation of the danger.
        Here is how to bypass it :)
    """
    logger.debug("Plotting shots per camera")
    sns.countplot(y="Camera",
                  hue="Brand",
                  data=data,
                  ax=subplot,
                  order=data.Camera.value_counts().index)
    subplot.set_title("Number of Shots per Camera Model", fontsize=25)
    subplot.tick_params(axis="both", which="major", labelsize=13)
    subplot.set_xlabel("Number of Shots", fontsize=20)
    subplot.set_ylabel("Camera Model", fontsize=20)
    subplot.legend(loc="lower right", fontsize=18, title_fontsize=22)
예제 #5
0
def plot_shots_per_focal_length(subplot: matplotlib.axes.Axes,
                                data: pd.DataFrame) -> None:
    """
    Barplot of the number of shots per focal length (FF equivalent), on the provided subplot.

    Args:
        subplot: the subplot matplotlib.axes.Axes on which to plot.
        data: the pandas DataFrame with your exif data.

    Returns:
        Nothing, plots in place.
    """
    logger.debug("Plotting shots per focal length")
    sns.countplot(
        x="Focal_Range",
        hue="Lens",
        data=data,
        ax=subplot,
        order=data.Focal_Range.value_counts().index,
    )
    subplot.set_title("Number of shots per Focal Length (FF equivalent)",
                      fontsize=25)
    subplot.tick_params(axis="both", which="major", labelsize=13)
    subplot.set_xlabel("Focal Length", fontsize=20)
    subplot.set_ylabel("Number of Shots", fontsize=20)
    subplot.legend(loc="upper center", fontsize=15, title_fontsize=21)
예제 #6
0
    def apply(
        self,
        ax: matplotlib.axes.Axes,
        legend_handles: Optional[Sequence[matplotlib.container.ErrorbarContainer]] = None,
        legend_labels: Optional[Sequence[str]] = None,
    ) -> None:
        if self.location:
            kwargs = {}
            if legend_handles:
                kwargs["handles"] = legend_handles
            if legend_labels:
                kwargs["labels"] = legend_labels

            ax.legend(
                loc=self.location,
                bbox_to_anchor=self.anchor,
                # If we specify an anchor, we want to reduce an additional padding
                # to ensure that we have accurate placement.
                borderaxespad=(0 if self.anchor else None),
                borderpad=(0 if self.anchor else None),
                frameon=False,
                fontsize=self.font_size,
                ncol=self.ncol,
                handletextpad=self.marker_label_spacing,
                labelspacing=self.label_spacing,
                **kwargs,
            )
예제 #7
0
    def plot_p_and_q(self, ax: mpl.axes.Axes):
        """Plot p(x) and q(x) for Sturm-Liouville models.

        Plots both learned and true coefficients, if true coefficients
        are provided to the Plotter constructor.

        Keyword arguments
        ax -- matplotlib.axes.Axes to plot on
        """
        # Throw an error if using this for a non-Sturm-Liouville operator
        if not self.is_sturm_liouville:
            Exception("This method only applies to Sturm-Liouville operators.")

        # Pull out the parametric coefficient for 'f'
        learned_f_coeff = self.reg_coeffs['f']
        # Compute learned parametric coefficient \phi
        inferred_phi = -1*np.reciprocal(learned_f_coeff)
        # Set NaN entries to 0 (shouldn't be any NaNs though...)
        inferred_phi[np.isnan(inferred_phi)] = 0
        # Save the inferred phi vector
        self.inferred_phi = inferred_phi

        # If it is p(x), call it that on the line label
        if self.lhs_term == 'd^{2}u/dx^{2}':
            phi_label = "Inferred p(x)"
        elif self.lhs_term == 'd^{4}u/dx^{4}':
            phi_label = "Inferred p(x)"
        else:
            phi_label = "Inferred $\phi(x)$"

        # So plot against the true p(x)
        p_x_plotted = self.p_x-np.mean(self.p_x)
        ax.plot(self.true_x_vector, p_x_plotted, color=self.coeff_colors[0],
                label='True $p(x)$', **self.true_opts)
        # for S-L, u_xx (or u_xxxx) regression, phi(x) is p(x)
        ip = inferred_phi-np.mean(self.p_x)
        ax.plot(self.reg_x_vector, ip, marker=self.markers[0],
                markevery=self.npts, label=phi_label, **self.reg_opts)

        # If 'u' is found in the model, there is a q(x) that can be computed
        if 'u' in self.reg_coeffs:
            # Compute q(x)
            inferred_q = self.reg_coeffs['u'] * inferred_phi
            # Plot true and inferred q on a new axis
            offset = max([ceil(max(p_x_plotted)+abs(min(self.q_x))), 1])
            # Plot true and inferred q
            iq = self.q_x-np.mean(self.q_x)+offset
            ax.plot(self.true_x_vector, iq, color=self.coeff_colors[1],
                    label="True $q(x)$", **self.true_opts)
            ax.plot(self.reg_x_vector, inferred_q-np.mean(self.q_x)+offset,
                    marker=self.markers[1], label="Inferred $q(x)$",
                    markevery=self.npts, **self.reg_opts)
            # Save the inferred q(x)
            self.inferred_q = inferred_q

        # Throw a legend on this
        if self.show_legends:
            leglines = ax.get_lines()
            leglabels = [l.get_label() for l in leglines]
            ax.legend(leglines, leglabels, **self.legend_opts)
예제 #8
0
    def _toggle_legend(self, ax: matplotlib.axes.Axes,  pad_h: float) -> None:
        '''
        Update legends according to those curves on the canvas.
        Clear and remove all legend instance if no curves on the canvas.
        Set legend to invisible but still keeping all legend instances existed
        if user choose to hide legend in the axes parameter dialog.
        Those customized setting are stored in ``parameter``.
        '''
        handles, labels = self._sort_legend(ax)
        if ax == self.ax_main:
            loc = "upper left"
            leg_bottom = 0.5
        else:
            loc = "lower left"
            leg_bottom = 0

        if labels:
            ax.legend(handles, labels, bbox_to_anchor=(1+pad_h, leg_bottom, 1, .5),
                      loc=loc, borderaxespad=0)
            legend_visible = self.parameter["General"]["Legend"]["visible"]
            print(ax.get_title(), legend_visible)
            ax.get_legend().set_visible(legend_visible)
        else:
            leg = ax.legend([])
            leg.remove()
예제 #9
0
def hist2d_scatter(x,
                   y,
                   bg_x,
                   bg_y,
                   axis: matplotlib.axes.Axes,
                   dataset_name: str,
                   normalize=True,
                   marker_color: str = 'k',
                   bins=300,
                   colormap_name: str = 'jet',
                   color_bar=False,
                   hist2dkw={},
                   scatterkw={}):
    """Plots 2D histogram with two parameters (e.g. BxGSM and ByGSM).
    x, y (array_like): Values of the TPAs for the parameter that will be plotted on the x- or y-axis.
                       This data will correspond to the dots in the plot.
    bg_x, bg_y (array_like): Values of the IMF over the period of the dataset.
                             These will form the background (colored tiles) of the plot
    """
    colormap = cm.get_cmap(colormap_name)
    axis.axhline(0, color='grey', zorder=1)
    axis.axvline(0, color='grey', zorder=1)
    omit_index = np.isnan(x) | np.isnan(y)
    x = x[~omit_index]
    y = y[~omit_index]
    bg_omit_index = np.isnan(bg_x) | np.isnan(bg_y)
    bg_x = bg_x[~bg_omit_index]
    bg_y = bg_y[~bg_omit_index]
    counts, xedges, yedges, im = axis.hist2d(bg_x,
                                             bg_y,
                                             bins=bins,
                                             cmap=colormap,
                                             density=normalize,
                                             zorder=0,
                                             **hist2dkw)

    if color_bar:
        cbar = plt.colorbar(im, ax=axis)
        cbar.set_label('IMF probability distribution',
                       rotation=270,
                       labelpad=10)

    scatter = axis.scatter(x,
                           y,
                           s=30,
                           marker='P',
                           edgecolors='w',
                           linewidth=0.5,
                           label=dataset_name,
                           c=marker_color[~omit_index] if isinstance(
                               marker_color, np.ndarray) else marker_color,
                           zorder=2,
                           **scatterkw)

    axis.set_facecolor(colormap(0))
    axis.legend(loc='upper left')
    return scatter, counts, xedges, yedges, im
예제 #10
0
파일: plot.py 프로젝트: kapilsh/pyqstrat
    def _draw(self, ax: mpl.axes.Axes, plot_timestamps: np.ndarray, date_formatter: Optional[DateFormatter]) -> None:
        
        if self.time_plot:
            self._reindex(plot_timestamps)
            if date_formatter is not None: ax.xaxis.set_major_formatter(date_formatter)
        
        lines = []
        
        ax2 = None
        if self.secondary_y is not None and len(self.secondary_y):
            ax2 = ax.twinx()
        
        for data in self.data_list:
            if _VERBOSE: print(f'plotting data: {data.name}')
            if ax2 and data.name in self.secondary_y:
                line = _plot_data(ax2, data)
            else:
                line = _plot_data(ax, data)
            lines.append(line)
                        
        for date_line in self.date_lines:  # vertical lines on time plot
            line = draw_date_line(ax, plot_timestamps, date_line.date, date_line.line_type, date_line.color)
            if date_line.name is not None: lines.append(line)
                
        for horizontal_line in self.horizontal_lines:
            line = draw_horizontal_line(ax, horizontal_line.y, horizontal_line.line_type, horizontal_line.color)
            if horizontal_line.name is not None: lines.append(line)
                
        for vertical_line in self.vertical_lines:
            line = draw_vertical_line(ax, vertical_line.x, vertical_line.line_type, vertical_line.color)
            if vertical_line.name is not None: lines.append(line)
          
        self.legend_names = [data.name for data in self.data_list]
        self.legend_names += [date_line.name for date_line in self.date_lines if date_line.name is not None]
        self.legend_names += [horizontal_line.name for horizontal_line in self.horizontal_lines if horizontal_line.name is not None]
        self.legend_names += [vertical_line.name for vertical_line in self.vertical_lines if vertical_line.name is not None]
        
        if self.ylim: ax.set_ylim(self.ylim)
        if (len(self.data_list) > 1 or len(self.date_lines)) and self.display_legend: 
            ax.legend([line for line in lines if line is not None],
                      [self.legend_names[i] for i, line in enumerate(lines) if line is not None], loc=self.legend_loc)
 
        if self.log_y: 
            ax.set_yscale('log')
            ax.yaxis.set_major_locator(mtick.AutoLocator())
            ax.yaxis.set_minor_locator(mtick.NullLocator())
        if self.y_tick_format:
            ax.yaxis.set_major_formatter(mtick.StrMethodFormatter(self.y_tick_format))
        
        if self.title: ax.set_title(self.title)
        if self.xlabel: ax.set_xlabel(self.xlabel)
        if self.ylabel: ax.set_ylabel(self.ylabel)
        if self.zlabel: ax.set_zlabel(self.zlabel)
            
        ax.relim()
        ax.autoscale_view()
예제 #11
0
def _position_legend(ax: mpl.axes.Axes, legend_loc: str, **kwargs):
    # modified from scVelo
    if legend_loc == "upper right":
        return ax.legend(loc="upper left", bbox_to_anchor=(1, 1), **kwargs)
    if legend_loc == "lower right":
        return ax.legend(loc="lower left", bbox_to_anchor=(1, 0), **kwargs)
    if "right" in legend_loc:  # 'right', 'center right', 'right margin'
        return ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), **kwargs)
    if legend_loc != "none":
        return ax.legend(loc=legend_loc, **kwargs)

    raise ValueError(f"Invalid legend location `{legend_loc!r}`.")
예제 #12
0
def plot_feature_importance_over_time(models: List[LGBMClassifier],
                                      feature_names: List[str],
                                      ax: mpl.axes.Axes) -> mpl.axes.Axes:
    importances = [(model.feature_importances_ / model.feature_importances_.sum()).reshape(1, -1)  # noqa
                   for name, model in models.items()]
    importances_df = pd.DataFrame(np.concatenate(importances))
    importances_df.columns = feature_names
    importances_df.index = [name for name, model in models.items()]

    ax = importances_df.plot.area(legend=False, cmap='tab20c', ax=ax)
    ax.set_ylabel('Importance')
    ax.set_ylim((0, 1))
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    return ax
예제 #13
0
def add_legend(ax: mpl.axes.Axes, labels: list, cmap1="cool", cmap2="bwr", s=20):
    cmp1 = mpl.cm.get_cmap(cmap1)
    cmp2 = mpl.cm.get_cmap(cmap2)
    colors = (cmp1(1.0), cmp1(0.0), cmp2(1.0), cmp2(0.0))
    markers = ("+", "+", "x", "x")
    sizes = (s * 2, s * 2, s + 5, s + 5)
    for label, c, m, size in zip(labels, colors, markers, sizes):
        ax.scatter([], [], color=c, marker=m, label=label, s=size)
    ax.legend(
        bbox_to_anchor=(0.0, 1.2, 1.0, 0.102),
        loc="lower left",
        ncol=2,
        mode="expand",
        borderaxespad=0.0,
    )
예제 #14
0
    def plot_ode_solutions(self, ax: mpl.axes.Axes, number: int = 3,
                           start_idc: int = 0, shift: int = 0):
        """Plot solutions to differential equations used to build dataset.

        Keyword arguments
        ax -- matplotlib.axes.Axes to plot on
        ode_sols -- list of diff eqn solutions from scipy's solve_ivp
        number -- plot only the first <number> solutions in the ode_sols list
        start_idc -- instead of retrieving the first <number> solutions to plot
        start retrieval at the start_idc index of the list
        shift -- provide x-axis offset to each successive solution plotted.
        useful if the solutions overlap significantly.
        """
        # ODE line colors and line properties
        ode_sols = self.ode_sols
        lcolors = self.ode_colors
        lprops = self.ode_opts

        # Plot differential equation solutions
        num_sols_to_plot = min(len(ode_sols), number)
        lines = []
        # Plot each ODE solution
        for i in range(num_sols_to_plot):
            sol = ode_sols[i+start_idc]
            label_u = self.dependent_var + ' solution {}'.format(i+1)
            label_dudx = 'd{}/d{}, solution {}'.format(self.dependent_var,
                                                       self.independent_var,
                                                       i+1)
            if i < len(lcolors):
                line, = ax.plot(sol.t+(i*shift), sol.y[0], color=lcolors[i],
                                linestyle='-', label=label_u, **lprops)
                dline, = ax.plot(sol.t+(i*shift), sol.y[1], color=lcolors[i],
                                 linestyle='--', label=label_dudx, **lprops)
            else:
                line, = ax.plot(sol.t+(i*shift), sol.y[0], linestyle='-',
                                label=label_u, **lprops)
                dline, = ax.plot(sol.t+(i*shift), sol.y[1],
                                 color=line.get_color(), linestyle='--',
                                 label=label_dudx, **lprops)
            lines.append(line)
            lines.append(dline)

        # Format the plot title, x-axis label, y-axis label, and legend
        if self.show_legends:
            leg_lines = [lines[0], lines[1]]
            leg_labels = ['u', '$du/dx$']
            ax.legend(leg_lines, leg_labels, **self.legend_opts)
예제 #15
0
def display_segmented_image(y: np.ndarray,
                            threshold: float = 0.5,
                            input_image: np.ndarray = None,
                            alpha_input_image: float = 0.2,
                            title: str = '',
                            ax: matplotlib.axes.Axes = None) -> None:
    """Display segemented image.

    This function displays the image where each class is shown in particular color.
    This is useful for getting a rapid view of the performance of the model
    on a few examples.

    Parameters:
        y: The array containing the prediction.
            Must be of shape (image_shape, num_classes)
        threshold: The threshold used on the predictions.
        input_image: If provided, display the input image in black.
        alpha_input_image: If an input_image is provided, the transparency of
            the input_image.
    """
    ax = ax or plt.gca()

    base_array = np.ones((y.shape[0], y.shape[1], 3)) * 1
    legend_handles = []

    for i in range(y.shape[-1]):
        # Retrieve a color (without the transparency value).
        colour = plt.cm.jet(i / y.shape[-1])[:-1]
        base_array[y[..., i] > threshold] = colour
        legend_handles.append(mpatches.Patch(color=colour, label=str(i)))

    # plt.figure(figsize=figsize)
    ax.imshow(base_array)
    ax.legend(handles=legend_handles, bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_title(title)

    if input_image is not None:
        ax.imshow(input_image[..., 0],
                  cmap=plt.cm.binary,
                  alpha=alpha_input_image)

    if not ax:
        plt.show()
예제 #16
0
파일: Plot.py 프로젝트: tttapa/EAGLE
def plot_step_analyzer(axes: matplotlib.axes.Axes, result: dict, title: str,
                       legends: list, colorset: int):
    colors = [
        ('red', 'green', 'blue'),  # TODO: add more colors
        ('tomato', 'lightgreen', 'steelblue'),
    ]
    c = colors[colorset]

    n = len(result['states'])

    for i in range(n):
        if legends:
            axes.plot(result['times'],
                      result['states'][i],
                      color=c[i],
                      label=legends[i])
        else:
            axes.plot(result['times'], result['states'][i], color=c[i])

        if result['references'][i] != 0.0:
            axes.axhline(result['references'][i], linestyle='--', color=c[i])
            axes.axhline(result['references'][i] - result['thresholds'][i],
                         linestyle='-.',
                         color=c[i])
            axes.axhline(result['references'][i] + result['thresholds'][i],
                         linestyle='-.',
                         color=c[i])
            axes.axhline(result['references'][i] + result['overshoots'][i],
                         linestyle=':',
                         color=c[i])
        if result['risetimes'][i] > 0.0:
            axes.axvline(result['risetimes'][i], linestyle='--', color=c[i])
        if result['settletimes'][i] > 0.0:
            axes.axvline(result['settletimes'][i], linestyle='-.', color=c[i])

    if legends:
        axes.legend()

    if title:
        axes.set_title(title)

    axes.set_xlim(result['times'][0], result['times'][-1])
    axes.figure.tight_layout()
예제 #17
0
def plot_testcount_forecast(
    result: pandas.Series,
    m: preprocessing.fbprophet.Prophet,
    forecast: pandas.DataFrame,
    considered_holidays: preprocessing.NamedDates, *,
    ax: matplotlib.axes.Axes=None
) -> matplotlib.axes.Axes:
    """ Helper function for plotting the detailed testcount forecasting result.

    Parameters
    ----------
    result : pandas.Series
        the date-indexed series of smoothed/predicted testcounts
    m : fbprophet.Prophet
        the prophet model
    forecast : pandas.DataFrame
        contains the prophet model prediction
    holidays : dict of { datetime : str }
        dictionary of the holidays that were used in the model
    ax : optional, matplotlib.axes.Axes
        an existing subplot to use

    Returns
    -------
    ax : matplotlib.axes.Axes
        the (created) subplot that was plotted into
    """
    if not ax:
        _, ax = pyplot.subplots(figsize=(13.4, 6))
    m.plot(forecast[forecast.ds >= m.history.set_index('ds').index[0]], ax=ax)
    ax.set_ylim(bottom=0)
    ax.set_xlim(pandas.to_datetime('2020-03-01'))
    plot_vlines(ax, considered_holidays, alignment='bottom')
    ax.legend(frameon=False, loc='upper left', handles=[
        ax.scatter([], [], color='black', label='training data'),
        ax.plot([], [], color='blue', label='prediction')[0],
        ax.plot(result.index, result.values, color='orange', label='result')[0],
    ])
    ax.set_ylabel('total tests')
    ax.set_xlabel('')
    return ax
예제 #18
0
def fix_pandas_multiplot_legend(ax:         matplotlib.axes.Axes, 
                                legend_loc: Tuple[Union[int, float], Union[int, float]]):
    """
    When plotting multiple pieces of data, the Pandas-generated
    plot legend will often look like "(metric, place)" (e.g.,
    "(Deaths, Connecticut)".
    
    This function corrects the legend, by extracting just the place.
    
    Parameters:
    
    ax         - the plot axis
    legend_loc - the desired location for the legend, as a tuple
    """
    patches, labels = ax.get_legend_handles_labels()
    pat = re.compile(r'^\([^,\s]+,\s+(.*)\)$')
    labels2 = []
    for label in labels:
        m = pat.match(label)
        assert m is not None
        labels2.append(m.group(1))
    ax.legend(patches, labels2, loc=legend_loc)
예제 #19
0
파일: main.py 프로젝트: andreasfuhr/sciplot
def set_legend(ax: matplotlib.axes.Axes,
               plot_tpl: Tuple[matplotlib.artist.Artist],
               label_tpl: Tuple[str],
               loc: str = 'lower left',
               outside_plot: bool = False,
               handle_scale_factor: float = 5.):
    if outside_plot:
        if 'right' in loc:
            horizontal_anchor = 1.04
            loc = loc.replace('right', 'left')
        elif 'left' in loc:
            horizontal_anchor = 0.94
            loc = loc.replace('left', 'right')
        else:
            horizontal_anchor = 0.5

        if 'upper' in loc:
            vertical_anchor = 1.
        elif 'lower' in loc:
            vertical_anchor = 0.
        else:
            vertical_anchor = 0.5

        lgnd = ax.legend(plot_tpl,
                         label_tpl,
                         scatterpoints=1,
                         loc=loc,
                         bbox_to_anchor=(horizontal_anchor, vertical_anchor))
    else:
        lgnd = ax.legend(
            plot_tpl,
            label_tpl,
            scatterpoints=1,
            loc=loc,
        )

    for lgnd_handle in lgnd.legendHandles:
        lgnd_handle._sizes = [handle_scale_factor]
예제 #20
0
def plot_shots_per_lens(subplot: matplotlib.axes.Axes,
                        data: pd.DataFrame) -> None:
    """
    Barplot of the number of shots per lens used, on the provided subplot. Acts in place.

    Args:
        subplot: the subplot plt.axes on which to plot.
        data: the pandas DataFrame with your exif data.

    Returns:
        Nothing, plots in place.
    """
    logger.debug("Plotting shots per lens")
    sns.countplot(y="Lens",
                  hue="Brand",
                  data=data,
                  ax=subplot,
                  order=data.Lens.value_counts().index)
    subplot.set_title("Number of Shots per Lens Model", fontsize=25)
    subplot.tick_params(axis="both", which="major", labelsize=13)
    subplot.set_xlabel("Number of Shots", fontsize=20)
    subplot.set_ylabel("Lens Model", fontsize=20)
    subplot.legend(loc="lower right", fontsize=18, title_fontsize=25)
예제 #21
0
def _position_legend(ax: mpl.axes.Axes, legend_loc: str, **kwargs) -> mpl.legend.Legend:
    """
    Position legend in- or outside the figure.

    Parameters
    ----------
    ax
        Ax where to position the legend.
    legend_loc
        Position of legend.
    **kwargs
        Keyword arguments for :func:`matplotlib.pyplot.legend`.

    Returns
    -------
    :class: `matplotlib.legend.Legend`
        The created legend.
    """

    if legend_loc == "center center out":
        raise ValueError(
            "Invalid option: `'center center out'`. Doesn't really make sense, does it?"
        )
    if legend_loc == "best":
        return ax.legend(loc="best", **kwargs)

    tmp, loc = legend_loc.split(" "), ""

    if len(tmp) == 1:
        height, rest = tmp[0], []
        width = "right" if height in ("upper", "top", "center") else "left"
    else:
        height, width, *rest = legend_loc.split(" ")
        if rest:
            if len(rest) != 1:
                raise ValueError(
                    f"Expected only 1 additional modifier ('in' or 'out'), found `{list(rest)}`."
                )
            elif rest[0] not in ("in", "out"):
                raise ValueError(
                    f"Invalid modifier `{rest[0]!r}`. Valid options are: `'in', 'out'`."
                )
            if rest[0] == "in":  # ignore in, it's default
                rest = []

    if height in ("upper", "top"):
        y = 1.55 if width == "center" else 1.025
        loc += "upper"
    elif height == "center":
        y = 0.5
        loc += "center"
    elif height in ("lower", "bottom"):
        y = -0.55 if width == "center" else -0.025
        loc += "lower"
    else:
        raise ValueError(
            f"Invalid legend position on y-axis: `{height!r}`. "
            f"Valid options are: `'upper', 'top', 'center', 'lower', 'bottom'`."
        )

    if width == "left":
        x = -0.05
        loc += " right" if rest else " left"
    elif width == "center":
        x = 0.5
        if height != "center":  # causes to be like top center
            loc += " center"
    elif width == "right":
        x = 1.05
        loc += " left" if rest else " right"
    else:
        raise ValueError(
            f"Invalid legend position on x-axis: `{width!r}`. "
            f"Valid options are: `'left', 'center', 'right'`."
        )

    if rest:
        kwargs["bbox_to_anchor"] = (x, y)

    return ax.legend(loc=loc, **kwargs)
예제 #22
0
def plot_lookahead_acceptance_rates(sampler_df: Union[pd.DataFrame, str],
                                    t_min: int = 0,
                                    title: str = "Acceptance rates",
                                    size: tuple = None,
                                    ax: mpl.axes.Axes = None):
    """Plot acceptance rates for look-ahead vs ordinary samples.
    The ratios are relative to all accepted particles, including eventually
    discarded ones.

    Parameters
    ----------
    sampler_df:
        Dataframe or file as generated via
        `RedisEvalParallelSampler(log_file=...)`.
    t_min:
        The minimum generation to show. E.g. a value of 1 omits the first
        generation.
    title:
        Plot title.
    size:
        The size of the plot in inches.
    ax:
        The axis object to use.

    Returns
    -------
    ax: Axis of the generated plot.
    """
    # process input
    if isinstance(sampler_df, str):
        sampler_df = pd.read_csv(sampler_df, sep=',')

    # create figure
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    # restrict to t >= 0
    sampler_df = sampler_df[sampler_df.t >= t_min]

    # extract variables

    # time
    t = sampler_df.t

    # look-ahead acceptances and samples
    n_la_acc = sampler_df.n_lookahead_accepted
    n_la = sampler_df.n_lookahead

    # total acceptances and samples
    n_all_acc = sampler_df.n_accepted
    n_all = sampler_df.n_evaluated

    # difference (actual proposal)
    n_act_acc = n_all_acc - n_la_acc
    n_act = n_all - n_la

    # plot
    ax.plot(t,
            n_all_acc / n_all,
            linestyle='--',
            marker='o',
            color='black',
            label="Combined")
    ax.plot(t, n_act_acc / n_act, marker='o', label="Actual")
    ax.plot(t, n_la_acc / n_la, marker='o', label="Look-ahead")

    # prettify plot
    ax.legend()
    ax.set_title(title)
    ax.set_xlabel("Population index")
    ax.set_ylabel("Acceptance rate")
    ax.set_ylim(bottom=0)
    # enforce integer ticks
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    if size is not None:
        fig.set_size_inches(size)

    return ax
예제 #23
0
def plot_lookahead_final_acceptance_fractions(
        sampler_df: Union[pd.DataFrame, str],
        population_sizes: Union[np.ndarray, History],
        relative: bool = False,
        fill: bool = False,
        alpha: float = None,
        t_min: int = 0,
        title: str = "Composition of final acceptances",
        size: tuple = None,
        ax: mpl.axes.Axes = None):
    """Plot fraction of look-ahead samples in final acceptances,
    over generations.

    Parameters
    ----------
    sampler_df:
        Dataframe or file as generated via
        `RedisEvalParallelSampler(log_file=...)`.
    population_sizes:
        The sizes of the populations of accepted particles. If a History is
        passed, those values are extracted automatically, otherwise should
        be for the same time values as `sampler_df`.
    relative:
        Whether to normalize the total evaluations for each generation to 1.
    fill:
        If True, instead of lines, filled areas are drawn that sum up to the
        totals.
    alpha:
        Alpha value for lines or areas.
    t_min:
        The minimum generation to show. E.g. a value of 1 omits the first
        generation.
    title:
        Plot title.
    size:
        The size of the plot in inches.
    ax:
        The axis object to use.

    Returns
    -------
    ax: Axis of the generated plot.
    """
    # process input
    if isinstance(sampler_df, str):
        sampler_df = pd.read_csv(sampler_df, sep=',')
    if alpha is None:
        alpha = 0.7 if fill else 1.0

    # create figure
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    # get numbers of final acceptances
    if isinstance(population_sizes, History):
        pop = population_sizes.get_all_populations()

        population_sizes = np.array(
            [pop.loc[pop.t == t, 'particles'] for t in sampler_df.t],
            dtype=float).flatten()

    # restrict to t >= 0
    population_sizes = population_sizes[sampler_df.t >= t_min]
    sampler_df = sampler_df[sampler_df.t >= t_min]

    # extract variables
    t = sampler_df.t

    n_la_acc = sampler_df.n_lookahead_accepted
    # actual look-ahead acceptances cannot be more than requested
    n_la_acc = np.minimum(n_la_acc, population_sizes)

    # actual acceptances are the remaining ones, as these are always later
    n_act_acc = population_sizes - n_la_acc

    # normalize
    if relative:
        n_la_acc /= population_sizes
        n_act_acc /= population_sizes
        population_sizes /= population_sizes

    # plot
    if fill:
        ax.fill_between(t,
                        n_la_acc,
                        population_sizes,
                        alpha=alpha,
                        label="Actual")
        ax.fill_between(t, 0, n_la_acc, alpha=alpha, label="Look-ahead")
    else:
        ax.plot(t,
                population_sizes,
                linestyle='--',
                marker='o',
                color='black',
                alpha=alpha,
                label="Population size")
        ax.plot(t, n_act_acc, marker='o', alpha=alpha, label="Actual")
        ax.plot(t, n_la_acc, marker='o', alpha=alpha, label="Look-ahead")

    # prettify plot
    ax.legend()
    ax.set_title(title)
    ax.set_xlabel("Population index")
    ax.set_ylabel("Final acceptances")
    ax.set_ylim(bottom=0)
    # enforce integer ticks
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    if size is not None:
        fig.set_size_inches(size)

    return ax
예제 #24
0
def plot_lookahead_evaluations(sampler_df: Union[pd.DataFrame, str],
                               relative: bool = False,
                               fill: bool = False,
                               alpha: float = None,
                               t_min: int = 0,
                               title: str = "Total evaluations",
                               size: tuple = None,
                               ax: mpl.axes.Axes = None):
    """Plot total vs look-ahead evaluations over the generations.

    Parameters
    ----------
    sampler_df:
        Dataframe or file as generated via
        `RedisEvalParallelSampler(log_file=...)`.
    relative:
        Whether to normalize the total evaluations for each generation to 1.
    fill:
        If True, instead of lines, filled areas are drawn that sum up to the
        totals.
    alpha:
        Alpha value for lines or areas.
    t_min:
        The minimum generation to show. E.g. a value of 1 omits the first
        generation.
    title:
        Plot title.
    size:
        The size of the plot in inches.
    ax:
        The axis object to use.

    Returns
    -------
    ax: Axis of the generated plot.
    """
    # process input
    if isinstance(sampler_df, str):
        sampler_df = pd.read_csv(sampler_df, sep=',')
    if alpha is None:
        alpha = 0.7 if fill else 1.0

    # create figure
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    # restrict to t >= 0
    sampler_df = sampler_df[sampler_df.t >= t_min]

    # extract variables
    t = sampler_df.t
    n_la = sampler_df.n_lookahead
    n_eval = sampler_df.n_evaluated
    n_act = n_eval - n_la

    # normalize
    if relative:
        n_la /= n_eval
        n_act /= n_eval
        n_eval /= n_eval

    # plot
    if fill:
        ax.fill_between(t, n_la, n_eval, alpha=alpha, label="Actual")
        ax.fill_between(t, 0, n_la, alpha=alpha, label="Look-ahead")
    else:
        ax.plot(t,
                n_eval,
                linestyle='--',
                marker='o',
                color='black',
                alpha=alpha,
                label="Total")
        ax.plot(t, n_act, marker='o', alpha=alpha, label="Actual")
        ax.plot(t, n_la, marker='o', alpha=alpha, label="Look-ahead")

    # prettify plot
    ax.legend()
    ax.set_title(title)
    ax.set_xlabel("Population index")
    ax.set_ylabel("Evaluations")
    ax.set_ylim(bottom=0)
    # enforce integer ticks
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    if size is not None:
        fig.set_size_inches(size)

    return ax
예제 #25
0
def plot_acceptance_rates_trajectory(histories: Union[List, History],
                                     labels: Union[List, str] = None,
                                     title: str = "Acceptance rates",
                                     yscale: str = 'lin',
                                     size: tuple = None,
                                     ax: mpl.axes.Axes = None,
                                     colors: List[str] = None,
                                     normalize_by_ess: bool = False):
    """
    Plot of acceptance rates over all iterations, i.e. one trajectory
    per history.

    Parameters
    ----------
    histories:
        The histories to plot from. History ids must be set correctly.
    labels:
        Labels corresponding to the histories. If None are provided,
        indices are used as labels.
    title:
        Title for the plot.
    yscale:
        The scale on which to plot the counts. Can be one of 'lin', 'log'
        (basis e) or 'log10'
    size:
        The size of the plot in inches.
    ax:
        The axis object to use.
    normalize_by_ess: bool, optional (default = False)
        Indicator to use effective sample size for the acceptance rate in
        place of the population size.

    Returns
    -------
    ax: Axis of the generated plot.
    """
    # preprocess input
    histories = to_lists(histories)
    labels = get_labels(labels, len(histories))
    if colors is None:
        colors = [None] * len(histories)
    # create figure
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    # extract sample numbers
    times = []
    samples = []
    pop_sizes = []
    for history in histories:
        # note: the first entry of time -1 is trivial and is thus ignored here
        h_info = history.get_all_populations()
        times.append(np.array(h_info['t'])[1:])
        if normalize_by_ess:
            ess = np.zeros(len(h_info['t']) - 1)
            for t in np.array(h_info['t'])[1:]:
                w = history.get_weighted_distances(t=t)['w']
                ess[t - 1] = effective_sample_size(w)
            pop_sizes.append(ess)
        else:
            pop_sizes.append(
                np.array(history.get_nr_particles_per_population().values[1:]))
        samples.append(np.array(h_info['samples'])[1:])

    # compute acceptance rates
    rates = []
    for sample, pop_size in zip(samples, pop_sizes):
        rates.append(pop_size / sample)

    # apply scale
    ylabel = "Acceptance rate"
    if yscale == 'log':
        rates = [np.log(rate) for rate in rates]
        ylabel = "log(" + ylabel + ")"
    elif yscale == 'log10':
        rates = [np.log10(rate) for rate in rates]
        ylabel = "log10(" + ylabel + ")"

    # plot
    for t, rate, label, color in zip(times, rates, labels, colors):
        ax.plot(t, rate, 'x-', label=label, color=color)

    # add labels
    ax.legend()
    ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.set_xlabel("Population index $t$")
    # set size
    if size is not None:
        fig.set_size_inches(size)
    fig.tight_layout()

    return ax
예제 #26
0
def plot_sample_numbers_trajectory(histories: Union[List, History],
                                   labels: Union[List, str] = None,
                                   rotation: int = 0,
                                   title: str = "Required samples",
                                   yscale: str = 'lin',
                                   size: tuple = None,
                                   ax: mpl.axes.Axes = None):
    """
    Plot of required sample number over all iterations, i.e. one trajectory
    per history.

    Parameters
    ----------
    histories: Union[List, History]
        The histories to plot from. History ids must be set correctly.
    labels: Union[List ,str], optional
        Labels corresponding to the histories. If None are provided,
        indices are used as labels.
    rotation: int, optional (default = 0)
        Rotation to apply to the plot's x tick labels. For longer labels,
        a tilting of 45 or even 90 can be preferable.
    title: str, optional (default = "Required samples")
        Title for the plot.
    yscale: str, optional (default = 'lin')
        The scale on which to plot the counts. Can be one of 'lin', 'log'
        (basis e) or 'log10'
    size: tuple of float, optional
        The size of the plot in inches.
    ax: matplotlib.axes.Axes, optional
        The axis object to use.

    Returns
    -------
    ax: Axis of the generated plot.
    """
    # preprocess input
    histories = to_lists(histories)
    labels = get_labels(labels, len(histories))

    # create figure
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    # extract sample numbers
    times = []
    samples = []
    for history in histories:
        # note: the first entry corresponds to the calibration and should
        # be included here to be fair against methods not requiring
        # calibration
        h_info = history.get_all_populations()
        times.append(np.array(h_info['t']))
        samples.append(np.array(h_info['samples']))

    # apply scale
    ylabel = "Samples"
    if yscale == 'log':
        samples = [np.log(sample) for sample in samples]
        ylabel = "log(" + ylabel + ")"
    elif yscale == 'log10':
        samples = [np.log10(sample) for sample in samples]
        ylabel = "log10(" + ylabel + ")"

    # plot
    for t, sample, label in zip(times, samples, labels):
        ax.plot(t, sample, 'x-', label=label)

    # add labels
    if any(lab is not None for lab in labels):
        ax.legend()
    ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.set_xlabel("Population index $t$")
    # set size
    if size is not None:
        fig.set_size_inches(size)
    fig.tight_layout()

    return ax
예제 #27
0
def plot_sample_numbers(histories: Union[List, History],
                        labels: Union[List, str] = None,
                        rotation: int = 0,
                        title: str = "Required samples",
                        size: tuple = None,
                        ax: mpl.axes.Axes = None):
    """
    Stacked bar plot of required numbers of samples over all iterations.

    Parameters
    ----------

    histories: Union[List, History]
        The histories to plot from. History ids must be set correctly.
    labels: Union[List ,str], optional
        Labels corresponding to the histories. If None are provided,
        indices are used as labels.
    rotation: int, optional (default = 0)
        Rotation to apply to the plot's x tick labels. For longer labels,
        a tilting of 45 or even 90 can be preferable.
    title: str, optional (default = "Total required samples")
        Title for the plot.
    size: tuple of float, optional
        The size of the plot in inches.
    ax: matplotlib.axes.Axes, optional
        The axis object to use.

    Returns
    -------

    ax: Axis of the generated plot.
    """
    # preprocess input
    histories = to_lists(histories)
    labels = get_labels(labels, len(histories))

    # create figure
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    n_run = len(histories)

    # extract sample numbers
    samples = []
    for history in histories:
        # note: the first entry corresponds to the calibration and should
        # be included here to be fair against methods not requiring
        # calibration
        samples.append(np.array(history.get_all_populations()['samples']))

    # create matrix
    n_pop = max(len(sample) for sample in samples)
    matrix = np.zeros((n_pop, n_run))
    for i_sample, sample in enumerate(samples):
        matrix[:len(sample), i_sample] = sample

    # plot bars
    for i_pop in reversed(range(n_pop)):
        ax.bar(x=np.arange(n_run),
               height=matrix[i_pop, :],
               bottom=np.sum(matrix[:i_pop, :], axis=0),
               label=f"Generation {i_pop-1}")

    # add labels
    ax.set_xticks(np.arange(n_run))
    ax.set_xticklabels(labels, rotation=rotation)
    ax.set_title(title)
    ax.set_ylabel("Samples")
    ax.set_xlabel("Run")
    ax.legend()
    # set size
    if size is not None:
        fig.set_size_inches(size)
    fig.tight_layout()

    return ax
예제 #28
0
def plot_effective_sample_sizes(
    histories: Union[List, History],
    labels: Union[List, str] = None,
    rotation: int = 0,
    title: str = "Effective sample size",
    relative: bool = False,
    colors: List = None,
    size: tuple = None,
    ax: mpl.axes.Axes = None,
):
    """
    Plot effective sample sizes over all iterations.

    Parameters
    ----------

    histories: Union[List, History]
        The histories to plot from. History ids must be set correctly.
    labels: Union[List ,str], optional
        Labels corresponding to the histories. If None are provided,
        indices are used as labels.
    rotation: int, optional (default = 0)
        Rotation to apply to the plot's x tick labels. For longer labels,
        a tilting of 45 or even 90 can be preferable.
    title: str, optional (default = "Total required samples")
        Title for the plot.
    relative: bool, optional (default = False)
        Whether to show relative sizes (to 1) or w.r.t. the real number
        of particles.
    colors: List, optional
        Colors to use for the lines. If None, then the matplotlib
        default values are used.
    size: tuple of float, optional
        The size of the plot in inches.
    ax: matplotlib.axes.Axes, optional
        The axis object to use. A new one is created if None.

    Returns
    -------

    ax: Axis of the generated plot.
    """
    # preprocess input
    histories = to_lists(histories)
    labels = get_labels(labels, len(histories))
    if colors is None:
        colors = [None for _ in range(len(histories))]

    # create figure
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    # extract effective sample sizes
    essss = []  # :)
    for history in histories:
        esss = []
        for t in range(0, history.max_t + 1):
            # we need the weights not normalized to 1 for each model
            w = history.get_weighted_distances(t=t)['w']
            ess = effective_sample_size(w)
            if relative:
                ess /= len(w)
            esss.append(ess)
        essss.append(esss)

    # plot
    for esss, label, color in zip(essss, labels, colors):
        ax.plot(range(0, len(esss)), esss, 'x-', label=label, color=color)

    # format
    ax.set_xlabel("Population index")
    ax.set_ylabel("ESS")
    if any(lab is not None for lab in labels):
        ax.legend()
    ax.set_title(title)
    # enforce integer ticks
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    # set size
    if size is not None:
        fig.set_size_inches(size)
    fig.tight_layout()

    return ax
예제 #29
0
def plot_correlations(
    fig: matplotlib.figure.Figure,
    ax: matplotlib.axes.Axes,
    r2: float,
    slope: float,
    y_inter: float,
    corr_vals: np.ndarray,
    vis_vals: np.ndarray,
    scale_factor: Union[float, int],
    corr_bname: str,
    vis_bname: str,
    odir: Union[Path, str],
):
    """
    Plot the correlations between NIR band and the visible bands for
    the Hedley et al. (2005) sunglint correction method

    Parameters
    ----------
    fig : matplotlib.figure object
        Reusing a matplotlib.figure object to avoid the creation many
        fig instantances

    ax : matplotlib.axes._subplots object
        Reusing the axes object

    r2 : float
        The correlation coefficient squared of the linear regression
        between NIR and a VIS band

    slope : float
        The slope/gradient of the linear regression between NIR and
        a VIS band

    y_inter : float
        The intercept of the linear regression between NIR and a
        VIS band

    corr_vals : numpy.ndarray
        1D array containing the NIR values from the ROI

    vis_vals : numpy.ndarray
        1D array containing the VIS values from the ROI

    scale_factor : int or None
        The scale factor used to convert integers to reflectances
        that range [0...1]

    corr_bname : str
        The NIR band number

    vis_bname : str
        The VIS band number

    odir : str
        Directory where the correlation plots are saved

    """
    # clear previous plot
    ax.clear()

    # ----------------------------------- #
    #   Create a unique cmap for hist2d   #
    # ----------------------------------- #
    ncolours = 256

    # get the jet colormap
    colour_array = plt.get_cmap("jet")(range(ncolours))  # 256 x 4

    # change alpha values
    # e.g. low values have alpha = 1, high values have alpha = 0
    # color_array[:,-1] = np.linspace(1.0,0.0,ncolors)
    # e.g. low values have alpha = 0, high values have alpha = 1
    # color_array[:,-1] = np.linspace(0.0,1.0,ncolors)

    # We want only the first few colours to have low alpha
    # as they would represent low density [meshgrid] bins
    # which we are not interested in, and hence would want
    # them to appear as a white colour (alpha ~ 0)
    num_alpha = 25
    colour_array[0:num_alpha, -1] = np.linspace(0.0, 1.0, num_alpha)
    colour_array[num_alpha:, -1] = 1

    # create a colormap object
    cmap = LinearSegmentedColormap.from_list(name="jet_alpha",
                                             colors=colour_array)

    # ----------------------------------- #
    #  Plot density using np.histogram2d  #
    # ----------------------------------- #
    xbin_low, xbin_high = np.percentile(corr_vals, (1, 99),
                                        interpolation="linear")
    ybin_low, ybin_high = np.percentile(vis_vals, (1, 99),
                                        interpolation="linear")

    nbins = [int(xbin_high - xbin_low), int(ybin_high - ybin_low)]

    bin_range = [[int(xbin_low), int(xbin_high)],
                 [int(ybin_low), int(ybin_high)]]

    hist2d, xedges, yedges = np.histogram2d(x=corr_vals,
                                            y=vis_vals,
                                            bins=nbins,
                                            range=bin_range)

    # normalised hist to range [0...1] then rotate and flip
    hist2d = np.flipud(np.rot90(hist2d / hist2d.max()))

    # Mask zeros
    hist_masked = np.ma.masked_where(hist2d == 0, hist2d)

    # use pcolormesh to plot the hist2D
    qm = ax.pcolormesh(xedges, yedges, hist_masked, cmap=cmap)

    # create a colour bar axes within ax
    cbaxes = inset_axes(
        ax,
        width="3%",
        height="30%",
        bbox_to_anchor=(0.37, 0.03, 1, 1),
        loc="lower center",
        bbox_transform=ax.transAxes,
    )

    # Add a colour bar inside the axes
    fig.colorbar(
        cm.ScalarMappable(cmap=cmap),
        cax=cbaxes,
        ticks=[0.0, 1],
        orientation="vertical",
        label="Point Density",
    )

    # ----------------------------------- #
    #     Plot linear regression line     #
    # ----------------------------------- #
    x_range = np.array([xbin_low, xbin_high])
    (ln, ) = ax.plot(
        x_range,
        slope * (x_range) + y_inter,
        color="k",
        linestyle="-",
        label="linear regr.",
    )

    # ----------------------------------- #
    #          Format the figure          #
    # ----------------------------------- #
    # add legend (top left)
    lgnd = ax.legend(loc=2, fontsize=10)

    # add annotation
    ann_str = (r"$r^{2}$" + " = {0:0.2f}\n"
               "slope = {1:0.2f}\n"
               "y-inter = {2:0.2f}".format(r2, slope, y_inter))
    ann = ax.annotate(ann_str,
                      xy=(0.02, 0.76),
                      xycoords="axes fraction",
                      fontsize=10)

    # Add labels to figure
    xlabel = f"Reflectance ({corr_bname})"
    ylabel = f"Reflectance ({vis_bname})"

    if scale_factor is not None:
        if scale_factor > 1:
            xlabel += " " + r"$\times$" + " {0}".format(int(scale_factor))
            ylabel += " " + r"$\times$" + " {0}".format(int(scale_factor))

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    # plt.show(); sys.exit()

    # Save figure
    png_file = os.path.join(
        odir, "Correlation_{0}_vs_{1}.png".format(corr_bname, vis_bname))

    fig.savefig(png_file,
                format="png",
                bbox_inches="tight",
                pad_inches=0.1,
                dpi=300)

    # delete all lines and annotations from figure,
    # so it can be reused in the next iteration
    qm.remove()
    ln.remove()
    ann.remove()
    lgnd.remove()
예제 #30
0
def hist1d(foreground,
           background,
           axis: matplotlib.axes.Axes,
           dataset_name: str,
           normalize=True,
           nbins: np.ndarray = np.linspace(-20, 20, 40),
           norm_ymax=10,
           log=False,
           bg_label='total IMF'):
    """Plots 1D histogram of e.g. BxGSM.
    Parameters
    ----------
    foreground (array_like): data for the TPAs.
    background (array_like): all the data for the IMF during the period of the dataset.

    Returns
    -------
    fg_hist_values, bg_hist_values, bins
    """
    foreground = foreground[~np.isnan(foreground)]
    background = background[~np.isnan(background)]

    axis.axvline(0, color="grey", lw=1, zorder=-1)
    bg_hist_values, _, _ = axis.hist(background,
                                     bins=nbins,
                                     weights=np.ones_like(background) /
                                     len(background),
                                     label=bg_label,
                                     histtype='step',
                                     zorder=0)
    fg_hist_values, bins, _ = axis.hist(foreground,
                                        bins=nbins,
                                        weights=np.ones_like(foreground) /
                                        len(foreground),
                                        label=dataset_name,
                                        histtype='step',
                                        zorder=1)
    if normalize:
        normalized_axis = axis.twinx(
        )  # instantiate a second axes that shares the same x-axis
        normalized_axis.axhline(1, ls="--", color='lightgrey', lw=1)
        masked_bg_hist_values = np.ma.masked_where(bg_hist_values == 0,
                                                   bg_hist_values)
        normalized_axis.plot((bins[1:] + bins[:-1]) / 2,
                             fg_hist_values / masked_bg_hist_values,
                             c='g',
                             label='IMF normalized',
                             zorder=2)
        normalized_axis.set_ylim(0, norm_ymax)
        label = normalized_axis.set_ylabel('IMF normalized TPA distribution',
                                           color='g')
        label.set_color('g')

    axis.legend(loc='upper left')

    if log:
        axis.set_xscale('log')

    axis.set_ylabel('Probability Distribution')
    axis.minorticks_on()

    return fg_hist_values, bg_hist_values, bins