コード例 #1
0
def prepare_recording(*, bin_uri, bin_file_size, raw_num_channels, mat_uri,
                      meta_uri, single_only):
    samplerate, chanmap, xcoords, ycoords, spike_times, spike_labels, unit_notes = load_info_from_mat(
        mat_uri)

    # exclude clusters 0 and -1
    spike_inds = np.where(spike_labels > 0)[0]
    spike_times = spike_times[spike_inds]
    spike_labels = spike_labels[spike_inds]

    if single_only:
        okay_to_use = np.zeros((len(spike_times, )))
        for unit_id, notes in unit_notes.items():
            if 'single' in notes:
                okay_to_use[np.where(spike_labels == unit_id)[0]] = 1
        spike_inds = np.where(okay_to_use)[0]
        print(
            f'Using {len(spike_inds)} of {len(spike_times)} events (single units only)'
        )
        spike_times = spike_times[spike_inds]
        spike_labels = spike_labels[spike_inds]

    times_npy_uri = ka.store_npy(spike_times)
    labels_npy_uri = ka.store_npy(spike_labels)

    sorting_object = dict(sorting_format='npy1',
                          data=dict(times_npy_uri=times_npy_uri,
                                    labels_npy_uri=labels_npy_uri,
                                    samplerate=samplerate))

    num_frames = bin_file_size / (raw_num_channels * 2)
    print(num_frames)
    assert num_frames == int(num_frames)
    num_frames = int(num_frames)

    meta_lines = kp.load_text(meta_uri).split('\n')  # perhaps use in future
    num_channels = len(chanmap)
    print(f'Number of channels: {num_channels}')

    channel_ids = [int(i) for i in range(num_channels)]
    channel_map = dict(
        zip([str(c) for c in channel_ids],
            [int(chanmap[i]) for i in range(num_channels)]))
    channel_positions = dict(
        zip([str(c) for c in channel_ids],
            [[float(xcoords[i]), float(ycoords[i])]
             for i in range(num_channels)]))

    recording_object = dict(recording_format='bin1',
                            data=dict(raw=bin_uri,
                                      raw_num_channels=raw_num_channels,
                                      num_frames=num_frames,
                                      samplerate=samplerate,
                                      channel_ids=channel_ids,
                                      channel_map=channel_map,
                                      channel_positions=channel_positions))

    return recording_object, sorting_object, unit_notes
コード例 #2
0
def _create_npy1_sorting_object(*, sorting):
    unit_ids = sorting.get_unit_ids()
    spike_trains = [
        np.array(sorting.get_unit_spike_train(unit_id=unit_id)).squeeze()
        for unit_id in unit_ids
    ]
    spike_labels = [
        unit_id * np.ones((len(spike_trains[ii]), ))
        for ii, unit_id in enumerate(unit_ids)
    ]
    all_times = np.concatenate(spike_trains)
    all_labels = np.concatenate(spike_labels)
    sort_inds = np.argsort(all_times)
    all_times = all_times[sort_inds]
    all_labels = all_labels[sort_inds]
    return dict(sorting_format='npy1',
                data=dict(times_npy_uri=ka.store_npy(all_times),
                          labels_npy_uri=ka.store_npy(all_labels),
                          samplerate=sorting.get_sampling_frequency()))
コード例 #3
0
def _create_sorting_object(sorting):
    unit_ids = sorting.get_unit_ids()
    times_list = []
    labels_list = []
    for i in range(len(unit_ids)):
        unit = unit_ids[i]
        times = sorting.get_unit_spike_train(unit_id=unit)
        times_list.append(times)
        labels_list.append(np.ones(times.shape) * unit)
    all_times = np.concatenate(times_list)
    all_labels = np.concatenate(labels_list)
    sort_inds = np.argsort(all_times)
    all_times = all_times[sort_inds]
    all_labels = all_labels[sort_inds]
    times_npy_uri = ka.store_npy(all_times)
    labels_npy_uri = ka.store_npy(all_labels)
    return dict(sorting_format='npy1',
                data=dict(times_npy_uri=times_npy_uri,
                          labels_npy_uri=labels_npy_uri,
                          samplerate=30000))
コード例 #4
0
def mountainsort4(*,
                  recording_object,
                  detect_sign=-1,
                  clip_size=50,
                  adjacency_radius=-1,
                  detect_threshold=3,
                  detect_interval=10,
                  num_workers=None,
                  verbose=True):
    from ml_ms4alg.mountainsort4 import MountainSort4
    recording = LabboxEphysRecordingExtractor(recording_object)
    MS4 = MountainSort4()
    MS4.setRecording(recording)
    geom = _get_geom_from_recording(recording)
    MS4.setGeom(geom)
    MS4.setSortingOpts(clip_size=clip_size,
                       adjacency_radius=adjacency_radius,
                       detect_sign=detect_sign,
                       detect_interval=detect_interval,
                       detect_threshold=detect_threshold,
                       verbose=verbose)
    if num_workers is not None:
        MS4.setNumWorkers(num_workers)
    with hi.TemporaryDirectory() as tmpdir:
        MS4.setTemporaryDirectory(tmpdir)
        MS4.sort()
        times, labels, channels = MS4.eventTimesLabelsChannels()
        sorting_object = {
            'sorting_format': 'npy1',
            'data': {
                'samplerate': recording.get_sampling_frequency(),
                'times_npy_uri': ka.store_npy(times.astype(np.float64)),
                'labels_npy_uri': ka.store_npy(labels.astype(np.int32))
            }
        }
        return sorting_object
コード例 #5
0
    def write_sorting(sorting, save_path, write_primary_channels=False):
        print('write sorting')
        unit_ids = sorting.get_unit_ids()
        times_list = []
        labels_list = []
        primary_channels_list = []
        for unit_id in unit_ids:
            times = sorting.get_unit_spike_train(unit_id=unit_id)
            times_list.append(times)
            labels_list.append(np.ones(times.shape) * unit_id)
            if write_primary_channels:
                if 'max_channel' in sorting.get_unit_property_names(unit_id):
                    primary_channels_list.append(
                        [sorting.get_unit_property(unit_id, 'max_channel')] *
                        times.shape[0])
                else:
                    raise ValueError(
                        "Unable to write primary channels because 'max_channel' spike feature not set in unit "
                        + str(unit_id))
            else:
                primary_channels_list.append(np.zeros(times.shape))
        all_times = _concatenate(times_list)
        all_labels = _concatenate(labels_list)
        all_primary_channels = _concatenate(primary_channels_list)
        sort_inds = np.argsort(all_times)
        all_times = all_times[sort_inds]
        all_labels = all_labels[sort_inds]
        all_primary_channels = all_primary_channels[sort_inds]
        L = len(all_times)
        firings = np.zeros((3, L))
        firings[0, :] = all_primary_channels
        firings[1, :] = all_times
        firings[2, :] = all_labels

        firings_path = ka.store_npy(array=firings, basename='firings.npy')
        sorting_obj = _json_serialize(
            dict(
                firings=firings_path,
                samplerate=sorting.get_sampling_frequency(),
                unit_ids=unit_ids,
            ))
        if save_path is not None:
            with open(save_path, 'w') as f:
                json.dump(sorting_obj, f, indent=4)
        print('-----------------', sorting_obj)
        return sorting_obj
コード例 #6
0
def store_npy(array: np.ndarray, basename: Union[str, None] = None):
    return ka.store_npy(array, basename=basename)
コード例 #7
0
def _test_store_npy(val: np.ndarray):
    x = ka.store_npy(val)
    assert x
    val2 = ka.load_npy(x)
    assert np.array_equal(val, val2)