Пример #1
0
def main(dir):
    save_dir = os.path.join(dir, "plots", "phase")
    container = NDataContainer(load_on_fly=True)
    container.add_axona_files_from_dir(dir)
    container.setup()
    spike_names = container.get_file_dict()["Spike"]

    for i, ndata in enumerate(container):
        name = spike_names[container._index_to_data_pos(i)[0]]
        results = ndata.phase_at_spikes(should_filter=True)
        positions = results["positions"]
        phases = results["phases"]
        good_place = results["good_place"]
        directions = results["directions"]
        co_ords = {}
        co_ords["north"] = np.nonzero((45 <= directions) & (directions < 135))
        co_ords["south"] = np.nonzero((225 <= directions)
                                      & (directions <= 315))
        if (phases.size != 0) and good_place:
            for direction in "north", "south":
                dim_pos = positions[1][co_ords[direction]]
                directional_phases = phases[co_ords[direction]]
                fig, ax = plt.subplots()
                ax.scatter(dim_pos, directional_phases)
                # ax.hist2d(dim_pos, directional_phases, bins=[10, 90])
                parts = os.path.basename(name[0]).split(".")
                end_name = (parts[0] + "_unit" + str(ndata.get_unit_no()) +
                            "_" + direction + ".png")
                out_name = os.path.join(save_dir, end_name)
                make_dir_if_not_exists(out_name)
                fig.savefig(out_name)
                plt.close(fig)
Пример #2
0
def single_data_phase(dir, spike_name, pos_name, lfp_name):
    ndata = load_data(dir, spike_name, pos_name, lfp_name)
    spike_file = os.path.join(dir, spike_name)
    save_dir = os.path.join(dir, "plots", "phase")
    results = ndata.phase_at_spikes(should_filter=True)
    positions = results["positions"]
    phases = results["phases"]
    good_place = results["good_place"]
    directions = results["directions"]
    co_ords = {}
    co_ords["north"] = np.nonzero((45 <= directions) & (directions < 135))
    co_ords["south"] = np.nonzero((225 <= directions) & (directions <= 315))
    if (phases.size != 0) and good_place:
        for direction in "north", "south":
            dim_pos = positions[1][co_ords[direction]]
            directional_phases = phases[co_ords[direction]]
            fig, ax = plt.subplots()
            ax.scatter(dim_pos, directional_phases)
            # ax.hist2d(dim_pos, directional_phases, bins=[10, 90])
            parts = os.path.basename(spike_file[0]).split(".")
            end_name = (parts[0] + "_unit" + str(ndata.get_unit_no()) + "_" +
                        direction + ".png")
            out_name = os.path.join(save_dir, end_name)
            make_dir_if_not_exists(out_name)
            fig.savefig(out_name)
            plt.close(fig)
def main(dir, bin_size, bin_bound):
    container = NDataContainer(load_on_fly=True)
    container.add_axona_files_from_dir(dir, recursive=True)
    container.setup()
    print(container.string_repr(True))
    event = NEvent()
    last_stm_name = None
    dict_list = []

    for i in range(len(container)):
        try:
            result_dict = OrderedDict()
            data_index, _ = container._index_to_data_pos(i)
            stm_name = container.get_file_dict("STM")[data_index][0]
            data = container[i]

            # Add extra keys
            spike_file = data.spike.get_filename()
            spike_dir = os.path.dirname(spike_file)
            spike_name = os.path.basename(spike_file)
            spike_name_only, spike_ext = os.path.splitext(spike_name)
            spike_ext = spike_ext[1:]
            result_dict["Dir"] = spike_dir
            result_dict["Name"] = spike_name_only
            result_dict["Tet"] = int(spike_ext)
            result_dict["Unit"] = data.spike.get_unit_no()

            # Do analysis
            if last_stm_name != stm_name:
                event.load(stm_name, 'Axona')
                last_stm_name = stm_name

            graph_data = event.psth(data.spike, bins=bin_size, bound=bin_bound)
            spike_count = data.get_unit_spikes_count()
            result_dict["Num_Spikes"] = spike_count

            # Bin renaming
            for (b, v) in zip(graph_data["all_bins"][:-1], graph_data["psth"]):
                result_dict[str(b)] = v
            dict_list.append(result_dict)

            # Do plotting
            name = (spike_name_only + "_" + spike_ext + "_" +
                    str(result_dict["Unit"]) + ".png")
            plot_name = os.path.join(dir, "psth_results", name)
            make_dir_if_not_exists(plot_name)
            plot_psth(graph_data, plot_name)
            print("Saved psth to {}".format(plot_name))

        except Exception as e:
            log_exception(e, "During stimulation batch at {}".format(i))
            dict_list.append(result_dict)

    fname = os.path.join(dir, "psth_results", "psth.csv")
    save_dicts_to_csv(fname, dict_list)
    print("Saved results to {}".format(fname))
Пример #4
0
    def get_name_at_idx(self,
                        idx,
                        ext,
                        opt_end="",
                        base_dir=None,
                        out_dirname="nc_plots"):
        """
        Get the filename to save an index in the collection to.

        Parameters
        ----------
        idx : int
            The index of the collection to get the filename for.
        ext : str
            The extension to append to the filename.
        opt_end : str, optional.
            Used like this default_name + opt_end + ext
        base_dir : str, optional.
            One can specify a directory that all files originated from.
            It is used like so:
            Say data1 is in test/foo/data.txt
            and data2 is in test/bar/data.txt
            Then passing base_dir as test
            Would set the names to
            out_dirname/foo--data.txt
            out_dirname/bar--data.txt
        outdirname : str, optional
            The directory to save the plots to.
            This is relative to the directory of the filename
            if base dir is None.

        Returns
        -------
        str

        """
        data_idx, unit_idx = self._index_to_data_pos(idx)
        filename = self.get_file_dict()["Spike"][data_idx][0]
        unit_number = self.get_units(data_idx)[unit_idx]
        spike_name = os.path.basename(filename)
        final_bname, final_ext = os.path.splitext(spike_name)
        final_ext = final_ext[1:]
        f_dir = os.path.dirname(filename)
        data_basename = (final_bname + "_" + final_ext + "_" +
                         str(unit_number) + opt_end + "." + ext)
        if base_dir is not None:
            main_dir = base_dir
            out_base = f_dir[len(base_dir + os.sep):]
            if len(out_base) != 0:
                out_base = ("--").join(out_base.split(os.sep))
                data_basename = out_base + "--" + data_basename
        else:
            main_dir = f_dir
        out_name = os.path.join(main_dir, out_dirname, data_basename)
        make_dir_if_not_exists(out_name)
        return out_name
Пример #5
0
def pca_clustering(container,
                   in_dir,
                   n_isi_comps=3,
                   n_auto_comps=2,
                   opt_end="",
                   s_color=False):
    """
    Wraps up other functions to do PCA clustering on a container.

    Computes PCA for ISI and AC and then clusters based on these.

    Params
    ------
    container - the input NDataContainer to consider
    in_dir - the directory to save information to
    n_isi_comps - the number of principal components for isi
    n_auto_comps - the number of principla components for auto_corr
    """
    print("Considering ISIH PCA")
    make_dir_if_not_exists(os.path.join(in_dir, "nc_plots", "dummy.txt"))
    isi_hist_matrix = calculate_isi_hist(container,
                                         in_dir,
                                         opt_end=opt_end,
                                         s_color=s_color)
    isi_after_pca, isi_pca = perform_pca(isi_hist_matrix, n_isi_comps, True)
    print("Considering ACH PCA")
    auto_corr_matrix = calculate_auto_corr(container,
                                           in_dir,
                                           opt_end=opt_end,
                                           s_color=s_color)
    corr_after_pca, corr_pca = perform_pca(auto_corr_matrix, n_auto_comps,
                                           True)
    joint_pca = np.empty((len(container), n_isi_comps + n_auto_comps),
                         dtype=float)
    joint_pca[:, :n_isi_comps] = isi_after_pca
    joint_pca[:, n_isi_comps:n_isi_comps + n_auto_comps] = corr_after_pca
    clust, dend = ward_clustering(joint_pca,
                                  in_dir,
                                  0,
                                  3,
                                  opt_end=opt_end,
                                  s_color=s_color)
    plot_clustering(container,
                    in_dir,
                    isi_hist_matrix,
                    auto_corr_matrix,
                    dend,
                    clust,
                    opt_end=opt_end)
    fname = os.path.join(in_dir, "nc_results",
                         "PCA_results" + opt_end + ".csv")
    save_pca_res(container, fname, n_isi_comps, n_auto_comps, isi_pca,
                 corr_pca, clust, dend, joint_pca)
Пример #6
0
def place_cell_summary(
        collection, dpi=150, out_dirname="nc_plots",
        filter_place_cells=False, filter_low_freq=False,
        opt_end="", base_dir=None, output_format="png",
        output=["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"],
        isi_bound=350, isi_bin_length=2, fixed_color=None,
        save_data=False, point_size=None, color_isi=True,
        burst_thresh=5, hd_predict=False):
    """
    Quick Png spatial information summary of each cell in collection.

    Parameters
    ----------
    collection : NDataCollection
        The collection to plot summaries of.
    dpi : int, default 150
        Dpi of the output figures.
    out_dirname : str, default "nc_plots
        The relative name of the dir to save pngs to
    filter_place_cells: bool, default True
        Whether to filter out non spatial cells from the plots.
        Considered non spatial if shuffled Skaggs, Coherency or Sparsity
        is similar to actual values.
    filter_low_freq: bool, default True
        Filter out cells with spike freq less than 0.1Hz
    opt_end : str, default ""
        A string to append to the file output just before the extension
    base_dir : str, default None
        An optional directory to save the files to
    output_format : str, default png
        What format to save the output image in
    output : List of str,
        default ["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"]
        Input should be some subset and/or permutation of these
    isi_bound: int, default 350
        How long in ms to plot the ISI to
    isi_bin_length: int, default 1
        How long in ms the ISI bins should be
    save_data: bool, default False
        Whether to save out the information used for the plot
    color_isi: bool, default True
        Whether the ISI should be black or blue
    burst_thresh: int, default 5
        How long in ms to consider the window for burst to be
    hd_predict: bool, default False
        Whether the head directional graph should be plotted with predicted HD.

    Returns
    -------
    None
    """

    def save_dicts_to_csv(filename, dicts_arr):
        """Saves the last element of each arr in dicts_arr to file"""
        with open(filename, "w") as f:
            for d_arr in dicts_arr:
                if d_arr is not None:
                    d = d_arr[-1]
                    for k, v in d.items():
                        out_str = k.replace(" ", "_")
                        if isinstance(v, Iterable):
                            if isinstance(v, np.ndarray):
                                v = v.flatten()
                            else:
                                v = np.array(v).flatten()
                            str_arr = [str(x) for x in v]
                            out_str = out_str + "," + ",".join(str_arr)
                        else:
                            out_str += "," + str(v)
                        f.write(out_str + "\n")

    good_placedata = []
    good_graphdata = []
    good_wavedata = []
    good_headdata = []
    good_thetadata = []
    good_isidata = []
    good_units = []
    bad_placedata = []
    bad_graphdata = []
    bad_wavedata = []
    bad_headdata = []
    bad_thetadata = []
    bad_isidata = []
    bad_units = []
    skipped = 0

    if point_size is None:
        point_size = dpi / 7
    for i, data in enumerate(collection):
        try:
            data_idx, unit_idx = collection._index_to_data_pos(i)
            print("Working on {}, {}".format(unit_idx, len(
                collection.get_units(data_idx)) - 1))
            filename = collection.get_file_dict()["Spike"][data_idx][0]
            unit_number = collection.get_units(data_idx)[unit_idx]
            print("Working on {} unit {}".format(
                filename, unit_number))

            count = data.spike.get_unit_spikes_count()
            # Skip very low count cells
            if count < 5:
                skipped += 1
                print("Skipping as only {} spikes".format(count))
            else:
                duration = data.spike.get_duration()
                good = True

                # Place cell filtering is based on
                # https://www.nature.com/articles/ncomms11824
                # Activity-plasticity of hippocampus place maps
                # Schoenenberger et al, 2016
                if filter_low_freq:
                    if (count / duration) < 0.25 or (count / duration) > 7:
                        print("Reject spike frequency {}".format(
                            count / duration))
                        good = False

                if good and filter_place_cells:
                    skaggs = data.loc_shuffle(nshuff=1)
                    bad_sparsity = skaggs['refSparsity'] >= 0.3
                    bad_cohere = skaggs['refCoherence'] <= 0.55
                    first_str_part = "Accept "

                    if bad_sparsity or bad_cohere:
                        good = False
                        first_str_part = "Reject "

                    print((
                        first_str_part +
                        "Skaggs {:2f}, " +
                        "Sparsity {:2f}, " +
                        "Coherence {:2f}").format(
                        skaggs['refSkaggs'], skaggs['refSparsity'],
                        skaggs['refCoherence']))

                if good:
                    good_units.append(unit_idx)
                    placedata = good_placedata
                    graphdata = good_graphdata
                    wavedata = good_wavedata
                    headdata = good_headdata
                    thetadata = good_thetadata
                    isidata = good_isidata
                else:
                    bad_units.append(unit_idx)
                    placedata = bad_placedata
                    graphdata = bad_graphdata
                    wavedata = bad_wavedata
                    headdata = bad_headdata
                    thetadata = bad_thetadata
                    isidata = bad_isidata

                if (
                    (len(bad_units) + len(good_units)) >
                        len(collection.get_units(data_idx))):
                    save_bad = bad_units
                    save_good = good_units
                    good_placedata = []
                    good_graphdata = []
                    good_wavedata = []
                    good_headdata = []
                    good_thetadata = []
                    good_isidata = []
                    good_units = []
                    bad_placedata = []
                    bad_graphdata = []
                    bad_wavedata = []
                    bad_headdata = []
                    bad_thetadata = []
                    bad_isidata = []
                    bad_units = []
                    skipped = 0
                    print("ERROR: Found too many units in the collection")
                    raise Exception(
                        "Accumlated more units than possible " +
                        "bad {} good {} total {}".format(
                            save_bad, save_good, collection.get_units(data_idx)))

                if "Place" in output:
                    placedata.append(data.place())
                else:
                    bad_placedata = None
                    good_placedata = None
                if "LowAC" in output:
                    graphdata.append(data.isi_corr(bins=1, bound=[-10, 10]))
                else:
                    bad_graphdata = None
                    good_graphdata = None
                if "Wave" in output:
                    wavedata.append(data.wave_property())
                else:
                    bad_wavedata = None
                    good_wavedata = None
                if "HD" in output:
                    headdata.append(data.hd_rate())
                else:
                    bad_headdata = None
                    good_headdata = None
                if "Theta" in output:
                    thetadata.append(
                        data.theta_index(
                            bins=2, bound=[-350, 350]))
                else:
                    bad_thetadata = None
                    good_thetadata = None
                if "HighISI" in output:
                    isidata.append(
                        data.isi(bins=int(isi_bound / isi_bin_length),
                                 bound=[0, isi_bound]))
                else:
                    bad_isidata = None
                    good_isidata = None

                if save_data:
                    try:
                        spike_name = os.path.basename(filename)
                        parts = spike_name.split(".")
                        f_dir = os.path.dirname(filename)

                        data_basename = (
                            parts[0] + "_" + parts[1] + "_" +
                            str(unit_number) + opt_end + ".csv")
                        if base_dir is not None:
                            main_dir = base_dir
                            out_base = f_dir[len(base_dir + os.sep):]
                            if len(out_base) != 0:
                                out_base = ("--").join(out_base.split(os.sep))
                                data_basename = out_base + "--" + data_basename
                        else:
                            main_dir = f_dir
                        out_name = os.path.join(
                            main_dir, out_dirname, "data", data_basename)
                        make_dir_if_not_exists(out_name)
                        save_dicts_to_csv(
                            out_name,
                            [placedata, graphdata, wavedata,
                                headdata, thetadata, isidata])
                    except Exception as e:
                        log_exception(
                            e, "Occurred during place cell data saving on" +
                            " {} unit {} name {} in {}".format(
                                data_idx, unit_number, spike_name, main_dir))

            # Save the accumulated information
            if unit_idx == len(collection.get_units(data_idx)) - 1:
                if ((len(bad_units) + len(good_units)) !=
                        len(collection.get_units(data_idx)) - skipped):
                    print("ERROR: Did not cover all units in the collection")
                    print("Good {}, Bad {}, Total {}".format(
                        bad_units, good_units, collection.get_units(data_idx)))
                spike_name = os.path.basename(filename)
                parts = spike_name.split(".")
                f_dir = os.path.dirname(filename)

                out_basename = (
                    parts[0] + "_" + parts[1] + opt_end + "." + output_format)

                if base_dir is not None:
                    main_dir = base_dir
                    out_base = f_dir[len(base_dir + os.sep):]
                    if len(out_base) != 0:
                        out_base = ("--").join(out_base.split(os.sep))
                        out_basename = out_base + "--" + out_basename
                else:
                    main_dir = f_dir

                if filter_place_cells:
                    named_units = [
                        collection.get_units(data_idx)[j]
                        for j in good_units]
                    bad_named_units = [
                        collection.get_units(data_idx)[j]
                        for j in bad_units]
                else:
                    named_units = collection.get_units(data_idx)
                    bad_named_units = []

                # Save figures one by one if using pdf or svg
                one_by_one = (output_format == "pdf") or (
                    output_format == "svg")

                if len(named_units) > 0:
                    if filter_place_cells:
                        print((
                            "Plotting summary for {} " +
                            "spatial units {}").format(
                            spike_name, named_units))
                    else:
                        print((
                            "Plotting summary for {} " +
                            "units {}").format(
                            spike_name, named_units))

                    fig = print_place_cells(
                        len(named_units), cols=len(output),
                        placedata=good_placedata, graphdata=good_graphdata,
                        wavedata=good_wavedata, headdata=good_headdata,
                        thetadata=good_thetadata, isidata=good_isidata,
                        size_multiplier=4, point_size=point_size,
                        units=named_units, fixed_color=fixed_color,
                        output=output, color_isi=color_isi,
                        burst_ms=burst_thresh, one_by_one=one_by_one,
                        raster=one_by_one, hd_predict=hd_predict)

                    if one_by_one:
                        for k, f in enumerate(fig):
                            unit_number = named_units[k]
                            iname = (
                                out_basename[:-4] + "_" +
                                str(unit_number) + out_basename[-4:])
                            if filter_low_freq or filter_place_cells:
                                out_name = os.path.join(
                                    main_dir, out_dirname, "good", iname)
                            else:
                                out_name = os.path.join(
                                    main_dir, out_dirname, iname)

                            print("Saving place cell figure to {}".format(
                                out_name))
                            make_dir_if_not_exists(out_name)
                            f.savefig(out_name, dpi=dpi, format=output_format)

                    else:
                        if filter_low_freq or filter_place_cells:
                            out_name = os.path.join(
                                main_dir, out_dirname, "good", out_basename)
                        else:
                            out_name = os.path.join(
                                main_dir, out_dirname, out_basename)
                        print("Saving place cell figure to {}".format(
                            out_name))
                        make_dir_if_not_exists(out_name)
                        fig.savefig(out_name, dpi=dpi, format=output_format)

                    close("all")
                    gc.collect()

                if len(bad_named_units) > 0:
                    print((
                        "Plotting bad summary for {} " +
                        "non-spatial units {}").format(
                        spike_name, bad_named_units))
                    fig = print_place_cells(
                        len(bad_named_units), cols=len(output),
                        placedata=bad_placedata, graphdata=bad_graphdata,
                        wavedata=bad_wavedata, headdata=bad_headdata,
                        thetadata=bad_thetadata, isidata=bad_isidata,
                        size_multiplier=4, point_size=point_size,
                        units=bad_named_units, fixed_color=fixed_color,
                        output=output, color_isi=color_isi,
                        burst_ms=burst_thresh, one_by_one=one_by_one,
                        raster=one_by_one, hd_predict=hd_predict)

                    if one_by_one:
                        for k, f in enumerate(fig):
                            unit_number = bad_named_units[k]
                            iname = (
                                out_basename[:-4] + "_" +
                                str(unit_number) + out_basename[-4:])
                            out_name = os.path.join(
                                main_dir, out_dirname, "bad", iname)

                            print("Saving place cell figure to {}".format(
                                out_name))
                            make_dir_if_not_exists(out_name)
                            f.savefig(out_name, dpi=dpi,
                                      format=output_format)

                    else:
                        out_name = os.path.join(
                            main_dir, out_dirname, "bad", out_basename)
                        print("Saving place cell figure to {}".format(
                            out_name))
                        make_dir_if_not_exists(out_name)
                        fig.savefig(out_name, dpi=dpi,
                                    format=output_format)

                    close("all")
                    gc.collect()

                good_placedata = []
                good_graphdata = []
                good_wavedata = []
                good_headdata = []
                good_thetadata = []
                good_isidata = []
                good_units = []
                bad_placedata = []
                bad_graphdata = []
                bad_wavedata = []
                bad_headdata = []
                bad_thetadata = []
                bad_isidata = []
                bad_units = []
                skipped = 0

        except Exception as e:
            log_exception(
                e, "Occurred during place cell summary on data" +
                   " {} unit {} name {} in {}".format(
                       data_idx, unit_number, spike_name, main_dir))
    return
Пример #7
0
        The absolute path to the source directory
    exts : list
        The extension of files to get.
    verbose: bool, optional. Defaults to False.
        Whether to print the destination of files.
    re_filter: str, optional. Defaults to Nonension

    """

    basename = os.path.basename(fname)
    for ext in exts:
        dest = shutil.copy(os.path.join(src_dir, fname + ext),
                           os.path.join(dest_dir, basename + ext))

    if verbose:
        print(dest)


if __name__ == "__main__":
    src_dir = r"F:\Ham Data\A9_CAR-SA1"

    dest_dir = r"G:\PhD (Shane O'Mara)\Operant Data\Recordings"
    make_dir_if_not_exists(dest_dir)

    # Analysis flags:
    # 0 - Convert .bin -> .inp in src_dir
    # 1 - Transfer .inp from src_dir to dest_dir
    analysis_flags = [1, 0]

    main(src_dir, dest_dir, analysis_flags)
Пример #8
0
import datetime
import logging
import time
import os
import traceback
import sys

from PyQt5 import QtWidgets

from neurochat.nc_ui import NeuroChaT_Ui
from neurochat.nc_utils import make_dir_if_not_exists

default_write = sys.stdout.write
default_loc = os.path.join(
    os.path.expanduser("~"), ".nc_saved", "nc_errorlog.txt")
make_dir_if_not_exists(default_loc)
this_logger = logging.getLogger(__name__)
handler = logging.FileHandler(default_loc)
this_logger.addHandler(handler)


def excepthook(exc_type, exc_value, exc_traceback):
    """
    Any uncaught exceptions will be logged from here.

    """
    # Don't catch CTRL+C exceptions
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return
def place_cell_summary(
        collection,
        dpi=150,
        out_dirname="nc_plots",
        filter_place_cells=False,
        filter_low_freq=False,
        opt_end="",
        base_dir=None,
        output_format="png",
        output=["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"],
        isi_bound=350,
        isi_bin_length=2,
        fixed_color=None,
        save_data=False,
        point_size=None,
        color_isi=True,
        burst_thresh=5,
        hd_predict=False,
        one_by_one="auto"):
    """
    Perform spatial information summary of each cell in collection.

    The function is named as place_cell_summary as it can be a
    quick visual way to look for place cells.
    However, it can be used to look other spatial properties,
    such as head directionality.

    The output image (any matplotlib format is supported) contains
    the information from and in the order of the output argument.

    Parameters
    ----------
    collection : NDataCollection
        The collection to plot spatial summaries of.
    dpi : int, default 150
        Dpi of the output figures if the output_format supports dpi.
    out_dirname : str, default "nc_plots
        The relative name of the directory to save output to.
    filter_place_cells: bool, default True
        Whether to filter out non spatial place cells from the plots.
        A unit is considered a spatial place cell if:
        Sparsity < 0.3 and Coherence > 0.55.
        Recommended filter_low_freq=True if this flag is True.
        See https://www.nature.com/articles/ncomms11824.
    filter_low_freq: bool, default True
        Filter out cells with spike freq less than 0.1Hz
    opt_end : str, default ""
        A string to append to the file output just before the extension
        Can be used if doing multiple runs to avoid overwriting output files.
    base_dir : str, default None
        An optional directory to save all the files to.
        If not provided, the files will save to the location of the input data.
    output_format : str, default "png"
        What format to save the output image in.
    output : List of str,
        default ["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"]
        Provided argument should be some subset and/or permutation of these.
    isi_bound : int, default 350
        How long in ms to plot the ISI for.
    isi_bin_length : int, default 2
        How long in ms the ISI bins should be.
    fixed_color : str, default None
        If provided, will plot all units with this color instead of auto color.
        Can be any matplotlib compatible color.
    save_data : bool, default False
        Whether to save out the information used for the plot.
        If True, will produce a csv with all the data used for plotting.
    point_size : int, default None
        If provided, the size of the matplotlib points to use in spatial plots.
        If None, the point size is dpi / 7.
    color_isi : bool, default True
        Whether the ISI should be black (False) or blue (True).
    burst_thresh : int, default 5
        How long in ms to consider the window for burst to be
    hd_predict : bool, default False
        Whether the head directional graph should be plotted with predicted HD.
    one_by_one : bool or str, default "auto"
        Whether to plot all units in each tetrode file to single plot. Options:
        True   - plot each unit to a seperate file.
        False  - plot all units on a tetrode to the same file.
        "auto" - Determine T/F based on output_format. pdf and svg are True.

    Returns
    -------
    None

    """

    # This function is used to save out the plotting data to a csv.
    def save_dicts_to_csv(filename, dicts_arr):
        """Save the last element of each arr in dicts_arr to file."""
        with open(filename, "w") as f:
            for d_arr in dicts_arr:
                if d_arr is not None:
                    d = d_arr[-1]
                    for k, v in d.items():
                        out_str = k.replace(" ", "_")
                        if isinstance(v, Iterable):
                            if isinstance(v, np.ndarray):
                                v = v.flatten()
                            else:
                                v = np.array(v).flatten()
                            str_arr = [str(x) for x in v]
                            out_str = out_str + "," + ",".join(str_arr)
                        else:
                            out_str += "," + str(v)
                        f.write(out_str + "\n")

    # Set up the arrays to hold the data for plotting.
    good_placedata = []
    good_graphdata = []
    good_wavedata = []
    good_headdata = []
    good_thetadata = []
    good_isidata = []
    good_units = []
    bad_placedata = []
    bad_graphdata = []
    bad_wavedata = []
    bad_headdata = []
    bad_thetadata = []
    bad_isidata = []
    bad_units = []
    skipped = 0

    if point_size is None:
        point_size = dpi / 7

    for i, data in enumerate(collection):
        try:
            data_idx, unit_idx = collection._index_to_data_pos(i)
            filename = collection.get_file_dict()["Spike"][data_idx][0]
            unit_number = collection.get_units(data_idx)[unit_idx]
            logging.info("Working on {} unit {} out of {}".format(
                filename, unit_number, len(collection.get_units(data_idx))))

            count = data.spike.get_unit_spikes_count()
            # Skip very low count cells
            if count < 5:
                skipped += 1
                logging.warning("Skipping as only {} spikes".format(count))
            else:
                duration = data.spike.get_duration()
                good = True

                # Place cell filtering is based on
                # https://www.nature.com/articles/ncomms11824
                # Activity-plasticity of hippocampus place maps
                # Schoenenberger et al, 2016
                if filter_low_freq:
                    if (count / duration) < 0.25 or (count / duration) > 7:
                        logging.info("Reject spike frequency {}".format(
                            count / duration))
                        good = False

                if good and filter_place_cells:
                    skaggs = data.loc_shuffle(nshuff=1)
                    bad_sparsity = skaggs['refSparsity'] >= 0.3
                    bad_cohere = skaggs['refCoherence'] <= 0.55
                    first_str_part = "Accept "

                    if bad_sparsity or bad_cohere:
                        good = False
                        first_str_part = "Reject "

                    logging.info(
                        (first_str_part + "Skaggs {:2f}, " +
                         "Sparsity {:2f}, " + "Coherence {:2f}").format(
                             skaggs['refSkaggs'], skaggs['refSparsity'],
                             skaggs['refCoherence']))

                if good:
                    good_units.append(unit_idx)
                    placedata = good_placedata
                    graphdata = good_graphdata
                    wavedata = good_wavedata
                    headdata = good_headdata
                    thetadata = good_thetadata
                    isidata = good_isidata
                else:
                    bad_units.append(unit_idx)
                    placedata = bad_placedata
                    graphdata = bad_graphdata
                    wavedata = bad_wavedata
                    headdata = bad_headdata
                    thetadata = bad_thetadata
                    isidata = bad_isidata

                # In case somehow there is a double counting bug
                if ((len(bad_units) + len(good_units)) > len(
                        collection.get_units(data_idx))):
                    save_bad = bad_units
                    save_good = good_units
                    good_placedata = []
                    good_graphdata = []
                    good_wavedata = []
                    good_headdata = []
                    good_thetadata = []
                    good_isidata = []
                    good_units = []
                    bad_placedata = []
                    bad_graphdata = []
                    bad_wavedata = []
                    bad_headdata = []
                    bad_thetadata = []
                    bad_isidata = []
                    bad_units = []
                    skipped = 0
                    raise Exception("Accumlated more units than possible " +
                                    "bad {} good {} total {}".format(
                                        save_bad, save_good,
                                        collection.get_units(data_idx)))

                if "Place" in output:
                    placedata.append(data.place())
                else:
                    bad_placedata = None
                    good_placedata = None
                if "LowAC" in output:
                    graphdata.append(data.isi_corr(bins=1, bound=[-10, 10]))
                else:
                    bad_graphdata = None
                    good_graphdata = None
                if "Wave" in output:
                    wavedata.append(data.wave_property())
                else:
                    bad_wavedata = None
                    good_wavedata = None
                if "HD" in output:
                    headdata.append(data.hd_rate())
                else:
                    bad_headdata = None
                    good_headdata = None
                if "Theta" in output:
                    thetadata.append(
                        data.theta_index(bins=2, bound=[-350, 350]))
                else:
                    bad_thetadata = None
                    good_thetadata = None
                if "HighISI" in output:
                    isidata.append(
                        data.isi(bins=int(isi_bound / isi_bin_length),
                                 bound=[0, isi_bound]))
                else:
                    bad_isidata = None
                    good_isidata = None

                if save_data:
                    try:
                        spike_name = os.path.basename(filename)
                        final_bname, final_ext = os.path.splitext(spike_name)
                        final_ext = final_ext[1:]
                        f_dir = os.path.dirname(filename)

                        data_basename = (final_bname + "_" + final_ext + "_" +
                                         str(unit_number) + opt_end + ".csv")
                        if base_dir is not None:
                            main_dir = base_dir
                            out_base = f_dir[len(base_dir + os.sep):]
                            if len(out_base) != 0:
                                out_base = ("--").join(out_base.split(os.sep))
                                data_basename = out_base + "--" + data_basename
                        else:
                            main_dir = f_dir
                        out_name = os.path.join(main_dir, out_dirname, "data",
                                                data_basename)
                        make_dir_if_not_exists(out_name)
                        save_dicts_to_csv(out_name, [
                            placedata, graphdata, wavedata, headdata,
                            thetadata, isidata
                        ])
                    except Exception as e:
                        log_exception(
                            e, "Occurred during place cell data saving on" +
                            " {} unit {} name {} in {}".format(
                                data_idx, unit_number, spike_name, main_dir))

            # Save the accumulated information
            if unit_idx == len(collection.get_units(data_idx)) - 1:
                if ((len(bad_units) + len(good_units)) !=
                        len(collection.get_units(data_idx)) - skipped):
                    logging.error("Good {}, Bad {}, Total {}".format(
                        good_units, bad_units, collection.get_units(data_idx)))
                    raise ValueError(
                        "Did not cover all units in the collection")
                spike_name = os.path.basename(filename)
                final_bname, final_ext = os.path.splitext(spike_name)
                final_ext = final_ext[1:]
                f_dir = os.path.dirname(filename)

                out_basename = (final_bname + "_" + final_ext + opt_end + "." +
                                output_format)

                if base_dir is not None:
                    main_dir = base_dir
                    out_base = f_dir[len(base_dir + os.sep):]
                    if len(out_base) != 0:
                        out_base = ("--").join(out_base.split(os.sep))
                        out_basename = out_base + "--" + out_basename
                else:
                    main_dir = f_dir

                if filter_place_cells:
                    named_units = [
                        collection.get_units(data_idx)[j] for j in good_units
                    ]
                    bad_named_units = [
                        collection.get_units(data_idx)[j] for j in bad_units
                    ]
                else:
                    named_units = collection.get_units(data_idx)
                    bad_named_units = []

                # Save figures one by one if using pdf or svg
                if one_by_one == "auto":
                    one_by_one = (output_format == "pdf") or (output_format
                                                              == "svg")

                if len(named_units) > 0:
                    if filter_place_cells:
                        logging.info(("Plotting summary for {} " +
                                      "spatial units {}").format(
                                          spike_name, named_units))
                    else:
                        logging.info(
                            ("Plotting summary for {} " + "units {}").format(
                                spike_name, named_units))

                    fig = print_place_cells(len(named_units),
                                            cols=len(output),
                                            placedata=good_placedata,
                                            graphdata=good_graphdata,
                                            wavedata=good_wavedata,
                                            headdata=good_headdata,
                                            thetadata=good_thetadata,
                                            isidata=good_isidata,
                                            size_multiplier=4,
                                            point_size=point_size,
                                            units=named_units,
                                            fixed_color=fixed_color,
                                            output=output,
                                            color_isi=color_isi,
                                            burst_ms=burst_thresh,
                                            one_by_one=one_by_one,
                                            raster=one_by_one,
                                            hd_predict=hd_predict)

                    if one_by_one:
                        for k, f in enumerate(fig):
                            unit_number = named_units[k]
                            iname = (out_basename[:-4] + "_" +
                                     str(unit_number) + out_basename[-4:])
                            if filter_low_freq or filter_place_cells:
                                out_name = os.path.join(
                                    main_dir, out_dirname, "good", iname)
                            else:
                                out_name = os.path.join(
                                    main_dir, out_dirname, iname)

                            logging.info(
                                "Saving place cell figure to {}".format(
                                    out_name))
                            make_dir_if_not_exists(out_name)
                            f.savefig(out_name, dpi=dpi, format=output_format)

                    else:
                        if filter_low_freq or filter_place_cells:
                            out_name = os.path.join(main_dir, out_dirname,
                                                    "good", out_basename)
                        else:
                            out_name = os.path.join(main_dir, out_dirname,
                                                    out_basename)
                        logging.info(
                            "Saving place cell figure to {}".format(out_name))
                        make_dir_if_not_exists(out_name)
                        fig.savefig(out_name, dpi=dpi, format=output_format)

                    close("all")
                    gc.collect()

                if len(bad_named_units) > 0:
                    logging.info(("Plotting bad summary for {} " +
                                  "non-spatial units {}").format(
                                      spike_name, bad_named_units))
                    fig = print_place_cells(len(bad_named_units),
                                            cols=len(output),
                                            placedata=bad_placedata,
                                            graphdata=bad_graphdata,
                                            wavedata=bad_wavedata,
                                            headdata=bad_headdata,
                                            thetadata=bad_thetadata,
                                            isidata=bad_isidata,
                                            size_multiplier=4,
                                            point_size=point_size,
                                            units=bad_named_units,
                                            fixed_color=fixed_color,
                                            output=output,
                                            color_isi=color_isi,
                                            burst_ms=burst_thresh,
                                            one_by_one=one_by_one,
                                            raster=one_by_one,
                                            hd_predict=hd_predict)

                    if one_by_one:
                        for k, f in enumerate(fig):
                            unit_number = bad_named_units[k]
                            iname = (out_basename[:-4] + "_" +
                                     str(unit_number) + out_basename[-4:])
                            out_name = os.path.join(main_dir, out_dirname,
                                                    "bad", iname)

                            logging.info(
                                "Saving place cell fig to {}".format(out_name))
                            make_dir_if_not_exists(out_name)
                            f.savefig(out_name, dpi=dpi, format=output_format)

                    else:
                        out_name = os.path.join(main_dir, out_dirname, "bad",
                                                out_basename)
                        logging.info(
                            "Saving place cell fig to {}".format(out_name))
                        make_dir_if_not_exists(out_name)
                        fig.savefig(out_name, dpi=dpi, format=output_format)

                    close("all")
                    gc.collect()

                good_placedata = []
                good_graphdata = []
                good_wavedata = []
                good_headdata = []
                good_thetadata = []
                good_isidata = []
                good_units = []
                bad_placedata = []
                bad_graphdata = []
                bad_wavedata = []
                bad_headdata = []
                bad_thetadata = []
                bad_isidata = []
                bad_units = []
                skipped = 0

        except Exception as e:
            log_exception(
                e, "Occurred during place cell summary on data" +
                " {} unit {} name {} in {}".format(data_idx, unit_number,
                                                   spike_name, main_dir))
    return
Пример #10
0
def cell_classification_stats(in_dir,
                              container,
                              out_name,
                              should_plot=False,
                              opt_end="",
                              output_spaces=True):
    """
    Compute a csv of cell stats for each unit in a container

    Params
    ------
    in_dir - the data output/input location
    container - the NDataContainer object to get stats for
    should_plot - whether to save some plots for this
    """
    _results = []
    spike_names = container.get_file_dict()["Spike"]
    for i, ndata in enumerate(container):
        try:
            data_idx, unit_idx = container._index_to_data_pos(i)
            name = spike_names[data_idx][0]
            parts = os.path.basename(name).split(".")

            # Setup up identifier information
            note_dict = oDict()
            dir_t = os.path.dirname(name)
            note_dict["Index"] = i
            note_dict["FullDir"] = dir_t
            if dir_t != in_dir:
                note_dict["RelDir"] = os.path.dirname(name)[len(in_dir +
                                                                os.sep):]
            else:
                note_dict["RelDir"] = ""
            note_dict["Recording"] = parts[0]
            note_dict["Tetrode"] = int(parts[-1])
            note_dict["Unit"] = ndata.get_unit_no()
            ndata.update_results(note_dict)

            # Caculate cell properties
            ndata.wave_property()
            ndata.place()
            ndata.hd_rate()
            ndata.grid()
            ndata.border()
            ndata.multiple_regression()
            isi = ndata.isi()
            ndata.burst(burst_thresh=6)
            phase_dist = ndata.phase_dist()
            theta_index = ndata.theta_index()
            ndata.bandpower_ratio([5, 11], [1.5, 4],
                                  1.6,
                                  relative=True,
                                  first_name="Theta",
                                  second_name="Delta")
            result = copy(
                ndata.get_results(spaces_to_underscores=not output_spaces))
            _results.append(result)

            if should_plot:
                plot_loc = os.path.join(
                    in_dir, "nc_plots", parts[0] + "_" + parts[-1] + "_" +
                    str(ndata.get_unit_no()) + "_phase" + opt_end + ".png")
                make_dir_if_not_exists(plot_loc)
                fig1, fig2, fig3 = nc_plot.spike_phase(phase_dist)
                fig2.savefig(plot_loc)
                plt.close("all")

                if unit_idx == len(container.get_units(data_idx)) - 1:
                    plot_loc = os.path.join(
                        in_dir, "nc_plots",
                        parts[0] + "_lfp" + opt_end + ".png")
                    make_dir_if_not_exists(plot_loc)

                    lfp_spectrum = ndata.spectrum()
                    fig = nc_plot.lfp_spectrum(lfp_spectrum)
                    fig.savefig(plot_loc)
                    plt.close(fig)

        except Exception as e:
            print("WARNING: Failed to analyse {} unit {}".format(
                os.path.basename(name), note_dict["Unit"]))
            log_exception(
                e, "Failed on {} unit {}".format(os.path.basename(name),
                                                 note_dict["Unit"]))

    # Save the cell statistics
    make_dir_if_not_exists(out_name)
    save_dicts_to_csv(out_name, _results)
    _results.clear()
Пример #11
0
    def add_axona_files_from_dir(self,
                                 directory,
                                 recursive=False,
                                 verbose=False,
                                 **kwargs):
        """
        Go through a directory, extracting Axona files from it automatically.

        Parameters
        ----------
        directory : str
            The directory to parse through
        recursive : bool, optional.
            Defaults to False.
            Whether to recurse through dirs
        verbose: bool, optional.
            Defaults to False.
            Whether to print the files being added.

        **kwargs: keyword arguments
            tetrode_list : list
                list of tetrodes to consider
                default is 1 to 16
            data_extension : str
                default .set
            cluster_extension : str
                default .cut
            pos_extension : str
                default .txt
            lfp_extension : str
                default .eeg
            re_filter : str
                default None - no regex performed
                regex string for matching filenames
            save_result : bool
                default True
                should save the resulting collection to a file
            unit_cutoff : tuple of ints
                don't consider any recordings with units outside this range
                e.g. if the cutoff is set at 10, any clustering containing
                11 or more units will not be considered valid and won't be
                added to the container.

        Returns
        -------
        List or str:
            If save_result is true,
            returns a string indicated where the result was saved
            Otherwise returns a list of the cluster files which were used.

        """
        default_tetrode_list = [
            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
        ]
        tetrode_list = kwargs.get("tetrode_list", default_tetrode_list)
        data_extension = kwargs.get("data_extension", ".set")
        cluster_extension = kwargs.get("cluster_extension", ".cut")
        clu_extension = kwargs.get("clu_extension", ".clu.X")
        pos_extension = kwargs.get("pos_extension", ".txt")
        lfp_extension = kwargs.get("lfp_extension", ".eeg")
        stm_extension = kwargs.get("stm_extension", ".stm")
        re_filter = kwargs.get("re_filter", None)
        save_result = kwargs.get("save_result", True)
        unit_cutoff = kwargs.get("unit_cutoff", None)

        if verbose:
            print("Finding set files:")
        files = get_all_files_in_dir(directory,
                                     data_extension,
                                     recursive=recursive,
                                     verbose=verbose,
                                     re_filter=re_filter,
                                     return_absolute=True)
        if verbose:
            print("Finding txt files:")
        txt_files = get_all_files_in_dir(directory,
                                         pos_extension,
                                         recursive=recursive,
                                         verbose=verbose,
                                         re_filter=re_filter,
                                         return_absolute=True)
        if verbose:
            print("Finding stm files:")
        stm_files = get_all_files_in_dir(directory,
                                         stm_extension,
                                         recursive=recursive,
                                         verbose=verbose,
                                         re_filter=re_filter,
                                         return_absolute=True)

        num_found = 0
        cluster_files = []
        for filename in files:
            filename = filename[:-len(data_extension)]
            for tetrode in tetrode_list:
                spike_name = filename + '.' + str(tetrode)
                cut_name = filename + '_' + str(tetrode) + cluster_extension
                clu_name = filename + clu_extension[:-1] + str(tetrode)
                lfp_name = filename + lfp_extension
                stm_name = ""

                if not os.path.isfile(os.path.join(directory, spike_name)):
                    continue
                # Don't consider files that have not been clustered
                if not (os.path.isfile(os.path.join(directory, cut_name))
                        or os.path.isfile(os.path.join(directory, clu_name))):
                    logging.info(
                        "Skipping tetrode {} - no clust file {} or {}".format(
                            tetrode, cut_name, os.path.basename(clu_name)))
                    continue

                for fname in txt_files:
                    if fname[:(len(filename) + 1)] == filename + "_":
                        pos_name = fname
                        break

                else:
                    logging.info(
                        "Skipping tetrode {} - no position file {}".format(
                            tetrode, filename))
                    continue

                for fname in stm_files:
                    if fname[:len(filename)] == filename:
                        stm_name = fname
                        break

                if os.path.isfile(os.path.join(directory, cut_name)):
                    cluster_name = cut_name
                else:
                    cluster_name = clu_name

                cluster_files.append(cluster_name)
                logging.info(
                    "Adding tetrode {} with spikes {}, clusters {}, positions {}"
                    .format(tetrode, os.path.basename(spike_name),
                            os.path.basename(cluster_name),
                            os.path.basename(pos_name)))
                num_found += 1
                self.add_files(NDataContainer.EFileType.Spike, [spike_name])
                self.add_files(NDataContainer.EFileType.Position, [pos_name])
                self.add_files(NDataContainer.EFileType.LFP, [lfp_name])
                self.add_files(NDataContainer.EFileType.STM, [stm_name])

        if num_found == 0:
            logging.warning("Did not find any Axona files to add")
            return

        self.set_units()

        if unit_cutoff:
            self.remove_recordings_units(unit_cutoff[0],
                                         unit_cutoff[1],
                                         verbose=verbose)

        if save_result:
            friendly_re = ""
            if re_filter:
                friendly_re = "_" + \
                    "-".join(re.findall("[a-zA-Z0-9_]+", re_filter))
            name = ("file_list_" + os.path.basename(directory) + friendly_re +
                    ".txt")
            out_loc = os.path.join(directory, "nc_results", name)
            make_dir_if_not_exists(out_loc)
            with open(out_loc, 'w') as f:
                f.write(str(self))
            logging.info(
                "Wrote list of files considered to {}".format(out_loc))
            return out_loc

        return cluster_files
def cell_classification_stats(
        in_dir, container, out_name,
        should_plot=False, opt_end="", output_spaces=True,
        good_cells=None):
    """
    Compute a csv of cell stats for each unit in a container

    Params
    ------
    in_dir - the data output/input location
    container - the NDataContainer object to get stats for
    should_plot - whether to save some plots for this
    """
    _results = []
    spike_names = container.get_file_dict()["Spike"]
    overall_count = 0
    for i in range(len(container)):
        try:
            data_idx, unit_idx = container._index_to_data_pos(i)
            name = spike_names[data_idx][0]
            parts = os.path.basename(name).split(".")
            o_name = os.path.join(
                os.path.dirname(name)[len(in_dir + os.sep):],
                parts[0])
            note_dict = oDict()
            # Setup up identifier information
            dir_t = os.path.dirname(name)
            note_dict["Index"] = i
            note_dict["FullDir"] = dir_t
            if dir_t != in_dir:
                note_dict["RelDir"] = os.path.dirname(
                    name)[len(in_dir + os.sep):]
            else:
                note_dict["RelDir"] = ""
            note_dict["Recording"] = parts[0]
            note_dict["Tetrode"] = int(parts[-1])
            if good_cells is not None:
                check = [
                    os.path.normpath(name[len(in_dir + os.sep):]),
                    container.get_units(data_idx)[unit_idx]]
                if check not in good_cells:
                    continue
            ndata = container[i]
            overall_count += 1
            print("Working on unit {} of {}: {}, T{}, U{}".format(
                i + 1, len(container), o_name, parts[-1], ndata.get_unit_no()))

            note_dict["Unit"] = ndata.get_unit_no()
            ndata.update_results(note_dict)

            # Caculate cell properties
            ndata.wave_property()
            ndata.place()
            isi = ndata.isi()
            ndata.burst(burst_thresh=6)
            theta_index = ndata.theta_index()
            ndata._results["IsPyramidal"] = cell_type(ndata)
            result = copy(ndata.get_results(
                spaces_to_underscores=not output_spaces))
            _results.append(result)

        except Exception as e:
            to_out = note_dict.get("Unit", "NA")
            print("WARNING: Failed to analyse {} unit {}".format(
                os.path.basename(name), to_out))
            log_exception(e, "Failed on {} unit {}".format(
                os.path.basename(name), to_out))

    # Save the cell statistics
    make_dir_if_not_exists(out_name)
    print("Analysed {} cells in total".format(overall_count))
    save_dicts_to_csv(out_name, _results)
    _results.clear()