示例#1
0
def make_comparison_figures():
    
    gt_sorting, tested_sorting = generate_erroneous_sorting()
    
    comp = sc.compare_sorter_to_ground_truth(gt_sorting, tested_sorting, gt_name=None, tested_name=None,
                                   delta_time=0.4, sampling_frequency=None, min_accuracy=0.5, exhaustive_gt=True, match_mode='hungarian', 
                                   n_jobs=-1, bad_redundant_threshold=0.2, compute_labels=False, verbose=False)
    
    print(comp.hungarian_match_12)
    
    fig, ax = plt.subplots()
    im = ax.matshow(comp.match_event_count, cmap='Greens')
    ax.set_xticks(np.arange(0, comp.match_event_count.shape[1]))
    ax.set_yticks(np.arange(0, comp.match_event_count.shape[0]))
    ax.xaxis.tick_bottom()
    ax.set_yticklabels(comp.match_event_count.index, fontsize=12)
    ax.set_xticklabels(comp.match_event_count.columns, fontsize=12)
    fig.colorbar(im)
    fig.savefig('spikecomparison_match_count.png')
    
    fig, ax = plt.subplots()
    sw.plot_agreement_matrix(comp, ax=ax, ordered=False)
    im = ax.get_images()[0]
    fig.colorbar(im)
    fig.savefig('spikecomparison_agreement_unordered.png')

    fig, ax = plt.subplots()
    sw.plot_agreement_matrix(comp, ax=ax)
    im = ax.get_images()[0]
    fig.colorbar(im)
    fig.savefig('spikecomparison_agreement.png')
    
    fig, ax = plt.subplots()
    sw.plot_confusion_matrix(comp, ax=ax)
    im = ax.get_images()[0]
    fig.colorbar(im)
    fig.savefig('spikecomparison_confusion.png')
    
    
    
    
    
    plt.show()
示例#2
0
# The :code:`compare_two_sorters` function allows us to compare the spike
# sorting output. It returns a :code:`SortingComparison` object, with methods
# to inspect the comparison output easily. The comparison matches the
# units by comparing the agreement between unit spike trains.
#
# Let’s see how to inspect and access this matching.

cmp_HS_TDC = sc.compare_two_sorters(sorting1=sorting_HS,
                                    sorting2=sorting_TDC,
                                    sorting1_name='HS',
                                    sorting2_name='TDC')

#############################################################################
# We can check the agreement matrix to inspect the matching.

sw.plot_agreement_matrix(cmp_HS_TDC)

#############################################################################
# Some useful internal dataframes help to check the match and count
#  like **match_event_count** or **agreement_scores**

print(cmp_HS_TDC.match_event_count)
print(cmp_HS_TDC.agreement_scores)

#############################################################################
# In order to check which units were matched, the :code:`get_matching`
# methods can be used. If units are not matched they are listed as -1.

sc_to_tdc, tdc_to_sc = cmp_HS_TDC.get_matching()

print('matching HS to TDC')
示例#3
0
 def plot_agreement_matrix(self):
     sw.plot_agreement_matrix(self.cmp_sorters)
  os.system('phy template-gui phy_'+i+'/params.py') 
 
      
  #Remove detections curated as noise.
  sorting_phy_curated = se.PhySortingExtractor('phy_'+i+'/', exclude_cluster_groups=['noise']);
  
  #Print waveforms of units
  w_wf = sw.plot_unit_templates(sorting=sorting_phy_curated, recording=recording_cache)
  plt.savefig('manual_'+i+'_unit_templates.pdf', bbox_inches='tight');
  plt.savefig('manual_'+i+'_unit_templates.png', bbox_inches='tight');
  plt.close()
  
  #Compute agreement matrix wrt consensus-based sorting.
  sorting_phy_consensus = se.PhySortingExtractor('phy_AGR/', exclude_cluster_groups=['noise']);
  cmp=sc.compare_sorter_to_ground_truth(sorting_phy_curated,sorting_phy_consensus)
  sw.plot_agreement_matrix(cmp)
  plt.savefig('agreement_matrix_'+i+'.pdf', bbox_inches='tight');
  plt.savefig('agreement_matrix_'+i+'.png', bbox_inches='tight');
  plt.close()
  
  
  #Access unit ID and firing rate.
  os.chdir('phy_'+i)
  spike_times=np.load('spike_times.npy');
  spike_clusters=np.load('spike_clusters.npy');
  #Find units curated as 'noise'
  noise_id=[];    
  with open("cluster_group.tsv") as fd:
      rd = csv.reader(fd, delimiter="\t", quotechar='"')
      for row in rd:
          if row[1]=='noise':
示例#5
0

##############################################################################
# run herdingspikes on it

sorting_HS = ss.run_herdingspikes(recording)

##############################################################################

cmp_gt_HS = sc.compare_sorter_to_ground_truth(sorting_true, sorting_HS, exhaustive_gt=True)


##############################################################################
# To have an overview of the match we can use the unordered agreement matrix

sw.plot_agreement_matrix(cmp_gt_HS, ordered=False)

##############################################################################
# or ordered

sw.plot_agreement_matrix(cmp_gt_HS, ordered=True)

##############################################################################
# This function first matches the ground-truth and spike sorted units, and
# then it computes several performance metrics.
# 
# Once the spike trains are matched, each spike is labelled as: 
# 
# - true positive (tp): spike found both in :code:`gt_sorting` and :code:`tested_sorting`
# - false negative (fn): spike found in :code:`gt_sorting`, but not in :code:`tested_sorting` 
# - false positive (fp): spike found in :code:`tested_sorting`, but not in :code:`gt_sorting` 
示例#6
0
def manual(recording_folder):
    #Folder with tetrode data
    #recording_folder='/home/adrian/Documents/SpikeSorting/Adrian_test_data/Irene_data/test_without_zero_main_channels/Tetrode_9_CH';

    os.chdir(recording_folder)
    """
    Adding Matlab-based sorters to path
    
    """

    #IronClust
    iron_path = "~/Documents/SpikeSorting/ironclust"
    ss.IronClustSorter.set_ironclust_path(os.path.expanduser(iron_path))
    ss.IronClustSorter.ironclust_path

    #If sorter has already been run skip it.
    subfolders = [f.name for f in os.scandir(recording_folder) if f.is_dir()]
    #if ('phy_KL' in subfolders) & ('phy_IC' in subfolders) & ('phy_Waveclus' in subfolders) & ('phy_SC' in subfolders) & ('phy_MS4' in subfolders) & ('phy_HS' in subfolders) & ('phy_TRI' in subfolders):
    if ('phy_KL' in subfolders) & ('phy_IC' in subfolders) & (
            'phy_SC' in subfolders) & ('phy_MS4' in subfolders) & (
                'phy_HS' in subfolders) & ('phy_TRI' in subfolders):
        print('Tetrode ' + recording_folder.split('_')[-1] +
              ' was previously manually sorted. Skipping')
        return

    #Check if the recording has been preprocessed before and load it.
    # Else proceed with preprocessing.
    arr = os.listdir()

    #Load .continuous files
    recording = se.OpenEphysRecordingExtractor(recording_folder)
    channel_ids = recording.get_channel_ids()
    fs = recording.get_sampling_frequency()
    num_chan = recording.get_num_channels()

    print('Channel ids:', channel_ids)
    print('Sampling frequency:', fs)
    print('Number of channels:', num_chan)

    #!cat tetrode9.prb #Asks for prb file
    # os.system('cat /home/adrian/Documents/SpikeSorting/Adrian_test_data/Irene_data/test_without_zero_main_channels/Tetrode_9_CH/tetrode9.prb')
    recording_prb = recording.load_probe_file(os.getcwd() + '/tetrode.prb')

    print('Channels after loading the probe file:',
          recording_prb.get_channel_ids())
    print('Channel groups after loading the probe file:',
          recording_prb.get_channel_groups())

    #For testing only: Reduce recording.
    #recording_prb = se.SubRecordingExtractor(recording_prb, start_frame=100*fs, end_frame=420*fs)

    #Bandpass filter
    recording_cmr = st.preprocessing.bandpass_filter(recording_prb,
                                                     freq_min=300,
                                                     freq_max=6000)
    recording_cache = se.CacheRecordingExtractor(recording_cmr)

    print(recording_cache.get_channel_ids())
    print(recording_cache.get_channel_groups())
    print(recording_cache.get_num_frames() /
          recording_cache.get_sampling_frequency())

    #View installed sorters
    #ss.installed_sorters()
    #mylist = [f for f in glob.glob("*.txt")]

    #%% Run all channels. There are only single tetrode channels anyway.

    #Create sub recording to avoid saving whole recording.Requirement from NWB to allow saving sorters data.
    recording_sub = se.SubRecordingExtractor(recording_cache,
                                             start_frame=200 * fs,
                                             end_frame=320 * fs)
    # Sorters2CompareLabel=['KL','IC','Waveclus','HS','MS4','SC','TRI'];
    Sorters2CompareLabel = ['KL', 'IC', 'HS', 'MS4', 'SC', 'TRI']
    subfolders = [f.name for f in os.scandir(recording_folder) if f.is_dir()]

    for num in range(len(Sorters2CompareLabel)):

        i = Sorters2CompareLabel[num]
        print(i)
        if 'phy_' + i in subfolders:
            print('Sorter already used for curation. Skipping')
            continue
        else:

            if 'KL' in i:
                #Klusta
                if 'sorting_KL_all.nwb' in arr:
                    print('Loading Klusta')
                    sorting_KL_all = se.NwbSortingExtractor(
                        'sorting_KL_all.nwb')

                else:
                    t = time.time()
                    sorting_KL_all = ss.run_klusta(
                        recording_cache,
                        output_folder='results_all_klusta',
                        delete_output_folder=True)
                    print('Found', len(sorting_KL_all.get_unit_ids()), 'units')
                    print(time.time() - t)
                    #Save Klusta
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_KL_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_KL_all, 'sorting_KL_all.nwb')
                sorter = sorting_KL_all

            if 'IC' in i:
                #Ironclust
                if 'sorting_IC_all.nwb' in arr:
                    print('Loading Ironclust')
                    sorting_IC_all = se.NwbSortingExtractor(
                        'sorting_IC_all.nwb')

                else:
                    t = time.time()
                    sorting_IC_all = ss.run_ironclust(
                        recording_cache,
                        output_folder='results_all_ic',
                        delete_output_folder=True,
                        filter=False)
                    print('Found', len(sorting_IC_all.get_unit_ids()), 'units')
                    print(time.time() - t)
                    #Save IC
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_IC_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_IC_all, 'sorting_IC_all.nwb')
                sorter = sorting_IC_all

            # if 'Waveclus' in i:
            #     #Waveclust
            #     if 'sorting_waveclus_all.nwb' in arr:
            #         print('Loading waveclus')
            #         sorting_waveclus_all=se.NwbSortingExtractor('sorting_waveclus_all.nwb');

            #     else:
            #         t = time.time()
            #         sorting_waveclus_all = ss.run_waveclus(recording_cache, output_folder='results_all_waveclus',delete_output_folder=True)
            #         print('Found', len(sorting_waveclus_all.get_unit_ids()), 'units')
            #         print(time.time() - t)
            #         #Save waveclus
            #         se.NwbRecordingExtractor.write_recording(recording_sub, 'sorting_waveclus_all.nwb')
            #         se.NwbSortingExtractor.write_sorting(sorting_waveclus_all, 'sorting_waveclus_all.nwb')
            #     sorter=sorting_waveclus_all;

            if 'HS' in i:
                #Herdingspikes
                if 'sorting_herdingspikes_all.nwb' in arr:
                    print('Loading herdingspikes')
                    sorting_herdingspikes_all = se.NwbSortingExtractor(
                        'sorting_herdingspikes_all.nwb')
                    sorter = sorting_herdingspikes_all

                else:
                    t = time.time()
                    try:
                        sorting_herdingspikes_all = ss.run_herdingspikes(
                            recording_cache,
                            output_folder='results_all_herdingspikes',
                            delete_output_folder=True)
                        print('Found',
                              len(sorting_herdingspikes_all.get_unit_ids()),
                              'units')
                        time.time() - t
                        #Save herdingspikes
                        se.NwbRecordingExtractor.write_recording(
                            recording_sub, 'sorting_herdingspikes_all.nwb')
                        try:
                            se.NwbSortingExtractor.write_sorting(
                                sorting_herdingspikes_all,
                                'sorting_herdingspikes_all.nwb')
                        except TypeError:
                            print(
                                "No units detected.  Can't save HerdingSpikes")
                            os.remove("sorting_herdingspikes_all.nwb")
                        sorter = sorting_herdingspikes_all
                    except:
                        print('Herdingspikes has failed')
                        sorter = []

            if 'MS4' in i:
                #Mountainsort4
                if 'sorting_mountainsort4_all.nwb' in arr:
                    print('Loading mountainsort4')
                    sorting_mountainsort4_all = se.NwbSortingExtractor(
                        'sorting_mountainsort4_all.nwb')

                else:
                    t = time.time()
                    sorting_mountainsort4_all = ss.run_mountainsort4(
                        recording_cache,
                        output_folder='results_all_mountainsort4',
                        delete_output_folder=True,
                        filter=False)
                    print('Found',
                          len(sorting_mountainsort4_all.get_unit_ids()),
                          'units')
                    print(time.time() - t)
                    #Save mountainsort4
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_mountainsort4_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_mountainsort4_all,
                        'sorting_mountainsort4_all.nwb')
                sorter = sorting_mountainsort4_all

            if 'SC' in i:
                #Spykingcircus
                if 'sorting_spykingcircus_all.nwb' in arr:
                    print('Loading spykingcircus')
                    sorting_spykingcircus_all = se.NwbSortingExtractor(
                        'sorting_spykingcircus_all.nwb', filter=False)

                else:
                    t = time.time()
                    sorting_spykingcircus_all = ss.run_spykingcircus(
                        recording_cache,
                        output_folder='results_all_spykingcircus',
                        delete_output_folder=True)
                    print('Found',
                          len(sorting_spykingcircus_all.get_unit_ids()),
                          'units')
                    print(time.time() - t)
                    #Save sorting_spykingcircus
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_spykingcircus_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_spykingcircus_all,
                        'sorting_spykingcircus_all.nwb')
                sorter = sorting_spykingcircus_all

            if 'TRI' in i:
                #Tridesclous
                if 'sorting_tridesclous_all.nwb' in arr:
                    print('Loading tridesclous')
                    try:
                        sorting_tridesclous_all = se.NwbSortingExtractor(
                            'sorting_tridesclous_all.nwb')
                    except AttributeError:
                        print(
                            "No units detected.  Can't load Tridesclous so will run it."
                        )
                        t = time.time()
                        sorting_tridesclous_all = ss.run_tridesclous(
                            recording_cache,
                            output_folder='results_all_tridesclous',
                            delete_output_folder=True)
                        print('Found',
                              len(sorting_tridesclous_all.get_unit_ids()),
                              'units')
                        time.time() - t
                        os.remove("sorting_tridesclous_all.nwb")
                        #Save sorting_tridesclous
                        se.NwbRecordingExtractor.write_recording(
                            recording_sub, 'sorting_tridesclous_all.nwb')
                        se.NwbSortingExtractor.write_sorting(
                            sorting_tridesclous_all,
                            'sorting_tridesclous_all.nwb')

                else:
                    t = time.time()
                    sorting_tridesclous_all = ss.run_tridesclous(
                        recording_cache,
                        output_folder='results_all_tridesclous',
                        delete_output_folder=True)
                    print('Found', len(sorting_tridesclous_all.get_unit_ids()),
                          'units')
                    time.time() - t
                    #Save sorting_tridesclous
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_tridesclous_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_tridesclous_all, 'sorting_tridesclous_all.nwb')
                sorter = sorting_tridesclous_all

        #Check if sorter failed
            if not sorter:
                continue

            st.postprocessing.export_to_phy(recording_cache,
                                            sorter,
                                            output_folder='phy_' + i,
                                            grouping_property='group',
                                            verbose=True,
                                            recompute_info=True)

            #Open phy interface
            os.system('phy template-gui phy_' + i + '/params.py')

            #Remove detections curated as noise.
            sorting_phy_curated = se.PhySortingExtractor(
                'phy_' + i + '/', exclude_cluster_groups=['noise'])

            #Print waveforms of units
            w_wf = sw.plot_unit_templates(sorting=sorting_phy_curated,
                                          recording=recording_cache)
            plt.savefig('manual_' + i + '_unit_templates.pdf',
                        bbox_inches='tight')
            plt.savefig('manual_' + i + '_unit_templates.png',
                        bbox_inches='tight')
            plt.close()

            #Compute agreement matrix wrt consensus-based sorting.
            sorting_phy_consensus = se.PhySortingExtractor(
                'phy_AGR/', exclude_cluster_groups=['noise'])
            cmp = sc.compare_sorter_to_ground_truth(sorting_phy_curated,
                                                    sorting_phy_consensus)
            sw.plot_agreement_matrix(cmp)
            plt.savefig('agreement_matrix_' + i + '.pdf', bbox_inches='tight')
            plt.savefig('agreement_matrix_' + i + '.png', bbox_inches='tight')
            plt.close()

            #Access unit ID and firing rate.
            os.chdir('phy_' + i)
            spike_times = np.load('spike_times.npy')
            spike_clusters = np.load('spike_clusters.npy')
            #Find units curated as 'noise'
            noise_id = []
            with open("cluster_group.tsv") as fd:
                rd = csv.reader(fd, delimiter="\t", quotechar='"')
                for row in rd:
                    if row[1] == 'noise':
                        noise_id.append(int(row[0]))
            #Create a list with the unit IDs and remove those labeled as 'noise'
            some_list = np.unique(spike_clusters)
            some_list = some_list.tolist()
            for x in noise_id:
                print(x)
                some_list.remove(x)

            #Bin data in bins of 25ms
            #45 minutes
            bins = np.arange(start=0, stop=45 * 60 * fs + 1, step=.025 * fs)
            NData = np.zeros([
                np.unique(spike_clusters).shape[0] - len(noise_id),
                bins.shape[0] - 1
            ])

            cont = 0
            for x in some_list:
                #print(x)
                ind = (spike_clusters == x)
                fi = spike_times[ind]
                inds = np.histogram(fi, bins=bins)
                inds1 = inds[0]
                NData[cont, :] = inds1
                cont = cont + 1

            #Save activation matrix
            os.chdir("..")
            a = os.path.split(os.getcwd())[1]
            np.save('actmat_manual_' + i + '_' + a.split('_')[1], NData)
            np.save('unit_id_manual_' + i + '_' + a.split('_')[1], some_list)

    #End of for loop
    print("Stop the code here")
示例#7
0
    st8 = sorting_true.get_unit_spike_train(8)
    st80 = np.sort(np.random.choice(st8, size=int(st8.size*0.65), replace=False))
    st81 = np.sort(np.random.choice(st8, size=int(st8.size*0.6), replace=False))
    st82 = np.sort(np.random.choice(st8, size=int(st8.size*0.55), replace=False))
    units_err[80] = st80
    units_err[81] = st81
    units_err[81] = st82
    
    # unit 9 is missing
    
    # there are some units that do not exist 15 16 and 17
    nframes = rec.get_num_frames(segment_index=0)
    for u in [15,16,17]:
        st = np.sort(np.random.randint(0, high=nframes, size=35))
        units_err[u] = st
    sorting_err = se.NumpySorting.from_dict(units_err, sampling_frequency)
    
    
    return sorting_true, sorting_err
    


    
if __name__ == '__main__':
    # just for check
    sorting_true, sorting_err = generate_erroneous_sorting()
    comp = sc.compare_sorter_to_ground_truth(sorting_true, sorting_err, exhaustive_gt=True)
    sw.plot_agreement_matrix(comp, ordered=True)
    plt.show()

示例#8
0
 def test_agreement(self):
     sw.plot_agreement_matrix(self._gt_comp, count_text=True)
示例#9
0
# We can either (1) compare the spike sorting results with the ground-truth 
# sorting :code:`sorting_true`, (2) compare the output of two (HerdingSpikes
# and Tridesclous), or (3) compare the output of multiple sorters:

comp_gt_TDC = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC)
comp_TDC_HS = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_HS)
comp_multi = sc.compare_multiple_sorters(sorting_list=[sorting_TDC, sorting_HS],
                                         name_list=['tdc', 'hs'])

##############################################################################
# When comparing with a ground-truth sorting extractor (1), you can get the sorting performance and plot a confusion
# matrix

comp_gt_TDC.get_performance()
w_conf = sw.plot_confusion_matrix(comp_gt_TDC)
w_agr = sw.plot_agreement_matrix(comp_gt_TDC)

##############################################################################
# When comparing two sorters (2), we can see the matching of units between sorters.
# Units which are not matched has -1 as unit id:

comp_TDC_HS.hungarian_match_12

##############################################################################
# or the reverse:

comp_TDC_HS.hungarian_match_21

##############################################################################
# When comparing multiple sorters (3), you can extract a :code:`SortingExtractor` object with units in agreement
# between sorters. You can also plot a graph showing how the units are matched between the sorters.
示例#10
0
# So you can acces finely to all individual results.
#  
# Note that exhaustive_gt=True when you excatly how many
# units in ground truth (for synthetic datasets)

study.run_comparisons(exhaustive_gt=True)

for (rec_name, sorter_name), comp in study.comparisons.items():
    print('*' * 10)
    print(rec_name, sorter_name)
    print(comp.count_score)  # raw counting of tp/fp/...
    comp.print_summary()
    perf_unit = comp.get_performance(method='by_unit')
    perf_avg = comp.get_performance(method='pooled_with_average')
    m = comp.get_confusion_matrix()
    w_comp = sw.plot_agreement_matrix(comp)
    w_comp.ax.set_title(rec_name  + ' - ' + sorter_name)

##############################################################################
# Collect synthetic dataframes and display
# ----------------------------------------
# 
# As shown previously, the performance is returned as a pandas dataframe.
# The :code:`aggregate_performances_table` function, gathers all the outputs in
# the study folder and merges them in a single dataframe.

dataframes = study.aggregate_dataframes()

##############################################################################
# Pandas dataframes can be nicely displayed as tables in the notebook.
w_comp_MS4 = sw.plot_confusion_matrix(comp_MS4, count_text=True)
plt.show()

comp_SP = sc.compare_sorter_to_ground_truth(gtOutput,
                                            sorting_SP,
                                            sampling_frequency=sampleRate,
                                            delta_time=3,
                                            match_score=0.5,
                                            chance_score=0.1,
                                            well_detected_score=0.1,
                                            exhaustive_gt=True)
w_comp_SP = sw.plot_confusion_matrix(comp_SP, count_text=True)
plt.show()

#Computing some metrics for benchmarking-agreement matrix
sw.plot_agreement_matrix(comp_MATLAB, ordered=True, count_text=True)
perf_MATLAB = comp_MATLAB.get_performance()
plt.show()
print(perf_MATLAB)

#comparing the sorting algos
#We will try to compare all the three sorters
mcmp = sc.compare_multiple_sorters(
    sorting_list=[sortingPipeline, sorting_MS4, sorting_SP],
    name_list=['Our', 'MS4', 'SP'],
    verbose=True)

sw.plot_multicomp_graph(mcmp)
plt.show()

#Pairwise
示例#12
0
comp_MS4 = sc.compare_sorter_to_ground_truth(sorting_true, sorting_MS4)
comp_KL = sc.compare_sorter_to_ground_truth(sorting_true, sorting_KL)

##############################################################################
# plot_confusion_matrix()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~

w_comp_MS4 = sw.plot_confusion_matrix(comp_MS4, count_text=False)
w_comp_KL = sw.plot_confusion_matrix(comp_KL, count_text=False)

##############################################################################
# plot_agreement_matrix()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~

w_agr_MS4 = sw.plot_agreement_matrix(comp_MS4, count_text=False)

##############################################################################
# plot_sorting_performance()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We can also plot a performance metric (e.g. accuracy, recall, precision) with respect to a quality metric, for
# example signal-to-noise ratio. Quality metrics can be computed using the :code:`toolkit.validation` submodule

import spikeinterface.toolkit as st

snrs = st.validation.compute_snrs(sorting_true,
                                  recording,
                                  save_as_property=True)

w_perf = sw.plot_sorting_performance(comp_MS4,
示例#13
0
st1 = sortingPipeline.get_unit_spike_train(unit_id=1)
print('Num. events for first second of unit 1 = {}'.format(len(st1)))

#We are also going to be setting up the unit spike features associated with each waveform
ID1_features = A_snippets_reference[cluster == 1, :]
sortingPipeline.set_unit_spike_features(unit_id=1,
                                        feature_name='unitId1',
                                        value=ID1_features)
print("Spike feature names: " +
      str(sortingPipeline.get_unit_spike_feature_names(unit_id=1)))

#Comparing sorter with ground truth
cmp_gt_SP = sc.compare_sorter_to_ground_truth(gtOutput,
                                              sortingPipeline,
                                              exhaustive_gt=True)
sw.plot_agreement_matrix(cmp_gt_SP, ordered=True)

#Some more comparision metrics
perf = cmp_gt_SP.get_performance()
#print('well_detected', cmp_gt_SP.get_well_detected_units(well_detected_score=0))
print(perf)
#We will try to get the SNR and firing rates

#firing_rates = st.validation.compute_firing_rates(sortingPipeline, duration_in_frames=recordingInput.get_num_frames())

#Raster plots

w_rs_gt = sw.plot_rasters(sortingPipeline, sampling_frequency=sampleRate)

w_wf_gt = sw.plot_unit_waveforms(recordingInput,
                                 sortingPipeline,
示例#14
0
##############################################################################

recording, sorting_true = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0)

sorting_MS4 = ss.run_mountainsort4(recording)

##############################################################################

cmp_gt_MS4 = sc.compare_sorter_to_ground_truth(sorting_true, sorting_MS4, exhaustive_gt=True)


##############################################################################
# To have an overview of the match we can use the unordered agreement matrix

sw.plot_agreement_matrix(cmp_gt_MS4, ordered=False)

##############################################################################
# or ordered

sw.plot_agreement_matrix(cmp_gt_MS4, ordered=True)

##############################################################################
# This function first matches the ground-truth and spike sorted units, and
# then it computes several performance metrics.
# 
# Once the spike trains are matched, each spike is labelled as: 
# 
# - true positive (tp): spike found both in :code:`gt_sorting` and :code:`tested_sorting`
# - false negative (fn): spike found in :code:`gt_sorting`, but not in :code:`tested_sorting` 
# - false positive (fp): spike found in :code:`tested_sorting`, but not in :code:`gt_sorting` 
示例#15
0
# The :code:`compare_two_sorters` function allows us to compare the spike
# sorting output. It returns a :code:`SortingComparison` object, with methods
# to inspect the comparison output easily. The comparison matches the
# units by comparing the agreement between unit spike trains.
#
# Let’s see how to inspect and access this matching.

cmp_KL_MS4 = sc.compare_two_sorters(sorting1=sorting_KL,
                                    sorting2=sorting_MS4,
                                    sorting1_name='klusta',
                                    sorting2_name='ms4')

#############################################################################
# We can check the agreement matrix to inspect the matching.

sw.plot_agreement_matrix(cmp_KL_MS4)

#############################################################################
# Some useful internal dataframes help to check the match and count
#  like **match_event_count** or **agreement_scores**

print(cmp_KL_MS4.match_event_count)
print(cmp_KL_MS4.agreement_scores)

#############################################################################
# In order to check which units were matched, the :code:`get_mapped_sorting`
# methods can be used. If units are not matched they are listed as -1.

# units matched to klusta units
mapped_sorting_klusta = cmp_KL_MS4.get_mapped_sorting1()
print('Klusta units:', sorting_KL.get_unit_ids())