def compare_lfp(load_loc, out_loc): ndata1 = NData() ndata2 = NData() grid = np.meshgrid(np.arange(32), np.arange(32), indexing='ij') stacked = np.stack(grid, 2) pairs = stacked.reshape(-1, 2) result_a = np.zeros(shape=pairs.shape[0], dtype=np.float32) for i, pair in enumerate(pairs): load_lfp(load_loc, pair[0], ndata1) load_lfp(load_loc, pair[1], ndata2) res = get_normalised_diff(ndata1.lfp.get_samples(), ndata2.lfp.get_samples()) result_a[i] = res with open(out_loc, "w") as f: headers = [str(i) for i in range(1, 33)] out_str = ",".join(headers) f.write(out_str) out_str = "" for i, (pair, val) in enumerate(zip(pairs, result_a)): if i % 32 == 0: f.write(out_str + "\n") out_str = "" out_str += "{:.2f},".format(val) # f.write("({}, {}): {:.2f}\n".format(pair[0], pair[1], val)) f.write(out_str + "\n") return result_a
def list_all_units(self): """Print all the units in the container.""" if self._load_on_fly: for key, vals in self.get_file_dict().items(): if key == "Spike": for descriptor in vals: result = NData() self._load(key, descriptor, ndata=result) print("units are {}".format(result.get_unit_list())) else: for data in self._container: print("units are {}".format(data.get_unit_list()))
def get_data(self, index=None): """ Return the NData objects in this collection, or a specific object. Do not call this with no index if loading data on the fly. Parameters ---------- index : int Optional index to get data at Defaults to None, in which case all data is returned. Returns ------- NData or list of NData objects """ if self._load_on_fly: if index is None: logging.error("Can't load all data when loading on the fly") result = NData() for key, vals in self.get_file_dict().items(): descriptor = vals[index] self._load(key, descriptor, ndata=result) return result if index is None: return self._container if index >= self.get_num_data(): logging.error("Input index to get_data out of range") return return self._container[index]
def _load_all_data(self): """Intended private function which loads all the data.""" if self._load_on_fly: logging.error( "Don't load all the data in container if loading on the fly") for key, vals in self.get_file_dict().items(): for idx, _ in enumerate(vals): if idx >= self.get_num_data(): self.add_data(NData()) for idx, descriptor in enumerate(vals): self._load(key, descriptor, idx=idx)
def get_data_at(self, data_index, unit_index): """ Return an NData object from the given indices. Parameters ---------- data_index : int The index in the container to return data at. unit_index : int The unit number to set on the returned data. Returns ------- NData The ndata object at data_index with unit number unit_index """ if self._load_on_fly: try: if (data_index == self._last_data_pt[0] and (self._last_data_pt[1] is not None)): result = self._last_data_pt[1] else: result = NData() for key, vals in self.get_file_dict().items(): if key == "STM": continue descriptor = vals[data_index] self._load(key, descriptor, idx=data_index, ndata=result) self._last_data_pt = (data_index, result) except Exception as e: log_exception(e, "During loading data") else: result = self.get_data(data_index) if len(self.get_units()) > 0: result.set_unit_no(self.get_units(data_index)[unit_index]) return result
def set_units(self, units='all'): """ Set the list of units for the collection. Parameters ---------- units : list or str: If a list, indicates the units to use for each data object stored in the collection. "all" is accepted as a single element. Otherwise, a string "all" is expected, which sets all available units picked up from clustering files. """ self._units = [] if self.get_file_dict() == {}: raise ValueError("Can't set units for empty collection") if units == 'all': if self._load_on_fly: vals = self.get_file_dict()["Spike"] for descriptor in vals: result = NData() self._load("Spike", descriptor, ndata=result) self._units.append(result.get_unit_list()) else: for data in self.get_data(): self._units.append(data.get_unit_list()) elif isinstance(units, list): for idx, unit in enumerate(units): if unit == 'all': if self._load_on_fly: vals = self.get_file_dict()["Spike"] descriptor = vals[idx] result = NData() self._load("Spike", descriptor, ndata=result) all_units = result.get_unit_list() else: all_units = self.get_data(idx).get_unit_list() self._units.append(all_units) elif isinstance(unit, int): self._units.append([unit]) elif isinstance(unit, list): self._units.append(unit) else: logging.error( "Unrecognised type {} passed to set units".format( type(unit))) else: logging.error("Unrecognised type {} passed to set units".format( type(units))) self._unit_count = self._count_num_units()
def main(url, file_name, verbose=False): """ This demonstrates example hdf5 usage. url should be the download url of a hdf5 file. file_name is a local disk path to store that file in. if verbose is true, information about the hdf5 file is printed. """ # Fetch a file from OSF if not available on disk if not os.path.exists(file_name): print("Downloading file from {} to {}".format(url, file_name)) urllib.request.urlretrieve(url, file_name) else: print("Using {}".format(file_name)) if verbose: from skm_pyutils.py_print import print_h5 print_h5(file_name) # Set up the h5 paths spike_path = "/processing/Shank/7" pos_path = "/processing/Behavioural/Position" lfp_path = "/processing/Neural Continuous/LFP/eeg" # HDF requires filename + path_in_hdf5 # This function just does that def to_hdf_path(x): return file_name + "+" + x # Load in that data ndata = NData() ndata.set_data_format("NWB") ndata.set_spatial_file(to_hdf_path(pos_path)) ndata.set_spike_file(to_hdf_path(spike_path)) ndata.set_lfp_file(to_hdf_path(lfp_path)) ndata.load() # Choose the unit number from those available print("Units are:", ndata.get_unit_list()) unit_no = int(input("Unit to use:\n").strip()) ndata.set_unit_no(unit_no) print("Loaded:", ndata) # Perform analysis ndata.place() ndata.wave_property() # print(ndata.get_results()["Spatial Skaggs"]) # print(ndata.get_results()["Mean Spiking Freq"]) print(ndata.get_results(spaces_to_underscores=True)) print(ndata.get_results(spaces_to_underscores=False))
def load_axona_data(): dir = r'C:\Users\smartin5\Recordings\recording_example' spike_file = os.path.join(dir, "010416b-LS3-50Hz10V5ms.2") pos_file = os.path.join(dir, "010416b-LS3-50Hz10V5ms_2.txt") lfp_file = os.path.join(dir, "010416b-LS3-50Hz10V5ms.eeg") unit_no = 7 ndata = NData() ndata.set_spike_file(spike_file) ndata.set_spatial_file(pos_file) ndata.set_lfp_file(lfp_file) ndata.load() ndata.set_unit_no(unit_no) return ndata
def load_h5_data(): data_dir = r'C:\Users\smartin5\Recordings\NC_eg' main_file = "040513_1.hdf5" spike_file = "/processing/Shank/6" pos_file = "/processing/Behavioural/Position" lfp_file = "/processing/Neural Continuous/LFP/eeg" unit_no = 3 def m_file(x): return os.path.join(data_dir, main_file + "+" + x) ndata = NData() ndata.set_data_format(data_format='NWB') ndata.set_spatial_file(m_file(pos_file)) ndata.set_spike_file(m_file(spike_file)) ndata.set_lfp_file(m_file(lfp_file)) ndata.load() ndata.set_unit_no(unit_no) return ndata
def compare_lfp(fname, out_base_dir=None, ch=16): ''' Parameters ---------- fname : str full path name without extension out_base_dir : str, None Path for desired output location. Default - Saves output to folder named !LFP in base directory. ch: int Number of LFP channels in session ''' if out_base_dir == None: out_base_dir = os.path.join(os.path.dirname(fname), r"!LFP") make_dir_if_not_exists(out_base_dir) load_loc = fname + ".eeg" out_name = os.path.basename(fname) + "_SI.csv" out_loc = os.path.join(out_base_dir, out_name) ndata1 = NData() ndata2 = NData() grid = np.meshgrid(np.arange(ch), np.arange(ch), indexing='ij') stacked = np.stack(grid, 2) pairs = stacked.reshape(-1, 2) result_a = np.zeros(shape=pairs.shape[0], dtype=np.float32) for i, pair in enumerate(pairs): load_lfp(load_loc, pair[0], ndata1) load_lfp(load_loc, pair[1], ndata2) res = get_normalised_diff(ndata1.lfp.get_samples(), ndata2.lfp.get_samples()) result_a[i] = res with open(out_loc, "w") as f: headers = [str(i) for i in range(1, ch + 1)] out_str = ",".join(headers) f.write(out_str) out_str = "" for i, (pair, val) in enumerate(zip(pairs, result_a)): if i % ch == 0: f.write(out_str + "\n") out_str = "" out_str += "{:.2f},".format(val) # f.write("({}, {}): {:.2f}\n".format(pair[0], pair[1], val)) f.write(out_str + "\n") reshaped = np.reshape(result_a, newshape=[ch, ch]) sns.heatmap(reshaped) plt.xticks(np.arange(0.5, ch + 0.5), labels=np.arange(1, ch + 1), fontsize=8) plt.xlabel('LFP Channels') plt.yticks(np.arange(0.5, ch + 0.5), labels=np.arange(1, ch + 1), fontsize=8) plt.ylabel('LFP Channels') plt.title('Raw LFP Similarity Index') fig_path = os.path.join(out_base_dir, os.path.basename(fname) + "_LFP_SI.png") print("Saving figure to {}".format(fig_path)) plt.savefig(fig_path, dpi=200, bbox_inches='tight', pad_inches=0.1) return result_a
def load_data(dir, spike_name, pos_name, lfp_name): spike_file = os.path.join(dir, spike_name) pos_file = os.path.join(dir, pos_name) lfp_file = os.path.join(dir, lfp_name) unit_no = 6 ndata = NData() ndata.set_spike_file(spike_file) ndata.set_spatial_file(pos_file) ndata.set_lfp_file(lfp_file) ndata.load() ndata.set_unit_no(unit_no) return ndata