def __init__(self, *, raw, raw_num_channels, num_frames, samplerate, channel_ids, channel_map, channel_positions, p2p, download=False): se.RecordingExtractor.__init__(self) self._raw = raw self._num_frames = num_frames self._samplerate = samplerate self._raw_num_channels = raw_num_channels self._channel_ids = channel_ids self._channel_map = channel_map self._channel_positions = channel_positions self._p2p = p2p if download: kp.load_file(self._raw) for id in self._channel_ids: pos = self._channel_positions[str(id)] self.set_channel_property(id, 'location', pos)
def __init__(self, arg, samplerate=None): super().__init__() if (isinstance(arg, dict)) and ('sorting_format' in arg): obj = dict(arg) else: obj = _create_object_for_arg(arg, samplerate=samplerate) assert obj is not None, f'Unable to create sorting from arg: {arg}' self._object: dict = obj if 'firings' in self._object: sorting_format = 'mda' data={'firings': self._object['firings'], 'samplerate': self._object.get('samplerate', 30000)} else: sorting_format = self._object['sorting_format'] data: dict = self._object['data'] if sorting_format == 'mda': firings_path = kp.load_file(data['firings']) assert firings_path is not None, f'Unable to load firings file: {data["firings"]}' self._sorting: se.SortingExtractor = MdaSortingExtractor(firings_file=firings_path, samplerate=data['samplerate']) elif sorting_format == 'h5_v1': h5_path = kp.load_file(data['h5_path']) self._sorting = H5SortingExtractorV1(h5_path=h5_path) elif sorting_format == 'npy1': times_npy = kp.load_npy(data['times_npy_uri']) labels_npy = kp.load_npy(data['labels_npy_uri']) samplerate = data['samplerate'] S = se.NumpySortingExtractor() S.set_sampling_frequency(samplerate) S.set_times_labels(times_npy.ravel(), labels_npy.ravel()) self._sorting = S elif sorting_format == 'snippets1': S = Snippets1SortingExtractor(snippets_h5_uri = data['snippets_h5_uri'], p2p=True) self._sorting = S elif sorting_format == 'npy2': npz = kp.load_npy(data['npz_uri']) times_npy = npz['spike_indexes'] labels_npy = npz['spike_labels'] samplerate = float(npz['sampling_frequency']) S = se.NumpySortingExtractor() S.set_sampling_frequency(samplerate) S.set_times_labels(times_npy.ravel(), labels_npy.ravel()) self._sorting = S elif sorting_format == 'nwb': from .nwbextractors import NwbSortingExtractor path0 = kp.load_file(data['path']) self._sorting: se.SortingExtractor = NwbSortingExtractor(path0) elif sorting_format == 'in_memory': S = get_in_memory_object(data) if S is None: raise Exception('Unable to find in-memory object for sorting') self._sorting = S else: raise Exception(f'Unexpected sorting format: {sorting_format}') self.copy_unit_properties(sorting=self._sorting)
def create_recording_object_from_spikeforest_recdir(recdir, label): raw_path = kp.load_file(recdir + '/raw.mda') raw_path = kp.store_file(raw_path, basename=label + '-raw.mda') # store with manifest print(raw_path) params = kp.load_object(recdir + '/params.json') geom_path = kp.load_file(recdir + '/geom.csv') geom = _load_geom_from_csv(geom_path) recording_object = dict(recording_format='mda', data=dict(raw=raw_path, geom=geom, params=params)) return recording_object
def __init__(self, probe_file, xml_file, nrs_file, dat_file): se.RecordingExtractor.__init__(self) # info = check_load_nrs(dirpath) # assert info is not None probe_obj = kp.load_object(probe_file) xml_file = kp.load_file(xml_file) # nrs_file = kp.load_file(nrs_file) dat_file = kp.load_file(dat_file) from xml.etree import ElementTree as ET xml = ET.parse(xml_file) root_element = xml.getroot() try: txt = root_element.find('acquisitionSystem/samplingRate').text assert txt is not None self._samplerate = float(txt) except: raise Exception('Unable to load acquisitionSystem/samplingRate') try: txt = root_element.find('acquisitionSystem/nChannels').text assert txt is not None self._nChannels = int(txt) except: raise Exception('Unable to load acquisitionSystem/nChannels') try: txt = root_element.find('acquisitionSystem/nBits').text assert txt is not None self._nBits = int(txt) except: raise Exception('Unable to load acquisitionSystem/nBits') if self._nBits == 16: dtype = np.int16 elif self._nBits == 32: dtype = np.int32 else: raise Exception(f'Unexpected nBits: {self._nBits}') self._rec = se.BinDatRecordingExtractor( dat_file, sampling_frequency=self._samplerate, numchan=self._nChannels, dtype=dtype) self._channel_ids = probe_obj['channel'] for ii in range(len(probe_obj['channel'])): channel = probe_obj['channel'][ii] x = probe_obj['x'][ii] y = probe_obj['y'][ii] z = probe_obj['z'][ii] group = probe_obj.get('group', probe_obj.get('shank'))[ii] self.set_channel_property(channel, 'location', [x, y, z]) self.set_channel_property(channel, 'group', group)
def reup_file(object: Any): thekey = recording_key if recording_key[ 'key_field'] in object else sorting_key if VERBOSE: print( f"Executing: object['{thekey['key_field']}'] = kp.store_file(kp.load_file(object['{thekey['key_field']}']), basename={thekey['basename']})" ) if DRY_RUN: return # Turns out that kachery doesn't handle big files without manifests all that well. Which was the point of this exercise. # So let's take advantage of being on the same filesystem to do a little magic. raw = object[thekey["key_field"]] print(f'Got raw: {raw}') if ('sha1dir' in raw): key_field = object[thekey['key_field']] sha1dir = key_field.split('/')[2] print(f'Got dir: {sha1dir}') kp.load_file(f'sha1://{sha1dir}') print(f"Fetching hash for file: {key_field}") reformed_field = trim_dir_annotation(f"{key_field}") if VERBOSE: print(f"(using reformed field {reformed_field})") try: sha1 = ka.get_file_hash(reformed_field) except: if FORCE: print( f"\t** Trimmed lookup didn't work, falling back to kp.load_file({key_field})" ) kp.load_file(key_field) sha1 = ka.get_file_hash(key_field) else: print( f"Error on ka.get_file_hash({reformed_field}) -- aborting") exit() else: #sha1 = '/'.join(raw.split('/')[2:]) sha1 = raw.split('/')[2] print(f'Got sha1: {sha1}') src_path = f'/mnt/ceph/users/magland/kachery-storage/sha1/{sha1[0]}{sha1[1]}/{sha1[2]}{sha1[3]}/{sha1[4]}{sha1[5]}/{sha1}' dest_path = f'/mnt/ceph/users/jsoules/kachery-storage/sha1/{sha1[0]}{sha1[1]}/{sha1[2]}{sha1[3]}/{sha1[4]}{sha1[5]}/{sha1}' if VERBOSE: print(f"Executing: shutil.copyfile({src_path}, {dest_path})") if not exists(dest_path): pathlib.Path('/'.join(dest_path.split('/')[:-1])).mkdir(parents=True, exist_ok=True) copyfile(src_path, dest_path) print("\tCompleted copy operation.") object[thekey['key_field']] = kp.store_file( kp.load_file(object[thekey['key_field']]), basename=f"{thekey['basename']}")
def get_sorting_unit_info(snippets_h5, unit_id): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None # with h5py.File(h5_path, 'r') as f: # unit_ids = np.array(f.get('unit_ids')) # channel_ids = np.array(f.get('channel_ids')) # channel_locations = np.array(f.get(f'channel_locations')) # sampling_frequency = np.array(f.get('sampling_frequency'))[0].item() # if np.isnan(sampling_frequency): # print('WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.') # sampling_frequency = 30000 # unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids')) # print(unit_waveforms_channel_ids) unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5( h5_path, unit_id) channel_locations_2 = [] for ch_id in unit_waveforms_channel_ids: ind = np.where(unit_waveforms_channel_ids == ch_id)[0] channel_locations_2.append(channel_locations0[ind].ravel().tolist()) return dict(channel_ids=unit_waveforms_channel_ids.astype(np.int32), channel_locations=channel_locations_2, sampling_frequency=sampling_frequency)
def get_sorting_unit_snippets(snippets_h5, unit_id, time_range, max_num_snippets): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None # with h5py.File(h5_path, 'r') as f: # unit_ids = np.array(f.get('unit_ids')) # channel_ids = np.array(f.get('channel_ids')) # channel_locations = np.array(f.get(f'channel_locations')) # sampling_frequency = np.array(f.get('sampling_frequency'))[0].item() # if np.isnan(sampling_frequency): # print('WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.') # sampling_frequency = 30000 # unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}')) # unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms')) # unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids')) # print(unit_waveforms_channel_ids) unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5( h5_path, unit_id) snippets = [{ 'index': j, 'unitId': unit_id, 'waveform': unit_waveforms[j].astype(np.float32), 'timepoint': float(unit_spike_train[j]) } for j in range(unit_waveforms.shape[0]) if time_range['min'] <= unit_spike_train[j] and unit_spike_train[j] < time_range['max']] return dict(channel_ids=unit_waveforms_channel_ids.astype(np.int32), channel_locations=channel_locations0.astype(np.float32), sampling_frequency=sampling_frequency, snippets=snippets[:max_num_snippets])
def fetch_spike_amplitudes(snippets_h5, unit_id): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None # with h5py.File(h5_path, 'r') as f: # unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}')) # unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms')) unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5( h5_path, unit_id) average_waveform = np.mean(unit_waveforms, axis=0) peak_channel_index = _compute_peak_channel_index_from_average_waveform( average_waveform) maxs = [ np.max(unit_waveforms[i][peak_channel_index, :]) for i in range(unit_waveforms.shape[0]) ] mins = [ np.min(unit_waveforms[i][peak_channel_index, :]) for i in range(unit_waveforms.shape[0]) ] peak_amplitudes = np.array([maxs[i] - mins[i] for i in range(len(mins))]) timepoints = unit_spike_train.astype(np.float32) amplitudes = peak_amplitudes.astype(np.float32) sort_inds = np.argsort(timepoints) timepoints = timepoints[sort_inds] amplitudes = amplitudes[sort_inds] return dict(timepoints=timepoints, amplitudes=amplitudes)
def cat_file(uri, start, end, exp_nop2p, exp_file_server_url): old_stdout = sys.stdout sys.stdout = sys.stderr kp._experimental_config(nop2p=exp_nop2p, file_server_urls=list(exp_file_server_url)) if start is None and end is None: path1 = kp.load_file(uri) if not path1: raise Exception('Error loading file for cat.') sys.stdout = old_stdout with open(path1, 'rb') as f: while True: data = os.read(f.fileno(), 4096) if len(data) == 0: break os.write(sys.stdout.fileno(), data) else: assert start is not None and end is not None start = int(start) end = int(end) assert start <= end if start == end: return sys.stdout = old_stdout kp.load_bytes(uri=uri, start=start, end=end, write_to_stdout=True)
def fetch_spike_waveforms(snippets_h5, unit_ids, spike_indices): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None spikes = [] with h5py.File(h5_path, 'r') as f: sampling_frequency = np.array(f.get('sampling_frequency'))[0].item() if np.isnan(sampling_frequency): print( 'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.' ) sampling_frequency = 30000 for ii, unit_id in enumerate(unit_ids): unit_waveforms = np.array( f.get(f'unit_waveforms/{unit_id}/waveforms')) unit_waveforms_channel_ids = np.array( f.get(f'unit_waveforms/{unit_id}/channel_ids')) unit_waveforms_spike_train = np.array( f.get(f'unit_waveforms/{unit_id}/spike_train')) average_waveform = np.mean(unit_waveforms, axis=0) channel_maximums = np.max(np.abs(average_waveform), axis=1) maxchan_index = np.argmax(channel_maximums) maxchan_id = unit_waveforms_channel_ids[maxchan_index] for spike_index in spike_indices[ii]: spikes.append( dict(unit_id=unit_id, spike_index=spike_index, spike_time=unit_waveforms_spike_train[spike_index], channel_id=maxchan_id, waveform=unit_waveforms[ spike_index, maxchan_index, :].squeeze().tolist())) return {'sampling_frequency': sampling_frequency, 'spikes': spikes}
def fetch_average_waveform_plot_data(snippets_h5, unit_id): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None with h5py.File(h5_path, 'r') as f: unit_ids = np.array(f.get('unit_ids')) sampling_frequency = np.array(f.get('sampling_frequency'))[0].item() if np.isnan(sampling_frequency): print( 'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.' ) sampling_frequency = 30000 unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}')) unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms')) unit_waveforms_channel_ids = np.array( f.get(f'unit_waveforms/{unit_id}/channel_ids')) print(unit_waveforms_channel_ids) average_waveform = np.mean(unit_waveforms, axis=0) channel_maximums = np.max(np.abs(average_waveform), axis=1) maxchan_index = np.argmax(channel_maximums) maxchan_id = unit_waveforms_channel_ids[maxchan_index] return dict(channel_id=int(maxchan_id), sampling_frequency=sampling_frequency, average_waveform=average_waveform[maxchan_index, :].astype( np.float32))
def _try_mda_create_object(arg: Union[str, dict]) -> Union[None, dict]: if isinstance(arg, str): path = arg if path.startswith('sha1dir') or path.startswith('/'): dd = kp.read_dir(path) if dd is not None: if 'raw.mda' in dd['files'] and 'params.json' in dd[ 'files'] and 'geom.csv' in dd['files']: raw_path = path + '/raw.mda' params_path = path + '/params.json' geom_path = path + '/geom.csv' geom_path_resolved = kp.load_file(geom_path) assert geom_path_resolved is not None, f'Unable to load geom.csv from: {geom_path}' params = kp.load_object(params_path) assert params is not None, f'Unable to load params.json from: {params_path}' geom = _load_geom_from_csv(geom_path_resolved) return dict(recording_format='mda', data=dict(raw=raw_path, geom=geom, params=params)) if isinstance(arg, dict): if ('raw' in arg) and ('geom' in arg) and ('params' in arg) and (type( arg['geom']) == list) and (type(arg['params']) == dict): return dict(recording_format='mda', data=dict(raw=arg['raw'], geom=arg['geom'], params=arg['params'])) return None
def __init__(self, *, snippets_h5_uri: str, p2p: bool = False): se.RecordingExtractor.__init__(self) snippets_h5_path = kp.load_file(snippets_h5_uri, p2p=p2p) self._snippets_h5_path: str = snippets_h5_path channel_ids_set: Set[int] = set() max_timepoint: int = 0 with h5py.File(self._snippets_h5_path, 'r') as f: sampling_frequency: float = np.array( f.get('sampling_frequency'))[0] if np.isnan(sampling_frequency): print( 'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.' ) sampling_frequency = 30000 self.set_sampling_frequency(sampling_frequency) self._unit_ids: List[int] = np.array( f.get('unit_ids')).astype(int).tolist() for unit_id in self._unit_ids: unit_spike_train = np.array( f.get(f'unit_spike_trains/{unit_id}')) max_timepoint = int( max(max_timepoint, np.max(unit_spike_train))) # unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms')) unit_waveforms_channel_ids = np.array( f.get(f'unit_waveforms/{unit_id}/channel_ids')) for id in unit_waveforms_channel_ids: channel_ids_set.add(int(id)) self._channel_ids: List[int] = sorted(list(channel_ids_set)) self._num_frames: int = max_timepoint + 1
def fetch_pca_features(snippets_h5, unit_ids): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None with h5py.File(h5_path, 'r') as f: sampling_frequency = np.array(f.get('sampling_frequency'))[0].item() if np.isnan(sampling_frequency): print( 'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.' ) sampling_frequency = 30000 x = [ dict( unit_id=unit_id, unit_waveforms_spike_train=np.array( f.get(f'unit_waveforms/{unit_id}/spike_train')), # unit_waveforms_spike_train=_subsample(np.array(f.get(f'unit_spike_trains/{unit_id}')), 1000), unit_waveforms=np.array( f.get(f'unit_waveforms/{unit_id}/waveforms')), unit_waveforms_channel_ids=np.array( f.get(f'unit_waveforms/{unit_id}/channel_ids'))) for unit_id in unit_ids ] channel_ids = _intersect_channel_ids( [a['unit_waveforms_channel_ids'] for a in x]) assert len(channel_ids) > 0, 'No channel ids in intersection' for a in x: unit_waveforms = a['unit_waveforms'] unit_waveforms_channel_ids = a['unit_waveforms_channel_ids'] inds = [ np.where(unit_waveforms_channel_ids == ch_id)[0][0] for ch_id in channel_ids ] a['unit_waveforms_2'] = unit_waveforms[:, inds, :] a['labels'] = np.ones((unit_waveforms.shape[0], )) * a['unit_id'] unit_waveforms = np.concatenate([a['unit_waveforms_2'] for a in x], axis=0) spike_train = np.concatenate([a['unit_waveforms_spike_train'] for a in x]) labels = np.concatenate([a['labels'] for a in x]).astype(int) from sklearn.decomposition import PCA nf = 5 # number of features # list of arrays W = unit_waveforms # ntot x M x T # ntot x MT X = W.reshape((W.shape[0], W.shape[1] * W.shape[2])) pca = PCA(n_components=nf) pca.fit(X) features = pca.transform(X) # n x nf return dict( times=(spike_train / sampling_frequency).tolist(), features=[features[:, ii].squeeze().tolist() for ii in range(nf)], labels=labels.tolist())
def _keep_good_units(sorting_obj, cluster_groups_csv_uri): sorting = LabboxEphysSortingExtractor(sorting_obj) df = pd.read_csv(kp.load_file(cluster_groups_csv_uri), delimiter='\t') df_good = df.loc[df['group'] == 'good'] good_unit_ids = df_good['cluster_id'].to_numpy().tolist() sorting_good = se.SubSortingExtractor(parent_sorting=sorting, unit_ids=good_unit_ids) return _create_npy1_sorting_object(sorting=sorting_good)
def create_sorting_object_from_spikeforest_recdir(recdir, label): params = kp.load_object(recdir + '/params.json') firings_path = kp.load_file(recdir + '/firings_true.mda') firings_path = ka.store_file(firings_path, basename=label + '-firings.mda') sorting_object = dict(sorting_format='mda', data=dict(firings=firings_path, samplerate=params['samplerate'])) print(sorting_object) return sorting_object
def load_chanmap_data_from_mat(uri_mat): m = sio.loadmat(kp.load_file(uri_mat)) chanmap = m['chanMap0ind'].squeeze() xcoords = m['xcoords'].squeeze() ycoords = m['ycoords'].squeeze() num_chan = len(chanmap) assert len(xcoords) == num_chan assert len(ycoords) == num_chan return chanmap, xcoords, ycoords
def __init__(self, firings_file, samplerate): SortingExtractor.__init__(self) self._firings_path = kp.load_file(firings_file) if not self._firings_path: raise Exception('Unable to load firings file: ' + firings_file) self._firings = readmda(self._firings_path) self._sampling_frequency = samplerate self._times = self._firings[1, :] self._labels = self._firings[2, :] self._unit_ids = np.unique(self._labels).astype(int)
def __init__(self, arg, samplerate=None): super().__init__() if (isinstance(arg, dict)) and ('sorting_format' in arg): obj = dict(arg) else: obj = _create_object_for_arg(arg, samplerate=samplerate) assert obj is not None, f'Unable to create sorting from arg: {arg}' self._object: dict = obj sorting_format = self._object['sorting_format'] data: dict = self._object['data'] if sorting_format == 'mda': firings_path = kp.load_file(data['firings']) assert firings_path is not None, f'Unable to load firings file: {data["firings"]}' self._sorting: se.SortingExtractor = MdaSortingExtractor( firings_file=firings_path, samplerate=data['samplerate']) elif sorting_format == 'h5_v1': h5_path = kp.load_file(data['h5_path']) self._sorting = H5SortingExtractorV1(h5_path=h5_path) elif sorting_format == 'npy1': times_npy = kp.load_npy(data['times_npy_uri']) labels_npy = kp.load_npy(data['labels_npy_uri']) samplerate = data['samplerate'] S = se.NumpySortingExtractor() S.set_sampling_frequency(samplerate) S.set_times_labels(times_npy.ravel(), labels_npy.ravel()) self._sorting = S elif sorting_format == 'npy2': npz = kp.load_npy(data['npz_uri']) times_npy = npz['spike_indexes'] labels_npy = npz['spike_labels'] samplerate = float(npz['sampling_frequency']) S = se.NumpySortingExtractor() S.set_sampling_frequency(samplerate) S.set_times_labels(times_npy.ravel(), labels_npy.ravel()) self._sorting = S else: raise Exception(f'Unexpected sorting format: {sorting_format}') self.copy_unit_properties(sorting=self._sorting)
def get_unit_snrs(snippets_h5): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None ret = {} with h5py.File(h5_path, 'r') as f: unit_ids = np.array(f.get('unit_ids')) for unit_id in unit_ids: unit_waveforms = np.array( f.get(f'unit_waveforms/{unit_id}/waveforms')) # n x M x T ret[str(unit_id)] = _compute_unit_snr_from_waveforms( unit_waveforms) return ret
def fetch_average_waveform_2(snippets_h5, unit_id): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5( h5_path, unit_id) average_waveform = np.mean(unit_waveforms, axis=0) return dict(average_waveform=average_waveform.astype(np.float32), channel_ids=unit_waveforms_channel_ids.astype(np.int32), channel_locations=channel_locations0.astype(np.float32), sampling_frequency=sampling_frequency)
def __init__(self, arg: Union[str, dict], download: bool=False): super().__init__() obj = _create_object_for_arg(arg) assert obj is not None self._object: dict = obj recording_format = self._object['recording_format'] data: dict = self._object['data'] if recording_format == 'mda': self._recording: se.RecordingExtractor = MdaRecordingExtractor(timeseries_path=data['raw'], samplerate=data['params']['samplerate'], geom=np.array(data['geom']), download=download) elif recording_format == 'nrs': self._recording: se.RecordingExtractor = NrsRecordingExtractor(**data) elif recording_format == 'nwb': path0 = kp.load_file(data['path']) self._recording: se.RecordingExtractor = NwbRecordingExtractor(path0, electrical_series_name='e-series') elif recording_format == 'bin1': self._recording: se.RecordingExtractor = Bin1RecordingExtractor(**data, p2p=True, download=download) elif recording_format == 'subrecording': R = LabboxEphysRecordingExtractor(data['recording'], download=download) if 'channel_ids' in data: channel_ids = np.array(data['channel_ids']) elif 'group' in data: channel_ids = np.array(R.get_channel_ids()) groups = R.get_channel_groups(channel_ids=R.get_channel_ids()) group = int(data['group']) inds = np.where(np.array(groups) == group)[0] channel_ids = channel_ids[inds] elif 'groups' in data: raise Exception('This case not yet handled.') else: channel_ids = None if 'start_frame' in data: start_frame = data['start_frame'] end_frame = data['end_frame'] else: start_frame = None end_frame = None self._recording: se.RecordingExtractor = se.SubRecordingExtractor( parent_recording=R, channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame ) elif recording_format == 'filtered': R = LabboxEphysRecordingExtractor(data['recording'], download=download) self._recording: se.RecordingExtractor = _apply_filters(recording=R, filters=data['filters']) else: raise Exception(f'Unexpected recording format: {recording_format}') self.copy_channel_properties(recording=self._recording)
def _try_mda_create_object(arg: Union[str, dict], samplerate=None) -> Union[None, dict]: if isinstance(arg, str): path = arg if not kp.load_file(path): return None return dict(sorting_format='mda', data=dict(firings=path, samplerate=samplerate)) if isinstance(arg, dict): if 'firings' in arg: return dict(recording_format='mda', data=dict(firings=arg['firings'], samplerate=arg.get('samplerate', None))) return None
def get_peak_channels(snippets_h5): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None ret = {} with h5py.File(h5_path, 'r') as f: unit_ids = np.array(f.get('unit_ids')) for unit_id in unit_ids: unit_waveforms = np.array( f.get(f'unit_waveforms/{unit_id}/waveforms')) # n x M x T channel_ids = np.array( f.get(f'unit_waveforms/{unit_id}/channel_ids')) # n peak_channel_index = _compute_peak_channel_index_from_waveforms( unit_waveforms) ret[str(unit_id)] = int(channel_ids[peak_channel_index]) return ret
def __init__(self, *, snippets_h5_uri: str, p2p: bool = False): se.RecordingExtractor.__init__(self) snippets_h5_path = kp.load_file(snippets_h5_uri, p2p=p2p) self._snippets_h5_path: str = snippets_h5_path channel_ids_set: Set[int] = set() max_timepoint: int = 0 with h5py.File(self._snippets_h5_path, 'r') as f: self._sampling_frequency: float = np.array( f.get('sampling_frequency'))[0] if np.isnan(self._sampling_frequency): print( 'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.' ) self._sampling_frequency = 30000 self._unit_ids: List[int] = np.array( f.get('unit_ids')).astype(int).tolist() for unit_id in self._unit_ids: unit_spike_train = np.array( f.get(f'unit_spike_trains/{unit_id}')) max_timepoint = int( max(max_timepoint, np.max(unit_spike_train))) # unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms')) unit_waveforms_channel_ids = np.array( f.get(f'unit_waveforms/{unit_id}/channel_ids')) for id in unit_waveforms_channel_ids: channel_ids_set.add(int(id)) self._channel_ids: List[int] = sorted(list(channel_ids_set)) try: self._num_frames = f.get('num_frames')[0].item() except: print( 'Unable to load num_frames. Please update snippets file.') self._num_frames: int = max_timepoint + 1 try: channel_locations = np.array(f.get(f'channel_locations')) self.set_channel_locations(channel_locations) except: print( 'WARNING: using [0, 0] for channel locations. Please adjust snippets file' ) for channel_id in self._channel_ids: self.set_channel_property(channel_id, 'location', [0, 0])
def individual_cluster_features(snippets_h5, unit_id, max_num_events=1000): import h5py h5_path = kp.load_file(snippets_h5, p2p=False) assert h5_path is not None # with h5py.File(h5_path, 'r') as f: # unit_ids = np.array(f.get('unit_ids')) # channel_ids = np.array(f.get('channel_ids')) # channel_locations = np.array(f.get(f'channel_locations')) # sampling_frequency = np.array(f.get('sampling_frequency'))[0].item() # unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}')) # unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms')) # L x M x T # unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids')) # if len(unit_spike_train) > max_num_events: # inds = subsample_inds(len(unit_spike_train), max_num_events) # unit_spike_train = unit_spike_train[inds] # unit_waveforms = unit_waveforms[inds] unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5( h5_path, unit_id, max_num_events=max_num_events) from sklearn.decomposition import PCA nf = 2 # number of features # L = number of waveforms (number of spikes) # M = number of electrodes in nbhd # T = num. timepoints in the snippet W = unit_waveforms # L x M x T # subtract mean for each channel and waveform for i in range(W.shape[0]): for m in range(W.shape[1]): W[i, m, :] = W[i, m, :] - np.mean(W[i, m, :]) X = W.reshape((W.shape[0], W.shape[1] * W.shape[2])) # L x MT pca = PCA(n_components=nf) pca.fit(X) # L = number of waveforms (number of spikes) # nf = number of features features = pca.transform(X) # L x nf return dict(timepoints=unit_spike_train.astype(np.float32), x=features[:, 0].squeeze().astype(np.float32), y=features[:, 1].squeeze().astype(np.float32))
def test1(): f = kp.create_feed('f1') f2 = kp.load_feed('f1') assert (f.get_uri() == f2.get_uri()) sf = f.get_subfeed('sf1') sf.append_message({'m': 1}) assert (sf.get_num_messages() == 1) x = kp.store_text('abc') sf.set_access_rules({'rules': []}) r = sf.get_access_rules() try: a = kp.load_file( 'sha1://e25f95079381fe07651aa7d37c2f4e8bda19727c/file.txt') raise Exception('Did not get expected error') except LoadFileError as err: pass # expected except Exception as err: raise err
def _download_files_in_item(x): if type(x) == str: if x.startswith('sha1://') or x.startswith('sha1dir://'): if not ka.get_file_info(x, fr=dict(url=None)): a = kp.load_file(x) assert a is not None, f'Unable to download file: {x}' return elif type(x) == dict: for _, val in x.items(): _download_files_in_item(val) return elif type(x) == list: for y in x: _download_files_in_item(y) return elif type(x) == tuple: for y in x: _download_files_in_item(y) return else: return
def readmda(path): if (file_extension(path) == '.npy'): return readnpy(path) path = kp.load_file(path) H = _read_header(path) if (H is None): print("Problem reading header of: {}".format(path)) return None ret = np.array([]) f = open(path, "rb") try: f.seek(H.header_size) # This is how I do the column-major order ret = np.fromfile(f, dtype=H.dt, count=H.dimprod) ret = np.reshape(ret, H.dims, order='F') f.close() return ret except Exception as e: # catch *all* exceptions print(e) f.close() return None
def load_info_from_mat(uri_mat): m = sio.loadmat(kp.load_file(uri_mat)) spike_times = m['spikeTimes'].squeeze() spike_labels = m['spikeClusters'].squeeze() cluster_notes = m['clusterNotes'].squeeze() samplerate = m['SampleRate'][0][0] siteMap = m['siteMap'].squeeze() xcoords = m['xcoords'].squeeze() ycoords = m['ycoords'].squeeze() chanmap = siteMap - 1 xcoords = xcoords[chanmap] ycoords = ycoords[chanmap] num_chan = len(chanmap) assert len(xcoords) == num_chan assert len(ycoords) == num_chan unit_notes = {} for j in range(len(cluster_notes)): notes = [note for note in cluster_notes[j] if isinstance(note, str)] if len(notes) > 0: unit_notes[j + 1] = notes return samplerate, chanmap, xcoords, ycoords, spike_times, spike_labels, unit_notes