Exemple #1
0
 def from_memory(recording: se.RecordingExtractor,
                 serialize=False,
                 serialize_dtype=None):
     if serialize:
         if serialize_dtype is None:
             raise Exception(
                 'You must specify the serialize_dtype when serializing recording extractor in from_memory()'
             )
         with hi.TemporaryDirectory() as tmpdir:
             fname = tmpdir + '/' + _random_string(10) + '_recording.mda'
             se.BinDatRecordingExtractor.write_recording(
                 recording=recording,
                 save_path=fname,
                 time_axis=0,
                 dtype=serialize_dtype)
             with ka.config(use_hard_links=True):
                 uri = ka.store_file(fname, basename='raw.mda')
             num_channels = recording.get_num_channels()
             channel_ids = [int(a) for a in recording.get_channel_ids()]
             xcoords = [
                 recording.get_channel_property(a, 'location')[0]
                 for a in channel_ids
             ]
             ycoords = [
                 recording.get_channel_property(a, 'location')[1]
                 for a in channel_ids
             ]
             recording = LabboxEphysRecordingExtractor({
                 'recording_format': 'bin1',
                 'data': {
                     'raw':
                     uri,
                     'raw_num_channels':
                     num_channels,
                     'num_frames':
                     int(recording.get_num_frames()),
                     'samplerate':
                     float(recording.get_sampling_frequency()),
                     'channel_ids':
                     channel_ids,
                     'channel_map':
                     dict(
                         zip([str(c) for c in channel_ids],
                             [int(i) for i in range(num_channels)])),
                     'channel_positions':
                     dict(
                         zip([str(c) for c in channel_ids],
                             [[float(xcoords[i]),
                               float(ycoords[i])]
                              for i in range(num_channels)]))
                 }
             })
             return recording
     obj = {
         'recording_format': 'in_memory',
         'data': register_in_memory_object(recording)
     }
     return LabboxEphysRecordingExtractor(obj)
Exemple #2
0
def mountainsort4b(recording_object: dict,
                   detect_sign=-1,
                   adjacency_radius=50,
                   clip_size=50,
                   detect_threshold=3,
                   detect_interval=10,
                   freq_min=300,
                   freq_max=6000,
                   whiten=True,
                   curation=False,
                   filter=True):
    # Unfortunately we need to duplicate wrapper code from spikeforest2 due to trickiness in running code in containers. Will need to think about this
    # import spiketoolkit as st
    import spikesorters as ss
    import labbox_ephys as le

    recording = le.LabboxEphysRecordingExtractor(recording_object)

    # for quick testing
    # import spikeextractors as se
    # recording = se.SubRecordingExtractor(parent_recording=recording_object, start_frame=0, end_frame=30000 * 1)

    # Preprocessing
    # print('Preprocessing...')
    # recording = st.preprocessing.bandpass_filter(recording_object, freq_min=300, freq_max=6000)
    # recording = st.preprocessing.whiten(recording_object)

    # Sorting
    print('Sorting...')
    with hi.TemporaryDirectory() as tmpdir:
        sorter = ss.Mountainsort4Sorter(recording=recording,
                                        output_folder=tmpdir,
                                        delete_output_folder=False)

        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)
        else:
            num_workers = 0

        sorter.set_params(detect_sign=detect_sign,
                          adjacency_radius=adjacency_radius,
                          clip_size=clip_size,
                          detect_threshold=detect_threshold,
                          detect_interval=detect_interval,
                          num_workers=num_workers,
                          curation=curation,
                          whiten=whiten,
                          filter=filter,
                          freq_min=freq_min,
                          freq_max=freq_max)
        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
        sorting = sorter.get_result()
        sorting_object = _create_sorting_object(sorting)
        return dict(sorting_object=sorting_object)
def filter_recording(recobj, freq_min=300, freq_max=6000, freq_wid=1000):
    from spikeforest2_utils import AutoRecordingExtractor
    from spikeforest2_utils import writemda32
    import spiketoolkit as st
    rx = AutoRecordingExtractor(recobj)
    rx2 = st.preprocessing.bandpass_filter(recording=rx, freq_min=freq_min, freq_max=freq_max, freq_wid=freq_wid)
    recobj2 = recobj.copy()
    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        if not writemda32(rx2.get_traces(), raw_fname):
            raise Exception('Unable to write output file.')
        recobj2['raw'] = ka.store_file(raw_fname)
        return recobj2
Exemple #4
0
 def from_memory(sorting: se.SortingExtractor, serialize=False):
     if serialize:
         with hi.TemporaryDirectory() as tmpdir:
             fname = tmpdir + '/' + _random_string(10) + '_firings.mda'
             MdaSortingExtractor.write_sorting(sorting=sorting, save_path=fname)
             with ka.config(use_hard_links=True):
                 uri = ka.store_file(fname, basename='firings.mda')
             sorting = LabboxEphysSortingExtractor({
                 'sorting_format': 'mda',
                 'data': {
                     'firings': uri,
                     'samplerate': sorting.get_sampling_frequency()
                 }
             })
             return sorting
     obj = {
         'sorting_format': 'in_memory',
         'data': register_in_memory_object(sorting)
     }
     return LabboxEphysSortingExtractor(obj)
Exemple #5
0
def spykingcircus(*,
    recording_object,
    detect_sign=-1,
    adjacency_radius=100,
    detect_threshold=6,
    template_width_ms=3,
    filter=True,
    merge_spikes=True,
    auto_merge=0.75,
    num_workers=None,
    whitening_max_elts=1000,
    clustering_max_elts=10000
):
    import spikesorters as ss

    recording = LabboxEphysRecordingExtractor(recording_object)

    sorting_params = ss.get_default_params('spykingcircus')
    sorting_params['detect_sign'] = detect_sign
    sorting_params['adjacency_radius'] = adjacency_radius
    sorting_params['detect_threshold'] = detect_threshold
    sorting_params['template_width_ms'] = template_width_ms
    sorting_params['filter'] = filter
    sorting_params['merge_spikes'] = merge_spikes
    sorting_params['auto_merge'] = auto_merge
    sorting_params['num_workers'] = num_workers
    sorting_params['whitening_max_elts'] = whitening_max_elts
    sorting_params['clustering_max_elts'] = clustering_max_elts
    print('Using sorting parameters:', sorting_params)
    with hi.TemporaryDirectory() as tmpdir:
        sorting = ss.run_spykingcircus(recording, output_folder=tmpdir + '/sc_output', delete_output_folder=False, verbose=True, **sorting_params)
        h5_output_fname = tmpdir + '/sorting.h5'
        H5SortingExtractorV1.write_sorting(sorting=sorting, save_path=h5_output_fname)
        return {
            'sorting_format': 'h5_v1',
            'data': {
                'h5_path': ka.store_file(h5_output_fname)
            }
        }
def create_subrecording_object(recording_object, channels, start_frame,
                               end_frame):
    from .bin1recordingextractor import Bin1RecordingExtractor
    recording_format = recording_object['recording_format']
    assert recording_format == 'bin1', f'Unsupported recording format: {recording_format}'
    d = recording_object['data']
    rec = Bin1RecordingExtractor(raw=d['raw'],
                                 num_frames=d['num_frames'],
                                 raw_num_channels=d['raw_num_channels'],
                                 channel_ids=d['channel_ids'],
                                 samplerate=d['samplerate'],
                                 channel_map=d['channel_map'],
                                 channel_positions=d['channel_positions'],
                                 p2p=True)
    rec2 = se.SubRecordingExtractor(parent_recording=rec,
                                    channel_ids=channels,
                                    start_frame=start_frame,
                                    end_frame=end_frame)
    with hi.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.bin'
        rec2.get_traces().astype('int16').tofile(raw_fname)
        new_bin_uri = ka.store_file(raw_fname)
        new_channel_map = dict()
        new_channel_positions = dict()
        for ii, id in enumerate(rec2.get_channel_ids()):
            new_channel_map[str(id)] = ii
            new_channel_positions[str(id)] = rec2.get_channel_locations(
                channel_ids=[id])[0]
        return dict(recording_format='bin1',
                    data=dict(raw=new_bin_uri,
                              raw_num_channels=len(rec2.get_channel_ids()),
                              num_frames=end_frame - start_frame,
                              samplerate=rec2.get_sampling_frequency(),
                              channel_ids=_listify_ndarray(
                                  rec2.get_channel_ids()),
                              channel_map=new_channel_map,
                              channel_positions=new_channel_positions))
def mountainsort4(*,
                  recording_object,
                  detect_sign=-1,
                  clip_size=50,
                  adjacency_radius=-1,
                  detect_threshold=3,
                  detect_interval=10,
                  num_workers=None,
                  verbose=True):
    from ml_ms4alg.mountainsort4 import MountainSort4
    recording = LabboxEphysRecordingExtractor(recording_object)
    MS4 = MountainSort4()
    MS4.setRecording(recording)
    geom = _get_geom_from_recording(recording)
    MS4.setGeom(geom)
    MS4.setSortingOpts(clip_size=clip_size,
                       adjacency_radius=adjacency_radius,
                       detect_sign=detect_sign,
                       detect_interval=detect_interval,
                       detect_threshold=detect_threshold,
                       verbose=verbose)
    if num_workers is not None:
        MS4.setNumWorkers(num_workers)
    with hi.TemporaryDirectory() as tmpdir:
        MS4.setTemporaryDirectory(tmpdir)
        MS4.sort()
        times, labels, channels = MS4.eventTimesLabelsChannels()
        sorting_object = {
            'sorting_format': 'npy1',
            'data': {
                'samplerate': recording.get_sampling_frequency(),
                'times_npy_uri': ka.store_npy(times.astype(np.float64)),
                'labels_npy_uri': ka.store_npy(labels.astype(np.int32))
            }
        }
        return sorting_object
Exemple #8
0
def prepare_snippets_h5(recording_object,
                        sorting_object,
                        start_frame=None,
                        end_frame=None,
                        max_events_per_unit=None,
                        max_neighborhood_size=15):
    if recording_object['recording_format'] == 'snippets1':
        return recording_object['data']['snippets_h5_uri']

    import labbox_ephys as le
    recording = le.LabboxEphysRecordingExtractor(recording_object)
    sorting = le.LabboxEphysSortingExtractor(sorting_object)

    with hi.TemporaryDirectory() as tmpdir:
        save_path = tmpdir + '/snippets.h5'
        prepare_snippets_h5_from_extractors(
            recording=recording,
            sorting=sorting,
            output_h5_path=save_path,
            start_frame=start_frame,
            end_frame=end_frame,
            max_events_per_unit=max_events_per_unit,
            max_neighborhood_size=max_neighborhood_size)
        return ka.store_file(save_path)
Exemple #9
0
 def write_sorting(sorting, save_path):
     with hi.TemporaryDirectory() as tmpdir:
         H5SortingExtractorV1.write_sorting(sorting=sorting, save_path=tmpdir + '/' + _random_string(10) + '_sorting.h5')
Exemple #10
0
import json
import spikeextractors as se
import spiketoolkit as st
import labbox_ephys as le
import hither as hi

# Create a temporary directory where SI will dump the data
with hi.TemporaryDirectory(remove=True) as tmpdir:
    # Create a dumpable SpikeInterface recording extractor
    R, S = se.example_datasets.toy_example(dumpable=True, dump_folder=tmpdir, seed=1)
    R2 = se.SubRecordingExtractor(parent_recording=R, start_frame=10)
    R3 = st.preprocessing.bandpass_filter(R2)

    # Convert to labbox-ephys recording extractor
    R_le = le.LabboxEphysRecordingExtractor.from_spikeinterface(R3)

    # Print the labbox-ephys object
    print(json.dumps(
        R_le.object(), indent=4
    ))

# expected output:
# {
#     "recording_format": "filtered",
#     "data": {
#         "filters": [
#             {
#                 "type": "bandpass_filter",
#                 "freq_min": 300,
#                 "freq_max": 6000,
#                 "freq_wid": 1000
def main():
    import spikeextractors as se
    from spikeforest2_utils import writemda32, AutoRecordingExtractor
    from sklearn.neighbors import NearestNeighbors
    from sklearn.cross_decomposition import PLSRegression
    import spikeforest_widgets as sw
    sw.init_electron()

    # bandpass filter
    with hither.config(container='default', cache='default_readwrite'):
        recobj2 = filter_recording.run(
            recobj=recobj,
            freq_min=300,
            freq_max=6000,
            freq_wid=1000
        ).retval
    
    detect_threshold = 3
    detect_interval = 200
    detect_interval_reference = 10
    detect_sign = -1
    num_events = 1000
    snippet_len = (200, 200)
    window_frac = 0.3
    num_passes = 20
    npca = 100
    max_t = 30000 * 100
    k = 20
    ncomp = 4
    
    R = AutoRecordingExtractor(recobj2)

    X = R.get_traces()
    
    sig = X.copy()
    if detect_sign < 0:
        sig = -sig
    elif detect_sign == 0:
        sig = np.abs(sig)
    sig = np.max(sig, axis=0)
    noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
    times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000)
    times_reference = times_reference[times_reference <= max_t]
    print(f'Num. reference events = {len(times_reference)}')

    snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len)
    tt = np.linspace(-1, 1, snippets_reference.shape[2])
    window0 = np.exp(-tt**2/(2*window_frac**2))
    for j in range(snippets_reference.shape[0]):
        for m in range(snippets_reference.shape[1]):
            snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0
    A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2])

    print('PCA...')
    u, s, vh = np.linalg.svd(A_snippets_reference)
    components_reference = vh[0:npca, :].T
    features_reference = A_snippets_reference @ components_reference

    print('Setting up nearest neighbors...')
    nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference)

    X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32)

    for passnum in range(num_passes):
        print(f'Pass {passnum}')
        sig = X.copy()
        if detect_sign < 0:
            sig = -sig
        elif detect_sign == 0:
            sig = np.abs(sig)
        sig = np.max(sig, axis=0)
        noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
        times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000)
        times = times[times <= max_t]
        print(f'Number of events: {len(times)}')
        if len(times) == 0:
            break
        snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len)
        for j in range(snippets.shape[0]):
            for m in range(snippets.shape[1]):
                snippets[j, m, :] = snippets[j, m, :] * window0
        A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2])
        features = A_snippets @ components_reference
        
        print('Finding nearest neighbors...')
        distances, indices = nbrs.kneighbors(features)
        features2 = np.zeros(features.shape, dtype=features.dtype)
        print('PLS regression...')
        for j in range(features.shape[0]):
            print(f'{j+1} of {features.shape[0]}')
            inds0 = np.squeeze(indices[j, :])
            inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision?
            f_neighbors = features_reference[inds0, :]
            pls = PLSRegression(n_components=ncomp)
            pls.fit(f_neighbors.T, features[j, :].T)
            features2[j, :] = pls.predict(f_neighbors.T).T
        A_snippets_denoised = features2 @ components_reference.T
        
        snippets_denoised = A_snippets_denoised.reshape(snippets.shape)

        for j in range(len(times)):
            t0 = times[j]
            snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :])
            X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0
            X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0

    S = np.concatenate((X_signal, X, R.get_traces()), axis=0)

    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        writemda32(S, raw_fname)
        sig_recobj = recobj2.copy()
        sig_recobj['raw'] = ka.store_file(raw_fname)
    
    sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()