Exemple #1
0
def hist2d_scatter(x,
                   y,
                   bg_x,
                   bg_y,
                   axis: matplotlib.axes.Axes,
                   dataset_name: str,
                   normalize=True,
                   marker_color: str = 'k',
                   bins=300,
                   colormap_name: str = 'jet',
                   color_bar=False,
                   hist2dkw={},
                   scatterkw={}):
    """Plots 2D histogram with two parameters (e.g. BxGSM and ByGSM).
    x, y (array_like): Values of the TPAs for the parameter that will be plotted on the x- or y-axis.
                       This data will correspond to the dots in the plot.
    bg_x, bg_y (array_like): Values of the IMF over the period of the dataset.
                             These will form the background (colored tiles) of the plot
    """
    colormap = cm.get_cmap(colormap_name)
    axis.axhline(0, color='grey', zorder=1)
    axis.axvline(0, color='grey', zorder=1)
    omit_index = np.isnan(x) | np.isnan(y)
    x = x[~omit_index]
    y = y[~omit_index]
    bg_omit_index = np.isnan(bg_x) | np.isnan(bg_y)
    bg_x = bg_x[~bg_omit_index]
    bg_y = bg_y[~bg_omit_index]
    counts, xedges, yedges, im = axis.hist2d(bg_x,
                                             bg_y,
                                             bins=bins,
                                             cmap=colormap,
                                             density=normalize,
                                             zorder=0,
                                             **hist2dkw)

    if color_bar:
        cbar = plt.colorbar(im, ax=axis)
        cbar.set_label('IMF probability distribution',
                       rotation=270,
                       labelpad=10)

    scatter = axis.scatter(x,
                           y,
                           s=30,
                           marker='P',
                           edgecolors='w',
                           linewidth=0.5,
                           label=dataset_name,
                           c=marker_color[~omit_index] if isinstance(
                               marker_color, np.ndarray) else marker_color,
                           zorder=2,
                           **scatterkw)

    axis.set_facecolor(colormap(0))
    axis.legend(loc='upper left')
    return scatter, counts, xedges, yedges, im
Exemple #2
0
def draw_date_line(ax: mpl.axes.Axes, 
                   plot_timestamps: np.ndarray, 
                   date: np.datetime64, 
                   linestyle: str, 
                   color: Optional[str]) -> mpl.lines.Line2D:
    '''Draw vertical line on a subplot with datetime x axis'''
    closest_index = (np.abs(plot_timestamps - date)).argmin()
    return ax.axvline(x=closest_index, linestyle=linestyle, color=color)
Exemple #3
0
def plot_step_analyzer(axes: matplotlib.axes.Axes, result: dict, title: str,
                       legends: list, colorset: int):
    colors = [
        ('red', 'green', 'blue'),  # TODO: add more colors
        ('tomato', 'lightgreen', 'steelblue'),
    ]
    c = colors[colorset]

    n = len(result['states'])

    for i in range(n):
        if legends:
            axes.plot(result['times'],
                      result['states'][i],
                      color=c[i],
                      label=legends[i])
        else:
            axes.plot(result['times'], result['states'][i], color=c[i])

        if result['references'][i] != 0.0:
            axes.axhline(result['references'][i], linestyle='--', color=c[i])
            axes.axhline(result['references'][i] - result['thresholds'][i],
                         linestyle='-.',
                         color=c[i])
            axes.axhline(result['references'][i] + result['thresholds'][i],
                         linestyle='-.',
                         color=c[i])
            axes.axhline(result['references'][i] + result['overshoots'][i],
                         linestyle=':',
                         color=c[i])
        if result['risetimes'][i] > 0.0:
            axes.axvline(result['risetimes'][i], linestyle='--', color=c[i])
        if result['settletimes'][i] > 0.0:
            axes.axvline(result['settletimes'][i], linestyle='-.', color=c[i])

    if legends:
        axes.legend()

    if title:
        axes.set_title(title)

    axes.set_xlim(result['times'][0], result['times'][-1])
    axes.figure.tight_layout()
Exemple #4
0
def plot_ece_samples(ax: mpl.axes.Axes,
                     ground_truth_ece: float,
                     frequentist_ece,
                     samples_posterior: np.ndarray,
                     plot_kwargs: Dict[str, Any] = {}) -> mpl.axes.Axes:
    """

    :param ax:
    :param ground_truth_ece: float
    :param frequentist_ece: float or np.ndarray

    :param samples_posterior:
    :param plot_kwargs:
    :return:
    """
    _plot_kwargs = DEFAULT_PLOT_KWARGS.copy()
    _plot_kwargs.update(plot_kwargs)
    if isinstance(frequentist_ece, float):
        ax.axvline(x=frequentist_ece,
                   label='Frequentist',
                   color='blue',
                   **_plot_kwargs)
    else:
        ax.hist(frequentist_ece,
                color='blue',
                alpha=0.7,
                label='Frequentist',
                **_plot_kwargs)
    ax.hist(samples_posterior,
            color='red',
            label='Bayesian',
            alpha=0.7,
            **_plot_kwargs)
    ax.axvline(x=ground_truth_ece,
               label='Ground truth',
               color='black',
               **_plot_kwargs)

    ax.set_xlim(0, 0.3)
    ax.set_xticks([0.0, 0.1, 0.2, 0.3])

    return ax
Exemple #5
0
def _draw_date_gap_lines(ax: mpl.axes.Axes, plot_timestamps: np.ndarray) -> None:
    '''
    Draw vertical lines wherever there are gaps between two timestamps.
        i.e., the gap between two adjacent timestamps is more than the minimum gap in the series.
    '''
    timestamps = mdates.date2num(plot_timestamps)
    freq = np.nanmin(np.diff(timestamps))
    if freq <= 0: raise Exception('could not infer date frequency')
    date_index = np.arange(len(timestamps))
    date_diff = np.diff(timestamps)
    
    xs = []

    for i in date_index:
        if i < len(date_diff) and date_diff[i] > (freq + 0.000000001):
            xs.append(i + 0.5)
            
    if len(xs) > 20:
        return  # Too many lines will clutter the graph
    
    for x in xs:
        ax.axvline(x, linestyle='dashed', color='0.5')
Exemple #6
0
def plot_vlines(
    ax: matplotlib.axes.Axes,
    vlines: preprocessing.NamedDates,
    alignment: str,
) -> None:
    """ Helper function for marking special events with labeled vertical lines.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        the subplot to draw into
    vlines : dict of { datetime : label }
        the dates and labels for the lines
    alignment : str
        one of { "top", "bottom" }
    """
    ymin, ymax = ax.get_ylim()
    xmin, xmax = ax.get_xlim()
    for x, label in vlines.items():
        if xmin <= ax.xaxis.convert_xunits(x) <= xmax:
            label = textwrap.shorten(label, width=20, placeholder="...")
            ax.axvline(x, color="gray", linestyle=":")
            if alignment == 'top':
                y = ymin+0.98*(ymax-ymin)
            elif alignment == 'bottom':
                y = ymin+0.02*(ymax-ymin)
            else:
                raise ValueError(f"Unsupported alignment: '{alignment}'")
            ax.text(
                x, y,
                s=f'{label}\n',
                color="gray",
                rotation=90,
                horizontalalignment="center",
                verticalalignment=alignment,
            )
    return None
Exemple #7
0
 def _draw_rt_ref_lines(self, ax: matplotlib.axes.Axes) -> None:
     """Draw vertical lines for RT min, RT max, and RT peak"""
     for x_pos in self.rt_range:
         ax.axvline(x_pos, color="black")
     ax.axvline(self.rt_peak, color="red")
Exemple #8
0
def hist1d(foreground,
           background,
           axis: matplotlib.axes.Axes,
           dataset_name: str,
           normalize=True,
           nbins: np.ndarray = np.linspace(-20, 20, 40),
           norm_ymax=10,
           log=False,
           bg_label='total IMF'):
    """Plots 1D histogram of e.g. BxGSM.
    Parameters
    ----------
    foreground (array_like): data for the TPAs.
    background (array_like): all the data for the IMF during the period of the dataset.

    Returns
    -------
    fg_hist_values, bg_hist_values, bins
    """
    foreground = foreground[~np.isnan(foreground)]
    background = background[~np.isnan(background)]

    axis.axvline(0, color="grey", lw=1, zorder=-1)
    bg_hist_values, _, _ = axis.hist(background,
                                     bins=nbins,
                                     weights=np.ones_like(background) /
                                     len(background),
                                     label=bg_label,
                                     histtype='step',
                                     zorder=0)
    fg_hist_values, bins, _ = axis.hist(foreground,
                                        bins=nbins,
                                        weights=np.ones_like(foreground) /
                                        len(foreground),
                                        label=dataset_name,
                                        histtype='step',
                                        zorder=1)
    if normalize:
        normalized_axis = axis.twinx(
        )  # instantiate a second axes that shares the same x-axis
        normalized_axis.axhline(1, ls="--", color='lightgrey', lw=1)
        masked_bg_hist_values = np.ma.masked_where(bg_hist_values == 0,
                                                   bg_hist_values)
        normalized_axis.plot((bins[1:] + bins[:-1]) / 2,
                             fg_hist_values / masked_bg_hist_values,
                             c='g',
                             label='IMF normalized',
                             zorder=2)
        normalized_axis.set_ylim(0, norm_ymax)
        label = normalized_axis.set_ylabel('IMF normalized TPA distribution',
                                           color='g')
        label.set_color('g')

    axis.legend(loc='upper left')

    if log:
        axis.set_xscale('log')

    axis.set_ylabel('Probability Distribution')
    axis.minorticks_on()

    return fg_hist_values, bg_hist_values, bins
Exemple #9
0
def draw_vertical_line(ax: mpl.axes.Axes, x: float, linestyle: str,
                       color: Optional[str]) -> mpl.lines.Line2D:
    '''Draw vertical line on a subplot'''
    return ax.axvline(x=x, linestyle=linestyle, color=color)
Exemple #10
0
def plot_eps_walltime_lowlevel(
    end_times: List,
    eps: List,
    labels: Union[List, str] = None,
    colors: List[Any] = None,
    group_by_label: bool = True,
    indicate_end: bool = True,
    unit: str = 's',
    xscale: str = 'linear',
    yscale: str = 'log',
    title: str = "Epsilon over walltime",
    size: tuple = None,
    ax: mpl.axes.Axes = None,
) -> mpl.axes.Axes:
    """Low-level access to `plot_eps_walltime`.
    Directly define `end_times` and `eps`. Note that both should be arrays of
    the same length and at the beginning include a value for the calibration
    iteration. This is just what `pyabc.History.get_all_populations()` returns.
    The first time is used as the base time differences to which are plotted.
    The first epsilon is ignored.
    """
    # preprocess input
    end_times = to_lists(end_times)
    labels = get_labels(labels, len(end_times))
    n_run = len(end_times)

    if group_by_label:
        if colors is None:
            colors = []
            color_ix = -1
            for ix, label in enumerate(labels):
                if label not in labels[:ix]:
                    color_ix += 1
                colors.append(f"C{color_ix}")

        labels = [
            x if x not in labels[:ix] else None for ix, x in enumerate(labels)
        ]
    if colors is None:
        colors = [None] * n_run

    # check time unit
    if unit not in TIME_UNITS:
        raise AssertionError(f"`unit` must be in {TIME_UNITS}")

    # create figure
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    # extract relative walltimes
    walltimes = []
    for end_ts in end_times:
        # compute differences to base
        diffs = end_ts[1:] - end_ts[0]
        # as seconds
        diffs = [diff.total_seconds() for diff in diffs]
        # append
        walltimes.append(diffs)

    # disregard calibration epsilon (inf)
    eps = [ep[1:] for ep in eps]

    for wt, ep, label, color in zip(walltimes, eps, labels, colors):
        wt = np.asarray(wt)
        # apply time unit
        if unit == MINUTE:
            wt /= 60
        elif unit == HOUR:
            wt /= 60 * 60
        elif unit == DAY:
            wt /= 60 * 60 * 24
        # plot
        ax.plot(wt, ep, label=label, marker='o', color=color)
        if indicate_end:
            ax.axvline(wt[-1], linestyle='dashed', color=color)

    # prettify plot
    if n_run > 1:
        ax.legend()
    ax.set_title(title)
    ax.set_xlabel(f"Time [{unit}]")
    ax.set_ylabel("Epsilon")
    ax.set_xscale(xscale)
    ax.set_yscale(yscale)
    # enforce integer ticks
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))

    if size is not None:
        fig.set_size_inches(size)
    fig.tight_layout()

    return ax