def save_object(self, obj=None): """ Store a NeuroChaT dataset to the HDF5 file. It resolves the name first and then stores the data in the storage path. Parameters ---------- obj One of the NeuroChaT data types Returns ------- None """ try: obj_type = obj.get_type() except BaseException as e: log_exception(e, 'Object passed is not a neurochat data type') try: if os.path.isfile(obj.get_filename()): fun = getattr(self, 'save_' + obj_type) fun(obj) except BaseException as e: log_exception(e, 'Saving hdf5 dataset')
def save_dicts_to_csv(filename, in_dicts): """ Save a list of dictionaries to a csv. The headers are set as the maximal set of keys in in_dicts. It is assumed that all other dicts will have a subset of these keys. Each entry in the dict is saved to a row of the csv, so it is assumed that the values in the dict are mostly floats / ints / etc. """ # first, find the dict with the most keys max_key = in_dicts[0].keys() for in_dict in in_dicts: names = in_dict.keys() if len(names) > len(max_key): max_key = names # Then append other keys if still missing keys for in_dict in in_dicts: names = in_dict.keys() for name in names: if not name in max_key: max_key.append(name) try: with open(filename, 'w', newline='') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=max_key) writer.writeheader() for in_dict in in_dicts: writer.writerow(in_dict) except Exception as e: log_exception(e, "When {} saving to csv".format(filename))
def main(dir, bin_size, bin_bound): container = NDataContainer(load_on_fly=True) container.add_axona_files_from_dir(dir, recursive=True) container.setup() print(container.string_repr(True)) event = NEvent() last_stm_name = None dict_list = [] for i in range(len(container)): try: result_dict = OrderedDict() data_index, _ = container._index_to_data_pos(i) stm_name = container.get_file_dict("STM")[data_index][0] data = container[i] # Add extra keys spike_file = data.spike.get_filename() spike_dir = os.path.dirname(spike_file) spike_name = os.path.basename(spike_file) spike_name_only, spike_ext = os.path.splitext(spike_name) spike_ext = spike_ext[1:] result_dict["Dir"] = spike_dir result_dict["Name"] = spike_name_only result_dict["Tet"] = int(spike_ext) result_dict["Unit"] = data.spike.get_unit_no() # Do analysis if last_stm_name != stm_name: event.load(stm_name, 'Axona') last_stm_name = stm_name graph_data = event.psth(data.spike, bins=bin_size, bound=bin_bound) spike_count = data.get_unit_spikes_count() result_dict["Num_Spikes"] = spike_count # Bin renaming for (b, v) in zip(graph_data["all_bins"][:-1], graph_data["psth"]): result_dict[str(b)] = v dict_list.append(result_dict) # Do plotting name = (spike_name_only + "_" + spike_ext + "_" + str(result_dict["Unit"]) + ".png") plot_name = os.path.join(dir, "psth_results", name) make_dir_if_not_exists(plot_name) plot_psth(graph_data, plot_name) print("Saved psth to {}".format(plot_name)) except Exception as e: log_exception(e, "During stimulation batch at {}".format(i)) dict_list.append(result_dict) fname = os.path.join(dir, "psth_results", "psth.csv") save_dicts_to_csv(fname, dict_list) print("Saved results to {}".format(fname))
def compare_two(file1, file2, index): """ Compare spike file1 against spike file2. Cluster similarity is calculated. index is used for file naming, and saves to: output1.csv if index is 1, for example. """ nclust1 = NClust() nclust2 = NClust() nclust1.load(file1, "Axona") nclust2.load(file2, "Axona") units1 = nclust1.get_unit_list() units2 = nclust2.get_unit_list() bc_matrix = np.zeros(shape=(len(units1), len(units2)), dtype=np.float32) hd_matrix = np.zeros(shape=(len(units1), len(units2)), dtype=np.float32) for i, unit1 in enumerate(units1): for j, unit2 in enumerate(units2): try: bc, hd = nclust1.cluster_similarity(nclust2, unit1, unit2) print("({}, {}): BC {:.2f}, HD {:.2f}".format( i + 1, j + 1, bc, hd)) bc_matrix[i, j] = bc hd_matrix[i, j] = hd except Exception as e: log_exception(e, "at ({}, {})".format(i, j)) bc_matrix[i, j] = np.nan hd_matrix[i, j] = np.nan out_filename = "output{}.csv".format(index) with open(out_filename, "w") as f: # Save the BC # for i in range(len(units1)): # out_str = "" # for j in range(len(units2)): # out_str += str(bc_matrix[i, j]) + "," # out_str = out_str[:-1] + "\n" # f.write(out_str) # Save the HD for i in range(len(units1)): out_str = "" for j in range(len(units2)): out_str += str(hd_matrix[i, j]) + "," out_str = out_str[:-1] + "\n" f.write(out_str)
def save_dicts_to_csv(filename, in_dicts): """Save a dictionary to a csv""" # find the dict with the most keys max_key = [] for in_dict in in_dicts: names = in_dicts[0].keys() if len(names) > len(max_key): max_key = names try: with open(filename, 'w', newline='') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=names) writer.writeheader() for in_dict in in_dicts: writer.writerow(in_dict) except Exception as e: log_exception(e, "When {} saving to csv".format(filename))
def save_dataset(self, path=None, name=None, data=None, create_group=True): """ Store a dataset to a specific path. Parameters ---------- path : str Path of a group in HDF5 file name : str Name of the new dataset data : ndarray or list of numbers Data to be stored create_group : bool If True, creates a new group if the 'path' is not in the file Returns ------- None """ if not path: logging.error('Invalid group path specified!') if not name: logging.error('Please provide a name for the dataset!') if (path in self.f) or create_group: g = self.f.require_group(path) if name in g: del g[name] # This conditional restricts the None data to store, need to change if isinstance(data, list): data = [np.nan if item is None else item for item in data] try: data = np.array(data) except BaseException: pass try: g.create_dataset(name=name, data=data) except BaseException as e: log_exception(e, 'Saving ' + name + ' dataset to hdf5 file') else: logging.error('hdf5 file path can be created or restored!')
def get_data_at(self, data_index, unit_index): """ Return an NData object from the given indices. Parameters ---------- data_index : int The index in the container to return data at. unit_index : int The unit number to set on the returned data. Returns ------- NData The ndata object at data_index with unit number unit_index """ if self._load_on_fly: try: if (data_index == self._last_data_pt[0] and (self._last_data_pt[1] is not None)): result = self._last_data_pt[1] else: result = NData() for key, vals in self.get_file_dict().items(): if key == "STM": continue descriptor = vals[data_index] self._load(key, descriptor, idx=data_index, ndata=result) self._last_data_pt = (data_index, result) except Exception as e: log_exception(e, "During loading data") else: result = self.get_data(data_index) if len(self.get_units()) > 0: result.set_unit_no(self.get_units(data_index)[unit_index]) return result
def file(self): """ Open the file, and returns the file object. Parameters ---------- None Returns ------- object h5py file object """ self.close() try: self.f = h5py.File(self._filename, 'a') self.initialize() except BaseException as e: log_exception(e, 'Opening hdf file' + self._filename) return self.f
def place_cell_summary( collection, dpi=150, out_dirname="nc_plots", filter_place_cells=False, filter_low_freq=False, opt_end="", base_dir=None, output_format="png", output=["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"], isi_bound=350, isi_bin_length=2, fixed_color=None, save_data=False, point_size=None, color_isi=True, burst_thresh=5, hd_predict=False): """ Quick Png spatial information summary of each cell in collection. Parameters ---------- collection : NDataCollection The collection to plot summaries of. dpi : int, default 150 Dpi of the output figures. out_dirname : str, default "nc_plots The relative name of the dir to save pngs to filter_place_cells: bool, default True Whether to filter out non spatial cells from the plots. Considered non spatial if shuffled Skaggs, Coherency or Sparsity is similar to actual values. filter_low_freq: bool, default True Filter out cells with spike freq less than 0.1Hz opt_end : str, default "" A string to append to the file output just before the extension base_dir : str, default None An optional directory to save the files to output_format : str, default png What format to save the output image in output : List of str, default ["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"] Input should be some subset and/or permutation of these isi_bound: int, default 350 How long in ms to plot the ISI to isi_bin_length: int, default 1 How long in ms the ISI bins should be save_data: bool, default False Whether to save out the information used for the plot color_isi: bool, default True Whether the ISI should be black or blue burst_thresh: int, default 5 How long in ms to consider the window for burst to be hd_predict: bool, default False Whether the head directional graph should be plotted with predicted HD. Returns ------- None """ def save_dicts_to_csv(filename, dicts_arr): """Saves the last element of each arr in dicts_arr to file""" with open(filename, "w") as f: for d_arr in dicts_arr: if d_arr is not None: d = d_arr[-1] for k, v in d.items(): out_str = k.replace(" ", "_") if isinstance(v, Iterable): if isinstance(v, np.ndarray): v = v.flatten() else: v = np.array(v).flatten() str_arr = [str(x) for x in v] out_str = out_str + "," + ",".join(str_arr) else: out_str += "," + str(v) f.write(out_str + "\n") good_placedata = [] good_graphdata = [] good_wavedata = [] good_headdata = [] good_thetadata = [] good_isidata = [] good_units = [] bad_placedata = [] bad_graphdata = [] bad_wavedata = [] bad_headdata = [] bad_thetadata = [] bad_isidata = [] bad_units = [] skipped = 0 if point_size is None: point_size = dpi / 7 for i, data in enumerate(collection): try: data_idx, unit_idx = collection._index_to_data_pos(i) print("Working on {}, {}".format(unit_idx, len( collection.get_units(data_idx)) - 1)) filename = collection.get_file_dict()["Spike"][data_idx][0] unit_number = collection.get_units(data_idx)[unit_idx] print("Working on {} unit {}".format( filename, unit_number)) count = data.spike.get_unit_spikes_count() # Skip very low count cells if count < 5: skipped += 1 print("Skipping as only {} spikes".format(count)) else: duration = data.spike.get_duration() good = True # Place cell filtering is based on # https://www.nature.com/articles/ncomms11824 # Activity-plasticity of hippocampus place maps # Schoenenberger et al, 2016 if filter_low_freq: if (count / duration) < 0.25 or (count / duration) > 7: print("Reject spike frequency {}".format( count / duration)) good = False if good and filter_place_cells: skaggs = data.loc_shuffle(nshuff=1) bad_sparsity = skaggs['refSparsity'] >= 0.3 bad_cohere = skaggs['refCoherence'] <= 0.55 first_str_part = "Accept " if bad_sparsity or bad_cohere: good = False first_str_part = "Reject " print(( first_str_part + "Skaggs {:2f}, " + "Sparsity {:2f}, " + "Coherence {:2f}").format( skaggs['refSkaggs'], skaggs['refSparsity'], skaggs['refCoherence'])) if good: good_units.append(unit_idx) placedata = good_placedata graphdata = good_graphdata wavedata = good_wavedata headdata = good_headdata thetadata = good_thetadata isidata = good_isidata else: bad_units.append(unit_idx) placedata = bad_placedata graphdata = bad_graphdata wavedata = bad_wavedata headdata = bad_headdata thetadata = bad_thetadata isidata = bad_isidata if ( (len(bad_units) + len(good_units)) > len(collection.get_units(data_idx))): save_bad = bad_units save_good = good_units good_placedata = [] good_graphdata = [] good_wavedata = [] good_headdata = [] good_thetadata = [] good_isidata = [] good_units = [] bad_placedata = [] bad_graphdata = [] bad_wavedata = [] bad_headdata = [] bad_thetadata = [] bad_isidata = [] bad_units = [] skipped = 0 print("ERROR: Found too many units in the collection") raise Exception( "Accumlated more units than possible " + "bad {} good {} total {}".format( save_bad, save_good, collection.get_units(data_idx))) if "Place" in output: placedata.append(data.place()) else: bad_placedata = None good_placedata = None if "LowAC" in output: graphdata.append(data.isi_corr(bins=1, bound=[-10, 10])) else: bad_graphdata = None good_graphdata = None if "Wave" in output: wavedata.append(data.wave_property()) else: bad_wavedata = None good_wavedata = None if "HD" in output: headdata.append(data.hd_rate()) else: bad_headdata = None good_headdata = None if "Theta" in output: thetadata.append( data.theta_index( bins=2, bound=[-350, 350])) else: bad_thetadata = None good_thetadata = None if "HighISI" in output: isidata.append( data.isi(bins=int(isi_bound / isi_bin_length), bound=[0, isi_bound])) else: bad_isidata = None good_isidata = None if save_data: try: spike_name = os.path.basename(filename) parts = spike_name.split(".") f_dir = os.path.dirname(filename) data_basename = ( parts[0] + "_" + parts[1] + "_" + str(unit_number) + opt_end + ".csv") if base_dir is not None: main_dir = base_dir out_base = f_dir[len(base_dir + os.sep):] if len(out_base) != 0: out_base = ("--").join(out_base.split(os.sep)) data_basename = out_base + "--" + data_basename else: main_dir = f_dir out_name = os.path.join( main_dir, out_dirname, "data", data_basename) make_dir_if_not_exists(out_name) save_dicts_to_csv( out_name, [placedata, graphdata, wavedata, headdata, thetadata, isidata]) except Exception as e: log_exception( e, "Occurred during place cell data saving on" + " {} unit {} name {} in {}".format( data_idx, unit_number, spike_name, main_dir)) # Save the accumulated information if unit_idx == len(collection.get_units(data_idx)) - 1: if ((len(bad_units) + len(good_units)) != len(collection.get_units(data_idx)) - skipped): print("ERROR: Did not cover all units in the collection") print("Good {}, Bad {}, Total {}".format( bad_units, good_units, collection.get_units(data_idx))) spike_name = os.path.basename(filename) parts = spike_name.split(".") f_dir = os.path.dirname(filename) out_basename = ( parts[0] + "_" + parts[1] + opt_end + "." + output_format) if base_dir is not None: main_dir = base_dir out_base = f_dir[len(base_dir + os.sep):] if len(out_base) != 0: out_base = ("--").join(out_base.split(os.sep)) out_basename = out_base + "--" + out_basename else: main_dir = f_dir if filter_place_cells: named_units = [ collection.get_units(data_idx)[j] for j in good_units] bad_named_units = [ collection.get_units(data_idx)[j] for j in bad_units] else: named_units = collection.get_units(data_idx) bad_named_units = [] # Save figures one by one if using pdf or svg one_by_one = (output_format == "pdf") or ( output_format == "svg") if len(named_units) > 0: if filter_place_cells: print(( "Plotting summary for {} " + "spatial units {}").format( spike_name, named_units)) else: print(( "Plotting summary for {} " + "units {}").format( spike_name, named_units)) fig = print_place_cells( len(named_units), cols=len(output), placedata=good_placedata, graphdata=good_graphdata, wavedata=good_wavedata, headdata=good_headdata, thetadata=good_thetadata, isidata=good_isidata, size_multiplier=4, point_size=point_size, units=named_units, fixed_color=fixed_color, output=output, color_isi=color_isi, burst_ms=burst_thresh, one_by_one=one_by_one, raster=one_by_one, hd_predict=hd_predict) if one_by_one: for k, f in enumerate(fig): unit_number = named_units[k] iname = ( out_basename[:-4] + "_" + str(unit_number) + out_basename[-4:]) if filter_low_freq or filter_place_cells: out_name = os.path.join( main_dir, out_dirname, "good", iname) else: out_name = os.path.join( main_dir, out_dirname, iname) print("Saving place cell figure to {}".format( out_name)) make_dir_if_not_exists(out_name) f.savefig(out_name, dpi=dpi, format=output_format) else: if filter_low_freq or filter_place_cells: out_name = os.path.join( main_dir, out_dirname, "good", out_basename) else: out_name = os.path.join( main_dir, out_dirname, out_basename) print("Saving place cell figure to {}".format( out_name)) make_dir_if_not_exists(out_name) fig.savefig(out_name, dpi=dpi, format=output_format) close("all") gc.collect() if len(bad_named_units) > 0: print(( "Plotting bad summary for {} " + "non-spatial units {}").format( spike_name, bad_named_units)) fig = print_place_cells( len(bad_named_units), cols=len(output), placedata=bad_placedata, graphdata=bad_graphdata, wavedata=bad_wavedata, headdata=bad_headdata, thetadata=bad_thetadata, isidata=bad_isidata, size_multiplier=4, point_size=point_size, units=bad_named_units, fixed_color=fixed_color, output=output, color_isi=color_isi, burst_ms=burst_thresh, one_by_one=one_by_one, raster=one_by_one, hd_predict=hd_predict) if one_by_one: for k, f in enumerate(fig): unit_number = bad_named_units[k] iname = ( out_basename[:-4] + "_" + str(unit_number) + out_basename[-4:]) out_name = os.path.join( main_dir, out_dirname, "bad", iname) print("Saving place cell figure to {}".format( out_name)) make_dir_if_not_exists(out_name) f.savefig(out_name, dpi=dpi, format=output_format) else: out_name = os.path.join( main_dir, out_dirname, "bad", out_basename) print("Saving place cell figure to {}".format( out_name)) make_dir_if_not_exists(out_name) fig.savefig(out_name, dpi=dpi, format=output_format) close("all") gc.collect() good_placedata = [] good_graphdata = [] good_wavedata = [] good_headdata = [] good_thetadata = [] good_isidata = [] good_units = [] bad_placedata = [] bad_graphdata = [] bad_wavedata = [] bad_headdata = [] bad_thetadata = [] bad_isidata = [] bad_units = [] skipped = 0 except Exception as e: log_exception( e, "Occurred during place cell summary on data" + " {} unit {} name {} in {}".format( data_idx, unit_number, spike_name, main_dir)) return
def place_cell_summary( collection, dpi=150, out_dirname="nc_plots", filter_place_cells=False, filter_low_freq=False, opt_end="", base_dir=None, output_format="png", output=["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"], isi_bound=350, isi_bin_length=2, fixed_color=None, save_data=False, point_size=None, color_isi=True, burst_thresh=5, hd_predict=False, one_by_one="auto"): """ Perform spatial information summary of each cell in collection. The function is named as place_cell_summary as it can be a quick visual way to look for place cells. However, it can be used to look other spatial properties, such as head directionality. The output image (any matplotlib format is supported) contains the information from and in the order of the output argument. Parameters ---------- collection : NDataCollection The collection to plot spatial summaries of. dpi : int, default 150 Dpi of the output figures if the output_format supports dpi. out_dirname : str, default "nc_plots The relative name of the directory to save output to. filter_place_cells: bool, default True Whether to filter out non spatial place cells from the plots. A unit is considered a spatial place cell if: Sparsity < 0.3 and Coherence > 0.55. Recommended filter_low_freq=True if this flag is True. See https://www.nature.com/articles/ncomms11824. filter_low_freq: bool, default True Filter out cells with spike freq less than 0.1Hz opt_end : str, default "" A string to append to the file output just before the extension Can be used if doing multiple runs to avoid overwriting output files. base_dir : str, default None An optional directory to save all the files to. If not provided, the files will save to the location of the input data. output_format : str, default "png" What format to save the output image in. output : List of str, default ["Wave", "Path", "Place", "HD", "LowAC", "Theta", "HighISI"] Provided argument should be some subset and/or permutation of these. isi_bound : int, default 350 How long in ms to plot the ISI for. isi_bin_length : int, default 2 How long in ms the ISI bins should be. fixed_color : str, default None If provided, will plot all units with this color instead of auto color. Can be any matplotlib compatible color. save_data : bool, default False Whether to save out the information used for the plot. If True, will produce a csv with all the data used for plotting. point_size : int, default None If provided, the size of the matplotlib points to use in spatial plots. If None, the point size is dpi / 7. color_isi : bool, default True Whether the ISI should be black (False) or blue (True). burst_thresh : int, default 5 How long in ms to consider the window for burst to be hd_predict : bool, default False Whether the head directional graph should be plotted with predicted HD. one_by_one : bool or str, default "auto" Whether to plot all units in each tetrode file to single plot. Options: True - plot each unit to a seperate file. False - plot all units on a tetrode to the same file. "auto" - Determine T/F based on output_format. pdf and svg are True. Returns ------- None """ # This function is used to save out the plotting data to a csv. def save_dicts_to_csv(filename, dicts_arr): """Save the last element of each arr in dicts_arr to file.""" with open(filename, "w") as f: for d_arr in dicts_arr: if d_arr is not None: d = d_arr[-1] for k, v in d.items(): out_str = k.replace(" ", "_") if isinstance(v, Iterable): if isinstance(v, np.ndarray): v = v.flatten() else: v = np.array(v).flatten() str_arr = [str(x) for x in v] out_str = out_str + "," + ",".join(str_arr) else: out_str += "," + str(v) f.write(out_str + "\n") # Set up the arrays to hold the data for plotting. good_placedata = [] good_graphdata = [] good_wavedata = [] good_headdata = [] good_thetadata = [] good_isidata = [] good_units = [] bad_placedata = [] bad_graphdata = [] bad_wavedata = [] bad_headdata = [] bad_thetadata = [] bad_isidata = [] bad_units = [] skipped = 0 if point_size is None: point_size = dpi / 7 for i, data in enumerate(collection): try: data_idx, unit_idx = collection._index_to_data_pos(i) filename = collection.get_file_dict()["Spike"][data_idx][0] unit_number = collection.get_units(data_idx)[unit_idx] logging.info("Working on {} unit {} out of {}".format( filename, unit_number, len(collection.get_units(data_idx)))) count = data.spike.get_unit_spikes_count() # Skip very low count cells if count < 5: skipped += 1 logging.warning("Skipping as only {} spikes".format(count)) else: duration = data.spike.get_duration() good = True # Place cell filtering is based on # https://www.nature.com/articles/ncomms11824 # Activity-plasticity of hippocampus place maps # Schoenenberger et al, 2016 if filter_low_freq: if (count / duration) < 0.25 or (count / duration) > 7: logging.info("Reject spike frequency {}".format( count / duration)) good = False if good and filter_place_cells: skaggs = data.loc_shuffle(nshuff=1) bad_sparsity = skaggs['refSparsity'] >= 0.3 bad_cohere = skaggs['refCoherence'] <= 0.55 first_str_part = "Accept " if bad_sparsity or bad_cohere: good = False first_str_part = "Reject " logging.info( (first_str_part + "Skaggs {:2f}, " + "Sparsity {:2f}, " + "Coherence {:2f}").format( skaggs['refSkaggs'], skaggs['refSparsity'], skaggs['refCoherence'])) if good: good_units.append(unit_idx) placedata = good_placedata graphdata = good_graphdata wavedata = good_wavedata headdata = good_headdata thetadata = good_thetadata isidata = good_isidata else: bad_units.append(unit_idx) placedata = bad_placedata graphdata = bad_graphdata wavedata = bad_wavedata headdata = bad_headdata thetadata = bad_thetadata isidata = bad_isidata # In case somehow there is a double counting bug if ((len(bad_units) + len(good_units)) > len( collection.get_units(data_idx))): save_bad = bad_units save_good = good_units good_placedata = [] good_graphdata = [] good_wavedata = [] good_headdata = [] good_thetadata = [] good_isidata = [] good_units = [] bad_placedata = [] bad_graphdata = [] bad_wavedata = [] bad_headdata = [] bad_thetadata = [] bad_isidata = [] bad_units = [] skipped = 0 raise Exception("Accumlated more units than possible " + "bad {} good {} total {}".format( save_bad, save_good, collection.get_units(data_idx))) if "Place" in output: placedata.append(data.place()) else: bad_placedata = None good_placedata = None if "LowAC" in output: graphdata.append(data.isi_corr(bins=1, bound=[-10, 10])) else: bad_graphdata = None good_graphdata = None if "Wave" in output: wavedata.append(data.wave_property()) else: bad_wavedata = None good_wavedata = None if "HD" in output: headdata.append(data.hd_rate()) else: bad_headdata = None good_headdata = None if "Theta" in output: thetadata.append( data.theta_index(bins=2, bound=[-350, 350])) else: bad_thetadata = None good_thetadata = None if "HighISI" in output: isidata.append( data.isi(bins=int(isi_bound / isi_bin_length), bound=[0, isi_bound])) else: bad_isidata = None good_isidata = None if save_data: try: spike_name = os.path.basename(filename) final_bname, final_ext = os.path.splitext(spike_name) final_ext = final_ext[1:] f_dir = os.path.dirname(filename) data_basename = (final_bname + "_" + final_ext + "_" + str(unit_number) + opt_end + ".csv") if base_dir is not None: main_dir = base_dir out_base = f_dir[len(base_dir + os.sep):] if len(out_base) != 0: out_base = ("--").join(out_base.split(os.sep)) data_basename = out_base + "--" + data_basename else: main_dir = f_dir out_name = os.path.join(main_dir, out_dirname, "data", data_basename) make_dir_if_not_exists(out_name) save_dicts_to_csv(out_name, [ placedata, graphdata, wavedata, headdata, thetadata, isidata ]) except Exception as e: log_exception( e, "Occurred during place cell data saving on" + " {} unit {} name {} in {}".format( data_idx, unit_number, spike_name, main_dir)) # Save the accumulated information if unit_idx == len(collection.get_units(data_idx)) - 1: if ((len(bad_units) + len(good_units)) != len(collection.get_units(data_idx)) - skipped): logging.error("Good {}, Bad {}, Total {}".format( good_units, bad_units, collection.get_units(data_idx))) raise ValueError( "Did not cover all units in the collection") spike_name = os.path.basename(filename) final_bname, final_ext = os.path.splitext(spike_name) final_ext = final_ext[1:] f_dir = os.path.dirname(filename) out_basename = (final_bname + "_" + final_ext + opt_end + "." + output_format) if base_dir is not None: main_dir = base_dir out_base = f_dir[len(base_dir + os.sep):] if len(out_base) != 0: out_base = ("--").join(out_base.split(os.sep)) out_basename = out_base + "--" + out_basename else: main_dir = f_dir if filter_place_cells: named_units = [ collection.get_units(data_idx)[j] for j in good_units ] bad_named_units = [ collection.get_units(data_idx)[j] for j in bad_units ] else: named_units = collection.get_units(data_idx) bad_named_units = [] # Save figures one by one if using pdf or svg if one_by_one == "auto": one_by_one = (output_format == "pdf") or (output_format == "svg") if len(named_units) > 0: if filter_place_cells: logging.info(("Plotting summary for {} " + "spatial units {}").format( spike_name, named_units)) else: logging.info( ("Plotting summary for {} " + "units {}").format( spike_name, named_units)) fig = print_place_cells(len(named_units), cols=len(output), placedata=good_placedata, graphdata=good_graphdata, wavedata=good_wavedata, headdata=good_headdata, thetadata=good_thetadata, isidata=good_isidata, size_multiplier=4, point_size=point_size, units=named_units, fixed_color=fixed_color, output=output, color_isi=color_isi, burst_ms=burst_thresh, one_by_one=one_by_one, raster=one_by_one, hd_predict=hd_predict) if one_by_one: for k, f in enumerate(fig): unit_number = named_units[k] iname = (out_basename[:-4] + "_" + str(unit_number) + out_basename[-4:]) if filter_low_freq or filter_place_cells: out_name = os.path.join( main_dir, out_dirname, "good", iname) else: out_name = os.path.join( main_dir, out_dirname, iname) logging.info( "Saving place cell figure to {}".format( out_name)) make_dir_if_not_exists(out_name) f.savefig(out_name, dpi=dpi, format=output_format) else: if filter_low_freq or filter_place_cells: out_name = os.path.join(main_dir, out_dirname, "good", out_basename) else: out_name = os.path.join(main_dir, out_dirname, out_basename) logging.info( "Saving place cell figure to {}".format(out_name)) make_dir_if_not_exists(out_name) fig.savefig(out_name, dpi=dpi, format=output_format) close("all") gc.collect() if len(bad_named_units) > 0: logging.info(("Plotting bad summary for {} " + "non-spatial units {}").format( spike_name, bad_named_units)) fig = print_place_cells(len(bad_named_units), cols=len(output), placedata=bad_placedata, graphdata=bad_graphdata, wavedata=bad_wavedata, headdata=bad_headdata, thetadata=bad_thetadata, isidata=bad_isidata, size_multiplier=4, point_size=point_size, units=bad_named_units, fixed_color=fixed_color, output=output, color_isi=color_isi, burst_ms=burst_thresh, one_by_one=one_by_one, raster=one_by_one, hd_predict=hd_predict) if one_by_one: for k, f in enumerate(fig): unit_number = bad_named_units[k] iname = (out_basename[:-4] + "_" + str(unit_number) + out_basename[-4:]) out_name = os.path.join(main_dir, out_dirname, "bad", iname) logging.info( "Saving place cell fig to {}".format(out_name)) make_dir_if_not_exists(out_name) f.savefig(out_name, dpi=dpi, format=output_format) else: out_name = os.path.join(main_dir, out_dirname, "bad", out_basename) logging.info( "Saving place cell fig to {}".format(out_name)) make_dir_if_not_exists(out_name) fig.savefig(out_name, dpi=dpi, format=output_format) close("all") gc.collect() good_placedata = [] good_graphdata = [] good_wavedata = [] good_headdata = [] good_thetadata = [] good_isidata = [] good_units = [] bad_placedata = [] bad_graphdata = [] bad_wavedata = [] bad_headdata = [] bad_thetadata = [] bad_isidata = [] bad_units = [] skipped = 0 except Exception as e: log_exception( e, "Occurred during place cell summary on data" + " {} unit {} name {} in {}".format(data_idx, unit_number, spike_name, main_dir)) return
def cell_classification_stats(in_dir, container, out_name, should_plot=False, opt_end="", output_spaces=True): """ Compute a csv of cell stats for each unit in a container Params ------ in_dir - the data output/input location container - the NDataContainer object to get stats for should_plot - whether to save some plots for this """ _results = [] spike_names = container.get_file_dict()["Spike"] for i, ndata in enumerate(container): try: data_idx, unit_idx = container._index_to_data_pos(i) name = spike_names[data_idx][0] parts = os.path.basename(name).split(".") # Setup up identifier information note_dict = oDict() dir_t = os.path.dirname(name) note_dict["Index"] = i note_dict["FullDir"] = dir_t if dir_t != in_dir: note_dict["RelDir"] = os.path.dirname(name)[len(in_dir + os.sep):] else: note_dict["RelDir"] = "" note_dict["Recording"] = parts[0] note_dict["Tetrode"] = int(parts[-1]) note_dict["Unit"] = ndata.get_unit_no() ndata.update_results(note_dict) # Caculate cell properties ndata.wave_property() ndata.place() ndata.hd_rate() ndata.grid() ndata.border() ndata.multiple_regression() isi = ndata.isi() ndata.burst(burst_thresh=6) phase_dist = ndata.phase_dist() theta_index = ndata.theta_index() ndata.bandpower_ratio([5, 11], [1.5, 4], 1.6, relative=True, first_name="Theta", second_name="Delta") result = copy( ndata.get_results(spaces_to_underscores=not output_spaces)) _results.append(result) if should_plot: plot_loc = os.path.join( in_dir, "nc_plots", parts[0] + "_" + parts[-1] + "_" + str(ndata.get_unit_no()) + "_phase" + opt_end + ".png") make_dir_if_not_exists(plot_loc) fig1, fig2, fig3 = nc_plot.spike_phase(phase_dist) fig2.savefig(plot_loc) plt.close("all") if unit_idx == len(container.get_units(data_idx)) - 1: plot_loc = os.path.join( in_dir, "nc_plots", parts[0] + "_lfp" + opt_end + ".png") make_dir_if_not_exists(plot_loc) lfp_spectrum = ndata.spectrum() fig = nc_plot.lfp_spectrum(lfp_spectrum) fig.savefig(plot_loc) plt.close(fig) except Exception as e: print("WARNING: Failed to analyse {} unit {}".format( os.path.basename(name), note_dict["Unit"])) log_exception( e, "Failed on {} unit {}".format(os.path.basename(name), note_dict["Unit"])) # Save the cell statistics make_dir_if_not_exists(out_name) save_dicts_to_csv(out_name, _results) _results.clear()
def cell_classification_stats( in_dir, container, out_name, should_plot=False, opt_end="", output_spaces=True, good_cells=None): """ Compute a csv of cell stats for each unit in a container Params ------ in_dir - the data output/input location container - the NDataContainer object to get stats for should_plot - whether to save some plots for this """ _results = [] spike_names = container.get_file_dict()["Spike"] overall_count = 0 for i in range(len(container)): try: data_idx, unit_idx = container._index_to_data_pos(i) name = spike_names[data_idx][0] parts = os.path.basename(name).split(".") o_name = os.path.join( os.path.dirname(name)[len(in_dir + os.sep):], parts[0]) note_dict = oDict() # Setup up identifier information dir_t = os.path.dirname(name) note_dict["Index"] = i note_dict["FullDir"] = dir_t if dir_t != in_dir: note_dict["RelDir"] = os.path.dirname( name)[len(in_dir + os.sep):] else: note_dict["RelDir"] = "" note_dict["Recording"] = parts[0] note_dict["Tetrode"] = int(parts[-1]) if good_cells is not None: check = [ os.path.normpath(name[len(in_dir + os.sep):]), container.get_units(data_idx)[unit_idx]] if check not in good_cells: continue ndata = container[i] overall_count += 1 print("Working on unit {} of {}: {}, T{}, U{}".format( i + 1, len(container), o_name, parts[-1], ndata.get_unit_no())) note_dict["Unit"] = ndata.get_unit_no() ndata.update_results(note_dict) # Caculate cell properties ndata.wave_property() ndata.place() isi = ndata.isi() ndata.burst(burst_thresh=6) theta_index = ndata.theta_index() ndata._results["IsPyramidal"] = cell_type(ndata) result = copy(ndata.get_results( spaces_to_underscores=not output_spaces)) _results.append(result) except Exception as e: to_out = note_dict.get("Unit", "NA") print("WARNING: Failed to analyse {} unit {}".format( os.path.basename(name), to_out)) log_exception(e, "Failed on {} unit {}".format( os.path.basename(name), to_out)) # Save the cell statistics make_dir_if_not_exists(out_name) print("Analysed {} cells in total".format(overall_count)) save_dicts_to_csv(out_name, _results) _results.clear()