def test_get_performance(): ###### # simple match gt_sorting, tested_sorting = make_sorting([100, 200, 300, 400], [0, 0, 1, 0], [ 101, 201, 301, ], [0, 0, 5]) sc = compare_sorter_to_ground_truth(gt_sorting, tested_sorting, exhaustive_gt=True, delta_time=0.3) perf = sc.get_performance('raw_count') assert perf.loc[0, 'tp'] == 2 assert perf.loc[1, 'tp'] == 1 assert perf.loc[0, 'fn'] == 1 assert perf.loc[1, 'fn'] == 0 assert perf.loc[0, 'fp'] == 0 assert perf.loc[1, 'fp'] == 0 perf = sc.get_performance('pooled_with_average') assert perf['miss_rate'] == 1 / 6 perf = sc.get_performance('by_unit') assert perf.loc[0, 'accuracy'] == 2 / 3. assert perf.loc[0, 'miss_rate'] == 1 / 3. ###### # match when 2 units fire at same time gt_sorting, tested_sorting = make_sorting( [100, 100, 200, 200, 300], [0, 1, 0, 1, 0], [100, 100, 200, 200, 300], [0, 1, 0, 1, 0], ) sc = compare_sorter_to_ground_truth(gt_sorting, tested_sorting, exhaustive_gt=True) perf = sc.get_performance('raw_count') assert perf.loc[0, 'tp'] == 3 assert perf.loc[0, 'fn'] == 0 assert perf.loc[0, 'fp'] == 0 assert perf.loc[0, 'num_gt'] == 3 assert perf.loc[0, 'num_tested'] == 3 perf = sc.get_performance('pooled_with_average') assert perf['accuracy'] == 1. assert perf['miss_rate'] == 0.
def _compare_one_sorter(sorter_name, sortings_pre, agg_sortings, gts, comparisons): print(f'Performance comparisons for {sorter_name}') for key in sortings_pre.keys(): if key[1] == sorter_name: print(f'Recording name: {key[0]}') comparison_post = sc.compare_sorter_to_ground_truth( tested_sorting=agg_sortings[key], gt_sorting=gts[key[0]]) print('Before recovery:') print(comparisons[key].print_performance()) print('After recovery:') print(comparison_post.print_performance())
def setUp(self): #~ self._rec, self._sorting = se.toy_example(num_channels=10, duration=10, num_segments=1) #~ self._rec = self._rec.save() #~ self._sorting = self._sorting.save() local_path = download_dataset(remote_path='mearec/mearec_test_10s.h5') self._rec = se.MEArecRecordingExtractor(local_path) self._sorting = se.MEArecSortingExtractor(local_path) self.num_units = len(self._sorting.get_unit_ids()) # self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) self._we = extract_waveforms(self._rec, self._sorting, './mearec_test', load_if_exists=True) self._amplitudes = st.get_spike_amplitudes(self._we, peak_sign='neg', outputs='by_unit') self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting)
def make_comparison_figures(): gt_sorting, tested_sorting = generate_erroneous_sorting() comp = sc.compare_sorter_to_ground_truth(gt_sorting, tested_sorting, gt_name=None, tested_name=None, delta_time=0.4, sampling_frequency=None, min_accuracy=0.5, exhaustive_gt=True, match_mode='hungarian', n_jobs=-1, bad_redundant_threshold=0.2, compute_labels=False, verbose=False) print(comp.hungarian_match_12) fig, ax = plt.subplots() im = ax.matshow(comp.match_event_count, cmap='Greens') ax.set_xticks(np.arange(0, comp.match_event_count.shape[1])) ax.set_yticks(np.arange(0, comp.match_event_count.shape[0])) ax.xaxis.tick_bottom() ax.set_yticklabels(comp.match_event_count.index, fontsize=12) ax.set_xticklabels(comp.match_event_count.columns, fontsize=12) fig.colorbar(im) fig.savefig('spikecomparison_match_count.png') fig, ax = plt.subplots() sw.plot_agreement_matrix(comp, ax=ax, ordered=False) im = ax.get_images()[0] fig.colorbar(im) fig.savefig('spikecomparison_agreement_unordered.png') fig, ax = plt.subplots() sw.plot_agreement_matrix(comp, ax=ax) im = ax.get_images()[0] fig.colorbar(im) fig.savefig('spikecomparison_agreement.png') fig, ax = plt.subplots() sw.plot_confusion_matrix(comp, ax=ax) im = ax.get_images()[0] fig.colorbar(im) fig.savefig('spikecomparison_confusion.png') plt.show()
local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') recording = se.MEArecRecordingExtractor(local_path) sorting_true = se.MEArecSortingExtractor(local_path) print(recording) print(sorting_true) ############################################################################## # run herdingspikes on it sorting_HS = ss.run_herdingspikes(recording) ############################################################################## cmp_gt_HS = sc.compare_sorter_to_ground_truth(sorting_true, sorting_HS, exhaustive_gt=True) ############################################################################## # To have an overview of the match we can use the unordered agreement matrix sw.plot_agreement_matrix(cmp_gt_HS, ordered=False) ############################################################################## # or ordered sw.plot_agreement_matrix(cmp_gt_HS, ordered=True) ############################################################################## # This function first matches the ground-truth and spike sorted units, and # then it computes several performance metrics.
def test_compare_sorter_to_ground_truth(): # simple match gt_sorting, tested_sorting = make_sorting( [100, 200, 300, 400, 500, 600, 700], [0, 0, 1, 0, 1, 1, 1], [101, 201, 301, 302, 401, 501, 502, 601, 900], [0, 0, 5, 6, 0, 5, 6, 5, 11]) for match_mode in ('hungarian', 'best'): compute_labels = (match_mode == 'hungarian') sc = compare_sorter_to_ground_truth(gt_sorting, tested_sorting, exhaustive_gt=True, match_mode=match_mode, compute_labels=compute_labels) assert_array_equal(sc.event_counts1.values, [3, 4]) assert_array_equal(sc.event_counts2.values, [3, 3, 2, 1]) assert_array_equal(sc.possible_match_12[1], [5, 6]) assert_array_equal(sc.best_match_12[1], 5) assert_array_equal(sc.hungarian_match_12[1], 5) # ~ print(sc.agreement_scores) # ~ print(sc. scores = sc.agreement_scores ordered_scores = sc.get_ordered_agreement_scores() assert_array_equal( scores.loc[ordered_scores.index, ordered_scores.columns], ordered_scores) assert sc.count_score.at[0, 'tp'] == 3 assert sc.count_score.at[1, 'tp'] == 3 assert sc.count_score.at[1, 'fn'] == 1 sc._do_confusion_matrix() # print(sc._confusion_matrix) methods = [ 'raw_count', 'by_unit', 'pooled_with_average', ] for method in methods: perf = sc.get_performance(method=method) # ~ print(perf) for method in methods: sc.print_performance(method=method) sc.print_summary() sc = compare_sorter_to_ground_truth(gt_sorting, tested_sorting, exhaustive_gt=True, match_mode='hungarian') # test well detected units depending on thresholds good_units = sc.get_well_detected_units() # tp_thresh=0.95 default value print(good_units) assert_array_equal(good_units, [ 0, ]) good_units = sc.get_well_detected_units(well_detected_score=0.95) assert_array_equal(good_units, [ 0, ]) good_units = sc.get_well_detected_units(well_detected_score=.6) assert_array_equal(good_units, [0, 5]) # good_units = sc.get_well_detected_units(false_discovery_rate=0.05) # assert_array_equal(good_units, [0, 1]) # good_units = sc.get_well_detected_units(accuracy=0.95, false_discovery_rate=.05) # combine thresh # assert_array_equal(good_units, [0]) # count num_ok = sc.count_well_detected_units(well_detected_score=0.95) assert num_ok == 1 # false_positive_units [11] fpu_ids = sc.get_false_positive_units() assert_array_equal(fpu_ids, [11]) num_fpu = sc.count_false_positive_units() assert num_fpu == 1 # redundant_units [6] redundant_ids = sc.get_redundant_units() assert_array_equal(redundant_ids, [6]) # bad_units [11] bad_ids = sc.get_bad_units() assert_array_equal(bad_ids, [6, 11]) num_bad = sc.count_bad_units() # bad units is union of false_positive_units + redundant_units fpu_ids = sc.get_false_positive_units() redundant_ids = sc.get_redundant_units() bad_ids = sc.get_bad_units() assert_array_equal(bad_ids, sorted(fpu_ids + redundant_ids))
def manual(recording_folder): #Folder with tetrode data #recording_folder='/home/adrian/Documents/SpikeSorting/Adrian_test_data/Irene_data/test_without_zero_main_channels/Tetrode_9_CH'; os.chdir(recording_folder) """ Adding Matlab-based sorters to path """ #IronClust iron_path = "~/Documents/SpikeSorting/ironclust" ss.IronClustSorter.set_ironclust_path(os.path.expanduser(iron_path)) ss.IronClustSorter.ironclust_path #If sorter has already been run skip it. subfolders = [f.name for f in os.scandir(recording_folder) if f.is_dir()] #if ('phy_KL' in subfolders) & ('phy_IC' in subfolders) & ('phy_Waveclus' in subfolders) & ('phy_SC' in subfolders) & ('phy_MS4' in subfolders) & ('phy_HS' in subfolders) & ('phy_TRI' in subfolders): if ('phy_KL' in subfolders) & ('phy_IC' in subfolders) & ( 'phy_SC' in subfolders) & ('phy_MS4' in subfolders) & ( 'phy_HS' in subfolders) & ('phy_TRI' in subfolders): print('Tetrode ' + recording_folder.split('_')[-1] + ' was previously manually sorted. Skipping') return #Check if the recording has been preprocessed before and load it. # Else proceed with preprocessing. arr = os.listdir() #Load .continuous files recording = se.OpenEphysRecordingExtractor(recording_folder) channel_ids = recording.get_channel_ids() fs = recording.get_sampling_frequency() num_chan = recording.get_num_channels() print('Channel ids:', channel_ids) print('Sampling frequency:', fs) print('Number of channels:', num_chan) #!cat tetrode9.prb #Asks for prb file # os.system('cat /home/adrian/Documents/SpikeSorting/Adrian_test_data/Irene_data/test_without_zero_main_channels/Tetrode_9_CH/tetrode9.prb') recording_prb = recording.load_probe_file(os.getcwd() + '/tetrode.prb') print('Channels after loading the probe file:', recording_prb.get_channel_ids()) print('Channel groups after loading the probe file:', recording_prb.get_channel_groups()) #For testing only: Reduce recording. #recording_prb = se.SubRecordingExtractor(recording_prb, start_frame=100*fs, end_frame=420*fs) #Bandpass filter recording_cmr = st.preprocessing.bandpass_filter(recording_prb, freq_min=300, freq_max=6000) recording_cache = se.CacheRecordingExtractor(recording_cmr) print(recording_cache.get_channel_ids()) print(recording_cache.get_channel_groups()) print(recording_cache.get_num_frames() / recording_cache.get_sampling_frequency()) #View installed sorters #ss.installed_sorters() #mylist = [f for f in glob.glob("*.txt")] #%% Run all channels. There are only single tetrode channels anyway. #Create sub recording to avoid saving whole recording.Requirement from NWB to allow saving sorters data. recording_sub = se.SubRecordingExtractor(recording_cache, start_frame=200 * fs, end_frame=320 * fs) # Sorters2CompareLabel=['KL','IC','Waveclus','HS','MS4','SC','TRI']; Sorters2CompareLabel = ['KL', 'IC', 'HS', 'MS4', 'SC', 'TRI'] subfolders = [f.name for f in os.scandir(recording_folder) if f.is_dir()] for num in range(len(Sorters2CompareLabel)): i = Sorters2CompareLabel[num] print(i) if 'phy_' + i in subfolders: print('Sorter already used for curation. Skipping') continue else: if 'KL' in i: #Klusta if 'sorting_KL_all.nwb' in arr: print('Loading Klusta') sorting_KL_all = se.NwbSortingExtractor( 'sorting_KL_all.nwb') else: t = time.time() sorting_KL_all = ss.run_klusta( recording_cache, output_folder='results_all_klusta', delete_output_folder=True) print('Found', len(sorting_KL_all.get_unit_ids()), 'units') print(time.time() - t) #Save Klusta se.NwbRecordingExtractor.write_recording( recording_sub, 'sorting_KL_all.nwb') se.NwbSortingExtractor.write_sorting( sorting_KL_all, 'sorting_KL_all.nwb') sorter = sorting_KL_all if 'IC' in i: #Ironclust if 'sorting_IC_all.nwb' in arr: print('Loading Ironclust') sorting_IC_all = se.NwbSortingExtractor( 'sorting_IC_all.nwb') else: t = time.time() sorting_IC_all = ss.run_ironclust( recording_cache, output_folder='results_all_ic', delete_output_folder=True, filter=False) print('Found', len(sorting_IC_all.get_unit_ids()), 'units') print(time.time() - t) #Save IC se.NwbRecordingExtractor.write_recording( recording_sub, 'sorting_IC_all.nwb') se.NwbSortingExtractor.write_sorting( sorting_IC_all, 'sorting_IC_all.nwb') sorter = sorting_IC_all # if 'Waveclus' in i: # #Waveclust # if 'sorting_waveclus_all.nwb' in arr: # print('Loading waveclus') # sorting_waveclus_all=se.NwbSortingExtractor('sorting_waveclus_all.nwb'); # else: # t = time.time() # sorting_waveclus_all = ss.run_waveclus(recording_cache, output_folder='results_all_waveclus',delete_output_folder=True) # print('Found', len(sorting_waveclus_all.get_unit_ids()), 'units') # print(time.time() - t) # #Save waveclus # se.NwbRecordingExtractor.write_recording(recording_sub, 'sorting_waveclus_all.nwb') # se.NwbSortingExtractor.write_sorting(sorting_waveclus_all, 'sorting_waveclus_all.nwb') # sorter=sorting_waveclus_all; if 'HS' in i: #Herdingspikes if 'sorting_herdingspikes_all.nwb' in arr: print('Loading herdingspikes') sorting_herdingspikes_all = se.NwbSortingExtractor( 'sorting_herdingspikes_all.nwb') sorter = sorting_herdingspikes_all else: t = time.time() try: sorting_herdingspikes_all = ss.run_herdingspikes( recording_cache, output_folder='results_all_herdingspikes', delete_output_folder=True) print('Found', len(sorting_herdingspikes_all.get_unit_ids()), 'units') time.time() - t #Save herdingspikes se.NwbRecordingExtractor.write_recording( recording_sub, 'sorting_herdingspikes_all.nwb') try: se.NwbSortingExtractor.write_sorting( sorting_herdingspikes_all, 'sorting_herdingspikes_all.nwb') except TypeError: print( "No units detected. Can't save HerdingSpikes") os.remove("sorting_herdingspikes_all.nwb") sorter = sorting_herdingspikes_all except: print('Herdingspikes has failed') sorter = [] if 'MS4' in i: #Mountainsort4 if 'sorting_mountainsort4_all.nwb' in arr: print('Loading mountainsort4') sorting_mountainsort4_all = se.NwbSortingExtractor( 'sorting_mountainsort4_all.nwb') else: t = time.time() sorting_mountainsort4_all = ss.run_mountainsort4( recording_cache, output_folder='results_all_mountainsort4', delete_output_folder=True, filter=False) print('Found', len(sorting_mountainsort4_all.get_unit_ids()), 'units') print(time.time() - t) #Save mountainsort4 se.NwbRecordingExtractor.write_recording( recording_sub, 'sorting_mountainsort4_all.nwb') se.NwbSortingExtractor.write_sorting( sorting_mountainsort4_all, 'sorting_mountainsort4_all.nwb') sorter = sorting_mountainsort4_all if 'SC' in i: #Spykingcircus if 'sorting_spykingcircus_all.nwb' in arr: print('Loading spykingcircus') sorting_spykingcircus_all = se.NwbSortingExtractor( 'sorting_spykingcircus_all.nwb', filter=False) else: t = time.time() sorting_spykingcircus_all = ss.run_spykingcircus( recording_cache, output_folder='results_all_spykingcircus', delete_output_folder=True) print('Found', len(sorting_spykingcircus_all.get_unit_ids()), 'units') print(time.time() - t) #Save sorting_spykingcircus se.NwbRecordingExtractor.write_recording( recording_sub, 'sorting_spykingcircus_all.nwb') se.NwbSortingExtractor.write_sorting( sorting_spykingcircus_all, 'sorting_spykingcircus_all.nwb') sorter = sorting_spykingcircus_all if 'TRI' in i: #Tridesclous if 'sorting_tridesclous_all.nwb' in arr: print('Loading tridesclous') try: sorting_tridesclous_all = se.NwbSortingExtractor( 'sorting_tridesclous_all.nwb') except AttributeError: print( "No units detected. Can't load Tridesclous so will run it." ) t = time.time() sorting_tridesclous_all = ss.run_tridesclous( recording_cache, output_folder='results_all_tridesclous', delete_output_folder=True) print('Found', len(sorting_tridesclous_all.get_unit_ids()), 'units') time.time() - t os.remove("sorting_tridesclous_all.nwb") #Save sorting_tridesclous se.NwbRecordingExtractor.write_recording( recording_sub, 'sorting_tridesclous_all.nwb') se.NwbSortingExtractor.write_sorting( sorting_tridesclous_all, 'sorting_tridesclous_all.nwb') else: t = time.time() sorting_tridesclous_all = ss.run_tridesclous( recording_cache, output_folder='results_all_tridesclous', delete_output_folder=True) print('Found', len(sorting_tridesclous_all.get_unit_ids()), 'units') time.time() - t #Save sorting_tridesclous se.NwbRecordingExtractor.write_recording( recording_sub, 'sorting_tridesclous_all.nwb') se.NwbSortingExtractor.write_sorting( sorting_tridesclous_all, 'sorting_tridesclous_all.nwb') sorter = sorting_tridesclous_all #Check if sorter failed if not sorter: continue st.postprocessing.export_to_phy(recording_cache, sorter, output_folder='phy_' + i, grouping_property='group', verbose=True, recompute_info=True) #Open phy interface os.system('phy template-gui phy_' + i + '/params.py') #Remove detections curated as noise. sorting_phy_curated = se.PhySortingExtractor( 'phy_' + i + '/', exclude_cluster_groups=['noise']) #Print waveforms of units w_wf = sw.plot_unit_templates(sorting=sorting_phy_curated, recording=recording_cache) plt.savefig('manual_' + i + '_unit_templates.pdf', bbox_inches='tight') plt.savefig('manual_' + i + '_unit_templates.png', bbox_inches='tight') plt.close() #Compute agreement matrix wrt consensus-based sorting. sorting_phy_consensus = se.PhySortingExtractor( 'phy_AGR/', exclude_cluster_groups=['noise']) cmp = sc.compare_sorter_to_ground_truth(sorting_phy_curated, sorting_phy_consensus) sw.plot_agreement_matrix(cmp) plt.savefig('agreement_matrix_' + i + '.pdf', bbox_inches='tight') plt.savefig('agreement_matrix_' + i + '.png', bbox_inches='tight') plt.close() #Access unit ID and firing rate. os.chdir('phy_' + i) spike_times = np.load('spike_times.npy') spike_clusters = np.load('spike_clusters.npy') #Find units curated as 'noise' noise_id = [] with open("cluster_group.tsv") as fd: rd = csv.reader(fd, delimiter="\t", quotechar='"') for row in rd: if row[1] == 'noise': noise_id.append(int(row[0])) #Create a list with the unit IDs and remove those labeled as 'noise' some_list = np.unique(spike_clusters) some_list = some_list.tolist() for x in noise_id: print(x) some_list.remove(x) #Bin data in bins of 25ms #45 minutes bins = np.arange(start=0, stop=45 * 60 * fs + 1, step=.025 * fs) NData = np.zeros([ np.unique(spike_clusters).shape[0] - len(noise_id), bins.shape[0] - 1 ]) cont = 0 for x in some_list: #print(x) ind = (spike_clusters == x) fi = spike_times[ind] inds = np.histogram(fi, bins=bins) inds1 = inds[0] NData[cont, :] = inds1 cont = cont + 1 #Save activation matrix os.chdir("..") a = os.path.split(os.getcwd())[1] np.save('actmat_manual_' + i + '_' + a.split('_')[1], NData) np.save('unit_id_manual_' + i + '_' + a.split('_')[1], some_list) #End of for loop print("Stop the code here")
import matplotlib.pyplot as plt import seaborn as sns import spikeinterface.extractors as se import spikeinterface.sorters as sorters import spikeinterface.comparison as sc ############################################################################## recording, sorting_true = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0) sorting_MS4 = sorters.run_mountainsort4(recording) ############################################################################## cmp_gt_MS4 = sc.compare_sorter_to_ground_truth(sorting_true, sorting_MS4, exhaustive_gt=True) ############################################################################## # This function first matches the ground-truth and spike sorted units, and # then it computes several performance metrics. # # Once the spike trains are matched, each spike is labelled as: - true # positive (tp): spike found both in :code:`gt_sorting` and :code:`tested_sorting` # - false negative (fn): spike found in :code:`gt_sorting`, but not in # :code:`tested_sorting` - false positive (fp): spike found in # :code:`tested_sorting`, but not in :code:`gt_sorting` - misclassification errors # (cl): spike found in :code:`gt_sorting`, not in :code:`tested_sorting`, found in # another matched spike train of :code:`tested_sorting`, and not labelled as # true positives # # From the counts of these labels the following performance measures are
def main(): np.set_printoptions(threshold=np.infty) ###################################################### # data extract & preprocess ###################################################### recording, sorting_true = load_data(globvar.h5data_path) extrt_dict, recording_bp = preprocessing(recording, sorting_true) channel_location: list = extrt_dict['channel_loc'] ###################################################### # threshold detection ###################################################### detec = Detection(extrt_dict) split_list = globvar.h5filename.split('_') time_buff = [] start = time.time() detct_rst = detec.mc_run() # detection # divide-and-conquer for multi-electorde merged_frames, grouped_frames, grouped_frameids = detec.get_frame_groups( detct_rst) time_buff.append(time.time() - start) # evaluate detection results fn_frames, fp_frames, evalinfo_dict, according_labels = detec.evaluate( merged_frames) print('----------------detection fnished----------------') print( f"precision: {evalinfo_dict['precision']}, recall: {evalinfo_dict['recall']}, F1: {evalinfo_dict['F1_score']}" ) start = time.time() snipgroups = detec.get_grouped_snippets(grouped_frames) # -----remove empty array----- for i in np.arange(len(grouped_frames))[::-1]: if len(grouped_frames[i]) == 0: del grouped_frames[i] del grouped_frameids[i] del snipgroups[i] del channel_location[i] time_buff.append(time.time() - start) for i in range(len(grouped_frames)): print(f"group{i}, frames{grouped_frames[i].shape}") ###################################################### # KNN outlier detection (NOT USED) ###################################################### all_snippets = np.zeros([ 0, extrt_dict['snip_frame_before'] + extrt_dict['snip_frame_after'] + 1 ]) for i in range(len(snipgroups)): all_snippets = np.vstack((all_snippets, snipgroups[i])) outlier_marks = np.zeros(len(all_snippets), dtype=np.int) grouped_outlier_marks = [] cnt_sample = 0 for i in range(len(snipgroups)): n_samples = len(snipgroups[i]) grouped_outlier_marks.append(outlier_marks[cnt_sample:cnt_sample + n_samples]) cnt_sample += n_samples ###################################################### # PCA for each group ###################################################### start = time.time() pca_snipgroups_rst = [] p = Pool(min(globvar.n_multiprocess, len(snipgroups))) for i in range(len(snipgroups)): pca_snipgroups_rst.append( p.apply_async(decomposition, args=( snipgroups[i], globvar.group_min_samples, ))) p.close() p.join() pca_snipgroups = [] for each in pca_snipgroups_rst: pca_snipgroups.append(each.get()) time_buff.append(time.time() - start) # get benign data pca_benign_snipgroups = [] benign_snipgroups = [] benign_framegroups = [] for i in range(len(grouped_outlier_marks)): ids = np.where(grouped_outlier_marks[i] != -1)[0] benign_snipgroups.append((snipgroups[i])[ids]) benign_framegroups.append((grouped_frames[i])[ids]) pca_benign_snipgroups.append((pca_snipgroups[i])[ids]) ###################################################### # cluster ###################################################### start = time.time() # multiprocess pool p = Pool(min(globvar.n_multiprocess, len(benign_snipgroups))) mp_cluster_rst = [] for i_group in range(len(benign_snipgroups)): mp_cluster_rst.append( p.apply_async(cluster, args=( benign_snipgroups[i_group], pca_benign_snipgroups[i_group], ))) p.close() p.join() benign_labelgroup_list = [] # labels of benign samples tempgroups_list = [] # cluster centroids for i_group in range(len(benign_snipgroups)): i_cluster_rtn = (mp_cluster_rst[i_group]).get() benign_labelgroup_list.append(i_cluster_rtn[0]) # get labels tempgroups_list.append(i_cluster_rtn[1]) # get centroids labelgroup_list = copy.deepcopy(grouped_outlier_marks) # ★★★labels for i in range(len(labelgroup_list)): benign_ids = np.where(grouped_outlier_marks[i] != -1)[0] labelgroup_list[i][benign_ids] = benign_labelgroup_list[ i] # fill labels time_buff.append(time.time() - start) # for i in range(len(labelgroup_list)): statistical = Counter(labelgroup_list[i]) print(statistical) ###################################################### # postprocessing ###################################################### # ----------------templates merging------------------ start = time.time() # merging merged_centroids, merged_labelgroup_list = templates_merging( tempgroups_list, labelgroup_list, channel_location=channel_location, channel_type=(split_list[2].split("-"))[0], sigma=globvar.sigma, peek_diff=globvar.peek_diff) time_buff.append(time.time() - start) ###################################################### # tempalte matching ###################################################### start = time.time() p = Pool(min(globvar.n_multiprocess, len(merged_labelgroup_list))) aftertm_list = [] threshold = np.mean(extrt_dict['stdvir_channels']) * globvar.detect_th for i_group in range(len(merged_labelgroup_list)): aftertm_list.append( p.apply_async(template_matching, args=( snipgroups[i_group], merged_labelgroup_list[i_group], grouped_frames[i_group], merged_centroids, threshold, ))) p.close() p.join() recovered_label_list = [] recovered_frame_list = [] # recovered_frame_list = grouped_frames for i_group in range(len(aftertm_list)): # recovered_label_list.append(aftertm_list[i_group]) tm_rtn = (aftertm_list[i_group]).get() recovered_label_list.append(tm_rtn[0]) recovered_frame_list.append(tm_rtn[1]) time_buff.append(time.time() - start) ###################################################### # get performance ###################################################### recovered_labels = np.array([]) recovered_frames = np.array([]) for i in range(len(recovered_label_list)): recovered_labels = np.append(recovered_labels, recovered_label_list[i]) recovered_frames = np.append(recovered_frames, recovered_frame_list[i]) sorting = se.NumpySortingExtractor() sorting.set_sampling_frequency(extrt_dict['sample_freq']) sorting.set_times_labels(recovered_frames, recovered_labels) n_units = len(sorting.get_unit_ids()) # compare to ground-truth comp: sc.GroundTruthComparison = sc.compare_sorter_to_ground_truth( sorting_true, sorting, delta_time=0.4, match_mode='best') # get_performance comp.print_performance(method='by_unit') comp.print_performance() gt_num_units = len(sorting_true.get_unit_ids()) tested_num_units = len(sorting.get_unit_ids()) print(f"GT num_units: {gt_num_units}") print(f"tested num_units: {tested_num_units}") perf = comp.get_performance(method='pooled_with_average', output='pandas') perf['gt_num_units'] = gt_num_units perf['tested_num_units'] = tested_num_units perf['time'] = sum(time_buff) perf['dataset'] = globvar.h5filename # output to csv perf.to_csv("perf_logger.csv", mode='a', header=False) assert True # debug
#Mountainsort with ka.config(fr='default_readonly'): #with hither.config(cache='default_readwrite'): with hither.config(container='default'): result_MS4 = sorters.mountainsort4.run(recording_path=recordingPath, sorting_out=hither.File()) #Aggregating the output of the sorters sorting_MS4 = AutoSortingExtractor(result_MS4.outputs.sorting_out._path) sorting_SP = AutoSortingExtractor( result_spyKingCircus.outputs.sorting_out._path) #Comparing each to ground truth-confusion matrix comp_MATLAB = sc.compare_sorter_to_ground_truth(gtOutput, sortingPipeline, sampling_frequency=sampleRate, delta_time=3, match_score=0.5, chance_score=0.1, well_detected_score=0.1, exhaustive_gt=True) w_comp_MATLAB = sw.plot_confusion_matrix(comp_MATLAB, count_text=True) plt.show() comp_MS4 = sc.compare_sorter_to_ground_truth(gtOutput, sorting_MS4, sampling_frequency=sampleRate, delta_time=3, match_score=0.5, chance_score=0.1, well_detected_score=0.1, exhaustive_gt=True) w_comp_MS4 = sw.plot_confusion_matrix(comp_MS4, count_text=True)
import spikeinterface.sorters as ss sorting_MS4 = ss.run_mountainsort4(recording) sorting_KL = ss.run_klusta(recording) ############################################################################## # Widgets using SortingComparison # --------------------------------- # # We can compare the spike sorting output to the ground-truth sorting :code:`sorting_true` using the # :code:`comparison` module. :code:`comp_MS4` and :code:`comp_KL` are :code:`SortingComparison` objects import spikeinterface.comparison as sc comp_MS4 = sc.compare_sorter_to_ground_truth(sorting_true, sorting_MS4) comp_KL = sc.compare_sorter_to_ground_truth(sorting_true, sorting_KL) ############################################################################## # plot_confusion_matrix() # ~~~~~~~~~~~~~~~~~~~~~~~~~~ w_comp_MS4 = sw.plot_confusion_matrix(comp_MS4, count_text=False) w_comp_KL = sw.plot_confusion_matrix(comp_KL, count_text=False) ############################################################################## # plot_agreement_matrix() # ~~~~~~~~~~~~~~~~~~~~~~~~~~ w_agr_MS4 = sw.plot_agreement_matrix(comp_MS4, count_text=False)
st = sortingPipeline.get_unit_spike_train(unit_id=1) print('Num. events for unit 1 = {}'.format(len(st))) st1 = sortingPipeline.get_unit_spike_train(unit_id=1) print('Num. events for first second of unit 1 = {}'.format(len(st1))) #We are also going to be setting up the unit spike features associated with each waveform ID1_features = A_snippets_reference[cluster == 1, :] sortingPipeline.set_unit_spike_features(unit_id=1, feature_name='unitId1', value=ID1_features) print("Spike feature names: " + str(sortingPipeline.get_unit_spike_feature_names(unit_id=1))) #Comparing sorter with ground truth cmp_gt_SP = sc.compare_sorter_to_ground_truth(gtOutput, sortingPipeline, exhaustive_gt=True) sw.plot_agreement_matrix(cmp_gt_SP, ordered=True) #Some more comparision metrics perf = cmp_gt_SP.get_performance() #print('well_detected', cmp_gt_SP.get_well_detected_units(well_detected_score=0)) print(perf) #We will try to get the SNR and firing rates #firing_rates = st.validation.compute_firing_rates(sortingPipeline, duration_in_frames=recordingInput.get_num_frames()) #Raster plots w_rs_gt = sw.plot_rasters(sortingPipeline, sampling_frequency=sampleRate)
Created on Mon Sep 30 13:04:03 2019 @author: Jasper Wouters """ import spikeinterface.extractors as se import spikeinterface.sorters as ss import spikeinterface.comparison as sc # full filenames to both the hybrid recording and ground truth recording_fn = '/path/to/recording.bin' gt_fn = '/path/to/hybrid_GT.csv' # create extractor object for both the recording data and ground truth labels recording_ex = se.SHYBRIDRecordingExtractor(recording_fn) sorting_ex = se.SHYBRIDSortingExtractor(gt_fn) # perform spike sorting (e.g., using spyking circus) sc_params = ss.SpykingcircusSorter.default_params() sorting_sc = ss.run_spykingcircus(recording=recording_ex, **sc_params, output_folder='tmp_sc') # calculate spike sorting performance # note: exhaustive_gt is set to False, because the hybrid approach generates # partial ground truth only comparison = sc.compare_sorter_to_ground_truth(sorting_ex, sorting_sc, exhaustive_gt=False) print(comparison.get_performance())
def _do_recovery_loop(task_args): key, well_detected_score, isi_thr, fr_thr, sample_window_ms, \ percentage_spikes, balance_spikes, detect_threshold, method, skew_thr, n_jobs, we_params, compare, \ output_folder, job_kwargs = task_args recording = load_extractor(output_folder / 'back_recording' / key[1] / key[0]) if compare is True: gt = load_extractor(output_folder / 'back_recording' / key[1] / (key[0] + '_gt')) else: gt = None sorting = load_extractor(output_folder / 'back_recording' / key[0] / (key[1] + '_pre')) we = extract_waveforms( recording, sorting, folder=output_folder / 'waveforms' / key[0] / key[1], load_if_exists=we_params['load_if_exists'], ms_before=we_params['ms_before'], ms_after=we_params['ms_after'], max_spikes_per_unit=we_params['max_spikes_per_unit'], return_scaled=we_params['return_scaled'], dtype=we_params['dtype'], overwrite=True, **job_kwargs) if gt is not None: comparison = sc.compare_sorter_to_ground_truth(tested_sorting=sorting, gt_sorting=gt) selected_units = comparison.get_well_detected_units( well_detected_score) print(key[1][:-1]) if key[1] == 'hdsort': selected_units = [unit - 1000 for unit in selected_units] else: isi_violation = st.compute_isi_violations(we)[0] good_isi = np.argwhere( np.array(list(isi_violation.values())) < isi_thr)[:, 0] firing_rate = st.compute_firing_rate(we) good_fr_idx_up = np.argwhere( np.array(list(firing_rate.values())) < fr_thr[1])[:, 0] good_fr_idx_down = np.argwhere( np.array(list(firing_rate.values())) > fr_thr[0])[:, 0] selected_units = [ unit for unit in range(sorting.get_num_units()) if unit in good_fr_idx_up and unit in good_fr_idx_down and unit in good_isi ] templates = we.get_all_templates() templates_dict = { str(unit): templates[unit - 1] for unit in selected_units } recording_subtracted = subtract_templates(recording, sorting, templates_dict, we.nbefore, selected_units) sorter = SpyICASorter(recording_subtracted) sorter.mask_traces(sample_window_ms=sample_window_ms, percent_spikes=percentage_spikes, balance_spikes_on_channel=balance_spikes, detect_threshold=detect_threshold, method=method, **job_kwargs) sorter.compute_ica(n_comp='all') cleaning_result = clean_correlated_sources( recording, sorter.W_ica, skew_thresh=skew_thr, n_jobs=n_jobs, chunk_size=recording.get_num_samples(0) // n_jobs, **job_kwargs) sorter.A_ica[cleaning_result[1]] = -sorter.A_ica[cleaning_result[1]] sorter.W_ica[cleaning_result[1]] = -sorter.W_ica[cleaning_result[1]] sorter.source_idx = cleaning_result[0] sorter.cleaned_A_ica = sorter.A_ica[cleaning_result[0]] sorter.cleaned_W_ica = sorter.W_ica[cleaning_result[0]] ica_recording = st.preprocessing.lin_map(recording_subtracted, sorter.cleaned_W_ica) recording_back = st.preprocessing.lin_map(ica_recording, sorter.cleaned_A_ica.T) recording_back.save_to_folder(folder=output_folder / 'back_recording' / key[0] / key[1])
#Open phy interface os.system('phy template-gui phy_'+i+'/params.py') #Remove detections curated as noise. sorting_phy_curated = se.PhySortingExtractor('phy_'+i+'/', exclude_cluster_groups=['noise']); #Print waveforms of units w_wf = sw.plot_unit_templates(sorting=sorting_phy_curated, recording=recording_cache) plt.savefig('manual_'+i+'_unit_templates.pdf', bbox_inches='tight'); plt.savefig('manual_'+i+'_unit_templates.png', bbox_inches='tight'); plt.close() #Compute agreement matrix wrt consensus-based sorting. sorting_phy_consensus = se.PhySortingExtractor('phy_AGR/', exclude_cluster_groups=['noise']); cmp=sc.compare_sorter_to_ground_truth(sorting_phy_curated,sorting_phy_consensus) sw.plot_agreement_matrix(cmp) plt.savefig('agreement_matrix_'+i+'.pdf', bbox_inches='tight'); plt.savefig('agreement_matrix_'+i+'.png', bbox_inches='tight'); plt.close() #Access unit ID and firing rate. os.chdir('phy_'+i) spike_times=np.load('spike_times.npy'); spike_clusters=np.load('spike_clusters.npy'); #Find units curated as 'noise' noise_id=[]; with open("cluster_group.tsv") as fd: rd = csv.reader(fd, delimiter="\t", quotechar='"') for row in rd:
st8 = sorting_true.get_unit_spike_train(8) st80 = np.sort(np.random.choice(st8, size=int(st8.size*0.65), replace=False)) st81 = np.sort(np.random.choice(st8, size=int(st8.size*0.6), replace=False)) st82 = np.sort(np.random.choice(st8, size=int(st8.size*0.55), replace=False)) units_err[80] = st80 units_err[81] = st81 units_err[81] = st82 # unit 9 is missing # there are some units that do not exist 15 16 and 17 nframes = rec.get_num_frames(segment_index=0) for u in [15,16,17]: st = np.sort(np.random.randint(0, high=nframes, size=35)) units_err[u] = st sorting_err = se.NumpySorting.from_dict(units_err, sampling_frequency) return sorting_true, sorting_err if __name__ == '__main__': # just for check sorting_true, sorting_err = generate_erroneous_sorting() comp = sc.compare_sorter_to_ground_truth(sorting_true, sorting_err, exhaustive_gt=True) sw.plot_agreement_matrix(comp, ordered=True) plt.show()
# Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select # sorted units with a SNR above a certain threshold: sorting_curated_snr = st.curation.threshold_snr(sorting_KL, recording, threshold=5) snrs_above = st.validation.compute_snrs(sorting_curated_snr, recording_cmr) print('Curated SNR', snrs_above) ############################################################################## # The final part of this tutorial deals with comparing spike sorting outputs. # We can either (1) compare the spike sorting results with the ground-truth sorting :code:`sorting_true`, (2) compare # the output of two (Klusta and Mountainsor4), or (3) compare the output of multiple sorters: comp_gt_KL = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_KL) comp_KL_MS4 = sc.compare_two_sorters(sorting1=sorting_KL, sorting2=sorting_MS4) comp_multi = sc.compare_multiple_sorters( sorting_list=[sorting_MS4, sorting_KL], name_list=['klusta', 'ms4']) ############################################################################## # When comparing with a ground-truth sorting extractor (1), you can get the sorting performance and plot a confusion # matrix comp_gt_KL.get_performance() w_conf = sw.plot_confusion_matrix(comp_gt_KL) ############################################################################## # When comparing two sorters (2), we can see the matching of units between sorters. For example, this is how to extract # the unit ids of Mountainsort4 (sorting2) mapped to the units of Klusta (sorting1). Units which are not mapped has -1 # as unit id.
def run(self): task_args_list = [] for key in self._sortings_pre.keys(): # recording_dict = self._recordings[key[0]].to_dict() # sorting_dict = self._sortings_pre[key].to_dict() # gt_dict = self._gt[key[0]].to_dict() if self._gt is not None else None # comparison = sc.compare_sorter_to_ground_truth(tested_sorting=self._sortings_pre[key], gt_sorting=self._gt[key[0]]) # self._comparisons[key] = comparison # task_args_list.append((recording_dict, gt_dict, sorting_dict, key, # self._params_dict['wd_score'], self._params_dict['isi_thr'], # self._params_dict['fr_thr'], self._params_dict['sample_window_ms'], # self._params_dict['percentage_spikes'], self._params_dict['balance_spikes'], # self._params_dict['detect_threshold'], self._params_dict['method'], # self._params_dict['skew_thr'], self._params_dict['n_jobs'], self._we_params, # comparison, self._output_folder, self._params_dict['job_kwargs'])) self._recordings[key[0]].save_to_folder( folder=self._output_folder / 'back_recording' / key[1] / key[0]) self._sortings_pre[key].save_to_folder(folder=self._output_folder / 'back_recording' / key[0] / (key[1] + '_pre')) self._gt[key[0]].save_to_folder(folder=self._output_folder / 'back_recording' / key[1] / (key[0] + '_gt')) task_args_list.append( (key, self._params_dict['wd_score'], self._params_dict['isi_thr'], self._params_dict['fr_thr'], self._params_dict['sample_window_ms'], self._params_dict['percentage_spikes'], self._params_dict['balance_spikes'], self._params_dict['detect_threshold'], self._params_dict['method'], self._params_dict['skew_thr'], self._params_dict['n_jobs'], self._we_params, self._compare, self._output_folder, self._params_dict['job_kwargs'])) if self._params_dict['parallel']: # raise NotImplementedError() from joblib import Parallel, delayed Parallel(n_jobs=self._params_dict['n_jobs'], backend='loky')(delayed(_do_recovery_loop)(task_args) for task_args in task_args_list) else: for task_args in task_args_list: _do_recovery_loop(task_args) for key in self._sortings_pre.keys(): if key[1] in self._recordings_backprojected.keys(): self._recordings_backprojected[key[1]].append( load_extractor(self._output_folder / 'back_recording' / key[0] / key[1])) else: self._recordings_backprojected[key[1]] = \ [load_extractor(self._output_folder / 'back_recording' / key[0] / key[1])] for sorter in self._recordings_backprojected.keys(): self._sortings_post[sorter] = ss.run_sorters( sorter, self._recordings_backprojected[sorter], working_folder=self._output_folder / 'sortings_post' / sorter, sorter_params=self._sorters_params['sorters_params'], mode_if_folder_exists='overwrite', engine=self._sorters_params['engine'], engine_kwargs=self._sorters_params['engine_kwargs'], verbose=self._sorters_params['verbose'], with_output=self._sorters_params['with_output']) for key in self._sortings_post[sorter].keys(): self._aggregated_sortings[key] = aggregate_units([ self._sortings_post[sorter][key], self._sortings_pre[key] ]) self._comparisons[key] = sc.compare_sorter_to_ground_truth( tested_sorting=self._sortings_pre[key], gt_sorting=self._gt[key[0]])