Exemple #1
0
    def test_init_delayed_bands(self):
        timeseries = TimeSeries(name='dummy timeseries',
                                description='desc',
                                data=np.ones((3, 3)),
                                unit='Volts',
                                timestamps=np.ones((3, )))
        spec_anal = DecompositionSeries(name='LFPSpectralAnalysis',
                                        description='my description',
                                        data=np.ones((3, 3, 3)),
                                        timestamps=np.ones((3, )),
                                        source_timeseries=timeseries,
                                        metric='amplitude')
        for band_name in ['alpha', 'beta', 'gamma']:
            spec_anal.add_band(band_name=band_name,
                               band_limits=(1., 1.),
                               band_mean=1.,
                               band_stdev=1.)

        self.assertEqual(spec_anal.name, 'LFPSpectralAnalysis')
        self.assertEqual(spec_anal.description, 'my description')
        np.testing.assert_equal(spec_anal.data, np.ones((3, 3, 3)))
        np.testing.assert_equal(spec_anal.timestamps, np.ones((3, )))
        self.assertEqual(spec_anal.bands['band_name'].data,
                         ['alpha', 'beta', 'gamma'])
        np.testing.assert_equal(spec_anal.bands['band_limits'].data,
                                np.ones((3, 2)))
        self.assertEqual(spec_anal.source_timeseries, timeseries)
        self.assertEqual(spec_anal.metric, 'amplitude')
Exemple #2
0
    def test_init(self):
        timeseries = TimeSeries(name='dummy timeseries',
                                description='desc',
                                data=np.ones((3, 3)),
                                unit='Volts',
                                timestamps=np.ones((3, )))
        bands = DynamicTable(name='bands',
                             description='band info for LFPSpectralAnalysis',
                             columns=[
                                 VectorData(name='band_name',
                                            description='name of bands',
                                            data=['alpha', 'beta', 'gamma']),
                                 VectorData(
                                     name='band_limits',
                                     description='low and high cutoffs in Hz',
                                     data=np.ones((3, 2)))
                             ])
        spec_anal = DecompositionSeries(name='LFPSpectralAnalysis',
                                        description='my description',
                                        data=np.ones((3, 3, 3)),
                                        timestamps=np.ones((3, )),
                                        source_timeseries=timeseries,
                                        metric='amplitude',
                                        bands=bands)

        self.assertEqual(spec_anal.name, 'LFPSpectralAnalysis')
        self.assertEqual(spec_anal.description, 'my description')
        np.testing.assert_equal(spec_anal.data, np.ones((3, 3, 3)))
        np.testing.assert_equal(spec_anal.timestamps, np.ones((3, )))
        self.assertEqual(spec_anal.bands['band_name'].data,
                         ['alpha', 'beta', 'gamma'])
        np.testing.assert_equal(spec_anal.bands['band_limits'].data,
                                np.ones((3, 2)))
        self.assertEqual(spec_anal.source_timeseries, timeseries)
        self.assertEqual(spec_anal.metric, 'amplitude')
    def setUp(self):
        data = np.random.rand(160, 2, 3)

        self.ds = DecompositionSeries(name='Test Decomposition',
                                      data=data,
                                      metric='amplitude',
                                      rate=1.0)
Exemple #4
0
    def setUpContainer(self):
        self.timeseries = TimeSeries(name='dummy timeseries',
                                     description='desc',
                                     data=np.ones((3, 3)),
                                     unit='flibs',
                                     timestamps=np.ones((3, )))
        bands = DynamicTable(name='bands',
                             description='band info for LFPSpectralAnalysis',
                             columns=[
                                 VectorData(name='band_name',
                                            description='name of bands',
                                            data=['alpha', 'beta', 'gamma']),
                                 VectorData(
                                     name='band_limits',
                                     description='low and high cutoffs in Hz',
                                     data=np.ones((3, 2)))
                             ])
        spec_anal = DecompositionSeries(name='LFPSpectralAnalysis',
                                        description='my description',
                                        data=np.ones((3, 3, 3)),
                                        timestamps=np.ones((3, )),
                                        source_timeseries=self.timeseries,
                                        metric='amplitude',
                                        bands=bands)

        return spec_anal
Exemple #5
0
 def test_add_decomposition_series(self):
     lfp = LFP()
     timeseries = TimeSeries(name='dummy timeseries',
                             description='desc',
                             data=np.ones((3, 3)),
                             unit='Volts',
                             timestamps=np.ones((3, )))
     spec_anal = DecompositionSeries(name='LFPSpectralAnalysis',
                                     description='my description',
                                     data=np.ones((3, 3, 3)),
                                     timestamps=np.ones((3, )),
                                     source_timeseries=timeseries,
                                     metric='amplitude')
     lfp.add_decomposition_series(spec_anal)
Exemple #6
0
    def test_init_with_source_channels(self):
        self.make_electrode_table(self)
        region = DynamicTableRegion(name='source_channels',
                                    data=[0, 2],
                                    description='the first and third electrodes',
                                    table=self.table)
        data = np.random.randn(100, 2, 30)
        timestamps = np.arange(100)/100
        ds = DecompositionSeries(name='test_DS',
                                 data=data,
                                 source_channels=region,
                                 timestamps=timestamps,
                                 metric='amplitude')

        self.assertIs(ds.source_channels, region)
Exemple #7
0
 def setUpContainer(self):
     """ Return the test ElectricalSeries to read/write """
     self.make_electrode_table(self)
     region = DynamicTableRegion(
         name='source_channels',
         data=[0, 2],
         description='the first and third electrodes',
         table=self.table)
     data = np.random.randn(100, 2, 30)
     timestamps = np.arange(100) / 100
     ds = DecompositionSeries(name='test_DS',
                              data=data,
                              source_channels=region,
                              timestamps=timestamps,
                              metric='amplitude')
     return ds
    def convert_data(self,
                     nwbfile: NWBFile,
                     metadata_dict: dict,
                     stub_test: bool = False):
        session_path = self.input_args['folder_path']
        # TODO: check/enforce format?
        all_shank_channels = metadata_dict['all_shank_channels']
        special_electrode_dict = metadata_dict['special_electrodes']
        lfp_channel = metadata_dict['lfp_channel']
        lfp_sampling_rate = metadata_dict['lfp_sampling_rate']
        spikes_nsamples = metadata_dict['spikes_nsamples']
        shank_channels = metadata_dict['shank_channels']

        subject_path, session_id = os.path.split(session_path)

        _, all_channels_lfp_data = read_lfp(session_path, stub=stub_test)
        lfp_data = all_channels_lfp_data[:, all_shank_channels]
        lfp_ts = write_lfp(nwbfile,
                           lfp_data,
                           lfp_sampling_rate,
                           name=metadata_dict['lfp']['name'],
                           description=metadata_dict['lfp']['description'],
                           electrode_inds=None)

        # TODO: error checking on format?
        for special_electrode in special_electrode_dict:
            ts = TimeSeries(
                name=special_electrode['name'],
                description=special_electrode['description'],
                data=all_channels_lfp_data[:, special_electrode['channel']],
                rate=lfp_sampling_rate,
                unit='V',
                resolution=np.nan)
            nwbfile.add_acquisition(ts)

        # TODO: discuss/consider more robust checking well prior to this
        # when missing experimental sheets for a subject, the lfp_channel cannot be determined(?)
        # which causes uninformative downstream errors at this step because lfp_channel is None
        # (get_reference_electrode does throw a warning, though)
        if lfp_channel is not None:
            all_lfp_phases = []
            for passband in ('theta', 'gamma'):
                lfp_fft = filter_lfp(
                    lfp_data[:, all_shank_channels == lfp_channel].ravel(),
                    lfp_sampling_rate,
                    passband=passband)
                lfp_phase, _ = hilbert_lfp(lfp_fft)
                all_lfp_phases.append(lfp_phase[:, np.newaxis])
            decomp_series_data = np.dstack(all_lfp_phases)

            # TODO: should units or metrics be metadata?
            decomp_series = DecompositionSeries(
                name=metadata_dict['lfp_decomposition']['name'],
                description=metadata_dict['lfp_decomposition']['description'],
                data=decomp_series_data,
                rate=lfp_sampling_rate,
                source_timeseries=lfp_ts,
                metric='phase',
                unit='radians')
            # TODO: the band limits should be extracted from parse_passband in band_analysis?
            decomp_series.add_band(band_name='theta', band_limits=(4, 10))
            decomp_series.add_band(band_name='gamma', band_limits=(30, 80))

            check_module(
                nwbfile, 'ecephys',
                'contains processed extracellular electrophysiology data'
            ).add_data_interface(decomp_series)

        write_spike_waveforms(nwbfile,
                              session_path,
                              spikes_nsamples=spikes_nsamples,
                              shank_channels=shank_channels,
                              stub_test=stub_test)
Exemple #9
0
def store_wavelet_transform(elec_series, processing, filters='rat', hg_only=True, X_fft_h=None,
                            abs_only=True, npad=1000, post_resample_rate=None):
    """Apply a wavelet transform using a prespecified set of filters. Results are stored in the
    NWB file as a `DecompositionSeries`.

    Calculates the center frequencies and bandwidths for the wavelets and applies them along with
    a heavyside function to the fft of the signal before performing an inverse fft. The center
    frequencies and bandwidths are also stored in the NWB file.

    Parameters
    ----------
    elec_series : ElectricalSeries
        ElectricalSeries to process.
    processing : Processing module
        NWB Processing module to save processed data.
    filters : str (optional)
        Which type of filters to use. Options are
        'rat': center frequencies spanning 2-1200 Hz, constant Q, 54 bands
        'human': center frequencies spanning 4-200 Hz, constant Q, 40 bands
        'changlab': center frequencies spanning 4-200 Hz, variable Q, 40 bands
    hg_only : bool
        If True, only the amplitudes in the high gamma range [70-150 Hz] is computed.
    X_fft_h : ndarray (n_time, n_channels)
        Precomputed product of X_fft and heavyside.
    abs_only : bool
        If True, only the amplitude is stored.
    npad : int
        Padding to add to beginning and end of timeseries. Default 1000.
    post_resample_rate : float
        If not `None`, resample the computed wavelet amplitudes to this rate.

    Returns
    -------
    X_wvlt : ndarray, complex
        Complex wavelet coefficients.
    series : list of DecompositionSeries
        List of NWB objects.
    """
    X = elec_series.data[:]
    rate = elec_series.rate
    X_wvlt, _, cfs, sds = wavelet_transform(X, rate, filters=filters, X_fft_h=X_fft_h,
                                            hg_only=hg_only, npad=npad)
    amplitude = abs(X_wvlt)
    if post_resample_rate is not None:
        amplitude = resample(amplitude, post_resample_rate, rate)
        rate = post_resample_rate
    elec_series_wvlt_amp = DecompositionSeries('wvlt_amp_' + elec_series.name,
                                               abs(X_wvlt),
                                               metric='amplitude',
                                               source_timeseries=elec_series,
                                               starting_time=elec_series.starting_time,
                                               rate=rate,
                                               description=('Wavlet: ' +
                                                            elec_series.description))
    series = [elec_series_wvlt_amp]
    if not abs_only:
        if post_resample_rate is not None:
            raise ValueError('Wavelet phase should not be resampled.')
        elec_series_wvlt_phase = DecompositionSeries('wvlt_phase_' + elec_series.name,
                                                     np.angle(X_wvlt),
                                                     metric='phase',
                                                     source_timeseries=elec_series,
                                                     starting_time=elec_series.starting_time,
                                                     rate=rate,
                                                     description=('Wavlet: ' +
                                                                  elec_series.description))
        series.append(elec_series_wvlt_phase)

    for es in series:
        for ii, (cf, sd) in enumerate(zip(cfs, sds)):
            es.add_band(band_name=str(ii), band_mean=cf,
                        band_stdev=sd, band_limits=(-1, -1))

        processing.add(es)
    return X_wvlt, series
Exemple #10
0
def spectral_decomposition(block_path, bands_vals):
    """
    Takes preprocessed LFP data and does the standard Hilbert transform on
    different bands. Takes about 20 minutes to run on 1 10-min block.

    Parameters
    ----------
    block_path : str
        subject file path
    bands_vals : [2,nBands] numpy array with Gaussian filter parameters, where:
        bands_vals[0,:] = filter centers [Hz]
        bands_vals[1,:] = filter sigmas [Hz]

    Returns
    -------
    Saves spectral power (DecompositionSeries) in the current NWB file.
    Only if container for this data do not exist in the file.
    """

    # Get filter parameters
    band_param_0 = bands_vals[0, :]
    band_param_1 = bands_vals[1, :]

    with NWBHDF5IO(block_path, 'r+', load_namespaces=True) as io:
        nwb = io.read()
        lfp = nwb.processing['ecephys'].data_interfaces[
            'LFP'].electrical_series['preprocessed']
        rate = lfp.rate

        nBands = len(band_param_0)
        nSamples = lfp.data.shape[0]
        nChannels = lfp.data.shape[1]
        Xp = np.zeros(
            (nBands, nChannels, nSamples))  #power (nBands,nChannels,nSamples)

        # Apply Hilbert transform ----------------------------------------------
        print('Running Spectral Decomposition...')
        start = time.time()
        for ch in np.arange(nChannels):
            Xch = lfp.data[:,
                           ch] * 1e6  # 1e6 scaling helps with numerical accuracy
            Xch = Xch.reshape(1, -1)
            Xch = Xch.astype('float32')  # signal (nChannels,nSamples)
            X_fft_h = None
            for ii, (bp0, bp1) in enumerate(zip(band_param_0, band_param_1)):
                kernel = gaussian(Xch, rate, bp0, bp1)
                X_analytic, X_fft_h = hilbert_transform(Xch,
                                                        rate,
                                                        kernel,
                                                        phase=None,
                                                        X_fft_h=X_fft_h)
                Xp[ii, ch, :] = abs(X_analytic).astype('float32')
        print('Spectral Decomposition finished in {} seconds'.format(
            time.time() - start))

        # data: (ndarray) dims: num_times * num_channels * num_bands
        Xp = np.swapaxes(Xp, 0, 2)

        # Spectral band power
        # bands: (DynamicTable) frequency bands that signal was decomposed into
        band_param_0V = VectorData(
            name='filter_param_0',
            description='frequencies for bandpass filters',
            data=band_param_0)
        band_param_1V = VectorData(
            name='filter_param_1',
            description='frequencies for bandpass filters',
            data=band_param_1)
        bandsTable = DynamicTable(
            name='bands',
            description='Series of filters used for Hilbert transform.',
            columns=[band_param_0V, band_param_1V],
            colnames=['filter_param_0', 'filter_param_1'])
        decs = DecompositionSeries(
            name='DecompositionSeries',
            data=Xp,
            description='Analytic amplitude estimated with Hilbert transform.',
            metric='amplitude',
            unit='V',
            bands=bandsTable,
            rate=rate,
            source_timeseries=lfp)

        # Storage of spectral decomposition on NWB file ------------------------
        ecephys_module = nwb.processing['ecephys']
        ecephys_module.add_data_interface(decs)
        io.write(nwb)
        print('Spectral decomposition saved in ' + block_path)
Exemple #11
0
def copy_obj(obj_old, nwb_old, nwb_new):
    """ Creates a copy of obj_old. """

    # ElectricalSeries --------------------------------------------------------
    if type(obj_old) is ElectricalSeries:
        nChannels = obj_old.electrodes.table['x'].data.shape[0]
        elecs_region = nwb_new.electrodes.create_region(
            name='electrodes',
            region=np.arange(nChannels).tolist(),
            description=''
        )
        return ElectricalSeries(
            name=obj_old.name,
            data=obj_old.data[:],
            electrodes=elecs_region,
            rate=obj_old.rate,
            description=obj_old.description
        )

    # DynamicTable ------------------------------------------------------------
    if type(obj_old) is DynamicTable:
        return DynamicTable(
            name=obj_old.name,
            description=obj_old.description,
            colnames=obj_old.colnames,
            columns=obj_old.columns,
        )

    # LFP ---------------------------------------------------------------------
    if type(obj_old) is LFP:
        obj = LFP(name=obj_old.name)
        assert len(obj_old.electrical_series) == 1, (
                'Expected precisely one electrical series, got %i!' %
                len(obj_old.electrical_series))
        els = list(obj_old.electrical_series.values())[0]
        nChannels = els.data.shape[1]

        ####
        # first check for a table among the new file's data_interfaces
        if els.electrodes.table.name in nwb_new.processing[
            'ecephys'].data_interfaces:
            LFP_dynamic_table = nwb_new.processing['ecephys'].data_interfaces[
                els.electrodes.table.name]
        else:
            # othewise use the electrodes as the table
            LFP_dynamic_table = nwb_new.electrodes
        ####

        elecs_region = LFP_dynamic_table.create_region(
            name='electrodes',
            region=[i for i in range(nChannels)],
            description=els.electrodes.description
        )

        obj_ts = obj.create_electrical_series(
            name=els.name,
            comments=els.comments,
            conversion=els.conversion,
            data=els.data[:],
            description=els.description,
            electrodes=elecs_region,
            rate=els.rate,
            resolution=els.resolution,
            starting_time=els.starting_time
        )

        return obj

    # TimeSeries --------------------------------------------------------------
    if type(obj_old) is TimeSeries:
        return TimeSeries(
            name=obj_old.name,
            description=obj_old.description,
            data=obj_old.data[:],
            rate=obj_old.rate,
            resolution=obj_old.resolution,
            conversion=obj_old.conversion,
            starting_time=obj_old.starting_time,
            unit=obj_old.unit
        )

    # DecompositionSeries -----------------------------------------------------
    if type(obj_old) is DecompositionSeries:
        list_columns = []
        for item in obj_old.bands.columns:
            bp = VectorData(
                name=item.name,
                description=item.description,
                data=item.data[:]
            )
            list_columns.append(bp)
        bandsTable = DynamicTable(
            name=obj_old.bands.name,
            description=obj_old.bands.description,
            columns=list_columns,
            colnames=obj_old.bands.colnames
        )
        return DecompositionSeries(
            name=obj_old.name,
            data=obj_old.data[:],
            description=obj_old.description,
            metric=obj_old.metric,
            unit=obj_old.unit,
            rate=obj_old.rate,
            # source_timeseries=lfp,
            bands=bandsTable,
        )

    # Spectrum ----------------------------------------------------------------
    if type(obj_old) is Spectrum:
        file_elecs = nwb_new.electrodes
        nChannels = len(file_elecs['x'].data[:])
        elecs_region = file_elecs.create_region(
            name='electrodes',
            region=np.arange(nChannels).tolist(),
            description=''
        )
        return Spectrum(
            name=obj_old.name,
            frequencies=obj_old.frequencies[:],
            power=obj_old.power,
            electrodes=elecs_region
        )
Exemple #12
0
    def run_conversion(self, nwbfile: NWBFile, metadata: dict, stub_test: bool = False):
        super().run_conversion(nwbfile=nwbfile, metadata=metadata, stub_test=stub_test)

        session_path = Path(self.source_data["file_path"]).parent
        session_id = session_path.name
        subject_path = session_path.parent

        xml_filepath = session_path / f"{session_id}.xml"
        root = et.parse(str(xml_filepath)).getroot()
        n_total_channels = int(root.find("acquisitionSystem").find("nChannels").text)
        lfp_sampling_rate = float(root.find("fieldPotentials").find("lfpSamplingRate").text)
        shank_channels = [
            [int(channel.text) for channel in group.find("channels")]
            for group in root.find("spikeDetection").find("channelGroups").findall("group")
        ]
        all_shank_channels = np.concatenate(shank_channels)  # Flattened

        # Special electrodes
        special_electrode_mapping = dict(
            ch_wait=79,
            ch_arm=78,
            ch_solL=76,
            ch_solR=77,
            ch_dig1=65,
            ch_dig2=68,
            ch_entL=72,
            ch_entR=71,
            ch_SsolL=73,
            ch_SsolR=70,
        )
        special_electrodes = []
        for special_electrode_name, channel in special_electrode_mapping.items():
            if channel <= n_total_channels - 1:
                special_electrodes.append(
                    dict(
                        name=special_electrode_name,
                        channel=channel,
                        description="Environmental electrode recorded inline with neural data.",
                    )
                )
        _, all_channels_lfp_data = read_lfp(session_path, stub=stub_test)
        for special_electrode in special_electrodes:
            ts = TimeSeries(
                name=special_electrode["name"],
                description=special_electrode["description"],
                data=all_channels_lfp_data[:, special_electrode["channel"]],
                rate=lfp_sampling_rate,
                unit="V",
                resolution=np.nan,
            )
            nwbfile.add_acquisition(ts)

        # DecompositionSeries
        mouse_number = session_id[-9:-7]
        subject_xls = str(subject_path / f"DGProject/YM{mouse_number} exp_sheet.xlsx")
        hilus_csv_path = str(subject_path / "DGProject/early_session_hilus_chans.csv")
        session_start = metadata["NWBFile"]["session_start_time"]
        if "-" in session_id:
            b = False
        else:
            b = True
        lfp_channel = get_reference_elec(subject_xls, hilus_csv_path, session_start, session_id, b=b)
        if lfp_channel is not None:
            lfp_data = all_channels_lfp_data[:, all_shank_channels]
            all_lfp_phases = []
            for passband in ("theta", "gamma"):
                lfp_fft = filter_lfp(
                    lfp_data[:, all_shank_channels == lfp_channel].ravel(),
                    lfp_sampling_rate,
                    passband=passband,
                )
                lfp_phase, _ = hilbert_lfp(lfp_fft)
                all_lfp_phases.append(lfp_phase[:, np.newaxis])
            decomp_series_data = np.dstack(all_lfp_phases)
            ecephys_mod = check_module(
                nwbfile,
                "ecephys",
                "Intermediate data from extracellular electrophysiology recordings, e.g., LFP.",
            )
            lfp_ts = ecephys_mod.data_interfaces["LFP"]["LFP"]
            decomp_series = DecompositionSeries(
                name="LFPDecompositionSeries",
                description="Theta and Gamma phase for reference LFP",
                data=decomp_series_data,
                rate=lfp_sampling_rate,
                source_timeseries=lfp_ts,
                metric="phase",
                unit="radians",
            )
            # TODO: the band limits should be extracted from parse_passband in band_analysis?
            decomp_series.add_band(band_name="theta", band_limits=(4, 10))
            decomp_series.add_band(band_name="gamma", band_limits=(30, 80))
            check_module(
                nwbfile,
                "ecephys",
                "Contains processed extracellular electrophysiology data.",
            ).add(decomp_series)
Exemple #13
0
def copy_obj(obj_old, nwb_old, nwb_new):
    """ Creates a copy of obj_old. """

    obj = None
    obj_type = type(obj_old).__name__

    #ElectricalSeries ----------------------------------------------------------
    if obj_type == 'ElectricalSeries':
        nChannels = obj_old.electrodes.table['x'].data.shape[0]
        elecs_region = nwb_new.electrodes.create_region(
            name='electrodes',
            region=np.arange(nChannels).tolist(),
            description='')
        obj = ElectricalSeries(name=obj_old.name,
                               data=obj_old.data[:],
                               electrodes=elecs_region,
                               rate=obj_old.rate,
                               description=obj_old.description)

    #LFP -----------------------------------------------------------------------
    if obj_type == 'LFP':
        obj = LFP(name=obj_old.name)
        els_name = list(obj_old.electrical_series.keys())[0]
        els = obj_old.electrical_series[els_name]
        nChannels = els.data.shape[1]
        elecs_region = nwb_new.electrodes.create_region(
            name='electrodes',
            region=np.arange(nChannels).tolist(),
            description='')
        obj_ts = obj.create_electrical_series(name=els.name,
                                              comments=els.comments,
                                              conversion=els.conversion,
                                              data=els.data[:],
                                              description=els.description,
                                              electrodes=elecs_region,
                                              rate=els.rate,
                                              resolution=els.resolution,
                                              starting_time=els.starting_time)

    #TimeSeries ----------------------------------------------------------------
    elif obj_type == 'TimeSeries':
        obj = TimeSeries(name=obj_old.name,
                         description=obj_old.description,
                         data=obj_old.data[:],
                         rate=obj_old.rate,
                         resolution=obj_old.resolution,
                         conversion=obj_old.conversion,
                         starting_time=obj_old.starting_time,
                         unit=obj_old.unit)

    #DecompositionSeries -------------------------------------------------------
    elif obj_type == 'DecompositionSeries':
        list_columns = []
        for item in obj_old.bands.columns:
            bp = VectorData(name=item.name,
                            description=item.description,
                            data=item.data[:])
            list_columns.append(bp)
        bandsTable = DynamicTable(name=obj_old.bands.name,
                                  description=obj_old.bands.description,
                                  columns=list_columns,
                                  colnames=obj_old.bands.colnames)
        obj = DecompositionSeries(
            name=obj_old.name,
            data=obj_old.data[:],
            description=obj_old.description,
            metric=obj_old.metric,
            unit=obj_old.unit,
            rate=obj_old.rate,
            #source_timeseries=lfp,
            bands=bandsTable,
        )

    #Spectrum ------------------------------------------------------------------
    elif obj_type == 'Spectrum':
        file_elecs = nwb_new.electrodes
        nChannels = len(file_elecs['x'].data[:])
        elecs_region = file_elecs.create_region(
            name='electrodes',
            region=np.arange(nChannels).tolist(),
            description='')
        obj = Spectrum(name=obj_old.name,
                       frequencies=obj_old.frequencies[:],
                       power=obj_old.power,
                       electrodes=elecs_region)

    return obj
Exemple #14
0
 def create_acquisition(self):
     """
     Acquisition data like audiospectrogram(raw beh data), nidq(raw ephys data), raw camera data.
     These are independent of probe type.
     """
     for neurodata_type_name, neurodata_type_args_list in self.nwb_metadata[
             'Acquisition'].items():
         data_retrieved_args_list = self._get_data(neurodata_type_args_list)
         for neurodata_type_args in data_retrieved_args_list:
             if neurodata_type_name == 'ImageSeries':
                 for types, times in zip(neurodata_type_args['data'],
                                         neurodata_type_args['timestamps']):
                     customargs = dict(name='camera_raw',
                                       external_file=[str(types)],
                                       format='external',
                                       timestamps=times,
                                       unit='n.a.')
                     self.nwbfile.add_acquisition(ImageSeries(**customargs))
             elif neurodata_type_name == 'DecompositionSeries':
                 neurodata_type_args['bands'] = np.squeeze(
                     neurodata_type_args['bands'])
                 freqs = DynamicTable(
                     'bands',
                     'spectogram frequencies',
                     id=np.arange(neurodata_type_args['bands'].shape[0]))
                 freqs.add_column('freq',
                                  'frequency value',
                                  data=neurodata_type_args['bands'])
                 neurodata_type_args.update(dict(bands=freqs))
                 temp = neurodata_type_args['data'][:, :, np.newaxis]
                 neurodata_type_args['data'] = np.moveaxis(
                     temp, [0, 1, 2], [0, 2, 1])
                 ts = neurodata_type_args.pop('timestamps')
                 starting_time = ts[0][0] if isinstance(
                     ts[0], np.ndarray) else ts[0]
                 neurodata_type_args.update(
                     dict(starting_time=np.float64(starting_time),
                          rate=1 / np.mean(np.diff(ts.squeeze())),
                          unit='sec'))
                 self.nwbfile.add_acquisition(
                     DecompositionSeries(**neurodata_type_args))
             elif neurodata_type_name == 'ElectricalSeries':
                 if not self.electrode_table_exist:
                     self.create_electrode_table_ecephys()
                 if neurodata_type_args['name'] in ['raw.lf', 'raw.ap']:
                     for probe_no in range(self.no_probes):
                         if neurodata_type_args['data'][probe_no].shape[
                                 1] > self._one_data.data_attrs_dump[
                                     'electrode_table_length'][probe_no]:
                             if 'channels.rawInd' in self._one_data.loaded_datasets:
                                 channel_idx = self._one_data.loaded_datasets[
                                     'channels.rawInd'][
                                         probe_no].data.astype('int')
                             else:
                                 warnings.warn(
                                     'could not find channels.rawInd')
                                 break
                         else:
                             channel_idx = slice(None)
                         self.nwbfile.add_acquisition(
                             ElectricalSeries(
                                 name=neurodata_type_args['name'] + '_' +
                                 self.nwb_metadata['Probes'][probe_no]
                                 ['name'],
                                 starting_time=np.abs(
                                     np.round(
                                         neurodata_type_args['timestamps']
                                         [probe_no][0, 1], 2)
                                 ),  # round starting times of the order of 1e-5
                                 rate=neurodata_type_args['data']
                                 [probe_no].fs,
                                 data=H5DataIO(
                                     DataChunkIterator(
                                         _iter_datasetview(
                                             neurodata_type_args['data']
                                             [probe_no],
                                             channel_ids=channel_idx),
                                         buffer_size=self.buffer_size),
                                     compression=True,
                                     shuffle=self.shuffle,
                                     compression_opts=self.complevel),
                                 electrodes=self.probe_dt_region[probe_no],
                                 channel_conversion=neurodata_type_args[
                                     'data']
                                 [probe_no].channel_conversion_sample2v[
                                     neurodata_type_args['data']
                                     [probe_no].type][channel_idx]))
                 elif neurodata_type_args['name'] in ['raw.nidq']:
                     self.nwbfile.add_acquisition(
                         ElectricalSeries(**neurodata_type_args))
Exemple #15
0
def chang2nwb(blockpath,
              outpath=None,
              session_start_time=None,
              session_description=None,
              identifier=None,
              anin4=False,
              ecog_format='auto',
              external_subject=True,
              include_pitch=False,
              include_intensity=False,
              speakers=True,
              mic=False,
              mini=False,
              hilb=False,
              verbose=False,
              imaging_path=None,
              parse_transcript=False,
              include_cortical_surfaces=True,
              include_electrodes=True,
              include_ekg=True,
              subject_image_list=None,
              rest_period=None,
              load_warped=False,
              **kwargs):
    """

    Parameters
    ----------
    blockpath: str
    outpath: None | str
        if None, output = [blockpath]/[blockname].nwb
    session_start_time: datetime.datetime
        default: datetime(1900, 1, 1)
    session_description: str
        default: blockname
    identifier: str
        default: blockname
    anin4: False | str
        Whether or not to convert ANIN4. ANIN4 is used as an extra channel for
        things like button presses, and is usually unused. If a string is
        supplied, that is used as the name of the timeseries.
    ecog_format: str
        ({'htk'}, 'mat', 'raw')
    external_subject: bool (optional)
        True: (default) cortical mesh is saved in an external file and a link is
            provided to that file. This is useful if you have multiple sessions for a single subject.
        False: cortical mesh is saved normally
    include_pitch: bool (optional)
        add pitch data. Default: False
    include_intensity: bool (optional)
        add intensity data. Default: False
    speakers: bool (optional)
        Default: False
    mic: bool (optional)
        default: False
    mini: only save data stub. Used for testing
    hilb: bool
        include Hilbert Transform data. Default: False
    verbose: bool (optional)
    imaging_path: str (optional)
        None: use IMAGING_DIR
        'local': use subject_dir/Imaging/
        else: use supplied string
    parse_transcript: str (optional)
    include_cortical_surfaces: bool (optional)
    include_electrodes: bool (optional)
    include_ekg: bool (optional)
    subject_image_list: list (optional)
        List of paths of images to include
    rest_period: None | array-like
    kwargs: dict
        passed to pynwb.NWBFile

    Returns
    -------

    """

    behav_module = None

    basepath, blockname = os.path.split(blockpath)
    subject_id = get_subject_id(blockname)
    if identifier is None:
        identifier = blockname

    if session_description is None:
        session_description = blockname

    if outpath is None:
        outpath = blockpath + '.nwb'
    out_base_path = os.path.split(outpath)[0]

    if session_start_time is None:
        session_start_time = datetime(1900, 1, 1).astimezone(timezone('UTC'))

    if imaging_path is None:
        subj_imaging_path = path.join(IMAGING_PATH, subject_id)
    elif imaging_path == 'local':
        subj_imaging_path = path.join(basepath, 'imaging')
    else:
        subj_imaging_path = os.path.join(imaging_path, subject_id)

    # file paths
    bad_time_file = path.join(blockpath, 'Artifacts', 'badTimeSegments.mat')
    ecog_path = path.join(blockpath, 'RawHTK')
    ecog400_path = path.join(blockpath, 'ecog400', 'ecog.mat')
    elec_metadata_file = path.join(subj_imaging_path, 'elecs',
                                   'TDT_elecs_all.mat')
    mesh_path = path.join(subj_imaging_path, 'Meshes')
    pial_files = glob.glob(path.join(mesh_path, '*pial.mat'))

    # Create the NWB file object
    nwbfile = NWBFile(session_description,
                      identifier,
                      session_start_time,
                      datetime.now().astimezone(),
                      session_id=identifier,
                      institution='University of California, San Francisco',
                      lab='Chang Lab',
                      **kwargs)

    nwbfile.add_electrode_column('bad', 'electrode identified as too noisy')

    bad_elecs_inds = get_bad_elecs(blockpath)

    if include_electrodes:
        add_electrodes(nwbfile,
                       elec_metadata_file,
                       bad_elecs_inds,
                       load_warped=load_warped)
    else:
        device = nwbfile.create_device('256Grid')
        electrode_group = nwbfile.create_electrode_group(
            name='256Grid electrodes',
            description='auto_group',
            location='location',
            device=device)

        for elec_counter in range(256):
            bad = elec_counter in bad_elecs_inds
            nwbfile.add_electrode(id=elec_counter + 1,
                                  x=np.nan,
                                  y=np.nan,
                                  z=np.nan,
                                  imp=np.nan,
                                  location=' ',
                                  filtering='none',
                                  group=electrode_group,
                                  bad=bad)
    ecog_elecs = list(range(len(nwbfile.electrodes)))
    ecog_elecs_region = nwbfile.create_electrode_table_region(
        ecog_elecs, 'ECoG electrodes on brain')

    # Read electrophysiology data from HTK files and add them to NWB file
    if ecog_format == 'auto':
        ecog_rate, data, ecog_path = auto_ecog(blockpath,
                                               ecog_elecs,
                                               verbose=False)
    elif ecog_format == 'htk':
        if verbose:
            print('reading htk acquisition...', flush=True)
        ecog_rate, data = readhtks(ecog_path, ecog_elecs)
        data = data.squeeze()
        if verbose:
            print('done', flush=True)

    elif ecog_format == 'mat':
        with File(ecog400_path, 'r') as f:
            data = f['ecogDS']['data'][:, ecog_elecs]
            ecog_rate = f['ecogDS']['sampFreq'][:].ravel()[0]
        ecog_path = ecog400_path

    elif ecog_format == 'raw':
        ecog_path = os.path.join(tdt_data_path, subject_id, blockname,
                                 'raw.mat')
        ecog_rate, data = load_wavs(ecog_path)

    else:
        raise ValueError('unrecognized argument: ecog_format')

    ts_desc = "all Wav data"

    if mini:
        data = data[:2000]

    ecog_ts = ElectricalSeries(name='ElectricalSeries',
                               data=H5DataIO(data, compression='gzip'),
                               electrodes=ecog_elecs_region,
                               rate=ecog_rate,
                               description=ts_desc,
                               conversion=0.001)
    nwbfile.add_acquisition(ecog_ts)

    if include_ekg:
        ekg_elecs = find_ekg_elecs(elec_metadata_file)
        if len(ekg_elecs):
            add_ekg(nwbfile, ecog_path, ekg_elecs)

    if mic:
        # Add microphone recording from room
        fs, data = get_analog(blockpath, 1)
        nwbfile.add_acquisition(
            TimeSeries('microphone',
                       data,
                       'audio unit',
                       rate=fs,
                       description="audio recording from microphone in room"))
    if speakers:
        fs, data = get_analog(blockpath, 2)
        # Add audio stimulus 1
        nwbfile.add_stimulus(
            TimeSeries('speaker 1',
                       data,
                       'NA',
                       rate=fs,
                       description="audio stimulus 1"))

        # Add audio stimulus 2
        fs, data = get_analog(blockpath, 3)
        if fs is not None:
            nwbfile.add_stimulus(
                TimeSeries('speaker 2',
                           data,
                           'NA',
                           rate=fs,
                           description='the second stimulus source'))

    if anin4:
        fs, data = get_analog(blockpath, 4)
        nwbfile.add_acquisition(
            TimeSeries(anin4,
                       data,
                       'aux unit',
                       rate=fs,
                       description="aux analog recording"))

    # Add bad time segments
    if os.path.exists(bad_time_file) and os.stat(bad_time_file).st_size:
        bad_time = sio.loadmat(bad_time_file)['badTimeSegments']
        for row in bad_time:
            nwbfile.add_invalid_time_interval(start_time=row[0],
                                              stop_time=row[1],
                                              tags=('ECoG artifact', ),
                                              timeseries=ecog_ts)

    if rest_period is not None:
        nwbfile.add_epoch_column(name='label', description='label')
        nwbfile.add_epoch(start_time=rest_period[0],
                          stop_time=rest_period[1],
                          label='rest_period')

    if hilb:
        block_hilb_path = os.path.join(hilb_dir, subject_id, blockname,
                                       blockname + '_AA.h5')
        file = File(block_hilb_path, 'r')

        data = transpose_iter(
            file['X'])  # transposes data during iterative write
        filter_center = file['filter_center'][:]
        filter_sigma = file['filter_sigma'][:]

        data = H5DataIO(DataChunkIterator(tqdm(data,
                                               desc='writing hilbert data'),
                                          buffer_size=400 * 20),
                        compression='gzip')

        decomp_series = DecompositionSeries(
            name='LFPDecompositionSeries',
            description='Gaussian band Hilbert transform',
            data=data,
            rate=400.,
            source_timeseries=ecog_ts,
            metric='amplitude')

        for band_mean, band_stdev in zip(filter_center, filter_sigma):
            decomp_series.add_band(band_mean=band_mean, band_stdev=band_stdev)

        hilb_mod = nwbfile.create_processing_module(
            name='ecephys', description='holds hilbert analysis results')
        hilb_mod.add_container(decomp_series)

    if include_cortical_surfaces:
        subject = ECoGSubject(subject_id=subject_id)
        subject.cortical_surfaces = create_cortical_surfaces(
            pial_files, subject_id)
    else:
        subject = Subject(subject_id=subject_id, species='H**o sapiens')

    if subject_image_list is not None:
        subject = add_images_to_subject(subject, subject_image_list)

    if external_subject:
        subj_fpath = path.join(out_base_path, subject_id + '.nwb')
        if not os.path.isfile(subj_fpath):
            subj_nwbfile = NWBFile(session_description=subject_id,
                                   identifier=subject_id,
                                   subject=subject,
                                   session_start_time=datetime(
                                       1900, 1, 1).astimezone(timezone('UTC')))
            with NWBHDF5IO(subj_fpath, manager=manager, mode='w') as subj_io:
                subj_io.write(subj_nwbfile)
        subj_read_io = NWBHDF5IO(subj_fpath, manager=manager, mode='r')
        subj_nwbfile = subj_read_io.read()
        subject = subj_nwbfile.subject

    nwbfile.subject = subject

    if parse_transcript:
        if parse_transcript == 'CV':
            parseout = parse(blockpath, blockname)
            df = make_df(parseout, 0, subject_id, align_pos=1)
            nwbfile.add_trial_column('cv_transition_time',
                                     'time of CV transition in seconds')
            nwbfile.add_trial_column(
                'speak',
                'if True, subject is speaking. If False, subject is listening')
            nwbfile.add_trial_column('condition', 'syllable spoken')
            for _, row in df.iterrows():
                nwbfile.add_trial(start_time=row['start'],
                                  stop_time=row['stop'],
                                  cv_transition_time=row['align'],
                                  speak=row['mode'] == 'speak',
                                  condition=row['label'])
        elif parse_transcript == 'singing':
            parseout = parse(blockpath, blockname)
            df = make_df(parseout, 0, subject_id, align_pos=0)
            if not len(df):
                df = pd.DataFrame(parseout)
                df['mode'] = 'speak'

            df = df.loc[df['label'].astype('bool'), :]  # handle empty labels
            nwbfile.add_trial_column(
                'speak',
                'if True, subject is speaking. If False, subject is listening')
            nwbfile.add_trial_column('condition', 'syllable spoken')
            for _, row in df.iterrows():
                nwbfile.add_trial(start_time=row['start'],
                                  stop_time=row['stop'],
                                  speak=row['mode'] == 'speak',
                                  condition=row['label'])
        elif parse_transcript == 'emphasis':
            parseout = parse(blockpath, blockname)
            try:
                df = make_df(parseout, 0, subject_id, align_pos=0)
            except:
                df = pd.DataFrame(parseout)
            if not len(df):
                df = pd.DataFrame(parseout)
            df = df.loc[df['label'].astype('bool'), :]  # handle empty labels
            nwbfile.add_trial_column('condition', 'word emphasized')
            nwbfile.add_trial_column(
                'speak',
                'if True, subject is speaking. If False, subject is listening')
            for _, row in df.iterrows():
                nwbfile.add_trial(start_time=row['start'],
                                  stop_time=row['stop'],
                                  speak=True,
                                  condition=row['label'])
        elif parse_transcript == 'MOCHA':
            nwbfile = create_transcription(nwbfile, transcript_path, blockname)

    # behavior
    if include_pitch:
        if behav_module is None:
            behav_module = nwbfile.create_processing_module(
                'behavior', 'processing about behavior')
        if os.path.isfile(
                os.path.join(blockpath, 'pitch_' + blockname + '.mat')):
            fs, data = load_pitch(blockpath)
            pitch_ts = TimeSeries(
                data=data,
                rate=fs,
                unit='Hz',
                name='pitch',
                description=
                'Pitch as extracted from Praat. NaNs mark unvoiced regions.')
            behav_module.add_container(
                BehavioralTimeSeries(name='pitch', time_series=pitch_ts))
        else:
            print('No pitch file for ' + blockname)

    if include_intensity:
        if behav_module is None:
            behav_module = nwbfile.create_processing_module(
                'behavior', 'processing about behavior')
        if os.path.isfile(
                os.path.join(blockpath, 'intensity_' + blockname + '.mat')):
            fs, data = load_pitch(blockpath)
            intensity_ts = TimeSeries(
                data=data,
                rate=fs,
                unit='dB',
                name='intensity',
                description='Intensity of speech in dB extracted from Praat.')
            behav_module.add_container(
                BehavioralTimeSeries(name='intensity',
                                     time_series=intensity_ts))
        else:
            print('No intensity file for ' + blockname)

    # Export the NWB file
    with NWBHDF5IO(outpath, manager=manager, mode='w') as io:
        io.write(nwbfile)

    if external_subject:
        subj_read_io.close()

    if hilb:
        file.close()

    # read check
    with NWBHDF5IO(outpath, manager=manager, mode='r') as io:
        io.read()
def transform(block_path, filter='default', bands_vals=None):
    """
    Takes raw LFP data and does the standard Hilbert algorithm:
    1) CAR
    2) notch filters
    3) Hilbert transform on different bands

    Takes about 20 minutes to run on 1 10-min block.

    Parameters
    ----------
    block_path : str
        subject file path
    filter: str, optional
        Frequency bands to filter the signal.
        'default' for Chang lab default values (Gaussian filters)
        'custom' for user defined (Gaussian filters)
    bands_vals: 2D array, necessary only if filter='custom'
        [2,nBands] numpy array with Gaussian filter parameters, where:
        bands_vals[0,:] = filter centers [Hz]
        bands_vals[1,:] = filter sigmas [Hz]

    Returns
    -------
    Saves preprocessed signals (LFP) and spectral power (DecompositionSeries) in
    the current NWB file. Only if containers for these data do not exist in the
    file.
    """
    write_file = 1
    rate = 400.

    # Define filter parameters
    if filter == 'default':
        band_param_0 = bands.chang_lab['cfs']
        band_param_1 = bands.chang_lab['sds']
    elif filter == 'high_gamma':
        band_param_0 = bands.chang_lab['cfs'][(bands.chang_lab['cfs'] > 70)
                                              & (bands.chang_lab['cfs'] < 150)]
        band_param_1 = bands.chang_lab['sds'][(bands.chang_lab['cfs'] > 70)
                                              & (bands.chang_lab['cfs'] < 150)]
        #band_param_0 = [ bands.neuro['min_freqs'][-1] ]  #for hamming window filter
        #band_param_1 = [ bands.neuro['max_freqs'][-1] ]
        #band_param_0 = bands.chang_lab['cfs'][29:37]      #for average of gaussian filters
        #band_param_1 = bands.chang_lab['sds'][29:37]
    elif filter == 'custom':
        band_param_0 = bands_vals[0, :]
        band_param_1 = bands_vals[1, :]

    block_name = os.path.splitext(block_path)[0]

    start = time.time()

    with NWBHDF5IO(block_path, 'a') as io:
        nwb = io.read()

        # Storage of processed signals on NWB file -----------------------------
        if 'ecephys' not in nwb.modules:
            # Add module to NWB file
            nwb.create_processing_module(
                name='ecephys',
                description='Extracellular electrophysiology data.')
        ecephys_module = nwb.modules['ecephys']

        # LFP: Downsampled and power line signal removed
        if 'LFP' in nwb.modules['ecephys'].data_interfaces:
            lfp_ts = nwb.modules['ecephys'].data_interfaces[
                'LFP'].electrical_series['preprocessed']
            X = lfp_ts.data[:].T
            rate = lfp_ts.rate
        else:

            # 1e6 scaling helps with numerical accuracy
            X = nwb.acquisition['ECoG'].data[:].T * 1e6
            fs = nwb.acquisition['ECoG'].rate
            bad_elects = load_bad_electrodes(nwb)
            print('Load time for h5 {}: {} seconds'.format(
                block_name,
                time.time() - start))
            print('rates {}: {} {}'.format(block_name, rate, fs))
            if not np.allclose(rate, fs):
                assert rate < fs
                start = time.time()
                X = resample(X, rate, fs)
                print('resample time for {}: {} seconds'.format(
                    block_name,
                    time.time() - start))

            if bad_elects.sum() > 0:
                X[bad_elects] = np.nan

            # Subtract CAR
            start = time.time()
            X = subtract_CAR(X)
            print('CAR subtract time for {}: {} seconds'.format(
                block_name,
                time.time() - start))

            # Apply Notch filters
            start = time.time()
            X = linenoise_notch(X, rate)
            print('Notch filter time for {}: {} seconds'.format(
                block_name,
                time.time() - start))

            lfp = LFP()
            # Add preprocessed downsampled signals as an electrical_series
            lfp_ts = lfp.create_electrical_series(
                name='preprocessed',
                data=X.T,
                electrodes=nwb.acquisition['ECoG'].electrodes,
                rate=rate,
                description='')
            ecephys_module.add_data_interface(lfp)

        # Spectral band power
        if 'Bandpower_' + filter not in nwb.modules['ecephys'].data_interfaces:

            # Apply Hilbert transform
            X = X.astype('float32')  # signal (nChannels,nSamples)
            nChannels = X.shape[0]
            nSamples = X.shape[1]
            nBands = len(band_param_0)
            Xp = np.zeros((nBands, nChannels,
                           nSamples))  # power (nBands,nChannels,nSamples)
            X_fft_h = None
            for ii, (bp0, bp1) in enumerate(zip(band_param_0, band_param_1)):
                # if filter=='high_gamma':
                #    kernel = hamming(X, rate, bp0, bp1)
                # else:
                kernel = gaussian(X, rate, bp0, bp1)
                X_analytic, X_fft_h = hilbert_transform(X,
                                                        rate,
                                                        kernel,
                                                        phase=None,
                                                        X_fft_h=X_fft_h)
                Xp[ii] = abs(X_analytic).astype('float32')

            # Scales signals back to Volt
            X /= 1e6

            band_param_0V = VectorData(
                name='filter_param_0',
                description='frequencies for bandpass filters',
                data=band_param_0)
            band_param_1V = VectorData(
                name='filter_param_1',
                description='frequencies for bandpass filters',
                data=band_param_1)
            bandsTable = DynamicTable(
                name='bands',
                description='Series of filters used for Hilbert transform.',
                columns=[band_param_0V, band_param_1V],
                colnames=['filter_param_0', 'filter_param_1'])

            # data: (ndarray) dims: num_times * num_channels * num_bands
            Xp = np.swapaxes(Xp, 0, 2)
            decs = DecompositionSeries(
                name='Bandpower_' + filter,
                data=Xp,
                description='Band power estimated with Hilbert transform.',
                metric='power',
                unit='V**2/Hz',
                bands=bandsTable,
                rate=rate,
                source_timeseries=lfp_ts)
            ecephys_module.add_data_interface(decs)
        io.write(nwb)

        print('done', flush=True)
Exemple #17
0
def copy_obj(obj_old, nwb_old, nwb_new):
    """ Creates a copy of obj_old. """

    # ElectricalSeries --------------------------------------------------------
    if type(obj_old) is ElectricalSeries:
        # If reference electrodes table is bipolar scheme
        if isinstance(obj_old.electrodes.table, BipolarSchemeTable):
            bst_old = obj_old.electrodes.table
            bst_old_df = bst_old.to_dataframe()
            bst_new = nwb_new.lab_meta_data['ecephys_ext'].bipolar_scheme_table

            for id, row in bst_old_df.iterrows():
                index_anodes = row['anodes'].index.tolist()
                index_cathodes = row['cathodes'].index.tolist()
                bst_new.add_row(anodes=index_anodes, cathodes=index_cathodes)
            bst_new.anodes.table = nwb_new.electrodes
            bst_new.cathodes.table = nwb_new.electrodes

            # if there are custom columns
            new_cols = list(bst_old_df.columns)
            default_cols = ['anodes', 'cathodes']
            [new_cols.remove(col) for col in default_cols]
            for col in new_cols:
                col_data = list(bst_old[col].data[:])
                bst_new.add_column(name=str(col),
                                   description=str(bst_old[col].description),
                                   data=col_data)

            elecs_region = DynamicTableRegion(name='electrodes',
                                              data=bst_old_df.index.tolist(),
                                              description='desc',
                                              table=bst_new)
        else:
            region = np.array(obj_old.electrodes.table.id[:])[
                obj_old.electrodes.data[:]].tolist()
            elecs_region = nwb_new.create_electrode_table_region(
                name='electrodes', region=region, description='')
        els = ElectricalSeries(name=obj_old.name,
                               data=obj_old.data[:],
                               electrodes=elecs_region,
                               rate=obj_old.rate,
                               description=obj_old.description)
        return els

    # DynamicTable ------------------------------------------------------------
    if type(obj_old) is DynamicTable:
        return DynamicTable(
            name=obj_old.name,
            description=obj_old.description,
            colnames=obj_old.colnames,
            columns=obj_old.columns,
        )

    # LFP ---------------------------------------------------------------------
    if type(obj_old) is LFP:
        obj = LFP(name=obj_old.name)
        assert len(obj_old.electrical_series) == 1, (
            'Expected precisely one electrical series, got %i!' %
            len(obj_old.electrical_series))
        els = list(obj_old.electrical_series.values())[0]

        ####
        # first check for a table among the new file's data_interfaces
        if els.electrodes.table.name in nwb_new.processing[
                'ecephys'].data_interfaces:
            LFP_dynamic_table = nwb_new.processing['ecephys'].data_interfaces[
                els.electrodes.table.name]
        else:
            # othewise use the electrodes as the table
            LFP_dynamic_table = nwb_new.electrodes
        ####

        region = np.array(
            els.electrodes.table.id[:])[els.electrodes.data[:]].tolist()
        elecs_region = LFP_dynamic_table.create_region(
            name='electrodes',
            region=region,
            description=els.electrodes.description)

        obj.create_electrical_series(name=els.name,
                                     comments=els.comments,
                                     conversion=els.conversion,
                                     data=els.data[:],
                                     description=els.description,
                                     electrodes=elecs_region,
                                     rate=els.rate,
                                     resolution=els.resolution,
                                     starting_time=els.starting_time)

        return obj

    # TimeSeries --------------------------------------------------------------
    if type(obj_old) is TimeSeries:
        return TimeSeries(name=obj_old.name,
                          description=obj_old.description,
                          data=obj_old.data[:],
                          rate=obj_old.rate,
                          resolution=obj_old.resolution,
                          conversion=obj_old.conversion,
                          starting_time=obj_old.starting_time,
                          unit=obj_old.unit)

    # DecompositionSeries -----------------------------------------------------
    if type(obj_old) is DecompositionSeries:
        list_columns = []
        for item in obj_old.bands.columns:
            bp = VectorData(name=item.name,
                            description=item.description,
                            data=item.data[:])
            list_columns.append(bp)
        bandsTable = DynamicTable(name=obj_old.bands.name,
                                  description=obj_old.bands.description,
                                  columns=list_columns,
                                  colnames=obj_old.bands.colnames)
        return DecompositionSeries(
            name=obj_old.name,
            data=obj_old.data[:],
            description=obj_old.description,
            metric=obj_old.metric,
            unit=obj_old.unit,
            rate=obj_old.rate,
            # source_timeseries=lfp,
            bands=bandsTable,
        )

    # Spectrum ----------------------------------------------------------------
    if type(obj_old) is Spectrum:
        file_elecs = nwb_new.electrodes
        nChannels = len(file_elecs['x'].data[:])
        elecs_region = file_elecs.create_region(
            name='electrodes',
            region=np.arange(nChannels).tolist(),
            description='')
        return Spectrum(name=obj_old.name,
                        frequencies=obj_old.frequencies[:],
                        power=obj_old.power,
                        electrodes=elecs_region)

    # Survey tables ------------------------------------------------------------
    if obj_old.neurodata_type == 'SurveyTable':
        if obj_old.name == 'nrs_survey_table':
            n_rows = len(obj_old['nrs_pain_intensity_rating'].data)
            for i in range(n_rows):
                nrs_survey_table.add_row(
                    **{c: obj_old[c][i]
                       for c in obj_old.colnames})
            return nrs_survey_table

        elif obj_old.name == 'vas_survey_table':
            n_rows = len(obj_old['vas_pain_intensity_rating'].data)
            for i in range(n_rows):
                vas_survey_table.add_row(
                    **{c: obj_old[c][i]
                       for c in obj_old.colnames})
            return vas_survey_table

        elif obj_old.name == 'mpq_survey_table':
            n_rows = len(obj_old['throbbing'].data)
            for i in range(n_rows):
                mpq_survey_table.add_row(
                    **{c: obj_old[c][i]
                       for c in obj_old.colnames})
            return mpq_survey_table
Exemple #18
0
def yuta2nwb(session_path='/Users/bendichter/Desktop/Buzsaki/SenzaiBuzsaki2017/YutaMouse41/YutaMouse41-150903',
             subject_xls=None, include_spike_waveforms=True, stub=True):

    subject_path, session_id = os.path.split(session_path)
    fpath_base = os.path.split(subject_path)[0]
    identifier = session_id
    mouse_number = session_id[9:11]
    if '-' in session_id:
        subject_id, date_text = session_id.split('-')
        b = False
    else:
        subject_id, date_text = session_id.split('b')
        b = True

    if subject_xls is None:
        subject_xls = os.path.join(subject_path, 'YM' + mouse_number + ' exp_sheet.xlsx')
    else:
        if not subject_xls[-4:] == 'xlsx':
            subject_xls = os.path.join(subject_xls, 'YM' + mouse_number + ' exp_sheet.xlsx')

    session_start_time = dateparse(date_text, yearfirst=True)

    df = pd.read_excel(subject_xls)

    subject_data = {}
    for key in ['genotype', 'DOB', 'implantation', 'Probe', 'Surgery', 'virus injection', 'mouseID']:
        names = df.iloc[:, 0]
        if key in names.values:
            subject_data[key] = df.iloc[np.argmax(names == key), 1]

    if isinstance(subject_data['DOB'], datetime):
        age = session_start_time - subject_data['DOB']
    else:
        age = None

    subject = Subject(subject_id=subject_id, age=str(age),
                      genotype=subject_data['genotype'],
                      species='mouse')

    nwbfile = NWBFile(session_description='mouse in open exploration and theta maze',
                      identifier=identifier,
                      session_start_time=session_start_time.astimezone(),
                      file_create_date=datetime.now().astimezone(),
                      experimenter='Yuta Senzai',
                      session_id=session_id,
                      institution='NYU',
                      lab='Buzsaki',
                      subject=subject,
                      related_publications='DOI:10.1016/j.neuron.2016.12.011')

    print('reading and writing raw position data...', end='', flush=True)
    ns.add_position_data(nwbfile, session_path)

    shank_channels = ns.get_shank_channels(session_path)[:8]
    all_shank_channels = np.concatenate(shank_channels)

    print('setting up electrodes...', end='', flush=True)
    hilus_csv_path = os.path.join(fpath_base, 'early_session_hilus_chans.csv')
    lfp_channel = get_reference_elec(subject_xls, hilus_csv_path, session_start_time, session_id, b=b)
    print(lfp_channel)
    custom_column = [{'name': 'theta_reference',
                      'description': 'this electrode was used to calculate LFP canonical bands',
                      'data': all_shank_channels == lfp_channel}]
    ns.write_electrode_table(nwbfile, session_path, custom_columns=custom_column, max_shanks=max_shanks)

    print('reading LFPs...', end='', flush=True)
    lfp_fs, all_channels_data = ns.read_lfp(session_path, stub=stub)

    lfp_data = all_channels_data[:, all_shank_channels]
    print('writing LFPs...', flush=True)
    # lfp_data[:int(len(lfp_data)/4)]
    lfp_ts = ns.write_lfp(nwbfile, lfp_data, lfp_fs, name='lfp',
                          description='lfp signal for all shank electrodes')

    for name, channel in special_electrode_dict.items():
        ts = TimeSeries(name=name, description='environmental electrode recorded inline with neural data',
                        data=all_channels_data[channel], rate=lfp_fs, unit='V', conversion=np.nan, resolution=np.nan)
        nwbfile.add_acquisition(ts)

    # compute filtered LFP
    print('filtering LFP...', end='', flush=True)
    all_lfp_phases = []
    for passband in ('theta', 'gamma'):
        lfp_fft = filter_lfp(lfp_data[:, all_shank_channels == lfp_channel].ravel(), lfp_fs, passband=passband)
        lfp_phase, _ = hilbert_lfp(lfp_fft)
        all_lfp_phases.append(lfp_phase[:, np.newaxis])
    data = np.dstack(all_lfp_phases)
    print('done.', flush=True)

    if include_spike_waveforms:
        print('writing waveforms...', end='', flush=True)
        for shankn in np.arange(1, 9, dtype=int):
            ns.write_spike_waveforms(nwbfile, session_path, shankn, stub=stub)
        print('done.', flush=True)

    decomp_series = DecompositionSeries(name='LFPDecompositionSeries',
                                        description='Theta and Gamma phase for reference LFP',
                                        data=data, rate=lfp_fs,
                                        source_timeseries=lfp_ts,
                                        metric='phase', unit='radians')
    decomp_series.add_band(band_name='theta', band_limits=(4, 10))
    decomp_series.add_band(band_name='gamma', band_limits=(30, 80))

    check_module(nwbfile, 'ecephys', 'contains processed extracellular electrophysiology data').add_data_interface(decomp_series)

    [nwbfile.add_stimulus(x) for x in ns.get_events(session_path)]

    # create epochs corresponding to experiments/environments for the mouse

    sleep_state_fpath = os.path.join(session_path, '{}--StatePeriod.mat'.format(session_id))

    exist_pos_data = any(os.path.isfile(os.path.join(session_path, '{}__{}.mat'.format(session_id, task_type['name'])))
                         for task_type in task_types)

    if exist_pos_data:
        nwbfile.add_epoch_column('label', 'name of epoch')

    for task_type in task_types:
        label = task_type['name']

        file = os.path.join(session_path, session_id + '__' + label + '.mat')
        if os.path.isfile(file):
            print('loading position for ' + label + '...', end='', flush=True)

            pos_obj = Position(name=label + '_position')

            matin = loadmat(file)
            tt = matin['twhl_norm'][:, 0]
            exp_times = find_discontinuities(tt)

            if 'conversion' in task_type:
                conversion = task_type['conversion']
            else:
                conversion = np.nan

            for pos_type in ('twhl_norm', 'twhl_linearized'):
                if pos_type in matin:
                    pos_data_norm = matin[pos_type][:, 1:]

                    spatial_series_object = SpatialSeries(
                        name=label + '_{}_spatial_series'.format(pos_type),
                        data=H5DataIO(pos_data_norm, compression='gzip'),
                        reference_frame='unknown', conversion=conversion,
                        resolution=np.nan,
                        timestamps=H5DataIO(tt, compression='gzip'))
                    pos_obj.add_spatial_series(spatial_series_object)

            check_module(nwbfile, 'behavior', 'contains processed behavioral data').add_data_interface(pos_obj)
            for i, window in enumerate(exp_times):
                nwbfile.add_epoch(start_time=window[0], stop_time=window[1],
                                  label=label + '_' + str(i))
            print('done.')

    # there are occasional mismatches between the matlab struct and the neuroscope files
    # regions: 3: 'CA3', 4: 'DG'

    df_unit_features = get_UnitFeatureCell_features(fpath_base, session_id, session_path)

    celltype_names = []
    for celltype_id, region_id in zip(df_unit_features['fineCellType'].values,
                                      df_unit_features['region'].values):
        if celltype_id == 1:
            if region_id == 3:
                celltype_names.append('pyramidal cell')
            elif region_id == 4:
                celltype_names.append('granule cell')
            else:
                raise Exception('unknown type')
        elif not np.isfinite(celltype_id):
            celltype_names.append('missing')
        else:
            celltype_names.append(celltype_dict[celltype_id])

    custom_unit_columns = [
        {
            'name': 'cell_type',
            'description': 'name of cell type',
            'data': celltype_names},
        {
            'name': 'global_id',
            'description': 'global id for cell for entire experiment',
            'data': df_unit_features['unitID'].values},
        {
            'name': 'max_electrode',
            'description': 'electrode that has the maximum amplitude of the waveform',
            'data': get_max_electrodes(nwbfile, session_path),
            'table': nwbfile.electrodes
        }]

    ns.add_units(nwbfile, session_path, custom_unit_columns, max_shanks=max_shanks)

    trialdata_path = os.path.join(session_path, session_id + '__EightMazeRun.mat')
    if os.path.isfile(trialdata_path):
        trials_data = loadmat(trialdata_path)['EightMazeRun']

        trialdatainfo_path = os.path.join(fpath_base, 'EightMazeRunInfo.mat')
        trialdatainfo = [x[0] for x in loadmat(trialdatainfo_path)['EightMazeRunInfo'][0]]

        features = trialdatainfo[:7]
        features[:2] = 'start_time', 'stop_time',
        [nwbfile.add_trial_column(x, 'description') for x in features[4:] + ['condition']]

        for trial_data in trials_data:
            if trial_data[3]:
                cond = 'run_left'
            else:
                cond = 'run_right'
            nwbfile.add_trial(start_time=trial_data[0], stop_time=trial_data[1], condition=cond,
                              error_run=trial_data[4], stim_run=trial_data[5], both_visit=trial_data[6])
    """
    mono_syn_fpath = os.path.join(session_path, session_id+'-MonoSynConvClick.mat')

    matin = loadmat(mono_syn_fpath)
    exc = matin['FinalExcMonoSynID']
    inh = matin['FinalInhMonoSynID']

    #exc_obj = CatCellInfo(name='excitatory_connections',
    #                      indices_values=[], cell_index=exc[:, 0] - 1, indices=exc[:, 1] - 1)
    #module_cellular.add_container(exc_obj)
    #inh_obj = CatCellInfo(name='inhibitory_connections',
    #                      indices_values=[], cell_index=inh[:, 0] - 1, indices=inh[:, 1] - 1)
    #module_cellular.add_container(inh_obj)
    """

    if os.path.isfile(sleep_state_fpath):
        matin = loadmat(sleep_state_fpath)['StatePeriod']

        table = TimeIntervals(name='states', description='sleep states of animal')
        table.add_column(name='label', description='sleep state')

        data = []
        for name in matin.dtype.names:
            for row in matin[name][0][0]:
                data.append({'start_time': row[0], 'stop_time': row[1], 'label': name})
        [table.add_row(**row) for row in sorted(data, key=lambda x: x['start_time'])]

        check_module(nwbfile, 'behavior', 'contains behavioral data').add_data_interface(table)

    if stub:
        out_fname = session_path + '_stub.nwb'
    else:
        out_fname = session_path + '.nwb'

    print('writing NWB file...', end='', flush=True)
    with NWBHDF5IO(out_fname, mode='w') as io:
        io.write(nwbfile)
    print('done.')

    print('testing read...', end='', flush=True)
    # test read
    with NWBHDF5IO(out_fname, mode='r') as io:
        io.read()
    print('done.')
Exemple #19
0
    def convert_data(self,
                     nwbfile: NWBFile,
                     metadata: dict,
                     stub_test: bool = False):
        session_path = self.input_args["folder_path"]
        # TODO: check/enforce format?
        all_shank_channels = metadata["all_shank_channels"]
        special_electrode_dict = metadata.get("special_electrodes", [])
        lfp_channels = metadata["lfp_channels"]
        lfp_sampling_rate = metadata["lfp_sampling_rate"]
        spikes_nsamples = metadata["spikes_nsamples"]
        shank_channels = metadata["shank_channels"]
        n_total_channels = metadata["n_total_channels"]

        subject_path, session_id = os.path.split(session_path)

        _, all_channels_lfp_data = read_lfp(session_path,
                                            stub=stub_test,
                                            n_channels=n_total_channels)
        try:
            lfp_data = all_channels_lfp_data[:, all_shank_channels]
        except IndexError:
            lfp_data = all_channels_lfp_data
        lfp_ts = write_lfp(
            nwbfile,
            lfp_data,
            lfp_sampling_rate,
            name=metadata["lfp"]["name"],
            description=metadata["lfp"]["description"],
            electrode_inds=None,
        )

        # TODO: error checking on format?
        for special_electrode in special_electrode_dict:
            ts = TimeSeries(
                name=special_electrode["name"],
                description=special_electrode["description"],
                data=all_channels_lfp_data[:, special_electrode["channel"]],
                rate=lfp_sampling_rate,
                unit="V",
                resolution=np.nan,
            )
            nwbfile.add_acquisition(ts)

        for ref_name, lfp_channel in lfp_channels.items():
            try:
                all_lfp_phases = []
                for passband in ("theta", "gamma"):
                    lfp_fft = filter_lfp(
                        lfp_data[:, all_shank_channels == lfp_channel].ravel(),
                        lfp_sampling_rate,
                        passband=passband)
                    lfp_phase, _ = hilbert_lfp(lfp_fft)
                    all_lfp_phases.append(lfp_phase[:, np.newaxis])
                decomp_series_data = np.dstack(all_lfp_phases)

                # TODO: should units or metrics be metadata?
                decomp_series = DecompositionSeries(
                    name=metadata["lfp_decomposition"][ref_name]["name"],
                    description=metadata["lfp_decomposition"][ref_name]
                    ["description"],
                    data=decomp_series_data,
                    rate=lfp_sampling_rate,
                    source_timeseries=lfp_ts,
                    metric="phase",
                    unit="radians",
                )
                # TODO: the band limits should be extracted from parse_passband in band_analysis?
                decomp_series.add_band(band_name="theta", band_limits=(4, 10))
                decomp_series.add_band(band_name="gamma", band_limits=(30, 80))

                check_module(
                    nwbfile, "ecephys",
                    "contains processed extracellular electrophysiology data"
                ).add_data_interface(decomp_series)
            except IndexError:
                print(
                    "Unable to index lfp data for decomposition series - skipping"
                )

        write_spike_waveforms(nwbfile,
                              session_path,
                              spikes_nsamples=spikes_nsamples,
                              shank_channels=shank_channels,
                              stub_test=stub_test)
def store_wavelet_transform(elec_series,
                            processing,
                            npad=None,
                            filters='default',
                            X_fft_h=None,
                            abs_only=True,
                            constant_Q=True):
    """
    Apply bandpass filtering with wavelet transform using
    a prespecified set of filters.

    Parameters
    ----------
    X : ndarray (n_time, n_channels)
        Input data, dimensions
    rate : float
        Number of samples per second.
    filters : filter or list of filters (optional)
        One or more bandpass filters

    Returns
    -------
    Xh : ndarray, complex
        Bandpassed analytic signal
    X_fft_h : ndarray, complex
        Product of X_ff and heavyside.
    """
    X = elec_series.data[:]
    rate = elec_series.rate
    if npad is None:
        npad = int(rate)
    X_wvlt, _ = wavelet_transform(X,
                                  rate,
                                  filters=filters,
                                  X_fft_h=X_fft_h,
                                  npad=npad,
                                  constant_Q=constant_Q)
    elec_series_wvlt_amp = DecompositionSeries(
        'wvlt_amp_' + elec_series.name,
        abs(X_wvlt),
        metric='amplitude',
        source_timeseries=elec_series,
        starting_time=elec_series.starting_time,
        rate=rate,
        description=('Wavlet: ' + elec_series.description))
    series = [elec_series_wvlt_amp]
    if not abs_only:
        elec_series_wvlt_phase = DecompositionSeries(
            'wvlt_phase_' + elec_series.name,
            np.angle(X_wvlt),
            metric='phase',
            source_timeseries=elec_series,
            starting_time=elec_series.starting_time,
            rate=rate,
            description=('Wavlet: ' + elec_series.description))
        series.append(elec_series_wvlt_phase)

    for es in series:
        if filters == 'default':
            cfs = log_spaced_cfs(4.0749286538265, 200, 40)
            if constant_Q:
                sds = const_Q_sds(cfs)
            else:
                raise NotImplementedError
            for ii, (cf, sd) in enumerate(zip(cfs, sds)):
                es.add_band(band_name=str(ii),
                            band_mean=cf,
                            band_stdev=sd,
                            band_limits=(-1, -1))

        processing.add(es)
    return X_wvlt, series
Exemple #21
0
def yuta2nwb(
        session_path='D:/BuzsakiData/SenzaiY/YutaMouse41/YutaMouse41-150903',
        # '/Users/bendichter/Desktop/Buzsaki/SenzaiBuzsaki2017/YutaMouse41/YutaMouse41-150903',
        subject_xls=None,
        include_spike_waveforms=True,
        stub=True,
        cache_spec=True):

    subject_path, session_id = os.path.split(session_path)
    fpath_base = os.path.split(subject_path)[0]
    identifier = session_id
    mouse_number = session_id[9:11]
    if '-' in session_id:
        subject_id, date_text = session_id.split('-')
        b = False
    else:
        subject_id, date_text = session_id.split('b')
        b = True

    if subject_xls is None:
        subject_xls = os.path.join(subject_path,
                                   'YM' + mouse_number + ' exp_sheet.xlsx')
    else:
        if not subject_xls[-4:] == 'xlsx':
            subject_xls = os.path.join(subject_xls,
                                       'YM' + mouse_number + ' exp_sheet.xlsx')

    session_start_time = dateparse(date_text, yearfirst=True)

    df = pd.read_excel(subject_xls)

    subject_data = {}
    for key in [
            'genotype', 'DOB', 'implantation', 'Probe', 'Surgery',
            'virus injection', 'mouseID'
    ]:
        names = df.iloc[:, 0]
        if key in names.values:
            subject_data[key] = df.iloc[np.argmax(names == key), 1]

    if isinstance(subject_data['DOB'], datetime):
        age = session_start_time - subject_data['DOB']
    else:
        age = None

    subject = Subject(subject_id=subject_id,
                      age=str(age),
                      genotype=subject_data['genotype'],
                      species='mouse')

    nwbfile = NWBFile(
        session_description='mouse in open exploration and theta maze',
        identifier=identifier,
        session_start_time=session_start_time.astimezone(),
        file_create_date=datetime.now().astimezone(),
        experimenter='Yuta Senzai',
        session_id=session_id,
        institution='NYU',
        lab='Buzsaki',
        subject=subject,
        related_publications='DOI:10.1016/j.neuron.2016.12.011')

    print('reading and writing raw position data...', end='', flush=True)
    ns.add_position_data(nwbfile, session_path)

    shank_channels = ns.get_shank_channels(session_path)[:8]
    nshanks = len(shank_channels)
    all_shank_channels = np.concatenate(shank_channels)

    print('setting up electrodes...', end='', flush=True)
    hilus_csv_path = os.path.join(fpath_base, 'early_session_hilus_chans.csv')
    lfp_channel = get_reference_elec(subject_xls,
                                     hilus_csv_path,
                                     session_start_time,
                                     session_id,
                                     b=b)

    custom_column = [{
        'name': 'theta_reference',
        'description':
        'this electrode was used to calculate LFP canonical bands',
        'data': all_shank_channels == lfp_channel
    }]
    ns.write_electrode_table(nwbfile,
                             session_path,
                             custom_columns=custom_column,
                             max_shanks=max_shanks)

    print('reading raw electrode data...', end='', flush=True)
    if stub:
        # example recording extractor for fast testing
        xml_filepath = os.path.join(session_path, session_id + '.xml')
        xml_root = et.parse(xml_filepath).getroot()
        acq_sampling_frequency = float(
            xml_root.find('acquisitionSystem').find('samplingRate').text)
        num_channels = 4
        num_frames = 10000
        X = np.random.normal(0, 1, (num_channels, num_frames))
        geom = np.random.normal(0, 1, (num_channels, 2))
        X = (X * 100).astype(int)
        sre = se.NumpyRecordingExtractor(
            timeseries=X, sampling_frequency=acq_sampling_frequency, geom=geom)
    else:
        nre = se.NeuroscopeRecordingExtractor('{}/{}.dat'.format(
            session_path, session_id))
        sre = se.SubRecordingExtractor(nre, channel_ids=all_shank_channels)

    print('writing raw electrode data...', end='', flush=True)
    se.NwbRecordingExtractor.add_electrical_series(sre, nwbfile)
    print('done.')

    print('reading spiking units...', end='', flush=True)
    if stub:
        spike_times = [200, 300, 400]
        num_frames = 10000
        allshanks = []
        for k in range(nshanks):
            SX = se.NumpySortingExtractor()
            for j in range(len(spike_times)):
                SX.add_unit(unit_id=j + 1,
                            times=np.sort(
                                np.random.uniform(0, num_frames,
                                                  spike_times[j])))
            allshanks.append(SX)
        se_allshanks = se.MultiSortingExtractor(allshanks)
        se_allshanks.set_sampling_frequency(acq_sampling_frequency)
    else:
        se_allshanks = se.NeuroscopeMultiSortingExtractor(session_path,
                                                          keep_mua_units=False)

    electrode_group = []
    for shankn in np.arange(1, nshanks + 1, dtype=int):
        for id in se_allshanks.sortings[shankn - 1].get_unit_ids():
            electrode_group.append(nwbfile.electrode_groups['shank' +
                                                            str(shankn)])

    df_unit_features = get_UnitFeatureCell_features(fpath_base, session_id,
                                                    session_path)

    celltype_names = []
    for celltype_id, region_id in zip(df_unit_features['fineCellType'].values,
                                      df_unit_features['region'].values):
        if celltype_id == 1:
            if region_id == 3:
                celltype_names.append('pyramidal cell')
            elif region_id == 4:
                celltype_names.append('granule cell')
            else:
                raise Exception('unknown type')
        elif not np.isfinite(celltype_id):
            celltype_names.append('missing')
        else:
            celltype_names.append(celltype_dict[celltype_id])

    # Add custom column data into the SortingExtractor so it can be written by the converter
    # Note there is currently a hidden assumption that the way in which the NeuroscopeSortingExtractor
    # merges the cluster IDs matches one-to-one with the get_UnitFeatureCell_features extraction
    property_descriptions = {
        'cell_type': 'name of cell type',
        'global_id': 'global id for cell for entire experiment',
        'shank_id': '0-indexed id of cluster of shank',
        'electrode_group': 'the electrode group that each spike unit came from'
    }
    property_values = {
        'cell_type': celltype_names,
        'global_id': df_unit_features['unitID'].values,
        'shank_id': [x - 2 for x in df_unit_features['unitIDshank'].values],
        # - 2 b/c the get_UnitFeatureCell_features removes 0 and 1 IDs from each shank
        'electrode_group': electrode_group
    }
    for unit_id in se_allshanks.get_unit_ids():
        for property_name in property_descriptions.keys():
            se_allshanks.set_unit_property(
                unit_id, property_name,
                property_values[property_name][unit_id])

    se.NwbSortingExtractor.write_sorting(
        se_allshanks,
        nwbfile=nwbfile,
        property_descriptions=property_descriptions)
    print('done.')

    # Read and write LFP's
    print('reading LFPs...', end='', flush=True)
    lfp_fs, all_channels_lfp_data = ns.read_lfp(session_path, stub=stub)

    lfp_data = all_channels_lfp_data[:, all_shank_channels]
    print('writing LFPs...', flush=True)
    # lfp_data[:int(len(lfp_data)/4)]
    lfp_ts = ns.write_lfp(nwbfile,
                          lfp_data,
                          lfp_fs,
                          name='lfp',
                          description='lfp signal for all shank electrodes')

    # Read and add special environmental electrodes
    for name, channel in special_electrode_dict.items():
        ts = TimeSeries(
            name=name,
            description=
            'environmental electrode recorded inline with neural data',
            data=all_channels_lfp_data[:, channel],
            rate=lfp_fs,
            unit='V',
            #conversion=np.nan,
            resolution=np.nan)
        nwbfile.add_acquisition(ts)

    # compute filtered LFP
    print('filtering LFP...', end='', flush=True)
    all_lfp_phases = []
    for passband in ('theta', 'gamma'):
        lfp_fft = filter_lfp(
            lfp_data[:, all_shank_channels == lfp_channel].ravel(),
            lfp_fs,
            passband=passband)
        lfp_phase, _ = hilbert_lfp(lfp_fft)
        all_lfp_phases.append(lfp_phase[:, np.newaxis])
    data = np.dstack(all_lfp_phases)
    print('done.', flush=True)

    if include_spike_waveforms:
        print('writing waveforms...', end='', flush=True)
        nshanks = min((max_shanks, len(ns.get_shank_channels(session_path))))

        for shankn in np.arange(nshanks, dtype=int) + 1:
            # Get spike activty from .spk file on a per-shank and per-sample basis
            ns.write_spike_waveforms(nwbfile, session_path, shankn, stub=stub)
        print('done.', flush=True)

    # Get the LFP Decomposition Series
    decomp_series = DecompositionSeries(
        name='LFPDecompositionSeries',
        description='Theta and Gamma phase for reference LFP',
        data=data,
        rate=lfp_fs,
        source_timeseries=lfp_ts,
        metric='phase',
        unit='radians')
    decomp_series.add_band(band_name='theta', band_limits=(4, 10))
    decomp_series.add_band(band_name='gamma', band_limits=(30, 80))

    check_module(nwbfile, 'ecephys',
                 'contains processed extracellular electrophysiology data'
                 ).add_data_interface(decomp_series)

    [nwbfile.add_stimulus(x) for x in ns.get_events(session_path)]

    # create epochs corresponding to experiments/environments for the mouse

    sleep_state_fpath = os.path.join(session_path,
                                     '{}--StatePeriod.mat'.format(session_id))

    exist_pos_data = any(
        os.path.isfile(
            os.path.join(session_path, '{}__{}.mat'.format(
                session_id, task_type['name']))) for task_type in task_types)

    if exist_pos_data:
        nwbfile.add_epoch_column('label', 'name of epoch')

    for task_type in task_types:
        label = task_type['name']

        file = os.path.join(session_path, session_id + '__' + label + '.mat')
        if os.path.isfile(file):
            print('loading position for ' + label + '...', end='', flush=True)

            pos_obj = Position(name=label + '_position')

            matin = loadmat(file)
            tt = matin['twhl_norm'][:, 0]
            exp_times = find_discontinuities(tt)

            if 'conversion' in task_type:
                conversion = task_type['conversion']
            else:
                conversion = np.nan

            for pos_type in ('twhl_norm', 'twhl_linearized'):
                if pos_type in matin:
                    pos_data_norm = matin[pos_type][:, 1:]

                    spatial_series_object = SpatialSeries(
                        name=label + '_{}_spatial_series'.format(pos_type),
                        data=H5DataIO(pos_data_norm, compression='gzip'),
                        reference_frame='unknown',
                        conversion=conversion,
                        resolution=np.nan,
                        timestamps=H5DataIO(tt, compression='gzip'))
                    pos_obj.add_spatial_series(spatial_series_object)

            check_module(
                nwbfile, 'behavior',
                'contains processed behavioral data').add_data_interface(
                    pos_obj)
            for i, window in enumerate(exp_times):
                nwbfile.add_epoch(start_time=window[0],
                                  stop_time=window[1],
                                  label=label + '_' + str(i))
            print('done.')

    # there are occasional mismatches between the matlab struct and the neuroscope files
    # regions: 3: 'CA3', 4: 'DG'

    trialdata_path = os.path.join(session_path,
                                  session_id + '__EightMazeRun.mat')
    if os.path.isfile(trialdata_path):
        trials_data = loadmat(trialdata_path)['EightMazeRun']

        trialdatainfo_path = os.path.join(fpath_base, 'EightMazeRunInfo.mat')
        trialdatainfo = [
            x[0] for x in loadmat(trialdatainfo_path)['EightMazeRunInfo'][0]
        ]

        features = trialdatainfo[:7]
        features[:2] = 'start_time', 'stop_time',
        [
            nwbfile.add_trial_column(x, 'description')
            for x in features[4:] + ['condition']
        ]

        for trial_data in trials_data:
            if trial_data[3]:
                cond = 'run_left'
            else:
                cond = 'run_right'
            nwbfile.add_trial(start_time=trial_data[0],
                              stop_time=trial_data[1],
                              condition=cond,
                              error_run=trial_data[4],
                              stim_run=trial_data[5],
                              both_visit=trial_data[6])
    """
    mono_syn_fpath = os.path.join(session_path, session_id+'-MonoSynConvClick.mat')

    matin = loadmat(mono_syn_fpath)
    exc = matin['FinalExcMonoSynID']
    inh = matin['FinalInhMonoSynID']

    #exc_obj = CatCellInfo(name='excitatory_connections',
    #                      indices_values=[], cell_index=exc[:, 0] - 1, indices=exc[:, 1] - 1)
    #module_cellular.add_container(exc_obj)
    #inh_obj = CatCellInfo(name='inhibitory_connections',
    #                      indices_values=[], cell_index=inh[:, 0] - 1, indices=inh[:, 1] - 1)
    #module_cellular.add_container(inh_obj)
    """

    if os.path.isfile(sleep_state_fpath):
        matin = loadmat(sleep_state_fpath)['StatePeriod']

        table = TimeIntervals(name='states',
                              description='sleep states of animal')
        table.add_column(name='label', description='sleep state')

        data = []
        for name in matin.dtype.names:
            for row in matin[name][0][0]:
                data.append({
                    'start_time': row[0],
                    'stop_time': row[1],
                    'label': name
                })
        [
            table.add_row(**row)
            for row in sorted(data, key=lambda x: x['start_time'])
        ]

        check_module(nwbfile, 'behavior',
                     'contains behavioral data').add_data_interface(table)

    print('writing NWB file...', end='', flush=True)
    if stub:
        out_fname = session_path + '_stub.nwb'
    else:
        out_fname = session_path + '.nwb'

    with NWBHDF5IO(out_fname, mode='w') as io:
        io.write(nwbfile, cache_spec=cache_spec)
    print('done.')

    print('testing read...', end='', flush=True)
    # test read
    with NWBHDF5IO(out_fname, mode='r') as io:
        io.read()
    print('done.')