Пример #1
0
    def save_object(self, obj=None):
        """
        Store a NeuroChaT dataset to the HDF5 file.

        It resolves the name first and then stores the data in the storage path.

        Parameters
        ----------
        obj
            One of the NeuroChaT data types

        Returns
        -------
        None

        """
        try:
            obj_type = obj.get_type()
        except BaseException as e:
            log_exception(e, 'Object passed is not a neurochat data type')

        try:
            if os.path.isfile(obj.get_filename()):
                fun = getattr(self, 'save_' + obj_type)
                fun(obj)
        except BaseException as e:
            log_exception(e, 'Saving hdf5 dataset')
Пример #2
0
def save_dicts_to_csv(filename, in_dicts):
    """
    Save a list of dictionaries to a csv.

    The headers are set as the maximal set of keys in in_dicts.
    It is assumed that all other dicts will have a subset of these keys.
    Each entry in the dict is saved to a row of the csv, so it is assumed that
    the values in the dict are mostly floats / ints / etc.
    """
    # first, find the dict with the most keys
    max_key = in_dicts[0].keys()
    for in_dict in in_dicts:
        names = in_dict.keys()
        if len(names) > len(max_key):
            max_key = names

    # Then append other keys if still missing keys
    for in_dict in in_dicts:
        names = in_dict.keys()
        for name in names:
            if not name in max_key:
                max_key.append(name)

    try:
        with open(filename, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=max_key)
            writer.writeheader()
            for in_dict in in_dicts:
                writer.writerow(in_dict)

    except Exception as e:
        log_exception(e, "When {} saving to csv".format(filename))
def main(dir, bin_size, bin_bound):
    container = NDataContainer(load_on_fly=True)
    container.add_axona_files_from_dir(dir, recursive=True)
    container.setup()
    print(container.string_repr(True))
    event = NEvent()
    last_stm_name = None
    dict_list = []

    for i in range(len(container)):
        try:
            result_dict = OrderedDict()
            data_index, _ = container._index_to_data_pos(i)
            stm_name = container.get_file_dict("STM")[data_index][0]
            data = container[i]

            # Add extra keys
            spike_file = data.spike.get_filename()
            spike_dir = os.path.dirname(spike_file)
            spike_name = os.path.basename(spike_file)
            spike_name_only, spike_ext = os.path.splitext(spike_name)
            spike_ext = spike_ext[1:]
            result_dict["Dir"] = spike_dir
            result_dict["Name"] = spike_name_only
            result_dict["Tet"] = int(spike_ext)
            result_dict["Unit"] = data.spike.get_unit_no()

            # Do analysis
            if last_stm_name != stm_name:
                event.load(stm_name, 'Axona')
                last_stm_name = stm_name

            graph_data = event.psth(data.spike, bins=bin_size, bound=bin_bound)
            spike_count = data.get_unit_spikes_count()
            result_dict["Num_Spikes"] = spike_count

            # Bin renaming
            for (b, v) in zip(graph_data["all_bins"][:-1], graph_data["psth"]):
                result_dict[str(b)] = v
            dict_list.append(result_dict)

            # Do plotting
            name = (spike_name_only + "_" + spike_ext + "_" +
                    str(result_dict["Unit"]) + ".png")
            plot_name = os.path.join(dir, "psth_results", name)
            make_dir_if_not_exists(plot_name)
            plot_psth(graph_data, plot_name)
            print("Saved psth to {}".format(plot_name))

        except Exception as e:
            log_exception(e, "During stimulation batch at {}".format(i))
            dict_list.append(result_dict)

    fname = os.path.join(dir, "psth_results", "psth.csv")
    save_dicts_to_csv(fname, dict_list)
    print("Saved results to {}".format(fname))
Пример #4
0
def compare_two(file1, file2, index):
    """
    Compare spike file1 against spike file2.

    Cluster similarity is calculated.
    index is used for file naming, and saves to:
    output1.csv if index is 1, for example.

    """
    nclust1 = NClust()
    nclust2 = NClust()
    nclust1.load(file1, "Axona")
    nclust2.load(file2, "Axona")
    units1 = nclust1.get_unit_list()
    units2 = nclust2.get_unit_list()
    bc_matrix = np.zeros(shape=(len(units1), len(units2)), dtype=np.float32)
    hd_matrix = np.zeros(shape=(len(units1), len(units2)), dtype=np.float32)
    for i, unit1 in enumerate(units1):
        for j, unit2 in enumerate(units2):
            try:
                bc, hd = nclust1.cluster_similarity(nclust2, unit1, unit2)
                print("({}, {}): BC {:.2f}, HD {:.2f}".format(
                    i + 1, j + 1, bc, hd))
                bc_matrix[i, j] = bc
                hd_matrix[i, j] = hd
            except Exception as e:
                log_exception(e, "at ({}, {})".format(i, j))
                bc_matrix[i, j] = np.nan
                hd_matrix[i, j] = np.nan

    out_filename = "output{}.csv".format(index)

    with open(out_filename, "w") as f:
        # Save the BC
        # for i in range(len(units1)):
        #     out_str = ""
        #     for j in range(len(units2)):
        #         out_str += str(bc_matrix[i, j]) + ","
        #     out_str = out_str[:-1] + "\n"
        #     f.write(out_str)

        # Save the HD
        for i in range(len(units1)):
            out_str = ""
            for j in range(len(units2)):
                out_str += str(hd_matrix[i, j]) + ","
            out_str = out_str[:-1] + "\n"
            f.write(out_str)
def save_dicts_to_csv(filename, in_dicts):
    """Save a dictionary to a csv"""
    # find the dict with the most keys
    max_key = []
    for in_dict in in_dicts:
        names = in_dicts[0].keys()
        if len(names) > len(max_key):
            max_key = names

    try:
        with open(filename, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=names)
            writer.writeheader()
            for in_dict in in_dicts:
                writer.writerow(in_dict)

    except Exception as e:
        log_exception(e, "When {} saving to csv".format(filename))
Пример #6
0
    def save_dataset(self, path=None, name=None, data=None, create_group=True):
        """
        Store a dataset to a specific path.

        Parameters
        ----------
        path : str
            Path of a group in HDF5 file
        name : str
            Name of the new dataset
        data : ndarray or list of numbers
            Data to be stored
        create_group : bool
            If True, creates a new group if the 'path' is not in the file

        Returns
        -------
        None

        """
        if not path:
            logging.error('Invalid group path specified!')
        if not name:
            logging.error('Please provide a name for the dataset!')
        if (path in self.f) or create_group:
            g = self.f.require_group(path)
            if name in g:
                del g[name]
            # This conditional restricts the None data to store, need to change
            if isinstance(data, list):
                data = [np.nan if item is None else item for item in data]
                try:
                    data = np.array(data)
                except BaseException:
                    pass
            try:
                g.create_dataset(name=name, data=data)
            except BaseException as e:
                log_exception(e, 'Saving ' + name + ' dataset to hdf5 file')
        else:
            logging.error('hdf5 file path can be created or restored!')
Пример #7
0
    def get_data_at(self, data_index, unit_index):
        """
        Return an NData object from the given indices.

        Parameters
        ----------
        data_index : int
            The index in the container to return data at.
        unit_index : int
            The unit number to set on the returned data.

        Returns
        -------
        NData
            The ndata object at data_index with unit number unit_index

        """
        if self._load_on_fly:
            try:
                if (data_index == self._last_data_pt[0]
                        and (self._last_data_pt[1] is not None)):
                    result = self._last_data_pt[1]
                else:
                    result = NData()
                    for key, vals in self.get_file_dict().items():
                        if key == "STM":
                            continue
                        descriptor = vals[data_index]
                        self._load(key,
                                   descriptor,
                                   idx=data_index,
                                   ndata=result)
                    self._last_data_pt = (data_index, result)
            except Exception as e:
                log_exception(e, "During loading data")
        else:
            result = self.get_data(data_index)
        if len(self.get_units()) > 0:
            result.set_unit_no(self.get_units(data_index)[unit_index])
        return result
Пример #8
0
    def file(self):
        """
        Open the file, and returns the file object.

        Parameters
        ----------
        None

        Returns
        -------
        object
            h5py file object

        """
        self.close()
        try:
            self.f = h5py.File(self._filename, 'a')
            self.initialize()
        except BaseException as e:
            log_exception(e, 'Opening hdf file' + self._filename)

        return self.f
Пример #9
0
def place_cell_summary(
        collection, dpi=150, out_dirname="nc_plots",
        filter_place_cells=False, filter_low_freq=False,
        opt_end="", base_dir=None, output_format="png",
        output=["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"],
        isi_bound=350, isi_bin_length=2, fixed_color=None,
        save_data=False, point_size=None, color_isi=True,
        burst_thresh=5, hd_predict=False):
    """
    Quick Png spatial information summary of each cell in collection.

    Parameters
    ----------
    collection : NDataCollection
        The collection to plot summaries of.
    dpi : int, default 150
        Dpi of the output figures.
    out_dirname : str, default "nc_plots
        The relative name of the dir to save pngs to
    filter_place_cells: bool, default True
        Whether to filter out non spatial cells from the plots.
        Considered non spatial if shuffled Skaggs, Coherency or Sparsity
        is similar to actual values.
    filter_low_freq: bool, default True
        Filter out cells with spike freq less than 0.1Hz
    opt_end : str, default ""
        A string to append to the file output just before the extension
    base_dir : str, default None
        An optional directory to save the files to
    output_format : str, default png
        What format to save the output image in
    output : List of str,
        default ["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"]
        Input should be some subset and/or permutation of these
    isi_bound: int, default 350
        How long in ms to plot the ISI to
    isi_bin_length: int, default 1
        How long in ms the ISI bins should be
    save_data: bool, default False
        Whether to save out the information used for the plot
    color_isi: bool, default True
        Whether the ISI should be black or blue
    burst_thresh: int, default 5
        How long in ms to consider the window for burst to be
    hd_predict: bool, default False
        Whether the head directional graph should be plotted with predicted HD.

    Returns
    -------
    None
    """

    def save_dicts_to_csv(filename, dicts_arr):
        """Saves the last element of each arr in dicts_arr to file"""
        with open(filename, "w") as f:
            for d_arr in dicts_arr:
                if d_arr is not None:
                    d = d_arr[-1]
                    for k, v in d.items():
                        out_str = k.replace(" ", "_")
                        if isinstance(v, Iterable):
                            if isinstance(v, np.ndarray):
                                v = v.flatten()
                            else:
                                v = np.array(v).flatten()
                            str_arr = [str(x) for x in v]
                            out_str = out_str + "," + ",".join(str_arr)
                        else:
                            out_str += "," + str(v)
                        f.write(out_str + "\n")

    good_placedata = []
    good_graphdata = []
    good_wavedata = []
    good_headdata = []
    good_thetadata = []
    good_isidata = []
    good_units = []
    bad_placedata = []
    bad_graphdata = []
    bad_wavedata = []
    bad_headdata = []
    bad_thetadata = []
    bad_isidata = []
    bad_units = []
    skipped = 0

    if point_size is None:
        point_size = dpi / 7
    for i, data in enumerate(collection):
        try:
            data_idx, unit_idx = collection._index_to_data_pos(i)
            print("Working on {}, {}".format(unit_idx, len(
                collection.get_units(data_idx)) - 1))
            filename = collection.get_file_dict()["Spike"][data_idx][0]
            unit_number = collection.get_units(data_idx)[unit_idx]
            print("Working on {} unit {}".format(
                filename, unit_number))

            count = data.spike.get_unit_spikes_count()
            # Skip very low count cells
            if count < 5:
                skipped += 1
                print("Skipping as only {} spikes".format(count))
            else:
                duration = data.spike.get_duration()
                good = True

                # Place cell filtering is based on
                # https://www.nature.com/articles/ncomms11824
                # Activity-plasticity of hippocampus place maps
                # Schoenenberger et al, 2016
                if filter_low_freq:
                    if (count / duration) < 0.25 or (count / duration) > 7:
                        print("Reject spike frequency {}".format(
                            count / duration))
                        good = False

                if good and filter_place_cells:
                    skaggs = data.loc_shuffle(nshuff=1)
                    bad_sparsity = skaggs['refSparsity'] >= 0.3
                    bad_cohere = skaggs['refCoherence'] <= 0.55
                    first_str_part = "Accept "

                    if bad_sparsity or bad_cohere:
                        good = False
                        first_str_part = "Reject "

                    print((
                        first_str_part +
                        "Skaggs {:2f}, " +
                        "Sparsity {:2f}, " +
                        "Coherence {:2f}").format(
                        skaggs['refSkaggs'], skaggs['refSparsity'],
                        skaggs['refCoherence']))

                if good:
                    good_units.append(unit_idx)
                    placedata = good_placedata
                    graphdata = good_graphdata
                    wavedata = good_wavedata
                    headdata = good_headdata
                    thetadata = good_thetadata
                    isidata = good_isidata
                else:
                    bad_units.append(unit_idx)
                    placedata = bad_placedata
                    graphdata = bad_graphdata
                    wavedata = bad_wavedata
                    headdata = bad_headdata
                    thetadata = bad_thetadata
                    isidata = bad_isidata

                if (
                    (len(bad_units) + len(good_units)) >
                        len(collection.get_units(data_idx))):
                    save_bad = bad_units
                    save_good = good_units
                    good_placedata = []
                    good_graphdata = []
                    good_wavedata = []
                    good_headdata = []
                    good_thetadata = []
                    good_isidata = []
                    good_units = []
                    bad_placedata = []
                    bad_graphdata = []
                    bad_wavedata = []
                    bad_headdata = []
                    bad_thetadata = []
                    bad_isidata = []
                    bad_units = []
                    skipped = 0
                    print("ERROR: Found too many units in the collection")
                    raise Exception(
                        "Accumlated more units than possible " +
                        "bad {} good {} total {}".format(
                            save_bad, save_good, collection.get_units(data_idx)))

                if "Place" in output:
                    placedata.append(data.place())
                else:
                    bad_placedata = None
                    good_placedata = None
                if "LowAC" in output:
                    graphdata.append(data.isi_corr(bins=1, bound=[-10, 10]))
                else:
                    bad_graphdata = None
                    good_graphdata = None
                if "Wave" in output:
                    wavedata.append(data.wave_property())
                else:
                    bad_wavedata = None
                    good_wavedata = None
                if "HD" in output:
                    headdata.append(data.hd_rate())
                else:
                    bad_headdata = None
                    good_headdata = None
                if "Theta" in output:
                    thetadata.append(
                        data.theta_index(
                            bins=2, bound=[-350, 350]))
                else:
                    bad_thetadata = None
                    good_thetadata = None
                if "HighISI" in output:
                    isidata.append(
                        data.isi(bins=int(isi_bound / isi_bin_length),
                                 bound=[0, isi_bound]))
                else:
                    bad_isidata = None
                    good_isidata = None

                if save_data:
                    try:
                        spike_name = os.path.basename(filename)
                        parts = spike_name.split(".")
                        f_dir = os.path.dirname(filename)

                        data_basename = (
                            parts[0] + "_" + parts[1] + "_" +
                            str(unit_number) + opt_end + ".csv")
                        if base_dir is not None:
                            main_dir = base_dir
                            out_base = f_dir[len(base_dir + os.sep):]
                            if len(out_base) != 0:
                                out_base = ("--").join(out_base.split(os.sep))
                                data_basename = out_base + "--" + data_basename
                        else:
                            main_dir = f_dir
                        out_name = os.path.join(
                            main_dir, out_dirname, "data", data_basename)
                        make_dir_if_not_exists(out_name)
                        save_dicts_to_csv(
                            out_name,
                            [placedata, graphdata, wavedata,
                                headdata, thetadata, isidata])
                    except Exception as e:
                        log_exception(
                            e, "Occurred during place cell data saving on" +
                            " {} unit {} name {} in {}".format(
                                data_idx, unit_number, spike_name, main_dir))

            # Save the accumulated information
            if unit_idx == len(collection.get_units(data_idx)) - 1:
                if ((len(bad_units) + len(good_units)) !=
                        len(collection.get_units(data_idx)) - skipped):
                    print("ERROR: Did not cover all units in the collection")
                    print("Good {}, Bad {}, Total {}".format(
                        bad_units, good_units, collection.get_units(data_idx)))
                spike_name = os.path.basename(filename)
                parts = spike_name.split(".")
                f_dir = os.path.dirname(filename)

                out_basename = (
                    parts[0] + "_" + parts[1] + opt_end + "." + output_format)

                if base_dir is not None:
                    main_dir = base_dir
                    out_base = f_dir[len(base_dir + os.sep):]
                    if len(out_base) != 0:
                        out_base = ("--").join(out_base.split(os.sep))
                        out_basename = out_base + "--" + out_basename
                else:
                    main_dir = f_dir

                if filter_place_cells:
                    named_units = [
                        collection.get_units(data_idx)[j]
                        for j in good_units]
                    bad_named_units = [
                        collection.get_units(data_idx)[j]
                        for j in bad_units]
                else:
                    named_units = collection.get_units(data_idx)
                    bad_named_units = []

                # Save figures one by one if using pdf or svg
                one_by_one = (output_format == "pdf") or (
                    output_format == "svg")

                if len(named_units) > 0:
                    if filter_place_cells:
                        print((
                            "Plotting summary for {} " +
                            "spatial units {}").format(
                            spike_name, named_units))
                    else:
                        print((
                            "Plotting summary for {} " +
                            "units {}").format(
                            spike_name, named_units))

                    fig = print_place_cells(
                        len(named_units), cols=len(output),
                        placedata=good_placedata, graphdata=good_graphdata,
                        wavedata=good_wavedata, headdata=good_headdata,
                        thetadata=good_thetadata, isidata=good_isidata,
                        size_multiplier=4, point_size=point_size,
                        units=named_units, fixed_color=fixed_color,
                        output=output, color_isi=color_isi,
                        burst_ms=burst_thresh, one_by_one=one_by_one,
                        raster=one_by_one, hd_predict=hd_predict)

                    if one_by_one:
                        for k, f in enumerate(fig):
                            unit_number = named_units[k]
                            iname = (
                                out_basename[:-4] + "_" +
                                str(unit_number) + out_basename[-4:])
                            if filter_low_freq or filter_place_cells:
                                out_name = os.path.join(
                                    main_dir, out_dirname, "good", iname)
                            else:
                                out_name = os.path.join(
                                    main_dir, out_dirname, iname)

                            print("Saving place cell figure to {}".format(
                                out_name))
                            make_dir_if_not_exists(out_name)
                            f.savefig(out_name, dpi=dpi, format=output_format)

                    else:
                        if filter_low_freq or filter_place_cells:
                            out_name = os.path.join(
                                main_dir, out_dirname, "good", out_basename)
                        else:
                            out_name = os.path.join(
                                main_dir, out_dirname, out_basename)
                        print("Saving place cell figure to {}".format(
                            out_name))
                        make_dir_if_not_exists(out_name)
                        fig.savefig(out_name, dpi=dpi, format=output_format)

                    close("all")
                    gc.collect()

                if len(bad_named_units) > 0:
                    print((
                        "Plotting bad summary for {} " +
                        "non-spatial units {}").format(
                        spike_name, bad_named_units))
                    fig = print_place_cells(
                        len(bad_named_units), cols=len(output),
                        placedata=bad_placedata, graphdata=bad_graphdata,
                        wavedata=bad_wavedata, headdata=bad_headdata,
                        thetadata=bad_thetadata, isidata=bad_isidata,
                        size_multiplier=4, point_size=point_size,
                        units=bad_named_units, fixed_color=fixed_color,
                        output=output, color_isi=color_isi,
                        burst_ms=burst_thresh, one_by_one=one_by_one,
                        raster=one_by_one, hd_predict=hd_predict)

                    if one_by_one:
                        for k, f in enumerate(fig):
                            unit_number = bad_named_units[k]
                            iname = (
                                out_basename[:-4] + "_" +
                                str(unit_number) + out_basename[-4:])
                            out_name = os.path.join(
                                main_dir, out_dirname, "bad", iname)

                            print("Saving place cell figure to {}".format(
                                out_name))
                            make_dir_if_not_exists(out_name)
                            f.savefig(out_name, dpi=dpi,
                                      format=output_format)

                    else:
                        out_name = os.path.join(
                            main_dir, out_dirname, "bad", out_basename)
                        print("Saving place cell figure to {}".format(
                            out_name))
                        make_dir_if_not_exists(out_name)
                        fig.savefig(out_name, dpi=dpi,
                                    format=output_format)

                    close("all")
                    gc.collect()

                good_placedata = []
                good_graphdata = []
                good_wavedata = []
                good_headdata = []
                good_thetadata = []
                good_isidata = []
                good_units = []
                bad_placedata = []
                bad_graphdata = []
                bad_wavedata = []
                bad_headdata = []
                bad_thetadata = []
                bad_isidata = []
                bad_units = []
                skipped = 0

        except Exception as e:
            log_exception(
                e, "Occurred during place cell summary on data" +
                   " {} unit {} name {} in {}".format(
                       data_idx, unit_number, spike_name, main_dir))
    return
def place_cell_summary(
        collection,
        dpi=150,
        out_dirname="nc_plots",
        filter_place_cells=False,
        filter_low_freq=False,
        opt_end="",
        base_dir=None,
        output_format="png",
        output=["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"],
        isi_bound=350,
        isi_bin_length=2,
        fixed_color=None,
        save_data=False,
        point_size=None,
        color_isi=True,
        burst_thresh=5,
        hd_predict=False,
        one_by_one="auto"):
    """
    Perform spatial information summary of each cell in collection.

    The function is named as place_cell_summary as it can be a
    quick visual way to look for place cells.
    However, it can be used to look other spatial properties,
    such as head directionality.

    The output image (any matplotlib format is supported) contains
    the information from and in the order of the output argument.

    Parameters
    ----------
    collection : NDataCollection
        The collection to plot spatial summaries of.
    dpi : int, default 150
        Dpi of the output figures if the output_format supports dpi.
    out_dirname : str, default "nc_plots
        The relative name of the directory to save output to.
    filter_place_cells: bool, default True
        Whether to filter out non spatial place cells from the plots.
        A unit is considered a spatial place cell if:
        Sparsity < 0.3 and Coherence > 0.55.
        Recommended filter_low_freq=True if this flag is True.
        See https://www.nature.com/articles/ncomms11824.
    filter_low_freq: bool, default True
        Filter out cells with spike freq less than 0.1Hz
    opt_end : str, default ""
        A string to append to the file output just before the extension
        Can be used if doing multiple runs to avoid overwriting output files.
    base_dir : str, default None
        An optional directory to save all the files to.
        If not provided, the files will save to the location of the input data.
    output_format : str, default "png"
        What format to save the output image in.
    output : List of str,
        default ["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"]
        Provided argument should be some subset and/or permutation of these.
    isi_bound : int, default 350
        How long in ms to plot the ISI for.
    isi_bin_length : int, default 2
        How long in ms the ISI bins should be.
    fixed_color : str, default None
        If provided, will plot all units with this color instead of auto color.
        Can be any matplotlib compatible color.
    save_data : bool, default False
        Whether to save out the information used for the plot.
        If True, will produce a csv with all the data used for plotting.
    point_size : int, default None
        If provided, the size of the matplotlib points to use in spatial plots.
        If None, the point size is dpi / 7.
    color_isi : bool, default True
        Whether the ISI should be black (False) or blue (True).
    burst_thresh : int, default 5
        How long in ms to consider the window for burst to be
    hd_predict : bool, default False
        Whether the head directional graph should be plotted with predicted HD.
    one_by_one : bool or str, default "auto"
        Whether to plot all units in each tetrode file to single plot. Options:
        True   - plot each unit to a seperate file.
        False  - plot all units on a tetrode to the same file.
        "auto" - Determine T/F based on output_format. pdf and svg are True.

    Returns
    -------
    None

    """

    # This function is used to save out the plotting data to a csv.
    def save_dicts_to_csv(filename, dicts_arr):
        """Save the last element of each arr in dicts_arr to file."""
        with open(filename, "w") as f:
            for d_arr in dicts_arr:
                if d_arr is not None:
                    d = d_arr[-1]
                    for k, v in d.items():
                        out_str = k.replace(" ", "_")
                        if isinstance(v, Iterable):
                            if isinstance(v, np.ndarray):
                                v = v.flatten()
                            else:
                                v = np.array(v).flatten()
                            str_arr = [str(x) for x in v]
                            out_str = out_str + "," + ",".join(str_arr)
                        else:
                            out_str += "," + str(v)
                        f.write(out_str + "\n")

    # Set up the arrays to hold the data for plotting.
    good_placedata = []
    good_graphdata = []
    good_wavedata = []
    good_headdata = []
    good_thetadata = []
    good_isidata = []
    good_units = []
    bad_placedata = []
    bad_graphdata = []
    bad_wavedata = []
    bad_headdata = []
    bad_thetadata = []
    bad_isidata = []
    bad_units = []
    skipped = 0

    if point_size is None:
        point_size = dpi / 7

    for i, data in enumerate(collection):
        try:
            data_idx, unit_idx = collection._index_to_data_pos(i)
            filename = collection.get_file_dict()["Spike"][data_idx][0]
            unit_number = collection.get_units(data_idx)[unit_idx]
            logging.info("Working on {} unit {} out of {}".format(
                filename, unit_number, len(collection.get_units(data_idx))))

            count = data.spike.get_unit_spikes_count()
            # Skip very low count cells
            if count < 5:
                skipped += 1
                logging.warning("Skipping as only {} spikes".format(count))
            else:
                duration = data.spike.get_duration()
                good = True

                # Place cell filtering is based on
                # https://www.nature.com/articles/ncomms11824
                # Activity-plasticity of hippocampus place maps
                # Schoenenberger et al, 2016
                if filter_low_freq:
                    if (count / duration) < 0.25 or (count / duration) > 7:
                        logging.info("Reject spike frequency {}".format(
                            count / duration))
                        good = False

                if good and filter_place_cells:
                    skaggs = data.loc_shuffle(nshuff=1)
                    bad_sparsity = skaggs['refSparsity'] >= 0.3
                    bad_cohere = skaggs['refCoherence'] <= 0.55
                    first_str_part = "Accept "

                    if bad_sparsity or bad_cohere:
                        good = False
                        first_str_part = "Reject "

                    logging.info(
                        (first_str_part + "Skaggs {:2f}, " +
                         "Sparsity {:2f}, " + "Coherence {:2f}").format(
                             skaggs['refSkaggs'], skaggs['refSparsity'],
                             skaggs['refCoherence']))

                if good:
                    good_units.append(unit_idx)
                    placedata = good_placedata
                    graphdata = good_graphdata
                    wavedata = good_wavedata
                    headdata = good_headdata
                    thetadata = good_thetadata
                    isidata = good_isidata
                else:
                    bad_units.append(unit_idx)
                    placedata = bad_placedata
                    graphdata = bad_graphdata
                    wavedata = bad_wavedata
                    headdata = bad_headdata
                    thetadata = bad_thetadata
                    isidata = bad_isidata

                # In case somehow there is a double counting bug
                if ((len(bad_units) + len(good_units)) > len(
                        collection.get_units(data_idx))):
                    save_bad = bad_units
                    save_good = good_units
                    good_placedata = []
                    good_graphdata = []
                    good_wavedata = []
                    good_headdata = []
                    good_thetadata = []
                    good_isidata = []
                    good_units = []
                    bad_placedata = []
                    bad_graphdata = []
                    bad_wavedata = []
                    bad_headdata = []
                    bad_thetadata = []
                    bad_isidata = []
                    bad_units = []
                    skipped = 0
                    raise Exception("Accumlated more units than possible " +
                                    "bad {} good {} total {}".format(
                                        save_bad, save_good,
                                        collection.get_units(data_idx)))

                if "Place" in output:
                    placedata.append(data.place())
                else:
                    bad_placedata = None
                    good_placedata = None
                if "LowAC" in output:
                    graphdata.append(data.isi_corr(bins=1, bound=[-10, 10]))
                else:
                    bad_graphdata = None
                    good_graphdata = None
                if "Wave" in output:
                    wavedata.append(data.wave_property())
                else:
                    bad_wavedata = None
                    good_wavedata = None
                if "HD" in output:
                    headdata.append(data.hd_rate())
                else:
                    bad_headdata = None
                    good_headdata = None
                if "Theta" in output:
                    thetadata.append(
                        data.theta_index(bins=2, bound=[-350, 350]))
                else:
                    bad_thetadata = None
                    good_thetadata = None
                if "HighISI" in output:
                    isidata.append(
                        data.isi(bins=int(isi_bound / isi_bin_length),
                                 bound=[0, isi_bound]))
                else:
                    bad_isidata = None
                    good_isidata = None

                if save_data:
                    try:
                        spike_name = os.path.basename(filename)
                        final_bname, final_ext = os.path.splitext(spike_name)
                        final_ext = final_ext[1:]
                        f_dir = os.path.dirname(filename)

                        data_basename = (final_bname + "_" + final_ext + "_" +
                                         str(unit_number) + opt_end + ".csv")
                        if base_dir is not None:
                            main_dir = base_dir
                            out_base = f_dir[len(base_dir + os.sep):]
                            if len(out_base) != 0:
                                out_base = ("--").join(out_base.split(os.sep))
                                data_basename = out_base + "--" + data_basename
                        else:
                            main_dir = f_dir
                        out_name = os.path.join(main_dir, out_dirname, "data",
                                                data_basename)
                        make_dir_if_not_exists(out_name)
                        save_dicts_to_csv(out_name, [
                            placedata, graphdata, wavedata, headdata,
                            thetadata, isidata
                        ])
                    except Exception as e:
                        log_exception(
                            e, "Occurred during place cell data saving on" +
                            " {} unit {} name {} in {}".format(
                                data_idx, unit_number, spike_name, main_dir))

            # Save the accumulated information
            if unit_idx == len(collection.get_units(data_idx)) - 1:
                if ((len(bad_units) + len(good_units)) !=
                        len(collection.get_units(data_idx)) - skipped):
                    logging.error("Good {}, Bad {}, Total {}".format(
                        good_units, bad_units, collection.get_units(data_idx)))
                    raise ValueError(
                        "Did not cover all units in the collection")
                spike_name = os.path.basename(filename)
                final_bname, final_ext = os.path.splitext(spike_name)
                final_ext = final_ext[1:]
                f_dir = os.path.dirname(filename)

                out_basename = (final_bname + "_" + final_ext + opt_end + "." +
                                output_format)

                if base_dir is not None:
                    main_dir = base_dir
                    out_base = f_dir[len(base_dir + os.sep):]
                    if len(out_base) != 0:
                        out_base = ("--").join(out_base.split(os.sep))
                        out_basename = out_base + "--" + out_basename
                else:
                    main_dir = f_dir

                if filter_place_cells:
                    named_units = [
                        collection.get_units(data_idx)[j] for j in good_units
                    ]
                    bad_named_units = [
                        collection.get_units(data_idx)[j] for j in bad_units
                    ]
                else:
                    named_units = collection.get_units(data_idx)
                    bad_named_units = []

                # Save figures one by one if using pdf or svg
                if one_by_one == "auto":
                    one_by_one = (output_format == "pdf") or (output_format
                                                              == "svg")

                if len(named_units) > 0:
                    if filter_place_cells:
                        logging.info(("Plotting summary for {} " +
                                      "spatial units {}").format(
                                          spike_name, named_units))
                    else:
                        logging.info(
                            ("Plotting summary for {} " + "units {}").format(
                                spike_name, named_units))

                    fig = print_place_cells(len(named_units),
                                            cols=len(output),
                                            placedata=good_placedata,
                                            graphdata=good_graphdata,
                                            wavedata=good_wavedata,
                                            headdata=good_headdata,
                                            thetadata=good_thetadata,
                                            isidata=good_isidata,
                                            size_multiplier=4,
                                            point_size=point_size,
                                            units=named_units,
                                            fixed_color=fixed_color,
                                            output=output,
                                            color_isi=color_isi,
                                            burst_ms=burst_thresh,
                                            one_by_one=one_by_one,
                                            raster=one_by_one,
                                            hd_predict=hd_predict)

                    if one_by_one:
                        for k, f in enumerate(fig):
                            unit_number = named_units[k]
                            iname = (out_basename[:-4] + "_" +
                                     str(unit_number) + out_basename[-4:])
                            if filter_low_freq or filter_place_cells:
                                out_name = os.path.join(
                                    main_dir, out_dirname, "good", iname)
                            else:
                                out_name = os.path.join(
                                    main_dir, out_dirname, iname)

                            logging.info(
                                "Saving place cell figure to {}".format(
                                    out_name))
                            make_dir_if_not_exists(out_name)
                            f.savefig(out_name, dpi=dpi, format=output_format)

                    else:
                        if filter_low_freq or filter_place_cells:
                            out_name = os.path.join(main_dir, out_dirname,
                                                    "good", out_basename)
                        else:
                            out_name = os.path.join(main_dir, out_dirname,
                                                    out_basename)
                        logging.info(
                            "Saving place cell figure to {}".format(out_name))
                        make_dir_if_not_exists(out_name)
                        fig.savefig(out_name, dpi=dpi, format=output_format)

                    close("all")
                    gc.collect()

                if len(bad_named_units) > 0:
                    logging.info(("Plotting bad summary for {} " +
                                  "non-spatial units {}").format(
                                      spike_name, bad_named_units))
                    fig = print_place_cells(len(bad_named_units),
                                            cols=len(output),
                                            placedata=bad_placedata,
                                            graphdata=bad_graphdata,
                                            wavedata=bad_wavedata,
                                            headdata=bad_headdata,
                                            thetadata=bad_thetadata,
                                            isidata=bad_isidata,
                                            size_multiplier=4,
                                            point_size=point_size,
                                            units=bad_named_units,
                                            fixed_color=fixed_color,
                                            output=output,
                                            color_isi=color_isi,
                                            burst_ms=burst_thresh,
                                            one_by_one=one_by_one,
                                            raster=one_by_one,
                                            hd_predict=hd_predict)

                    if one_by_one:
                        for k, f in enumerate(fig):
                            unit_number = bad_named_units[k]
                            iname = (out_basename[:-4] + "_" +
                                     str(unit_number) + out_basename[-4:])
                            out_name = os.path.join(main_dir, out_dirname,
                                                    "bad", iname)

                            logging.info(
                                "Saving place cell fig to {}".format(out_name))
                            make_dir_if_not_exists(out_name)
                            f.savefig(out_name, dpi=dpi, format=output_format)

                    else:
                        out_name = os.path.join(main_dir, out_dirname, "bad",
                                                out_basename)
                        logging.info(
                            "Saving place cell fig to {}".format(out_name))
                        make_dir_if_not_exists(out_name)
                        fig.savefig(out_name, dpi=dpi, format=output_format)

                    close("all")
                    gc.collect()

                good_placedata = []
                good_graphdata = []
                good_wavedata = []
                good_headdata = []
                good_thetadata = []
                good_isidata = []
                good_units = []
                bad_placedata = []
                bad_graphdata = []
                bad_wavedata = []
                bad_headdata = []
                bad_thetadata = []
                bad_isidata = []
                bad_units = []
                skipped = 0

        except Exception as e:
            log_exception(
                e, "Occurred during place cell summary on data" +
                " {} unit {} name {} in {}".format(data_idx, unit_number,
                                                   spike_name, main_dir))
    return
Пример #11
0
def cell_classification_stats(in_dir,
                              container,
                              out_name,
                              should_plot=False,
                              opt_end="",
                              output_spaces=True):
    """
    Compute a csv of cell stats for each unit in a container

    Params
    ------
    in_dir - the data output/input location
    container - the NDataContainer object to get stats for
    should_plot - whether to save some plots for this
    """
    _results = []
    spike_names = container.get_file_dict()["Spike"]
    for i, ndata in enumerate(container):
        try:
            data_idx, unit_idx = container._index_to_data_pos(i)
            name = spike_names[data_idx][0]
            parts = os.path.basename(name).split(".")

            # Setup up identifier information
            note_dict = oDict()
            dir_t = os.path.dirname(name)
            note_dict["Index"] = i
            note_dict["FullDir"] = dir_t
            if dir_t != in_dir:
                note_dict["RelDir"] = os.path.dirname(name)[len(in_dir +
                                                                os.sep):]
            else:
                note_dict["RelDir"] = ""
            note_dict["Recording"] = parts[0]
            note_dict["Tetrode"] = int(parts[-1])
            note_dict["Unit"] = ndata.get_unit_no()
            ndata.update_results(note_dict)

            # Caculate cell properties
            ndata.wave_property()
            ndata.place()
            ndata.hd_rate()
            ndata.grid()
            ndata.border()
            ndata.multiple_regression()
            isi = ndata.isi()
            ndata.burst(burst_thresh=6)
            phase_dist = ndata.phase_dist()
            theta_index = ndata.theta_index()
            ndata.bandpower_ratio([5, 11], [1.5, 4],
                                  1.6,
                                  relative=True,
                                  first_name="Theta",
                                  second_name="Delta")
            result = copy(
                ndata.get_results(spaces_to_underscores=not output_spaces))
            _results.append(result)

            if should_plot:
                plot_loc = os.path.join(
                    in_dir, "nc_plots", parts[0] + "_" + parts[-1] + "_" +
                    str(ndata.get_unit_no()) + "_phase" + opt_end + ".png")
                make_dir_if_not_exists(plot_loc)
                fig1, fig2, fig3 = nc_plot.spike_phase(phase_dist)
                fig2.savefig(plot_loc)
                plt.close("all")

                if unit_idx == len(container.get_units(data_idx)) - 1:
                    plot_loc = os.path.join(
                        in_dir, "nc_plots",
                        parts[0] + "_lfp" + opt_end + ".png")
                    make_dir_if_not_exists(plot_loc)

                    lfp_spectrum = ndata.spectrum()
                    fig = nc_plot.lfp_spectrum(lfp_spectrum)
                    fig.savefig(plot_loc)
                    plt.close(fig)

        except Exception as e:
            print("WARNING: Failed to analyse {} unit {}".format(
                os.path.basename(name), note_dict["Unit"]))
            log_exception(
                e, "Failed on {} unit {}".format(os.path.basename(name),
                                                 note_dict["Unit"]))

    # Save the cell statistics
    make_dir_if_not_exists(out_name)
    save_dicts_to_csv(out_name, _results)
    _results.clear()
def cell_classification_stats(
        in_dir, container, out_name,
        should_plot=False, opt_end="", output_spaces=True,
        good_cells=None):
    """
    Compute a csv of cell stats for each unit in a container

    Params
    ------
    in_dir - the data output/input location
    container - the NDataContainer object to get stats for
    should_plot - whether to save some plots for this
    """
    _results = []
    spike_names = container.get_file_dict()["Spike"]
    overall_count = 0
    for i in range(len(container)):
        try:
            data_idx, unit_idx = container._index_to_data_pos(i)
            name = spike_names[data_idx][0]
            parts = os.path.basename(name).split(".")
            o_name = os.path.join(
                os.path.dirname(name)[len(in_dir + os.sep):],
                parts[0])
            note_dict = oDict()
            # Setup up identifier information
            dir_t = os.path.dirname(name)
            note_dict["Index"] = i
            note_dict["FullDir"] = dir_t
            if dir_t != in_dir:
                note_dict["RelDir"] = os.path.dirname(
                    name)[len(in_dir + os.sep):]
            else:
                note_dict["RelDir"] = ""
            note_dict["Recording"] = parts[0]
            note_dict["Tetrode"] = int(parts[-1])
            if good_cells is not None:
                check = [
                    os.path.normpath(name[len(in_dir + os.sep):]),
                    container.get_units(data_idx)[unit_idx]]
                if check not in good_cells:
                    continue
            ndata = container[i]
            overall_count += 1
            print("Working on unit {} of {}: {}, T{}, U{}".format(
                i + 1, len(container), o_name, parts[-1], ndata.get_unit_no()))

            note_dict["Unit"] = ndata.get_unit_no()
            ndata.update_results(note_dict)

            # Caculate cell properties
            ndata.wave_property()
            ndata.place()
            isi = ndata.isi()
            ndata.burst(burst_thresh=6)
            theta_index = ndata.theta_index()
            ndata._results["IsPyramidal"] = cell_type(ndata)
            result = copy(ndata.get_results(
                spaces_to_underscores=not output_spaces))
            _results.append(result)

        except Exception as e:
            to_out = note_dict.get("Unit", "NA")
            print("WARNING: Failed to analyse {} unit {}".format(
                os.path.basename(name), to_out))
            log_exception(e, "Failed on {} unit {}".format(
                os.path.basename(name), to_out))

    # Save the cell statistics
    make_dir_if_not_exists(out_name)
    print("Analysed {} cells in total".format(overall_count))
    save_dicts_to_csv(out_name, _results)
    _results.clear()