Ejemplo n.º 1
0
def test_run_sorters_with_list():
    # This import is to get error on github whenn import fails
    import tridesclous
    
    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_list'

    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)
    
    rec0, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
    rec1, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1)
    
    # make dumpable
    set_global_tmp_folder(cache_folder)
    rec0 = rec0.save(name='rec0')
    rec1 = rec1.save(name='rec1')
    
    recording_list = [rec0, rec1]
    sorter_list = ['tridesclous']

    run_sorters(sorter_list, recording_list, working_folder,
            engine='loop', verbose=False, with_output=False)
Ejemplo n.º 2
0
def setup_module():
    for folder in ('toy_rec_1seg', 'toy_sorting_1seg', 'toy_waveforms_1seg',
                   'toy_rec_2seg', 'toy_sorting_2seg', 'toy_waveforms_2seg'):
        if Path(folder).is_dir():
            shutil.rmtree(folder)

    recording, sorting = toy_example(num_segments=2, num_units=10)
    recording = recording.save(folder='toy_rec_2seg')
    sorting = sorting.save(folder='toy_sorting_2seg')
    we = extract_waveforms(recording,
                           sorting,
                           'toy_waveforms_2seg',
                           ms_before=3.,
                           ms_after=4.,
                           max_spikes_per_unit=500,
                           n_jobs=1,
                           chunk_size=30000)

    recording, sorting = toy_example(num_segments=1,
                                     num_units=10,
                                     num_channels=12)
    recording = recording.save(folder='toy_rec_1seg')
    sorting = sorting.save(folder='toy_sorting_1seg')
    we = extract_waveforms(recording,
                           sorting,
                           'toy_waveforms_1seg',
                           ms_before=3.,
                           ms_after=4.,
                           max_spikes_per_unit=500,
                           n_jobs=1,
                           chunk_size=30000)
Ejemplo n.º 3
0
def _setup_comparison_study():
    rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
    rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1)
    
    gt_dict = {
        'toy_tetrode': (rec0, gt_sorting0),
        'toy_probe32': (rec1, gt_sorting1),
    }
    study = GroundTruthStudy.create(study_folder, gt_dict)
Ejemplo n.º 4
0
def test_run_sorters_with_dict():
    # This import is to get error on github whenn import fails
    import tridesclous
    import circus

    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_dict'

    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)

    rec0, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
    rec1, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1)

    # make dumpable
    set_global_tmp_folder(cache_folder)
    rec0 = rec0.save(name='rec0')
    rec1 = rec1.save(name='rec1')

    recording_dict = {'toy_tetrode': rec0, 'toy_octotrode': rec1}

    sorter_list = ['tridesclous', 'spykingcircus']

    sorter_params = {
        'tridesclous': dict(detect_threshold=5.6),
        'spykingcircus': dict(detect_threshold=5.6),
    }

    # simple loop
    t0 = time.perf_counter()
    results = run_sorters(sorter_list,
                          recording_dict,
                          working_folder,
                          engine='loop',
                          sorter_params=sorter_params,
                          with_output=True,
                          mode_if_folder_exists='raise')

    t1 = time.perf_counter()
    print(t1 - t0)
    print(results)

    shutil.rmtree(working_folder + '/toy_tetrode/tridesclous')
    run_sorters(sorter_list,
                recording_dict,
                working_folder,
                engine='loop',
                sorter_params=sorter_params,
                with_output=False,
                mode_if_folder_exists='keep')
Ejemplo n.º 5
0
def test_run_sorters_joblib():
    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_joblib'
    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)
    
    set_global_tmp_folder(cache_folder)
    
    recording_dict = {}
    for i in range(8):
        rec, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
        # make dumpable
        rec = rec.save(name=f'rec_{i}')
        recording_dict[f'rec_{i}'] = rec

    sorter_list = ['tridesclous', ]

    # joblib
    t0 = time.perf_counter()
    run_sorters(sorter_list, recording_dict, working_folder,
                engine='joblib', engine_kwargs={'n_jobs' : 4},
                with_output=False,
                mode_if_folder_exists='keep')
    t1 = time.perf_counter()
    print(t1 - t0)
Ejemplo n.º 6
0
 def setUp(self):
     recording, sorting_gt = toy_example(num_channels=4,
                                         duration=60,
                                         seed=0,
                                         num_segments=1)
     self.recording = recording.save(verbose=False, format='binary')
     print(self.recording)
Ejemplo n.º 7
0
def generate_erroneous_sorting():
    rec, sorting_true = se.toy_example(num_channels=4, num_units=10, duration=10, seed=10, num_segments=1)
    
    # artificilaly remap to one based
    sorting_true = sorting_true.select_units(unit_ids=None,
                renamed_unit_ids=np.arange(10, dtype='int64')+1)
    
    sampling_frequency = sorting_true.get_sampling_frequency()
    
    units_err = {}
    
    # sorting_true have 10 units
    np.random.seed(0)
    
    # unit 1 2 are perfect
    for u in [1, 2]:
        st = sorting_true.get_unit_spike_train(u)
        units_err[u] = st

    # unit 3 4 (medium) 10 (low) have medium to low agreement
    for u, score in [(3, 0.8),  (4, 0.75), (10, 0.3)]:
        st = sorting_true.get_unit_spike_train(u)
        st = np.sort(np.random.choice(st, size=int(st.size*score), replace=False))
        units_err[u] = st
    
    # unit 5 6 are over merge
    st5 = sorting_true.get_unit_spike_train(5)
    st6 = sorting_true.get_unit_spike_train(6)
    st = np.unique(np.concatenate([st5, st6]))
    st = np.sort(np.random.choice(st, size=int(st.size*0.7), replace=False))
    units_err[56] = st
    
    # unit 7 is over split in 2 part
    st7 = sorting_true.get_unit_spike_train(7)
    st70 = st7[::2]
    units_err[70] = st70
    st71 = st7[1::2]
    st71 = np.sort(np.random.choice(st71, size=int(st71.size*0.9), replace=False))
    units_err[71] = st71
    
    # unit 8 is redundant 3 times
    st8 = sorting_true.get_unit_spike_train(8)
    st80 = np.sort(np.random.choice(st8, size=int(st8.size*0.65), replace=False))
    st81 = np.sort(np.random.choice(st8, size=int(st8.size*0.6), replace=False))
    st82 = np.sort(np.random.choice(st8, size=int(st8.size*0.55), replace=False))
    units_err[80] = st80
    units_err[81] = st81
    units_err[81] = st82
    
    # unit 9 is missing
    
    # there are some units that do not exist 15 16 and 17
    nframes = rec.get_num_frames(segment_index=0)
    for u in [15,16,17]:
        st = np.sort(np.random.randint(0, high=nframes, size=35))
        units_err[u] = st
    sorting_err = se.NumpySorting.from_dict(units_err, sampling_frequency)
    
    
    return sorting_true, sorting_err
Ejemplo n.º 8
0
def test_run_sorters_dask():
    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_dask'
    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)

    # create recording
    recording_dict = {}
    for i in range(8):
        rec, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
        # make dumpable
        rec = rec.save(name=f'rec_{i}')
        recording_dict[f'rec_{i}'] = rec

    sorter_list = ['tridesclous', ]

    # create a dask Client for a slurm queue
    from dask.distributed import Client
    from dask_jobqueue import SLURMCluster

    python = '/home/samuel.garcia/.virtualenvs/py36/bin/python3.6'
    cluster = SLURMCluster(processes=1, cores=1, memory="12GB", python=python, walltime='12:00:00', )
    cluster.scale(5)
    client = Client(cluster)

    # dask
    t0 = time.perf_counter()
    run_sorters(sorter_list, recording_dict, working_folder,
                engine='dask', engine_kwargs={'client': client},
                with_output=False,
                mode_if_folder_exists='keep')
    t1 = time.perf_counter()
    print(t1 - t0)
Ejemplo n.º 9
0
def test_export_to_phy_by_property():
    num_units = 4
    recording, sorting = se.toy_example(num_channels=8,
                                        duration=10,
                                        num_units=num_units,
                                        num_segments=1)
    recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1])
    sorting.set_property("group", [0, 0, 1, 1])

    waveform_folder = Path('waveforms')
    waveform_folder_rm = Path('waveforms_rm')
    output_folder = Path('phy_output')
    output_folder_rm = Path('phy_output_rm')
    rec_folder = Path("rec")
    sort_folder = Path("sort")

    for f in (waveform_folder, waveform_folder_rm, output_folder,
              output_folder_rm, rec_folder, sort_folder):
        if f.is_dir():
            shutil.rmtree(f)

    recording = recording.save(folder=rec_folder)
    sorting = sorting.save(folder=sort_folder)

    waveform_extractor = extract_waveforms(recording, sorting, waveform_folder)

    export_to_phy(waveform_extractor,
                  output_folder,
                  compute_pc_features=True,
                  compute_amplitudes=True,
                  max_channels_per_template=8,
                  sparsity_dict=dict(method="by_property",
                                     by_property="group"),
                  n_jobs=1,
                  chunk_size=10000,
                  progress_bar=True)

    template_inds = np.load(output_folder / "template_ind.npy")
    assert template_inds.shape == (num_units, 4)

    # Remove one channel
    recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7])
    waveform_extractor_rm = extract_waveforms(recording_rm, sorting,
                                              waveform_folder_rm)

    export_to_phy(waveform_extractor_rm,
                  output_folder_rm,
                  compute_pc_features=True,
                  compute_amplitudes=True,
                  max_channels_per_template=8,
                  sparsity_dict=dict(method="by_property",
                                     by_property="group"),
                  n_jobs=1,
                  chunk_size=10000,
                  progress_bar=True)

    template_inds = np.load(output_folder_rm / "template_ind.npy")
    assert template_inds.shape == (num_units, 4)
    assert len(np.where(template_inds == -1)[0]) > 0
Ejemplo n.º 10
0
def test_toy_example():
    rec, sorting = toy_example(num_segments=2, num_units=10)
    assert rec.get_num_segments() == 2
    assert sorting.get_num_segments() == 2
    assert sorting.get_num_units() == 10
    # print(rec)
    # print(sorting)

    rec, sorting = toy_example(num_segments=1, num_channels=16, num_columns=2)
    assert rec.get_num_segments() == 1
    assert sorting.get_num_segments() == 1
    print(rec)
    print(sorting)

    # print(rec.get_channel_locations())

    probe = rec.get_probe()
    print(probe)
def setup_module():
    for folder in ('toy_rec', 'toy_sorting', 'toy_waveforms'):
        if Path(folder).is_dir():
            shutil.rmtree(folder)

    recording, sorting = toy_example(num_segments=2, num_units=10)
    recording = recording.save(folder='toy_rec')
    sorting = sorting.save(folder='toy_sorting')

    we = WaveformExtractor.create(recording, sorting, 'toy_waveforms')
    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
    we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
Ejemplo n.º 12
0
def test_mda_extractors():
    rec, sort = toy_example(num_segments=1, num_units=10)

    MdaRecordingExtractor.write_recording(rec, "mdatest")
    rec_mda = MdaRecordingExtractor("mdatest")
    probe = rec_mda.get_probe()

    check_recordings_equal(rec, rec_mda, return_scaled=False)

    MdaSortingExtractor.write_sorting(sort, "mdatest/firings.mda")
    sort_mda = MdaSortingExtractor("mdatest/firings.mda", sampling_frequency=sort.get_sampling_frequency())

    check_sortings_equal(sort, sort_mda)
Ejemplo n.º 13
0
def setup_module():
    for folder in ('toy_rec', 'toy_sorting', 'toy_waveforms'):
        if Path(folder).is_dir():
            shutil.rmtree(folder)

    recording, sorting = toy_example(num_segments=2, num_units=10)
    recording = recording.save(folder='toy_rec')
    sorting = sorting.save(folder='toy_sorting')

    we = WaveformExtractor.create(recording, sorting, 'toy_waveforms')
    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
    we.run(n_jobs=1, chunk_size=30000)

    pca = WaveformPrincipalComponent(we)
    pca.set_params(n_components=5, mode='by_channel_local')
    pca.run()
def test_shybrid_extractors():
    rec, sort = toy_example(num_segments=1, num_units=10)

    SHYBRIDSortingExtractor.write_sorting(sort, "shybridtest")
    sort_shybrid = SHYBRIDSortingExtractor(
        "shybridtest/initial_sorting.csv",
        sampling_frequency=sort.get_sampling_frequency())

    check_sortings_equal(sort, sort_shybrid)

    SHYBRIDRecordingExtractor.write_recording(
        rec,
        "shybridtest",
        initial_sorting_fn="shybridtest/initial_sorting.csv")
    rec_shybrid = SHYBRIDRecordingExtractor("shybridtest/recording.yml")
    probe = rec_shybrid.get_probe()

    check_recordings_equal(rec, rec_shybrid, return_scaled=False)
def test_compare_multiple_templates():
    test_dir = Path("temp_comp_test")

    duration = 60
    num_channels = 8

    rec, sort = toy_example(duration=duration,
                            num_segments=1,
                            num_channels=num_channels)
    rec = rec.save(folder=test_dir / "rec")
    sort = sort.save(folder=test_dir / "sort")

    # split recording in 3 equal slices
    fs = rec.get_sampling_frequency()
    rec1 = rec.frame_slice(start_frame=0 * fs, end_frame=duration / 3 * fs)
    rec2 = rec.frame_slice(start_frame=duration / 3 * fs,
                           end_frame=2 / 3 * duration * fs)
    rec3 = rec.frame_slice(start_frame=2 / 3 * duration * fs,
                           end_frame=duration * fs)
    sort1 = sort.frame_slice(start_frame=0 * fs, end_frame=duration / 3 * fs)
    sort2 = sort.frame_slice(start_frame=duration / 3 * fs,
                             end_frame=2 / 3 * duration * fs)
    sort3 = sort.frame_slice(start_frame=2 / 3 * duration * fs,
                             end_frame=duration * fs)
    # compute waveforms
    we1 = extract_waveforms(rec1, sort1, test_dir / "wf1", n_jobs=1)
    we2 = extract_waveforms(rec2, sort2, test_dir / "wf2", n_jobs=1)
    we3 = extract_waveforms(rec3, sort3, test_dir / "wf3", n_jobs=1)

    # paired comparison
    temp_cmp = compare_templates(we1, we2)

    for u1 in temp_cmp.hungarian_match_12.index.values:
        u2 = temp_cmp.hungarian_match_12[u1]
        if u2 != -1:
            assert u1 == u2

    # multi-comparison
    temp_mcmp = compare_multiple_templates([we1, we2, we3])
    # assert unit ids are the same across sessions (because of initial slicing)
    for unit_dict in temp_mcmp.units.values():
        unit_ids = unit_dict["unit_ids"].values()
        if len(unit_ids) > 1:
            assert len(np.unique(unit_ids)) == 1
Ejemplo n.º 16
0
def test_run_sorter_by_property():
    cache_folder = './local_cache'
    working_folder = 'test_run_sorter_by_property'

    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)

    rec0, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1)
    rec0.set_channel_groups(["0"] * 4 + ["1"] * 4)

    # make dumpable
    set_global_tmp_folder(cache_folder)
    rec0 = rec0.save(name='rec0')
    sorter_name = 'tridesclous'

    sorting = run_sorter_by_property(sorter_name, rec0, "group", working_folder,
                                     engine='loop', verbose=False)
    assert "group" in sorting.get_property_keys()
Ejemplo n.º 17
0
'''
SortingExtractor Widgets Gallery
===================================

Here is a gallery of all the available widgets using SortingExtractor objects.
'''
import matplotlib.pyplot as plt

import spikeinterface.extractors as se
import spikeinterface.widgets as sw

##############################################################################
# First, let's create a toy example with the `extractors` module:

recording, sorting = se.toy_example(duration=10,
                                    num_channels=4,
                                    seed=0,
                                    num_segments=1)

##############################################################################
# plot_rasters()
# ~~~~~~~~~~~~~~~~~

w_rs = sw.plot_rasters(sorting)

##############################################################################
# plot_isi_distribution()
# ~~~~~~~~~~~~~~~~~~~~~~~~

#TODO : @alessio: this is for you
#w_isi = sw.plot_isi_distribution(sorting, bins=10, window=1)
Ejemplo n.º 18
0
spikeinterface use internaly `probe interface <https://probeinterface.readthedocs.io/>`_ to handle
probe or probe group for recordings.

Depending the dataset the `Probe` object can be already include or you have to settlt it manually.

Here's how!
'''
import matplotlib.pyplot as plt
import numpy as np
import spikeinterface.extractors as se

##############################################################################
# First, let's create a toy example:

recording, sorting_true = se.toy_example(duration=10,
                                         num_channels=32,
                                         seed=0,
                                         num_segments=2)
print(recording)

###############################################################################
# This genertor already contain a probe object you can retreive directly an plot

probe = recording.get_probe()
print(probe)
from probeinterface.plotting import plot_probe
plot_probe(probe)

###############################################################################
#  You can also change the probe
# In that case you need to manually make the wiring
# Lets use a probe from cambridgeneurotech with 32ch
Ejemplo n.º 19
0
import matplotlib.pyplot as plt
import seaborn as sns

import spikeinterface.extractors as se
import spikeinterface.widgets as sw
from spikeinterface.comparison import GroundTruthStudy

##############################################################################
# Setup study folder and run all sorters
# --------------------------------------
# 
# We first generate the folder.
# this can take some time because recordings are copied inside the folder.


rec0, gt_sorting0 = se.toy_example(num_channels=4, duration=10, seed=10, num_segments=1)
rec1, gt_sorting1 = se.toy_example(num_channels=4, duration=10, seed=0, num_segments=1)
gt_dict = {
    'rec0': (rec0, gt_sorting0),
    'rec1': (rec1, gt_sorting1),
}
study_folder = 'a_study_folder'
study = GroundTruthStudy.create(study_folder, gt_dict)

##############################################################################
# Then just run all sorters on all recordings in one functions.

# sorter_list = st.sorters.available_sorters() # this get all sorters.
sorter_list = ['herdingspikes', 'tridesclous', ]
study.run_sorters(sorter_list, mode_if_folder_exists="keep")
##############################################################################
#  Sometimes, you might want to sort your data depending on a specific property of your recording channels.
#
# For example, when using multiple tetrodes, a good idea is to sort each tetrode separately. In this case, channels
# belonging to the same tetrode will be in the same 'group'. Alternatively, for long silicon probes, such as
# Neuropixels, you could sort different areas separately, for example hippocampus and thalamus.
#
# All this can be done by sorting by 'property'. Properties can be loaded to the recording channels either manually
# (using the :code:`set_channel_property` method), or by using a probe file. In this example we will create a 16 channel
# recording and split it in four channel groups (tetrodes).
#
# Let's create a toy example with 16 channels (the :code:`dumpable=True` dumps the extractors to a file, which is
# required for parallel sorting):

recording, sorting_true = se.toy_example(duration=[10.],
                                         num_segments=1,
                                         num_channels=16)
# make dumpable
recording = recording.save()

##############################################################################
# Initially all channel are in the same group.

print(recording.get_channel_groups())

##############################################################################
# Lets now change the probe mapping and assign a 4 tetrodes to this recording.
# for this we will use the `probeinterface` module and create a `ProbeGroup` containing for dummy tetrode.

from probeinterface import generate_tetrode, ProbeGroup
Ejemplo n.º 21
0
"""
Use the spike sorting launcher
==============================

This example shows how to use the spike sorting launcher. The launcher allows to parameterize the sorter name and
to run several sorters on one or multiple recordings.

"""

import spikeinterface.extractors as se
import spikeinterface.sorters as ss

##############################################################################
# First, let's create the usual toy example:

recording, sorting_true = se.toy_example(duration=10, seed=0, num_segments=1)
print(recording)
print(sorting_true)

##############################################################################
# Lets cache this recording to make it "dumpable"

recording = recording.save(name='toy')
print(recording)

##############################################################################
# The launcher enables to call any spike sorter with the same functions:  :code:`run_sorter` and :code:`run_sorters`.
# For running multiple sorters on the same recording extractor or a collection of them, the :code:`run_sorters`
# function can be used.
#
# Let's first see how to run a single sorter, for example, Klusta:
Ejemplo n.º 22
0
Before spike sorting, you may need to preproccess your signals in order to improve the spike sorting performance.
You can do that in SpikeInterface using the :code:`toolkit.preprocessing` submodule.

"""

import numpy as np
import matplotlib.pylab as plt
import scipy.signal

import spikeinterface.extractors as se
import spikeinterface.toolkit as st

##############################################################################
# First, let's create a toy example:

recording, sorting = se.toy_example(num_channels=4, duration=10, seed=0)

##############################################################################
# Apply filters
# -----------------
#
# Now apply a bandpass filter and a notch filter (separately) to the
# recording extractor. Filters are also RecordingExtractor objects.
# Note that theses operation are **lazy** the computation is done on the fly
# with `rec.get_traces()`

recording_bp = st.preprocessing.bandpass_filter(recording,
                                                freq_min=300,
                                                freq_max=6000)
print(recording_bp)
recording_notch = st.preprocessing.notch_filter(recording, freq=2000, q=30)