Exemple #1
0
def test_get_performance():
    ######
    # simple match
    gt_sorting, tested_sorting = make_sorting([100, 200, 300, 400],
                                              [0, 0, 1, 0], [
                                                  101,
                                                  201,
                                                  301,
                                              ], [0, 0, 5])
    sc = compare_sorter_to_ground_truth(gt_sorting,
                                        tested_sorting,
                                        exhaustive_gt=True,
                                        delta_time=0.3)

    perf = sc.get_performance('raw_count')
    assert perf.loc[0, 'tp'] == 2
    assert perf.loc[1, 'tp'] == 1
    assert perf.loc[0, 'fn'] == 1
    assert perf.loc[1, 'fn'] == 0
    assert perf.loc[0, 'fp'] == 0
    assert perf.loc[1, 'fp'] == 0

    perf = sc.get_performance('pooled_with_average')
    assert perf['miss_rate'] == 1 / 6

    perf = sc.get_performance('by_unit')

    assert perf.loc[0, 'accuracy'] == 2 / 3.
    assert perf.loc[0, 'miss_rate'] == 1 / 3.

    ######
    # match when 2 units fire at same time
    gt_sorting, tested_sorting = make_sorting(
        [100, 100, 200, 200, 300],
        [0, 1, 0, 1, 0],
        [100, 100, 200, 200, 300],
        [0, 1, 0, 1, 0],
    )
    sc = compare_sorter_to_ground_truth(gt_sorting,
                                        tested_sorting,
                                        exhaustive_gt=True)

    perf = sc.get_performance('raw_count')
    assert perf.loc[0, 'tp'] == 3
    assert perf.loc[0, 'fn'] == 0
    assert perf.loc[0, 'fp'] == 0
    assert perf.loc[0, 'num_gt'] == 3
    assert perf.loc[0, 'num_tested'] == 3

    perf = sc.get_performance('pooled_with_average')
    assert perf['accuracy'] == 1.
    assert perf['miss_rate'] == 0.
Exemple #2
0
def _compare_one_sorter(sorter_name, sortings_pre, agg_sortings, gts,
                        comparisons):
    print(f'Performance comparisons for {sorter_name}')
    for key in sortings_pre.keys():
        if key[1] == sorter_name:
            print(f'Recording name: {key[0]}')
            comparison_post = sc.compare_sorter_to_ground_truth(
                tested_sorting=agg_sortings[key], gt_sorting=gts[key[0]])
            print('Before recovery:')
            print(comparisons[key].print_performance())
            print('After recovery:')
            print(comparison_post.print_performance())
Exemple #3
0
    def setUp(self):
        #~ self._rec, self._sorting = se.toy_example(num_channels=10, duration=10, num_segments=1)
        #~ self._rec = self._rec.save()
        #~ self._sorting = self._sorting.save()
        local_path = download_dataset(remote_path='mearec/mearec_test_10s.h5')
        self._rec = se.MEArecRecordingExtractor(local_path)

        self._sorting = se.MEArecSortingExtractor(local_path)

        self.num_units = len(self._sorting.get_unit_ids())
        #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True)
        self._we = extract_waveforms(self._rec, self._sorting, './mearec_test', load_if_exists=True)

        self._amplitudes = st.get_spike_amplitudes(self._we, peak_sign='neg', outputs='by_unit')
        self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting)
Exemple #4
0
def make_comparison_figures():
    
    gt_sorting, tested_sorting = generate_erroneous_sorting()
    
    comp = sc.compare_sorter_to_ground_truth(gt_sorting, tested_sorting, gt_name=None, tested_name=None,
                                   delta_time=0.4, sampling_frequency=None, min_accuracy=0.5, exhaustive_gt=True, match_mode='hungarian', 
                                   n_jobs=-1, bad_redundant_threshold=0.2, compute_labels=False, verbose=False)
    
    print(comp.hungarian_match_12)
    
    fig, ax = plt.subplots()
    im = ax.matshow(comp.match_event_count, cmap='Greens')
    ax.set_xticks(np.arange(0, comp.match_event_count.shape[1]))
    ax.set_yticks(np.arange(0, comp.match_event_count.shape[0]))
    ax.xaxis.tick_bottom()
    ax.set_yticklabels(comp.match_event_count.index, fontsize=12)
    ax.set_xticklabels(comp.match_event_count.columns, fontsize=12)
    fig.colorbar(im)
    fig.savefig('spikecomparison_match_count.png')
    
    fig, ax = plt.subplots()
    sw.plot_agreement_matrix(comp, ax=ax, ordered=False)
    im = ax.get_images()[0]
    fig.colorbar(im)
    fig.savefig('spikecomparison_agreement_unordered.png')

    fig, ax = plt.subplots()
    sw.plot_agreement_matrix(comp, ax=ax)
    im = ax.get_images()[0]
    fig.colorbar(im)
    fig.savefig('spikecomparison_agreement.png')
    
    fig, ax = plt.subplots()
    sw.plot_confusion_matrix(comp, ax=ax)
    im = ax.get_images()[0]
    fig.colorbar(im)
    fig.savefig('spikecomparison_confusion.png')
    
    
    
    
    
    plt.show()
Exemple #5
0
local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording = se.MEArecRecordingExtractor(local_path)
sorting_true = se.MEArecSortingExtractor(local_path)
print(recording)
print(sorting_true)


##############################################################################
# run herdingspikes on it

sorting_HS = ss.run_herdingspikes(recording)

##############################################################################

cmp_gt_HS = sc.compare_sorter_to_ground_truth(sorting_true, sorting_HS, exhaustive_gt=True)


##############################################################################
# To have an overview of the match we can use the unordered agreement matrix

sw.plot_agreement_matrix(cmp_gt_HS, ordered=False)

##############################################################################
# or ordered

sw.plot_agreement_matrix(cmp_gt_HS, ordered=True)

##############################################################################
# This function first matches the ground-truth and spike sorted units, and
# then it computes several performance metrics.
Exemple #6
0
def test_compare_sorter_to_ground_truth():
    # simple match
    gt_sorting, tested_sorting = make_sorting(
        [100, 200, 300, 400, 500, 600, 700], [0, 0, 1, 0, 1, 1, 1],
        [101, 201, 301, 302, 401, 501, 502, 601, 900],
        [0, 0, 5, 6, 0, 5, 6, 5, 11])

    for match_mode in ('hungarian', 'best'):

        compute_labels = (match_mode == 'hungarian')

        sc = compare_sorter_to_ground_truth(gt_sorting,
                                            tested_sorting,
                                            exhaustive_gt=True,
                                            match_mode=match_mode,
                                            compute_labels=compute_labels)

        assert_array_equal(sc.event_counts1.values, [3, 4])
        assert_array_equal(sc.event_counts2.values, [3, 3, 2, 1])

        assert_array_equal(sc.possible_match_12[1], [5, 6])

        assert_array_equal(sc.best_match_12[1], 5)
        assert_array_equal(sc.hungarian_match_12[1], 5)

        # ~ print(sc.agreement_scores)
        # ~ print(sc.
        scores = sc.agreement_scores
        ordered_scores = sc.get_ordered_agreement_scores()
        assert_array_equal(
            scores.loc[ordered_scores.index, ordered_scores.columns],
            ordered_scores)

        assert sc.count_score.at[0, 'tp'] == 3
        assert sc.count_score.at[1, 'tp'] == 3
        assert sc.count_score.at[1, 'fn'] == 1

        sc._do_confusion_matrix()
        # print(sc._confusion_matrix)

        methods = [
            'raw_count',
            'by_unit',
            'pooled_with_average',
        ]
        for method in methods:
            perf = sc.get_performance(method=method)
            # ~ print(perf)

        for method in methods:
            sc.print_performance(method=method)

        sc.print_summary()

    sc = compare_sorter_to_ground_truth(gt_sorting,
                                        tested_sorting,
                                        exhaustive_gt=True,
                                        match_mode='hungarian')

    # test well detected units depending on thresholds
    good_units = sc.get_well_detected_units()  # tp_thresh=0.95 default value
    print(good_units)
    assert_array_equal(good_units, [
        0,
    ])
    good_units = sc.get_well_detected_units(well_detected_score=0.95)
    assert_array_equal(good_units, [
        0,
    ])
    good_units = sc.get_well_detected_units(well_detected_score=.6)
    assert_array_equal(good_units, [0, 5])
    # good_units = sc.get_well_detected_units(false_discovery_rate=0.05)
    # assert_array_equal(good_units, [0, 1])
    # good_units = sc.get_well_detected_units(accuracy=0.95, false_discovery_rate=.05)  # combine thresh
    # assert_array_equal(good_units, [0])

    # count
    num_ok = sc.count_well_detected_units(well_detected_score=0.95)
    assert num_ok == 1

    # false_positive_units [11]
    fpu_ids = sc.get_false_positive_units()
    assert_array_equal(fpu_ids, [11])
    num_fpu = sc.count_false_positive_units()
    assert num_fpu == 1

    # redundant_units [6]
    redundant_ids = sc.get_redundant_units()
    assert_array_equal(redundant_ids, [6])

    # bad_units [11]
    bad_ids = sc.get_bad_units()
    assert_array_equal(bad_ids, [6, 11])
    num_bad = sc.count_bad_units()

    # bad units is union of false_positive_units + redundant_units
    fpu_ids = sc.get_false_positive_units()
    redundant_ids = sc.get_redundant_units()
    bad_ids = sc.get_bad_units()
    assert_array_equal(bad_ids, sorted(fpu_ids + redundant_ids))
Exemple #7
0
def manual(recording_folder):
    #Folder with tetrode data
    #recording_folder='/home/adrian/Documents/SpikeSorting/Adrian_test_data/Irene_data/test_without_zero_main_channels/Tetrode_9_CH';

    os.chdir(recording_folder)
    """
    Adding Matlab-based sorters to path
    
    """

    #IronClust
    iron_path = "~/Documents/SpikeSorting/ironclust"
    ss.IronClustSorter.set_ironclust_path(os.path.expanduser(iron_path))
    ss.IronClustSorter.ironclust_path

    #If sorter has already been run skip it.
    subfolders = [f.name for f in os.scandir(recording_folder) if f.is_dir()]
    #if ('phy_KL' in subfolders) & ('phy_IC' in subfolders) & ('phy_Waveclus' in subfolders) & ('phy_SC' in subfolders) & ('phy_MS4' in subfolders) & ('phy_HS' in subfolders) & ('phy_TRI' in subfolders):
    if ('phy_KL' in subfolders) & ('phy_IC' in subfolders) & (
            'phy_SC' in subfolders) & ('phy_MS4' in subfolders) & (
                'phy_HS' in subfolders) & ('phy_TRI' in subfolders):
        print('Tetrode ' + recording_folder.split('_')[-1] +
              ' was previously manually sorted. Skipping')
        return

    #Check if the recording has been preprocessed before and load it.
    # Else proceed with preprocessing.
    arr = os.listdir()

    #Load .continuous files
    recording = se.OpenEphysRecordingExtractor(recording_folder)
    channel_ids = recording.get_channel_ids()
    fs = recording.get_sampling_frequency()
    num_chan = recording.get_num_channels()

    print('Channel ids:', channel_ids)
    print('Sampling frequency:', fs)
    print('Number of channels:', num_chan)

    #!cat tetrode9.prb #Asks for prb file
    # os.system('cat /home/adrian/Documents/SpikeSorting/Adrian_test_data/Irene_data/test_without_zero_main_channels/Tetrode_9_CH/tetrode9.prb')
    recording_prb = recording.load_probe_file(os.getcwd() + '/tetrode.prb')

    print('Channels after loading the probe file:',
          recording_prb.get_channel_ids())
    print('Channel groups after loading the probe file:',
          recording_prb.get_channel_groups())

    #For testing only: Reduce recording.
    #recording_prb = se.SubRecordingExtractor(recording_prb, start_frame=100*fs, end_frame=420*fs)

    #Bandpass filter
    recording_cmr = st.preprocessing.bandpass_filter(recording_prb,
                                                     freq_min=300,
                                                     freq_max=6000)
    recording_cache = se.CacheRecordingExtractor(recording_cmr)

    print(recording_cache.get_channel_ids())
    print(recording_cache.get_channel_groups())
    print(recording_cache.get_num_frames() /
          recording_cache.get_sampling_frequency())

    #View installed sorters
    #ss.installed_sorters()
    #mylist = [f for f in glob.glob("*.txt")]

    #%% Run all channels. There are only single tetrode channels anyway.

    #Create sub recording to avoid saving whole recording.Requirement from NWB to allow saving sorters data.
    recording_sub = se.SubRecordingExtractor(recording_cache,
                                             start_frame=200 * fs,
                                             end_frame=320 * fs)
    # Sorters2CompareLabel=['KL','IC','Waveclus','HS','MS4','SC','TRI'];
    Sorters2CompareLabel = ['KL', 'IC', 'HS', 'MS4', 'SC', 'TRI']
    subfolders = [f.name for f in os.scandir(recording_folder) if f.is_dir()]

    for num in range(len(Sorters2CompareLabel)):

        i = Sorters2CompareLabel[num]
        print(i)
        if 'phy_' + i in subfolders:
            print('Sorter already used for curation. Skipping')
            continue
        else:

            if 'KL' in i:
                #Klusta
                if 'sorting_KL_all.nwb' in arr:
                    print('Loading Klusta')
                    sorting_KL_all = se.NwbSortingExtractor(
                        'sorting_KL_all.nwb')

                else:
                    t = time.time()
                    sorting_KL_all = ss.run_klusta(
                        recording_cache,
                        output_folder='results_all_klusta',
                        delete_output_folder=True)
                    print('Found', len(sorting_KL_all.get_unit_ids()), 'units')
                    print(time.time() - t)
                    #Save Klusta
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_KL_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_KL_all, 'sorting_KL_all.nwb')
                sorter = sorting_KL_all

            if 'IC' in i:
                #Ironclust
                if 'sorting_IC_all.nwb' in arr:
                    print('Loading Ironclust')
                    sorting_IC_all = se.NwbSortingExtractor(
                        'sorting_IC_all.nwb')

                else:
                    t = time.time()
                    sorting_IC_all = ss.run_ironclust(
                        recording_cache,
                        output_folder='results_all_ic',
                        delete_output_folder=True,
                        filter=False)
                    print('Found', len(sorting_IC_all.get_unit_ids()), 'units')
                    print(time.time() - t)
                    #Save IC
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_IC_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_IC_all, 'sorting_IC_all.nwb')
                sorter = sorting_IC_all

            # if 'Waveclus' in i:
            #     #Waveclust
            #     if 'sorting_waveclus_all.nwb' in arr:
            #         print('Loading waveclus')
            #         sorting_waveclus_all=se.NwbSortingExtractor('sorting_waveclus_all.nwb');

            #     else:
            #         t = time.time()
            #         sorting_waveclus_all = ss.run_waveclus(recording_cache, output_folder='results_all_waveclus',delete_output_folder=True)
            #         print('Found', len(sorting_waveclus_all.get_unit_ids()), 'units')
            #         print(time.time() - t)
            #         #Save waveclus
            #         se.NwbRecordingExtractor.write_recording(recording_sub, 'sorting_waveclus_all.nwb')
            #         se.NwbSortingExtractor.write_sorting(sorting_waveclus_all, 'sorting_waveclus_all.nwb')
            #     sorter=sorting_waveclus_all;

            if 'HS' in i:
                #Herdingspikes
                if 'sorting_herdingspikes_all.nwb' in arr:
                    print('Loading herdingspikes')
                    sorting_herdingspikes_all = se.NwbSortingExtractor(
                        'sorting_herdingspikes_all.nwb')
                    sorter = sorting_herdingspikes_all

                else:
                    t = time.time()
                    try:
                        sorting_herdingspikes_all = ss.run_herdingspikes(
                            recording_cache,
                            output_folder='results_all_herdingspikes',
                            delete_output_folder=True)
                        print('Found',
                              len(sorting_herdingspikes_all.get_unit_ids()),
                              'units')
                        time.time() - t
                        #Save herdingspikes
                        se.NwbRecordingExtractor.write_recording(
                            recording_sub, 'sorting_herdingspikes_all.nwb')
                        try:
                            se.NwbSortingExtractor.write_sorting(
                                sorting_herdingspikes_all,
                                'sorting_herdingspikes_all.nwb')
                        except TypeError:
                            print(
                                "No units detected.  Can't save HerdingSpikes")
                            os.remove("sorting_herdingspikes_all.nwb")
                        sorter = sorting_herdingspikes_all
                    except:
                        print('Herdingspikes has failed')
                        sorter = []

            if 'MS4' in i:
                #Mountainsort4
                if 'sorting_mountainsort4_all.nwb' in arr:
                    print('Loading mountainsort4')
                    sorting_mountainsort4_all = se.NwbSortingExtractor(
                        'sorting_mountainsort4_all.nwb')

                else:
                    t = time.time()
                    sorting_mountainsort4_all = ss.run_mountainsort4(
                        recording_cache,
                        output_folder='results_all_mountainsort4',
                        delete_output_folder=True,
                        filter=False)
                    print('Found',
                          len(sorting_mountainsort4_all.get_unit_ids()),
                          'units')
                    print(time.time() - t)
                    #Save mountainsort4
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_mountainsort4_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_mountainsort4_all,
                        'sorting_mountainsort4_all.nwb')
                sorter = sorting_mountainsort4_all

            if 'SC' in i:
                #Spykingcircus
                if 'sorting_spykingcircus_all.nwb' in arr:
                    print('Loading spykingcircus')
                    sorting_spykingcircus_all = se.NwbSortingExtractor(
                        'sorting_spykingcircus_all.nwb', filter=False)

                else:
                    t = time.time()
                    sorting_spykingcircus_all = ss.run_spykingcircus(
                        recording_cache,
                        output_folder='results_all_spykingcircus',
                        delete_output_folder=True)
                    print('Found',
                          len(sorting_spykingcircus_all.get_unit_ids()),
                          'units')
                    print(time.time() - t)
                    #Save sorting_spykingcircus
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_spykingcircus_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_spykingcircus_all,
                        'sorting_spykingcircus_all.nwb')
                sorter = sorting_spykingcircus_all

            if 'TRI' in i:
                #Tridesclous
                if 'sorting_tridesclous_all.nwb' in arr:
                    print('Loading tridesclous')
                    try:
                        sorting_tridesclous_all = se.NwbSortingExtractor(
                            'sorting_tridesclous_all.nwb')
                    except AttributeError:
                        print(
                            "No units detected.  Can't load Tridesclous so will run it."
                        )
                        t = time.time()
                        sorting_tridesclous_all = ss.run_tridesclous(
                            recording_cache,
                            output_folder='results_all_tridesclous',
                            delete_output_folder=True)
                        print('Found',
                              len(sorting_tridesclous_all.get_unit_ids()),
                              'units')
                        time.time() - t
                        os.remove("sorting_tridesclous_all.nwb")
                        #Save sorting_tridesclous
                        se.NwbRecordingExtractor.write_recording(
                            recording_sub, 'sorting_tridesclous_all.nwb')
                        se.NwbSortingExtractor.write_sorting(
                            sorting_tridesclous_all,
                            'sorting_tridesclous_all.nwb')

                else:
                    t = time.time()
                    sorting_tridesclous_all = ss.run_tridesclous(
                        recording_cache,
                        output_folder='results_all_tridesclous',
                        delete_output_folder=True)
                    print('Found', len(sorting_tridesclous_all.get_unit_ids()),
                          'units')
                    time.time() - t
                    #Save sorting_tridesclous
                    se.NwbRecordingExtractor.write_recording(
                        recording_sub, 'sorting_tridesclous_all.nwb')
                    se.NwbSortingExtractor.write_sorting(
                        sorting_tridesclous_all, 'sorting_tridesclous_all.nwb')
                sorter = sorting_tridesclous_all

        #Check if sorter failed
            if not sorter:
                continue

            st.postprocessing.export_to_phy(recording_cache,
                                            sorter,
                                            output_folder='phy_' + i,
                                            grouping_property='group',
                                            verbose=True,
                                            recompute_info=True)

            #Open phy interface
            os.system('phy template-gui phy_' + i + '/params.py')

            #Remove detections curated as noise.
            sorting_phy_curated = se.PhySortingExtractor(
                'phy_' + i + '/', exclude_cluster_groups=['noise'])

            #Print waveforms of units
            w_wf = sw.plot_unit_templates(sorting=sorting_phy_curated,
                                          recording=recording_cache)
            plt.savefig('manual_' + i + '_unit_templates.pdf',
                        bbox_inches='tight')
            plt.savefig('manual_' + i + '_unit_templates.png',
                        bbox_inches='tight')
            plt.close()

            #Compute agreement matrix wrt consensus-based sorting.
            sorting_phy_consensus = se.PhySortingExtractor(
                'phy_AGR/', exclude_cluster_groups=['noise'])
            cmp = sc.compare_sorter_to_ground_truth(sorting_phy_curated,
                                                    sorting_phy_consensus)
            sw.plot_agreement_matrix(cmp)
            plt.savefig('agreement_matrix_' + i + '.pdf', bbox_inches='tight')
            plt.savefig('agreement_matrix_' + i + '.png', bbox_inches='tight')
            plt.close()

            #Access unit ID and firing rate.
            os.chdir('phy_' + i)
            spike_times = np.load('spike_times.npy')
            spike_clusters = np.load('spike_clusters.npy')
            #Find units curated as 'noise'
            noise_id = []
            with open("cluster_group.tsv") as fd:
                rd = csv.reader(fd, delimiter="\t", quotechar='"')
                for row in rd:
                    if row[1] == 'noise':
                        noise_id.append(int(row[0]))
            #Create a list with the unit IDs and remove those labeled as 'noise'
            some_list = np.unique(spike_clusters)
            some_list = some_list.tolist()
            for x in noise_id:
                print(x)
                some_list.remove(x)

            #Bin data in bins of 25ms
            #45 minutes
            bins = np.arange(start=0, stop=45 * 60 * fs + 1, step=.025 * fs)
            NData = np.zeros([
                np.unique(spike_clusters).shape[0] - len(noise_id),
                bins.shape[0] - 1
            ])

            cont = 0
            for x in some_list:
                #print(x)
                ind = (spike_clusters == x)
                fi = spike_times[ind]
                inds = np.histogram(fi, bins=bins)
                inds1 = inds[0]
                NData[cont, :] = inds1
                cont = cont + 1

            #Save activation matrix
            os.chdir("..")
            a = os.path.split(os.getcwd())[1]
            np.save('actmat_manual_' + i + '_' + a.split('_')[1], NData)
            np.save('unit_id_manual_' + i + '_' + a.split('_')[1], some_list)

    #End of for loop
    print("Stop the code here")
import matplotlib.pyplot as plt
import seaborn as sns

import spikeinterface.extractors as se
import spikeinterface.sorters as sorters
import spikeinterface.comparison as sc

##############################################################################

recording, sorting_true = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0)

sorting_MS4 = sorters.run_mountainsort4(recording)

##############################################################################

cmp_gt_MS4 = sc.compare_sorter_to_ground_truth(sorting_true, sorting_MS4, exhaustive_gt=True)

##############################################################################
# This function first matches the ground-truth and spike sorted units, and
# then it computes several performance metrics.
# 
# Once the spike trains are matched, each spike is labelled as: - true
# positive (tp): spike found both in :code:`gt_sorting` and :code:`tested_sorting`
# - false negative (fn): spike found in :code:`gt_sorting`, but not in
# :code:`tested_sorting` - false positive (fp): spike found in
# :code:`tested_sorting`, but not in :code:`gt_sorting` - misclassification errors
# (cl): spike found in :code:`gt_sorting`, not in :code:`tested_sorting`, found in
# another matched spike train of :code:`tested_sorting`, and not labelled as
# true positives
# 
# From the counts of these labels the following performance measures are
Exemple #9
0
def main():
    np.set_printoptions(threshold=np.infty)

    ######################################################
    # data extract & preprocess
    ######################################################
    recording, sorting_true = load_data(globvar.h5data_path)
    extrt_dict, recording_bp = preprocessing(recording, sorting_true)
    channel_location: list = extrt_dict['channel_loc']

    ######################################################
    # threshold detection
    ######################################################
    detec = Detection(extrt_dict)
    split_list = globvar.h5filename.split('_')
    time_buff = []

    start = time.time()
    detct_rst = detec.mc_run()  # detection
    # divide-and-conquer for multi-electorde
    merged_frames, grouped_frames, grouped_frameids = detec.get_frame_groups(
        detct_rst)
    time_buff.append(time.time() - start)

    # evaluate detection results
    fn_frames, fp_frames, evalinfo_dict, according_labels = detec.evaluate(
        merged_frames)
    print('----------------detection fnished----------------')
    print(
        f"precision: {evalinfo_dict['precision']}, recall: {evalinfo_dict['recall']}, F1: {evalinfo_dict['F1_score']}"
    )

    start = time.time()
    snipgroups = detec.get_grouped_snippets(grouped_frames)
    # -----remove empty array-----
    for i in np.arange(len(grouped_frames))[::-1]:
        if len(grouped_frames[i]) == 0:
            del grouped_frames[i]
            del grouped_frameids[i]
            del snipgroups[i]
            del channel_location[i]
    time_buff.append(time.time() - start)
    for i in range(len(grouped_frames)):
        print(f"group{i}, frames{grouped_frames[i].shape}")

    ######################################################
    # KNN outlier detection (NOT USED)
    ######################################################
    all_snippets = np.zeros([
        0, extrt_dict['snip_frame_before'] + extrt_dict['snip_frame_after'] + 1
    ])
    for i in range(len(snipgroups)):
        all_snippets = np.vstack((all_snippets, snipgroups[i]))
    outlier_marks = np.zeros(len(all_snippets), dtype=np.int)
    grouped_outlier_marks = []
    cnt_sample = 0
    for i in range(len(snipgroups)):
        n_samples = len(snipgroups[i])
        grouped_outlier_marks.append(outlier_marks[cnt_sample:cnt_sample +
                                                   n_samples])
        cnt_sample += n_samples

    ######################################################
    # PCA for each group
    ######################################################
    start = time.time()
    pca_snipgroups_rst = []
    p = Pool(min(globvar.n_multiprocess, len(snipgroups)))
    for i in range(len(snipgroups)):
        pca_snipgroups_rst.append(
            p.apply_async(decomposition,
                          args=(
                              snipgroups[i],
                              globvar.group_min_samples,
                          )))
    p.close()
    p.join()
    pca_snipgroups = []
    for each in pca_snipgroups_rst:
        pca_snipgroups.append(each.get())
    time_buff.append(time.time() - start)

    # get benign data
    pca_benign_snipgroups = []
    benign_snipgroups = []
    benign_framegroups = []
    for i in range(len(grouped_outlier_marks)):
        ids = np.where(grouped_outlier_marks[i] != -1)[0]
        benign_snipgroups.append((snipgroups[i])[ids])
        benign_framegroups.append((grouped_frames[i])[ids])
        pca_benign_snipgroups.append((pca_snipgroups[i])[ids])

    ######################################################
    # cluster
    ######################################################
    start = time.time()
    # multiprocess pool
    p = Pool(min(globvar.n_multiprocess, len(benign_snipgroups)))
    mp_cluster_rst = []
    for i_group in range(len(benign_snipgroups)):
        mp_cluster_rst.append(
            p.apply_async(cluster,
                          args=(
                              benign_snipgroups[i_group],
                              pca_benign_snipgroups[i_group],
                          )))
    p.close()
    p.join()
    benign_labelgroup_list = []  # labels of benign samples
    tempgroups_list = []  # cluster centroids
    for i_group in range(len(benign_snipgroups)):
        i_cluster_rtn = (mp_cluster_rst[i_group]).get()
        benign_labelgroup_list.append(i_cluster_rtn[0])  # get labels
        tempgroups_list.append(i_cluster_rtn[1])  # get centroids

    labelgroup_list = copy.deepcopy(grouped_outlier_marks)  # ★★★labels
    for i in range(len(labelgroup_list)):
        benign_ids = np.where(grouped_outlier_marks[i] != -1)[0]
        labelgroup_list[i][benign_ids] = benign_labelgroup_list[
            i]  # fill labels
    time_buff.append(time.time() - start)
    #
    for i in range(len(labelgroup_list)):
        statistical = Counter(labelgroup_list[i])
        print(statistical)

    ######################################################
    # postprocessing
    ######################################################
    # ----------------templates merging------------------
    start = time.time()
    # merging
    merged_centroids, merged_labelgroup_list = templates_merging(
        tempgroups_list,
        labelgroup_list,
        channel_location=channel_location,
        channel_type=(split_list[2].split("-"))[0],
        sigma=globvar.sigma,
        peek_diff=globvar.peek_diff)
    time_buff.append(time.time() - start)

    ######################################################
    # tempalte matching
    ######################################################
    start = time.time()
    p = Pool(min(globvar.n_multiprocess, len(merged_labelgroup_list)))
    aftertm_list = []
    threshold = np.mean(extrt_dict['stdvir_channels']) * globvar.detect_th
    for i_group in range(len(merged_labelgroup_list)):
        aftertm_list.append(
            p.apply_async(template_matching,
                          args=(
                              snipgroups[i_group],
                              merged_labelgroup_list[i_group],
                              grouped_frames[i_group],
                              merged_centroids,
                              threshold,
                          )))
    p.close()
    p.join()
    recovered_label_list = []
    recovered_frame_list = []
    # recovered_frame_list = grouped_frames
    for i_group in range(len(aftertm_list)):
        # recovered_label_list.append(aftertm_list[i_group])
        tm_rtn = (aftertm_list[i_group]).get()
        recovered_label_list.append(tm_rtn[0])
        recovered_frame_list.append(tm_rtn[1])
    time_buff.append(time.time() - start)
    ######################################################
    # get performance
    ######################################################
    recovered_labels = np.array([])
    recovered_frames = np.array([])
    for i in range(len(recovered_label_list)):
        recovered_labels = np.append(recovered_labels, recovered_label_list[i])
        recovered_frames = np.append(recovered_frames, recovered_frame_list[i])

    sorting = se.NumpySortingExtractor()
    sorting.set_sampling_frequency(extrt_dict['sample_freq'])
    sorting.set_times_labels(recovered_frames, recovered_labels)
    n_units = len(sorting.get_unit_ids())

    # compare to ground-truth
    comp: sc.GroundTruthComparison = sc.compare_sorter_to_ground_truth(
        sorting_true, sorting, delta_time=0.4, match_mode='best')
    # get_performance
    comp.print_performance(method='by_unit')
    comp.print_performance()
    gt_num_units = len(sorting_true.get_unit_ids())
    tested_num_units = len(sorting.get_unit_ids())
    print(f"GT num_units: {gt_num_units}")
    print(f"tested num_units: {tested_num_units}")
    perf = comp.get_performance(method='pooled_with_average', output='pandas')
    perf['gt_num_units'] = gt_num_units
    perf['tested_num_units'] = tested_num_units
    perf['time'] = sum(time_buff)
    perf['dataset'] = globvar.h5filename
    # output to csv
    perf.to_csv("perf_logger.csv", mode='a', header=False)

    assert True  # debug
#Mountainsort
with ka.config(fr='default_readonly'):
    #with hither.config(cache='default_readwrite'):
    with hither.config(container='default'):
        result_MS4 = sorters.mountainsort4.run(recording_path=recordingPath,
                                               sorting_out=hither.File())
#Aggregating the output of the sorters
sorting_MS4 = AutoSortingExtractor(result_MS4.outputs.sorting_out._path)
sorting_SP = AutoSortingExtractor(
    result_spyKingCircus.outputs.sorting_out._path)

#Comparing  each to ground truth-confusion matrix
comp_MATLAB = sc.compare_sorter_to_ground_truth(gtOutput,
                                                sortingPipeline,
                                                sampling_frequency=sampleRate,
                                                delta_time=3,
                                                match_score=0.5,
                                                chance_score=0.1,
                                                well_detected_score=0.1,
                                                exhaustive_gt=True)
w_comp_MATLAB = sw.plot_confusion_matrix(comp_MATLAB, count_text=True)
plt.show()

comp_MS4 = sc.compare_sorter_to_ground_truth(gtOutput,
                                             sorting_MS4,
                                             sampling_frequency=sampleRate,
                                             delta_time=3,
                                             match_score=0.5,
                                             chance_score=0.1,
                                             well_detected_score=0.1,
                                             exhaustive_gt=True)
w_comp_MS4 = sw.plot_confusion_matrix(comp_MS4, count_text=True)
Exemple #11
0
import spikeinterface.sorters as ss

sorting_MS4 = ss.run_mountainsort4(recording)
sorting_KL = ss.run_klusta(recording)

##############################################################################
# Widgets using SortingComparison
# ---------------------------------
#
# We can compare the spike sorting output to the ground-truth sorting :code:`sorting_true` using the
# :code:`comparison` module. :code:`comp_MS4` and :code:`comp_KL` are :code:`SortingComparison` objects

import spikeinterface.comparison as sc

comp_MS4 = sc.compare_sorter_to_ground_truth(sorting_true, sorting_MS4)
comp_KL = sc.compare_sorter_to_ground_truth(sorting_true, sorting_KL)

##############################################################################
# plot_confusion_matrix()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~

w_comp_MS4 = sw.plot_confusion_matrix(comp_MS4, count_text=False)
w_comp_KL = sw.plot_confusion_matrix(comp_KL, count_text=False)

##############################################################################
# plot_agreement_matrix()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~

w_agr_MS4 = sw.plot_agreement_matrix(comp_MS4, count_text=False)
Exemple #12
0
st = sortingPipeline.get_unit_spike_train(unit_id=1)
print('Num. events for unit 1 = {}'.format(len(st)))
st1 = sortingPipeline.get_unit_spike_train(unit_id=1)
print('Num. events for first second of unit 1 = {}'.format(len(st1)))

#We are also going to be setting up the unit spike features associated with each waveform
ID1_features = A_snippets_reference[cluster == 1, :]
sortingPipeline.set_unit_spike_features(unit_id=1,
                                        feature_name='unitId1',
                                        value=ID1_features)
print("Spike feature names: " +
      str(sortingPipeline.get_unit_spike_feature_names(unit_id=1)))

#Comparing sorter with ground truth
cmp_gt_SP = sc.compare_sorter_to_ground_truth(gtOutput,
                                              sortingPipeline,
                                              exhaustive_gt=True)
sw.plot_agreement_matrix(cmp_gt_SP, ordered=True)

#Some more comparision metrics
perf = cmp_gt_SP.get_performance()
#print('well_detected', cmp_gt_SP.get_well_detected_units(well_detected_score=0))
print(perf)
#We will try to get the SNR and firing rates

#firing_rates = st.validation.compute_firing_rates(sortingPipeline, duration_in_frames=recordingInput.get_num_frames())

#Raster plots

w_rs_gt = sw.plot_rasters(sortingPipeline, sampling_frequency=sampleRate)
Created on Mon Sep 30 13:04:03 2019

@author: Jasper Wouters
"""

import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc

# full filenames to both the hybrid recording and ground truth
recording_fn = '/path/to/recording.bin'
gt_fn = '/path/to/hybrid_GT.csv'

# create extractor object for both the recording data and ground truth labels
recording_ex = se.SHYBRIDRecordingExtractor(recording_fn)
sorting_ex = se.SHYBRIDSortingExtractor(gt_fn)

# perform spike sorting (e.g., using spyking circus)
sc_params = ss.SpykingcircusSorter.default_params()
sorting_sc = ss.run_spykingcircus(recording=recording_ex,
                                  **sc_params,
                                  output_folder='tmp_sc')

# calculate spike sorting performance
# note: exhaustive_gt is set to False, because the hybrid approach generates
#       partial ground truth only
comparison = sc.compare_sorter_to_ground_truth(sorting_ex,
                                               sorting_sc,
                                               exhaustive_gt=False)
print(comparison.get_performance())
Exemple #14
0
def _do_recovery_loop(task_args):

    key, well_detected_score, isi_thr, fr_thr, sample_window_ms, \
    percentage_spikes, balance_spikes, detect_threshold, method, skew_thr, n_jobs, we_params, compare, \
    output_folder, job_kwargs = task_args
    recording = load_extractor(output_folder / 'back_recording' / key[1] /
                               key[0])
    if compare is True:
        gt = load_extractor(output_folder / 'back_recording' / key[1] /
                            (key[0] + '_gt'))
    else:
        gt = None
    sorting = load_extractor(output_folder / 'back_recording' / key[0] /
                             (key[1] + '_pre'))
    we = extract_waveforms(
        recording,
        sorting,
        folder=output_folder / 'waveforms' / key[0] / key[1],
        load_if_exists=we_params['load_if_exists'],
        ms_before=we_params['ms_before'],
        ms_after=we_params['ms_after'],
        max_spikes_per_unit=we_params['max_spikes_per_unit'],
        return_scaled=we_params['return_scaled'],
        dtype=we_params['dtype'],
        overwrite=True,
        **job_kwargs)
    if gt is not None:
        comparison = sc.compare_sorter_to_ground_truth(tested_sorting=sorting,
                                                       gt_sorting=gt)
        selected_units = comparison.get_well_detected_units(
            well_detected_score)
        print(key[1][:-1])
        if key[1] == 'hdsort':
            selected_units = [unit - 1000 for unit in selected_units]
    else:
        isi_violation = st.compute_isi_violations(we)[0]
        good_isi = np.argwhere(
            np.array(list(isi_violation.values())) < isi_thr)[:, 0]

        firing_rate = st.compute_firing_rate(we)
        good_fr_idx_up = np.argwhere(
            np.array(list(firing_rate.values())) < fr_thr[1])[:, 0]
        good_fr_idx_down = np.argwhere(
            np.array(list(firing_rate.values())) > fr_thr[0])[:, 0]

        selected_units = [
            unit for unit in range(sorting.get_num_units())
            if unit in good_fr_idx_up and unit in good_fr_idx_down
            and unit in good_isi
        ]

    templates = we.get_all_templates()
    templates_dict = {
        str(unit): templates[unit - 1]
        for unit in selected_units
    }

    recording_subtracted = subtract_templates(recording, sorting,
                                              templates_dict, we.nbefore,
                                              selected_units)

    sorter = SpyICASorter(recording_subtracted)
    sorter.mask_traces(sample_window_ms=sample_window_ms,
                       percent_spikes=percentage_spikes,
                       balance_spikes_on_channel=balance_spikes,
                       detect_threshold=detect_threshold,
                       method=method,
                       **job_kwargs)
    sorter.compute_ica(n_comp='all')
    cleaning_result = clean_correlated_sources(
        recording,
        sorter.W_ica,
        skew_thresh=skew_thr,
        n_jobs=n_jobs,
        chunk_size=recording.get_num_samples(0) // n_jobs,
        **job_kwargs)
    sorter.A_ica[cleaning_result[1]] = -sorter.A_ica[cleaning_result[1]]
    sorter.W_ica[cleaning_result[1]] = -sorter.W_ica[cleaning_result[1]]
    sorter.source_idx = cleaning_result[0]
    sorter.cleaned_A_ica = sorter.A_ica[cleaning_result[0]]
    sorter.cleaned_W_ica = sorter.W_ica[cleaning_result[0]]

    ica_recording = st.preprocessing.lin_map(recording_subtracted,
                                             sorter.cleaned_W_ica)
    recording_back = st.preprocessing.lin_map(ica_recording,
                                              sorter.cleaned_A_ica.T)
    recording_back.save_to_folder(folder=output_folder / 'back_recording' /
                                  key[0] / key[1])
  #Open phy interface
  os.system('phy template-gui phy_'+i+'/params.py') 
 
      
  #Remove detections curated as noise.
  sorting_phy_curated = se.PhySortingExtractor('phy_'+i+'/', exclude_cluster_groups=['noise']);
  
  #Print waveforms of units
  w_wf = sw.plot_unit_templates(sorting=sorting_phy_curated, recording=recording_cache)
  plt.savefig('manual_'+i+'_unit_templates.pdf', bbox_inches='tight');
  plt.savefig('manual_'+i+'_unit_templates.png', bbox_inches='tight');
  plt.close()
  
  #Compute agreement matrix wrt consensus-based sorting.
  sorting_phy_consensus = se.PhySortingExtractor('phy_AGR/', exclude_cluster_groups=['noise']);
  cmp=sc.compare_sorter_to_ground_truth(sorting_phy_curated,sorting_phy_consensus)
  sw.plot_agreement_matrix(cmp)
  plt.savefig('agreement_matrix_'+i+'.pdf', bbox_inches='tight');
  plt.savefig('agreement_matrix_'+i+'.png', bbox_inches='tight');
  plt.close()
  
  
  #Access unit ID and firing rate.
  os.chdir('phy_'+i)
  spike_times=np.load('spike_times.npy');
  spike_clusters=np.load('spike_clusters.npy');
  #Find units curated as 'noise'
  noise_id=[];    
  with open("cluster_group.tsv") as fd:
      rd = csv.reader(fd, delimiter="\t", quotechar='"')
      for row in rd:
Exemple #16
0
    st8 = sorting_true.get_unit_spike_train(8)
    st80 = np.sort(np.random.choice(st8, size=int(st8.size*0.65), replace=False))
    st81 = np.sort(np.random.choice(st8, size=int(st8.size*0.6), replace=False))
    st82 = np.sort(np.random.choice(st8, size=int(st8.size*0.55), replace=False))
    units_err[80] = st80
    units_err[81] = st81
    units_err[81] = st82
    
    # unit 9 is missing
    
    # there are some units that do not exist 15 16 and 17
    nframes = rec.get_num_frames(segment_index=0)
    for u in [15,16,17]:
        st = np.sort(np.random.randint(0, high=nframes, size=35))
        units_err[u] = st
    sorting_err = se.NumpySorting.from_dict(units_err, sampling_frequency)
    
    
    return sorting_true, sorting_err
    


    
if __name__ == '__main__':
    # just for check
    sorting_true, sorting_err = generate_erroneous_sorting()
    comp = sc.compare_sorter_to_ground_truth(sorting_true, sorting_err, exhaustive_gt=True)
    sw.plot_agreement_matrix(comp, ordered=True)
    plt.show()

# Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select
# sorted units with a SNR above a certain threshold:

sorting_curated_snr = st.curation.threshold_snr(sorting_KL,
                                                recording,
                                                threshold=5)
snrs_above = st.validation.compute_snrs(sorting_curated_snr, recording_cmr)

print('Curated SNR', snrs_above)

##############################################################################
# The final part of this tutorial deals with comparing spike sorting outputs.
# We can either (1) compare the spike sorting results with the ground-truth sorting :code:`sorting_true`, (2) compare
# the output of two (Klusta and Mountainsor4), or (3) compare the output of multiple sorters:

comp_gt_KL = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true,
                                               tested_sorting=sorting_KL)
comp_KL_MS4 = sc.compare_two_sorters(sorting1=sorting_KL, sorting2=sorting_MS4)
comp_multi = sc.compare_multiple_sorters(
    sorting_list=[sorting_MS4, sorting_KL], name_list=['klusta', 'ms4'])

##############################################################################
# When comparing with a ground-truth sorting extractor (1), you can get the sorting performance and plot a confusion
# matrix

comp_gt_KL.get_performance()
w_conf = sw.plot_confusion_matrix(comp_gt_KL)

##############################################################################
# When comparing two sorters (2), we can see the matching of units between sorters. For example, this is how to extract
# the unit ids of Mountainsort4 (sorting2) mapped to the units of Klusta (sorting1). Units which are not mapped has -1
# as unit id.
Exemple #18
0
    def run(self):

        task_args_list = []
        for key in self._sortings_pre.keys():
            # recording_dict = self._recordings[key[0]].to_dict()
            # sorting_dict = self._sortings_pre[key].to_dict()
            # gt_dict = self._gt[key[0]].to_dict() if self._gt is not None else None
            # comparison = sc.compare_sorter_to_ground_truth(tested_sorting=self._sortings_pre[key], gt_sorting=self._gt[key[0]])
            # self._comparisons[key] = comparison
            # task_args_list.append((recording_dict, gt_dict, sorting_dict, key,
            #                        self._params_dict['wd_score'], self._params_dict['isi_thr'],
            #                        self._params_dict['fr_thr'], self._params_dict['sample_window_ms'],
            #                        self._params_dict['percentage_spikes'], self._params_dict['balance_spikes'],
            #                        self._params_dict['detect_threshold'], self._params_dict['method'],
            #                        self._params_dict['skew_thr'], self._params_dict['n_jobs'], self._we_params,
            #                        comparison, self._output_folder, self._params_dict['job_kwargs']))
            self._recordings[key[0]].save_to_folder(
                folder=self._output_folder / 'back_recording' / key[1] /
                key[0])
            self._sortings_pre[key].save_to_folder(folder=self._output_folder /
                                                   'back_recording' / key[0] /
                                                   (key[1] + '_pre'))
            self._gt[key[0]].save_to_folder(folder=self._output_folder /
                                            'back_recording' / key[1] /
                                            (key[0] + '_gt'))
            task_args_list.append(
                (key, self._params_dict['wd_score'],
                 self._params_dict['isi_thr'], self._params_dict['fr_thr'],
                 self._params_dict['sample_window_ms'],
                 self._params_dict['percentage_spikes'],
                 self._params_dict['balance_spikes'],
                 self._params_dict['detect_threshold'],
                 self._params_dict['method'], self._params_dict['skew_thr'],
                 self._params_dict['n_jobs'], self._we_params, self._compare,
                 self._output_folder, self._params_dict['job_kwargs']))

        if self._params_dict['parallel']:
            # raise NotImplementedError()
            from joblib import Parallel, delayed
            Parallel(n_jobs=self._params_dict['n_jobs'],
                     backend='loky')(delayed(_do_recovery_loop)(task_args)
                                     for task_args in task_args_list)
        else:
            for task_args in task_args_list:
                _do_recovery_loop(task_args)

        for key in self._sortings_pre.keys():
            if key[1] in self._recordings_backprojected.keys():
                self._recordings_backprojected[key[1]].append(
                    load_extractor(self._output_folder / 'back_recording' /
                                   key[0] / key[1]))
            else:
                self._recordings_backprojected[key[1]] = \
                    [load_extractor(self._output_folder / 'back_recording' / key[0] / key[1])]

        for sorter in self._recordings_backprojected.keys():
            self._sortings_post[sorter] = ss.run_sorters(
                sorter,
                self._recordings_backprojected[sorter],
                working_folder=self._output_folder / 'sortings_post' / sorter,
                sorter_params=self._sorters_params['sorters_params'],
                mode_if_folder_exists='overwrite',
                engine=self._sorters_params['engine'],
                engine_kwargs=self._sorters_params['engine_kwargs'],
                verbose=self._sorters_params['verbose'],
                with_output=self._sorters_params['with_output'])
            for key in self._sortings_post[sorter].keys():
                self._aggregated_sortings[key] = aggregate_units([
                    self._sortings_post[sorter][key], self._sortings_pre[key]
                ])
                self._comparisons[key] = sc.compare_sorter_to_ground_truth(
                    tested_sorting=self._sortings_pre[key],
                    gt_sorting=self._gt[key[0]])