Beispiel #1
0
def sort(algorithm: str,
         recording_path: str,
         sorting_out: str = None,
         params: dict = None,
         container: str = 'default',
         git_annex_mode=True,
         use_singularity: bool = False,
         job_timeout: float = 3600) -> str:

    from spikeforest2 import sorters
    HITHER_USE_SINGULARITY = os.getenv('HITHER_USE_SINGULARITY')
    if HITHER_USE_SINGULARITY is None:
        HITHER_USE_SINGULARITY = False
    print('HITHER_USE_SINGULARITY: ' + HITHER_USE_SINGULARITY)
    if not hasattr(sorters, algorithm):
        raise Exception('Sorter not found: {}'.format(algorithm))
    sorter = getattr(sorters, algorithm)
    if algorithm in [
            'kilosort2', 'kilosort', 'ironclust', 'tridesclous', 'jrclust'
    ]:
        gpu = True
    else:
        gpu = False
    if not sorting_out:
        sorting_out = hither.File()
    if not recording_path.startswith(
            'sha1dir://') or not recording_path.startswith('sha1://'):
        if os.path.isfile(recording_path):
            recording_path = ka.store_file(recording_path)
        elif os.path.isdir(recording_path):
            recording_path = ka.store_dir(recording_path,
                                          git_annex_mode=git_annex_mode)
    if params is None:
        params = dict()
    params_hither = dict(gpu=gpu, container=container)
    if job_timeout is not None:
        params_hither['job_timeout'] = job_timeout
    with hither.config(**params_hither):
        result = sorter.run(recording_path=recording_path,
                            sorting_out=sorting_out,
                            **params)
    print('SORTING')
    print('==============================================')
    return ka.store_file(result.outputs.sorting_out._path,
                         basename='firings.mda')


# def set_params(sorter, params_file):
#     params = {}
#     names_float = ['detection_thresh']
#     with open(params_file, 'r') as myfile:
#         for line in myfile:
#             name, var = line.partition("=")[::2]
#             name = name.strip()

#             params[name.strip()] = var
#     sorter.set_params(**params)
Beispiel #2
0
 def make(self, key):
     print('Computing SHA-1 and storing in kachery...')
     nwb_file_abs_path = Nwbfile.get_abs_path(key['nwb_file_name'])
     with ka.config(use_hard_links=True):
         kachery_path = ka.store_file(nwb_file_abs_path)
         key['nwb_file_sha1'] = ka.get_file_hash(kachery_path)
     self.insert1(key)
Beispiel #3
0
def _internal_serialize_result(result):
    import kachery as ka
    ret: Dict[Any] = dict(
        output_files=dict()
    )
    ret['name'] = 'hither_result'

    ret['runtime_info'] = deepcopy(result.runtime_info)
    ret['runtime_info']['console_out'] = ka.store_object(ret['runtime_info'].get('console_out', ''))

    for oname in result._output_names:
        path = getattr(result.outputs, oname)._path
        if path is not None:
            ret['output_files'][oname] = ka.store_file(path)
        else:
            ret['output_files'][oname] = None

    ret['retval'] = result.retval
    ret['success'] = result.success
    ret['version'] = result.version
    ret['container'] = result.container
    ret['hash_object'] = result.hash_object
    ret['hash'] = ka.get_object_hash(result.hash_object)
    ret['status'] = result.status
    return ret
Beispiel #4
0
def spikeinterface_recording_dict_to_labbox_dict(x):
    c = x['class']
    if c == 'spiketoolkit.preprocessing.bandpass_filter.BandpassFilterRecording':
        kwargs = x['kwargs']
        recording = spikeinterface_recording_dict_to_labbox_dict(
            kwargs['recording'])
        freq_min = kwargs['freq_min']
        freq_max = kwargs['freq_max']
        freq_wid = kwargs['freq_wid']
        return _make_json_safe({
            'recording_format': 'filtered',
            'data': {
                'filters': [{
                    'type': 'bandpass_filter',
                    'freq_min': freq_min,
                    'freq_max': freq_max,
                    'freq_wid': freq_wid
                }],
                'recording':
                recording
            }
        })
    elif c == 'spikeextractors.subrecordingextractor.SubRecordingExtractor':
        kwargs = x['kwargs']
        recording = spikeinterface_recording_dict_to_labbox_dict(
            kwargs['parent_recording'])
        channel_ids = kwargs['channel_ids']
        renamed_channel_ids = kwargs.get('renamed_channel_ids', None)
        start_frame = kwargs['start_frame']
        end_frame = kwargs['end_frame']
        if renamed_channel_ids is not None:
            raise Exception('renamed_channel_ids field not supported')
        return _make_json_safe({
            'recording_format': 'subrecording',
            'data': {
                'recording': recording,
                'channel_ids': channel_ids,
                'start_frame': start_frame,
                'end_frame': end_frame
            }
        })
    elif c == 'spikeextractors.extractors.mdaextractors.mdaextractors.MdaRecordingExtractor':
        kwargs = x['kwargs']
        path = kwargs['folder_path']
        raw_path = ka.store_file(path + '/raw.mda')
        params_path = path + '/params.json'
        geom_path = path + '/geom.csv'
        params = ka.load_object(params_path)
        assert params is not None, f'Unable to load params.json from: {params_path}'
        geom = _load_geom_from_csv(geom_path)
        return _make_json_safe({
            'recording_format': 'mda',
            'data': {
                'raw': raw_path,
                'geom': geom,
                'params': params
            }
        })
    else:
        raise Exception(f'Unsupported class: {c}')
def register_groundtruth(*, recdir, output_fname, label, to):
    with ka.config(to=to):
        raw_path = ka.store_file(recdir + '/raw.mda')
        obj = dict(firings=raw_path)
        obj['self_reference'] = ka.store_object(
            obj, basename='{}.json'.format(label))
        with open(output_fname, 'w') as f:
            json.dump(obj, f, indent=4)
Beispiel #6
0
 def from_memory(recording: se.RecordingExtractor,
                 serialize=False,
                 serialize_dtype=None):
     if serialize:
         if serialize_dtype is None:
             raise Exception(
                 'You must specify the serialize_dtype when serializing recording extractor in from_memory()'
             )
         with hi.TemporaryDirectory() as tmpdir:
             fname = tmpdir + '/' + _random_string(10) + '_recording.mda'
             se.BinDatRecordingExtractor.write_recording(
                 recording=recording,
                 save_path=fname,
                 time_axis=0,
                 dtype=serialize_dtype)
             with ka.config(use_hard_links=True):
                 uri = ka.store_file(fname, basename='raw.mda')
             num_channels = recording.get_num_channels()
             channel_ids = [int(a) for a in recording.get_channel_ids()]
             xcoords = [
                 recording.get_channel_property(a, 'location')[0]
                 for a in channel_ids
             ]
             ycoords = [
                 recording.get_channel_property(a, 'location')[1]
                 for a in channel_ids
             ]
             recording = LabboxEphysRecordingExtractor({
                 'recording_format': 'bin1',
                 'data': {
                     'raw':
                     uri,
                     'raw_num_channels':
                     num_channels,
                     'num_frames':
                     int(recording.get_num_frames()),
                     'samplerate':
                     float(recording.get_sampling_frequency()),
                     'channel_ids':
                     channel_ids,
                     'channel_map':
                     dict(
                         zip([str(c) for c in channel_ids],
                             [int(i) for i in range(num_channels)])),
                     'channel_positions':
                     dict(
                         zip([str(c) for c in channel_ids],
                             [[float(xcoords[i]),
                               float(ycoords[i])]
                              for i in range(num_channels)]))
                 }
             })
             return recording
     obj = {
         'recording_format': 'in_memory',
         'data': register_in_memory_object(recording)
     }
     return LabboxEphysRecordingExtractor(obj)
    def make(self, key):
        print('Computing SHA-1 and storing in kachery...')
        analysis_file_abs_path = AnalysisNwbfile().get_abs_path(key['analysis_file_name'])
        with ka.config(use_hard_links=True):
            kachery_path = ka.store_file(analysis_file_abs_path)
            key['analysis_file_sha1'] = ka.get_file_hash(kachery_path)
        self.insert1(key)

    #TODO: load from kachery and fetch_nwb
Beispiel #8
0
def _handle_temporary_outputs(outputs):
    import kachery as ka
    for output in outputs:
        if output._is_temporary:
            old_path = output._path
            new_path = ka.load_file(ka.store_file(old_path))
            output._path = new_path
            output._is_temporary = False
            os.unlink(old_path)
def create_sorting_object_from_spikeforest_recdir(recdir, label):
    params = kp.load_object(recdir + '/params.json')
    firings_path = kp.load_file(recdir + '/firings_true.mda')
    firings_path = ka.store_file(firings_path, basename=label + '-firings.mda')
    sorting_object = dict(sorting_format='mda',
                          data=dict(firings=firings_path,
                                    samplerate=params['samplerate']))
    print(sorting_object)
    return sorting_object
Beispiel #10
0
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running clustering')

        alg_name = state.get('alg_name', 'none')
        alg_arguments = state.get('alg_arguments', dict())
        kachery_config = state.get('kachery_config', None)
        args0 = alg_arguments.get(alg_name, {})

        if kachery_config:
            ka.set_config(**kachery_config)

        dirname = os.path.dirname(os.path.realpath(__file__))
        fname = os.path.join(dirname, 'clustering_datasets.json')
        with open(fname, 'r') as f:
            datasets = json.load(f)
        timer = time.time()
        for ds in datasets['datasets']:
            self._set_status('running', 'Running: {}'.format(ds['path']))
            print('Loading {}'.format(ds['path']))
            path2 = ka.load_file(ds['path'])
            ka.store_file(path2)
            if path2:
                ds['data'] = self._load_dataset_data(path2)
                if alg_name:
                    print('Clustering...')
                    ds['algName'] = alg_name
                    ds['algArgs'] = args0
                    timer0 = time.time()
                    ds['labels'] = self._do_clustering(
                        ds['data'], alg_name, args0,
                        dict(true_num_clusters=ds['trueNumClusters']))
                    elapsed0 = time.time() - timer0
                    if alg_name != 'none':
                        ds['elapsed'] = elapsed0
            else:
                print('Unable to realize file: {}'.format(ds['path']))
            elapsed = time.time() - timer
            if elapsed > 0.1:
                self._set_state(algorithms=self._algorithms, datasets=datasets)

        self._set_state(algorithms=self._algorithms, datasets=datasets)

        self._set_status('finished', 'Finished clustering')
Beispiel #11
0
def _handle_temporary_outputs(outputs: List[File]):
    import kachery as ka
    for output in outputs:
        if not output._exists:
            old_path = output._path
            new_path = ka.load_file(ka.store_file(old_path))
            output._path = new_path
            output._is_temporary = False
            output._exists = True
            os.unlink(old_path)
Beispiel #12
0
def register_recording(*, recdir, output_fname, label, to):
    with ka.config(to=to):
        raw_path = ka.store_file(recdir + '/raw.mda')
        obj = dict(raw=raw_path,
                   params=ka.load_object(recdir + '/params.json'),
                   geom=np.genfromtxt(ka.load_file(recdir + '/geom.csv'),
                                      delimiter=',').tolist())
        obj['self_reference'] = ka.store_object(
            obj, basename='{}.json'.format(label))
        with open(output_fname, 'w') as f:
            json.dump(obj, f, indent=4)
def filter_recording(recobj, freq_min=300, freq_max=6000, freq_wid=1000):
    from spikeforest2_utils import AutoRecordingExtractor
    from spikeforest2_utils import writemda32
    import spiketoolkit as st
    rx = AutoRecordingExtractor(recobj)
    rx2 = st.preprocessing.bandpass_filter(recording=rx, freq_min=freq_min, freq_max=freq_max, freq_wid=freq_wid)
    recobj2 = recobj.copy()
    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        if not writemda32(rx2.get_traces(), raw_fname):
            raise Exception('Unable to write output file.')
        recobj2['raw'] = ka.store_file(raw_fname)
        return recobj2
Beispiel #14
0
def sort(algorithm: str, recording_path: str):
    from spikeforest2 import sorters
    if not hasattr(sorters, algorithm):
        raise Exception('Sorter not found: {}'.format(algorithm))
    sorter = getattr(sorters, algorithm)
    if algorithm in ['kilosort2', 'ironclust']:
        gpu = True
    else:
        gpu = False
    with hither.config(gpu=gpu):
        result = sorter.run(recording_path=recording_path,
                            sorting_out=hither.File())
    print('SORTING')
    print('==============================================')
    return ka.store_file(result.outputs.sorting_out._path,
                         basename='firings.mda')
Beispiel #15
0
def register_study(*,
                   path_from,
                   path_to,
                   studySetName,
                   studyName,
                   to='default_readwrite'):
    list_rec = [
        str(f) for f in os.listdir(path_from)
        if os.path.isdir(os.path.join(path_from, f))
    ]
    print('# files: {}'.format(len(list_rec)))
    study_obj = dict(name=studyName, studySetName=studySetName, recordings=[])
    mkdir_(path_to)
    for rec1 in list_rec:
        print(f'Uploading {rec1}')
        path_rec1 = os.path.join(path_from, rec1)
        register_groundtruth(recdir=path_rec1,
                             output_fname=os.path.join(
                                 path_to, rec1 + '.firings_true.json'),
                             label=rec1)
        rec = MdaRecordingExtractor(recording_directory=path_rec1)
        sorting = MdaSortingExtractor(firings_file=path_rec1 +
                                      '/firings_true.mda',
                                      samplerate=rec.get_sampling_frequency())
        recording_obj = dict(
            name=rec1,
            studyName=studyName,
            studySetName=studySetName,
            directory=ka.store_dir(path_rec1),
            firingsTrue=ka.store_file(os.path.join(path_to, rec1 +
                                                   '.firings_true.json'),
                                      basename='firings_true.json'),
            sampleRateHz=rec.get_sampling_frequency(),
            numChannels=len(rec.get_channel_ids()),
            durationSec=rec.get_num_frames() / rec.get_sampling_frequency(),
            numTrueUnits=len(sorting.get_unit_ids()),
            spikeSign=-1  # TODO: get this from params.json
        )
        study_obj['recordings'].append(recording_obj)
        # update .json files
        register_recording(recdir=path_rec1,
                           output_fname=os.path.join(path_to, rec1 + '.json'),
                           label=rec1)
    study_obj['self_reference'] = ka.store_object(study_obj)
    with open(os.path.join(path_to, studyName + '.json'), 'w') as f:
        json.dump(study_obj, f, indent=4)
    return study_obj
Beispiel #16
0
def _serialize_result(result):
    import kachery as ka
    ret = dict(output_files=dict())
    ret['name'] = 'hither_result'

    ret['runtime_info'] = result.runtime_info
    ret['runtime_info']['stdout'] = ka.store_text(
        ret['runtime_info']['stdout'])
    ret['runtime_info']['stderr'] = ka.store_text(
        ret['runtime_info']['stderr'])

    for oname in result._output_names:
        path = getattr(result.outputs, oname)._path
        ret['output_files'][oname] = ka.store_file(path)

    ret['retval'] = result.retval
    ret['hash_object'] = result.hash_object
    ret['hash'] = ka.get_object_hash(result.hash_object)
    return ret
Beispiel #17
0
 def from_memory(sorting: se.SortingExtractor, serialize=False):
     if serialize:
         with hi.TemporaryDirectory() as tmpdir:
             fname = tmpdir + '/' + _random_string(10) + '_firings.mda'
             MdaSortingExtractor.write_sorting(sorting=sorting, save_path=fname)
             with ka.config(use_hard_links=True):
                 uri = ka.store_file(fname, basename='firings.mda')
             sorting = LabboxEphysSortingExtractor({
                 'sorting_format': 'mda',
                 'data': {
                     'firings': uri,
                     'samplerate': sorting.get_sampling_frequency()
                 }
             })
             return sorting
     obj = {
         'sorting_format': 'in_memory',
         'data': register_in_memory_object(sorting)
     }
     return LabboxEphysSortingExtractor(obj)
Beispiel #18
0
def spykingcircus(*,
    recording_object,
    detect_sign=-1,
    adjacency_radius=100,
    detect_threshold=6,
    template_width_ms=3,
    filter=True,
    merge_spikes=True,
    auto_merge=0.75,
    num_workers=None,
    whitening_max_elts=1000,
    clustering_max_elts=10000
):
    import spikesorters as ss

    recording = LabboxEphysRecordingExtractor(recording_object)

    sorting_params = ss.get_default_params('spykingcircus')
    sorting_params['detect_sign'] = detect_sign
    sorting_params['adjacency_radius'] = adjacency_radius
    sorting_params['detect_threshold'] = detect_threshold
    sorting_params['template_width_ms'] = template_width_ms
    sorting_params['filter'] = filter
    sorting_params['merge_spikes'] = merge_spikes
    sorting_params['auto_merge'] = auto_merge
    sorting_params['num_workers'] = num_workers
    sorting_params['whitening_max_elts'] = whitening_max_elts
    sorting_params['clustering_max_elts'] = clustering_max_elts
    print('Using sorting parameters:', sorting_params)
    with hi.TemporaryDirectory() as tmpdir:
        sorting = ss.run_spykingcircus(recording, output_folder=tmpdir + '/sc_output', delete_output_folder=False, verbose=True, **sorting_params)
        h5_output_fname = tmpdir + '/sorting.h5'
        H5SortingExtractorV1.write_sorting(sorting=sorting, save_path=h5_output_fname)
        return {
            'sorting_format': 'h5_v1',
            'data': {
                'h5_path': ka.store_file(h5_output_fname)
            }
        }
def create_subrecording_object(recording_object, channels, start_frame,
                               end_frame):
    from .bin1recordingextractor import Bin1RecordingExtractor
    recording_format = recording_object['recording_format']
    assert recording_format == 'bin1', f'Unsupported recording format: {recording_format}'
    d = recording_object['data']
    rec = Bin1RecordingExtractor(raw=d['raw'],
                                 num_frames=d['num_frames'],
                                 raw_num_channels=d['raw_num_channels'],
                                 channel_ids=d['channel_ids'],
                                 samplerate=d['samplerate'],
                                 channel_map=d['channel_map'],
                                 channel_positions=d['channel_positions'],
                                 p2p=True)
    rec2 = se.SubRecordingExtractor(parent_recording=rec,
                                    channel_ids=channels,
                                    start_frame=start_frame,
                                    end_frame=end_frame)
    with hi.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.bin'
        rec2.get_traces().astype('int16').tofile(raw_fname)
        new_bin_uri = ka.store_file(raw_fname)
        new_channel_map = dict()
        new_channel_positions = dict()
        for ii, id in enumerate(rec2.get_channel_ids()):
            new_channel_map[str(id)] = ii
            new_channel_positions[str(id)] = rec2.get_channel_locations(
                channel_ids=[id])[0]
        return dict(recording_format='bin1',
                    data=dict(raw=new_bin_uri,
                              raw_num_channels=len(rec2.get_channel_ids()),
                              num_frames=end_frame - start_frame,
                              samplerate=rec2.get_sampling_frequency(),
                              channel_ids=_listify_ndarray(
                                  rec2.get_channel_ids()),
                              channel_map=new_channel_map,
                              channel_positions=new_channel_positions))
Beispiel #20
0
def prepare_snippets_h5(recording_object,
                        sorting_object,
                        start_frame=None,
                        end_frame=None,
                        max_events_per_unit=None,
                        max_neighborhood_size=15):
    if recording_object['recording_format'] == 'snippets1':
        return recording_object['data']['snippets_h5_uri']

    import labbox_ephys as le
    recording = le.LabboxEphysRecordingExtractor(recording_object)
    sorting = le.LabboxEphysSortingExtractor(sorting_object)

    with hi.TemporaryDirectory() as tmpdir:
        save_path = tmpdir + '/snippets.h5'
        prepare_snippets_h5_from_extractors(
            recording=recording,
            sorting=sorting,
            output_h5_path=save_path,
            start_frame=start_frame,
            end_frame=end_frame,
            max_events_per_unit=max_events_per_unit,
            max_neighborhood_size=max_neighborhood_size)
        return ka.store_file(save_path)
    end_frame=None,
    max_events_per_unit=1000,
    max_neighborhood_size=2
)

# Example display some contents of the file
with h5py.File(output_h5_path, 'r') as f:
    unit_ids = np.array(f.get('unit_ids'))
    sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
    if np.isnan(sampling_frequency):
        print('WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.')
        sampling_frequency = 30000
    print(f'Unit IDs: {unit_ids}')
    print(f'Sampling freq: {sampling_frequency}')
    for unit_id in unit_ids:
        unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
        unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
        unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids'))
        print(f'Unit {unit_id} | Tot num events: {len(unit_spike_train)} | shape of subsampled snippets: {unit_waveforms.shape}')

recording = le.LabboxEphysRecordingExtractor({
    'recording_format': 'snippets1',
    'data': {
        'snippets_h5_uri': ka.store_file(output_h5_path)
    }
})
print(f'Channel IDs: {recording.get_channel_ids()}')
print(f'Num. frames: {recording.get_num_frames()}')
for channel_id in recording.get_channel_ids():
    print(f'Channel {channel_id}: {recording.get_channel_property(channel_id, "location")}')
def main():
    import spikeextractors as se
    from spikeforest2_utils import writemda32, AutoRecordingExtractor
    from sklearn.neighbors import NearestNeighbors
    from sklearn.cross_decomposition import PLSRegression
    import spikeforest_widgets as sw
    sw.init_electron()

    # bandpass filter
    with hither.config(container='default', cache='default_readwrite'):
        recobj2 = filter_recording.run(
            recobj=recobj,
            freq_min=300,
            freq_max=6000,
            freq_wid=1000
        ).retval
    
    detect_threshold = 3
    detect_interval = 200
    detect_interval_reference = 10
    detect_sign = -1
    num_events = 1000
    snippet_len = (200, 200)
    window_frac = 0.3
    num_passes = 20
    npca = 100
    max_t = 30000 * 100
    k = 20
    ncomp = 4
    
    R = AutoRecordingExtractor(recobj2)

    X = R.get_traces()
    
    sig = X.copy()
    if detect_sign < 0:
        sig = -sig
    elif detect_sign == 0:
        sig = np.abs(sig)
    sig = np.max(sig, axis=0)
    noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
    times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000)
    times_reference = times_reference[times_reference <= max_t]
    print(f'Num. reference events = {len(times_reference)}')

    snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len)
    tt = np.linspace(-1, 1, snippets_reference.shape[2])
    window0 = np.exp(-tt**2/(2*window_frac**2))
    for j in range(snippets_reference.shape[0]):
        for m in range(snippets_reference.shape[1]):
            snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0
    A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2])

    print('PCA...')
    u, s, vh = np.linalg.svd(A_snippets_reference)
    components_reference = vh[0:npca, :].T
    features_reference = A_snippets_reference @ components_reference

    print('Setting up nearest neighbors...')
    nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference)

    X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32)

    for passnum in range(num_passes):
        print(f'Pass {passnum}')
        sig = X.copy()
        if detect_sign < 0:
            sig = -sig
        elif detect_sign == 0:
            sig = np.abs(sig)
        sig = np.max(sig, axis=0)
        noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
        times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000)
        times = times[times <= max_t]
        print(f'Number of events: {len(times)}')
        if len(times) == 0:
            break
        snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len)
        for j in range(snippets.shape[0]):
            for m in range(snippets.shape[1]):
                snippets[j, m, :] = snippets[j, m, :] * window0
        A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2])
        features = A_snippets @ components_reference
        
        print('Finding nearest neighbors...')
        distances, indices = nbrs.kneighbors(features)
        features2 = np.zeros(features.shape, dtype=features.dtype)
        print('PLS regression...')
        for j in range(features.shape[0]):
            print(f'{j+1} of {features.shape[0]}')
            inds0 = np.squeeze(indices[j, :])
            inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision?
            f_neighbors = features_reference[inds0, :]
            pls = PLSRegression(n_components=ncomp)
            pls.fit(f_neighbors.T, features[j, :].T)
            features2[j, :] = pls.predict(f_neighbors.T).T
        A_snippets_denoised = features2 @ components_reference.T
        
        snippets_denoised = A_snippets_denoised.reshape(snippets.shape)

        for j in range(len(times)):
            t0 = times[j]
            snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :])
            X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0
            X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0

    S = np.concatenate((X_signal, X, R.get_traces()), axis=0)

    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        writemda32(S, raw_fname)
        sig_recobj = recobj2.copy()
        sig_recobj['raw'] = ka.store_file(raw_fname)
    
    sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()
Beispiel #23
0
        return None

# studysets_to_include = ['PAIRED_BOYDEN', 'PAIRED_CRCNS_HC1', 'PAIRED_MEA64C_YGER', 'PAIRED_KAMPFF', 'PAIRED_MONOTRODE', 'SYNTH_MONOTRODE', 'SYNTH_MAGLAND', 'SYNTH_MEAREC_NEURONEXUS', 'SYNTH_MEAREC_TETRODE', 'SYNTH_MONOTRODE', 'SYNTH_VISAPY', 'HYBRID_JANELIA', 'MANUAL_FRANKLAB']
studysets_to_include = ['SYNTH_BIONET']
fnames = ['geom.csv', 'params.json', 'raw.mda', 'firings_true.mda']
# fnames = ['geom.csv', 'params.json', 'firings_true.mda']
# fnames = ['geom.csv', 'params.json']
for studyset in X['StudySets']:
    print('STUDYSET: {}'.format(studyset['name']))
    if studyset['name'] in studysets_to_include:
        for study in studyset['studies']:
            study_name = study['name']
            print('STUDY: {}'.format(study_name))
            for recording in study['recordings']:
                recname = recording['name']
                recdir = recording['directory']
                print('RECORDING: {}'.format(recname), recdir)
                sha1 = get_sha1_part_of_sha1dir(recdir)
                if sha1:
                   ff = mt.realizeFile('sha1://' + sha1)
                   print('Storing directory index file: {} for sha1={}'.format(ff, sha1))
                   ka.store_file(ff)
                for fname in fnames:
                    print('Realizing file: {}'.format(recdir + '/' + fname))
                    ff = mt.realizeFile(path=recdir + '/' + fname)
                    if ff:
                        print('Storing file: {}'.format(ff))
                        ka.store_file(ff)
                    else:
                        print('WARNING: could not realize file: {}'.format(recdir + '/' + fname))
Beispiel #24
0
def store_file(path: str, basename: Union[str, None] = None):
    return ka.store_file(path, basename=basename)
Beispiel #25
0
print('# files: {}'.format(len(list_rec)))
study_obj = dict(name=study_name, studySetName=studyset_name, recordings=[])
for rec1 in list_rec:
    print(f'Uploading {rec1}')
    path_rec1 = os.path.join(path_from, rec1)
    rec = MdaRecordingExtractor(recording_directory=path_rec1)
    sorting = MdaSortingExtractor(firings_file=path_rec1 + '/firings_true.mda',
                                  samplerate=rec.get_sampling_frequency())
    recording_obj = dict(
        name=rec1,
        studyName=study_name,
        studySetName=studyset_name,
        directory=ka.store_dir(path_rec1),
        firingsTrue=ka.store_file(os.path.join(path_to,
                                               rec1 + '.firings_true.json'),
                                  basename='firings_true.json'),
        sampleRateHz=rec.get_sampling_frequency(),
        numChannels=len(rec.get_channel_ids()),
        durationSec=rec.get_num_frames() / rec.get_sampling_frequency(),
        numTrueUnits=len(sorting.get_unit_ids()),
        spikeSign=-1  # TODO: get this from params.json
    )
    study_obj['recordings'].append(recording_obj)

study_obj['self_reference'] = ka.store_object(study_obj)
with open(os.path.join(path_to, study_name + '.json'), 'w') as f:
    json.dump(study_obj, f, indent=4)

studyset_obj = dict(name=studyset_name,
                    info=dict(label=studyset_name,
Beispiel #26
0
def main():
    from spikeforest2 import sorters
    from spikeforest2 import processing

    parser = argparse.ArgumentParser(
        description='Run the SpikeForest2 main analysis')
    # parser.add_argument('analysis_file', help='Path to the analysis specification file (.json format).')
    # parser.add_argument('--config', help='Configuration file', required=True)
    # parser.add_argument('--output', help='Analysis output file (.json format)', required=True)
    # parser.add_argument('--slurm', help='Optional SLURM configuration file (.json format)', required=False, default=None)
    # parser.add_argument('--verbose', help='Provide some additional verbose output.', action='store_true')
    parser.add_argument(
        'spec',
        help='Path to the .json file containing the analysis specification')
    parser.add_argument('--output',
                        '-o',
                        help='The output .json file',
                        required=True)
    parser.add_argument('--force-run',
                        help='Force rerunning of all spike sorting',
                        action='store_true')
    parser.add_argument(
        '--force-run-all',
        help='Force rerunning of all spike sorting and other processing',
        action='store_true')
    parser.add_argument('--parallel',
                        help='Optional number of parallel jobs',
                        required=False,
                        default='0')
    parser.add_argument('--slurm',
                        help='Path to slurm config file',
                        required=False,
                        default=None)
    parser.add_argument('--cache',
                        help='The cache database to use',
                        required=False,
                        default=None)
    parser.add_argument('--rerun-failing',
                        help='Rerun sorting jobs that previously failed',
                        action='store_true')
    parser.add_argument('--test', help='Only run a few.', action='store_true')
    parser.add_argument('--job-timeout',
                        help='Timeout for sorting jobs',
                        required=False,
                        default=600)
    parser.add_argument('--log-file',
                        help='Log file for analysis progress',
                        required=False,
                        default=None)

    args = parser.parse_args()
    force_run_all = args.force_run_all

    # the following apply to sorting jobs only
    force_run = args.force_run or args.force_run_all
    job_timeout = float(args.job_timeout)
    cache_failing = True
    rerun_failing = args.rerun_failing

    with open(args.spec, 'r') as f:
        spec = json.load(f)

    # clear the log file
    if args.log_file is not None:
        with open(args.log_file, 'w'):
            pass

    studysets_path = spec['studysets']
    studyset_names = spec['studyset_names']
    spike_sorters = spec['spike_sorters']

    ka.set_config(fr='default_readonly')

    print(f'Loading study sets object from: {studysets_path}')
    studysets_obj = ka.load_object(studysets_path)
    if not studysets_obj:
        raise Exception(f'Unable to load: {studysets_path}')

    all_study_sets = studysets_obj['StudySets']
    study_sets = []
    for studyset in all_study_sets:
        if studyset['name'] in studyset_names:
            study_sets.append(studyset)

    if int(args.parallel) > 0:
        job_handler = hither.ParallelJobHandler(int(args.parallel))
        job_handler_gpu = job_handler
        job_handler_ks = job_handler
    elif args.slurm:
        with open(args.slurm, 'r') as f:
            slurm_config = json.load(f)
        job_handler = hither.SlurmJobHandler(working_dir='tmp_slurm',
                                             **slurm_config['cpu'])
        job_handler_gpu = hither.SlurmJobHandler(working_dir='tmp_slurm',
                                                 **slurm_config['gpu'])
        job_handler_ks = hither.SlurmJobHandler(working_dir='tmp_slurm',
                                                **slurm_config['ks'])
    else:
        job_handler = None
        job_handler_gpu = None
        job_handler_ks = None

    with hither.config(container='default',
                       cache=args.cache,
                       force_run=force_run_all,
                       job_handler=job_handler,
                       log_path=args.log_file), hither.job_queue():
        studies = []
        recordings = []
        for studyset in study_sets:
            studyset_name = studyset['name']
            print(f'================ STUDY SET: {studyset_name}')
            studies0 = studyset['studies']
            if args.test:
                studies0 = studies0[:1]
                studyset['studies'] = studies0
            for study in studies0:
                study['study_set'] = studyset_name
                study_name = study['name']
                print(f'======== STUDY: {study_name}')
                recordings0 = study['recordings']
                if args.test:
                    recordings0 = recordings0[:2]
                    study['recordings'] = recordings0
                for recording in recordings0:
                    recording['study'] = study_name
                    recording['study_set'] = studyset_name
                    recording['firings_true'] = recording['firingsTrue']
                    recordings.append(recording)
                studies.append(study)

        # Download recordings
        for recording in recordings:
            ka.load_file(recording['directory'] + '/raw.mda')
            ka.load_file(recording['directory'] + '/firings_true.mda')

        # Attach results objects
        for recording in recordings:
            recording['results'] = dict()

        # Summarize recordings
        for recording in recordings:
            recording_path = recording['directory']
            sorting_true_path = recording['firingsTrue']
            recording['results'][
                'computed-info'] = processing.compute_recording_info.run(
                    _label=
                    f'compute-recording-info:{recording["study"]}/{recording["name"]}',
                    recording_path=recording_path,
                    json_out=hither.File())
            recording['results'][
                'true-units-info'] = processing.compute_units_info.run(
                    _label=
                    f'compute-units-info:{recording["study"]}/{recording["name"]}',
                    recording_path=recording_path,
                    sorting_path=sorting_true_path,
                    json_out=hither.File())

        # Spike sorting
        for sorter in spike_sorters:
            for recording in recordings:
                if recording['study_set'] in sorter['studysets']:
                    recording_path = recording['directory']
                    sorting_true_path = recording['firingsTrue']

                    algorithm = sorter['processor_name']
                    if not hasattr(sorters, algorithm):
                        raise Exception(
                            f'No such sorting algorithm: {algorithm}')
                    Sorter = getattr(sorters, algorithm)

                    if algorithm in ['ironclust']:
                        gpu = True
                        jh = job_handler_gpu
                    elif algorithm in ['kilosort', 'kilosort2']:
                        gpu = True
                        jh = job_handler_ks
                    else:
                        gpu = False
                        jh = job_handler
                    with hither.config(gpu=gpu,
                                       force_run=force_run,
                                       exception_on_fail=False,
                                       cache_failing=cache_failing,
                                       rerun_failing=rerun_failing,
                                       job_handler=jh,
                                       job_timeout=job_timeout):
                        sorting_result = Sorter.run(
                            _label=
                            f'{algorithm}:{recording["study"]}/{recording["name"]}',
                            recording_path=recording['directory'],
                            sorting_out=hither.File())
                        recording['results']['sorting-' +
                                             sorter['name']] = sorting_result
                    recording['results'][
                        'comparison-with-truth-' +
                        sorter['name']] = processing.compare_with_truth.run(
                            _label=
                            f'comparison-with-truth:{algorithm}:{recording["study"]}/{recording["name"]}',
                            sorting_path=sorting_result.outputs.sorting_out,
                            sorting_true_path=sorting_true_path,
                            json_out=hither.File())
                    recording['results'][
                        'units-info-' +
                        sorter['name']] = processing.compute_units_info.run(
                            _label=
                            f'units-info:{algorithm}:{recording["study"]}/{recording["name"]}',
                            recording_path=recording_path,
                            sorting_path=sorting_result.outputs.sorting_out,
                            json_out=hither.File())

    # Assemble all of the results
    print('')
    print('=======================================================')
    print('Assembling results...')
    for recording in recordings:
        print(
            f'Assembling recording: {recording["study"]}/{recording["name"]}')
        recording['summary'] = dict(
            plots=dict(),
            computed_info=ka.load_object(
                recording['results']['computed-info'].outputs.json_out._path),
            true_units_info=ka.store_file(
                recording['results']
                ['true-units-info'].outputs.json_out._path))
    sorting_results = []
    for sorter in spike_sorters:
        for recording in recordings:
            if recording['study_set'] in sorter['studysets']:
                print(
                    f'Assembling sorting: {sorter["processor_name"]} {recording["study"]}/{recording["name"]}'
                )
                sorting_result = recording['results']['sorting-' +
                                                      sorter['name']]
                comparison_result = recording['results'][
                    'comparison-with-truth-' + sorter['name']]
                units_info_result = recording['results']['units-info-' +
                                                         sorter['name']]
                console_out_str = _console_out_to_str(
                    sorting_result.runtime_info['console_out'])
                console_out_path = ka.store_text(console_out_str)
                sr = dict(
                    recording=recording,
                    sorter=sorter,
                    firings_true=recording['directory'] + '/firings_true.mda',
                    processor_name=sorter['processor_name'],
                    processor_version=sorting_result.version,
                    execution_stats=dict(
                        start_time=sorting_result.runtime_info['start_time'],
                        end_time=sorting_result.runtime_info['end_time'],
                        elapsed_sec=sorting_result.runtime_info['end_time'] -
                        sorting_result.runtime_info['start_time'],
                        retcode=0 if sorting_result.success else -1,
                        timed_out=sorting_result.runtime_info.get(
                            'timed_out', False)),
                    container=sorting_result.container,
                    console_out=console_out_path)
                if sorting_result.success:
                    sr['firings'] = ka.store_file(
                        sorting_result.outputs.sorting_out._path)
                    sr['comparison_with_truth'] = dict(json=ka.store_file(
                        comparison_result.outputs.json_out._path))
                    sr['sorted_units_info'] = ka.store_file(
                        units_info_result.outputs.json_out._path)
                else:
                    sr['firings'] = None
                    sr['comparison_with_truth'] = None
                    sr['sorted_units_info'] = None
                sorting_results.append(sr)

    # Delete results from recordings
    for recording in recordings:
        del recording['results']

    # Aggregate sorting results
    print('')
    print('=======================================================')
    print('Aggregating sorting results...')
    aggregated_sorting_results = aggregate_sorting_results(
        studies, recordings, sorting_results)

    # Show output summary
    for sr in aggregated_sorting_results['study_sorting_results']:
        study_name = sr['study']
        sorter_name = sr['sorter']
        n1 = np.array(sr['num_matches'])
        n2 = np.array(sr['num_false_positives'])
        n3 = np.array(sr['num_false_negatives'])
        accuracies = n1 / (n1 + n2 + n3)
        avg_accuracy = np.mean(accuracies)
        txt = 'STUDY: {}, SORTER: {}, AVG ACCURACY: {}'.format(
            study_name, sorter_name, avg_accuracy)
        print(txt)

    output_object = dict(studies=studies,
                         recordings=recordings,
                         study_sets=study_sets,
                         sorting_results=sorting_results,
                         aggregated_sorting_results=ka.store_object(
                             aggregated_sorting_results,
                             basename='aggregated_sorting_results.json'))

    print(f'Writing output to {args.output}...')
    with open(args.output, 'w') as f:
        json.dump(output_object, f, indent=4)
    print('Done.')
    def dumpDataset(self, uuid):
        exclude_groups = self.options.get('exclude_groups', None) or []
        response = {}
        self.log.info("dumpDataset: " + uuid)
        item = self.db.getDatasetItemByUuid(uuid)
        if 'alias' in item:
            alias = item['alias']
            if alias:
                self.log.info("dumpDataset alias: [" + alias[0] + "]")
                for a in alias:
                    for e in exclude_groups:
                        if e == a or a.startswith(e + '/'):
                            return None
            response['alias'] = item['alias']
        else:
            alias = None

        typeItem = item['type']
        response['type'] = getTypeResponse(typeItem)
        shapeItem = item['shape']
        shape_rsp = {}
        num_elements = 1
        shape_rsp['class'] = shapeItem['class']
        if 'dims' in shapeItem:
            shape_rsp['dims'] = shapeItem['dims']
            for dim in shapeItem['dims']:
                num_elements *= dim
        if 'maxdims' in shapeItem:
            maxdims = []
            for dim in shapeItem['maxdims']:
                if dim == 0:
                    maxdims.append('H5S_UNLIMITED')
                else:
                    maxdims.append(dim)
            shape_rsp['maxdims'] = maxdims
        response['shape'] = shape_rsp

        if 'creationProperties' in item:
            response['creationProperties'] = item['creationProperties']

        attributes = self.dumpAttributes('datasets', uuid)
        if attributes:
            response['attributes'] = attributes

        shape_class = shapeItem['class']
        # if (not (self.options.D or self.options.d)) or (shape_class == 'H5S_SCALAR'):
        alias_name = None
        if len(alias) > 0:
            if len(alias[0].split('/')) > 0:
                alias_name = alias[0].split('/')[-1]
        include_value = False
        write_binary = False
        include_dataset_names = self.options.get('include_dataset_names',
                                                 None) or []
        if (self.options.get('include_datasets', False)) or (
                alias_name in include_dataset_names) or (shape_class
                                                         == 'H5S_SCALAR'):
            include_value = True
        else:
            dset = self.db.getDatasetObjByUuid(uuid)
            type_info = getTypeItem(dset.dtype)
            type_class = type_info['class']
            if type_class == 'H5T_ARRAY':
                write_binary = True
            elif type_class == 'H5T_FLOAT':
                write_binary = True
            elif type_class == 'H5T_INTEGER':
                write_binary = True
            ###################################################################
            elif type_class == 'H5T_STRING':
                include_value = True
            elif type_class == 'H5T_ENUM':
                include_value = True
            elif type_class == 'H5T_REFERENCE':
                include_value = True
            ###################################################################
            elif type_class == 'H5T_COMPOUND':
                raise Exception('Unsupported type class {}: {}'.format(
                    type_class, alias))
            elif type_class == 'H5T_VLEN':
                raise Exception('Unsupported type class {}: {}'.format(
                    type_class, alias))
            elif type_class == 'H5T_OPAQUE':
                raise Exception('Unsupported type class {}: {}'.format(
                    type_class, alias))
            else:
                raise Exception('Unsupported type class {}: {}'.format(
                    type_class, alias))

        if include_value:
            if num_elements > 0:
                value = self.db.getDatasetValuesByUuid(uuid)
                response[
                    'value'] = value  # dump values unless header flag was passed
            else:
                response['value'] = []  # empty list
        if write_binary:
            if self.options['data_dir']:
                print('Writing {} with shape {}'.format(
                    alias, shapeItem.get('dims')))
                fname = self.options['data_dir'] + '/' + _random_string(
                    10) + '.dat.tmp'
                with open(fname, 'wb') as f:
                    self.db.getDatasetValuesByUuid(uuid,
                                                   format='binary',
                                                   tofile=f)
                sha1 = _compute_file_hash(path=fname, algorithm='sha1')
                fname2 = self.options['data_dir'] + '/sha1_' + sha1 + '.dat'
                os.rename(fname, fname2)
                if self.options['use_kachery']:
                    try:
                        import kachery as ka
                    except:
                        raise Exception(
                            'Kachery is not installed. Try "pip install --upgrade kachery".'
                        )
                    ka.store_file(fname2)
                response['valueHash'] = dict(sha1=sha1)
            else:
                if self.options['use_kachery']:
                    raise Exception('Cannot use kachery without data_dir')
        return response