Beispiel #1
0
    def __init__(self,
                 *,
                 recording_directory=None,
                 timeseries_path=None,
                 download=False,
                 samplerate=None,
                 geom=None,
                 geom_path=None,
                 params_path=None):
        RecordingExtractor.__init__(self)
        if recording_directory:
            timeseries_path = recording_directory + '/raw.mda'
            geom_path = recording_directory + '/geom.csv'
            params_path = recording_directory + '/params.json'
        self._timeseries_path = timeseries_path
        if params_path:
            self._dataset_params = ka.load_object(params_path)
            self._samplerate = self._dataset_params['samplerate']
        else:
            self._dataset_params = dict(samplerate=samplerate)
            self._samplerate = samplerate

        if download:
            path0 = ka.load_file(path=self._timeseries_path)
            if not path0:
                raise Exception('Unable to realize file: ' +
                                self._timeseries_path)
            self._timeseries_path = path0

        self._timeseries = DiskReadMda(self._timeseries_path)
        if self._timeseries is None:
            raise Exception('Unable to load timeseries: {}'.format(
                self._timeseries_path))
        X = self._timeseries
        if geom is not None:
            self._geom = geom
        elif geom_path:
            geom_path2 = ka.load_file(geom_path)
            self._geom = np.genfromtxt(geom_path2, delimiter=',')
        else:
            self._geom = np.zeros((X.N1(), 2))

        if self._geom.shape[0] != X.N1():
            # raise Exception(
            #    'Incompatible dimensions between geom.csv and timeseries file {} <> {}'.format(self._geom.shape[0], X.N1()))
            print(
                'WARNING: Incompatible dimensions between geom.csv and timeseries file {} <> {}'
                .format(self._geom.shape[0], X.N1()))
            self._geom = np.zeros((X.N1(), 2))

        self._hash = ka.get_object_hash(
            dict(timeseries=ka.get_file_hash(self._timeseries_path),
                 samplerate=self._samplerate,
                 geom=_json_serialize(self._geom)))

        self._num_channels = X.N1()
        self._num_timepoints = X.N2()
        for m in range(self._num_channels):
            self.set_channel_property(m, 'location', self._geom[m, :])
Beispiel #2
0
def _load_file_from_file_server(*, uri, dest, file_server_url):
    protocol, algorithm, hash0, additional_path, query = _parse_kachery_uri(
        uri)
    if query.get('manifest'):
        manifest = load_object(f'sha1://{query["manifest"][0]}')
        if manifest is None:
            print('Unable to load manifest')
            return None
        assert manifest[
            'sha1'] == hash0, 'Manifest sha1 does not match expected.'
        chunk_local_paths = []
        for ii, ch in enumerate(manifest['chunks']):
            if len(manifest['chunks']) > 1:
                print(
                    f'load_bytes: Loading chunk {ii + 1} of {len(manifest["chunks"])}'
                )
            a = load_file(
                uri=
                f'sha1://{ch["sha1"]}?chunkOf={hash0}~{ch["start"]}~{ch["end"]}'
            )
            if a is None:
                print('Unable to load data from chunk')
                return None
            chunk_local_paths.append(a)
        with TemporaryDirectory() as tmpdir:
            concat_fname = f'{tmpdir}/concat_{hash0}'
            print('Concatenating chunks...')
            sha1_concat = _concatenate_files_and_compute_sha1(
                paths=chunk_local_paths, dest=concat_fname)
            assert sha1_concat == hash0, f'Unexpected sha1 of concatenated file: {sha1_concat} <> {hash0}'
            ka.core._store_local_file_in_cache(path=concat_fname,
                                               hash=sha1_concat,
                                               algorithm='sha1',
                                               config=ka.core._load_config())
            ff = ka.load_file('sha1://' + hash0)
            assert ff is not None, f'Unexpected problem. Unable to load file after storing in local cache: sha1://{hash0}'
            return ff
    if query.get('chunkOf'):
        chunkOf_str = query.get('chunkOf')[0]
    else:
        chunkOf_str = None

    with TemporaryDirectory() as tmpdir:
        tmp_fname = tmpdir + f'/download_{hash0}'
        url = f'{file_server_url}/sha1/{hash0}'
        if chunkOf_str is not None:
            url = url + f'?chunkOf={chunkOf_str}'
        sha1 = _download_file_and_compute_sha1(url=url, fname=tmp_fname)
        assert hash0 == sha1, f'Unexpected sha1 of downloaded file: {sha1} <> {hash0}'
        # todo: think about how to do this without calling internal (private) function
        ka.core._store_local_file_in_cache(path=tmp_fname,
                                           hash=sha1,
                                           algorithm='sha1',
                                           config=ka.core._load_config())
        ff = ka.load_file('sha1://' + sha1)
        assert ff is not None, f'Unexpected problem. Unable to load file after storing in local cache: sha1://{hash0}'
        return ff
Beispiel #3
0
def fetch_spike_waveforms(snippets_h5, unit_ids, spike_indices):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    spikes = []
    with h5py.File(h5_path, 'r') as f:
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        for ii, unit_id in enumerate(unit_ids):
            unit_waveforms = np.array(
                f.get(f'unit_waveforms/{unit_id}/waveforms'))
            unit_waveforms_channel_ids = np.array(
                f.get(f'unit_waveforms/{unit_id}/channel_ids'))
            unit_waveforms_spike_train = np.array(
                f.get(f'unit_waveforms/{unit_id}/spike_train'))
            average_waveform = np.mean(unit_waveforms, axis=0)
            channel_maximums = np.max(np.abs(average_waveform), axis=1)
            maxchan_index = np.argmax(channel_maximums)
            maxchan_id = unit_waveforms_channel_ids[maxchan_index]
            for spike_index in spike_indices[ii]:
                spikes.append(
                    dict(unit_id=unit_id,
                         spike_index=spike_index,
                         spike_time=unit_waveforms_spike_train[spike_index],
                         channel_id=maxchan_id,
                         waveform=unit_waveforms[
                             spike_index,
                             maxchan_index, :].squeeze().tolist()))
    return {'sampling_frequency': sampling_frequency, 'spikes': spikes}
Beispiel #4
0
def get_similar_units(snippets_h5):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        unit_infos = {}
        for unit_id in unit_ids:
            unit_waveforms = np.array(
                f.get(f'unit_waveforms/{unit_id}/waveforms'))
            channel_ids = np.array(
                f.get(f'unit_waveforms/{unit_id}/channel_ids'))
            average_waveform = np.mean(unit_waveforms, axis=0)
            channel_maximums = np.max(np.abs(average_waveform), axis=1)
            maxchan_index = np.argmax(channel_maximums)
            maxchan_id = int(channel_ids[maxchan_index])
            unit_infos[int(unit_id)] = dict(unit_id=unit_id,
                                            average_waveform=average_waveform,
                                            channel_ids=channel_ids,
                                            peak_channel_id=maxchan_id)
        ret = {}
        for id1 in unit_ids:
            x = [
                dict(unit_id=id2,
                     similarity=_get_similarity_score(unit_infos[int(id1)],
                                                      unit_infos[int(id2)]))
                for id2 in unit_ids if (id2 != id1)
            ]
            x.sort(key=lambda a: a['similarity'], reverse=True)
            x = [a for a in x if a['similarity'] >= 0.2]
            ret[str(id1)] = x
        return ret
Beispiel #5
0
def fetch_spike_amplitudes(snippets_h5, unit_id):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    # with h5py.File(h5_path, 'r') as f:
    #     unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
    #     unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))

    unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5(
        h5_path, unit_id)
    average_waveform = np.mean(unit_waveforms, axis=0)
    peak_channel_index = _compute_peak_channel_index_from_average_waveform(
        average_waveform)
    maxs = [
        np.max(unit_waveforms[i][peak_channel_index, :])
        for i in range(unit_waveforms.shape[0])
    ]
    mins = [
        np.min(unit_waveforms[i][peak_channel_index, :])
        for i in range(unit_waveforms.shape[0])
    ]
    peak_amplitudes = np.array([maxs[i] - mins[i] for i in range(len(mins))])

    timepoints = unit_spike_train.astype(np.float32)
    amplitudes = peak_amplitudes.astype(np.float32)

    sort_inds = np.argsort(timepoints)
    timepoints = timepoints[sort_inds]
    amplitudes = amplitudes[sort_inds]

    return dict(timepoints=timepoints, amplitudes=amplitudes)
Beispiel #6
0
def get_sorting_unit_snippets(snippets_h5, unit_id, time_range,
                              max_num_snippets):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    # with h5py.File(h5_path, 'r') as f:
    #     unit_ids = np.array(f.get('unit_ids'))
    #     channel_ids = np.array(f.get('channel_ids'))
    #     channel_locations = np.array(f.get(f'channel_locations'))
    #     sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
    #     if np.isnan(sampling_frequency):
    #         print('WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.')
    #         sampling_frequency = 30000
    #     unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
    #     unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
    #     unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids'))
    #     print(unit_waveforms_channel_ids)
    unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5(
        h5_path, unit_id)

    snippets = [{
        'index': j,
        'unitId': unit_id,
        'waveform': unit_waveforms[j].astype(np.float32),
        'timepoint': float(unit_spike_train[j])
    } for j in range(unit_waveforms.shape[0])
                if time_range['min'] <= unit_spike_train[j]
                and unit_spike_train[j] < time_range['max']]

    return dict(channel_ids=unit_waveforms_channel_ids.astype(np.int32),
                channel_locations=channel_locations0.astype(np.float32),
                sampling_frequency=sampling_frequency,
                snippets=snippets[:max_num_snippets])
Beispiel #7
0
 def __init__(self, arg, samplerate=None):
     super().__init__()
     self._hash = None
     if isinstance(arg, se.SortingExtractor):
         self._sorting = arg
         self.copy_unit_properties(sorting=self._sorting)
     else:
         self._sorting = None
         if type(arg) == str:
             arg = dict(path=arg, samplerate=samplerate)
         if type(arg) == dict:
             if 'kachery_config' in arg:
                 ka.set_config(**arg['kachery_config'])
             if 'path' in arg:
                 path = arg['path']
                 if ka.get_file_info(path):
                     file_path = ka.load_file(path)
                     if not file_path:
                         raise Exception(
                             'Unable to realize file: {}'.format(path))
                     self._init_from_file(file_path,
                                          original_path=path,
                                          kwargs=arg)
                 else:
                     raise Exception('Not a file: {}'.format(path))
             else:
                 raise Exception('Unable to initialize sorting extractor')
         else:
             raise Exception(
                 'Unable to initialize sorting extractor (unexpected type)')
 def javascript_state_changed(self, prev_state, state):
     self.set_python_state(dict(status='running', status_message='Running'))
     path = state.get('path', None)
     if path:
         self.set_python_state(dict(status_message='Realizing file: {}'.format(path)))
         if path.endswith('.csv'):
             path2 = ka.load_file(path)
             if not path2:
                 self.set_python_state(dict(
                     status='error',
                     status_message='Unable to realize file: {}'.format(path)
                 ))
                 return
             self.set_python_state(dict(status_message='Loading locations'))
             x = np.genfromtxt(path2, delimiter=',')
             locations = x.T
             num_elec = x.shape[0]
             labels = ['{}'.format(a) for a in range(1, num_elec + 1)]
         else:
             raise Exception('Unexpected file type for {}'.format(path))
     else:
         locations = [[0, 0], [1, 0], [1, 1], [2, 1]]
         labels = ['1', '2', '3', '4']
     state = dict()
     state['locations'] = locations
     state['labels'] = labels
     state['status'] = 'finished'
     self.set_python_state(state)
Beispiel #9
0
def get_sorting_unit_info(snippets_h5, unit_id):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        channel_ids = np.array(f.get('channel_ids'))
        channel_locations = np.array(f.get(f'channel_locations'))
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        unit_waveforms_channel_ids = np.array(
            f.get(f'unit_waveforms/{unit_id}/channel_ids'))
        print(unit_waveforms_channel_ids)

    channel_locations0 = []
    for ch_id in unit_waveforms_channel_ids:
        ind = np.where(channel_ids == ch_id)[0]
        channel_locations0.append(channel_locations[ind, :].ravel().tolist())

    return dict(channel_ids=unit_waveforms_channel_ids.astype(np.int32),
                channel_locations=channel_locations0,
                sampling_frequency=sampling_frequency)
Beispiel #10
0
def _internal_deserialize_result(obj):
    import kachery as ka
    result = Result()

    result.runtime_info = obj['runtime_info']
    result.runtime_info['console_out'] = ka.load_object(
        result.runtime_info.get('console_out', ''))
    if result.runtime_info['console_out'] is None:
        return None

    output_files = obj['output_files']
    for oname, path in output_files.items():
        if path is not None:
            path2 = ka.load_file(path)
            if path2 is None:
                print('Unable to find file when deserializing result.')
                return None
        else:
            path2 = None
        setattr(result.outputs, oname, File(path2))
        result._output_names.append(oname)

    result.retval = obj['retval']
    result.success = obj.get('success', False)
    result.version = obj.get('version', None)
    result.container = obj.get('container', None)
    result.hash_object = obj['hash_object']
    result.status = obj['status']
    return result
Beispiel #11
0
def fetch_average_waveform_plot_data(snippets_h5, unit_id):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
        unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
        unit_waveforms_channel_ids = np.array(
            f.get(f'unit_waveforms/{unit_id}/channel_ids'))
        print(unit_waveforms_channel_ids)

    average_waveform = np.mean(unit_waveforms, axis=0)
    channel_maximums = np.max(np.abs(average_waveform), axis=1)
    maxchan_index = np.argmax(channel_maximums)
    maxchan_id = unit_waveforms_channel_ids[maxchan_index]

    return dict(channel_id=int(maxchan_id),
                sampling_frequency=sampling_frequency,
                average_waveform=average_waveform[maxchan_index, :].astype(
                    float).tolist())
Beispiel #12
0
def individual_cluster_features(snippets_h5, unit_id, max_num_events=1000):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        channel_ids = np.array(f.get('channel_ids'))
        channel_locations = np.array(f.get(f'channel_locations'))
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
        unit_waveforms = np.array(
            f.get(f'unit_waveforms/{unit_id}/waveforms'))  # L x M x T
        unit_waveforms_channel_ids = np.array(
            f.get(f'unit_waveforms/{unit_id}/channel_ids'))
        if len(unit_spike_train) > max_num_events:
            inds = subsample_inds(len(unit_spike_train), max_num_events)
            unit_spike_train = unit_spike_train[inds]
            unit_waveforms = unit_waveforms[inds]

    from sklearn.decomposition import PCA
    nf = 2  # number of features
    W = unit_waveforms  # L x M x T
    # subtract mean for each channel and waveform
    for i in range(W.shape[0]):
        for m in range(W.shape[1]):
            W[i, m, :] = W[i, m, :] - np.mean(W[i, m, :])
    X = W.reshape((W.shape[0], W.shape[1] * W.shape[2]))  # L x MT
    pca = PCA(n_components=nf)
    pca.fit(X)
    features = pca.transform(X)  # L x nf

    return dict(timepoints=unit_spike_train.astype(np.float32),
                x=features[:, 0].squeeze().astype(np.float32),
                y=features[:, 1].squeeze().astype(np.float32))
Beispiel #13
0
def fetch_pca_features(snippets_h5, unit_ids):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    with h5py.File(h5_path, 'r') as f:
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        x = [
            dict(
                unit_id=unit_id,
                unit_waveforms_spike_train=np.array(
                    f.get(f'unit_waveforms/{unit_id}/spike_train')),
                # unit_waveforms_spike_train=_subsample(np.array(f.get(f'unit_spike_trains/{unit_id}')), 1000),
                unit_waveforms=np.array(
                    f.get(f'unit_waveforms/{unit_id}/waveforms')),
                unit_waveforms_channel_ids=np.array(
                    f.get(f'unit_waveforms/{unit_id}/channel_ids')))
            for unit_id in unit_ids
        ]
        channel_ids = _intersect_channel_ids(
            [a['unit_waveforms_channel_ids'] for a in x])
        assert len(channel_ids) > 0, 'No channel ids in intersection'
        for a in x:
            unit_waveforms = a['unit_waveforms']
            unit_waveforms_channel_ids = a['unit_waveforms_channel_ids']
            inds = [
                np.where(unit_waveforms_channel_ids == ch_id)[0][0]
                for ch_id in channel_ids
            ]
            a['unit_waveforms_2'] = unit_waveforms[:, inds, :]
            a['labels'] = np.ones((unit_waveforms.shape[0], )) * a['unit_id']

    unit_waveforms = np.concatenate([a['unit_waveforms_2'] for a in x], axis=0)
    spike_train = np.concatenate([a['unit_waveforms_spike_train'] for a in x])
    labels = np.concatenate([a['labels'] for a in x]).astype(int)

    from sklearn.decomposition import PCA

    nf = 5  # number of features

    # list of arrays
    W = unit_waveforms  # ntot x M x T

    # ntot x MT
    X = W.reshape((W.shape[0], W.shape[1] * W.shape[2]))

    pca = PCA(n_components=nf)
    pca.fit(X)

    features = pca.transform(X)  # n x nf

    return dict(
        times=(spike_train / sampling_frequency).tolist(),
        features=[features[:, ii].squeeze().tolist() for ii in range(nf)],
        labels=labels.tolist())
Beispiel #14
0
def _handle_temporary_outputs(outputs):
    import kachery as ka
    for output in outputs:
        if output._is_temporary:
            old_path = output._path
            new_path = ka.load_file(ka.store_file(old_path))
            output._path = new_path
            output._is_temporary = False
            os.unlink(old_path)
def main():
    parser = argparse.ArgumentParser(description=help_txt, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('path', help='Path to the assembled website data in .json format')

    args = parser.parse_args()

    print('Loading spike-front results object...')
    with open(args.path, 'r') as f:
        obj = json.load(f)

    SortingResults = obj['SortingResults']

    for sr in SortingResults:
        print(sr['studyName'], sr['recordingName'])
        console_out_fname = ka.load_file(sr['consoleOut'])
        mt.createSnapshot(path=console_out_fname, upload_to='spikeforest.public')
        if sr.get('firings', None) is not None:
            firings_fname = ka.load_file(sr['firings'])
            mt.createSnapshot(path=firings_fname, upload_to='spikeforest.public')
Beispiel #16
0
def _handle_temporary_outputs(outputs: List[File]):
    import kachery as ka
    for output in outputs:
        if not output._exists:
            old_path = output._path
            new_path = ka.load_file(ka.store_file(old_path))
            output._path = new_path
            output._is_temporary = False
            output._exists = True
            os.unlink(old_path)
Beispiel #17
0
def test_1(tmp_path, datajoint_server):
    from nwb_datajoint.data_import import insert_sessions
    from nwb_datajoint.common import Session, Device, Probe
    tmpdir = str(tmp_path)
    os.environ['NWB_DATAJOINT_BASE_DIR'] = tmpdir + '/nwb-data'
    os.environ['KACHERY_STORAGE_DIR'] = tmpdir + '/nwb-data/kachery-storage'
    os.mkdir(os.environ['NWB_DATAJOINT_BASE_DIR'])
    os.mkdir(os.environ['KACHERY_STORAGE_DIR'])

    nwb_fname = os.environ['NWB_DATAJOINT_BASE_DIR'] + '/test.nwb'

    with ka.config(fr='default_readonly'):
        ka.load_file(
            'sha1://8ed68285c327b3766402ee75730d87994ac87e87/beans20190718_no_eseries_no_behavior.nwb',
            dest=nwb_fname)

    with pynwb.NWBHDF5IO(path=nwb_fname, mode='r') as io:
        nwbf = io.read()

    insert_sessions(['test.nwb'])

    x = (Session() & {'nwb_file_name': 'test.nwb'}).fetch1()
    assert x['nwb_file_name'] == 'test.nwb'
    assert x['subject_id'] == 'Beans'
    assert x['institution_name'] == 'University of California, San Francisco'
    assert x['lab_name'] == 'Loren Frank'
    assert x['session_id'] == 'beans_01'
    assert x['session_description'] == 'Reinforcement leaarning'
    assert x['session_start_time'] == datetime(2019, 7, 18, 15, 29, 47)
    assert x['timestamps_reference_time'] == datetime(1970, 1, 1, 0, 0)
    assert x['experiment_description'] == 'Reinforcement learning'

    x = Device().fetch()
    # No devices?
    assert len(x) == 0

    x = Probe().fetch()
    assert len(x) == 1
    assert x[0]['probe_type'] == '128c-4s8mm6cm-20um-40um-sl'
    assert x[0]['probe_description'] == '128 channel polyimide probe'
    assert x[0]['num_shanks'] == 4
    assert x[0]['contact_side_numbering'] == 'True'
Beispiel #18
0
    def __init__(self, firings_file, samplerate):
        SortingExtractor.__init__(self)
        self._firings_path = ka.load_file(firings_file)
        if not self._firings_path:
            raise Exception('Unable to load firings file: ' + firings_file)

        self._firings = readmda(self._firings_path)
        self._sampling_frequency = samplerate
        self._times = self._firings[1, :]
        self._labels = self._firings[2, :]
        self._unit_ids = np.unique(self._labels).astype(int)
Beispiel #19
0
def register_recording(*, recdir, output_fname, label, to):
    with ka.config(to=to):
        raw_path = ka.store_file(recdir + '/raw.mda')
        obj = dict(raw=raw_path,
                   params=ka.load_object(recdir + '/params.json'),
                   geom=np.genfromtxt(ka.load_file(recdir + '/geom.csv'),
                                      delimiter=',').tolist())
        obj['self_reference'] = ka.store_object(
            obj, basename='{}.json'.format(label))
        with open(output_fname, 'w') as f:
            json.dump(obj, f, indent=4)
Beispiel #20
0
def get_unit_snrs(snippets_h5):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    ret = {}
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        for unit_id in unit_ids:
            unit_waveforms = np.array(
                f.get(f'unit_waveforms/{unit_id}/waveforms'))  # n x M x T
            ret[str(unit_id)] = _compute_unit_snr_from_waveforms(
                unit_waveforms)
    return ret
def get_peak_channels(snippets_h5):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    ret = {}
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        for unit_id in unit_ids:
            unit_waveforms = np.array(
                f.get(f'unit_waveforms/{unit_id}/waveforms'))  # n x M x T
            channel_ids = np.array(
                f.get(f'unit_waveforms/{unit_id}/channel_ids'))  # n
            peak_channel_index = _compute_peak_channel_index_from_waveforms(
                unit_waveforms)
            ret[str(unit_id)] = int(channel_ids[peak_channel_index])
    return ret
Beispiel #22
0
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running clustering')

        alg_name = state.get('alg_name', 'none')
        alg_arguments = state.get('alg_arguments', dict())
        kachery_config = state.get('kachery_config', None)
        args0 = alg_arguments.get(alg_name, {})

        if kachery_config:
            ka.set_config(**kachery_config)

        dirname = os.path.dirname(os.path.realpath(__file__))
        fname = os.path.join(dirname, 'clustering_datasets.json')
        with open(fname, 'r') as f:
            datasets = json.load(f)
        timer = time.time()
        for ds in datasets['datasets']:
            self._set_status('running', 'Running: {}'.format(ds['path']))
            print('Loading {}'.format(ds['path']))
            path2 = ka.load_file(ds['path'])
            ka.store_file(path2)
            if path2:
                ds['data'] = self._load_dataset_data(path2)
                if alg_name:
                    print('Clustering...')
                    ds['algName'] = alg_name
                    ds['algArgs'] = args0
                    timer0 = time.time()
                    ds['labels'] = self._do_clustering(
                        ds['data'], alg_name, args0,
                        dict(true_num_clusters=ds['trueNumClusters']))
                    elapsed0 = time.time() - timer0
                    if alg_name != 'none':
                        ds['elapsed'] = elapsed0
            else:
                print('Unable to realize file: {}'.format(ds['path']))
            elapsed = time.time() - timer
            if elapsed > 0.1:
                self._set_state(algorithms=self._algorithms, datasets=datasets)

        self._set_state(algorithms=self._algorithms, datasets=datasets)

        self._set_status('finished', 'Finished clustering')
Beispiel #23
0
def readmda(path):
    if (file_extension(path) == '.npy'):
        return readnpy(path)
    path = ka.load_file(path)
    H = _read_header(path)
    if (H is None):
        print("Problem reading header of: {}".format(path))
        return None
    ret = np.array([])
    f = open(path, "rb")
    try:
        f.seek(H.header_size)
        # This is how I do the column-major order
        ret = np.fromfile(f, dtype=H.dt, count=H.dimprod)
        ret = np.reshape(ret, H.dims, order='F')
        f.close()
        return ret
    except Exception as e:  # catch *all* exceptions
        print(e)
        f.close()
        return None
Beispiel #24
0
def get_sorting_unit_snippets(snippets_h5, unit_id, time_range,
                              max_num_snippets):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        channel_ids = np.array(f.get('channel_ids'))
        channel_locations = np.array(f.get(f'channel_locations'))
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
        unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
        unit_waveforms_channel_ids = np.array(
            f.get(f'unit_waveforms/{unit_id}/channel_ids'))
        print(unit_waveforms_channel_ids)

    snippets = [{
        'index': j,
        'unitId': unit_id,
        'waveform': unit_waveforms[j, :, :].astype(np.float32),
        'timepoint': float(unit_spike_train[j])
    } for j in range(unit_waveforms.shape[0])]
    snippets = [
        s for s in snippets if time_range['min'] <= s['timepoint']
        and s['timepoint'] < time_range['max']
    ]
    channel_locations0 = []
    for ch_id in unit_waveforms_channel_ids:
        ind = np.where(channel_ids == ch_id)[0]
        channel_locations0.append(channel_locations[ind, :].ravel().tolist())

    return dict(channel_ids=unit_waveforms_channel_ids.astype(np.int32),
                channel_locations=channel_locations0,
                sampling_frequency=sampling_frequency,
                snippets=snippets[:max_num_snippets])
Beispiel #25
0
def fetch_spike_amplitudes(snippets_h5, unit_id):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    with h5py.File(h5_path, 'r') as f:
        unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
        unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
        average_waveform = np.mean(unit_waveforms, axis=0)
        peak_channel_index = _compute_peak_channel_index_from_average_waveform(
            average_waveform)
        maxs = [
            np.max(unit_waveforms[i][peak_channel_index, :])
            for i in range(unit_waveforms.shape[0])
        ]
        mins = [
            np.min(unit_waveforms[i][peak_channel_index, :])
            for i in range(unit_waveforms.shape[0])
        ]
        peak_amplitudes = np.array(
            [maxs[i] - mins[i] for i in range(len(mins))])

    return dict(timepoints=unit_spike_train.astype(np.float32),
                amplitudes=peak_amplitudes.astype(np.float32))
def fetch_average_waveform_2(snippets_h5, unit_id, visible_channel_ids):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        channel_ids = np.array(f.get('channel_ids'))
        channel_locations = np.array(f.get(f'channel_locations'))
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
        unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
        unit_waveforms_channel_ids = np.array(
            f.get(f'unit_waveforms/{unit_id}/channel_ids'))

        if visible_channel_ids is not None:
            inds = [
                ii for ii in range(len(unit_waveforms_channel_ids))
                if unit_waveforms_channel_ids[ii] in visible_channel_ids
            ]
            unit_waveforms_channel_ids = unit_waveforms_channel_ids[inds]
            unit_waveforms = unit_waveforms[:, inds, :]

        print(unit_waveforms_channel_ids)

    average_waveform = np.mean(unit_waveforms, axis=0)
    channel_locations0 = []
    for ch_id in unit_waveforms_channel_ids:
        ind = np.where(channel_ids == ch_id)[0]
        channel_locations0.append(channel_locations[ind, :].ravel().tolist())

    return dict(average_waveform=average_waveform.astype(np.float32),
                channel_ids=unit_waveforms_channel_ids.astype(np.int32),
                channel_locations=channel_locations0,
                sampling_frequency=sampling_frequency)
Beispiel #27
0
def _deserialize_result(obj):
    import kachery as ka
    result = Result()

    result.runtime_info = obj['runtime_info']
    result.runtime_info['stdout'] = ka.load_text(result.runtime_info['stdout'])
    result.runtime_info['stderr'] = ka.load_text(result.runtime_info['stderr'])
    if result.runtime_info['stdout'] is None:
        return None
    if result.runtime_info['stderr'] is None:
        return None

    output_files = obj['output_files']
    for oname, path in output_files.items():
        path2 = ka.load_file(path)
        if path2 is None:
            return None
        setattr(result.outputs, oname, File(path2))
        result._output_names.append(oname)

    result.retval = obj['retval']
    result.hash_object = obj['hash_object']
    return result
Beispiel #28
0
def get_sorting_unit_info(snippets_h5, unit_id):
    import h5py
    h5_path = ka.load_file(snippets_h5)
    # with h5py.File(h5_path, 'r') as f:
    #     unit_ids = np.array(f.get('unit_ids'))
    #     channel_ids = np.array(f.get('channel_ids'))
    #     channel_locations = np.array(f.get(f'channel_locations'))
    #     sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
    #     if np.isnan(sampling_frequency):
    #         print('WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.')
    #         sampling_frequency = 30000
    #     unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids'))
    #     print(unit_waveforms_channel_ids)
    unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5(
        h5_path, unit_id)

    channel_locations_2 = []
    for ch_id in unit_waveforms_channel_ids:
        ind = np.where(unit_waveforms_channel_ids == ch_id)[0]
        channel_locations_2.append(channel_locations0[ind].ravel().tolist())

    return dict(channel_ids=unit_waveforms_channel_ids.astype(np.int32),
                channel_locations=channel_locations_2,
                sampling_frequency=sampling_frequency)
recording_path = 'sha1dir://ed0fe4de4ef2c54b7c9de420c87f9df200721b24.synth_visapy/mea_c30/set4'
sorting_true_path = 'sha1dir://ed0fe4de4ef2c54b7c9de420c87f9df200721b24.synth_visapy/mea_c30/set4/firings_true.mda'

sorter_name = 'kilosort2'
sorter = getattr(sorters, sorter_name)
params = {}

# Determine whether we are going to use gpu based on the name of the sorter
gpu = sorter_name in ['kilosort2', 'kilosort', 'tridesclous', 'ironclust']

# In the future we will check whether we have the correct version of the wrapper here
# Version: 0.1.5-w1

# Download the data (if needed)
ka.set_config(fr='default_readonly')
ka.load_file(recording_path + '/raw.mda')

# Run the spike sorting
with hither.config(container='docker://magland/sf-kilosort2:0.1.5', gpu=gpu):
    sorting_result = sorter.run(recording_path=recording_path,
                                sorting_out=hither.File(),
                                **params)
assert sorting_result.success
sorting_path = sorting_result.outputs.sorting_out

# Compare with ground truth
with hither.config(container='default'):
    compare_result = processing.compare_with_truth.run(
        sorting_path=sorting_path,
        sorting_true_path=sorting_true_path,
        json_out=hither.File())
Beispiel #30
0
        def execute(_force_run=False, _container=None, **kwargs):
            import kachery as ka
            hash_object = dict(api_version='0.1.0',
                               name=name,
                               version=version,
                               input_files=dict(),
                               output_files=dict(),
                               parameters=dict())
            resolved_kwargs = dict()
            hither_input_files = getattr(f, '_hither_input_files', [])
            hither_output_files = getattr(f, '_hither_output_files', [])
            hither_parameters = getattr(f, '_hither_parameters', [])

            # Let's make sure the input and output files are all coming in as File objects
            for input_file in hither_input_files:
                iname = input_file['name']
                if iname in kwargs:
                    if type(kwargs[iname]) == str:
                        kwargs[iname] = File(kwargs[iname])
            for output_file in hither_output_files:
                oname = output_file['name']
                if oname in kwargs:
                    if type(kwargs[oname]) == str:
                        kwargs[oname] = File(kwargs[oname])

            input_file_keys = []
            for input_file in hither_input_files:
                iname = input_file['name']
                if iname not in kwargs or kwargs[iname] is None:
                    if input_file['required']:
                        raise Exception(
                            'Missing required input file: {}'.format(iname))
                else:
                    x = kwargs[iname]
                    # a hither File object
                    if x._path is None:
                        raise Exception(
                            'Input file has no path: {}'.format(iname))
                    # we really want the path
                    x2 = x._path
                    if _is_hash_url(x2):
                        # a hash url
                        y = ka.load_file(x2)
                        if y is None:
                            raise Exception(
                                'Unable to load input file {}: {}'.format(
                                    iname, x))
                        x2 = y
                    info0 = ka.get_file_info(x2)
                    if info0 is None:
                        raise Exception(
                            'Unable to get info for input file {}: {}'.format(
                                iname, x2))
                    tmp0 = dict()
                    for field0 in ['sha1', 'md5']:
                        if field0 in info0:
                            tmp0[field0] = info0[field0]
                    hash_object['input_files'][iname] = tmp0
                    input_file_keys.append(iname)
                    resolved_kwargs[iname] = x2

            output_file_keys = []
            for output_file in hither_output_files:
                oname = output_file['name']
                if oname not in kwargs or kwargs[oname] is None:
                    if output_file['required']:
                        raise Exception(
                            'Missing required output file: {}'.format(oname))
                else:
                    x = kwargs[oname]
                    x2 = x._path
                    if _is_hash_url(x2):
                        raise Exception(
                            'Output file {} cannot be a hash URI: {}'.format(
                                oname, x2))
                    resolved_kwargs[oname] = x2
                    if oname in resolved_kwargs:
                        hash_object['output_files'][oname] = True
                        output_file_keys.append(oname)

            for parameter in hither_parameters:
                pname = parameter['name']
                if pname not in kwargs or kwargs[pname] is None:
                    if parameter['required']:
                        raise Exception(
                            'Missing required parameter: {}'.format(pname))
                    if 'default' in parameter:
                        resolved_kwargs[pname] = parameter['default']
                else:
                    resolved_kwargs[pname] = kwargs[pname]
                hash_object['parameters'][pname] = resolved_kwargs[pname]

            for k, v in kwargs.items():
                if k not in resolved_kwargs:
                    hash_object['parameters'][k] = v
                    resolved_kwargs[k] = v

            if not _force_run:
                result_serialized: Union[dict, None] = _load_result(
                    hash_object=hash_object)
                if result_serialized is not None:
                    result0 = _deserialize_result(result_serialized)
                    if result0 is not None:
                        for output_file in hither_output_files:
                            oname = output_file['name']
                            if oname in resolved_kwargs:
                                shutil.copyfile(
                                    getattr(result0.outputs, oname)._path,
                                    resolved_kwargs[oname])
                        _handle_temporary_outputs([
                            getattr(result0.outputs, oname)
                            for oname in output_file_keys
                        ])
                        if result0.runtime_info['stdout']:
                            sys.stdout.write(result0.runtime_info['stdout'])
                        if result0.runtime_info['stderr']:
                            sys.stderr.write(result0.runtime_info['stderr'])
                        print(
                            '===== Hither: using cached result for {}'.format(
                                name))
                        return result0

            with ConsoleCapture() as cc:
                if _container is None:
                    returnval = f(**resolved_kwargs)
                else:
                    if hasattr(f, '_hither_containers'):
                        if _container in getattr(f, '_hither_containers'):
                            _container = getattr(
                                f, '_hither_containers')[_container]
                    returnval = run_function_in_container(
                        name=name,
                        function=f,
                        input_file_keys=input_file_keys,
                        output_file_keys=output_file_keys,
                        container=_container,
                        keyword_args=resolved_kwargs)

            result = Result()
            result.outputs = Outputs()
            for oname in hash_object['output_files'].keys():
                setattr(result.outputs, oname, kwargs[k])
                result._output_names.append(oname)
            result.runtime_info = cc.runtime_info()
            result.hash_object = hash_object
            result.retval = returnval
            _handle_temporary_outputs(
                [getattr(result.outputs, oname) for oname in output_file_keys])
            _store_result(serialized_result=_serialize_result(result))
            return result