Exemple #1
0
def plot_smoothed(
        ax: plt.axis,
        timesteps: Iterable[int],
        values: Sequence[float], smoothed_values: float,
        label: str,
        alpha: float = 0.2) \
        -> None:
    line = ax.plot(timesteps, smoothed_values, label=label)[0]
    clr = line.get_color()
    ax.plot(timesteps, values, color=clr, alpha=alpha)
def plot_peaks(axis: plt.axis, collection: PeakCollection) -> plt.axis:
    for i, peak_data in enumerate(collection):
        if i > 1000:  # safety break
            break
        if peak_data.relevant:
            axis.plot(peak_data.get_x,
                      peak_data.get_y,
                      'x',
                      color='r',
                      alpha=1)
    return axis
Exemple #3
0
def plot_percentiles(
        ax: plt.axis, df: pd.DataFrame, percentiles: Tuple[float, float],
        alpha: float) -> None:

    assert 0 <= alpha <= 1

    assert len(percentiles) == 2 and \
        0 <= percentiles[0] <= 1 and \
        0 <= percentiles[1] <= 1, "percentiles must be between 0 and 1"

    pct_low = df.quantile(percentiles[0], axis=1)
    pct_high = df.quantile(percentiles[1], axis=1)
    ax.fill_between(df.index, pct_low, pct_high, alpha=alpha)
Exemple #4
0
    def draw(self, ax: plt.axis):
        ax.scatter(
            *self.p0,
            s=200,
            lw=1,
            ec=[0.3, 0.3, 0.3],
            color=self.color,
            zorder=100,
        )
        if self.last:
            ax.scatter(
                *self.p1,
                s=200,
                lw=1,
                ec=[0.3, 0.3, 0.3],
                color=self.color,
                zorder=100,
            )

        ax.plot(
            [self.p0[0], self.p1[0]],
            [self.p0[1], self.p1[1]],
            lw=6,
            color=[0.3, 0.3, 0.3],
            zorder=98,
        )
        ax.plot(
            [self.p0[0], self.p1[0]],
            [self.p0[1], self.p1[1]],
            lw=5,
            color=self.color,
            zorder=99,
        )
Exemple #5
0
    def plot_areas_society_progress(ax: plt.axis, time_array: List,
                                    society_snapshot: Dict,
                                    society_progress: Dict):

        previous_status = None
        ax.set_xlabel('time [days]')
        ax.set_ylabel('population percentage')
        for st in Status:
            society_progress[st.name].append(society_snapshot[st.name])

            if previous_status:
                lower_limit = society_progress[previous_status.name]
            else:
                lower_limit = [0]

            ax.fill_between(x=time_array,
                            y1=lower_limit,
                            y2=society_progress[st.name],
                            color=st.value,
                            label=st.name,
                            alpha=0.25)

            ax.text(x=time_array[-1],
                    y=1 / 2 * (lower_limit[-1] + society_snapshot[st.name]),
                    s=r"{0:.2f}".format(society_snapshot[st.name] -
                                        lower_limit[-1]),
                    size=10,
                    color=st.value)

            previous_status = st
Exemple #6
0
    def plot_table_params(ax: plt.axis, params_dict: Dict):

        cells = [[i] for i in params_dict.values()]

        ax.axis('off')
        table = ax.table(cellText=cells,
                         cellLoc='center',
                         rowLabels=list(params_dict.keys()),
                         rowLoc='center',
                         colWidths=[0.1],
                         loc='center')

        table.auto_set_font_size(False)
        table.set_fontsize(10)
Exemple #7
0
def plot_tracking_linearized(
    tracking: Union[dict, pd.DataFrame],
    ax: plt.axis = None,
    plot: bool = True,
    **kwargs,
):
    ax = ax or plt.subplots(figsize=(9, 9))[1]

    x = tracking["global_coord"]
    y = np.linspace(1, 0, len(x))

    if not plot:
        ax.scatter(x, y, **kwargs)
    else:
        ax.plot(x, y, **kwargs)
Exemple #8
0
def draw_obstacle(ax: plt.axis, obst_pts, *args, **kwargs):
    obst_x = np.float32((obst_pts[0], obst_pts[2]))
    obst_y = np.float32((obst_pts[1], obst_pts[3]))
    p_min = np.float32((np.min(obst_x), np.min(obst_y)))
    p_max = np.float32((np.max(obst_x), np.max(obst_y)))
    h = p_max[1] - p_min[1]
    w = p_max[0] - p_min[0]
    patch = patches.Rectangle(p_min,
                              w,
                              h,
                              linewidth=1,
                              edgecolor='k',
                              facecolor='k')
    ax.add_patch(patch)
    return patch
def display_overlay(rgb_image: np.ndarray, binary_mask: np.ndarray, ax: plt.axis, check_mask=False):
    """
    Display version of RGB_image where `False` values in binary_mask are darkened.
    :param rgb_image:
    :param binary_mask:
    :param ax:
    :param check_mask:
    :return:
    """
    if check_mask and binary_mask.dtype is not bool:
        warn("`binary_mask` is non-boolean, attempting conversion.")
        binary_mask = binary_mask.astype(bool)

    overlay = rgb_image.copy()
    overlay[~binary_mask] = overlay[~binary_mask] // 2  # Darken the color of the original image where mask is False
    ax.imshow(overlay)
Exemple #10
0
def plot_bouts_x(
    tracking_data: pd.DataFrame,
    bouts: pd.DataFrame,
    ax: plt.axis,
    variable: str,
    color: str = blue_grey,
    **kwargs,
):
    """
        Plots a variable from the tracking data for each bout
    """
    for i, bout in bouts.iterrows():
        ax.plot(
            tracking_data[variable][bout.start_frame:bout.end_frame],
            color=color,
            **kwargs,
        )
Exemple #11
0
def plot_cluster_sizes(cluster_labels: list,
                       ax: plt.axis = None) -> np.ndarray:
    """
    Plots cluster sizes using a histogram and returns a list of most frequent
    cluster sizes.

    Parameters
    ----------
    cluster_labels : list
        List of cluster labels
    ax : plt.axis
        Matplotlib axis (default None)

    Returns
    -------
    most_common_cluster_sizes : np.ndarray
        Numpy array containing the most common cluster sizes
    """
    if ax is None:
        _, ax = plt.subplots()

    # Print cluster size ratio (max / min)
    labels_unique, labels_counts = np.unique(cluster_labels,
                                             return_counts=True)
    cluster_sizes, cluster_size_counts = np.unique(labels_counts,
                                                   return_counts=True)

    num_clusters = len(labels_unique)
    max_cluster_size = max(labels_counts)
    min_cluster_size = min(labels_counts)
    cluster_size_ratio = max_cluster_size / min_cluster_size
    print(
        f"{num_clusters} clusters: max={max_cluster_size}, min={min_cluster_size}, ratio={cluster_size_ratio}"
    )

    # Plot distribution of cluster sizes
    sns.histplot(labels_counts, bins=max_cluster_size, ax=ax)
    ax.set_xlabel("Cluster size")
    ax.set_ylabel("Number of words in cluster")
    plt.show()

    # Sort cluster sizes by frequency
    most_common_cluster_sizes = cluster_sizes[np.argsort(cluster_size_counts)
                                              [::-1]]

    return most_common_cluster_sizes
Exemple #12
0
def plot_bouts_x_by_y(
    tracking_data: pd.DataFrame,
    bouts: pd.DataFrame,
    ax: plt.axis,
    x: str,
    y: str,
    color: str = blue_grey,
    **kwargs,
):
    """
        Plots two tracking variables one against the other for each bout
    """
    for i, bout in bouts.iterrows():
        ax.plot(
            tracking_data[x][bout.start_frame:bout.end_frame],
            tracking_data[y][bout.start_frame:bout.end_frame],
            color=color,
            **kwargs,
        )
Exemple #13
0
    def mpl_plot(self,
                 ax: plt.axis = None,
                 temp_unit: str = "GK",
                 **kwargs) -> plt.axis:

        ax = ax or plt.gca()

        t = np.logspace(-2, 1, 1000)
        if temp_unit is "GK":
            ax.loglog(
                t,
                self.rate(t),
                color=self.color,
                label="Reaclib-" + self.label + " " + self.__str__(),
                **kwargs,
            )
        elif temp_unit is "KeV":
            ax.loglog(
                Temperature(t).kev,
                self.rate(t),
                color=self.color,
                label=self.label + " " + self.__str__(),
                **kwargs,
            )

        ax.legend()

        ax = super().mpl_plot(ax=ax, temp_unit=temp_unit)

        return ax
Exemple #14
0
def plot_tensorboard(
        ax: plt.axis,
        tag: str,
        tb_data: Iterable[Iterable[Iterable[Path]]],
        names: Iterable[str],
        percentiles: Optional[Tuple[float, float]] = (0.25, 0.75),
        alpha: float = 0.1,
        use_data_cache: bool = False,
        data_cache_fname: str = ".rlgear_data.p",
        show_same_num_timesteps: bool = False,
        max_step: Optional[int] = None) -> None:

    if use_data_cache:
        with open(data_cache_fname, 'rb') as f:
            out_dfs = pickle.load(f)
    else:
        out_dfs = []
        for grouped_files in tb_data:
            out_dfs.append(tb_to_df(
                grouped_files,
                tag, max_step=max_step, only_complete_data=True))

        with open(data_cache_fname, 'wb') as f:
            pickle.dump(out_dfs, f)

    if show_same_num_timesteps:
        shorten_dfs(out_dfs)

    for name, df in zip(names, out_dfs):
        ax.plot(df.index, df.mean(axis=1), label=name)

        if percentiles:
            plot_percentiles(ax, df, percentiles, alpha)

    # https://stackoverflow.com/a/25750438
    ax.xaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
    ax.set_xlabel('Training Step')
    plt.tight_layout()
Exemple #15
0
def plot_heatmap_2d(
    data: Union[dict, pd.DataFrame],
    key: str = None,
    ax: plt.axis = None,
    x_key: str = "x",
    y_key: str = "y",
    cmap: str = "inferno",
    vmin: int = 0,
    vmax: int = None,
    gridsize: int = 30,
    mincnt: int = 1,
    **kwargs,
):
    # bin data in 2d
    try:
        ax.hexbin(
            data[x_key],
            data[y_key],
            data[key] if key is not None else None,
            cmap=cmap,
            gridsize=gridsize,
            vmin=vmin,
            vmax=vmax,
            mincnt=mincnt,
            **kwargs,
        )
    except ValueError:  # likely the data was nested arrays
        ax.hexbin(
            np.hstack(data[x_key]),
            np.hstack(data[y_key]),
            np.hstack(data[key]) if key is not None else None,
            cmap=cmap,
            gridsize=gridsize,
            vmin=vmin,
            vmax=vmax,
            mincnt=mincnt,
            **kwargs,
        )
Exemple #16
0
    def _sub_plot_single_from_array(self, ax:plt.axis, data:np.array, x_title:str, title:str, legend=None):
        '''
        Create a sub-plot for a given column. 

        Parameters
        ----------
        ax : plt.axis
            figure axis
        
        data : np.array
            data to plot
        '''
        if legend == None:
            ax.plot(data,linestyle='-',linewidth=1)
        else:
            ax.plot(data,linestyle='-',linewidth=1,label=legend)
        ax.set_xlabel(x_title)
        ax.set_title(title) 
Exemple #17
0
    def _sub_plot_multiple(self, axis:plt.axis, columns:tuple):
        '''
        Create a sub-plot for the given columns. 

        Parameters
        ----------
        ax : plt.axis
            figure axis
        columns : tuple
            column names
        '''
        title = ''
        for col in columns:
            axis.plot(self.m_data_df[col], label=col,linestyle='-',linewidth=0.1)
            title += col + ', '
        axis.set_title(title) 
        axis.set_xlabel('Rollouts')
        axis.legend()
Exemple #18
0
def plot_tracking_xy(
    tracking: Union[dict, pd.DataFrame],
    key: str = None,
    skip_frames: int = 1,
    ax: plt.axis = None,
    plot: bool = False,
    **kwargs,
):
    ax = ax or plt.subplots(figsize=(9, 9))[1]

    if key is None:
        if not plot:
            ax.scatter(
                tracking["x"][::skip_frames],
                tracking["y"][::skip_frames],
                color=[0.3, 0.3, 0.3],
                **kwargs,
            )
        else:
            ax.plot(
                tracking["x"][::skip_frames],
                tracking["y"][::skip_frames],
                **kwargs,
            )
    else:
        ax.scatter(
            tracking["x"][::skip_frames],
            tracking["y"][::skip_frames],
            c=tracking[key][::skip_frames],
            **kwargs,
        )

        if "orientation" in key or "angle" in key:
            # draw arrows to mark the angles/colors mapping
            angles = np.linspace(0, 2 * np.pi, 16)
            x = 2 * np.cos(angles[::-1] + np.pi / 2) + 25
            y = 2 * np.sin(angles + np.pi / 2) + 2
            ax.scatter(x,
                       y,
                       s=80,
                       zorder=50,
                       c=np.degrees(angles),
                       alpha=1,
                       **kwargs)
Exemple #19
0
def plot_bouts_1d(
    tracking: Union[dict, pd.DataFrame],
    bouts: pd.DataFrame,
    ax: plt.axis,
    direction: bool = None,
    zorder: int = 100,
    lw: float = 2,
    alpha: float = 1,
    **kwargs,
):
    # select bouts by direction
    if direction is not None:
        bouts = bouts.loc[bouts.direction == direction]

    # get coords
    x = tracking["global_coord"]
    y = np.linspace(1, 0, len(x))

    # plot
    for i, bout in bouts.iterrows():
        _x = x[bout.start_frame:bout.end_frame]
        _y = y[bout.start_frame:bout.end_frame]

        ax.plot(
            _x,
            _y,
            color=colors.bout_direction_colors[bout.direction],
            zorder=zorder,
            lw=lw,
            alpha=alpha,
            **kwargs,
        )
        ax.scatter(
            _x[0],
            _y[0],
            color="white",
            lw=1,
            ec=colors.bout_direction_colors[bout.direction],
            s=30,
            zorder=101,
            alpha=0.85,
            **kwargs,
        )
        ax.scatter(
            _x[-1],
            _y[-1],
            color=[0.2, 0.2, 0.2],
            lw=1,
            ec=colors.bout_direction_colors[bout.direction],
            s=30,
            zorder=101,
            alpha=0.85,
            **kwargs,
        )
Exemple #20
0
    def plot_lines_society_progress(ax: plt.axis, time_array: List,
                                    society_snapshot: Dict,
                                    society_progress: Dict):

        ax.set_xlabel('time [days]')
        ax.set_ylabel('population percentage')
        for st in Status:
            society_progress[st.name].append(society_snapshot[st.name] /
                                             society_snapshot["Total"])
            ax.plot(time_array,
                    society_progress[st.name],
                    c=st.value,
                    ls='-',
                    label=st.name,
                    alpha=0.5)
            ax.text(x=time_array[-1],
                    y=society_progress[st.name][-1],
                    s=r"{0:.2f}".format(society_progress[st.name][-1]),
                    size=10,
                    color=st.value)
Exemple #21
0
def get_auto_ylims(ax: plt.axis,
                   hMC,
                   *,
                   hdata=None,
                   log_y=False,
                   yaxis_scale='auto'
                   ) -> Tuple[Union[int, float], Union[int, float]]:
    """Return the minimum and maximum values for the y axis. If yaxis_scale is 'auto', get a value such that the
    histogram stays in bounds of the plot, and the legend does not overlap with the histogram.

    """

    ylims = ax.get_ylim()
    ymin = ylims[0]

    if yaxis_scale == 'auto':
        if not isinstance(hMC[0], float):
            hMC = hMC[-1]

        labelFraction = 1 / 4
        legendFraction = 1 / 3

        maxBelowLabel = max(hMC[:len(hMC) * 2 // 3])
        maxBelowLegend = max(hMC[len(hMC) * 2 // 3:])
        if hdata is not None:
            maxBelowLabel = max(maxBelowLabel,
                                max(hdata[:len(hdata) * 2 // 3]))
            maxBelowLegend = max(maxBelowLegend,
                                 max(hdata[len(hdata) * 2 // 3:]))

        if log_y:
            ymax = max(maxBelowLabel**(1 / (1 - labelFraction)),
                       maxBelowLegend**(1 / (1 - legendFraction)))
        else:
            ymax = max(maxBelowLabel / (1 - labelFraction),
                       maxBelowLegend / (1 - legendFraction))
    else:
        if log_y:
            ymax = ylims[1]**yaxis_scale
        else:
            ymax = ylims[1] * yaxis_scale

    return ymin, ymax
Exemple #22
0
    def _sub_plot_single(self, ax:plt.axis, column:str, legend=None):
        '''
        Create a sub-plot for a given column. 

        Parameters
        ----------
        ax : :plt.axis
            figure axis
        
        column : str
            column name
        '''
        if legend == None:
            ax.plot(self.m_data_df[column],linestyle='-',linewidth=0.1)
        else:
            ax.plot(self.m_data_df[column],linestyle='-',linewidth=0.1,label=legend)
        
        ax.set_xlabel('Rollouts')
        if False and column == 'Num_timesteps':
            column += ' (approx. run time: {0:.2f} hours)'.format(self.estimated_run_time())
        ax.set_title(column) 
Exemple #23
0
def plot_balls_errors(
    x: np.ndarray,
    y: np.ndarray,
    yerr: np.ndarray,
    ax: plt.axis,
    s: int = 150,
    colors: Union[list, str] = None,
):
    """
        Given a serires of XY values and Y errors it plots a scatter for each XY point and a line
        to mark each Y error
    """
    if colors is None:
        colors = [blue_grey] * len(x)
    elif isinstance(colors, str):
        colors = [colors] * len(x)

    ax.scatter(x, y, s=s, c=colors, zorder=100, lw=1, ec=[0.3, 0.3, 0.3])
    ax.plot(x, y, lw=3, color=colors[0], zorder=-1)

    if yerr is not None:
        for n in range(len(x)):
            ax.plot(
                [x[n], x[n]],
                [y[n] - yerr[n], y[n] + yerr[n]],
                lw=4,
                color=[0.3, 0.3, 0.3],
                zorder=96,
                solid_capstyle="round",
            )
            ax.plot(
                [x[n], x[n]],
                [y[n] - yerr[n], y[n] + yerr[n]],
                lw=2,
                color=colors[n],
                zorder=98,
                solid_capstyle="round",
            )
Exemple #24
0
def plot_probe_electrodes(
    rsites: pd.DataFrame,
    ax: plt.axis,
    TARGETS: list = [],
    annotate_every: Union[int, bool] = 5,
    x_shift: bool = True,
    s: int = 25,
    lw: float = 0.25,
    x_pos_delta: float = 0,
):
    x = np.ones(len(rsites)) * 1.025 + x_pos_delta
    if x_shift:
        x[::2] = 1.025 + x_pos_delta - 0.05
        x[2::4] = 1.025 + x_pos_delta - 0.025
        x[1::4] = 1.025 + x_pos_delta + 0.025
    else:
        x = (x_pos_delta +
             np.tile([0.975, 1.025, 1.0, 1.05], np.int(np.ceil(
                 len(rsites) / 4)))[:len(rsites)])

    if TARGETS is not None:
        colors = [
            rs.color if rs.brain_region in TARGETS else
            ([0.3, 0.3, 0.3] if rs.brain_region in ("unknown",
                                                    "OUT") else blue_grey)
            for i, rs in rsites.iterrows()
        ]
    else:
        colors = [rs.color for i, rs in rsites.iterrows()]

    ax.scatter(
        x,
        rsites.probe_coordinates,
        s=s,
        lw=lw,
        ec=[0.3, 0.3, 0.3],
        marker="s",
        c=colors,
    )

    if annotate_every:
        for i in np.arange(len(x))[::annotate_every]:
            ax.annotate(
                f"{rsites.site_id.iloc[i]} - {rsites.brain_region.iloc[i]}",
                (0.6, rsites.probe_coordinates.iloc[i]),
                color=colors[i],
            )
    ax.set(xlim=[0.5, 1.25], ylabel="probe coordinates (um)")
Exemple #25
0
def plot_raster(
    spikes: np.ndarray,
    events: Union[np.ndarray, list],
    ax: plt.axis,
    window: int = 120,
    s=5,
    color=grey_darker,
    kde: bool = True,
    bw: int = 6,
    **kwargs,
):
    """
        Plots a raster plot of spikes aligned to timestamps events

        It assumes that all event and spike times are in frames and framerate is 60
    """
    half_window = window / 2
    yticks_step = int(np.ceil(len(events) / 8)) if len(events) > 8 else 2
    X = []
    for n, event in enumerate(events):
        event_spikes = (spikes[(spikes >= event - half_window)
                               & (spikes <= event + half_window)] - event)
        X.extend(list(event_spikes))
        y = np.ones_like(event_spikes) * n
        ax.scatter(event_spikes, y, s=5, color=color, **kwargs)
    ax.axvline(0, ls=":", color="k", lw=0.75)

    # plot KDE
    if kde:
        raise NotImplementedError("KDE env setup incorrect")
        # plot_kde(
        #     ax=ax,
        #     z=-len(events) / 4,
        #     data=X,
        #     normto=len(events) / 5,
        #     color=blue_grey_dark,
        #     kde_kwargs=dict(bw=bw, cut=0),
        #     alpha=0.6,
        #     invert=False,
        # )

    # set x axis properties
    ax.set(
        yticks=np.arange(0, len(events), yticks_step),
        ylabel="event number",
        **get_window_ticks(window),
    )
Exemple #26
0
def draw_model_prediction(fig: plt.figure,
                          ax: plt.axis,
                          model: nn.Module,
                          gw: int = 0,
                          min_point: (float, float) = (0, 0),
                          max_point: (float, float) = (500, 500),
                          resolution: float = 1,
                          apply_ploss: bool = True,
                          colorbar: bool = True,
                          detection_flag: bool = False,
                          detection_threshold: float = -140,
                          device='cpu',
                          *args,
                          **kwargs):
    x_mesh, y_mesh = np.mgrid[min_point[0]:max_point[0]:resolution,
                              min_point[1]:max_point[1]:resolution]
    x = x_mesh.ravel()
    y = y_mesh.ravel()
    data_x = np.stack([x, y], axis=1).astype(np.float32)
    if detection_flag is False:
        z = model(torch.from_numpy(data_x).to(device),
                  apply_ploss).detach()[:, gw, 0].cpu().numpy()
    else:
        #z = model(torch.from_numpy(data_x).to(device), apply_ploss).detach()[:, gw, 0].cpu().numpy()
        z = model(torch.from_numpy(data_x).to(device),
                  apply_ploss).detach()[:, gw, :].cpu().numpy()
        sub_z1 = z[:, 1]
        sub_z = z[:, 0]
        sub_z[sub_z1 < 0.9] = detection_threshold
        z = z[:, 0]
    z = z.reshape((int(max_point[0] - min_point[0]),
                   int(max_point[1] - min_point[1]))).transpose()
    z[z < detection_threshold] = detection_threshold
    im = ax.imshow(z, norm=colors.PowerNorm(gamma=2))
    if colorbar:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im, cax=cax, orientation='vertical', label='dBm')
    return im
def set_inset_spectrum(axis: plt.axis, data: np.ndarray, current_index: int,
                       peak_collection: PeakCollection) -> plt.axis:
    axis.plot(data)  # Show intensity spectrum
    axis = plot_peaks(axis=axis, collection=peak_collection)  # Show peaks
    y_bot, y_top = axis.get_ylim()
    text_height = y_bot + 0.6 * (y_top - y_bot)  # Data coordinates
    _margin = 50
    x_lim = [
        max(0, current_index - _margin),
        min(len(data) - 1, current_index + _margin)
    ]
    # Plot clusters
    # for cluster in peak_collection.get_clusters:
    #     bound_left, bound_right = cluster.get_value_slice
    #     if bound_right > x_lim[0] or bound_left < x_lim[1]:
    #         axis.axvspan(bound_left, bound_right, alpha=0.5, color='green')
    #     if x_lim[0] < cluster.get_avg_x < x_lim[1]:
    #         axis.text(x=cluster.get_avg_x, y=text_height, s=r'$\tilde{m}$'+f'={cluster.get_transverse_mode_id}')
    axis.axvline(x=current_index, color='r')
    axis.set_xlim(x_lim)
    return axis
def add_descriptions_to_plot(
    ax: plt.axis,
    experiment: Union[str, None] = None,
    luminosity: Union[str, None] = None,
    additional_info: Union[str, None] = None,
):
    ax.set_title(experiment,
                 loc="left",
                 fontdict={
                     'size': 16,
                     'style': 'normal',
                     'weight': 'bold'
                 })
    ax.set_title(luminosity, loc="right")
    ax.annotate(additional_info, (0.02, 0.98),
                xytext=(4, -4),
                xycoords='axes fraction',
                textcoords='offset points',
                fontweight='bold',
                ha='left',
                va='top')
Exemple #29
0
    def mpl_plot(self,
                 ax: plt.axis = None,
                 temp_unit: str = "GK",
                 **kwargs) -> plt.axis:
        """

        Parameters
        ----------
        ax : mpl.Axis
            mpl Axis to plot onto, if none provided get current axis is used.
        temp_unit : str {"Gk", "KeV"}
            Tempreture units for x axis.
        kwargs : key word args for mpl.plot
        """
        ax = ax or plt.gca()
        ax.set_title("Reaction Rate")
        ax.set_ylabel(r"Rate ($cm^3\;mol^{-1}\;sec^{-1}$)")

        if temp_unit is "GK":
            ax.set_xlabel("Temperature ($GK$)")

        return ax
Exemple #30
0
def plot_aligned(
    x: np.ndarray,
    indices: Union[np.ndarray, list],
    ax: plt.axis,
    mode: str,
    window: int = 120,
    mean_kwargs: dict = None,
    **kwargs,
):
    """
        Given a 1d array and a series of indices it plots the values of 
        the array aligned to the timestamps.
    """
    pre, aft = int(window / 2), int(window / 2)
    mean_kwargs = mean_kwargs or dict(lw=4, zorder=100, color=pink_dark)

    # get plotting params
    if mode == "pre":
        pre_c, pre_lw = kwargs.pop("color", "salmon"), kwargs.pop("lw", 2)
        aft_c, aft_lw = blue_grey, 1
        ax.axvspan(0, pre, fc=blue, alpha=0.25, zorder=-20)
    else:
        aft_c, aft_lw = kwargs.pop("color", "salmon"), kwargs.pop("lw", 2)
        pre_c, pre_lw = blue_grey, 1
        ax.axvspan(aft, window, fc=blue, alpha=0.25, zorder=-20)

    # plot each trace
    X = []  # collect data to plot mean
    for idx in indices:
        x_pre = x[idx - pre:idx]
        x_aft = x[idx - 1:idx + aft]

        if len(x_pre) != pre or len(x_aft) != aft + 1:
            logger.warning(f"Could not plot data aligned to index: {idx}")
            continue
        X.append(x[idx - pre:idx + aft])

        ax.plot(x_pre, color=pre_c, lw=pre_lw, **kwargs)
        ax.plot(np.arange(aft + 1) + aft,
                x_aft,
                color=aft_c,
                lw=aft_lw,
                **kwargs)

    # plot mean and line
    X = np.vstack(X)
    plot_mean_and_error(np.mean(X, axis=0), np.std(X, axis=0), ax,
                        **mean_kwargs)
    ax.axvline(pre, lw=2, color=blue_grey_dark, zorder=-1)

    ax.set(**get_window_ticks(window, shifted=False))