Exemplo n.º 1
0
def plot_avg_firing_rate_based_on_movement(tracking: pd.DataFrame,
                                           unit: pd.Series, ax: plt.axis):
    """
        Plots a units average firing rate during different kinds of movements
    """
    movements = (
        "moving",
        "walking",
        "turning_left",
        "turning_right",
        "tone_on",
    )
    for n, movement in enumerate(movements):
        left = unit.firing_rate[tracking[movement] == 0].mean()
        left_std = sem(unit.firing_rate[tracking[movement] == 0])
        right = unit.firing_rate[tracking[movement] == 1].mean()
        right_std = sem(unit.firing_rate[tracking[movement] == 1])

        plot_balls_errors(
            [n - 0.25, n + 0.25],
            [left, right],
            [left_std, right_std],
            ax,
            s=150,
            colors=colors.movements[movement],
        )
    ax.set(
        xticks=np.arange(len(movements)),
        xticklabels=movements,
        ylabel="avg firing rate",
    )
Exemplo n.º 2
0
def plot_speeds(
    body_speed,
    start: int = 0,
    end: int = -1,
    is_moving: np.ndarray = None,
    ax: plt.axis = None,
    show: bool = True,
    **other_speeds,
):
    """
        Just plot some speeds for debugging stuff
    """
    if ax is None:
        f, ax = plt.subplots(figsize=(14, 8))

    ax.plot(body_speed[start:end], label="body", color="salmon")
    for name, speed in other_speeds.items():
        ax.plot(speed[start:end], label=name)

    if is_moving is not None:
        ax.plot(is_moving, lw=4, color="k")

    ax.legend()
    ax.set(xlabel="time (frames)", ylabel="speed (cm/s)")
    if show:
        plt.show()
Exemplo n.º 3
0
def plot_aligned(
    x: np.ndarray,
    indices: Union[np.ndarray, list],
    ax: plt.axis,
    mode: str,
    window: int = 120,
    mean_kwargs: dict = None,
    **kwargs,
):
    """
        Given a 1d array and a series of indices it plots the values of 
        the array aligned to the timestamps.
    """
    pre, aft = int(window / 2), int(window / 2)
    mean_kwargs = mean_kwargs or dict(lw=4, zorder=100, color=pink_dark)

    # get plotting params
    if mode == "pre":
        pre_c, pre_lw = kwargs.pop("color", "salmon"), kwargs.pop("lw", 2)
        aft_c, aft_lw = blue_grey, 1
        ax.axvspan(0, pre, fc=blue, alpha=0.25, zorder=-20)
    else:
        aft_c, aft_lw = kwargs.pop("color", "salmon"), kwargs.pop("lw", 2)
        pre_c, pre_lw = blue_grey, 1
        ax.axvspan(aft, window, fc=blue, alpha=0.25, zorder=-20)

    # plot each trace
    X = []  # collect data to plot mean
    for idx in indices:
        x_pre = x[idx - pre:idx]
        x_aft = x[idx - 1:idx + aft]

        if len(x_pre) != pre or len(x_aft) != aft + 1:
            logger.warning(f"Could not plot data aligned to index: {idx}")
            continue
        X.append(x[idx - pre:idx + aft])

        ax.plot(x_pre, color=pre_c, lw=pre_lw, **kwargs)
        ax.plot(np.arange(aft + 1) + aft,
                x_aft,
                color=aft_c,
                lw=aft_lw,
                **kwargs)

    # plot mean and line
    X = np.vstack(X)
    plot_mean_and_error(np.mean(X, axis=0), np.std(X, axis=0), ax,
                        **mean_kwargs)
    ax.axvline(pre, lw=2, color=blue_grey_dark, zorder=-1)

    ax.set(**get_window_ticks(window, shifted=False))
Exemplo n.º 4
0
def plot_probe_electrodes(
    rsites: pd.DataFrame,
    ax: plt.axis,
    TARGETS: list = [],
    annotate_every: Union[int, bool] = 5,
    x_shift: bool = True,
    s: int = 25,
    lw: float = 0.25,
    x_pos_delta: float = 0,
):
    x = np.ones(len(rsites)) * 1.025 + x_pos_delta
    if x_shift:
        x[::2] = 1.025 + x_pos_delta - 0.05
        x[2::4] = 1.025 + x_pos_delta - 0.025
        x[1::4] = 1.025 + x_pos_delta + 0.025
    else:
        x = (x_pos_delta +
             np.tile([0.975, 1.025, 1.0, 1.05], np.int(np.ceil(
                 len(rsites) / 4)))[:len(rsites)])

    if TARGETS is not None:
        colors = [
            rs.color if rs.brain_region in TARGETS else
            ([0.3, 0.3, 0.3] if rs.brain_region in ("unknown",
                                                    "OUT") else blue_grey)
            for i, rs in rsites.iterrows()
        ]
    else:
        colors = [rs.color for i, rs in rsites.iterrows()]

    ax.scatter(
        x,
        rsites.probe_coordinates,
        s=s,
        lw=lw,
        ec=[0.3, 0.3, 0.3],
        marker="s",
        c=colors,
    )

    if annotate_every:
        for i in np.arange(len(x))[::annotate_every]:
            ax.annotate(
                f"{rsites.site_id.iloc[i]} - {rsites.brain_region.iloc[i]}",
                (0.6, rsites.probe_coordinates.iloc[i]),
                color=colors[i],
            )
    ax.set(xlim=[0.5, 1.25], ylabel="probe coordinates (um)")
Exemplo n.º 5
0
def plot_raster(
    spikes: np.ndarray,
    events: Union[np.ndarray, list],
    ax: plt.axis,
    window: int = 120,
    s=5,
    color=grey_darker,
    kde: bool = True,
    bw: int = 6,
    **kwargs,
):
    """
        Plots a raster plot of spikes aligned to timestamps events

        It assumes that all event and spike times are in frames and framerate is 60
    """
    half_window = window / 2
    yticks_step = int(np.ceil(len(events) / 8)) if len(events) > 8 else 2
    X = []
    for n, event in enumerate(events):
        event_spikes = (spikes[(spikes >= event - half_window)
                               & (spikes <= event + half_window)] - event)
        X.extend(list(event_spikes))
        y = np.ones_like(event_spikes) * n
        ax.scatter(event_spikes, y, s=5, color=color, **kwargs)
    ax.axvline(0, ls=":", color="k", lw=0.75)

    # plot KDE
    if kde:
        raise NotImplementedError("KDE env setup incorrect")
        # plot_kde(
        #     ax=ax,
        #     z=-len(events) / 4,
        #     data=X,
        #     normto=len(events) / 5,
        #     color=blue_grey_dark,
        #     kde_kwargs=dict(bw=bw, cut=0),
        #     alpha=0.6,
        #     invert=False,
        # )

    # set x axis properties
    ax.set(
        yticks=np.arange(0, len(events), yticks_step),
        ylabel="event number",
        **get_window_ticks(window),
    )