def __init__(self,
                 *,
                 raw,
                 raw_num_channels,
                 num_frames,
                 samplerate,
                 channel_ids,
                 channel_map,
                 channel_positions,
                 p2p,
                 download=False):
        se.RecordingExtractor.__init__(self)

        self._raw = raw
        self._num_frames = num_frames
        self._samplerate = samplerate
        self._raw_num_channels = raw_num_channels
        self._channel_ids = channel_ids
        self._channel_map = channel_map
        self._channel_positions = channel_positions
        self._p2p = p2p

        if download:
            kp.load_file(self._raw)

        for id in self._channel_ids:
            pos = self._channel_positions[str(id)]
            self.set_channel_property(id, 'location', pos)
示例#2
0
    def __init__(self, arg, samplerate=None):
        super().__init__()
        if (isinstance(arg, dict)) and ('sorting_format' in arg):
            obj = dict(arg)
        else:
            obj = _create_object_for_arg(arg, samplerate=samplerate)
            assert obj is not None, f'Unable to create sorting from arg: {arg}'
        self._object: dict = obj

        if 'firings' in self._object:
            sorting_format = 'mda'
            data={'firings': self._object['firings'], 'samplerate': self._object.get('samplerate', 30000)}
        else:
            sorting_format = self._object['sorting_format']
            data: dict = self._object['data']
        if sorting_format == 'mda':
            firings_path = kp.load_file(data['firings'])
            assert firings_path is not None, f'Unable to load firings file: {data["firings"]}'
            self._sorting: se.SortingExtractor = MdaSortingExtractor(firings_file=firings_path, samplerate=data['samplerate'])
        elif sorting_format == 'h5_v1':
            h5_path = kp.load_file(data['h5_path'])
            self._sorting = H5SortingExtractorV1(h5_path=h5_path)
        elif sorting_format == 'npy1':
            times_npy = kp.load_npy(data['times_npy_uri'])
            labels_npy = kp.load_npy(data['labels_npy_uri'])
            samplerate = data['samplerate']
            S = se.NumpySortingExtractor()
            S.set_sampling_frequency(samplerate)
            S.set_times_labels(times_npy.ravel(), labels_npy.ravel())
            self._sorting = S
        elif sorting_format == 'snippets1':
            S = Snippets1SortingExtractor(snippets_h5_uri = data['snippets_h5_uri'], p2p=True)
            self._sorting = S
        elif sorting_format == 'npy2':
            npz = kp.load_npy(data['npz_uri'])
            times_npy = npz['spike_indexes']
            labels_npy = npz['spike_labels']
            samplerate = float(npz['sampling_frequency'])
            S = se.NumpySortingExtractor()
            S.set_sampling_frequency(samplerate)
            S.set_times_labels(times_npy.ravel(), labels_npy.ravel())
            self._sorting = S
        elif sorting_format == 'nwb':
            from .nwbextractors import NwbSortingExtractor
            path0 = kp.load_file(data['path'])
            self._sorting: se.SortingExtractor = NwbSortingExtractor(path0)
        elif sorting_format == 'in_memory':
            S = get_in_memory_object(data)
            if S is None:
                raise Exception('Unable to find in-memory object for sorting')
            self._sorting = S
        else:
            raise Exception(f'Unexpected sorting format: {sorting_format}')

        self.copy_unit_properties(sorting=self._sorting)
def create_recording_object_from_spikeforest_recdir(recdir, label):
    raw_path = kp.load_file(recdir + '/raw.mda')
    raw_path = kp.store_file(raw_path, basename=label +
                             '-raw.mda')  # store with manifest
    print(raw_path)
    params = kp.load_object(recdir + '/params.json')
    geom_path = kp.load_file(recdir + '/geom.csv')
    geom = _load_geom_from_csv(geom_path)
    recording_object = dict(recording_format='mda',
                            data=dict(raw=raw_path, geom=geom, params=params))
    return recording_object
示例#4
0
    def __init__(self, probe_file, xml_file, nrs_file, dat_file):
        se.RecordingExtractor.__init__(self)
        # info = check_load_nrs(dirpath)
        # assert info is not None
        probe_obj = kp.load_object(probe_file)
        xml_file = kp.load_file(xml_file)
        # nrs_file = kp.load_file(nrs_file)
        dat_file = kp.load_file(dat_file)

        from xml.etree import ElementTree as ET
        xml = ET.parse(xml_file)
        root_element = xml.getroot()
        try:
            txt = root_element.find('acquisitionSystem/samplingRate').text
            assert txt is not None
            self._samplerate = float(txt)
        except:
            raise Exception('Unable to load acquisitionSystem/samplingRate')
        try:
            txt = root_element.find('acquisitionSystem/nChannels').text
            assert txt is not None
            self._nChannels = int(txt)
        except:
            raise Exception('Unable to load acquisitionSystem/nChannels')
        try:
            txt = root_element.find('acquisitionSystem/nBits').text
            assert txt is not None
            self._nBits = int(txt)
        except:
            raise Exception('Unable to load acquisitionSystem/nBits')

        if self._nBits == 16:
            dtype = np.int16
        elif self._nBits == 32:
            dtype = np.int32
        else:
            raise Exception(f'Unexpected nBits: {self._nBits}')

        self._rec = se.BinDatRecordingExtractor(
            dat_file,
            sampling_frequency=self._samplerate,
            numchan=self._nChannels,
            dtype=dtype)

        self._channel_ids = probe_obj['channel']
        for ii in range(len(probe_obj['channel'])):
            channel = probe_obj['channel'][ii]
            x = probe_obj['x'][ii]
            y = probe_obj['y'][ii]
            z = probe_obj['z'][ii]
            group = probe_obj.get('group', probe_obj.get('shank'))[ii]
            self.set_channel_property(channel, 'location', [x, y, z])
            self.set_channel_property(channel, 'group', group)
def reup_file(object: Any):
    thekey = recording_key if recording_key[
        'key_field'] in object else sorting_key
    if VERBOSE:
        print(
            f"Executing: object['{thekey['key_field']}'] = kp.store_file(kp.load_file(object['{thekey['key_field']}']), basename={thekey['basename']})"
        )
    if DRY_RUN: return

    # Turns out that kachery doesn't handle big files without manifests all that well. Which was the point of this exercise.
    # So let's take advantage of being on the same filesystem to do a little magic.
    raw = object[thekey["key_field"]]
    print(f'Got raw: {raw}')
    if ('sha1dir' in raw):
        key_field = object[thekey['key_field']]
        sha1dir = key_field.split('/')[2]
        print(f'Got dir: {sha1dir}')
        kp.load_file(f'sha1://{sha1dir}')
        print(f"Fetching hash for file: {key_field}")
        reformed_field = trim_dir_annotation(f"{key_field}")
        if VERBOSE: print(f"(using reformed field {reformed_field})")
        try:
            sha1 = ka.get_file_hash(reformed_field)
        except:
            if FORCE:
                print(
                    f"\t** Trimmed lookup didn't work, falling back to kp.load_file({key_field})"
                )
                kp.load_file(key_field)
                sha1 = ka.get_file_hash(key_field)
            else:
                print(
                    f"Error on ka.get_file_hash({reformed_field}) -- aborting")
                exit()
    else:
        #sha1 = '/'.join(raw.split('/')[2:])
        sha1 = raw.split('/')[2]
    print(f'Got sha1: {sha1}')
    src_path = f'/mnt/ceph/users/magland/kachery-storage/sha1/{sha1[0]}{sha1[1]}/{sha1[2]}{sha1[3]}/{sha1[4]}{sha1[5]}/{sha1}'
    dest_path = f'/mnt/ceph/users/jsoules/kachery-storage/sha1/{sha1[0]}{sha1[1]}/{sha1[2]}{sha1[3]}/{sha1[4]}{sha1[5]}/{sha1}'
    if VERBOSE: print(f"Executing: shutil.copyfile({src_path}, {dest_path})")
    if not exists(dest_path):
        pathlib.Path('/'.join(dest_path.split('/')[:-1])).mkdir(parents=True,
                                                                exist_ok=True)
        copyfile(src_path, dest_path)
        print("\tCompleted copy operation.")
    object[thekey['key_field']] = kp.store_file(
        kp.load_file(object[thekey['key_field']]),
        basename=f"{thekey['basename']}")
def get_sorting_unit_info(snippets_h5, unit_id):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    # with h5py.File(h5_path, 'r') as f:
    #     unit_ids = np.array(f.get('unit_ids'))
    #     channel_ids = np.array(f.get('channel_ids'))
    #     channel_locations = np.array(f.get(f'channel_locations'))
    #     sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
    #     if np.isnan(sampling_frequency):
    #         print('WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.')
    #         sampling_frequency = 30000
    #     unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids'))
    #     print(unit_waveforms_channel_ids)
    unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5(
        h5_path, unit_id)

    channel_locations_2 = []
    for ch_id in unit_waveforms_channel_ids:
        ind = np.where(unit_waveforms_channel_ids == ch_id)[0]
        channel_locations_2.append(channel_locations0[ind].ravel().tolist())

    return dict(channel_ids=unit_waveforms_channel_ids.astype(np.int32),
                channel_locations=channel_locations_2,
                sampling_frequency=sampling_frequency)
def get_sorting_unit_snippets(snippets_h5, unit_id, time_range,
                              max_num_snippets):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    # with h5py.File(h5_path, 'r') as f:
    #     unit_ids = np.array(f.get('unit_ids'))
    #     channel_ids = np.array(f.get('channel_ids'))
    #     channel_locations = np.array(f.get(f'channel_locations'))
    #     sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
    #     if np.isnan(sampling_frequency):
    #         print('WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.')
    #         sampling_frequency = 30000
    #     unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
    #     unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
    #     unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids'))
    #     print(unit_waveforms_channel_ids)
    unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5(
        h5_path, unit_id)

    snippets = [{
        'index': j,
        'unitId': unit_id,
        'waveform': unit_waveforms[j].astype(np.float32),
        'timepoint': float(unit_spike_train[j])
    } for j in range(unit_waveforms.shape[0])
                if time_range['min'] <= unit_spike_train[j]
                and unit_spike_train[j] < time_range['max']]

    return dict(channel_ids=unit_waveforms_channel_ids.astype(np.int32),
                channel_locations=channel_locations0.astype(np.float32),
                sampling_frequency=sampling_frequency,
                snippets=snippets[:max_num_snippets])
def fetch_spike_amplitudes(snippets_h5, unit_id):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    # with h5py.File(h5_path, 'r') as f:
    #     unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
    #     unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))

    unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5(
        h5_path, unit_id)
    average_waveform = np.mean(unit_waveforms, axis=0)
    peak_channel_index = _compute_peak_channel_index_from_average_waveform(
        average_waveform)
    maxs = [
        np.max(unit_waveforms[i][peak_channel_index, :])
        for i in range(unit_waveforms.shape[0])
    ]
    mins = [
        np.min(unit_waveforms[i][peak_channel_index, :])
        for i in range(unit_waveforms.shape[0])
    ]
    peak_amplitudes = np.array([maxs[i] - mins[i] for i in range(len(mins))])

    timepoints = unit_spike_train.astype(np.float32)
    amplitudes = peak_amplitudes.astype(np.float32)

    sort_inds = np.argsort(timepoints)
    timepoints = timepoints[sort_inds]
    amplitudes = amplitudes[sort_inds]

    return dict(timepoints=timepoints, amplitudes=amplitudes)
示例#9
0
def cat_file(uri, start, end, exp_nop2p, exp_file_server_url):
    old_stdout = sys.stdout
    sys.stdout = sys.stderr

    kp._experimental_config(nop2p=exp_nop2p, file_server_urls=list(exp_file_server_url))

    if start is None and end is None:
        path1 = kp.load_file(uri)
        if not path1:
            raise Exception('Error loading file for cat.')
        sys.stdout = old_stdout
        with open(path1, 'rb') as f:
            while True:
                data = os.read(f.fileno(), 4096)
                if len(data) == 0:
                    break
                os.write(sys.stdout.fileno(), data)
    else:
        assert start is not None and end is not None
        start = int(start)
        end = int(end)
        assert start <= end
        if start == end:
            return
        sys.stdout = old_stdout
        kp.load_bytes(uri=uri, start=start, end=end, write_to_stdout=True)
示例#10
0
def fetch_spike_waveforms(snippets_h5, unit_ids, spike_indices):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    spikes = []
    with h5py.File(h5_path, 'r') as f:
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        for ii, unit_id in enumerate(unit_ids):
            unit_waveforms = np.array(
                f.get(f'unit_waveforms/{unit_id}/waveforms'))
            unit_waveforms_channel_ids = np.array(
                f.get(f'unit_waveforms/{unit_id}/channel_ids'))
            unit_waveforms_spike_train = np.array(
                f.get(f'unit_waveforms/{unit_id}/spike_train'))
            average_waveform = np.mean(unit_waveforms, axis=0)
            channel_maximums = np.max(np.abs(average_waveform), axis=1)
            maxchan_index = np.argmax(channel_maximums)
            maxchan_id = unit_waveforms_channel_ids[maxchan_index]
            for spike_index in spike_indices[ii]:
                spikes.append(
                    dict(unit_id=unit_id,
                         spike_index=spike_index,
                         spike_time=unit_waveforms_spike_train[spike_index],
                         channel_id=maxchan_id,
                         waveform=unit_waveforms[
                             spike_index,
                             maxchan_index, :].squeeze().tolist()))
    return {'sampling_frequency': sampling_frequency, 'spikes': spikes}
def fetch_average_waveform_plot_data(snippets_h5, unit_id):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
        unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
        unit_waveforms_channel_ids = np.array(
            f.get(f'unit_waveforms/{unit_id}/channel_ids'))
        print(unit_waveforms_channel_ids)

    average_waveform = np.mean(unit_waveforms, axis=0)
    channel_maximums = np.max(np.abs(average_waveform), axis=1)
    maxchan_index = np.argmax(channel_maximums)
    maxchan_id = unit_waveforms_channel_ids[maxchan_index]

    return dict(channel_id=int(maxchan_id),
                sampling_frequency=sampling_frequency,
                average_waveform=average_waveform[maxchan_index, :].astype(
                    np.float32))
示例#12
0
def _try_mda_create_object(arg: Union[str, dict]) -> Union[None, dict]:
    if isinstance(arg, str):
        path = arg
        if path.startswith('sha1dir') or path.startswith('/'):
            dd = kp.read_dir(path)
            if dd is not None:
                if 'raw.mda' in dd['files'] and 'params.json' in dd[
                        'files'] and 'geom.csv' in dd['files']:
                    raw_path = path + '/raw.mda'
                    params_path = path + '/params.json'
                    geom_path = path + '/geom.csv'
                    geom_path_resolved = kp.load_file(geom_path)
                    assert geom_path_resolved is not None, f'Unable to load geom.csv from: {geom_path}'
                    params = kp.load_object(params_path)
                    assert params is not None, f'Unable to load params.json from: {params_path}'
                    geom = _load_geom_from_csv(geom_path_resolved)
                    return dict(recording_format='mda',
                                data=dict(raw=raw_path,
                                          geom=geom,
                                          params=params))

    if isinstance(arg, dict):
        if ('raw' in arg) and ('geom' in arg) and ('params' in arg) and (type(
                arg['geom']) == list) and (type(arg['params']) == dict):
            return dict(recording_format='mda',
                        data=dict(raw=arg['raw'],
                                  geom=arg['geom'],
                                  params=arg['params']))

    return None
示例#13
0
    def __init__(self, *, snippets_h5_uri: str, p2p: bool = False):
        se.RecordingExtractor.__init__(self)

        snippets_h5_path = kp.load_file(snippets_h5_uri, p2p=p2p)

        self._snippets_h5_path: str = snippets_h5_path

        channel_ids_set: Set[int] = set()
        max_timepoint: int = 0
        with h5py.File(self._snippets_h5_path, 'r') as f:
            sampling_frequency: float = np.array(
                f.get('sampling_frequency'))[0]
            if np.isnan(sampling_frequency):
                print(
                    'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
                )
                sampling_frequency = 30000
            self.set_sampling_frequency(sampling_frequency)
            self._unit_ids: List[int] = np.array(
                f.get('unit_ids')).astype(int).tolist()
            for unit_id in self._unit_ids:
                unit_spike_train = np.array(
                    f.get(f'unit_spike_trains/{unit_id}'))
                max_timepoint = int(
                    max(max_timepoint, np.max(unit_spike_train)))
                # unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
                unit_waveforms_channel_ids = np.array(
                    f.get(f'unit_waveforms/{unit_id}/channel_ids'))
                for id in unit_waveforms_channel_ids:
                    channel_ids_set.add(int(id))
        self._channel_ids: List[int] = sorted(list(channel_ids_set))
        self._num_frames: int = max_timepoint + 1
示例#14
0
def fetch_pca_features(snippets_h5, unit_ids):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    with h5py.File(h5_path, 'r') as f:
        sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
        if np.isnan(sampling_frequency):
            print(
                'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
            )
            sampling_frequency = 30000
        x = [
            dict(
                unit_id=unit_id,
                unit_waveforms_spike_train=np.array(
                    f.get(f'unit_waveforms/{unit_id}/spike_train')),
                # unit_waveforms_spike_train=_subsample(np.array(f.get(f'unit_spike_trains/{unit_id}')), 1000),
                unit_waveforms=np.array(
                    f.get(f'unit_waveforms/{unit_id}/waveforms')),
                unit_waveforms_channel_ids=np.array(
                    f.get(f'unit_waveforms/{unit_id}/channel_ids')))
            for unit_id in unit_ids
        ]
        channel_ids = _intersect_channel_ids(
            [a['unit_waveforms_channel_ids'] for a in x])
        assert len(channel_ids) > 0, 'No channel ids in intersection'
        for a in x:
            unit_waveforms = a['unit_waveforms']
            unit_waveforms_channel_ids = a['unit_waveforms_channel_ids']
            inds = [
                np.where(unit_waveforms_channel_ids == ch_id)[0][0]
                for ch_id in channel_ids
            ]
            a['unit_waveforms_2'] = unit_waveforms[:, inds, :]
            a['labels'] = np.ones((unit_waveforms.shape[0], )) * a['unit_id']

    unit_waveforms = np.concatenate([a['unit_waveforms_2'] for a in x], axis=0)
    spike_train = np.concatenate([a['unit_waveforms_spike_train'] for a in x])
    labels = np.concatenate([a['labels'] for a in x]).astype(int)

    from sklearn.decomposition import PCA

    nf = 5  # number of features

    # list of arrays
    W = unit_waveforms  # ntot x M x T

    # ntot x MT
    X = W.reshape((W.shape[0], W.shape[1] * W.shape[2]))

    pca = PCA(n_components=nf)
    pca.fit(X)

    features = pca.transform(X)  # n x nf

    return dict(
        times=(spike_train / sampling_frequency).tolist(),
        features=[features[:, ii].squeeze().tolist() for ii in range(nf)],
        labels=labels.tolist())
示例#15
0
def _keep_good_units(sorting_obj, cluster_groups_csv_uri):
    sorting = LabboxEphysSortingExtractor(sorting_obj)
    df = pd.read_csv(kp.load_file(cluster_groups_csv_uri), delimiter='\t')
    df_good = df.loc[df['group'] == 'good']
    good_unit_ids = df_good['cluster_id'].to_numpy().tolist()
    sorting_good = se.SubSortingExtractor(parent_sorting=sorting,
                                          unit_ids=good_unit_ids)
    return _create_npy1_sorting_object(sorting=sorting_good)
def create_sorting_object_from_spikeforest_recdir(recdir, label):
    params = kp.load_object(recdir + '/params.json')
    firings_path = kp.load_file(recdir + '/firings_true.mda')
    firings_path = ka.store_file(firings_path, basename=label + '-firings.mda')
    sorting_object = dict(sorting_format='mda',
                          data=dict(firings=firings_path,
                                    samplerate=params['samplerate']))
    print(sorting_object)
    return sorting_object
示例#17
0
def load_chanmap_data_from_mat(uri_mat):
    m = sio.loadmat(kp.load_file(uri_mat))
    chanmap = m['chanMap0ind'].squeeze()
    xcoords = m['xcoords'].squeeze()
    ycoords = m['ycoords'].squeeze()
    num_chan = len(chanmap)
    assert len(xcoords) == num_chan
    assert len(ycoords) == num_chan
    return chanmap, xcoords, ycoords
示例#18
0
    def __init__(self, firings_file, samplerate):
        SortingExtractor.__init__(self)
        self._firings_path = kp.load_file(firings_file)
        if not self._firings_path:
            raise Exception('Unable to load firings file: ' + firings_file)

        self._firings = readmda(self._firings_path)
        self._sampling_frequency = samplerate
        self._times = self._firings[1, :]
        self._labels = self._firings[2, :]
        self._unit_ids = np.unique(self._labels).astype(int)
    def __init__(self, arg, samplerate=None):
        super().__init__()
        if (isinstance(arg, dict)) and ('sorting_format' in arg):
            obj = dict(arg)
        else:
            obj = _create_object_for_arg(arg, samplerate=samplerate)
            assert obj is not None, f'Unable to create sorting from arg: {arg}'
        self._object: dict = obj

        sorting_format = self._object['sorting_format']
        data: dict = self._object['data']
        if sorting_format == 'mda':
            firings_path = kp.load_file(data['firings'])
            assert firings_path is not None, f'Unable to load firings file: {data["firings"]}'
            self._sorting: se.SortingExtractor = MdaSortingExtractor(
                firings_file=firings_path, samplerate=data['samplerate'])
        elif sorting_format == 'h5_v1':
            h5_path = kp.load_file(data['h5_path'])
            self._sorting = H5SortingExtractorV1(h5_path=h5_path)
        elif sorting_format == 'npy1':
            times_npy = kp.load_npy(data['times_npy_uri'])
            labels_npy = kp.load_npy(data['labels_npy_uri'])
            samplerate = data['samplerate']
            S = se.NumpySortingExtractor()
            S.set_sampling_frequency(samplerate)
            S.set_times_labels(times_npy.ravel(), labels_npy.ravel())
            self._sorting = S
        elif sorting_format == 'npy2':
            npz = kp.load_npy(data['npz_uri'])
            times_npy = npz['spike_indexes']
            labels_npy = npz['spike_labels']
            samplerate = float(npz['sampling_frequency'])
            S = se.NumpySortingExtractor()
            S.set_sampling_frequency(samplerate)
            S.set_times_labels(times_npy.ravel(), labels_npy.ravel())
            self._sorting = S
        else:
            raise Exception(f'Unexpected sorting format: {sorting_format}')

        self.copy_unit_properties(sorting=self._sorting)
def get_unit_snrs(snippets_h5):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    ret = {}
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        for unit_id in unit_ids:
            unit_waveforms = np.array(
                f.get(f'unit_waveforms/{unit_id}/waveforms'))  # n x M x T
            ret[str(unit_id)] = _compute_unit_snr_from_waveforms(
                unit_waveforms)
    return ret
def fetch_average_waveform_2(snippets_h5, unit_id):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5(
        h5_path, unit_id)

    average_waveform = np.mean(unit_waveforms, axis=0)

    return dict(average_waveform=average_waveform.astype(np.float32),
                channel_ids=unit_waveforms_channel_ids.astype(np.int32),
                channel_locations=channel_locations0.astype(np.float32),
                sampling_frequency=sampling_frequency)
示例#22
0
    def __init__(self, arg: Union[str, dict], download: bool=False):
        super().__init__()
        obj = _create_object_for_arg(arg)
        assert obj is not None
        self._object: dict = obj
        
        recording_format = self._object['recording_format']
        data: dict = self._object['data']
        if recording_format == 'mda':
            self._recording: se.RecordingExtractor = MdaRecordingExtractor(timeseries_path=data['raw'], samplerate=data['params']['samplerate'], geom=np.array(data['geom']), download=download)
        elif recording_format == 'nrs':
            self._recording: se.RecordingExtractor = NrsRecordingExtractor(**data)
        elif recording_format == 'nwb':
            path0 = kp.load_file(data['path'])
            self._recording: se.RecordingExtractor = NwbRecordingExtractor(path0, electrical_series_name='e-series')
        elif recording_format == 'bin1':
            self._recording: se.RecordingExtractor = Bin1RecordingExtractor(**data, p2p=True, download=download)
        elif recording_format == 'subrecording':
            R = LabboxEphysRecordingExtractor(data['recording'], download=download)
            if 'channel_ids' in data:
                channel_ids = np.array(data['channel_ids'])
            elif 'group' in data:
                channel_ids = np.array(R.get_channel_ids())
                groups = R.get_channel_groups(channel_ids=R.get_channel_ids())
                group = int(data['group'])
                inds = np.where(np.array(groups) == group)[0]
                channel_ids = channel_ids[inds]
            elif 'groups' in data:
                raise Exception('This case not yet handled.')
            else:
                channel_ids = None
            if 'start_frame' in data:
                start_frame = data['start_frame']
                end_frame = data['end_frame']
            else:
                start_frame = None
                end_frame = None
            self._recording: se.RecordingExtractor = se.SubRecordingExtractor(
                parent_recording=R,
                channel_ids=channel_ids,
                start_frame=start_frame,
                end_frame=end_frame
            )
        elif recording_format == 'filtered':
            R = LabboxEphysRecordingExtractor(data['recording'], download=download)
            self._recording: se.RecordingExtractor = _apply_filters(recording=R, filters=data['filters'])
        else:
            raise Exception(f'Unexpected recording format: {recording_format}')

        self.copy_channel_properties(recording=self._recording)
def _try_mda_create_object(arg: Union[str, dict],
                           samplerate=None) -> Union[None, dict]:
    if isinstance(arg, str):
        path = arg
        if not kp.load_file(path):
            return None
        return dict(sorting_format='mda',
                    data=dict(firings=path, samplerate=samplerate))

    if isinstance(arg, dict):
        if 'firings' in arg:
            return dict(recording_format='mda',
                        data=dict(firings=arg['firings'],
                                  samplerate=arg.get('samplerate', None)))

    return None
示例#24
0
def get_peak_channels(snippets_h5):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    ret = {}
    with h5py.File(h5_path, 'r') as f:
        unit_ids = np.array(f.get('unit_ids'))
        for unit_id in unit_ids:
            unit_waveforms = np.array(
                f.get(f'unit_waveforms/{unit_id}/waveforms'))  # n x M x T
            channel_ids = np.array(
                f.get(f'unit_waveforms/{unit_id}/channel_ids'))  # n
            peak_channel_index = _compute_peak_channel_index_from_waveforms(
                unit_waveforms)
            ret[str(unit_id)] = int(channel_ids[peak_channel_index])
    return ret
示例#25
0
    def __init__(self, *, snippets_h5_uri: str, p2p: bool = False):
        se.RecordingExtractor.__init__(self)

        snippets_h5_path = kp.load_file(snippets_h5_uri, p2p=p2p)

        self._snippets_h5_path: str = snippets_h5_path

        channel_ids_set: Set[int] = set()
        max_timepoint: int = 0
        with h5py.File(self._snippets_h5_path, 'r') as f:
            self._sampling_frequency: float = np.array(
                f.get('sampling_frequency'))[0]
            if np.isnan(self._sampling_frequency):
                print(
                    'WARNING: sampling frequency is nan. Using 30000 for now. Please correct the snippets file.'
                )
                self._sampling_frequency = 30000
            self._unit_ids: List[int] = np.array(
                f.get('unit_ids')).astype(int).tolist()
            for unit_id in self._unit_ids:
                unit_spike_train = np.array(
                    f.get(f'unit_spike_trains/{unit_id}'))
                max_timepoint = int(
                    max(max_timepoint, np.max(unit_spike_train)))
                # unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
                unit_waveforms_channel_ids = np.array(
                    f.get(f'unit_waveforms/{unit_id}/channel_ids'))
                for id in unit_waveforms_channel_ids:
                    channel_ids_set.add(int(id))
            self._channel_ids: List[int] = sorted(list(channel_ids_set))
            try:
                self._num_frames = f.get('num_frames')[0].item()
            except:
                print(
                    'Unable to load num_frames. Please update snippets file.')
                self._num_frames: int = max_timepoint + 1
            try:
                channel_locations = np.array(f.get(f'channel_locations'))
                self.set_channel_locations(channel_locations)
            except:
                print(
                    'WARNING: using [0, 0] for channel locations. Please adjust snippets file'
                )
                for channel_id in self._channel_ids:
                    self.set_channel_property(channel_id, 'location', [0, 0])
def individual_cluster_features(snippets_h5, unit_id, max_num_events=1000):
    import h5py
    h5_path = kp.load_file(snippets_h5, p2p=False)
    assert h5_path is not None
    # with h5py.File(h5_path, 'r') as f:
    #     unit_ids = np.array(f.get('unit_ids'))
    #     channel_ids = np.array(f.get('channel_ids'))
    #     channel_locations = np.array(f.get(f'channel_locations'))
    #     sampling_frequency = np.array(f.get('sampling_frequency'))[0].item()
    #     unit_spike_train = np.array(f.get(f'unit_spike_trains/{unit_id}'))
    #     unit_waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms')) # L x M x T
    #     unit_waveforms_channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids'))
    #     if len(unit_spike_train) > max_num_events:
    #         inds = subsample_inds(len(unit_spike_train), max_num_events)
    #         unit_spike_train = unit_spike_train[inds]
    #         unit_waveforms = unit_waveforms[inds]

    unit_waveforms, unit_waveforms_channel_ids, channel_locations0, sampling_frequency, unit_spike_train = le.get_unit_waveforms_from_snippets_h5(
        h5_path, unit_id, max_num_events=max_num_events)

    from sklearn.decomposition import PCA
    nf = 2  # number of features

    # L = number of waveforms (number of spikes)
    # M = number of electrodes in nbhd
    # T = num. timepoints in the snippet
    W = unit_waveforms  # L x M x T

    # subtract mean for each channel and waveform
    for i in range(W.shape[0]):
        for m in range(W.shape[1]):
            W[i, m, :] = W[i, m, :] - np.mean(W[i, m, :])
    X = W.reshape((W.shape[0], W.shape[1] * W.shape[2]))  # L x MT
    pca = PCA(n_components=nf)
    pca.fit(X)

    # L = number of waveforms (number of spikes)
    # nf = number of features
    features = pca.transform(X)  # L x nf

    return dict(timepoints=unit_spike_train.astype(np.float32),
                x=features[:, 0].squeeze().astype(np.float32),
                y=features[:, 1].squeeze().astype(np.float32))
示例#27
0
def test1():
    f = kp.create_feed('f1')
    f2 = kp.load_feed('f1')
    assert (f.get_uri() == f2.get_uri())
    sf = f.get_subfeed('sf1')
    sf.append_message({'m': 1})
    assert (sf.get_num_messages() == 1)
    x = kp.store_text('abc')
    sf.set_access_rules({'rules': []})
    r = sf.get_access_rules()

    try:
        a = kp.load_file(
            'sha1://e25f95079381fe07651aa7d37c2f4e8bda19727c/file.txt')
        raise Exception('Did not get expected error')
    except LoadFileError as err:
        pass  # expected
    except Exception as err:
        raise err
示例#28
0
def _download_files_in_item(x):
    if type(x) == str:
        if x.startswith('sha1://') or x.startswith('sha1dir://'):
            if not ka.get_file_info(x, fr=dict(url=None)):
                a = kp.load_file(x)
                assert a is not None, f'Unable to download file: {x}'
        return
    elif type(x) == dict:
        for _, val in x.items():
            _download_files_in_item(val)
        return
    elif type(x) == list:
        for y in x:
            _download_files_in_item(y)
        return
    elif type(x) == tuple:
        for y in x:
            _download_files_in_item(y)
        return
    else:
        return
示例#29
0
def readmda(path):
    if (file_extension(path) == '.npy'):
        return readnpy(path)
    path = kp.load_file(path)
    H = _read_header(path)
    if (H is None):
        print("Problem reading header of: {}".format(path))
        return None
    ret = np.array([])
    f = open(path, "rb")
    try:
        f.seek(H.header_size)
        # This is how I do the column-major order
        ret = np.fromfile(f, dtype=H.dt, count=H.dimprod)
        ret = np.reshape(ret, H.dims, order='F')
        f.close()
        return ret
    except Exception as e:  # catch *all* exceptions
        print(e)
        f.close()
        return None
示例#30
0
def load_info_from_mat(uri_mat):
    m = sio.loadmat(kp.load_file(uri_mat))
    spike_times = m['spikeTimes'].squeeze()
    spike_labels = m['spikeClusters'].squeeze()
    cluster_notes = m['clusterNotes'].squeeze()
    samplerate = m['SampleRate'][0][0]
    siteMap = m['siteMap'].squeeze()
    xcoords = m['xcoords'].squeeze()
    ycoords = m['ycoords'].squeeze()
    chanmap = siteMap - 1
    xcoords = xcoords[chanmap]
    ycoords = ycoords[chanmap]
    num_chan = len(chanmap)
    assert len(xcoords) == num_chan
    assert len(ycoords) == num_chan

    unit_notes = {}
    for j in range(len(cluster_notes)):
        notes = [note for note in cluster_notes[j] if isinstance(note, str)]
        if len(notes) > 0:
            unit_notes[j + 1] = notes

    return samplerate, chanmap, xcoords, ycoords, spike_times, spike_labels, unit_notes