Beispiel #1
0
    def _create_example(self, seed):
        channel_ids = [0, 1, 2, 3]
        num_channels = 4
        num_frames = 10000
        sampling_frequency = 30000
        X = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, num_frames))
        geom = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, 2))
        X = (X * 100).astype(int)
        RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
        RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
        RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
        SX = se.NumpySortingExtractor()
        spike_times = [200, 300, 400]
        train1 = np.sort(np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[0])).astype(int))
        SX.add_unit(unit_id=1, times=train1)
        SX.add_unit(unit_id=2, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[1])))
        SX.add_unit(unit_id=3, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[2])))
        SX.set_unit_property(unit_id=1, property_name='stability', value=80)
        SX.set_sampling_frequency(sampling_frequency)
        SX2 = se.NumpySortingExtractor()
        spike_times2 = [100, 150, 450]
        train2 = np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[0])).astype(int)
        SX2.add_unit(unit_id=3, times=train2)
        SX2.add_unit(unit_id=4, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[1]))
        SX2.add_unit(unit_id=5, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[2]))
        SX2.set_unit_property(unit_id=4, property_name='stability', value=80)
        SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0]))
        RX.set_channel_locations([0, 0], channel_ids=0)
        for i, unit_id in enumerate(SX2.get_unit_ids()):
            SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i)
            SX2.set_unit_spike_features(unit_id=unit_id, feature_name='shared_unit_feature',
                                        value=np.asarray([i] * spike_times2[i]))
        for i, channel_id in enumerate(RX.get_channel_ids()):
            RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i)

        SX3 = se.NumpySortingExtractor()
        train3 = np.asarray([1, 20, 21, 35, 38, 45, 46, 47])
        SX3.add_unit(unit_id=0, times=train3)
        features3 = np.asarray([0, 5, 10, 15, 20, 25, 30, 35])
        features4 = np.asarray([0, 10, 20, 30])
        feature4_idx = np.asarray([0, 2, 4, 6])
        SX3.set_unit_spike_features(unit_id=0, feature_name='dummy', value=features3)
        SX3.set_unit_spike_features(unit_id=0, feature_name='dummy2', value=features4, indexes=feature4_idx)

        example_info = dict(
            channel_ids=channel_ids,
            num_channels=num_channels,
            num_frames=num_frames,
            sampling_frequency=sampling_frequency,
            unit_ids=[1, 2, 3],
            train1=train1,
            train2=train2,
            train3=train3,
            features3=features3,
            unit_prop=80,
            channel_prop=(0, 0)
        )

        return (RX, RX2, RX3, SX, SX2, SX3, example_info)
Beispiel #2
0
def compute_units_info(*, recording, sorting, channel_ids=[], unit_ids=[]):
    if (channel_ids) and (len(channel_ids) > 0):
        recording = si.SubRecordingExtractor(parent_recording=recording,
                                             channel_ids=channel_ids)

    # load into memory
    print('Loading recording into RAM...')
    recording = si.NumpyRecordingExtractor(
        timeseries=recording.get_traces(),
        samplerate=recording.get_sampling_frequency())

    # do filtering
    print('Filtering...')
    recording = bandpass_filter(recording=recording,
                                freq_min=300,
                                freq_max=6000)
    recording = si.NumpyRecordingExtractor(
        timeseries=recording.get_traces(),
        samplerate=recording.get_sampling_frequency())

    if (not unit_ids) or (len(unit_ids) == 0):
        unit_ids = sorting.get_unit_ids()

    print('Computing channel noise levels...')
    channel_noise_levels = compute_channel_noise_levels(recording=recording)

    # No longer use subset to compute the templates
    print('Computing unit templates...')
    templates = compute_unit_templates(recording=recording,
                                       sorting=sorting,
                                       unit_ids=unit_ids,
                                       max_num=100)

    print(recording.get_channel_ids())

    ret = []
    for i, unit_id in enumerate(unit_ids):
        print('Unit {} of {} (id={})'.format(i + 1, len(unit_ids), unit_id))
        template = templates[i]
        max_p2p_amps_on_channels = np.max(template, axis=1) - np.min(template,
                                                                     axis=1)
        peak_channel_index = np.argmax(max_p2p_amps_on_channels)
        peak_channel = recording.get_channel_ids()[peak_channel_index]
        peak_signal = np.max(np.abs(template[peak_channel_index, :]))
        info0 = dict()
        info0['unit_id'] = int(unit_id)
        info0['snr'] = peak_signal / channel_noise_levels[peak_channel_index]
        info0['peak_channel'] = int(recording.get_channel_ids()[peak_channel])
        train = sorting.get_unit_spike_train(unit_id=unit_id)
        info0['num_events'] = int(len(train))
        info0['firing_rate'] = float(
            len(train) /
            (recording.get_num_frames() / recording.get_sampling_frequency()))
        ret.append(info0)
    return ret
Beispiel #3
0
    def _create_example(self):
        channel_ids = [0, 1, 2, 3]
        num_channels = 4
        num_frames = 10000
        samplerate = 30000
        X = np.random.normal(0, 1, (num_channels, num_frames))
        geom = np.random.normal(0, 1, (num_channels, 2))
        X = (X * 100).astype(int)
        RX = se.NumpyRecordingExtractor(timeseries=X,
                                        samplerate=samplerate,
                                        geom=geom)
        RX2 = se.NumpyRecordingExtractor(timeseries=X,
                                         samplerate=samplerate,
                                         geom=geom)
        RX3 = se.NumpyRecordingExtractor(timeseries=X,
                                         samplerate=samplerate,
                                         geom=geom)
        SX = se.NumpySortingExtractor()
        spike_times = [200, 300, 400]
        train1 = np.sort(
            np.rint(np.random.uniform(0, num_frames,
                                      spike_times[0])).astype(int))
        SX.add_unit(unit_id=1, times=train1)
        SX.add_unit(unit_id=2,
                    times=np.sort(
                        np.random.uniform(0, num_frames, spike_times[1])))
        SX.add_unit(unit_id=3,
                    times=np.sort(
                        np.random.uniform(0, num_frames, spike_times[2])))
        SX.set_unit_property(unit_id=1, property_name='stablility', value=80)
        SX.set_sampling_frequency(samplerate)
        SX2 = se.NumpySortingExtractor()
        spike_times2 = [100, 150, 450]
        train2 = np.rint(np.random.uniform(0, num_frames,
                                           spike_times[0])).astype(int)
        SX2.add_unit(unit_id=3, times=train2)
        SX2.add_unit(unit_id=4,
                     times=np.random.uniform(0, num_frames, spike_times2[1]))
        SX2.add_unit(unit_id=5,
                     times=np.random.uniform(0, num_frames, spike_times2[2]))
        SX2.set_unit_property(unit_id=4, property_name='stablility', value=80)
        RX.set_channel_property(channel_id=0,
                                property_name='location',
                                value=(0, 0))
        example_info = dict(channel_ids=channel_ids,
                            num_channels=num_channels,
                            num_frames=num_frames,
                            samplerate=samplerate,
                            unit_ids=[1, 2, 3],
                            train1=train1,
                            unit_prop=80,
                            channel_prop=(0, 0))

        return (RX, RX2, RX3, SX, SX2, example_info)
Beispiel #4
0
def test_remove_bad_channels():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)
    rec_rm = remove_bad_channels(rec, bad_channel_ids=[0])
    assert 0 not in rec_rm.get_channel_ids()

    rec_rm = remove_bad_channels(rec, bad_channel_ids=[1, 2])
    assert 1 not in rec_rm.get_channel_ids() and 2 not in rec_rm.get_channel_ids()

    check_dumping(rec_rm)
    shutil.rmtree('test')

    timeseries = np.random.randn(4, 60000)
    timeseries[1] = 10 * timeseries[1]

    rec_np = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=30000)
    rec_np.set_channel_locations(np.ones((rec_np.get_num_channels(), 2)))
    se.MdaRecordingExtractor.write_recording(rec_np, 'test')
    rec = se.MdaRecordingExtractor('test')
    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)

    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=0.1)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)

    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=10)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)
    shutil.rmtree('test')
Beispiel #5
0
 def setUp(self):
     M = 4
     N = 10000
     N_ttl = 50
     seed = 0
     sampling_frequency = 30000
     X = np.random.RandomState(seed=seed).normal(0, 1, (M, N))
     geom = np.random.RandomState(seed=seed).normal(0, 1, (M, 2))
     self._X = X
     self._geom = geom
     self._sampling_frequency = sampling_frequency
     self.RX = se.NumpyRecordingExtractor(
         timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
     self._ttl_frames = np.sort(np.random.permutation(N)[:N_ttl])
     self.RX.set_ttls(self._ttl_frames)
     self.SX = se.NumpySortingExtractor()
     L = 200
     self._train1 = np.rint(
         np.random.RandomState(seed=seed).uniform(0, N, L)).astype(int)
     self.SX.add_unit(unit_id=1, times=self._train1)
     self.SX.add_unit(unit_id=2,
                      times=np.random.RandomState(seed=seed).uniform(
                          0, N, L))
     self.SX.add_unit(unit_id=3,
                      times=np.random.RandomState(seed=seed).uniform(
                          0, N, L))
def gen_synth_datasets(datasets, *, outdir, samplerate=32000):
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    for ds in datasets:
        ds_name = ds['name']
        print(ds_name)
        spiketrains = gen_spiketrains(duration=ds['duration'],
                                      n_exc=ds['n_exc'],
                                      n_inh=ds['n_inh'],
                                      f_exc=ds['f_exc'],
                                      f_inh=ds['f_inh'],
                                      min_rate=ds['min_rate'],
                                      st_exc=ds['st_exc'],
                                      st_inh=ds['st_inh'])
        OX = NeoSpikeTrainsOutputExtractor(spiketrains=spiketrains,
                                           samplerate=samplerate)
        X, geom = gen_recording(templates=ds['templates'],
                                output_extractor=OX,
                                noise_level=ds['noise_level'],
                                samplerate=samplerate,
                                duration=ds['duration'])
        IX = si.NumpyRecordingExtractor(timeseries=X,
                                        samplerate=samplerate,
                                        geom=geom)
        si.MdaRecordingExtractor.writeRecording(IX,
                                                outdir + '/{}'.format(ds_name))
        si.MdaSortingExtractor.writeSorting(
            OX, outdir + '/{}/firings_true.mda'.format(ds_name))
    print('Done.')
Beispiel #7
0
def toy_example(duration=10,
                num_channels=4,
                samplerate=30000.0,
                K=10,
                seed=None):
    upsamplefac = 13

    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(K=K,
                                              duration=duration,
                                              samplerate=samplerate,
                                              seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              samplerate=samplerate,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac)
    SX.set_sampling_frequency(samplerate)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    samplerate=samplerate,
                                    geom=geom)
    return (RX, SX)
def test_remove_bad_channels():
    rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4)
    rec_rm = remove_bad_channels(rec, bad_channel_ids=[0])
    assert 0 not in rec_rm.get_channel_ids()

    rec_rm = remove_bad_channels(rec, bad_channel_ids=[1, 2])
    assert 1 not in rec_rm.get_channel_ids(
    ) and 2 not in rec_rm.get_channel_ids()

    timeseries = np.random.randn(4, 60000)
    timeseries[1] = 3 * timeseries[1]

    rec_np = se.NumpyRecordingExtractor(timeseries=timeseries,
                                        sampling_frequency=30000)
    rec_rm = remove_bad_channels(rec_np, bad_channel_ids=None, bad_threshold=2)
    assert 1 not in rec_rm.get_channel_ids()

    rec_rm = remove_bad_channels(rec_np,
                                 bad_channel_ids=None,
                                 bad_threshold=2,
                                 seconds=0.1)
    assert 1 not in rec_rm.get_channel_ids()

    rec_rm = remove_bad_channels(rec_np,
                                 bad_channel_ids=None,
                                 bad_threshold=2,
                                 seconds=10)
    assert 1 not in rec_rm.get_channel_ids()
def create_simulated_recording(size,
                               num_frames=1000,
                               sampling_frequency=30000,
                               seed=0):
    #TODO if centered at 0, 0: two channels at pos 0 if even number
    # channel_pos = [int(coord-(size-1)/2) for coord in range(0, size)]
    channel_pos = [coord for coord in range(0, size)]
    geom = []
    for k in channel_pos:
        for j in channel_pos:
            geom.append([j, k, 0])

    geom = np.asarray(geom)
    channel_ids = np.arange(0, size * size)
    num_channels = len(channel_ids)

    X = np.random.RandomState(seed=seed).normal(0, 1,
                                                (num_channels, num_frames))
    X = (X * 100).astype(int)
    X, spike_frame_channel_array = add_artificial_spikes(X)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)

    return geom, RX, spike_frame_channel_array
 def __init__(self, file_path, acquisition_name=None):
     assert HAVE_NWB, "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n"
     self._path = file_path
     self._acquisition_name = acquisition_name
     with NWBHDF5IO(file_path, 'r') as io:
         nwbfile = io.read()
         if acquisition_name is None:
             a_names = list(nwbfile.acquisition.keys())
             if len(a_names) > 1:
                 raise Exception('More than one acquisition found. You must specify acquisition_name.')
             if len(a_names) == 0:
                 raise Exception('No acquisitions found in the .nwb file.')
             acquisition_name = a_names[0]
         ts = nwbfile.acquisition[acquisition_name]
         self._nwb_timeseries = ts
         M = np.array(ts.data).shape[1]
         if M != len(ts.electrodes):
             raise Exception(
                 'Number of electrodes does not match the shape of the data {}<>{}'.format(M, len(ts.electrodes)))
         geom = np.zeros((M, 3))
         for m in range(M):
             geom[m, :] = [ts.electrodes[m][1], ts.electrodes[m][2], ts.electrodes[m][3]]
         if hasattr(ts, 'timestamps') and ts.timestamps:
             sampling_frequency = 1 / (ts.timestamps[1] - ts.timestamps[0])  # there's probably a better way
         else:
             sampling_frequency = ts.rate * 1000
         data = np.copy(np.transpose(ts.data))
         NRX = se.NumpyRecordingExtractor(timeseries=data, sampling_frequency=sampling_frequency, geom=geom)
         CopyRecordingExtractor.__init__(self, NRX)
Beispiel #11
0
 def setUp(self):
     M = 32
     N = 10000
     samplerate = 30000
     X = np.random.normal(0, 1, (M, N))
     self._X = X
     self._samplerate = samplerate
     self.RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate)
     self.test_dir = tempfile.mkdtemp()
Beispiel #12
0
 def createSession(self):
     recording = SFMdaRecordingExtractor(
         dataset_directory=self._recording_directory, download=False)
     recording = se.SubRecordingExtractor(
         parent_recording=recording, start_frame=0, end_frame=10000)
     recording = se.NumpyRecordingExtractor(
         timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency())
     W = SFW.TimeseriesWidget(recording=recording)
     _make_full_browser(W)
     return W
 def setUp(self):
     M = 32
     N = 10000
     seed = 0
     sampling_frequency = 30000
     X = np.random.RandomState(seed=seed).normal(0, 1, (M, N))
     self._X = X
     self._sampling_frequency = sampling_frequency
     self.RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency)
     self.RX.set_channel_locations(np.random.randn(32, 3))
     self.test_dir = tempfile.mkdtemp()
Beispiel #14
0
def toy_example1(duration=10, num_channels=4, samplerate=30000, K=10, firing_rates=None, noise_level=10):
    upsamplefac = 13

    waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac)
    times, labels = synthesize_random_firings(K=K, duration=duration, samplerate=samplerate, firing_rates=firing_rates)
    labels = labels.astype(np.int64)
    OX = se.NumpySortingExtractor()
    OX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=OX, waveforms=waveforms, noise_level=noise_level, samplerate=samplerate, duration=duration,
                              waveform_upsamplefac=upsamplefac)

    IX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom)
    return (IX, OX)
def test_highpass_filter():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=2,
                                                num_channels=4,
                                                seed=0)

    rec_fft = highpass_filter(rec, freq_min=5000, filter_type='fft')

    assert check_signal_power_signal1_below_signal2(
        rec_fft.get_traces(),
        rec.get_traces(),
        freq_range=[1000, 5000],
        fs=rec.get_sampling_frequency())

    rec_sci = bandpass_filter(rec,
                              freq_min=3000,
                              freq_max=6000,
                              filter_type='butter',
                              order=3)

    assert check_signal_power_signal1_below_signal2(
        rec_sci.get_traces(),
        rec.get_traces(),
        freq_range=[1000, 3000],
        fs=rec.get_sampling_frequency())

    traces = rec.get_traces().astype('uint16')
    rec_u = se.NumpyRecordingExtractor(
        traces, sampling_frequency=rec.get_sampling_frequency())
    rec_fu = bandpass_filter(rec_u,
                             freq_min=5000,
                             freq_max=10000,
                             filter_type='fft')

    assert check_signal_power_signal1_below_signal2(
        rec_fu.get_traces(),
        rec_u.get_traces(),
        freq_range=[1000, 5000],
        fs=rec.get_sampling_frequency())
    assert check_signal_power_signal1_below_signal2(
        rec_fu.get_traces(),
        rec_u.get_traces(),
        freq_range=[10000, 15000],
        fs=rec.get_sampling_frequency())
    assert not str(rec_fu.get_dtype()).startswith('u')

    check_dumping(rec_fft)
    shutil.rmtree('test')
 def setUp(self):
     M = 4
     N = 10000
     samplerate = 30000
     X = np.random.normal(0, 1, (M, N))
     geom = np.random.normal(0, 1, (M, 2))
     self._X = X
     self._geom = geom
     self._samplerate = samplerate
     self.RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom)
     self.SX = se.NumpySortingExtractor()
     L = 200
     self._train1 = np.rint(np.random.uniform(0, N, L)).astype(int)
     self.SX.add_unit(unit_id=1, times=self._train1)
     self.SX.add_unit(unit_id=2, times=np.random.uniform(0, N, L))
     self.SX.add_unit(unit_id=3, times=np.random.uniform(0, N, L))
def toy_example(duration=10,
                num_channels=4,
                sampling_frequency=30000.0,
                K=10,
                dumpable=False,
                dump_folder=None,
                seed=None):
    upsamplefac = 13

    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(
        K=K,
        duration=duration,
        sampling_frequency=sampling_frequency,
        seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              sampling_frequency=sampling_frequency,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac,
                              seed=seed)
    SX.set_sampling_frequency(sampling_frequency)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)
    RX.is_filtered = True

    if dumpable:
        if dump_folder is None:
            dump_folder = 'toy_example'
        dump_folder = Path(dump_folder)

        se.MdaRecordingExtractor.write_recording(RX, dump_folder)
        RX = se.MdaRecordingExtractor(dump_folder)
        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')
        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')

    return RX, SX
def gen_synth_datasets(datasets,
                       *,
                       outdir,
                       num_channels=4,
                       upsamplefac=13,
                       samplerate=30000,
                       average_peak_amplitude=-100):
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    for ds in datasets:
        ds_name = ds['name']
        print(ds_name)
        K = ds['K']
        duration = ds['duration']
        noise_level = ds['noise_level']
        waveforms, geom = synthesize_random_waveforms(
            K=K,
            M=num_channels,
            average_peak_amplitude=average_peak_amplitude,
            upsamplefac=upsamplefac)
        times, labels = synthesize_random_firings(K=K,
                                                  duration=duration,
                                                  samplerate=samplerate)
        labels = labels.astype(np.int64)
        OX = si.NumpySortingExtractor()
        OX.setTimesLabels(times, labels)
        X = synthesize_timeseries(output_extractor=OX,
                                  waveforms=waveforms,
                                  noise_level=noise_level,
                                  samplerate=samplerate,
                                  duration=duration,
                                  waveform_upsamplefac=upsamplefac)
        IX = si.NumpyRecordingExtractor(timeseries=X,
                                        samplerate=samplerate,
                                        geom=geom)
        si.MdaRecordingExtractor.writeRecording(IX,
                                                outdir + '/{}'.format(ds_name))
        si.MdaSortingExtractor.writeSorting(
            OX, outdir + '/{}/firings_true.mda'.format(ds_name))
        # write json with two fields

    print('Done.')
Beispiel #19
0
    def _create_example(self):
        channel_ids = [0, 1, 2, 3]
        num_channels = 4
        num_frames = 10000
        sampling_frequency = 30000
        X = np.random.normal(0, 1, (num_channels, num_frames))
        geom = np.random.normal(0, 1, (num_channels, 2))
        X = (X * 100).astype(int)
        RX = se.NumpyRecordingExtractor(timeseries=X,
                                        sampling_frequency=sampling_frequency,
                                        geom=geom)
        RX2 = se.NumpyRecordingExtractor(timeseries=X,
                                         sampling_frequency=sampling_frequency,
                                         geom=geom)
        RX3 = se.NumpyRecordingExtractor(timeseries=X,
                                         sampling_frequency=sampling_frequency,
                                         geom=geom)
        SX = se.NumpySortingExtractor()
        spike_times = [200, 300, 400]
        train1 = np.sort(
            np.rint(np.random.uniform(0, num_frames,
                                      spike_times[0])).astype(int))
        SX.add_unit(unit_id=1, times=train1)
        SX.add_unit(unit_id=2,
                    times=np.sort(
                        np.random.uniform(0, num_frames, spike_times[1])))
        SX.add_unit(unit_id=3,
                    times=np.sort(
                        np.random.uniform(0, num_frames, spike_times[2])))
        SX.set_unit_property(unit_id=1, property_name='stability', value=80)
        SX.set_sampling_frequency(sampling_frequency)
        SX2 = se.NumpySortingExtractor()
        spike_times2 = [100, 150, 450]
        train2 = np.rint(np.random.uniform(0, num_frames,
                                           spike_times2[0])).astype(int)
        SX2.add_unit(unit_id=3, times=train2)
        SX2.add_unit(unit_id=4,
                     times=np.random.uniform(0, num_frames, spike_times2[1]))
        SX2.add_unit(unit_id=5,
                     times=np.random.uniform(0, num_frames, spike_times2[2]))
        SX2.set_unit_property(unit_id=4, property_name='stability', value=80)
        SX2.set_unit_spike_features(unit_id=3,
                                    feature_name='widths',
                                    value=np.asarray([3] * spike_times2[0]))
        RX.set_channel_property(channel_id=0,
                                property_name='location',
                                value=(0, 0))
        for i, unit_id in enumerate(SX2.get_unit_ids()):
            SX2.set_unit_property(unit_id=unit_id,
                                  property_name='shared_unit_prop',
                                  value=i)
            SX2.set_unit_spike_features(unit_id=unit_id,
                                        feature_name='shared_unit_feature',
                                        value=np.asarray([i] *
                                                         spike_times2[i]))
        for i, channel_id in enumerate(RX.get_channel_ids()):
            RX.set_channel_property(channel_id=channel_id,
                                    property_name='shared_channel_prop',
                                    value=i)
        example_info = dict(channel_ids=channel_ids,
                            num_channels=num_channels,
                            num_frames=num_frames,
                            sampling_frequency=sampling_frequency,
                            unit_ids=[1, 2, 3],
                            train1=train1,
                            unit_prop=80,
                            channel_prop=(0, 0))

        return (RX, RX2, RX3, SX, SX2, example_info)
def sort_main(task, overwrite_flag=0):
    try:
        save_path = Path(task['save_path'], task['task_type'])

        if (not (save_path / 'recording.dat').exists()) or overwrite_flag:
            # load task data
            data = np.load(task['file_path'])
            with open(task['file_header_path'], 'rb') as f:
                data_info = pickle.load(f)

            # prepare filter
            sos, _ = pp.get_sos_filter_bank(['Sp'], fs=data_info['fs'])
            spk_data = np.zeros_like(data)
            assert data_info['n_chans'] == spk_data.shape[0], "Inconsistent formating in the data files. Aborting."

            # spk filter (high pass)
            t0 = time.time()
            for ch in range(data_info['n_chans']):
                spk_data[ch] = scipy.signal.sosfiltfilt(sos, data[ch])
                print('', end='.')
            t1 = time.time()
            print('\nTime to spk filter data {0:0.2f}s'.format(t1 - t0))

            chan_masks = pp.create_chan_masks(data_info['Raw']['ClippedSegs'], data_info['n_samps'])
            chan_mad = pp.get_signals_mad(spk_data, chan_masks)
            data_info['Spk'] = {'mad': chan_mad}

            # convert data to spikeinterface format
            spk_data_masked = se.NumpyRecordingExtractor(timeseries=spk_data * chan_masks, geom=data_info['tt_geom'],
                                                         sampling_frequency=data_info['fs'])

            # sort data
            sort = sort_data(spk_data_masked, save_path, sorter=task['task_type'])
            if sort is not None:
                # export data to phy
                st.postprocessing.export_to_phy(recording=spk_data_masked, sorting=sort,
                                                output_folder=str(save_path),
                                                compute_pc_features=False, compute_amplitudes=False,
                                                max_channels_per_template=4)

                # get cluster stats
                spk_times_list = sort.get_units_spike_train()
                cluster_stats = get_cluster_stats(spk_times_list, spk_data_masked.get_traces(), data_info)
                cluster_stats_file_path = Path(save_path, 'cluster_stats.csv')
                cluster_stats.to_csv(cluster_stats_file_path)

                print('downSuccessful sort.')
            else:
                print('Uncesseful sort.')

            # save header
            updated_file_header_path = Path(task['save_path'], Path(task['file_header_path']).name)
            with updated_file_header_path.open(mode='wb') as file_handle:
                pickle.dump(data_info, file_handle, protocol=pickle.HIGHEST_PROTOCOL)

        else:
            print('Sorting Done and overwrite flag is False, skipping this sort.')
    except KeyboardInterrupt:
        print('Keyboard Interrupt Detected. Aborting Task Processing.')
        sys.exit()

    except:
        print("Error", sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2].tb_lineno)
        traceback.print_exc(file=sys.stdout)
Beispiel #21
0
    def run(self):
        # This temporary file will automatically be removed even in the case of a python exception
        with TemporaryDirectory() as tmpdir:
            # names of files for the temporary/intermediate data
            filt = tmpdir + '/filt.mda'
            filt2 = tmpdir + '/filt2.mda'
            pre = tmpdir + '/pre.mda'

            print('Bandpass filtering raw -> filt...')
            _bandpass_filter(self.recording_file_in, filt)

            if self.mask_out_artifacts:
                print('Masking out artifacts filt -> filt2...')
                _mask_out_artifacts(filt, filt2)
            else:
                print('Copying filt -> filt2...')
                filt2 = filt

            if self.whiten:
                print('Whitening filt2 -> pre...')
                _whiten(filt2, pre)
            else:
                pre = filt2

            # read the preprocessed timeseries into RAM (maybe we'll do it differently later)
            X = sf.mdaio.readmda(pre)

            # handle the geom
            if type(self.geom_in) == str:
                print('Using geom.csv from a file', self.geom_in)
                geom = read_geom_csv(self.geom_in)
            else:
                # no geom file was provided as input
                num_channels = X.shape[0]
                if num_channels > 6:
                    raise Exception(
                        'For more than six channels, we require that a geom.csv be provided')
                # otherwise make a trivial geometry file
                print('Making a trivial geom file.')
                geom = np.zeros((X.shape[0], 2))

            # Now represent the preprocessed recording using a RecordingExtractor
            recording = se.NumpyRecordingExtractor(
                X, samplerate=30000, geom=geom)

            # hard-code this for now -- idea: run many simultaneous jobs, each using only 2 cores
            # important to set certain environment variables in the .sh script that calls this .py script
            num_workers = 2

            # Call MountainSort4
            sorting = ml_ms4alg.mountainsort4(
                recording=recording,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                clip_size=self.clip_size,
                detect_threshold=self.detect_threshold,
                detect_interval=self.detect_interval,
                num_workers=num_workers,
            )

            # Write the firings.mda
            print('Writing firings.mda...')
            sf.SFMdaSortingExtractor.write_sorting(
                sorting=sorting, save_path=self.firings_out)

            print('Computing cluster metrics...')
            cluster_metrics_path = tmpdir + '/cluster_metrics.json'
            _cluster_metrics(pre, self.firings_out, cluster_metrics_path)

            print('Computing isolation metrics...')
            isolation_metrics_path = tmpdir + '/isolation_metrics.json'
            pair_metrics_path = tmpdir + '/pair_metrics.json'
            _isolation_metrics(pre, self.firings_out,
                               isolation_metrics_path, pair_metrics_path)

            print('Combining metrics...')
            metrics_path = tmpdir + '/metrics.json'
            _combine_metrics(cluster_metrics_path,
                             isolation_metrics_path, metrics_path)

            # copy metrics.json to the output location
            shutil.copy(metrics_path, self.metrics_out)

            print('Creating label map...')
            label_map_path = tmpdir + '/label_map.mda'
            create_label_map(metrics=metrics_path,
                             label_map_out=label_map_path)

            print('Applying label map...')
            apply_label_map(firings=self.firings_out, label_map=label_map_path,
                            firings_out=self.firings_curated_out)
Beispiel #22
0
def create_signal_with_known_waveforms(n_channels=4,
                                       n_waveforms=2,
                                       n_wf_samples=100,
                                       duration=5,
                                       fs=30000):
    '''
    Creates stereotyped recording, sorting, with waveforms, templates, and max_chans
    '''
    a_min = [-200, -50]
    a_max = [10, 50]
    wfs = []

    # gen waveforms
    for w in range(n_waveforms):
        amp_min = np.random.randint(a_min[0], a_min[1])
        amp_max = np.random.randint(a_max[0], a_max[1])

        wf = create_wf(amp_min, amp_max, n_wf_samples)
        wfs.append(wf)

    # gen templates
    templates = []
    max_chans = []
    for wf in wfs:
        found = False
        while not found:
            template, amps, found = generate_template_with_random_amps(
                n_channels, wf)
        templates.append(template)
        max_chans.append(np.argmax(amps))

    templates = np.array(templates)
    n_samples = int(fs * duration)

    # gen spiketrains
    interval = 10 * n_wf_samples
    times = np.arange(interval, duration * fs - interval, interval).astype(int)
    labels = np.zeros(len(times)).astype(int)
    for i, wf in enumerate(wfs):
        labels[i::len(wfs)] = i

    timeseries = np.zeros((n_channels, n_samples))
    waveforms = []
    amplitudes = []
    for i, tem in enumerate(templates):
        idxs = np.where(labels == i)
        wav = []
        amps = []
        for t in times[idxs]:
            rand_val = np.random.randn() * 0.01 + 1
            timeseries[:, t - n_wf_samples // 2:t +
                       n_wf_samples // 2] = rand_val * tem
            wav.append(rand_val * tem)
            amps.append(np.min(rand_val * tem))
        wav = np.array(wav)
        amps = np.array(amps)
        waveforms.append(wav)
        amplitudes.append(amps)

    rec = se.NumpyRecordingExtractor(timeseries=timeseries,
                                     sampling_frequency=fs)
    sort = se.NumpySortingExtractor()
    sort.set_times_labels(times=times, labels=labels)
    sort.set_sampling_frequency(fs)

    return rec, sort, waveforms, templates, max_chans, amplitudes
Beispiel #23
0
def toy_example(duration: float = 10.,
                num_channels: int = 4,
                sampling_frequency: float = 30000.,
                K: int = 10,
                dumpable: bool = False,
                dump_folder: Optional[Union[str, Path]] = None,
                seed: Optional[int] = None):
    """
    Create toy recording and sorting extractors.

    Parameters
    ----------
    duration: float
        Duration in s (default 10)
    num_channels: int
        Number of channels (default 4)
    sampling_frequency: float
        Sampling frequency (default 30000)
    K: int
        Number of units (default 10)
    dumpable: bool
        If True, objects are dumped to file and become 'dumpable'
    dump_folder: str or Path
        Path to dump folder (if None, 'test' is used
    seed: int
        Seed for random initialization

    Returns
    -------
    recording: RecordingExtractor
        The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an
        MdaRecordingExtractor
    sorting: SortingExtractor
        The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an
        NpzSortingExtractor
    """
    upsamplefac = 13
    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(
        K=K,
        duration=duration,
        sampling_frequency=sampling_frequency,
        seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              sampling_frequency=sampling_frequency,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac,
                              seed=seed)
    SX.set_sampling_frequency(sampling_frequency)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)
    RX.is_filtered = True

    if dumpable:
        if dump_folder is None:
            dump_folder = 'toy_example'
        dump_folder = Path(dump_folder)

        se.MdaRecordingExtractor.write_recording(RX, dump_folder)
        RX = se.MdaRecordingExtractor(dump_folder)
        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')
        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')

    return RX, SX
Beispiel #24
0
def test_accuracy_of_denoising():
    # Test the accuracy of denoising
    duration=10
    num_channels=4
    sampling_frequency=30000
    K=10
    seed=None

    upsamplefac = 13

    waveforms, geom = example_datasets.synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac, seed=seed)
    times, labels = example_datasets.synthesize_random_firings(
        K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)

    SX.set_sampling_frequency(sampling_frequency)

    recordings = []
    for noise_level in [0, 10]:
        X = example_datasets.synthesize_timeseries(
            sorting=SX,
            waveforms=waveforms,
            noise_level=noise_level,
            sampling_frequency=sampling_frequency,
            duration=duration,
            waveform_upsamplefac=upsamplefac,
            seed=seed
        )

        RX = se.NumpyRecordingExtractor(
            timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
        
        recordings.append(RX)
    
    recording_without_noise = recordings[0]
    recording_with_noise = recordings[1]
    
    opts = ephys_nlm_v1_opts(
        multi_neighborhood=False,
        block_size_sec=30,
        clip_size=30,
        sigma='auto',
        sigma_scale_factor=1,
        whitening='auto',
        whitening_pctvar=90,
        denom_threshold=30
    )

    recording_denoised, runtime_info = ephys_nlm_v1(
        recording=recording_with_noise,
        opts=opts,
        device=None, # detect from the EPHYS_NLM_DEVICE environment variable
        verbose=2
    )

    traces_with_noise = recording_with_noise.get_traces()
    traces_without_noise = recording_without_noise.get_traces()
    traces_denoised = recording_denoised.get_traces()

    std_noise_before = np.sqrt(np.var(traces_without_noise - traces_with_noise))
    std_noise_after = np.sqrt(np.var(traces_without_noise - traces_denoised))
    print(f'std_noise_before = {std_noise_before}; std_noise_after = {std_noise_after};')

    assert std_noise_after < 0.3 * std_noise_before
    def run_conversion(self,
                       nwbfile: NWBFile,
                       metadata: dict,
                       stub_test: bool = False,
                       write_ecephys_metadata: bool = False):
        if 'UnitProperties' not in metadata:
            metadata['UnitProperties'] = []
        if write_ecephys_metadata and 'Ecephys' in metadata:
            n_channels = max(
                [len(x['data']) for x in metadata['Ecephys']['Electrodes']])
            recording = se.NumpyRecordingExtractor(timeseries=np.array(
                range(n_channels)),
                                                   sampling_frequency=1)
            se.NwbRecordingExtractor.add_devices(recording=recording,
                                                 nwbfile=nwbfile,
                                                 metadata=metadata)
            se.NwbRecordingExtractor.add_electrode_groups(recording=recording,
                                                          nwbfile=nwbfile,
                                                          metadata=metadata)
            se.NwbRecordingExtractor.add_electrodes(recording=recording,
                                                    nwbfile=nwbfile,
                                                    metadata=metadata)

        property_descriptions = dict()
        if stub_test:
            max_min_spike_time = max([
                min(x) for y in self.sorting_extractor.get_unit_ids()
                for x in [self.sorting_extractor.get_unit_spike_train(y)]
                if any(x)
            ])
            stub_sorting_extractor = se.SubSortingExtractor(
                self.sorting_extractor,
                unit_ids=self.sorting_extractor.get_unit_ids(),
                start_frame=0,
                end_frame=1.1 * max_min_spike_time)
            sorting_extractor = stub_sorting_extractor
        else:
            sorting_extractor = self.sorting_extractor

        for metadata_column in metadata['UnitProperties']:
            assert len(metadata_column['data']) == len(sorting_extractor.get_unit_ids()), \
                f"The metadata_column '{metadata_column['name']}' data must have the same dimension as the sorting IDs!"

            property_descriptions.update(
                {metadata_column['name']: metadata_column['description']})
            for unit_idx, unit_id in enumerate(
                    sorting_extractor.get_unit_ids()):
                if metadata_column['name'] == 'electrode_group':
                    if nwbfile.electrode_groups:
                        data = nwbfile.electrode_groups[metadata_column['data']
                                                        [unit_idx]]
                        sorting_extractor.set_unit_property(
                            unit_id, metadata_column['name'], data)
                else:
                    data = metadata_column['data'][unit_idx]
                    sorting_extractor.set_unit_property(
                        unit_id, metadata_column['name'], data)

        se.NwbSortingExtractor.write_sorting(
            sorting_extractor,
            property_descriptions=property_descriptions,
            nwbfile=nwbfile)
Beispiel #26
0
def yuta2nwb(
        session_path='D:/BuzsakiData/SenzaiY/YutaMouse41/YutaMouse41-150903',
        # '/Users/bendichter/Desktop/Buzsaki/SenzaiBuzsaki2017/YutaMouse41/YutaMouse41-150903',
        subject_xls=None,
        include_spike_waveforms=True,
        stub=True,
        cache_spec=True):

    subject_path, session_id = os.path.split(session_path)
    fpath_base = os.path.split(subject_path)[0]
    identifier = session_id
    mouse_number = session_id[9:11]
    if '-' in session_id:
        subject_id, date_text = session_id.split('-')
        b = False
    else:
        subject_id, date_text = session_id.split('b')
        b = True

    if subject_xls is None:
        subject_xls = os.path.join(subject_path,
                                   'YM' + mouse_number + ' exp_sheet.xlsx')
    else:
        if not subject_xls[-4:] == 'xlsx':
            subject_xls = os.path.join(subject_xls,
                                       'YM' + mouse_number + ' exp_sheet.xlsx')

    session_start_time = dateparse(date_text, yearfirst=True)

    df = pd.read_excel(subject_xls)

    subject_data = {}
    for key in [
            'genotype', 'DOB', 'implantation', 'Probe', 'Surgery',
            'virus injection', 'mouseID'
    ]:
        names = df.iloc[:, 0]
        if key in names.values:
            subject_data[key] = df.iloc[np.argmax(names == key), 1]

    if isinstance(subject_data['DOB'], datetime):
        age = session_start_time - subject_data['DOB']
    else:
        age = None

    subject = Subject(subject_id=subject_id,
                      age=str(age),
                      genotype=subject_data['genotype'],
                      species='mouse')

    nwbfile = NWBFile(
        session_description='mouse in open exploration and theta maze',
        identifier=identifier,
        session_start_time=session_start_time.astimezone(),
        file_create_date=datetime.now().astimezone(),
        experimenter='Yuta Senzai',
        session_id=session_id,
        institution='NYU',
        lab='Buzsaki',
        subject=subject,
        related_publications='DOI:10.1016/j.neuron.2016.12.011')

    print('reading and writing raw position data...', end='', flush=True)
    ns.add_position_data(nwbfile, session_path)

    shank_channels = ns.get_shank_channels(session_path)[:8]
    nshanks = len(shank_channels)
    all_shank_channels = np.concatenate(shank_channels)

    print('setting up electrodes...', end='', flush=True)
    hilus_csv_path = os.path.join(fpath_base, 'early_session_hilus_chans.csv')
    lfp_channel = get_reference_elec(subject_xls,
                                     hilus_csv_path,
                                     session_start_time,
                                     session_id,
                                     b=b)

    custom_column = [{
        'name': 'theta_reference',
        'description':
        'this electrode was used to calculate LFP canonical bands',
        'data': all_shank_channels == lfp_channel
    }]
    ns.write_electrode_table(nwbfile,
                             session_path,
                             custom_columns=custom_column,
                             max_shanks=max_shanks)

    print('reading raw electrode data...', end='', flush=True)
    if stub:
        # example recording extractor for fast testing
        xml_filepath = os.path.join(session_path, session_id + '.xml')
        xml_root = et.parse(xml_filepath).getroot()
        acq_sampling_frequency = float(
            xml_root.find('acquisitionSystem').find('samplingRate').text)
        num_channels = 4
        num_frames = 10000
        X = np.random.normal(0, 1, (num_channels, num_frames))
        geom = np.random.normal(0, 1, (num_channels, 2))
        X = (X * 100).astype(int)
        sre = se.NumpyRecordingExtractor(
            timeseries=X, sampling_frequency=acq_sampling_frequency, geom=geom)
    else:
        nre = se.NeuroscopeRecordingExtractor('{}/{}.dat'.format(
            session_path, session_id))
        sre = se.SubRecordingExtractor(nre, channel_ids=all_shank_channels)

    print('writing raw electrode data...', end='', flush=True)
    se.NwbRecordingExtractor.add_electrical_series(sre, nwbfile)
    print('done.')

    print('reading spiking units...', end='', flush=True)
    if stub:
        spike_times = [200, 300, 400]
        num_frames = 10000
        allshanks = []
        for k in range(nshanks):
            SX = se.NumpySortingExtractor()
            for j in range(len(spike_times)):
                SX.add_unit(unit_id=j + 1,
                            times=np.sort(
                                np.random.uniform(0, num_frames,
                                                  spike_times[j])))
            allshanks.append(SX)
        se_allshanks = se.MultiSortingExtractor(allshanks)
        se_allshanks.set_sampling_frequency(acq_sampling_frequency)
    else:
        se_allshanks = se.NeuroscopeMultiSortingExtractor(session_path,
                                                          keep_mua_units=False)

    electrode_group = []
    for shankn in np.arange(1, nshanks + 1, dtype=int):
        for id in se_allshanks.sortings[shankn - 1].get_unit_ids():
            electrode_group.append(nwbfile.electrode_groups['shank' +
                                                            str(shankn)])

    df_unit_features = get_UnitFeatureCell_features(fpath_base, session_id,
                                                    session_path)

    celltype_names = []
    for celltype_id, region_id in zip(df_unit_features['fineCellType'].values,
                                      df_unit_features['region'].values):
        if celltype_id == 1:
            if region_id == 3:
                celltype_names.append('pyramidal cell')
            elif region_id == 4:
                celltype_names.append('granule cell')
            else:
                raise Exception('unknown type')
        elif not np.isfinite(celltype_id):
            celltype_names.append('missing')
        else:
            celltype_names.append(celltype_dict[celltype_id])

    # Add custom column data into the SortingExtractor so it can be written by the converter
    # Note there is currently a hidden assumption that the way in which the NeuroscopeSortingExtractor
    # merges the cluster IDs matches one-to-one with the get_UnitFeatureCell_features extraction
    property_descriptions = {
        'cell_type': 'name of cell type',
        'global_id': 'global id for cell for entire experiment',
        'shank_id': '0-indexed id of cluster of shank',
        'electrode_group': 'the electrode group that each spike unit came from'
    }
    property_values = {
        'cell_type': celltype_names,
        'global_id': df_unit_features['unitID'].values,
        'shank_id': [x - 2 for x in df_unit_features['unitIDshank'].values],
        # - 2 b/c the get_UnitFeatureCell_features removes 0 and 1 IDs from each shank
        'electrode_group': electrode_group
    }
    for unit_id in se_allshanks.get_unit_ids():
        for property_name in property_descriptions.keys():
            se_allshanks.set_unit_property(
                unit_id, property_name,
                property_values[property_name][unit_id])

    se.NwbSortingExtractor.write_sorting(
        se_allshanks,
        nwbfile=nwbfile,
        property_descriptions=property_descriptions)
    print('done.')

    # Read and write LFP's
    print('reading LFPs...', end='', flush=True)
    lfp_fs, all_channels_lfp_data = ns.read_lfp(session_path, stub=stub)

    lfp_data = all_channels_lfp_data[:, all_shank_channels]
    print('writing LFPs...', flush=True)
    # lfp_data[:int(len(lfp_data)/4)]
    lfp_ts = ns.write_lfp(nwbfile,
                          lfp_data,
                          lfp_fs,
                          name='lfp',
                          description='lfp signal for all shank electrodes')

    # Read and add special environmental electrodes
    for name, channel in special_electrode_dict.items():
        ts = TimeSeries(
            name=name,
            description=
            'environmental electrode recorded inline with neural data',
            data=all_channels_lfp_data[:, channel],
            rate=lfp_fs,
            unit='V',
            #conversion=np.nan,
            resolution=np.nan)
        nwbfile.add_acquisition(ts)

    # compute filtered LFP
    print('filtering LFP...', end='', flush=True)
    all_lfp_phases = []
    for passband in ('theta', 'gamma'):
        lfp_fft = filter_lfp(
            lfp_data[:, all_shank_channels == lfp_channel].ravel(),
            lfp_fs,
            passband=passband)
        lfp_phase, _ = hilbert_lfp(lfp_fft)
        all_lfp_phases.append(lfp_phase[:, np.newaxis])
    data = np.dstack(all_lfp_phases)
    print('done.', flush=True)

    if include_spike_waveforms:
        print('writing waveforms...', end='', flush=True)
        nshanks = min((max_shanks, len(ns.get_shank_channels(session_path))))

        for shankn in np.arange(nshanks, dtype=int) + 1:
            # Get spike activty from .spk file on a per-shank and per-sample basis
            ns.write_spike_waveforms(nwbfile, session_path, shankn, stub=stub)
        print('done.', flush=True)

    # Get the LFP Decomposition Series
    decomp_series = DecompositionSeries(
        name='LFPDecompositionSeries',
        description='Theta and Gamma phase for reference LFP',
        data=data,
        rate=lfp_fs,
        source_timeseries=lfp_ts,
        metric='phase',
        unit='radians')
    decomp_series.add_band(band_name='theta', band_limits=(4, 10))
    decomp_series.add_band(band_name='gamma', band_limits=(30, 80))

    check_module(nwbfile, 'ecephys',
                 'contains processed extracellular electrophysiology data'
                 ).add_data_interface(decomp_series)

    [nwbfile.add_stimulus(x) for x in ns.get_events(session_path)]

    # create epochs corresponding to experiments/environments for the mouse

    sleep_state_fpath = os.path.join(session_path,
                                     '{}--StatePeriod.mat'.format(session_id))

    exist_pos_data = any(
        os.path.isfile(
            os.path.join(session_path, '{}__{}.mat'.format(
                session_id, task_type['name']))) for task_type in task_types)

    if exist_pos_data:
        nwbfile.add_epoch_column('label', 'name of epoch')

    for task_type in task_types:
        label = task_type['name']

        file = os.path.join(session_path, session_id + '__' + label + '.mat')
        if os.path.isfile(file):
            print('loading position for ' + label + '...', end='', flush=True)

            pos_obj = Position(name=label + '_position')

            matin = loadmat(file)
            tt = matin['twhl_norm'][:, 0]
            exp_times = find_discontinuities(tt)

            if 'conversion' in task_type:
                conversion = task_type['conversion']
            else:
                conversion = np.nan

            for pos_type in ('twhl_norm', 'twhl_linearized'):
                if pos_type in matin:
                    pos_data_norm = matin[pos_type][:, 1:]

                    spatial_series_object = SpatialSeries(
                        name=label + '_{}_spatial_series'.format(pos_type),
                        data=H5DataIO(pos_data_norm, compression='gzip'),
                        reference_frame='unknown',
                        conversion=conversion,
                        resolution=np.nan,
                        timestamps=H5DataIO(tt, compression='gzip'))
                    pos_obj.add_spatial_series(spatial_series_object)

            check_module(
                nwbfile, 'behavior',
                'contains processed behavioral data').add_data_interface(
                    pos_obj)
            for i, window in enumerate(exp_times):
                nwbfile.add_epoch(start_time=window[0],
                                  stop_time=window[1],
                                  label=label + '_' + str(i))
            print('done.')

    # there are occasional mismatches between the matlab struct and the neuroscope files
    # regions: 3: 'CA3', 4: 'DG'

    trialdata_path = os.path.join(session_path,
                                  session_id + '__EightMazeRun.mat')
    if os.path.isfile(trialdata_path):
        trials_data = loadmat(trialdata_path)['EightMazeRun']

        trialdatainfo_path = os.path.join(fpath_base, 'EightMazeRunInfo.mat')
        trialdatainfo = [
            x[0] for x in loadmat(trialdatainfo_path)['EightMazeRunInfo'][0]
        ]

        features = trialdatainfo[:7]
        features[:2] = 'start_time', 'stop_time',
        [
            nwbfile.add_trial_column(x, 'description')
            for x in features[4:] + ['condition']
        ]

        for trial_data in trials_data:
            if trial_data[3]:
                cond = 'run_left'
            else:
                cond = 'run_right'
            nwbfile.add_trial(start_time=trial_data[0],
                              stop_time=trial_data[1],
                              condition=cond,
                              error_run=trial_data[4],
                              stim_run=trial_data[5],
                              both_visit=trial_data[6])
    """
    mono_syn_fpath = os.path.join(session_path, session_id+'-MonoSynConvClick.mat')

    matin = loadmat(mono_syn_fpath)
    exc = matin['FinalExcMonoSynID']
    inh = matin['FinalInhMonoSynID']

    #exc_obj = CatCellInfo(name='excitatory_connections',
    #                      indices_values=[], cell_index=exc[:, 0] - 1, indices=exc[:, 1] - 1)
    #module_cellular.add_container(exc_obj)
    #inh_obj = CatCellInfo(name='inhibitory_connections',
    #                      indices_values=[], cell_index=inh[:, 0] - 1, indices=inh[:, 1] - 1)
    #module_cellular.add_container(inh_obj)
    """

    if os.path.isfile(sleep_state_fpath):
        matin = loadmat(sleep_state_fpath)['StatePeriod']

        table = TimeIntervals(name='states',
                              description='sleep states of animal')
        table.add_column(name='label', description='sleep state')

        data = []
        for name in matin.dtype.names:
            for row in matin[name][0][0]:
                data.append({
                    'start_time': row[0],
                    'stop_time': row[1],
                    'label': name
                })
        [
            table.add_row(**row)
            for row in sorted(data, key=lambda x: x['start_time'])
        ]

        check_module(nwbfile, 'behavior',
                     'contains behavioral data').add_data_interface(table)

    print('writing NWB file...', end='', flush=True)
    if stub:
        out_fname = session_path + '_stub.nwb'
    else:
        out_fname = session_path + '.nwb'

    with NWBHDF5IO(out_fname, mode='w') as io:
        io.write(nwbfile, cache_spec=cache_spec)
    print('done.')

    print('testing read...', end='', flush=True)
    # test read
    with NWBHDF5IO(out_fname, mode='r') as io:
        io.read()
    print('done.')