def prepare_recording(*, bin_uri, bin_file_size, raw_num_channels, mat_uri, meta_uri, single_only): samplerate, chanmap, xcoords, ycoords, spike_times, spike_labels, unit_notes = load_info_from_mat( mat_uri) # exclude clusters 0 and -1 spike_inds = np.where(spike_labels > 0)[0] spike_times = spike_times[spike_inds] spike_labels = spike_labels[spike_inds] if single_only: okay_to_use = np.zeros((len(spike_times, ))) for unit_id, notes in unit_notes.items(): if 'single' in notes: okay_to_use[np.where(spike_labels == unit_id)[0]] = 1 spike_inds = np.where(okay_to_use)[0] print( f'Using {len(spike_inds)} of {len(spike_times)} events (single units only)' ) spike_times = spike_times[spike_inds] spike_labels = spike_labels[spike_inds] times_npy_uri = ka.store_npy(spike_times) labels_npy_uri = ka.store_npy(spike_labels) sorting_object = dict(sorting_format='npy1', data=dict(times_npy_uri=times_npy_uri, labels_npy_uri=labels_npy_uri, samplerate=samplerate)) num_frames = bin_file_size / (raw_num_channels * 2) print(num_frames) assert num_frames == int(num_frames) num_frames = int(num_frames) meta_lines = kp.load_text(meta_uri).split('\n') # perhaps use in future num_channels = len(chanmap) print(f'Number of channels: {num_channels}') channel_ids = [int(i) for i in range(num_channels)] channel_map = dict( zip([str(c) for c in channel_ids], [int(chanmap[i]) for i in range(num_channels)])) channel_positions = dict( zip([str(c) for c in channel_ids], [[float(xcoords[i]), float(ycoords[i])] for i in range(num_channels)])) recording_object = dict(recording_format='bin1', data=dict(raw=bin_uri, raw_num_channels=raw_num_channels, num_frames=num_frames, samplerate=samplerate, channel_ids=channel_ids, channel_map=channel_map, channel_positions=channel_positions)) return recording_object, sorting_object, unit_notes
def _create_npy1_sorting_object(*, sorting): unit_ids = sorting.get_unit_ids() spike_trains = [ np.array(sorting.get_unit_spike_train(unit_id=unit_id)).squeeze() for unit_id in unit_ids ] spike_labels = [ unit_id * np.ones((len(spike_trains[ii]), )) for ii, unit_id in enumerate(unit_ids) ] all_times = np.concatenate(spike_trains) all_labels = np.concatenate(spike_labels) sort_inds = np.argsort(all_times) all_times = all_times[sort_inds] all_labels = all_labels[sort_inds] return dict(sorting_format='npy1', data=dict(times_npy_uri=ka.store_npy(all_times), labels_npy_uri=ka.store_npy(all_labels), samplerate=sorting.get_sampling_frequency()))
def _create_sorting_object(sorting): unit_ids = sorting.get_unit_ids() times_list = [] labels_list = [] for i in range(len(unit_ids)): unit = unit_ids[i] times = sorting.get_unit_spike_train(unit_id=unit) times_list.append(times) labels_list.append(np.ones(times.shape) * unit) all_times = np.concatenate(times_list) all_labels = np.concatenate(labels_list) sort_inds = np.argsort(all_times) all_times = all_times[sort_inds] all_labels = all_labels[sort_inds] times_npy_uri = ka.store_npy(all_times) labels_npy_uri = ka.store_npy(all_labels) return dict(sorting_format='npy1', data=dict(times_npy_uri=times_npy_uri, labels_npy_uri=labels_npy_uri, samplerate=30000))
def mountainsort4(*, recording_object, detect_sign=-1, clip_size=50, adjacency_radius=-1, detect_threshold=3, detect_interval=10, num_workers=None, verbose=True): from ml_ms4alg.mountainsort4 import MountainSort4 recording = LabboxEphysRecordingExtractor(recording_object) MS4 = MountainSort4() MS4.setRecording(recording) geom = _get_geom_from_recording(recording) MS4.setGeom(geom) MS4.setSortingOpts(clip_size=clip_size, adjacency_radius=adjacency_radius, detect_sign=detect_sign, detect_interval=detect_interval, detect_threshold=detect_threshold, verbose=verbose) if num_workers is not None: MS4.setNumWorkers(num_workers) with hi.TemporaryDirectory() as tmpdir: MS4.setTemporaryDirectory(tmpdir) MS4.sort() times, labels, channels = MS4.eventTimesLabelsChannels() sorting_object = { 'sorting_format': 'npy1', 'data': { 'samplerate': recording.get_sampling_frequency(), 'times_npy_uri': ka.store_npy(times.astype(np.float64)), 'labels_npy_uri': ka.store_npy(labels.astype(np.int32)) } } return sorting_object
def write_sorting(sorting, save_path, write_primary_channels=False): print('write sorting') unit_ids = sorting.get_unit_ids() times_list = [] labels_list = [] primary_channels_list = [] for unit_id in unit_ids: times = sorting.get_unit_spike_train(unit_id=unit_id) times_list.append(times) labels_list.append(np.ones(times.shape) * unit_id) if write_primary_channels: if 'max_channel' in sorting.get_unit_property_names(unit_id): primary_channels_list.append( [sorting.get_unit_property(unit_id, 'max_channel')] * times.shape[0]) else: raise ValueError( "Unable to write primary channels because 'max_channel' spike feature not set in unit " + str(unit_id)) else: primary_channels_list.append(np.zeros(times.shape)) all_times = _concatenate(times_list) all_labels = _concatenate(labels_list) all_primary_channels = _concatenate(primary_channels_list) sort_inds = np.argsort(all_times) all_times = all_times[sort_inds] all_labels = all_labels[sort_inds] all_primary_channels = all_primary_channels[sort_inds] L = len(all_times) firings = np.zeros((3, L)) firings[0, :] = all_primary_channels firings[1, :] = all_times firings[2, :] = all_labels firings_path = ka.store_npy(array=firings, basename='firings.npy') sorting_obj = _json_serialize( dict( firings=firings_path, samplerate=sorting.get_sampling_frequency(), unit_ids=unit_ids, )) if save_path is not None: with open(save_path, 'w') as f: json.dump(sorting_obj, f, indent=4) print('-----------------', sorting_obj) return sorting_obj
def store_npy(array: np.ndarray, basename: Union[str, None] = None): return ka.store_npy(array, basename=basename)
def _test_store_npy(val: np.ndarray): x = ka.store_npy(val) assert x val2 = ka.load_npy(x) assert np.array_equal(val, val2)