예제 #1
0
    def __init__(self,
                 ax: plt.Axes = None,
                 set_ax=False,
                 img_path: str = None,
                 **kwargs):
        """
            Renders an image of the hairpin maze on a matplotlib axis
        """
        ax = ax or plt.gca()

        image = self.get_image(img_path)

        ax.imshow(
            image,
            extent=[self.x_0, self.x_1, self.y_0, self.y_1],
            origin="lower",
            zorder=-100,
            **kwargs,
        )

        if set_ax:
            ax.set(
                xlim=[self.x_0, self.x_1],
                ylim=[self.y_0, self.y_1],
                xlabel="cm",
                ylabel="cm",
                xticks=[self.x_0, self.x_1],
                yticks=[self.y_0, self.y_1],
            )
        clean_axes(ax.figure)
        ax.axis("equal")
def plot_n_units_per_channel(rname, units, rsites, TARGETS):
    """
        Plot the number of units on each channel from a single recording, highlighting
        channels in specific regoins
    """
    logger.info(f"Plotting n units per channel for {rname}")
    f, axes = plt.subplots(figsize=(12, 12), ncols=2, sharey=True)
    f.suptitle(rname)
    f._save_name = f"activity_units_per_channel"

    # draw probe
    plot_probe_electrodes(rsites, axes[0], TARGETS)

    # draw barplot of # units per channel
    counts = units.groupby("site_id").count()["name"]
    _colors = [
        rsites.loc[rsites.site_id == n]["color"].iloc[0] for n in counts.index
    ]
    _regions = [
        rsites.loc[rsites.site_id == n]["brain_region"].iloc[0]
        for n in counts.index
    ]
    colors = [
        c if r in TARGETS else ("k" if r in ("unknown", "OUT") else blue_grey)
        for c, r in zip(_colors, _regions)
    ]
    probe_coords = [
        rsites.loc[rsites.site_id == n]["probe_coordinates"].iloc[0]
        for n in counts.index
    ]

    axes[1].scatter(
        counts.values + np.random.normal(0, 0.02, size=len(counts.values)),
        probe_coords,
        color=colors,
        s=100,
        lw=1,
        ec="k",
    )

    for x, y in zip(counts.values, probe_coords):
        axes[1].plot([0, x], [y, y], color=[0.2, 0.2, 0.2], lw=2, zorder=-1)

    # cleanup and save
    axes[0].set(
        ylabel="Probe position (um)",
        xticks=[],
        xlim=[0.5, 1.5],
        ylim=[0, 8000],
    )
    axes[1].set(xlabel="# units per channel", ylim=[0, 8000])

    clean_axes(f)
    return f
예제 #3
0
def plot_unit(
    mouse_id: str,
    session_name: str,
    tracking: pd.DataFrame,
    bouts: pd.DataFrame,
    out_bouts: pd.DataFrame,
    in_bouts: pd.DataFrame,
    bouts_stacked: pd.DataFrame,
    units: pd.DataFrame,
    unit_id: Union[int, str],
    tone_onsets: np.ndarray,
    WINDOW: int,
):

    frames = np.arange(0, len(tracking.x), 60 * 60 * 5)
    time = (frames / 60 / 60).astype(np.int32)

    if isinstance(unit_id, int):
        if unit_id not in units.unit_id.values:
            raise ValueError(f"Unit id {unit_id} not in units list:\n{units}")

    for i, unit in units.iterrows():
        if isinstance(unit_id, int):
            if unit.unit_id != unit_id:
                continue
        logger.info(
            f'Showing activity summary for unit unit {i+1}/{len(units)} (id: {unit.unit_id} - in: "{unit.brain_region}")'
        )

        # get tracking data at each spike
        tracking["firing_rate"] = unit.firing_rate
        unit_tracking = data_utils.select_by_indices(tracking, unit.spikes)
        unit_tracking["spikes"] = unit.spikes

        unit_vmax_frate = np.percentile(unit.firing_rate, 98)

        out_bouts_stacked = data_utils.get_bouts_tracking_stacked(
            tracking, out_bouts)
        in_bouts_stacked = data_utils.get_bouts_tracking_stacked(
            tracking, in_bouts)

        # crate figure
        f = plt.figure(figsize=(24, 12))
        axes = f.subplot_mosaic("""
            ARBCDUV
            ARBCEUV
            FSGHIXX
            FSGHLXX
            MTNOPZY
            MTNOQZY
        """)
        f.suptitle(session_name + f"unit {unit.unit_id} {unit.brain_region}")
        f._save_name = f"unit_{unit.unit_id}_{unit.brain_region}".replace(
            "\\", "_")

        # plot spikes against tracking, speed and angular velocity
        visuals.plot_heatmap_2d(unit_tracking,
                                "spikes",
                                axes["A"],
                                cmap="inferno",
                                vmax=None)

        axes["B"].plot(tracking.speed, color=blue_grey, lw=2)
        axes["B"].scatter(
            unit.spikes,
            unit_tracking.speed,
            color=colors.speed,
            s=5,
            zorder=11,
        )

        axes["C"].plot(tracking.dmov_velocity, color=blue_grey, lw=2)
        axes["C"].scatter(
            unit.spikes,
            unit_tracking.dmov_velocity,
            color=colors.dmov_velocity,
            s=5,
            zorder=11,
        )

        # plot spikes heatmap
        visuals.plot_heatmap_2d(
            unit_tracking,
            "firing_rate",
            axes["R"],
            cmap="inferno",
            vmax=unit_vmax_frate,
        )

        # plot spikes raster around tone onsets
        visuals.plot_raster(unit.spikes, tone_onsets, axes["D"], window=WINDOW)
        visuals.plot_aligned(
            tracking.firing_rate,
            tone_onsets,
            axes["E"],
            "aft",
            color=blue_grey,
            lw=1,
            alpha=0.85,
            window=WINDOW,
        )

        # plot spike rasters at bouts onsets and offsets
        visuals.plot_raster(unit.spikes,
                            bouts.start_frame,
                            axes["I"],
                            window=WINDOW)
        visuals.plot_aligned(
            tracking.firing_rate,
            bouts.start_frame,
            axes["L"],
            "aft",
            color=blue_grey,
            lw=1,
            alpha=0.85,
            window=WINDOW,
        )

        visuals.plot_raster(unit.spikes,
                            bouts.end_frame,
                            axes["P"],
                            window=WINDOW)
        visuals.plot_aligned(
            tracking.firing_rate,
            bouts.end_frame,
            axes["Q"],
            "pre",
            color=blue_grey,
            lw=1,
            alpha=0.85,
            window=WINDOW,
        )

        # plot firing rate binned by speed and angular velocity
        visuals.plot_bin_x_by_y(
            tracking,
            "firing_rate",
            "speed",
            axes["U"],
            colors=colors.speed,
            bins=10,
            min_count=10,
            s=50,
        )
        visuals.plot_heatmap_2d(
            tracking,
            x_key="speed",
            y_key="firing_rate",
            ax=axes["U"],
            vmax=None,
            zorder=-10,
            alpha=0.5,
            cmap="inferno",
            linewidths=0,
            gridsize=20,
        )

        visuals.plot_bin_x_by_y(
            tracking,
            "firing_rate",
            "dmov_velocity",
            axes["V"],
            colors=colors.dmov_velocity,
            bins=10,
            min_count=10,
            s=50,
        )
        visuals.plot_heatmap_2d(
            tracking,
            x_key="dmov_velocity",
            y_key="firing_rate",
            ax=axes["V"],
            vmax=None,
            zorder=-10,
            alpha=0.5,
            cmap="inferno",
            linewidths=0,
            gridsize=20,
        )
        axes["H"].axvline(0, ls=":", lw=2, color=[0.2, 0.2, 0.2], zorder=101)

        # plot probe electrodes in which there is the unit
        visuals.plot_probe_electrodes(
            db_tables.Unit.get_unit_sites(mouse_id, session_name,
                                          unit["unit_id"],
                                          unit['probe_configuration']),
            axes["Z"],
            annotate_every=1,
            TARGETS=None,
            x_shift=False,
            s=100,
            lw=2,
        )

        # plot firing rate based on movements
        visuals.plot_avg_firing_rate_based_on_movement(tracking, unit,
                                                       axes["X"])

        # plot heatmap of firing rate vs speed by ang vel heatmap (during bouts)
        trk = dict(
            speed=tracking.speed[tracking.walking == 1],
            dmov_velocity=tracking.dmov_velocity[tracking.walking == 1],
            firing_rate=tracking.firing_rate[tracking.walking == 1],
        )
        visuals.plot_heatmap_2d(
            trk,
            key="firing_rate",
            ax=axes["Y"],
            x_key="speed",
            y_key="dmov_velocity",
            vmax=None,
        )

        # --------------------------------- in bouts --------------------------------- #
        # plot bouts 2d
        visuals.plot_bouts_heatmap_2d(tracking,
                                      in_bouts,
                                      "firing_rate",
                                      axes["F"],
                                      vmax=unit_vmax_frate)

        # plot firing rate binned by global coordinates
        visuals.plot_bin_x_by_y(
            in_bouts_stacked,
            "firing_rate",
            "global_coord",
            axes["S"],
            colors=colors.global_coord,
            bins=10,
            min_count=10,
            s=50,
        )
        visuals.plot_heatmap_2d(
            in_bouts_stacked,
            x_key="global_coord",
            y_key="firing_rate",
            ax=axes["S"],
            vmax=None,
            zorder=-10,
            alpha=0.5,
            cmap="inferno",
            linewidths=0,
            gridsize=20,
        )

        # plot firing rate binned by speed and angular velocity
        visuals.plot_bin_x_by_y(
            in_bouts_stacked,
            "firing_rate",
            "speed",
            axes["G"],
            colors=colors.speed,
            bins=10,
            min_count=10,
            s=50,
        )
        visuals.plot_heatmap_2d(
            in_bouts_stacked,
            x_key="speed",
            y_key="firing_rate",
            ax=axes["G"],
            vmax=None,
            zorder=-10,
            alpha=0.5,
            cmap="inferno",
            linewidths=0,
            gridsize=20,
        )

        visuals.plot_bin_x_by_y(
            in_bouts_stacked,
            "firing_rate",
            "dmov_velocity",
            axes["H"],
            colors=colors.dmov_velocity,
            bins=10,
            min_count=10,
            s=50,
        )
        visuals.plot_heatmap_2d(
            in_bouts_stacked,
            x_key="dmov_velocity",
            y_key="firing_rate",
            ax=axes["H"],
            vmax=None,
            zorder=-10,
            alpha=0.5,
            cmap="inferno",
            linewidths=0,
            gridsize=20,
        )
        axes["H"].axvline(0, ls=":", lw=2, color=[0.2, 0.2, 0.2], zorder=101)

        # --------------------------------- out bouts -------------------------------- #
        visuals.plot_bouts_heatmap_2d(tracking,
                                      out_bouts,
                                      "firing_rate",
                                      axes["M"],
                                      vmax=unit_vmax_frate)

        # plot firing rate binned by global coordinates
        visuals.plot_bin_x_by_y(
            out_bouts_stacked,
            "firing_rate",
            "global_coord",
            axes["T"],
            colors=colors.global_coord,
            bins=10,
            min_count=10,
            s=50,
        )
        visuals.plot_heatmap_2d(
            out_bouts_stacked,
            x_key="global_coord",
            y_key="firing_rate",
            ax=axes["T"],
            vmax=None,
            zorder=-10,
            alpha=0.5,
            cmap="inferno",
            linewidths=0,
            gridsize=20,
        )

        # plot firing rate binned by speed and angular velocity
        visuals.plot_bin_x_by_y(
            out_bouts_stacked,
            "firing_rate",
            "speed",
            axes["N"],
            colors=colors.speed,
            bins=10,
            min_count=10,
            s=50,
        )
        visuals.plot_heatmap_2d(
            out_bouts_stacked,
            x_key="speed",
            y_key="firing_rate",
            ax=axes["N"],
            vmax=None,
            zorder=-10,
            alpha=0.5,
            cmap="inferno",
            linewidths=0,
            gridsize=20,
        )

        visuals.plot_bin_x_by_y(
            out_bouts_stacked,
            "firing_rate",
            "dmov_velocity",
            axes["O"],
            colors=colors.dmov_velocity,
            bins=10,
            min_count=10,
            s=50,
        )
        visuals.plot_heatmap_2d(
            out_bouts_stacked,
            x_key="dmov_velocity",
            y_key="firing_rate",
            ax=axes["O"],
            vmax=None,
            zorder=-10,
            alpha=0.5,
            cmap="inferno",
            linewidths=0,
            gridsize=20,
        )
        axes["H"].axvline(0, ls=":", lw=2, color=[0.2, 0.2, 0.2], zorder=101)

        # ----------------------------- cleanup and save ----------------------------- #
        clean_axes(f)
        set_figure_subplots_aspect(wspace=0.5, hspace=0.6, left=0.3)
        move_figure(f, 50, 50)
        f.tight_layout()
        axes["A"].set(
            xlabel="xpos (cm)",
            ylabel="ypos (cm)",
            title=
            f"unit {unit.unit_id} {unit.brain_region} | {len(unit.spikes)} spikes",
        )
        axes["B"].set(
            xticks=frames,
            xticklabels=time,
            xlabel="time (min)",
            ylabel="speed (cm/s)",
        )
        axes["C"].set(
            xticks=frames,
            xticklabels=time,
            xlabel="time (min)",
            ylabel="ang vel (deg/s)",
        )
        axes["D"].set(title="tone onset")
        axes["F"].set(title="firing rate", ylabel="IN BOUTS")
        axes["G"].set(ylabel="firing rate", xlabel="speed")
        axes["H"].set(ylabel="firing rate",
                      xlabel="angular velocity",
                      xlim=[-350, 350])
        axes["I"].set(title="bout ONSET")
        # axes['L'].set(ylabel='firing rate', xlabel='global coord', xticks=np.arange(0, 1.1, .25))
        axes["M"].set(title="firing rate", ylabel="OUT BOUTS")
        axes["N"].set(ylabel="firing rate", xlabel="speed")
        axes["O"].set(ylabel="firing rate",
                      xlabel="angular velocity",
                      xlim=[-350, 350])
        axes["P"].set(title="bouts OFFSET")
        # axes['Q'].set(ylabel='firing rate', xlabel='angular vel')
        axes["R"].set(title="firing rate")
        axes["S"].set(ylabel="firing rate", xlabel="global coord")
        axes["T"].set(ylabel="firing rate", xlabel="global coord")
        axes["U"].set(ylabel="firing rate", xlabel="speed")
        axes["V"].set(ylabel="firing rate",
                      xlabel="angular velocity",
                      xlim=[-350, 350])
        axes["Z"].set(xticks=[])

        axes["Y"].set(
            title="bouts firing rate heatmap",
            xlabel="speed (cm/s)",
            ylabel="ang vel (deg/s)",
        )

        for ax in "ARFM":
            axes[ax].axis("equal")
            axes[ax].set(xlim=[-5, 45],
                         xticks=[0, 40],
                         ylim=[-5, 65],
                         yticks=[0, 60])
예제 #4
0
def plot_hairpin_tracking(
    session_name: str,
    tracking: pd.DataFrame,
    downsampled_tracking: pd.DataFrame,
    bouts: pd.DataFrame,
    out_bouts: pd.DataFrame,
    in_bouts: pd.DataFrame,
    bouts_stacked: pd.DataFrame,
    out_bouts_stacked: pd.DataFrame,
    in_bouts_stacked: pd.DataFrame,
):

    # crate figure
    f = plt.figure(figsize=(24, 12))
    axes = f.subplot_mosaic("""
        ABCDERX
        ABCDERX
        FGPHHSY
        FGPIITY
        LMQNNUW
        LMQOOVW
    """)
    f.suptitle(session_name)
    f._save_name = "tracking_data_2d"

    # draw tracking 2D
    visuals.plot_tracking_xy(
        downsampled_tracking,
        ax=axes["A"],
        plot=True,
        color=[0.6, 0.6, 0.6],
        alpha=0.5,
    )
    visuals.plot_bouts_2d(tracking, bouts, axes["A"], lw=2, zorder=100)

    # draw tracking 1D
    visuals.plot_tracking_linearized(downsampled_tracking,
                                     ax=axes["B"],
                                     plot=True,
                                     color=[0.6, 0.6, 0.6])
    visuals.plot_bouts_1d(tracking, bouts, axes["B"], lw=3, zorder=100)

    # plot speed aligned to bouts starts and ends
    visuals.plot_aligned(tracking.speed,
                         bouts.start_frame,
                         axes["C"],
                         "after",
                         alpha=0.5)
    visuals.plot_aligned(tracking.speed,
                         bouts.end_frame,
                         axes["D"],
                         "pre",
                         alpha=0.5)

    # plot histograms of bouts durations
    axes["E"].hist(
        out_bouts.duration,
        color=colors.outbound,
        label="out",
        bins=15,
        alpha=0.7,
        ec=[0.2, 0.2, 0.2],
        histtype="stepfilled",
        lw=2,
    )
    axes["E"].hist(
        in_bouts.duration,
        color=colors.inbound,
        label="in",
        bins=15,
        alpha=0.7,
        ec=[0.2, 0.2, 0.2],
        histtype="stepfilled",
        lw=2,
    )

    # plot histograms of speed and angular velocity distribution
    axes["X"].hist(tracking.speed, bins=50, color=colors.speed, density=True)
    axes["X"].axvline(
        db_tables.Movement.moving_threshold,
        lw=2,
        ls="--",
        color=grey_dark,
        zorder=100,
    )
    axes["Y"].hist(
        tracking.angular_velocity,
        bins=50,
        color=colors.angular_velocity,
        density=True,
    )
    axes["Y"].axvline(
        db_tables.Movement.turning_threshold,
        lw=2,
        ls="--",
        color=grey_dark,
        zorder=100,
    )
    axes["Y"].axvline(
        -db_tables.Movement.turning_threshold,
        lw=2,
        ls="--",
        color=grey_dark,
        zorder=100,
    )

    # plot speed vs angular velocity inn all data and during botus only
    trk = dict(
        speed=tracking.speed[tracking.moving == 1],
        angular_velocity=tracking.angular_velocity[tracking.moving == 1],
    )
    visuals.plot_heatmap_2d(
        trk,
        key=None,
        ax=axes["R"],
        x_key="speed",
        y_key="angular_velocity",
        vmax=None,
    )
    visuals.plot_heatmap_2d(
        bouts_stacked,
        key=None,
        ax=axes["W"],
        x_key="speed",
        y_key="angular_velocity",
        vmax=None,
    )

    # draw speed and orientation heatmaps during bouts
    visuals.plot_heatmap_2d(
        in_bouts_stacked,
        "speed",
        ax=axes["F"],
        alpha=1,
        vmax=30,
        cmap="inferno",
    )
    visuals.plot_heatmap_2d(
        out_bouts_stacked,
        "speed",
        ax=axes["L"],
        alpha=1,
        vmax=30,
        cmap="inferno",
    )
    visuals.plot_heatmap_2d(
        in_bouts_stacked,
        "orientation",
        ax=axes["G"],
        alpha=1,
        vmin=0,
        vmax=360,
        edgecolors=grey_darker,
    )
    visuals.plot_heatmap_2d(
        out_bouts_stacked,
        "orientation",
        ax=axes["M"],
        alpha=1,
        vmin=0,
        vmax=360,
        edgecolors=grey_darker,
    )
    visuals.plot_heatmap_2d(
        in_bouts_stacked,
        "angular_velocity",
        ax=axes["P"],
        alpha=1,
        vmin=-45,
        vmax=45,
        edgecolors=grey_darker,
    )
    visuals.plot_heatmap_2d(
        out_bouts_stacked,
        "angular_velocity",
        ax=axes["Q"],
        alpha=1,
        vmin=-45,
        vmax=45,
        edgecolors=grey_darker,
    )

    # plot speeds binned by global coords for in/out bouts
    nbins = 25
    clrs = make_palette(grey_light, grey_dark, nbins - 1)
    clrs2 = make_palette(amber_light, amber_darker, nbins - 1)

    visuals.plot_bin_x_by_y(
        in_bouts_stacked,
        "speed",
        "global_coord",
        axes["H"],
        bins=np.linspace(0, 1, nbins),
        colors=clrs,
    )
    visuals.plot_bin_x_by_y(
        in_bouts_stacked,
        "angular_velocity",
        "global_coord",
        axes["I"],
        bins=np.linspace(0, 1, nbins),
        colors=clrs2,
    )
    visuals.plot_bin_x_by_y(
        out_bouts_stacked,
        "speed",
        "global_coord",
        axes["N"],
        bins=np.linspace(0, 1, nbins),
        colors=clrs,
    )
    visuals.plot_bin_x_by_y(
        out_bouts_stacked,
        "angular_velocity",
        "global_coord",
        axes["O"],
        bins=np.linspace(0, 1, nbins),
        colors=clrs2,
    )

    # plot histograms with speed and angular velocity
    axes["S"].hist(in_bouts_stacked.speed, bins=20, color=colors.speed)
    axes["T"].hist(
        in_bouts_stacked.angular_velocity,
        bins=20,
        color=colors.angular_velocity,
    )
    axes["U"].hist(out_bouts_stacked.speed, bins=20, color=colors.speed)
    axes["V"].hist(
        out_bouts_stacked.angular_velocity,
        bins=20,
        color=colors.angular_velocity,
    )

    # ---------------------------------- cleanup --------------------------------- #
    for ax in "IO":
        axes[ax].axhline(0, lw=1, color=[0.6, 0.6, 0.6], zorder=-1)

    # plot bouts centered
    # visuals.plot_bouts_x_by_y(tracking, in_bouts, axes['S'], 'speed', 'global_coord')

    # cleanup and save
    clean_axes(f)
    move_figure(f, 50, 50)
    f.tight_layout()

    axes["A"].set(
        xlabel="xpos (cm)",
        ylabel="ypos (cm)",
        title=f"{len(in_bouts)} IN - {len(out_bouts)} OUT",
    )
    axes["B"].set(ylabel="time in exp", xlabel="arena position")
    axes["C"].set(
        ylabel="speed (cm/s)",
        xticks=[0, 60, 120],
        xticklabels=[-1, 0, 1],
        xlabel="time (s)",
        title="bout onset",
    )
    axes["D"].set(
        ylabel="speed (cm/s)",
        xticks=[0, 60, 120],
        xticklabels=[-1, 0, 1],
        xlabel="time (s)",
        title="bout offset",
    )
    axes["E"].set(ylabel="counts",
                  xlabel="duration (s)",
                  title="Bouts duration")
    axes["F"].set(ylabel="IN", xticks=[], yticks=[], title="speed")
    axes["G"].set(xticks=[], yticks=[], title="orientation")
    axes["H"].set(ylabel="speed (cm/s)", xticks=[])
    axes["I"].set(ylabel="ang vel (deg/s)", xticks=[])
    axes["L"].set(ylabel="OUT", xticks=[], yticks=[], xlabel="speed")
    axes["M"].set(xticks=[], yticks=[], xlabel="orientation")
    axes["N"].set(ylabel="speed (cm/s)", xticks=[])
    axes["O"].set(ylabel="ang vel (deg/s)", xticks=np.linspace(0, 1, 11))
    axes["P"].set(xticks=[], yticks=[], title="angular velocity")
    axes["Q"].set(xticks=[], yticks=[], xlabel="ang vel (deg/s)")
    axes["R"].set(xlabel="speed (cm/s)",
                  ylabel="ang vel (deg/s)",
                  title="moving")
    axes["S"].set(xlabel="speed (cm/s)", title="in bouts")
    axes["T"].set(xlabel="angular velocity (deg/s)")
    axes["U"].set(xlabel="speed (cm/s)", title="out bouts")
    axes["V"].set(xlabel="angular velocity (deg/s)")
    axes["X"].set(xlabel="speed (cm/s)", title="all tracking")
    axes["Y"].set(xlabel="angular velocity (deg/s)")
    axes["W"].set(xlabel="speed (cm/s)",
                  ylabel="ang vel (deg/s)",
                  title="bouts heatmap")

    for ax in "AFGLMPQ":
        axes[ax].axis("equal")
        axes[ax].set(xlim=[-5, 45],
                     ylim=[-5, 65],
                     xticks=[0, 40],
                     yticks=[0, 60])
예제 #5
0
                               zorder=100,
                               alpha=0.5)

        # plot against speed
        axes_dict["C"].plot(time, speed, color=blue_grey)
        axes_dict["C"].scatter(
            time[spike_times],
            speed[spike_times],
            color=color,
            zorder=100,
            alpha=0.5,
        )

        # plot against speed
        axes_dict["D"].plot(time, avel, color=blue_grey)
        axes_dict["D"].scatter(
            time[spike_times],
            avel[spike_times],
            color=color,
            zorder=100,
            alpha=0.5,
        )

        # cleanup and save
        axes_dict["A"].set(xlabel="xpos (cm)", ylabel="ypos (cm)")
        axes_dict["C"].set(xlabel="time (frames)", ylabel="speed (cm/s)")
        clean_axes(f)

        plt.show()
        # break
예제 #6
0
def plot_unit_firing_rate(unit: pd.Series, end: int = 60):
    """
        For a single unit plot the firing rate at each moment for different firing rate windows
    """
    if unit.empty:
        raise ValueError(
            "An empty pandas series was passed, perhaps the unit ID was invalid."
        )

    # get the unit data
    name = unit["name"]
    data = pd.DataFrame((Unit * Unit.Spikes * FiringRate
                         & f'name="{name}"'
                         & f"unit_id={unit.unit_id}").fetch())
    frate_windows = data.firing_rate_std.values
    n_frate_windows = len(frate_windows)
    logger.info(
        f"Found {n_frate_windows} firing rate windows to plot for unit {unit.unit_id}: {frate_windows}"
    )

    # create figure
    f, axes = plt.subplots(nrows=2, sharex=True, figsize=(16, 9))
    f.suptitle(f"Unit {unit.unit_id} firing rate")
    f._save_title = f"unit_{unit.unit_id}_firing_rate"

    # plot spikes
    spikes = unit.spikes[unit.spikes < end * 60]
    axes[0].hist(spikes, bins=end * 10, color=[0.5, 0.5, 0.5], density=True)
    axes[0].scatter(
        spikes,
        np.random.uniform(-0.05, -0.001, size=len(spikes)),
        color=black,
        s=25,
        zorder=1,
        label="spike times",
    )

    palette = make_palette(orange_dark, blue_light, n_frate_windows)
    for frate, color in zip(frate_windows, palette):
        if frate != 100:
            continue
        frate_data = (
            data.loc[data.firing_rate_std == frate].iloc[0].firing_rate[:end *
                                                                        60])
        axes[1].plot(frate_data,
                     lw=2,
                     color=color,
                     label=f"kernel std: {frate} ms")
        # break

    # cleanup
    clean_axes(f)

    time = np.arange(0, (end + 1) * 60, 60 * 2)
    axes[1].legend()
    axes[1].set(
        xticks=time,
        xticklabels=(time / 60).astype(np.int32),
        xlabel="time (s)",
        ylabel="firing rate",
    )