Beispiel #1
0
 def test_multicomp_graph(self):
     msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting])
     sw.plot_multicomp_graph(msc, edge_cmap='viridis', node_cmap='rainbow', draw_labels=False)
     sw.plot_multicomp_agreement(msc)
     sw.plot_multicomp_agreement_by_sorter(msc)
     fig, axes = plt.subplots(len(msc.sorting_list), 1)
     sw.plot_multicomp_agreement_by_sorter(msc, axes=axes)
Beispiel #2
0
def test_compare_multiple_sorters():
    # simple match
    sorting1, sorting2, sorting3 = make_sorting(
        [100, 200, 300, 400, 500, 600, 700, 800, 900],
        [0, 1, 2, 0, 1, 2, 0, 1, 2],
        [101, 201, 301, 400, 501, 598, 702, 801, 899, 1000, 1100, 2000, 3000],
        [0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 3, 4, 4],
        [
            101, 201, 301, 400, 500, 600, 700, 800, 900, 1000, 1100, 2000,
            3000, 3100, 3200, 3300
        ],
        [0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 3, 4, 4, 5, 5, 5],
    )
    msc = compare_multiple_sorters([sorting1, sorting2, sorting3],
                                   verbose=True)
    msc_shuffle = compare_multiple_sorters([sorting3, sorting1, sorting2])

    agr = msc._do_agreement_matrix()
    agr_shuffle = msc_shuffle._do_agreement_matrix()

    print(agr)
    print(agr_shuffle)

    assert len(
        msc.get_agreement_sorting(
            minimum_agreement_count=3).get_unit_ids()) == 3
    assert len(
        msc.get_agreement_sorting(
            minimum_agreement_count=2).get_unit_ids()) == 5
    assert len(msc.get_agreement_sorting().get_unit_ids()) == 6
    assert len(msc.get_agreement_sorting(minimum_agreement_count=3).get_unit_ids()) == \
           len(msc_shuffle.get_agreement_sorting(minimum_agreement_count=3).get_unit_ids())
    assert len(msc.get_agreement_sorting(minimum_agreement_count=2).get_unit_ids()) == \
           len(msc_shuffle.get_agreement_sorting(minimum_agreement_count=2).get_unit_ids())
    assert len(msc.get_agreement_sorting().get_unit_ids()) == len(
        msc_shuffle.get_agreement_sorting().get_unit_ids())
    agreement_2 = msc.get_agreement_sorting(minimum_agreement_count=2,
                                            minimum_agreement_count_only=True)
    assert np.all([agreement_2.get_unit_property(u, 'agreement_number')] == 2
                  for u in agreement_2.get_unit_ids())

    msc.save_to_folder('saved_multisorting_comparison')

    msc = MultiSortingComparison.load_from_folder(
        'saved_multisorting_comparison')
def compare_sorters(sort1, sort2):
    comp_KL_MS4 = sc.compare_two_sorters(sorting1=sort1, sorting2=sort2)
    mapped_units = comp_KL_MS4.get_mapped_sorting1().get_mapped_unit_ids()

    print('Klusta units:', sort1.get_unit_ids())
    print('Mapped Mountainsort4 units:', mapped_units)

    comp_multi = sc.compare_multiple_sorters(sorting_list=[sort1, sort2],
                                             name_list=['klusta', 'ms4'])

    sorting_agreement = comp_multi.get_agreement_sorting(minimum_matching=2)

    print('Units in agreement between Klusta and Mountainsort4:',
          sorting_agreement.get_unit_ids())

    w_multi = sw.plot_multicomp_graph(comp_multi)
    plt.show()
                                                     duration=20,
                                                     seed=0)

#############################################################################
# Then run 3 spike sorters and compare their ouput.

sorting_KL = sorters.run_klusta(recording)
sorting_MS4 = sorters.run_mountainsort4(recording)
sorting_TDC = sorters.run_tridesclous(recording)

#############################################################################
# Compare multiple spike sorter outputs
# -------------------------------------------

mcmp = sc.compare_multiple_sorters(
    sorting_list=[sorting_KL, sorting_MS4, sorting_TDC],
    name_list=['KL', 'MS4', 'TDC'],
    verbose=True)

#############################################################################
# The multiple sorters comparison internally computes pairwise comparison,
# that can be accessed as follows:

print(mcmp.comparisons[0].sorting1, mcmp.comparisons[0].sorting2)
mcmp.comparisons[0].get_mapped_sorting1().get_mapped_unit_ids()

#############################################################################
print(mcmp.comparisons[1].sorting1, mcmp.comparisons[1].sorting2)
mcmp.comparisons[0].get_mapped_sorting1().get_mapped_unit_ids()

#############################################################################
# The global multi caomparison can be visualized with this graph
sorting_curated_snr = st.curation.threshold_snr(sorting_KL,
                                                recording,
                                                threshold=5)
snrs_above = st.validation.compute_snrs(sorting_curated_snr, recording_cmr)

print('Curated SNR', snrs_above)

##############################################################################
# The final part of this tutorial deals with comparing spike sorting outputs.
# We can either (1) compare the spike sorting results with the ground-truth sorting :code:`sorting_true`, (2) compare
# the output of two (Klusta and Mountainsor4), or (3) compare the output of multiple sorters:

comp_gt_KL = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true,
                                               tested_sorting=sorting_KL)
comp_KL_MS4 = sc.compare_two_sorters(sorting1=sorting_KL, sorting2=sorting_MS4)
comp_multi = sc.compare_multiple_sorters(
    sorting_list=[sorting_MS4, sorting_KL], name_list=['klusta', 'ms4'])

##############################################################################
# When comparing with a ground-truth sorting extractor (1), you can get the sorting performance and plot a confusion
# matrix

comp_gt_KL.get_performance()
w_conf = sw.plot_confusion_matrix(comp_gt_KL)

##############################################################################
# When comparing two sorters (2), we can see the matching of units between sorters. For example, this is how to extract
# the unit ids of Mountainsort4 (sorting2) mapped to the units of Klusta (sorting1). Units which are not mapped has -1
# as unit id.

mapped_units = comp_KL_MS4.get_mapped_sorting1().get_mapped_unit_ids()
Beispiel #6
0
def auto(recording_folder):
    os.chdir(recording_folder)

    #If sorter has already been run skip it.
    subfolders = [f.name for f in os.scandir(recording_folder) if f.is_dir()]
    if ('phy_AGR' in subfolders) or ('phy_MS4' in subfolders):
        print('Tetrode ' + recording_folder.split('_')[-1] +
              ' was previously sorted. Skipping')
        return
    """
    Adding Matlab-based sorters to path
    
    """

    #IronClust
    iron_path = "~/Documents/SpikeSorting/ironclust"
    ss.IronClustSorter.set_ironclust_path(os.path.expanduser(iron_path))
    ss.IronClustSorter.ironclust_path

    #HDSort
    #ss.HDSortSorter.set_hdsort_path('/home/adrian/Documents/SpikeSorting/HDsort')
    #ss.HDSortSorter.hdsort_path

    #Waveclus
    #ss.WaveClusSorter.set_waveclus_path('/home/adrian/Documents/SpikeSorting/wave_clus')
    #ss.WaveClusSorter.waveclus_path

    #Check if the recording has been preprocessed before and load it.
    # Else proceed with preprocessing.
    arr = os.listdir()

    #Load .continuous files
    recording = se.OpenEphysRecordingExtractor(recording_folder)
    channel_ids = recording.get_channel_ids()
    fs = recording.get_sampling_frequency()
    num_chan = recording.get_num_channels()

    print('Channel ids:', channel_ids)
    print('Sampling frequency:', fs)
    print('Number of channels:', num_chan)

    #!cat tetrode9.prb #Asks for prb file
    # os.system('cat /home/adrian/Documents/SpikeSorting/Adrian_test_data/Irene_data/test_without_zero_main_channels/Tetrode_9_CH/tetrode9.prb')
    #recording_prb = recording.load_probe_file('/home/adrian/Documents/SpikeSorting/Adrian_test_data/Irene_data/test_without_zero_main_channels/Tetrode_9_CH/tetrode.prb')
    recording_prb = recording.load_probe_file('tetrode.prb')

    print('Channels after loading the probe file:',
          recording_prb.get_channel_ids())
    print('Channel groups after loading the probe file:',
          recording_prb.get_channel_groups())

    #For testing only: Reduce recording.
    #recording_prb = se.SubRecordingExtractor(recording_prb, start_frame=100*fs, end_frame=420*fs)

    #Bandpass filter
    recording_cmr = st.preprocessing.bandpass_filter(recording_prb,
                                                     freq_min=300,
                                                     freq_max=6000)
    recording_cache = se.CacheRecordingExtractor(recording_cmr)

    print(recording_cache.get_channel_ids())
    print(recording_cache.get_channel_groups())
    print(recording_cache.get_num_frames() /
          recording_cache.get_sampling_frequency())

    #View installed sorters
    #ss.installed_sorters()
    #mylist = [f for f in glob.glob("*.txt")]

    #%% Run all channels. There are only a single tetrode channels anyway.

    #Create sub recording to avoid saving whole recording.Requirement from NWB to allow saving sorters data.
    recording_sub = se.SubRecordingExtractor(recording_cache,
                                             start_frame=200 * fs,
                                             end_frame=320 * fs)

    Sorters2Compare = []
    Sorters2CompareLabel = []
    SortersCount = []
    #Amount of detections per sorter

    #Klusta
    if 'sorting_KL_all.nwb' in arr:
        print('Loading Klusta')
        sorting_KL_all = se.NwbSortingExtractor('sorting_KL_all.nwb')
        if not (not (sorting_KL_all.get_unit_ids())):
            Sorters2Compare.append(sorting_KL_all)
            Sorters2CompareLabel.append('KL')

    else:
        t = time.time()
        sorting_KL_all = ss.run_klusta(recording_cache,
                                       output_folder='results_all_klusta',
                                       delete_output_folder=True)
        print('Found', len(sorting_KL_all.get_unit_ids()), 'units')
        time.time() - t
        #Save Klusta
        se.NwbRecordingExtractor.write_recording(recording_sub,
                                                 'sorting_KL_all.nwb')
        se.NwbSortingExtractor.write_sorting(sorting_KL_all,
                                             'sorting_KL_all.nwb')
        if not (not (sorting_KL_all.get_unit_ids())):
            Sorters2Compare.append(sorting_KL_all)
            Sorters2CompareLabel.append('KL')
    SortersCount.append(len(sorting_KL_all.get_unit_ids()))

    #Ironclust
    if 'sorting_IC_all.nwb' in arr:
        print('Loading Ironclust')
        sorting_IC_all = se.NwbSortingExtractor('sorting_IC_all.nwb')
        if not (not (sorting_IC_all.get_unit_ids())):
            Sorters2Compare.append(sorting_IC_all)
            Sorters2CompareLabel.append('IC')
        SortersCount.append(len(sorting_IC_all.get_unit_ids()))

    else:
        try:
            t = time.time()
            sorting_IC_all = ss.run_ironclust(recording_cache,
                                              output_folder='results_all_ic',
                                              delete_output_folder=True,
                                              filter=False)
            print('Found', len(sorting_IC_all.get_unit_ids()), 'units')
            time.time() - t
            #Save IC
            se.NwbRecordingExtractor.write_recording(recording_sub,
                                                     'sorting_IC_all.nwb')
            se.NwbSortingExtractor.write_sorting(sorting_IC_all,
                                                 'sorting_IC_all.nwb')
            if not (not (sorting_IC_all.get_unit_ids())):
                Sorters2Compare.append(sorting_IC_all)
                Sorters2CompareLabel.append('IC')
            SortersCount.append(len(sorting_IC_all.get_unit_ids()))
        except:
            print('Ironclust has failed')

    # #Waveclust
    # if 'sorting_waveclus_all.nwb' in arr:
    #     print('Loading waveclus')
    #     sorting_waveclus_all=se.NwbSortingExtractor('sorting_waveclus_all.nwb');
    #     if not(not(sorting_waveclus_all.get_unit_ids())):
    #         Sorters2Compare.append(sorting_waveclus_all);
    #         Sorters2CompareLabel.append('Waveclus');
    #     SortersCount.append(len(sorting_waveclus_all.get_unit_ids()))

    # else:
    #     t = time.time()
    #     try:
    #         sorting_waveclus_all = ss.run_waveclus(recording_cache, output_folder='results_all_waveclus',delete_output_folder=True)
    #         print('Found', len(sorting_waveclus_all.get_unit_ids()), 'units')
    #         time.time() - t
    #         #Save waveclus
    #         se.NwbRecordingExtractor.write_recording(recording_sub, 'sorting_waveclus_all.nwb')
    #         se.NwbSortingExtractor.write_sorting(sorting_waveclus_all, 'sorting_waveclus_all.nwb')
    #         if not(not(sorting_waveclus_all.get_unit_ids())):
    #             Sorters2Compare.append(sorting_waveclus_all);
    #             Sorters2CompareLabel.append('Waveclus');
    #         SortersCount.append(len(sorting_waveclus_all.get_unit_ids()))
    #     except:
    #         print('Waveclus cannot be run')

    #Herdingspikes
    if 'sorting_herdingspikes_all.nwb' in arr:
        print('Loading herdingspikes')
        sorting_herdingspikes_all = se.NwbSortingExtractor(
            'sorting_herdingspikes_all.nwb')
        if not (not (sorting_herdingspikes_all.get_unit_ids())):
            Sorters2Compare.append(sorting_herdingspikes_all)
            Sorters2CompareLabel.append('HS')
        SortersCount.append(len(sorting_herdingspikes_all.get_unit_ids()))

    else:
        try:
            t = time.time()
            sorting_herdingspikes_all = ss.run_herdingspikes(
                recording_cache,
                output_folder='results_all_herdingspikes',
                delete_output_folder=True)

            print('Found', len(sorting_herdingspikes_all.get_unit_ids()),
                  'units')
            time.time() - t
            #Save herdingspikes
            se.NwbRecordingExtractor.write_recording(
                recording_sub, 'sorting_herdingspikes_all.nwb')
            try:
                se.NwbSortingExtractor.write_sorting(
                    sorting_herdingspikes_all, 'sorting_herdingspikes_all.nwb')
            except TypeError:
                print("No units detected.  Can't save HerdingSpikes")
                os.remove("sorting_herdingspikes_all.nwb")
            if not (not (sorting_herdingspikes_all.get_unit_ids())):
                Sorters2Compare.append(sorting_herdingspikes_all)
                Sorters2CompareLabel.append('HS')
            SortersCount.append(len(sorting_herdingspikes_all.get_unit_ids()))

        except:
            print('Herdingspikes has failed')

    try:
        rmtree("results_all_herdingspikes")
    except:
        print('Removed leftover herdingspikes files')

    try:
        rmtree("results_all_herdingspikes")
    except:
        print('Removed leftover herdingspikes files')

    #Mountainsort4
    if 'sorting_mountainsort4_all.nwb' in arr:
        print('Loading mountainsort4')
        sorting_mountainsort4_all = se.NwbSortingExtractor(
            'sorting_mountainsort4_all.nwb')
        if not (not (sorting_mountainsort4_all.get_unit_ids())):
            Sorters2Compare.append(sorting_mountainsort4_all)
            Sorters2CompareLabel.append('MS4')

    else:
        t = time.time()
        sorting_mountainsort4_all = ss.run_mountainsort4(
            recording_cache,
            output_folder='results_all_mountainsort4',
            delete_output_folder=True,
            filter=False)
        print('Found', len(sorting_mountainsort4_all.get_unit_ids()), 'units')
        time.time() - t
        #Save mountainsort4
        se.NwbRecordingExtractor.write_recording(
            recording_sub, 'sorting_mountainsort4_all.nwb')
        se.NwbSortingExtractor.write_sorting(sorting_mountainsort4_all,
                                             'sorting_mountainsort4_all.nwb')
        if not (not (sorting_mountainsort4_all.get_unit_ids())):
            Sorters2Compare.append(sorting_mountainsort4_all)
            Sorters2CompareLabel.append('MS4')
    SortersCount.append(len(sorting_mountainsort4_all.get_unit_ids()))

    #Spykingcircus
    if 'sorting_spykingcircus_all.nwb' in arr:
        print('Loading spykingcircus')
        sorting_spykingcircus_all = se.NwbSortingExtractor(
            'sorting_spykingcircus_all.nwb')
        if not (not (sorting_spykingcircus_all.get_unit_ids())):
            Sorters2Compare.append(sorting_spykingcircus_all)
            Sorters2CompareLabel.append('SC')
        SortersCount.append(len(sorting_spykingcircus_all.get_unit_ids()))

    else:
        try:
            t = time.time()
            sorting_spykingcircus_all = ss.run_spykingcircus(
                recording_cache,
                output_folder='results_all_spykingcircus',
                delete_output_folder=True,
                filter=False)
            print('Found', len(sorting_spykingcircus_all.get_unit_ids()),
                  'units')
            time.time() - t
            #Save sorting_spykingcircus
            se.NwbRecordingExtractor.write_recording(
                recording_sub, 'sorting_spykingcircus_all.nwb')
            se.NwbSortingExtractor.write_sorting(
                sorting_spykingcircus_all, 'sorting_spykingcircus_all.nwb')
            if not (not (sorting_spykingcircus_all.get_unit_ids())):
                Sorters2Compare.append(sorting_spykingcircus_all)
                Sorters2CompareLabel.append('SC')
            SortersCount.append(len(sorting_spykingcircus_all.get_unit_ids()))
        except:
            print('Spykingcircus has failed')

    #Tridesclous
    if 'sorting_tridesclous_all.nwb' in arr:
        print('Loading tridesclous')
        try:
            sorting_tridesclous_all = se.NwbSortingExtractor(
                'sorting_tridesclous_all.nwb')
        except AttributeError:
            print("No units detected.  Can't load Tridesclous so will run it.")
            t = time.time()
            sorting_tridesclous_all = ss.run_tridesclous(
                recording_cache,
                output_folder='results_all_tridesclous',
                delete_output_folder=True)
            print('Found', len(sorting_tridesclous_all.get_unit_ids()),
                  'units')
            time.time() - t
            os.remove("sorting_tridesclous_all.nwb")
            #Save sorting_tridesclous
            se.NwbRecordingExtractor.write_recording(
                recording_sub, 'sorting_tridesclous_all.nwb')
            se.NwbSortingExtractor.write_sorting(
                sorting_tridesclous_all, 'sorting_tridesclous_all.nwb')
        if not (not (sorting_tridesclous_all.get_unit_ids())):
            Sorters2Compare.append(sorting_tridesclous_all)
            Sorters2CompareLabel.append('TRI')
        SortersCount.append(len(sorting_tridesclous_all.get_unit_ids()))

    else:
        try:
            t = time.time()
            sorting_tridesclous_all = ss.run_tridesclous(
                recording_cache,
                output_folder='results_all_tridesclous',
                delete_output_folder=True)
            print('Found', len(sorting_tridesclous_all.get_unit_ids()),
                  'units')
            time.time() - t
            #Save sorting_tridesclous
            se.NwbRecordingExtractor.write_recording(
                recording_sub, 'sorting_tridesclous_all.nwb')
            se.NwbSortingExtractor.write_sorting(
                sorting_tridesclous_all, 'sorting_tridesclous_all.nwb')
            if not (not (sorting_tridesclous_all.get_unit_ids())):
                Sorters2Compare.append(sorting_tridesclous_all)
                Sorters2CompareLabel.append('TRI')
            SortersCount.append(len(sorting_tridesclous_all.get_unit_ids()))

        except:
            print('Tridesclous failed')

    try:
        rmtree("results_all_tridesclous")
    except:
        print('Removed leftover tridesclous files')

    #Consensus based curation.
    print(Sorters2CompareLabel)
    print('Comparing sorters agreement. Please wait...')
    mcmp = sc.compare_multiple_sorters(Sorters2Compare, Sorters2CompareLabel)
    w = sw.plot_multicomp_agreement_by_sorter(mcmp)
    # plt.show()
    plt.savefig('consensus.pdf', bbox_inches='tight')
    plt.savefig('consensus.png', bbox_inches='tight')
    plt.close()

    w = sw.plot_multicomp_agreement(mcmp)
    plt.savefig('consensus_spikes.pdf', bbox_inches='tight')
    plt.savefig('consensus_spikes.png', bbox_inches='tight')
    plt.close()

    # #Use amount of sorters which give a value closest to 10 units.
    # agreed_units=[];
    # for x in [1,2,3,4,5]:
    #     agreement_sorting = mcmp.get_agreement_sorting(minimum_agreement_count=x)
    #     agreed_units.append(len(agreement_sorting.get_unit_ids()));
    # print(agreed_units)
    # print(agreed_units.index(min(agreed_units, key=lambda x:abs(x-10)))+1)

    # agreement_sorting = mcmp.get_agreement_sorting(minimum_agreement_count=
    #         agreed_units.index(min(agreed_units, key=lambda x:abs(x-10)))+1);

    # Use units with at least 2 sorters agreeing.
    agreement_sorting = mcmp.get_agreement_sorting(minimum_agreement_count=2)

    print(agreement_sorting.get_unit_ids())
    phy_folder_name = 'phy_AGR'
    if not (agreement_sorting.get_unit_ids()):  #If there is no agreement.
        # print('No consensus. Finding sorter with closest to expected amount of units')
        # print(Sorters2CompareLabel[SortersCount.index(min(SortersCount, key=lambda x:abs(x-10)))])
        # agreement_sorting=Sorters2Compare[SortersCount.index(min(SortersCount, key=lambda x:abs(x-10)))]
        print('No consensus. Using detections from MountainSort4')
        agreement_sorting = sorting_mountainsort4_all
        phy_folder_name = 'phy_MS4'

    st.postprocessing.export_to_phy(recording_cache,
                                    agreement_sorting,
                                    output_folder=phy_folder_name,
                                    grouping_property='group',
                                    verbose=True,
                                    recompute_info=True)

    # se.NwbRecordingExtractor.write_recording(recording_sub, 'agreement_sorting.nwb')
    # se.NwbSortingExtractor.write_sorting(agreement_sorting, 'agreement_sorting.nwb')

    # os.system('phy template-gui phy_AGR/params.py')
    # sorting_phy_curated = se.PhySortingExtractor('phy_AGR/', exclude_cluster_groups=['noise']);

    # se.NwbRecordingExtractor.write_recording(recording_sub, 'consensus_phy_curated.nwb')
    # se.NwbSortingExtractor.write_sorting(sorting_phy_curated, 'consensus_phy_curated.nwb')

    w_wf = sw.plot_unit_templates(sorting=agreement_sorting,
                                  recording=recording_cache)
    plt.savefig('unit_templates.pdf', bbox_inches='tight')
    plt.savefig('unit_templates.png', bbox_inches='tight')
    plt.close()

    #Access unit ID and firing rate.
    os.chdir(phy_folder_name)
    spike_times = np.load('spike_times.npy')
    spike_clusters = np.load('spike_clusters.npy')

    #Create a list with the unit IDs
    some_list = np.unique(spike_clusters)
    some_list = some_list.tolist()

    #Bin data in bins of 25ms
    #45 minutes
    bins = np.arange(start=0, stop=45 * 60 * fs + 1, step=.025 * fs)
    NData = np.zeros([np.unique(spike_clusters).shape[0], bins.shape[0] - 1])

    cont = 0
    for x in some_list:
        ind = (spike_clusters == x)
        fi = spike_times[ind]
        inds = np.histogram(fi, bins=bins)
        inds1 = inds[0]
        NData[cont, :] = inds1
        cont = cont + 1

    #Save activation matrix
    os.chdir("..")
    a = os.path.split(os.getcwd())[1]
    np.save('actmat_auto_' + a.split('_')[1], NData)
    np.save('unit_id_auto_' + a.split('_')[1], some_list)
for num in range(len(Sorters2CompareLabel)):
     i=Sorters2CompareLabel[num];
#     print(i)
     if 'phy_'+i in subfolders:
         sorting_curated = se.PhySortingExtractor('phy_'+i+'/', exclude_cluster_groups=['noise','mua']);
         if not sorting_curated.get_unit_ids():
             Sorters2label.remove(i)
         else:
                 Sorters2Compare.append(sorting_curated);
         

#Consensus based curation.
print(Sorters2label)
print('Comparing sorters agreement. Please wait...')
mcmp = sc.compare_multiple_sorters(Sorters2Compare, Sorters2label)
w = sw.plot_multicomp_agreement_by_sorter(mcmp)


plt.savefig('consensus_curation.pdf', bbox_inches='tight');
plt.savefig('consensus_curation.png', bbox_inches='tight');
plt.close()

w = sw.plot_multicomp_agreement(mcmp)
plt.savefig('consensus_curation_spikes.pdf', bbox_inches='tight');
plt.savefig('consensus_curation_spikes.png', bbox_inches='tight');
plt.close()

#Consider at least 2 sorters agreeing.
agreement_sorting=mcmp.get_agreement_sorting(minimum_agreement_count=2);
Beispiel #8
0
keep_unit_ids = keep_mask[keep_mask].index.values
print(keep_unit_ids)

curated_sorting = sorting_TDC.select_units(keep_unit_ids)
print(curated_sorting)

##############################################################################
# The final part of this tutorial deals with comparing spike sorting outputs.
# We can either (1) compare the spike sorting results with the ground-truth 
# sorting :code:`sorting_true`, (2) compare the output of two (HerdingSpikes
# and Tridesclous), or (3) compare the output of multiple sorters:

comp_gt_TDC = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC)
comp_TDC_HS = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_HS)
comp_multi = sc.compare_multiple_sorters(sorting_list=[sorting_TDC, sorting_HS],
                                         name_list=['tdc', 'hs'])

##############################################################################
# When comparing with a ground-truth sorting extractor (1), you can get the sorting performance and plot a confusion
# matrix

comp_gt_TDC.get_performance()
w_conf = sw.plot_confusion_matrix(comp_gt_TDC)
w_agr = sw.plot_agreement_matrix(comp_gt_TDC)

##############################################################################
# When comparing two sorters (2), we can see the matching of units between sorters.
# Units which are not matched has -1 as unit id:

comp_TDC_HS.hungarian_match_12
                                            chance_score=0.1,
                                            well_detected_score=0.1,
                                            exhaustive_gt=True)
w_comp_SP = sw.plot_confusion_matrix(comp_SP, count_text=True)
plt.show()

#Computing some metrics for benchmarking-agreement matrix
sw.plot_agreement_matrix(comp_MATLAB, ordered=True, count_text=True)
perf_MATLAB = comp_MATLAB.get_performance()
plt.show()
print(perf_MATLAB)

#comparing the sorting algos
#We will try to compare all the three sorters
mcmp = sc.compare_multiple_sorters(
    sorting_list=[sortingPipeline, sorting_MS4, sorting_SP],
    name_list=['Our', 'MS4', 'SP'],
    verbose=True)

sw.plot_multicomp_graph(mcmp)
plt.show()

#Pairwise

cmp_MS4_Our = sc.compare_two_sorters(sorting1=sorting_MS4,
                                     sorting2=sortingPipeline,
                                     sorting1_name='MS4',
                                     sorting2_name='Our')
sw.plot_agreement_matrix(cmp_MS4_Our, ordered=True, count_text=True)
plt.show()

cmp_SP_Our = sc.compare_two_sorters(sorting1=sorting_SP,
Beispiel #10
0
# example signal-to-noise ratio. Quality metrics can be computed using the :code:`toolkit.validation` submodule

import spikeinterface.toolkit as st

snrs = st.validation.compute_snrs(sorting_true,
                                  recording,
                                  save_as_property=True)

w_perf = sw.plot_sorting_performance(comp_MS4,
                                     property_name='snr',
                                     metric='accuracy')

##############################################################################
# Widgets using MultiSortingComparison
# -------------------------------------
#
# We can also compare all three SortingExtractor objects, obtaining a :code:`MultiSortingComparison` object.

multicomp = sc.compare_multiple_sorters(
    [sorting_true, sorting_MS4, sorting_KL])

##############################################################################
# plot_multicomp_graph()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~

w_multi = sw.plot_multicomp_graph(multicomp,
                                  edge_cmap='coolwarm',
                                  node_cmap='viridis',
                                  draw_labels=False,
                                  colorbar=True)