def write_sorting(sorting, save_path, **nwbfile_kwargs):
        """

        Parameters
        ----------
        sorting: SortingExtractor
        save_path: str
        nwbfile_kwargs: optional, pynwb.NWBFile args
        """
        assert HAVE_NWB, "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n"
        ids = sorting.get_unit_ids()
        fs = sorting.get_sampling_frequency()

        if os.path.exists(save_path):
            io = NWBHDF5IO(save_path, 'r+')
            nwbfile = io.read()
        else:
            io = NWBHDF5IO(save_path, mode='w')
            if 'session_description' not in nwbfile_kwargs:
                nwbfile_kwargs['session_description'] = 'No description'
            if 'identifier' not in nwbfile_kwargs:
                nwbfile_kwargs['identifier'] = str(uuid.uuid4())
            input_nwbfile_kwargs = {'session_start_time': datetime.now()}
            input_nwbfile_kwargs.update(nwbfile_kwargs)
            nwbfile = NWBFile(**input_nwbfile_kwargs)

        # Stores spike times for each detected cell (unit)
        for id in ids:
            spkt = sorting.get_unit_spike_train(unit_id=id) / fs
            nwbfile.add_unit(id=id, spike_times=spkt)
            # 'waveform_mean' and 'waveform_sd' are interesting args to include later

        io.write(nwbfile)
        io.close()
示例#2
0
    def run_conversion(self, nwbfile: NWBFile, metadata: dict):
        file_path = self.source_data['file_path']

        file = h5py.File(file_path, 'r')
        cell_info = file['cell_info']

        cell_ids = [
            ''.join([chr(x[0]) for x in file[x[0]]])
            for x in cell_info['cell_id']
        ]

        times = get_data(file, 'time')
        body_pos = get_data(file, 'body_position')
        body_speed = get_data(file, 'body_speed')
        horizontal_eye_pos = get_data(file, 'horizontal_eye_position')
        vertial_eye_pos = get_data(file, 'vertical_eye_position')
        horizontal_eye_vel = get_data(file, 'horiztonal_eye_velocity')
        vertial_eye_vel = get_data(file, 'vertical_eye_velocity')

        all_spike_times = [
            file[x[0]][:].ravel() for x in cell_info['spike_times']
        ]

        behavior = nwbfile.create_processing_module(
            name='behavior', description='contains processed behavioral data')

        spatial_series = SpatialSeries(
            name='position',
            data=body_pos,
            timestamps=times,
            conversion=.01,
            reference_frame='on track. Position is in VR.')

        behavior.add(Position(spatial_series=spatial_series))

        behavior.add(
            TimeSeries(name='body_speed',
                       data=body_speed,
                       timestamps=spatial_series,
                       unit='cm/s'))

        behavior.add(
            EyeTracking(
                spatial_series=SpatialSeries(name='eye_position',
                                             data=np.c_[horizontal_eye_pos,
                                                        vertial_eye_pos],
                                             timestamps=spatial_series,
                                             reference_frame='unknown')))

        behavior.add(
            TimeSeries(name='eye_velocity',
                       data=np.c_[horizontal_eye_vel, vertial_eye_vel],
                       timestamps=spatial_series,
                       unit='unknown'))

        for spike_times, cell_id in zip(all_spike_times, cell_ids):
            id_ = int(cell_id.split('_')[-1])
            nwbfile.add_unit(spike_times=spike_times, id=id_)

        return nwbfile
示例#3
0
 def _create_ecephys_spiking(self, nwbfile: NWBFile,
                             metadata_ecephys: dict):
     """Add spiking data"""
     print('Converting spiking data...')
     with h5py.File(self.source_data['path_ecephys_processed'], 'r') as f:
         spike_times = np.where(f['spk'][0])[0] * f['dte'][0]
         nwbfile.add_unit(spike_times=spike_times)
示例#4
0
def add_units(nwbfile: NWBFile,
              session_path: str,
              custom_cols: Optional[List[dict]] = None,
              max_shanks: Optional[int] = 8):
    """Add the spiking unit information to the NWBFile.

    Parameters
    ----------
    nwbfile: pynwb.NWBFile
    session_path: str
    custom_cols: list(dict), optional
        [{name, description, data, kwargs}]
    max_shanks: int, optional
        only take the first <max_shanks> channel groups

    Returns
    -------
    nwbfile
    """
    nwbfile.add_unit_column('shank_id', '0-indexed id of cluster of shank')
    nshanks = len(get_shank_channels(session_path))
    nshanks = min((max_shanks, nshanks))

    for shankn in range(1, nshanks + 1):
        df = get_clusters_single_shank(session_path, shankn)
        electrode_group = nwbfile.electrode_groups['shank' + str(shankn)]
        for shank_id, idf in df.groupby('id'):
            nwbfile.add_unit(spike_times=idf['time'].values,
                             shank_id=shank_id,
                             electrode_group=electrode_group)

    if custom_cols:
        [nwbfile.add_unit_column(**x) for x in custom_cols]

    return nwbfile
    def write_sorting(sorting, save_path, nwbfile_kwargs=None):
        """

        Parameters
        ----------
        sorting: SortingExtractor
        save_path: str
        nwbfile_kwargs: optional, dict with optional args of pynwb.NWBFile
        """
        try:
            from pynwb import NWBHDF5IO
            from pynwb import NWBFile
            from pynwb.ecephys import ElectricalSeries

        except ModuleNotFoundError:
            raise ModuleNotFoundError(
                "To use the Nwb extractors, install pynwb: \n\n"
                "pip install pynwb\n\n")

        ids = sorting.get_unit_ids()
        fs = sorting.get_sampling_frequency()

        if os.path.exists(save_path):
            io = NWBHDF5IO(save_path, 'r+')
            nwbfile = io.read()
        else:
            io = NWBHDF5IO(save_path, mode='w')
            input_nwbfile_kwargs = {
                'session_start_time': datetime.now(),
                'identifier': '',
                'session_description': ''
            }
            if nwbfile_kwargs is not None:
                input_nwbfile_kwargs.update(nwbfile_kwargs)
            nwbfile = NWBFile(**input_nwbfile_kwargs)

        # Stores spike times for each detected cell (unit)
        for id in ids:
            spkt = sorting.get_unit_spike_train(unit_id=id + 1) / fs
            nwbfile.add_unit(id=id, spike_times=spkt)
            # 'waveform_mean' and 'waveform_sd' are interesting args to include later

        io.write(nwbfile)
        io.close()
示例#6
0
class ShowPSTHTestCase(unittest.TestCase):
    def setUp(self):
        """
        Trials must exist.
        """
        start_time = datetime(2017, 4, 3, 11, tzinfo=tzlocal())
        create_date = datetime(2017, 4, 15, 12, tzinfo=tzlocal())

        self.nwbfile = NWBFile(session_description='NWBFile for PSTH',
                               identifier='NWB123',
                               session_start_time=start_time,
                               file_create_date=create_date)

        self.nwbfile.add_unit_column('location',
                                     'the anatomical location of this unit')
        self.nwbfile.add_unit_column(
            'quality', 'the quality for the inference of this unit')

        self.nwbfile.add_unit(spike_times=[2.2, 3.0, 4.5],
                              obs_intervals=[[1, 10]],
                              location='CA1',
                              quality=0.95)
        self.nwbfile.add_unit(spike_times=[2.2, 3.0, 25.0, 26.0],
                              obs_intervals=[[1, 10], [20, 30]],
                              location='CA3',
                              quality=0.85)
        self.nwbfile.add_unit(spike_times=[1.2, 2.3, 3.3, 4.5],
                              obs_intervals=[[1, 10], [20, 30]],
                              location='CA1',
                              quality=0.90)

        self.nwbfile.add_trial_column(
            name='stim', description='the visual stimuli during the trial')

        self.nwbfile.add_trial(start_time=0.0, stop_time=2.0, stim='person')
        self.nwbfile.add_trial(start_time=3.0, stop_time=5.0, stim='ocean')
        self.nwbfile.add_trial(start_time=6.0, stop_time=8.0, stim='desert')

    def test_psth_widget(self):
        assert isinstance(psth_widget(self.nwbfile.units), widgets.Widget)

    def test_raster_widget(self):
        assert isinstance(raster_widget(self.nwbfile.units), widgets.Widget)

    def test_show_session_raster(self):
        assert isinstance(show_session_raster(self.nwbfile.units), plt.Figure)

    def test_raster_grid_widget(self):
        assert isinstance(raster_grid_widget(self.nwbfile.units),
                          widgets.Widget)

    def test_raster_grid(self):
        trials = self.nwbfile.units.get_ancestor('NWBFile').trials
        assert isinstance(
            raster_grid(self.nwbfile.units,
                        trials=trials,
                        index=0,
                        before=0.5,
                        after=20.0), plt.Figure)
示例#7
0
def neuroh5_to_nwb(fpath, out_path=None):
    """

    Parameters
    ----------
    fpath: str | path
        path of neuroh5 file
    out_path: str (optional)
        where the NWB file is saved

    """

    if out_path is None:
        out_path = fpath[:-3] + '.nwb'

    fname = os.path.split(fpath)[1]
    identifier = fname[:-4]

    nwbfile = NWBFile(session_description='session_description',
                      identifier=identifier,
                      session_start_time=datetime.now().astimezone(),
                      institution='Stanford University',
                      lab='Soltesz')

    with File(fpath, 'r') as f:
        write_position(nwbfile, f)

        nwbfile.add_unit_column('cell_type', 'cell type')
        nwbfile.add_unit_column('pop_id', 'cell number within population')

        for unit_dict in tqdm(get_neuroh5_cell_data(f),
                              total=38000 + 34000,
                              desc='reading cell data'):
            nwbfile.add_unit(**unit_dict)

    with NWBHDF5IO(out_path, 'w') as io:
        io.write(nwbfile)
示例#8
0
class UnitsTrialsTestCase(unittest.TestCase):
    def setUp(self):
        start_time = datetime(2017, 4, 3, 11, tzinfo=tzlocal())
        create_date = datetime(2017, 4, 15, 12, tzinfo=tzlocal())

        self.nwbfile = NWBFile(
            session_description="NWBFile for PSTH",
            identifier="NWB123",
            session_start_time=start_time,
            file_create_date=create_date,
        )

        self.nwbfile.add_unit_column("location",
                                     "the anatomical location of this unit")
        self.nwbfile.add_unit_column(
            "quality", "the quality for the inference of this unit")

        self.nwbfile.add_unit(
            id=1,
            spike_times=[2.2, 3.0, 4.5],
            obs_intervals=[[1, 10]],
            location="CA1",
            quality=0.95,
        )
        self.nwbfile.add_unit(
            id=2,
            spike_times=[2.2, 3.0, 25.0, 26.0],
            obs_intervals=[[1, 10], [20, 30]],
            location="CA3",
            quality=0.85,
        )
        self.nwbfile.add_unit(
            id=3,
            spike_times=[1.2, 2.3, 3.3, 4.5],
            obs_intervals=[[1, 10], [20, 30]],
            location="CA1",
            quality=0.90,
        )

        self.nwbfile.add_trial_column(
            name="stim", description="the visual stimuli during the trial")

        self.nwbfile.add_trial(start_time=0.0, stop_time=2.0, stim="person")
        self.nwbfile.add_trial(start_time=3.0, stop_time=5.0, stim="ocean")
        self.nwbfile.add_trial(start_time=6.0, stop_time=8.0, stim="desert")
        self.nwbfile.add_trial(start_time=8.0, stop_time=12.0, stim="person")
        self.nwbfile.add_trial(start_time=13.0, stop_time=15.0, stim="ocean")
        self.nwbfile.add_trial(start_time=16.0, stop_time=18.0, stim="desert")
示例#9
0
# unit was not being recorded (and thus correctly compute firing rates, for example). This information
# should be provided as a list of [start, end] time pairs in the `obs_intervals` field. If `obs_intervals` is
# provided, then all entries in `spike_times` should occur within one of the listed intervals. In the example
# below, all 3 units are observed during the time period from 1 to 10 s and fired spikes during that period.
# Units 2 and 3 were also observed during the time period from 20-30s; but only unit 2 fired spikes in that
# period.
#
# Lets specify some unit metadata and then add some units:

nwbfile.add_unit_column('location', 'the anatomical location of this unit')
nwbfile.add_unit_column('quality',
                        'the quality for the inference of this unit')

nwbfile.add_unit(id=1,
                 spike_times=[2.2, 3.0, 4.5],
                 obs_intervals=[[1, 10]],
                 location='CA1',
                 quality=0.95)
nwbfile.add_unit(id=2,
                 spike_times=[2.2, 3.0, 25.0, 26.0],
                 obs_intervals=[[1, 10], [20, 30]],
                 location='CA3',
                 quality=0.85)
nwbfile.add_unit(id=3,
                 spike_times=[1.2, 2.3, 3.3, 4.5],
                 obs_intervals=[[1, 10], [20, 30]],
                 location='CA1',
                 quality=0.90)

####################
# .. _units_fields_ref:
示例#10
0
def convert(
        input_file,
        session_start_time,
        subject_date_of_birth,
        subject_id='I5',
        subject_description='naive',
        subject_genotype='wild-type',
        subject_sex='M',
        subject_weight='11.6g',
        subject_species='Mus musculus',
        subject_brain_region='Medial Entorhinal Cortex',
        surgery='Probe: +/-3.3mm ML, 0.2mm A of sinus, then as deep as possible',
        session_id='npI5_0417_baseline_1',
        experimenter='Kei Masuda',
        experiment_description='Virtual Hallway Task',
        institution='Stanford University School of Medicine',
        lab_name='Giocomo Lab'):
    """
    Read in the .mat file specified by input_file and convert to .nwb format.

    Parameters
    ----------
    input_file : np.ndarray (..., n_channels, n_time)
        the .mat file to be converted
    subject_id : string
        the unique subject ID number for the subject of the experiment
    subject_date_of_birth : datetime ISO 8601
        the date and time the subject was born
    subject_description : string
        important information specific to this subject that differentiates it from other members of it's species
    subject_genotype : string
        the genetic strain of this species.
    subject_sex : string
        Male or Female
    subject_weight :
        the weight of the subject around the time of the experiment
    subject_species : string
        the name of the species of the subject
    subject_brain_region : basestring
        the name of the brain region where the electrode probe is recording from
    surgery : str
        information about the subject's surgery to implant electrodes
    session_id: string
        human-readable ID# for the experiment session that has a one-to-one relationship with a recording session
    session_start_time : datetime
        date and time that the experiment started
    experimenter : string
        who ran the experiment, first and last name
    experiment_description : string
        what task was being run during the session
    institution : string
        what institution was the experiment performed in
    lab_name : string
        the lab where the experiment was performed

    Returns
    -------
    nwbfile : NWBFile
        The contents of the .mat file converted into the NWB format.  The nwbfile is saved to disk using NDWHDF5
    """

    # input matlab data
    matfile = hdf5storage.loadmat(input_file)

    # output path for nwb data
    def replace_last(source_string, replace_what, replace_with):
        head, _sep, tail = source_string.rpartition(replace_what)
        return head + replace_with + tail

    outpath = replace_last(input_file, '.mat', '.nwb')

    create_date = datetime.today()
    timezone_cali = pytz.timezone('US/Pacific')
    create_date_tz = timezone_cali.localize(create_date)

    # if loading data from config.yaml, convert string dates into datetime
    if isinstance(session_start_time, str):
        session_start_time = datetime.strptime(session_start_time,
                                               '%B %d, %Y %I:%M%p')
        session_start_time = timezone_cali.localize(session_start_time)

    if isinstance(subject_date_of_birth, str):
        subject_date_of_birth = datetime.strptime(subject_date_of_birth,
                                                  '%B %d, %Y %I:%M%p')
        subject_date_of_birth = timezone_cali.localize(subject_date_of_birth)

    # create unique identifier for this experimental session
    uuid_identifier = uuid.uuid1()

    # Create NWB file
    nwbfile = NWBFile(
        session_description=experiment_description,  # required
        identifier=uuid_identifier.hex,  # required
        session_id=session_id,
        experiment_description=experiment_description,
        experimenter=experimenter,
        surgery=surgery,
        institution=institution,
        lab=lab_name,
        session_start_time=session_start_time,  # required
        file_create_date=create_date_tz)  # optional

    # add information about the subject of the experiment
    experiment_subject = Subject(subject_id=subject_id,
                                 species=subject_species,
                                 description=subject_description,
                                 genotype=subject_genotype,
                                 date_of_birth=subject_date_of_birth,
                                 weight=subject_weight,
                                 sex=subject_sex)
    nwbfile.subject = experiment_subject

    # adding constants via LabMetaData container
    # constants
    sample_rate = float(matfile['sp'][0]['sample_rate'][0][0][0])
    n_channels_dat = int(matfile['sp'][0]['n_channels_dat'][0][0][0])
    dat_path = matfile['sp'][0]['dat_path'][0][0][0]
    offset = int(matfile['sp'][0]['offset'][0][0][0])
    data_dtype = matfile['sp'][0]['dtype'][0][0][0]
    hp_filtered = bool(matfile['sp'][0]['hp_filtered'][0][0][0])
    vr_session_offset = matfile['sp'][0]['vr_session_offset'][0][0][0]
    # container
    lab_metadata = LabMetaData_ext(name='LabMetaData',
                                   acquisition_sampling_rate=sample_rate,
                                   number_of_electrodes=n_channels_dat,
                                   file_path=dat_path,
                                   bytes_to_skip=offset,
                                   raw_data_dtype=data_dtype,
                                   high_pass_filtered=hp_filtered,
                                   movie_start_time=vr_session_offset)
    nwbfile.add_lab_meta_data(lab_metadata)

    # Adding trial information
    nwbfile.add_trial_column(
        'trial_contrast',
        'visual contrast of the maze through which the mouse is running')
    trial = np.ravel(matfile['trial'])
    trial_nums = np.unique(trial)
    position_time = np.ravel(matfile['post'])
    # matlab trial numbers start at 1. To correctly index trial_contract vector,
    # subtracting 1 from 'num' so index starts at 0
    for num in trial_nums:
        trial_times = position_time[trial == num]
        nwbfile.add_trial(start_time=trial_times[0],
                          stop_time=trial_times[-1],
                          trial_contrast=matfile['trial_contrast'][num - 1][0])

    # Add mouse position inside:
    position = Position()
    position_virtual = np.ravel(matfile['posx'])
    # position inside the virtual environment
    sampling_rate = 1 / (position_time[1] - position_time[0])
    position.create_spatial_series(
        name='Position',
        data=position_virtual,
        starting_time=position_time[0],
        rate=sampling_rate,
        reference_frame='The start of the trial, which begins at the start '
        'of the virtual hallway.',
        conversion=0.01,
        description='Subject position in the virtual hallway.',
        comments='The values should be >0 and <400cm. Values greater than '
        '400cm mean that the mouse briefly exited the maze.',
    )

    # physical position on the mouse wheel
    physical_posx = position_virtual
    trial_gain = np.ravel(matfile['trial_gain'])
    for num in trial_nums:
        physical_posx[trial ==
                      num] = physical_posx[trial == num] / trial_gain[num - 1]

    position.create_spatial_series(
        name='PhysicalPosition',
        data=physical_posx,
        starting_time=position_time[0],
        rate=sampling_rate,
        reference_frame='Location on wheel re-referenced to zero '
        'at the start of each trial.',
        conversion=0.01,
        description='Physical location on the wheel measured '
        'since the beginning of the trial.',
        comments='Physical location found by dividing the '
        'virtual position by the "trial_gain"')
    nwbfile.add_acquisition(position)

    # Add timing of lick events, as well as mouse's virtual position during lick event
    lick_events = BehavioralEvents()
    lick_events.create_timeseries(
        'LickEvents',
        data=np.ravel(matfile['lickx']),
        timestamps=np.ravel(matfile['lickt']),
        unit='centimeter',
        description='Subject position in virtual hallway during the lick.')
    nwbfile.add_acquisition(lick_events)

    # Add information on the visual stimulus that was shown to the subject
    # Assumed rate=60 [Hz]. Update if necessary
    # Update external_file to link to Unity environment file
    visualization = ImageSeries(
        name='ImageSeries',
        unit='seconds',
        format='external',
        external_file=list(['https://unity.com/VR-and-AR-corner']),
        starting_time=vr_session_offset,
        starting_frame=[[0]],
        rate=float(60),
        description='virtual Unity environment that the mouse navigates through'
    )
    nwbfile.add_stimulus(visualization)

    # Add the recording device, a neuropixel probe
    recording_device = nwbfile.create_device(name='neuropixel_probes')
    electrode_group_description = 'single neuropixels probe http://www.open-ephys.org/neuropixelscorded'
    electrode_group_name = 'probe1'

    electrode_group = nwbfile.create_electrode_group(
        electrode_group_name,
        description=electrode_group_description,
        location=subject_brain_region,
        device=recording_device)

    # Add information about each electrode
    xcoords = np.ravel(matfile['sp'][0]['xcoords'][0])
    ycoords = np.ravel(matfile['sp'][0]['ycoords'][0])
    data_filtered_flag = matfile['sp'][0]['hp_filtered'][0][0]
    if data_filtered_flag:
        filter_desc = 'The raw voltage signals from the electrodes were high-pass filtered'
    else:
        filter_desc = 'The raw voltage signals from the electrodes were not high-pass filtered'

    num_recording_electrodes = xcoords.shape[0]
    recording_electrodes = range(0, num_recording_electrodes)

    # create electrode columns for the x,y location on the neuropixel  probe
    # the standard x,y,z locations are reserved for Allen Brain Atlas location
    nwbfile.add_electrode_column('rel_x', 'electrode x-location on the probe')
    nwbfile.add_electrode_column('rel_y', 'electrode y-location on the probe')

    for idx in recording_electrodes:
        nwbfile.add_electrode(id=idx,
                              x=np.nan,
                              y=np.nan,
                              z=np.nan,
                              rel_x=float(xcoords[idx]),
                              rel_y=float(ycoords[idx]),
                              imp=np.nan,
                              location='medial entorhinal cortex',
                              filtering=filter_desc,
                              group=electrode_group)

    # Add information about each unit, termed 'cluster' in giocomo data
    # create new columns in unit table
    nwbfile.add_unit_column(
        'quality',
        'labels given to clusters during manual sorting in phy (1=MUA, '
        '2=Good, 3=Unsorted)')

    # cluster information
    cluster_ids = matfile['sp'][0]['cids'][0][0]
    cluster_quality = matfile['sp'][0]['cgs'][0][0]
    # spikes in time
    spike_times = np.ravel(matfile['sp'][0]['st'][0])  # the time of each spike
    spike_cluster = np.ravel(
        matfile['sp'][0]['clu'][0])  # the cluster_id that spiked at that time

    for i, cluster_id in enumerate(cluster_ids):
        unit_spike_times = spike_times[spike_cluster == cluster_id]
        waveforms = matfile['sp'][0]['temps'][0][cluster_id]
        nwbfile.add_unit(id=int(cluster_id),
                         spike_times=unit_spike_times,
                         quality=cluster_quality[i],
                         waveform_mean=waveforms,
                         electrode_group=electrode_group)

    # Trying to add another Units table to hold the results of the automatic spike sorting
    # create TemplateUnits units table
    template_units = Units(
        name='TemplateUnits',
        description='units assigned during automatic spike sorting')
    template_units.add_column(
        'tempScalingAmps',
        'scaling amplitude applied to the template when extracting spike',
        index=True)

    # information on extracted spike templates
    spike_templates = np.ravel(matfile['sp'][0]['spikeTemplates'][0])
    spike_template_ids = np.unique(spike_templates)
    # template scaling amplitudes
    temp_scaling_amps = np.ravel(matfile['sp'][0]['tempScalingAmps'][0])

    for i, spike_template_id in enumerate(spike_template_ids):
        template_spike_times = spike_times[spike_templates ==
                                           spike_template_id]
        temp_scaling_amps_per_template = temp_scaling_amps[spike_templates ==
                                                           spike_template_id]
        template_units.add_unit(id=int(spike_template_id),
                                spike_times=template_spike_times,
                                electrode_group=electrode_group,
                                tempScalingAmps=temp_scaling_amps_per_template)

    # create ecephys processing module
    spike_template_module = nwbfile.create_processing_module(
        name='ecephys',
        description='units assigned during automatic spike sorting')

    # add template_units table to processing module
    spike_template_module.add(template_units)

    print(nwbfile)
    print('converted to NWB:N')
    print('saving ...')

    with NWBHDF5IO(outpath, 'w') as io:
        io.write(nwbfile)
        print('saved', outpath)
示例#11
0
class NWBFileTest(TestCase):
    def setUp(self):
        self.start = datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())
        self.ref_time = datetime(1979, 1, 1, 0, tzinfo=tzutc())
        self.create = [
            datetime(2017, 5, 1, 12, tzinfo=tzlocal()),
            datetime(2017, 5, 2, 13, 0, 0, 1, tzinfo=tzutc()),
            datetime(2017, 5, 2, 14, tzinfo=tzutc())
        ]
        self.path = 'nwbfile_test.h5'
        self.nwbfile = NWBFile(
            'a test session description for a test NWBFile',
            'FILE123',
            self.start,
            file_create_date=self.create,
            timestamps_reference_time=self.ref_time,
            experimenter='A test experimenter',
            lab='a test lab',
            institution='a test institution',
            experiment_description='a test experiment description',
            session_id='test1',
            notes='my notes',
            pharmacology='drugs',
            protocol='protocol',
            related_publications='my pubs',
            slices='my slices',
            surgery='surgery',
            virus='a virus',
            source_script='noscript',
            source_script_file_name='nofilename',
            stimulus_notes='test stimulus notes',
            data_collection='test data collection notes',
            keywords=('these', 'are', 'keywords'))

    def test_constructor(self):
        self.assertEqual(self.nwbfile.session_description,
                         'a test session description for a test NWBFile')
        self.assertEqual(self.nwbfile.identifier, 'FILE123')
        self.assertEqual(self.nwbfile.session_start_time, self.start)
        self.assertEqual(self.nwbfile.file_create_date, self.create)
        self.assertEqual(self.nwbfile.lab, 'a test lab')
        self.assertEqual(self.nwbfile.experimenter, ('A test experimenter', ))
        self.assertEqual(self.nwbfile.institution, 'a test institution')
        self.assertEqual(self.nwbfile.experiment_description,
                         'a test experiment description')
        self.assertEqual(self.nwbfile.session_id, 'test1')
        self.assertEqual(self.nwbfile.stimulus_notes, 'test stimulus notes')
        self.assertEqual(self.nwbfile.data_collection,
                         'test data collection notes')
        self.assertEqual(self.nwbfile.related_publications, ('my pubs', ))
        self.assertEqual(self.nwbfile.source_script, 'noscript')
        self.assertEqual(self.nwbfile.source_script_file_name, 'nofilename')
        self.assertEqual(self.nwbfile.keywords, ('these', 'are', 'keywords'))
        self.assertEqual(self.nwbfile.timestamps_reference_time, self.ref_time)

    def test_create_electrode_group(self):
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        d = self.nwbfile.create_device('a fake device')
        elecgrp = self.nwbfile.create_electrode_group(name, desc, loc, d)
        self.assertEqual(elecgrp.description, desc)
        self.assertEqual(elecgrp.location, loc)
        self.assertIs(elecgrp.device, d)

    def test_create_custom_intervals(self):
        df_words = pd.DataFrame({
            'start_time': [.1, 2.],
            'stop_time': [.8, 2.3],
            'label': ['hello', 'there']
        })
        words = TimeIntervals.from_dataframe(df_words, name='words')
        self.nwbfile.add_time_intervals(words)
        self.assertEqual(self.nwbfile.intervals['words'], words)

    def test_create_electrode_group_invalid_index(self):
        """
        Test the case where the user creates an electrode table region with
        indexes that are out of range of the amount of electrodes added.
        """
        nwbfile = NWBFile('a', 'b', datetime.now(tzlocal()))
        device = nwbfile.create_device('a')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 device=device,
                                                 location='a')
        for i in range(4):
            nwbfile.add_electrode(np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  'a',
                                  'a',
                                  elecgrp,
                                  id=i)
        with self.assertRaises(IndexError):
            nwbfile.create_electrode_table_region(list(range(6)), 'test')

    def test_access_group_after_io(self):
        """
        Motivated by #739
        """
        nwbfile = NWBFile('a', 'b', datetime.now(tzlocal()))
        device = nwbfile.create_device('a')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 device=device,
                                                 location='a')
        nwbfile.add_electrode(np.nan,
                              np.nan,
                              np.nan,
                              np.nan,
                              'a',
                              'a',
                              elecgrp,
                              id=0)

        with NWBHDF5IO('electrodes_mwe.nwb', 'w') as io:
            io.write(nwbfile)

        with NWBHDF5IO('electrodes_mwe.nwb', 'a') as io:
            nwbfile_i = io.read()
            for aa, bb in zip(nwbfile_i.electrodes['group'][:],
                              nwbfile.electrodes['group'][:]):
                self.assertEqual(aa.name, bb.name)

        for i in range(4):
            nwbfile.add_electrode(np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  'a',
                                  'a',
                                  elecgrp,
                                  id=i + 1)

        with NWBHDF5IO('electrodes_mwe.nwb', 'w') as io:
            io.write(nwbfile)

        with NWBHDF5IO('electrodes_mwe.nwb', 'a') as io:
            nwbfile_i = io.read()
            for aa, bb in zip(nwbfile_i.electrodes['group'][:],
                              nwbfile.electrodes['group'][:]):
                self.assertEqual(aa.name, bb.name)

        remove_test_file("electrodes_mwe.nwb")

    def test_access_processing(self):
        self.nwbfile.create_processing_module('test_mod', 'test_description')
        # test deprecate .modules
        with self.assertWarnsWith(DeprecationWarning,
                                  'replaced by NWBFile.processing'):
            modules = self.nwbfile.modules['test_mod']
        self.assertIs(self.nwbfile.processing['test_mod'], modules)

    def test_epoch_tags(self):
        tags1 = ['t1', 't2']
        tags2 = ['t3', 't4']
        tstamps = np.arange(1.0, 100.0, 0.1, dtype=np.float)
        ts = TimeSeries("test_ts",
                        list(range(len(tstamps))),
                        'unit',
                        timestamps=tstamps)
        expected_tags = tags1 + tags2
        self.nwbfile.add_epoch(0.0, 1.0, tags1, ts)
        self.nwbfile.add_epoch(0.0, 1.0, tags2, ts)
        tags = self.nwbfile.epoch_tags
        self.assertEqual(set(expected_tags), set(tags))

    def test_add_acquisition(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.acquisition), 1)

    def test_add_stimulus(self):
        self.nwbfile.add_stimulus(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus), 1)

    def test_add_stimulus_template(self):
        self.nwbfile.add_stimulus_template(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus_template), 1)

    def test_add_analysis(self):
        self.nwbfile.add_analysis(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.analysis), 1)

    def test_add_acquisition_check_dups(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        with self.assertRaises(ValueError):
            self.nwbfile.add_acquisition(
                TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                           'grams',
                           timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))

    def test_get_acquisition_empty(self):
        with self.assertRaisesWith(ValueError,
                                   "acquisition of NWBFile 'root' is empty"):
            self.nwbfile.get_acquisition()

    def test_get_acquisition_multiple_elements(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "more than one element in acquisition of NWBFile 'root' -- must specify a name"
        with self.assertRaisesWith(ValueError, msg):
            self.nwbfile.get_acquisition()

    def test_add_acquisition_invalid_name(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "\"'TEST_TS' not found in acquisition of NWBFile 'root'\""
        with self.assertRaisesWith(KeyError, msg):
            self.nwbfile.get_acquisition("TEST_TS")

    def test_set_electrode_table(self):
        table = ElectrodeTable()
        dev1 = self.nwbfile.create_device('dev1')
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-1.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-2.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-3.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-4.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        self.nwbfile.set_electrode_table(table)

        self.assertIs(self.nwbfile.electrodes, table)
        self.assertIs(table.parent, self.nwbfile)

    def test_add_unit_column(self):
        self.nwbfile.add_unit_column('unit_type', 'the type of unit')
        self.assertEqual(self.nwbfile.units.colnames, ('unit_type', ))

    def test_add_unit(self):
        self.nwbfile.add_unit(id=1)
        self.assertEqual(len(self.nwbfile.units), 1)
        self.nwbfile.add_unit(id=2)
        self.nwbfile.add_unit(id=3)
        self.assertEqual(len(self.nwbfile.units), 3)

    def test_add_trial_column(self):
        self.nwbfile.add_trial_column('trial_type', 'the type of trial')
        self.assertEqual(self.nwbfile.trials.colnames,
                         ('start_time', 'stop_time', 'trial_type'))

    def test_add_trial(self):
        self.nwbfile.add_trial(start_time=10.0, stop_time=20.0)
        self.assertEqual(len(self.nwbfile.trials), 1)
        self.nwbfile.add_trial(start_time=30.0, stop_time=40.0)
        self.nwbfile.add_trial(start_time=50.0, stop_time=70.0)
        self.assertEqual(len(self.nwbfile.trials), 3)

    def test_add_invalid_times_column(self):
        self.nwbfile.add_invalid_times_column(
            'comments', 'description of reason for omitting time')
        self.assertEqual(self.nwbfile.invalid_times.colnames,
                         ('start_time', 'stop_time', 'comments'))

    def test_add_invalid_time_interval(self):

        self.nwbfile.add_invalid_time_interval(start_time=0.0, stop_time=12.0)
        self.assertEqual(len(self.nwbfile.invalid_times), 1)
        self.nwbfile.add_invalid_time_interval(start_time=15.0, stop_time=16.0)
        self.nwbfile.add_invalid_time_interval(start_time=17.0, stop_time=20.5)
        self.assertEqual(len(self.nwbfile.invalid_times), 3)

    def test_add_invalid_time_w_ts(self):
        ts = TimeSeries(name='name', data=[1.2], rate=1.0, unit='na')
        self.nwbfile.add_invalid_time_interval(start_time=18.0,
                                               stop_time=20.6,
                                               timeseries=ts,
                                               tags=('hi', 'there'))

    def test_add_electrode(self):
        dev1 = self.nwbfile.create_device('dev1')
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        self.nwbfile.add_electrode(1.0,
                                   2.0,
                                   3.0,
                                   -1.0,
                                   'CA1',
                                   'none',
                                   group=group,
                                   id=1)
        elec = self.nwbfile.electrodes[0]
        self.assertEqual(elec.index[0], 1)
        self.assertEqual(elec.iloc[0]['x'], 1.0)
        self.assertEqual(elec.iloc[0]['y'], 2.0)
        self.assertEqual(elec.iloc[0]['z'], 3.0)
        self.assertEqual(elec.iloc[0]['location'], 'CA1')
        self.assertEqual(elec.iloc[0]['filtering'], 'none')
        self.assertEqual(elec.iloc[0]['group'], group)

    def test_all_children(self):
        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        device = self.nwbfile.create_device('a fake device')
        elecgrp = self.nwbfile.create_electrode_group(name, desc, loc, device)
        children = self.nwbfile.all_children()
        self.assertIn(ts1, children)
        self.assertIn(ts2, children)
        self.assertIn(device, children)
        self.assertIn(elecgrp, children)

    def test_fail_if_source_script_file_name_without_source_script(self):
        with self.assertRaises(ValueError):
            # <-- source_script_file_name without source_script is not allowed
            NWBFile('a test session description for a test NWBFile',
                    'FILE123',
                    self.start,
                    source_script=None,
                    source_script_file_name='nofilename')

    def test_get_neurodata_type(self):
        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        p1 = ts1.get_ancestor(neurodata_type='NWBFile')
        self.assertIs(p1, self.nwbfile)
        p2 = ts2.get_ancestor(neurodata_type='NWBFile')
        self.assertIs(p2, self.nwbfile)

    def test_print_units(self):
        self.nwbfile.add_unit(spike_times=[1., 2., 3.])
        expected = """units pynwb.misc.Units at 0x%d
Fields:
  colnames: ['spike_times']
  columns: (
    spike_times_index <class 'hdmf.common.table.VectorIndex'>,
    spike_times <class 'hdmf.common.table.VectorData'>
  )
  description: Autogenerated by NWBFile
  id: id <class 'hdmf.common.table.ElementIdentifiers'>
"""
        expected = expected % id(self.nwbfile.units)
        self.assertEqual(str(self.nwbfile.units), expected)

    def test_copy(self):
        self.nwbfile.add_unit(spike_times=[1., 2., 3.])
        device = self.nwbfile.create_device('a')
        elecgrp = self.nwbfile.create_electrode_group('a',
                                                      'b',
                                                      device=device,
                                                      location='a')
        self.nwbfile.add_electrode(np.nan,
                                   np.nan,
                                   np.nan,
                                   np.nan,
                                   'a',
                                   'a',
                                   elecgrp,
                                   id=0)
        self.nwbfile.add_electrode(np.nan, np.nan, np.nan, np.nan, 'b', 'b',
                                   elecgrp)
        elec_region = self.nwbfile.create_electrode_table_region([1], 'name')

        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = ElectricalSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                               electrodes=elec_region,
                               timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        self.nwbfile.add_trial(start_time=50.0, stop_time=70.0)
        self.nwbfile.add_invalid_times_column(
            'comments', 'description of reason for omitting time')
        self.nwbfile.create_processing_module('test_mod', 'test_description')
        self.nwbfile.create_time_intervals('custom_interval',
                                           'a custom time interval')
        self.nwbfile.intervals['custom_interval'].add_interval(start_time=10.,
                                                               stop_time=20.)
        newfile = self.nwbfile.copy()

        # test dictionaries
        self.assertIs(self.nwbfile.devices['a'], newfile.devices['a'])
        self.assertIs(self.nwbfile.acquisition['test_ts1'],
                      newfile.acquisition['test_ts1'])
        self.assertIs(self.nwbfile.acquisition['test_ts2'],
                      newfile.acquisition['test_ts2'])
        self.assertIs(self.nwbfile.processing['test_mod'],
                      newfile.processing['test_mod'])

        # test dynamic tables
        self.assertIsNot(self.nwbfile.electrodes, newfile.electrodes)
        self.assertIs(self.nwbfile.electrodes['x'], newfile.electrodes['x'])
        self.assertIsNot(self.nwbfile.units, newfile.units)
        self.assertIs(self.nwbfile.units['spike_times'],
                      newfile.units['spike_times'])
        self.assertIsNot(self.nwbfile.trials, newfile.trials)
        self.assertIsNot(self.nwbfile.trials.parent, newfile.trials.parent)
        self.assertIs(self.nwbfile.trials.id, newfile.trials.id)
        self.assertIs(self.nwbfile.trials['start_time'],
                      newfile.trials['start_time'])
        self.assertIs(self.nwbfile.trials['stop_time'],
                      newfile.trials['stop_time'])
        self.assertIsNot(self.nwbfile.invalid_times, newfile.invalid_times)
        self.assertTupleEqual(self.nwbfile.invalid_times.colnames,
                              newfile.invalid_times.colnames)
        self.assertIsNot(self.nwbfile.intervals['custom_interval'],
                         newfile.intervals['custom_interval'])
        self.assertTupleEqual(
            self.nwbfile.intervals['custom_interval'].colnames,
            newfile.intervals['custom_interval'].colnames)
        self.assertIs(self.nwbfile.intervals['custom_interval']['start_time'],
                      newfile.intervals['custom_interval']['start_time'])
        self.assertIs(self.nwbfile.intervals['custom_interval']['stop_time'],
                      newfile.intervals['custom_interval']['stop_time'])

    def test_multi_experimenters(self):
        self.nwbfile = NWBFile('a test session description for a test NWBFile',
                               'FILE123',
                               self.start,
                               experimenter=('experimenter1', 'experimenter2'))
        self.assertTupleEqual(self.nwbfile.experimenter,
                              ('experimenter1', 'experimenter2'))

    def test_multi_publications(self):
        self.nwbfile = NWBFile('a test session description for a test NWBFile',
                               'FILE123',
                               self.start,
                               related_publications=('pub1', 'pub2'))
        self.assertTupleEqual(self.nwbfile.related_publications,
                              ('pub1', 'pub2'))
示例#12
0
    def write_sorting(sorting, save_path, **nwbfile_kwargs):
        """

        Parameters
        ----------
        sorting: SortingExtractor
        save_path: str
        nwbfile_kwargs: optional, pynwb.NWBFile args
        """
        check_nwb_install()

        ids = sorting.get_unit_ids()
        fs = sorting.get_sampling_frequency()
        if hasattr(sorting, '_t0'):
            t0 = sorting._t0
        else:
            t0 = 0.

        (all_properties, all_features) = find_all_unit_property_names(
            properties_dict=sorting._properties,
            features_dict=sorting._features)

        if os.path.exists(save_path):
            read_mode = 'r+'
        else:
            read_mode = 'w'

        with NWBHDF5IO(save_path, mode=read_mode) as io:
            if read_mode == 'r+':
                nwbfile = io.read()
            else:
                kwargs = {
                    'session_description': 'No description',
                    'identifier': str(uuid.uuid4()),
                    'session_start_time': datetime.now()
                }
                kwargs.update(**nwbfile_kwargs)
                nwbfile = NWBFile(**kwargs)

            # If no Units present in mwb file
            if nwbfile.units is None:
                for id in ids:
                    spkt = sorting.get_unit_spike_train(unit_id=id) / fs
                    nwbfile.add_unit(id=id, spike_times=spkt)

            # Units properties
            for pr in all_properties:
                unit_ids = [
                    int(k) for k, v in sorting._properties.items() if pr in v
                ]
                vals = [
                    v[pr] for k, v in sorting._properties.items() if pr in v
                ]
                set_dynamic_table_property(dynamic_table=nwbfile.units,
                                           row_ids=unit_ids,
                                           property_name=pr,
                                           values=vals,
                                           default_value=np.nan,
                                           description='no description')

            # # Stores average and std of spike traces
            # if 'waveforms' in sorting.get_unit_spike_feature_names(unit_id=id):
            #     wf = sorting.get_unit_spike_features(unit_id=id,
            #                                          feature_name='waveforms')
            #     relevant_ch = most_relevant_ch(wf)
            #     # Spike traces on the most relevant channel
            #     traces = wf[:, relevant_ch, :]
            #     traces_avg = np.mean(traces, axis=0)
            #     traces_std = np.std(traces, axis=0)
            #     nwbfile.add_unit(
            #         id=id,
            #         spike_times=spkt,
            #         waveform_mean=traces_avg,
            #         waveform_sd=traces_std
            #     )

            # Units spike features
            nspikes = {k: get_nspikes(nwbfile.units, int(k)) for k in ids}
            for ft in all_features:
                vals = [
                    v[ft] if ft in v else [np.nan] * nspikes[int(k)]
                    for k, v in sorting._features.items()
                ]
                flatten_vals = [item for sublist in vals for item in sublist]
                nspks_list = [sp for sp in nspikes.values()]
                spikes_index = np.cumsum(nspks_list).tolist()
                set_dynamic_table_property(
                    dynamic_table=nwbfile.units,
                    row_ids=ids,
                    property_name=ft,
                    values=flatten_vals,
                    index=spikes_index,
                )

            io.write(nwbfile)
示例#13
0
class ShowPSTHTestCase(unittest.TestCase):
    def setUp(self):
        """
        Trials must exist.
        """
        start_time = datetime(2017, 4, 3, 11, tzinfo=tzlocal())
        create_date = datetime(2017, 4, 15, 12, tzinfo=tzlocal())

        self.nwbfile = NWBFile(
            session_description="NWBFile for PSTH",
            identifier="NWB123",
            session_start_time=start_time,
            file_create_date=create_date,
        )

        self.nwbfile.add_unit_column("location", "the anatomical location of this unit")
        self.nwbfile.add_unit_column(
            "quality", "the quality for the inference of this unit"
        )

        self.nwbfile.add_unit(
            spike_times=[2.2, 3.0, 4.5],
            obs_intervals=[[1, 10]],
            location="CA1",
            quality=0.95,
        )
        self.nwbfile.add_unit(
            spike_times=[2.2, 3.0, 25.0, 26.0],
            obs_intervals=[[1, 10], [20, 30]],
            location="CA3",
            quality=0.85,
        )
        self.nwbfile.add_unit(
            spike_times=[1.2, 2.3, 3.3, 4.5],
            obs_intervals=[[1, 10], [20, 30]],
            location="CA1",
            quality=0.90,
        )

        self.nwbfile.add_trial_column(
            name="stim", description="the visual stimuli during the trial"
        )

        self.nwbfile.add_trial(start_time=0.0, stop_time=1.0, stim="person")
        self.nwbfile.add_trial(start_time=0.1, stop_time=2.0, stim="person")
        self.nwbfile.add_trial(start_time=3.0, stop_time=4.0, stim="ocean")
        self.nwbfile.add_trial(start_time=4.0, stop_time=5.0, stim="ocean")
        self.nwbfile.add_trial(start_time=5.0, stop_time=6.0, stim="desert")
        self.nwbfile.add_trial(start_time=6.0, stop_time=8.0, stim="desert")

    def test_psth_widget(self):
        widget = PSTHWidget(self.nwbfile.units)
        assert isinstance(widget, widgets.Widget)

        widget.psth_type_radio = "gaussian"
        widget.trial_event_controller.value = ("start_time", "stop_time")
        widget.unit_controller.value = 1
        widget.gas.group_dd.value = "stim"
        widget.gas.group_dd.value = None

    def test_multipsth_widget(self):
        psth_widget = PSTHWidget(self.nwbfile.units)
        assert isinstance(psth_widget, widgets.Widget)
        start_labels = ('start_time', 'stop_time')
        fig = psth_widget.update(index=0, start_labels=start_labels)
        assert len(fig.axes) == 2 * len(start_labels)
        
    def test_raster_widget(self):
        assert isinstance(RasterWidget(self.nwbfile.units), widgets.Widget)

    def test_show_session_raster(self):
        assert isinstance(show_session_raster(self.nwbfile.units), plt.Axes)

    def test_raster_grid_widget(self):
        assert isinstance(RasterGridWidget(self.nwbfile.units), widgets.Widget)

    def test_raster_grid(self):
        trials = self.nwbfile.units.get_ancestor("NWBFile").trials
        assert isinstance(
            raster_grid(
                self.nwbfile.units,
                time_intervals=trials,
                index=0,
                start=-0.5,
                end=20.0,
            ),
            plt.Figure,
        )
示例#14
0
#
# By default, NWBFile only requires a unique identifier for each unit. Additional columns
# can be added using :py:func:`~pynwb.file.NWBFile.add_unit_column`. Like
# :py:func:`~pynwb.file.NWBFile.add_trial_column`, this method also takes a name
# for the column, a description of what the column stores and does not need a data type.
# Once all columns have been added, unit data can be populated using :py:func:`~pynwb.file.NWBFile.add_unit`.
# Again, like :py:func:`~pynwb.file.NWBFile.add_trial_column`, this method takes a dict with keys that correspond
# to column names.
#
# Lets specify some unit metadata and then add some units

nwbfile.add_unit_column('location', 'the anatomical location of this unit')
nwbfile.add_unit_column('quality',
                        'the quality for the inference of this unit')

nwbfile.add_unit(id=1, location='CA1', quality=0.95)
nwbfile.add_unit(id=2, location='CA3', quality=0.85)
nwbfile.add_unit(id=3, location='CA1', quality=0.90)

####################
# .. _basic_writing:
#
# Writing an NWB file
# -------------------
#
# NWB I/O is carried out using the :py:class:`~pynwb.NWBHDF5IO` class [#]_. This class is responsible
# for mapping an :py:class:`~pynwb.file.NWBFile` object into HDF5 according to the NWB schema.
#
# To write an :py:class:`~pynwb.file.NWBFile`, use the :py:func:`~pynwb.form.backends.io.FORMIO.write` method.

from pynwb import NWBHDF5IO
示例#15
0
def export_to_nwb(session_key,
                  nwb_output_dir=default_nwb_output_dir,
                  save=False,
                  overwrite=False):

    this_session = (experiment.Session & session_key).fetch1()
    print(f'Exporting to NWB 2.0 for session: {this_session}...')
    # ===============================================================================
    # ============================== META INFORMATION ===============================
    # ===============================================================================

    sess_desc = session_description_mapper[(
        experiment.ProjectSession & session_key).fetch1('project_name')]

    # -- NWB file - a NWB2.0 file for each session
    nwbfile = NWBFile(
        identifier='_'.join([
            'ANM' + str(this_session['subject_id']),
            this_session['session_date'].strftime('%Y-%m-%d'),
            str(this_session['session'])
        ]),
        session_description='',
        session_start_time=datetime.combine(this_session['session_date'],
                                            zero_zero_time),
        file_create_date=datetime.now(tzlocal()),
        experimenter=this_session['username'],
        institution=institution,
        experiment_description=sess_desc['experiment_description'],
        related_publications=sess_desc['related_publications'],
        keywords=sess_desc['keywords'])

    # -- subject
    subj = (lab.Subject & session_key).aggr(
        lab.Subject.Strain, ...,
        strains='GROUP_CONCAT(animal_strain)').fetch1()
    nwbfile.subject = pynwb.file.Subject(
        subject_id=str(this_session['subject_id']),
        description=
        f'source: {subj["animal_source"]}; strains: {subj["strains"]}',
        genotype=' x '.join((lab.Subject.GeneModification
                             & subj).fetch('gene_modification')),
        sex=subj['sex'],
        species=subj['species'],
        date_of_birth=datetime.combine(subj['date_of_birth'], zero_zero_time)
        if subj['date_of_birth'] else None)
    # -- virus
    nwbfile.virus = json.dumps([{
        k: str(v)
        for k, v in virus_injection.items() if k not in subj
    } for virus_injection in virus.VirusInjection * virus.Virus & session_key])

    # ===============================================================================
    # ======================== EXTRACELLULAR & CLUSTERING ===========================
    # ===============================================================================
    """
    In the event of multiple probe recording (i.e. multiple probe insertions), the clustering results 
    (and the associated units) are associated with the corresponding probe. 
    Each probe insertion is associated with one ElectrodeConfiguration (which may define multiple electrode groups)
    """

    dj_insert_location = ephys.ProbeInsertion.InsertionLocation.aggr(
        ephys.ProbeInsertion.RecordableBrainRegion.proj(
            brain_region='CONCAT(hemisphere, " ", brain_area)'), ...,
        brain_regions='GROUP_CONCAT(brain_region)')

    for probe_insertion in ephys.ProbeInsertion & session_key:
        electrode_config = (lab.ElectrodeConfig & probe_insertion).fetch1()

        electrode_groups = {}
        for electrode_group in lab.ElectrodeConfig.ElectrodeGroup & electrode_config:
            electrode_groups[electrode_group[
                'electrode_group']] = nwbfile.create_electrode_group(
                    name=electrode_config['electrode_config_name'] + '_g' +
                    str(electrode_group['electrode_group']),
                    description='N/A',
                    device=nwbfile.create_device(
                        name=electrode_config['probe']),
                    location=json.dumps({
                        k: str(v)
                        for k, v in (dj_insert_location
                                     & session_key).fetch1().items()
                        if k not in dj_insert_location.primary_key
                    }))

        for chn in (lab.ElectrodeConfig.Electrode * lab.Probe.Electrode
                    & electrode_config).fetch(as_dict=True):
            nwbfile.add_electrode(
                id=chn['electrode'],
                group=electrode_groups[chn['electrode_group']],
                filtering=hardware_filter,
                imp=-1.,
                x=chn['x_coord'] if chn['x_coord'] else np.nan,
                y=chn['y_coord'] if chn['y_coord'] else np.nan,
                z=chn['z_coord'] if chn['z_coord'] else np.nan,
                location=electrode_groups[chn['electrode_group']].location)

        # --- unit spike times ---
        nwbfile.add_unit_column(
            name='sampling_rate',
            description='Sampling rate of the raw voltage traces (Hz)')
        nwbfile.add_unit_column(name='quality',
                                description='unit quality from clustering')
        nwbfile.add_unit_column(
            name='posx',
            description=
            'estimated x position of the unit relative to probe (0,0) (um)')
        nwbfile.add_unit_column(
            name='posy',
            description=
            'estimated y position of the unit relative to probe (0,0) (um)')
        nwbfile.add_unit_column(
            name='cell_type',
            description='cell type (e.g. fast spiking or pyramidal)')

        for unit_key in (ephys.Unit * ephys.UnitCellType
                         & probe_insertion).fetch('KEY'):

            unit = (ephys.Unit * ephys.UnitCellType & probe_insertion
                    & unit_key).proj(..., '-spike_times').fetch1()
            if ephys.TrialSpikes & unit_key:
                obs_intervals = np.array(
                    list(
                        zip(*(ephys.TrialSpikes * experiment.SessionTrial
                              & unit_key).fetch('start_time',
                                                'stop_time')))).astype(float)
                tr_spike_times, tr_start, tr_go = (
                    ephys.TrialSpikes * experiment.SessionTrial *
                    (experiment.TrialEvent & 'trial_event_type = "go"')
                    & unit_key).fetch('spike_times', 'start_time',
                                      'trial_event_time')
                spike_times = np.hstack([
                    spks + float(t_go) + float(t_start) for spks, t_go, t_start
                    in zip(tr_spike_times, tr_start, tr_go)
                ])
            else:  # the case of unavailable `TrialSpikes`
                spike_times = (ephys.Unit & unit_key).fetch1('spike_times')
                obs_intervals = np.array(
                    list(
                        zip(*(experiment.SessionTrial & unit_key).fetch(
                            'start_time', 'stop_time')))).astype(float)
                obs_intervals = [
                    interval for interval in obs_intervals
                    if np.logical_and(spike_times >= interval[0],
                                      spike_times <= interval[-1]).any()
                ]

            # make an electrode table region (which electrode(s) is this unit coming from)
            nwbfile.add_unit(
                id=unit['unit'],
                electrodes=np.where(
                    np.array(nwbfile.electrodes.id.data) ==
                    unit['electrode'])[0],
                electrode_group=electrode_groups[unit['electrode_group']],
                obs_intervals=obs_intervals,
                sampling_rate=ecephys_fs,
                quality=unit['unit_quality'],
                posx=unit['unit_posx'],
                posy=unit['unit_posy'],
                cell_type=unit['cell_type'],
                spike_times=spike_times,
                waveform_mean=np.mean(unit['waveform'], axis=0),
                waveform_sd=np.std(unit['waveform'], axis=0))

    # ===============================================================================
    # ============================= BEHAVIOR TRACKING ===============================
    # ===============================================================================

    if tracking.LickTrace * experiment.SessionTrial & session_key:
        # re-concatenating trialized tracking traces
        lick_traces, time_vecs, trial_starts = (
            tracking.LickTrace * experiment.SessionTrial & session_key).fetch(
                'lick_trace', 'lick_trace_timestamps', 'start_time')
        behav_acq = pynwb.behavior.BehavioralTimeSeries(
            name='BehavioralTimeSeries')
        nwbfile.add_acquisition(behav_acq)
        behav_acq.create_timeseries(
            name='lick_trace',
            unit='a.u.',
            conversion=1.0,
            data=np.hstack(lick_traces),
            description=
            "Time-series of the animal's tongue movement when licking",
            timestamps=np.hstack(time_vecs + trial_starts.astype(float)))

    # ===============================================================================
    # ============================= PHOTO-STIMULATION ===============================
    # ===============================================================================
    stim_sites = {}
    for photostim in experiment.Photostim * experiment.PhotostimBrainRegion * lab.PhotostimDevice & session_key:

        stim_device = (nwbfile.get_device(photostim['photostim_device'])
                       if photostim['photostim_device'] in nwbfile.devices else
                       nwbfile.create_device(
                           name=photostim['photostim_device']))

        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name=photostim['stim_laterality'] + ' ' +
            photostim['stim_brain_area'],
            device=stim_device,
            excitation_lambda=float(photostim['excitation_wavelength']),
            location=json.dumps([{
                k: v
                for k, v in stim_locs.items()
                if k not in experiment.Photostim.primary_key
            } for stim_locs in (experiment.Photostim.PhotostimLocation.proj(
                ..., '-brain_area')
                                & photostim).fetch(as_dict=True)],
                                default=str),
            description='')
        nwbfile.add_ogen_site(stim_site)
        stim_sites[photostim['photo_stim']] = stim_site

    # re-concatenating trialized photostim traces
    dj_photostim = (experiment.PhotostimTrace * experiment.SessionTrial *
                    experiment.PhotostimEvent * experiment.Photostim
                    & session_key)

    for photo_stim, stim_site in stim_sites.items():
        if dj_photostim & {'photo_stim': photo_stim}:
            aom_input_trace, laser_power, time_vecs, trial_starts = (
                dj_photostim & {
                    'photo_stim': photo_stim
                }).fetch('aom_input_trace', 'laser_power',
                         'photostim_timestamps', 'start_time')

            aom_series = pynwb.ogen.OptogeneticSeries(
                name=stim_site.name + '_aom_input_trace',
                site=stim_site,
                conversion=1e-3,
                data=np.hstack(aom_input_trace),
                timestamps=np.hstack(time_vecs + trial_starts.astype(float)))
            laser_series = pynwb.ogen.OptogeneticSeries(
                name=stim_site.name + '_laser_power',
                site=stim_site,
                conversion=1e-3,
                data=np.hstack(laser_power),
                timestamps=np.hstack(time_vecs + trial_starts.astype(float)))

            nwbfile.add_stimulus(aom_series)
            nwbfile.add_stimulus(laser_series)

    # ===============================================================================
    # =============================== BEHAVIOR TRIALS ===============================
    # ===============================================================================

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes: 'start_time' and 'stop_time'
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)

    dj_trial = experiment.SessionTrial * experiment.BehaviorTrial
    skip_adding_columns = experiment.Session.primary_key + [
        'trial_uid', 'trial'
    ]

    if experiment.SessionTrial & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo
        trial_columns = [{
            'name':
            tag,
            'description':
            re.sub('\s+:|\s+', ' ',
                   re.search(f'(?<={tag})(.*)',
                             str(dj_trial.heading)).group()).strip()
        } for tag in dj_trial.heading.names if tag not in skip_adding_columns +
                         ['start_time', 'stop_time']]

        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns:
            nwbfile.add_trial_column(**c)

        # Add entry to the trial-table
        for trial in (dj_trial & session_key).fetch(as_dict=True):
            trial['start_time'] = float(trial['start_time'])
            trial['stop_time'] = float(
                trial['stop_time']) if trial['stop_time'] else np.nan
            trial['id'] = trial['trial']  # rename 'trial_id' to 'id'
            [trial.pop(k) for k in skip_adding_columns]
            nwbfile.add_trial(**trial)

    # ===============================================================================
    # =============================== BEHAVIOR TRIAL EVENTS ==========================
    # ===============================================================================

    behav_event = pynwb.behavior.BehavioralEvents(name='BehavioralEvents')
    nwbfile.add_acquisition(behav_event)

    for trial_event_type in (experiment.TrialEventType & experiment.TrialEvent
                             & session_key).fetch('trial_event_type'):
        event_times, trial_starts = (
            experiment.TrialEvent * experiment.SessionTrial
            & session_key & {
                'trial_event_type': trial_event_type
            }).fetch('trial_event_time', 'start_time')
        if len(event_times) > 0:
            event_times = np.hstack(
                event_times.astype(float) + trial_starts.astype(float))
            behav_event.create_timeseries(name=trial_event_type,
                                          unit='a.u.',
                                          conversion=1.0,
                                          data=np.full_like(event_times, 1),
                                          timestamps=event_times)

    photostim_event_time, trial_starts, photo_stim, power, duration = (
        experiment.PhotostimEvent * experiment.SessionTrial
        & session_key).fetch('photostim_event_time', 'start_time',
                             'photo_stim', 'power', 'duration')

    if len(photostim_event_time) > 0:
        behav_event.create_timeseries(
            name='photostim_start_time',
            unit='a.u.',
            conversion=1.0,
            data=power,
            timestamps=photostim_event_time.astype(float) +
            trial_starts.astype(float),
            control=photo_stim.astype('uint8'),
            control_description=stim_sites)
        behav_event.create_timeseries(
            name='photostim_stop_time',
            unit='a.u.',
            conversion=1.0,
            data=np.full_like(photostim_event_time, 0),
            timestamps=photostim_event_time.astype(float) +
            duration.astype(float) + trial_starts.astype(float),
            control=photo_stim.astype('uint8'),
            control_description=stim_sites)

    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(
                os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name),
                       mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile
示例#16
0
def export_to_nwb(session_key,
                  nwb_output_dir=default_nwb_output_dir,
                  save=False,
                  overwrite=True):
    this_session = (acquisition.Session & session_key).fetch1()

    identifier = '_'.join([
        this_session['subject_id'],
        this_session['session_time'].strftime('%Y-%m-%d'),
        this_session['session_id']
    ])

    # =============== General ====================
    # -- NWB file - a NWB2.0 file for each session
    nwbfile = NWBFile(session_description=this_session['session_note'],
                      identifier=identifier,
                      session_start_time=this_session['session_time'],
                      file_create_date=datetime.now(tzlocal()),
                      experimenter='; '.join(
                          (acquisition.Session.Experimenter
                           & session_key).fetch('experimenter')),
                      institution=institution,
                      experiment_description=experiment_description,
                      related_publications=related_publications,
                      keywords=keywords)
    # -- subject
    subj = (subject.Subject & session_key).fetch1()
    nwbfile.subject = pynwb.file.Subject(
        subject_id=this_session['subject_id'],
        description=subj['subject_description'],
        genotype=' x '.join(
            (subject.Subject.Allele & session_key).fetch('allele')),
        sex=subj['sex'],
        species=subj['species'])

    # =============== Intracellular ====================
    cell = ((intracellular.Cell
             & session_key).fetch1() if len(intracellular.Cell
                                            & session_key) == 1 else None)
    if cell:
        # metadata
        whole_cell_device = nwbfile.create_device(name=cell['device_name'])
        ic_electrode = nwbfile.create_ic_electrode(
            name=cell['cell_id'],
            device=whole_cell_device,
            description='N/A',
            filtering='N/A',
            location='; '.join([
                f'{k}: {str(v)}'
                for k, v in dict((reference.BrainLocation & cell).fetch1(),
                                 depth=cell['cell_depth']).items()
            ]))
        # acquisition - membrane potential
        mp, mp_wo_spike, mp_start_time, mp_fs = (
            intracellular.MembranePotential & cell).fetch1(
                'membrane_potential', 'membrane_potential_wo_spike',
                'membrane_potential_start_time',
                'membrane_potential_sampling_rate')
        nwbfile.add_acquisition(
            pynwb.icephys.PatchClampSeries(name='PatchClampSeries',
                                           electrode=ic_electrode,
                                           unit='mV',
                                           conversion=1e-3,
                                           gain=1.0,
                                           data=mp,
                                           starting_time=mp_start_time,
                                           rate=mp_fs))
        # acquisition - current injection
        if (intracellular.CurrentInjection & cell):
            current_injection, ci_start_time, ci_fs = (
                intracellular.CurrentInjection & cell).fetch1(
                    'current_injection', 'current_injection_start_time',
                    'current_injection_sampling_rate')
            nwbfile.add_stimulus(
                pynwb.icephys.CurrentClampStimulusSeries(
                    name='CurrentClampStimulus',
                    electrode=ic_electrode,
                    conversion=1e-9,
                    gain=1.0,
                    data=current_injection,
                    starting_time=ci_start_time,
                    rate=ci_fs))

        # analysis - membrane potential without spike
        mp_rmv_spike = nwbfile.create_processing_module(
            name='icephys', description='Spike removal')
        mp_rmv_spike.add_data_interface(
            pynwb.icephys.PatchClampSeries(name='icephys',
                                           electrode=ic_electrode,
                                           unit='mV',
                                           conversion=1e-3,
                                           gain=1.0,
                                           data=mp_wo_spike,
                                           starting_time=mp_start_time,
                                           rate=mp_fs))

    # =============== Extracellular ====================
    probe_insertion = ((extracellular.ProbeInsertion
                        & session_key).fetch1() if extracellular.ProbeInsertion
                       & session_key else None)
    if probe_insertion:
        probe = nwbfile.create_device(name=probe_insertion['probe_name'])
        electrode_group = nwbfile.create_electrode_group(name='; '.join([
            f'{probe_insertion["probe_name"]}: {str(probe_insertion["channel_counts"])}'
        ]),
                                                         description='N/A',
                                                         device=probe,
                                                         location='; '.join([
                                                             f'{k}: {str(v)}'
                                                             for k, v in
                                                             (reference.
                                                              BrainLocation
                                                              & probe_insertion
                                                              ).fetch1().items(
                                                              )
                                                         ]))

        for chn in (reference.Probe.Channel
                    & probe_insertion).fetch(as_dict=True):
            nwbfile.add_electrode(
                id=chn['channel_id'],
                group=electrode_group,
                filtering=hardware_filter,
                imp=-1.,
                x=0.0,  # not available from data
                y=0.0,  # not available from data
                z=0.0,  # not available from data
                location=electrode_group.location)

        # --- unit spike times ---
        nwbfile.add_unit_column(
            name='sampling_rate',
            description='Sampling rate of the raw voltage traces (Hz)')
        nwbfile.add_unit_column(name='depth',
                                description='depth this unit (mm)')
        nwbfile.add_unit_column(name='spike_width',
                                description='spike width of this unit (ms)')
        nwbfile.add_unit_column(
            name='cell_type',
            description='cell type (e.g. wide width, narrow width spiking)')

        for unit in (extracellular.UnitSpikeTimes
                     & probe_insertion).fetch(as_dict=True):
            # make an electrode table region (which electrode(s) is this unit coming from)
            nwbfile.add_unit(
                id=unit['unit_id'],
                electrodes=(unit['channel_id'] if isinstance(
                    unit['channel_id'], np.ndarray) else [unit['channel_id']]),
                depth=unit['unit_depth'],
                sampling_rate=ecephys_fs,
                spike_width=unit['unit_spike_width'],
                cell_type=unit['unit_cell_type'],
                spike_times=unit['spike_times'],
                waveform_mean=unit['spike_waveform'])

    # =============== Behavior ====================
    # Note: for this study, raw behavioral data were not available, only trialized data were provided
    # here, we reconstruct raw behavioral data by concatenation
    trial_seg_setting = (analysis.TrialSegmentationSetting
                         & 'trial_seg_setting=0').fetch1()
    seg_behav_query = (
        behavior.TrialSegmentedLickTrace * acquisition.TrialSet.Trial *
        (analysis.RealignedEvent.RealignedEventTime
         & 'trial_event="trial_start"')
        & session_key & trial_seg_setting)

    if seg_behav_query:
        behav_acq = pynwb.behavior.BehavioralTimeSeries(name='lick_times')
        nwbfile.add_acquisition(behav_acq)
        seg_behav = pd.DataFrame(
            seg_behav_query.fetch('start_time', 'realigned_event_time',
                                  'segmented_lick_left_on',
                                  'segmented_lick_left_off',
                                  'segmented_lick_right_on',
                                  'segmented_lick_right_off')).T
        seg_behav.columns = [
            'start_time', 'realigned_event_time', 'segmented_lick_left_on',
            'segmented_lick_left_off', 'segmented_lick_right_on',
            'segmented_lick_right_off'
        ]
        for behav_name in [
                'lick_left_on', 'lick_left_off', 'lick_right_on',
                'lick_right_off'
        ]:
            lick_times = np.hstack(r['segmented_' + behav_name] -
                                   r.realigned_event_time + r.start_time
                                   for _, r in seg_behav.iterrows())
            behav_acq.create_timeseries(name=behav_name,
                                        unit='a.u.',
                                        conversion=1.0,
                                        data=np.full_like(lick_times, 1),
                                        timestamps=lick_times)

    # =============== Photostimulation ====================
    photostim = ((stimulation.PhotoStimulation
                  & session_key).fetch1() if stimulation.PhotoStimulation
                 & session_key else None)
    if photostim:
        photostim_device = (stimulation.PhotoStimDevice & photostim).fetch1()
        stim_device = nwbfile.create_device(
            name=photostim_device['device_name'])
        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name='-'.join([photostim['hemisphere'],
                           photostim['brain_region']]),
            device=stim_device,
            excitation_lambda=float(
                (stimulation.PhotoStimProtocol
                 & photostim).fetch1('photo_stim_excitation_lambda')),
            location='; '.join([
                f'{k}: {str(v)}' for k, v in (reference.ActionLocation
                                              & photostim).fetch1().items()
            ]),
            description=(stimulation.PhotoStimProtocol
                         & photostim).fetch1('photo_stim_notes'))
        nwbfile.add_ogen_site(stim_site)

        if photostim['photostim_timeseries'] is not None:
            nwbfile.add_stimulus(
                pynwb.ogen.OptogeneticSeries(
                    name='_'.join([
                        'photostim_on',
                        photostim['photostim_datetime'].strftime(
                            '%Y-%m-%d_%H-%M-%S')
                    ]),
                    site=stim_site,
                    resolution=0.0,
                    conversion=1e-3,
                    data=photostim['photostim_timeseries'],
                    starting_time=photostim['photostim_start_time'],
                    rate=photostim['photostim_sampling_rate']))

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes:
    #                                                                       'id', 'start_time' and 'stop_time'.
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)
    if acquisition.TrialSet & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo - remove '_trial' prefix (if any)
        trial_columns = [
            {
                'name':
                tag.replace('trial_', ''),
                'description':
                re.search(
                    f'(?<={tag})(.*)#(.*)',
                    str((acquisition.TrialSet.Trial *
                         stimulation.TrialPhotoStimParam
                         ).heading)).groups()[-1].strip()
            } for tag in (acquisition.TrialSet.Trial *
                          stimulation.TrialPhotoStimParam).heading.names
            if tag not in (acquisition.TrialSet.Trial
                           & stimulation.TrialPhotoStimParam).primary_key +
            ['start_time', 'stop_time']
        ]

        # Trial Events - discard 'trial_start' and 'trial_stop' as we already have start_time and stop_time
        # also add `_time` suffix to all events
        trial_events = set(((acquisition.TrialSet.EventTime & session_key) -
                            [{
                                'trial_event': 'trial_start'
                            }, {
                                'trial_event': 'trial_stop'
                            }]).fetch('trial_event'))
        event_names = [{
            'name': e + '_time',
            'description': d + ' - (s) relative to trial start time'
        } for e, d in zip(*(reference.ExperimentalEvent & [{
            'event': k
        } for k in trial_events]).fetch('event', 'description'))]
        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns + event_names:
            nwbfile.add_trial_column(**c)

        photostim_tag_default = {
            tag: ''
            for tag in stimulation.TrialPhotoStimParam.heading.names
            if tag not in stimulation.TrialPhotoStimParam.primary_key
        }

        # Add entry to the trial-table
        for trial in (acquisition.TrialSet.Trial
                      & session_key).fetch(as_dict=True):
            events = dict(
                zip(*(acquisition.TrialSet.EventTime & trial
                      & [{
                          'trial_event': e
                      } for e in trial_events]
                      ).fetch('trial_event', 'event_time')))

            trial_tag_value = ({
                **trial,
                **events,
                **(stimulation.TrialPhotoStimParam & trial).fetch1()
            } if (stimulation.TrialPhotoStimParam & trial) else {
                **trial,
                **events,
                **photostim_tag_default
            })

            trial_tag_value['id'] = trial_tag_value[
                'trial_id']  # rename 'trial_id' to 'id'
            [
                trial_tag_value.pop(k)
                for k in acquisition.TrialSet.Trial.primary_key
            ]

            # convert None to np.nan since nwb fields does not take None
            for k, v in trial_tag_value.items():
                trial_tag_value[k] = v if v is not None else np.nan

            trial_tag_value['delay_duration'] = float(
                trial_tag_value['delay_duration'])  # convert Decimal to float

            # Final tweaks: i) add '_time' suffix and ii) remove 'trial_' prefix
            events = {k + '_time': trial_tag_value.pop(k) for k in events}
            trial_attrs = {
                k.replace('trial_', ''): trial_tag_value.pop(k)
                for k in
                [n for n in trial_tag_value if n.startswith('trial_')]
            }

            nwbfile.add_trial(**trial_tag_value, **events, **trial_attrs)

    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(
                os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name),
                       mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile
示例#17
0
class NWBFileTest(unittest.TestCase):
    def setUp(self):
        self.start = datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())
        self.ref_time = datetime(1979, 1, 1, 0, tzinfo=tzutc())
        self.create = [
            datetime(2017, 5, 1, 12, tzinfo=tzlocal()),
            datetime(2017, 5, 2, 13, 0, 0, 1, tzinfo=tzutc()),
            datetime(2017, 5, 2, 14, tzinfo=tzutc())
        ]
        self.path = 'nwbfile_test.h5'
        self.nwbfile = NWBFile(
            'a test session description for a test NWBFile',
            'FILE123',
            self.start,
            file_create_date=self.create,
            timestamps_reference_time=self.ref_time,
            experimenter='A test experimenter',
            lab='a test lab',
            institution='a test institution',
            experiment_description='a test experiment description',
            session_id='test1',
            notes='my notes',
            pharmacology='drugs',
            protocol='protocol',
            related_publications='my pubs',
            slices='my slices',
            surgery='surgery',
            virus='a virus',
            source_script='noscript',
            source_script_file_name='nofilename',
            stimulus_notes='test stimulus notes',
            data_collection='test data collection notes',
            keywords=('these', 'are', 'keywords'))

    def test_constructor(self):
        self.assertEqual(self.nwbfile.session_description,
                         'a test session description for a test NWBFile')
        self.assertEqual(self.nwbfile.identifier, 'FILE123')
        self.assertEqual(self.nwbfile.session_start_time, self.start)
        self.assertEqual(self.nwbfile.file_create_date, self.create)
        self.assertEqual(self.nwbfile.lab, 'a test lab')
        self.assertEqual(self.nwbfile.experimenter, 'A test experimenter')
        self.assertEqual(self.nwbfile.institution, 'a test institution')
        self.assertEqual(self.nwbfile.experiment_description,
                         'a test experiment description')
        self.assertEqual(self.nwbfile.session_id, 'test1')
        self.assertEqual(self.nwbfile.stimulus_notes, 'test stimulus notes')
        self.assertEqual(self.nwbfile.data_collection,
                         'test data collection notes')
        self.assertEqual(self.nwbfile.source_script, 'noscript')
        self.assertEqual(self.nwbfile.source_script_file_name, 'nofilename')
        self.assertEqual(self.nwbfile.keywords, ('these', 'are', 'keywords'))
        self.assertEqual(self.nwbfile.timestamps_reference_time, self.ref_time)

    def test_create_electrode_group(self):
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        d = self.nwbfile.create_device('a fake device')
        elecgrp = self.nwbfile.create_electrode_group(name, desc, loc, d)
        self.assertEqual(elecgrp.description, desc)
        self.assertEqual(elecgrp.location, loc)
        self.assertIs(elecgrp.device, d)

    def test_create_electrode_group_invalid_index(self):
        """
        Test the case where the user creates an electrode table region with
        indexes that are out of range of the amount of electrodes added.
        """
        nwbfile = NWBFile('a', 'b', datetime.now(tzlocal()))
        device = nwbfile.create_device('a')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 device=device,
                                                 location='a')
        for i in range(4):
            nwbfile.add_electrode(np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  'a',
                                  'a',
                                  elecgrp,
                                  id=i)
        with self.assertRaises(IndexError):
            nwbfile.create_electrode_table_region(list(range(6)), 'test')

    def test_access_group_after_io(self):
        """
        Motivated by #739
        """
        nwbfile = NWBFile('a', 'b', datetime.now(tzlocal()))
        device = nwbfile.create_device('a')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 device=device,
                                                 location='a')
        nwbfile.add_electrode(np.nan,
                              np.nan,
                              np.nan,
                              np.nan,
                              'a',
                              'a',
                              elecgrp,
                              id=0)

        with NWBHDF5IO('electrodes_mwe.nwb', 'w') as io:
            io.write(nwbfile)

        with NWBHDF5IO('electrodes_mwe.nwb', 'a') as io:
            nwbfile_i = io.read()
            for aa, bb in zip(nwbfile_i.electrodes['group'][:],
                              nwbfile.electrodes['group'][:]):
                self.assertEqual(aa.name, bb.name)

        for i in range(4):
            nwbfile.add_electrode(np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  'a',
                                  'a',
                                  elecgrp,
                                  id=i + 1)

        with NWBHDF5IO('electrodes_mwe.nwb', 'w') as io:
            io.write(nwbfile)

        with NWBHDF5IO('electrodes_mwe.nwb', 'a') as io:
            nwbfile_i = io.read()
            for aa, bb in zip(nwbfile_i.electrodes['group'][:],
                              nwbfile.electrodes['group'][:]):
                self.assertEqual(aa.name, bb.name)

        os.remove("electrodes_mwe.nwb")

    def test_epoch_tags(self):
        tags1 = ['t1', 't2']
        tags2 = ['t3', 't4']
        tstamps = np.arange(1.0, 100.0, 0.1, dtype=np.float)
        ts = TimeSeries("test_ts",
                        list(range(len(tstamps))),
                        'unit',
                        timestamps=tstamps)
        expected_tags = tags1 + tags2
        self.nwbfile.add_epoch(0.0, 1.0, tags1, ts)
        self.nwbfile.add_epoch(0.0, 1.0, tags2, ts)
        tags = self.nwbfile.epoch_tags
        six.assertCountEqual(self, expected_tags, tags)

    def test_add_acquisition(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.acquisition), 1)

    def test_add_stimulus(self):
        self.nwbfile.add_stimulus(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus), 1)

    def test_add_stimulus_template(self):
        self.nwbfile.add_stimulus_template(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus_template), 1)

    def test_add_acquisition_check_dups(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        with self.assertRaises(ValueError):
            self.nwbfile.add_acquisition(
                TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                           'grams',
                           timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))

    def test_get_acquisition_empty(self):
        with self.assertRaisesRegex(ValueError,
                                    "acquisition of NWBFile 'root' is empty"):
            self.nwbfile.get_acquisition()

    def test_get_acquisition_multiple_elements(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "more than one element in acquisition of NWBFile 'root' -- must specify a name"
        with self.assertRaisesRegex(ValueError, msg):
            self.nwbfile.get_acquisition()

    def test_add_acquisition_invalid_name(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "'TEST_TS' not found in acquisition of NWBFile 'root'"
        with self.assertRaisesRegex(KeyError, msg):
            self.nwbfile.get_acquisition("TEST_TS")

    def test_set_electrode_table(self):
        table = ElectrodeTable()  # noqa: F405
        dev1 = self.nwbfile.create_device('dev1')  # noqa: F405
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-1.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-2.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-3.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-4.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        self.nwbfile.set_electrode_table(table)
        self.assertIs(self.nwbfile.electrodes, table)
        self.assertIs(table.parent, self.nwbfile)

    def test_add_unit_column(self):
        self.nwbfile.add_unit_column('unit_type', 'the type of unit')
        self.assertEqual(self.nwbfile.units.colnames, ('unit_type', ))

    def test_add_unit(self):
        self.nwbfile.add_unit(id=1)
        self.assertEqual(len(self.nwbfile.units), 1)
        self.nwbfile.add_unit(id=2)
        self.nwbfile.add_unit(id=3)
        self.assertEqual(len(self.nwbfile.units), 3)

    def test_add_trial_column(self):
        self.nwbfile.add_trial_column('trial_type', 'the type of trial')
        self.assertEqual(self.nwbfile.trials.colnames,
                         ('start_time', 'stop_time', 'trial_type'))

    def test_add_trial(self):
        self.nwbfile.add_trial(start_time=10.0, stop_time=20.0)
        self.assertEqual(len(self.nwbfile.trials), 1)
        self.nwbfile.add_trial(start_time=30.0, stop_time=40.0)
        self.nwbfile.add_trial(start_time=50.0, stop_time=70.0)
        self.assertEqual(len(self.nwbfile.trials), 3)

    def test_add_invalid_times_column(self):
        self.nwbfile.add_invalid_times_column(
            'comments', 'description of reason for omitting time')
        self.assertEqual(self.nwbfile.invalid_times.colnames,
                         ('start_time', 'stop_time', 'comments'))

    def test_add_invalid_time_interval(self):

        self.nwbfile.add_invalid_time_interval(start_time=0.0, stop_time=12.0)
        self.assertEqual(len(self.nwbfile.invalid_times), 1)
        self.nwbfile.add_invalid_time_interval(start_time=15.0, stop_time=16.0)
        self.nwbfile.add_invalid_time_interval(start_time=17.0, stop_time=20.5)
        self.assertEqual(len(self.nwbfile.invalid_times), 3)

    def test_add_invalid_time_w_ts(self):
        ts = TimeSeries(name='name', data=[1.2], rate=1.0, unit='na')
        self.nwbfile.add_invalid_time_interval(start_time=18.0,
                                               stop_time=20.6,
                                               timeseries=ts,
                                               tags=('hi', 'there'))

    def test_add_electrode(self):
        dev1 = self.nwbfile.create_device('dev1')  # noqa: F405
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        self.nwbfile.add_electrode(1.0,
                                   2.0,
                                   3.0,
                                   -1.0,
                                   'CA1',
                                   'none',
                                   group=group,
                                   id=1)
        self.assertEqual(self.nwbfile.ec_electrodes[0][0], 1)
        self.assertEqual(self.nwbfile.ec_electrodes[0][1], 1.0)
        self.assertEqual(self.nwbfile.ec_electrodes[0][2], 2.0)
        self.assertEqual(self.nwbfile.ec_electrodes[0][3], 3.0)
        self.assertEqual(self.nwbfile.ec_electrodes[0][4], -1.0)
        self.assertEqual(self.nwbfile.ec_electrodes[0][5], 'CA1')
        self.assertEqual(self.nwbfile.ec_electrodes[0][6], 'none')
        self.assertEqual(self.nwbfile.ec_electrodes[0][7], group)

    def test_all_children(self):
        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        device = self.nwbfile.create_device('a fake device')
        elecgrp = self.nwbfile.create_electrode_group(name, desc, loc, device)
        children = self.nwbfile.all_children()
        self.assertIn(ts1, children)
        self.assertIn(ts2, children)
        self.assertIn(device, children)
        self.assertIn(elecgrp, children)

    def test_fail_if_source_script_file_name_without_source_script(self):
        with self.assertRaises(ValueError):
            # <-- source_script_file_name without source_script is not allowed
            NWBFile('a test session description for a test NWBFile',
                    'FILE123',
                    self.start,
                    source_script=None,
                    source_script_file_name='nofilename')

    def test_get_neurodata_type(self):
        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        p1 = ts1.get_ancestor(neurodata_type='NWBFile')
        self.assertIs(p1, self.nwbfile)
        p2 = ts2.get_ancestor(neurodata_type='NWBFile')
        self.assertIs(p2, self.nwbfile)
示例#18
0
from pynwb import NWBFile
from pynwb import NWBHDF5IO
import demoHelper

# Create the NWB file
nwb = NWBFile(...)

#session: 'TWH103_111918'
NOID = 144

#Get all Spike Times
allSpikeTimes = demoHelper.getSpikeTimesAllCells(...)

for cluster in allSpikeTimes.keys():
    # Add spike times to Units Table
    nwb.add_unit(id=int(cluster), spike_times=allSpikeTimes[cluster][0])

#Add Trial(s) information (Learning)
events_learn_stim_on = [4.218311e+09, 4.222039e+09, ...]
events_learn_stim_off = [4.219308e+09, 4.223034e+09, ...]
...

for i in range(len(events_learn_stim_on)):
    nwb.add_trial(stim_on=events_learn_stim_on[i],
                  stim_off=events_learn_stim_off[i], ...)

#Add Trial(s) information (Recognition)
events_recog_stim_on = [5.222291e+09, 5.229686e+09, ...]
events_recog_stim_off = [5.223292e+09, 5.230683e+09, ...]
...
示例#19
0
def run_network(model, env, run_trial, num_trial=1000, file=None):
    """Run trained networks for analysis on trial-based tasks.

    Args:
        model: model of arbitrary format, must provide a run_one_trial function
            that works with it
        env: neurogym environment
        run_trial: function handle for running model for one trial,
            takes (model, env) as inputs and 
            returns (model, env, activity, trial_info), where activity has 
            shape (N_time, N_unit)
        num_trial: int, number of trials to run
        file: str or None, file name to save

    Returns:
        activity: a list of activity matrices, each matrix has shape (
        N_time, N_neuron)
        info: pandas dataframe, each row is information of a trial
        config: dict of network, training configurations
    """
    env.reset(no_step=True)

    # Make NWB file
    nwbfile = NWBFile(
        session_description=str(env),  # required
        identifier='NWB_default',  # required
        session_start_time=datetime.datetime.now(),  # required
        file_create_date=datetime.datetime.now())

    info = pd.DataFrame()

    spike_times = defaultdict(list)
    start_time = 0.
    for i in range(num_trial):
        model, env, hidden, trial_info = run_trial(model, env)

        # Log trial info
        for key, val in env.start_t.items():
            # NWB time default unit is second, ngym default is ms
            trial_info['start_' + key] = val / 1000. + start_time
        for key, val in env.end_t.items():
            trial_info['end_' + key] = val / 1000. + start_time

        info = info.append(trial_info, ignore_index=True)

        # Store results to NWB file
        if i == 0:
            for key in trial_info.keys():
                nwbfile.add_trial_column(name=key, description=key)

        stop_time = start_time + hidden.shape[0] * env.dt / 1000.

        # Generate simulated spikes from rates
        scale_rate = 10.
        for j in range(hidden.shape[-1]):
            spikes = sample_spikes(hidden[:, j] * scale_rate,
                                   dt=env.dt / 1000.) + start_time
            spike_times[j].append(spikes)

        nwbfile.add_trial(start_time=start_time,
                          stop_time=stop_time,
                          **trial_info)
        start_time = stop_time  # Assuming continous trials

    try:
        print('Average performance', np.mean(info['correct']))
    except:
        pass

    for j in range(hidden.shape[-1]):  # For each neuron
        nwbfile.add_unit(id=j, spike_times=np.concatenate(spike_times[j]))
    # TODO: Check why the file.units['spike_times'] is weird

    if file is None:
        file = str(get_modelpath(envid) / (envid + '.nwb'))
    with pynwb.NWBHDF5IO(file, 'w') as io:
        io.write(nwbfile)
示例#20
0
def no2nwb(NOData, session_use, subjects_ini, path_to_data):
    '''
       Purpose:
           Import the data and associated meta-data from the new/old recognition dataset into an
           NWB file. Each of the features of the dataset, such as the events (i.e., TTLs) or mean waveform, are
           compartmentalized to the appropriate component of the NWB file.


    '''

    # Time scaling (covert uS -----> S for NWB file)
    TIME_SCALING = 10**6

    # Prepare the NO data that will be coverted to the NWB format

    session = NOData.sessions[session_use]
    events = NOData._get_event_data(session_use, experiment_type='All')
    cell_ids = NOData.ls_cells(session_use)
    experiment_id_learn = session['experiment_id_learn']
    experiment_id_recog = session['experiment_id_recog']
    task_descr = session['task_descr']

    # Get the metadata for the subject
    # ============ Read Config File ()
    # load config file (subjects == config file)

    #  Check config file path
    filename = subjects_ini
    if not os.path.exists(filename):
        print('This file does not exist: {}'.format(filename))
        print("Check filename/and or directory")

    # Read the config file
    try:
        # initialze the ConfigParser() class
        config = configparser.ConfigParser()
        # read .ini file
        config.read(filename)
    except:
        print('Failed to read the config file..')
        print('Does this file exist: {}'.format(os.path.exists(filename)))

    #  Read Meta-data from INI file.
    for section in config.sections():
        if session_use == int(section):
            session_id = int(section)  #  The New/Old ID for the session
            #Get the session ID
            for value in config[section]:
                if value.lower() == 'nosessions.age':
                    age = int(config[section][value])
                if value.lower() == 'nosessions.diagnosiscode':
                    epilepsyDxCode = config[section][value]
                    epilepsyDx = getEpilepsyDx(int(epilepsyDxCode))
                if value.lower() == 'nosessions.sex':
                    sex = config[section][value].strip("'")
                if value.lower() == 'nosessions.id':
                    ID = config[section][value].strip("'")
                if value.lower() == 'nosessions.session':
                    pt_session = config[section][value].strip("'")
                if value.lower() == 'nosessions.date':
                    unformattedDate = config[section][value].strip("'")
                    date = datetime.strptime(unformattedDate, '%Y-%m-%d')
                    finaldate = date.replace(hour=0, minute=0)
                if value.lower() == 'nosessions.institution':
                    institution = config[section][value].strip("'")
                if value.lower() == 'nosessions.la':
                    LA = config[section][value].strip("'").split(',')
                    if LA[0] == 'NaN':
                        LA_x = np.nan
                        LA_y = np.nan
                        LA_z = np.nan
                    else:
                        LA_x = float(LA[0])
                        LA_y = float(LA[1])
                        LA_z = float(LA[2])
                if value.lower() == 'nosessions.ra':
                    RA = config[section][value].strip("'").split(',')
                    if RA[0] == 'NaN':
                        RA_x = np.nan
                        RA_y = np.nan
                        RA_z = np.nan
                    else:
                        RA_x = float(RA[0])
                        RA_y = float(RA[1])
                        RA_z = float(RA[2])
                if value.lower() == 'nosessions.lh':
                    LH = config[section][value].strip("'").split(',')
                    if LH[0] == 'NaN':
                        LH_x = np.nan
                        LH_y = np.nan
                        LH_z = np.nan
                    else:
                        LH_x = float(LH[0])
                        LH_y = float(LH[1])
                        LH_z = float(LH[2])
                if value.lower() == 'nosessions.rh':
                    RH = config[section][value].strip("'").split(',')
                    if RH[0] == 'NaN':
                        RH_x = np.nan
                        RH_y = np.nan
                        RH_z = np.nan
                    else:
                        RH_x = float(RH[0])
                        RH_y = float(RH[1])
                        RH_z = float(RH[2])
                if value.lower() == 'nosessions.system':
                    signalSystem = config[section][value].strip("'")

    # =================================================================

    print(
        '======================================================================='
    )
    print('session use: {}'.format(session_id))
    print('age: {}'.format(age))
    print('epilepsy_diagnosis: {}'.format(epilepsyDx))

    nwb_subject = Subject(age=str(age),
                          description=epilepsyDx,
                          sex=sex,
                          species='Human',
                          subject_id=pt_session[:pt_session.find('_')])

    # Create the NWB file
    nwbfile = NWBFile(
        #source='https://datadryad.org/bitstream/handle/10255/dryad.163179/RecogMemory_MTL_release_v2.zip',
        session_description='New/Old recognition task for ID: {}. '.format(
            session_id),
        identifier='{}_{}'.format(ID, session_use),
        session_start_time=finaldate,  #default session start time
        file_create_date=datetime.now(),
        experiment_description=
        'The data contained within this file describes a new/old recogntion task performed in '
        'patients with intractable epilepsy implanted with depth electrodes and Behnke-Fried '
        'microwires in the human Medical Temporal Lobe (MTL).',
        institution=institution,
        keywords=[
            'Intracranial Recordings', 'Intractable Epilepsy',
            'Single-Unit Recordings', 'Cognitive Neuroscience', 'Learning',
            'Memory', 'Neurosurgery'
        ],
        related_publications=
        'Faraut et al. 2018, Scientific Data; Rutishauser et al. 2015, Nat Neurosci;',
        lab='Rutishauser',
        subject=nwb_subject,
        data_collection='learning: {}, recognition: {}'.format(
            session['experiment_id_learn'], session['experiment_id_recog']))

    # Add events and experiment_id acquisition
    events_description = (
        """ The events coorespond to the TTL markers for each trial. For the learning trials, the TTL markers 
            are the following: 55 = start of the experiment, 1 = stimulus ON, 2 = stimulus OFF, 3 = Question Screen Onset [“Is this an animal?”], 
            20 = Yes (21 = NO) during learning, 6 = End of Delay after Response, 66 = End of Experiment. For the recognition trials, 
            the TTL markers are the following: 55 = start of experiment, 1 = stimulus ON, 2 = stimulus OFF, 3 = Question Screen Onset [“Have you seen this image before?”], 
            31:36 = Confidence (Yes vs. No) response [31 (new, confident), 32 (new, probably), 33 (new, guess), 34 (old, guess), 
            35 (old, probably), 36 (old, confident)], 66 = End of Experiment"""
    )

    event_ts = AnnotationSeries(name='events',
                                data=np.asarray(events[1].values).astype(str),
                                timestamps=np.asarray(events[0].values) /
                                TIME_SCALING,
                                description=events_description)

    experiment_ids_description = (
        """The experiment_ids coorespond to the encoding (i.e., learning) or recogniton trials. The learning trials are demarcated by: {}. The recognition trials are demarcated by: {}. """
        .format(experiment_id_learn, experiment_id_recog))

    experiment_ids = TimeSeries(name='experiment_ids',
                                unit='NA',
                                data=np.asarray(events[2]),
                                timestamps=np.asarray(events[0].values) /
                                TIME_SCALING,
                                description=experiment_ids_description)

    nwbfile.add_acquisition(event_ts)
    nwbfile.add_acquisition(experiment_ids)

    # Add stimuli to the NWB file
    # Get the first cell from the cell list
    cell = NOData.pop_cell(session_use,
                           NOData.ls_cells(session_use)[0], path_to_data)
    trials = cell.trials
    stimuli_recog_path = [trial.file_path_recog for trial in trials]
    stimuli_learn_path = [trial.file_path_learn for trial in trials]

    # Add epochs and trials: storing start and end times for a stimulus

    # First extract the category ids and names that we need
    # The metadata for each trials will be store in a trial table

    cat_id_recog = [trial.category_recog for trial in trials]
    cat_name_recog = [trial.category_name_recog for trial in trials]
    cat_id_learn = [trial.category_learn for trial in trials]
    cat_name_learn = [trial.category_name_learn for trial in trials]

    # Extract the event timestamps
    events_learn_stim_on = events[(events[2] == experiment_id_learn) &
                                  (events[1] == NOData.markers['stimulus_on'])]
    events_learn_stim_off = events[(events[2] == experiment_id_learn) & (
        events[1] == NOData.markers['stimulus_off'])]
    events_learn_delay1_off = events[(events[2] == experiment_id_learn) & (
        events[1] == NOData.markers['delay1_off'])]
    events_learn_delay2_off = events[(events[2] == experiment_id_learn) & (
        events[1] == NOData.markers['delay2_off'])]
    events_learn = events[(events[2] == experiment_id_learn)]
    events_learn_response = []
    events_learn_response_time = []
    for i in range(len(events_learn[0])):
        if (events_learn.iloc[i, 1]
                == NOData.markers['response_learning_animal']) or (
                    events_learn.iloc[i, 1]
                    == NOData.markers['response_learning_non_animal']):
            events_learn_response.append(events_learn.iloc[i, 1] - 20)
            events_learn_response_time.append(events_learn.iloc[i, 0])

    events_recog_stim_on = events[(events[2] == experiment_id_recog) &
                                  (events[1] == NOData.markers['stimulus_on'])]
    events_recog_stim_off = events[(events[2] == experiment_id_recog) & (
        events[1] == NOData.markers['stimulus_off'])]
    events_recog_delay1_off = events[(events[2] == experiment_id_recog) & (
        events[1] == NOData.markers['delay1_off'])]
    events_recog_delay2_off = events[(events[2] == experiment_id_recog) & (
        events[1] == NOData.markers['delay2_off'])]
    events_recog = events[(events[2] == experiment_id_recog)]
    events_recog_response = []
    events_recog_response_time = []
    for i in range(len(events_recog[0])):
        if ((events_recog.iloc[i, 1] == NOData.markers['response_1'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_2'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_3'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_4'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_5'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_6'])):
            events_recog_response.append(events_recog.iloc[i, 1])
            events_recog_response_time.append(events_recog.iloc[i, 0])

    # Extract new_old label
    new_old_recog = [trial.new_old_recog for trial in trials]
    # Create the trial tables

    nwbfile.add_trial_column('stim_on_time',
                             'The Time when the Stimulus is Shown')
    nwbfile.add_trial_column('stim_off_time',
                             'The Time when the Stimulus is Off')
    nwbfile.add_trial_column('delay1_time', 'The Time when Delay1 is Off')
    nwbfile.add_trial_column('delay2_time', 'The Time when Delay2 is Off')
    nwbfile.add_trial_column('stim_phase',
                             'Learning/Recognition Phase During the Trial')
    nwbfile.add_trial_column('stimCategory', 'The Category ID of the Stimulus')
    nwbfile.add_trial_column('category_name',
                             'The Category Name of the Stimulus')
    nwbfile.add_trial_column('external_image_file',
                             'The File Path to the Stimulus')
    nwbfile.add_trial_column(
        'new_old_labels_recog',
        '''The Ground truth Labels for New or Old Stimulus. 0 == Old Stimuli 
                            (presented during the learning phase), 1 = New Stimuli (not seen )'during learning phase'''
    )
    nwbfile.add_trial_column('response_value',
                             'The Response for Each Stimulus')
    nwbfile.add_trial_column('response_time',
                             'The Response Time for each Stimulus')

    range_recog = np.amin([
        len(events_recog_stim_on),
        len(events_recog_stim_off),
        len(events_recog_delay1_off),
        len(events_recog_delay2_off)
    ])
    range_learn = np.amin([
        len(events_learn_stim_on),
        len(events_learn_stim_off),
        len(events_learn_delay1_off),
        len(events_learn_delay2_off)
    ])

    # Iterate the event list and add information into each epoch and trial table
    for i in range(range_learn):

        nwbfile.add_trial(
            start_time=(events_learn_stim_on.iloc[i][0]) / (TIME_SCALING),
            stop_time=(events_learn_delay2_off.iloc[i][0]) / (TIME_SCALING),
            stim_on_time=(events_learn_stim_on.iloc[i][0]) / (TIME_SCALING),
            stim_off_time=(events_learn_stim_off.iloc[i][0]) / (TIME_SCALING),
            delay1_time=(events_learn_delay1_off.iloc[i][0]) / (TIME_SCALING),
            delay2_time=(events_learn_delay2_off.iloc[i][0]) / (TIME_SCALING),
            stim_phase='learn',
            stimCategory=cat_id_learn[i],
            category_name=cat_name_learn[i],
            external_image_file=stimuli_learn_path[i],
            new_old_labels_recog='NA',
            response_value=events_learn_response[i],
            response_time=(events_learn_response_time[i]) / (TIME_SCALING))

    for i in range(range_recog):

        nwbfile.add_trial(
            start_time=events_recog_stim_on.iloc[i][0] / (TIME_SCALING),
            stop_time=events_recog_delay2_off.iloc[i][0] / (TIME_SCALING),
            stim_on_time=events_recog_stim_on.iloc[i][0] / (TIME_SCALING),
            stim_off_time=events_recog_stim_off.iloc[i][0] / (TIME_SCALING),
            delay1_time=events_recog_delay1_off.iloc[i][0] / (TIME_SCALING),
            delay2_time=events_recog_delay2_off.iloc[i][0] / (TIME_SCALING),
            stim_phase='recog',
            stimCategory=cat_id_recog[i],
            category_name=cat_name_recog[i],
            external_image_file=stimuli_recog_path[i],
            new_old_labels_recog=new_old_recog[i],
            response_value=events_recog_response[i],
            response_time=events_recog_response_time[i] / (TIME_SCALING))

    # Add the waveform clustering and the spike data.
    # Get the unique channel id that we will be iterate over
    channel_ids = np.unique([cell_id[0] for cell_id in cell_ids])

    # unique unit id
    unit_id = 0

    # Create unit columns
    nwbfile.add_unit_column('origClusterID', 'The original cluster id')
    nwbfile.add_unit_column('waveform_mean_encoding',
                            'The mean waveform for encoding phase.')
    nwbfile.add_unit_column('waveform_mean_recognition',
                            'The mean waveform for the recognition phase.')
    nwbfile.add_unit_column('IsolationDist', 'IsolDist')
    nwbfile.add_unit_column('SNR', 'SNR')
    nwbfile.add_unit_column('waveform_mean_sampling_rate',
                            'The Sampling Rate of Waveform')

    #Add Stimuli
    stimuli_presentation = []

    # Add stimuli learn
    counter = 1
    for path in stimuli_learn_path:
        if path == 'NA':
            continue
        folders = path.split('\\')

        path = os.path.join(path_to_data, 'Stimuli', folders[0], folders[1],
                            folders[2])
        img = cv2.imread(path)
        resized_image = cv2.resize(img, (300, 400))
        stimuli_presentation.append(resized_image)

    # Add stimuli recog
    counter = 1
    for path in stimuli_recog_path:
        folders = path.split('\\')
        path = os.path.join(path_to_data, 'Stimuli', folders[0], folders[1],
                            folders[2])
        img = cv2.imread(path)
        resized_image = cv2.resize(img, (300, 400))
        stimuli_presentation.append(resized_image)
        name = 'stimuli_recog_' + str(counter)

    # Add stimuli to OpticalSeries
    stimulus_presentation_on_time = []

    for n in range(0, len(events_learn_stim_on)):
        stimulus_presentation_on_time.append(events_learn_stim_on.iloc[n][0] /
                                             (TIME_SCALING))

    for n in range(0, len(events_recog_stim_on)):
        stimulus_presentation_on_time.append(events_recog_stim_on.iloc[n][0] /
                                             (TIME_SCALING))

    name = 'StimulusPresentation'
    stimulus = OpticalSeries(name=name,
                             data=stimuli_presentation,
                             timestamps=stimulus_presentation_on_time[:],
                             orientation='lower left',
                             format='raw',
                             unit='meters',
                             field_of_view=[.2, .3, .7],
                             distance=0.7,
                             dimension=[300, 400, 3])

    nwbfile.add_stimulus(stimulus)

    # Get Unit data
    all_spike_cluster_ids = []
    all_selected_time_stamps = []
    all_IsolDist = []
    all_SNR = []
    all_selected_mean_waveform_learn = []
    all_selected_mean_waveform_recog = []
    all_mean_waveform = []
    all_channel_id = []
    all_oriClusterIDs = []
    all_channel_numbers = []
    all_brain_area = []
    # Iterate the channel list

    # load brain area file
    brain_area_file_path = os.path.join(path_to_data, 'Data', 'events',
                                        session['session'], task_descr,
                                        'brainArea.mat')

    try:
        brain_area_mat = loadmat(brain_area_file_path)
    except FileNotFoundError:
        print("brain_area_mat file not found")

    for channel_id in channel_ids:
        cell_name = 'A' + str(channel_id) + '_cells.mat'
        cell_file_path = os.path.join(path_to_data, 'Data', 'sorted',
                                      session['session'], task_descr,
                                      cell_name)

        try:
            cell_mat = loadmat(cell_file_path)
        except FileNotFoundError:
            print("cell mat file not found")
            continue

        spikes = cell_mat['spikes']
        meanWaveform_recog = cell_mat['meanWaveform_recog']
        meanWaveform_learn = cell_mat['meanWaveform_learn']
        IsolDist_SNR = cell_mat['IsolDist_SNR']

        spike_cluster_id = np.asarray([spike[1] for spike in spikes
                                       ])  # Each Cluster ID of the spike
        spike_timestamps = (np.asarray([spike[2] for spike in spikes])) / (
            TIME_SCALING)  # Timestamps of spikes for each ClusterID
        unique_cluster_ids = np.unique(spike_cluster_id)

        # If there are more than one cluster.
        for id in unique_cluster_ids:

            # Grab brain area
            brain_area = extra_brain_area(brain_area_mat, channel_id)

            selected_spike_timestamps = spike_timestamps[spike_cluster_id ==
                                                         id]
            IsolDist, SNR = extract_IsolDist_SNR_by_cluster_id(
                IsolDist_SNR, id)
            selected_mean_waveform_learn = extra_mean_waveform(
                meanWaveform_learn, id)
            selected_mean_waveform_recog = extra_mean_waveform(
                meanWaveform_recog, id)

            # If the mean waveform does not have 256 elements, we set the mean wave form to all 0
            if len(selected_mean_waveform_learn) != 256:
                selected_mean_waveform_learn = np.zeros(256)
            if len(selected_mean_waveform_recog) != 256:
                selected_mean_waveform_recog = np.zeros(256)

            mean_waveform = np.hstack(
                [selected_mean_waveform_learn, selected_mean_waveform_recog])

            # Append unit data
            all_spike_cluster_ids.append(id)
            all_selected_time_stamps.append(selected_spike_timestamps)
            all_IsolDist.append(IsolDist)
            all_SNR.append(SNR)
            all_selected_mean_waveform_learn.append(
                selected_mean_waveform_learn)
            all_selected_mean_waveform_recog.append(
                selected_mean_waveform_recog)
            all_mean_waveform.append(mean_waveform)
            all_channel_id.append(channel_id)
            all_oriClusterIDs.append(int(id))
            all_channel_numbers.append(channel_id)
            all_brain_area.append(brain_area)

            unit_id += 1

    nwbfile.add_electrode_column(
        name='origChannel',
        description='The original channel ID for the channel')

    #Add Device
    device = nwbfile.create_device(name=signalSystem)

    # Add Electrodes (brain Area Locations, MNI coordinates for microwires)
    length_all_spike_cluster_ids = len(all_spike_cluster_ids)
    for electrodeNumber in range(0, len(channel_ids)):

        brainArea_location = extra_brain_area(brain_area_mat,
                                              channel_ids[electrodeNumber])

        if brainArea_location == 'RH':  #  Right Hippocampus
            full_brainArea_Location = 'Right Hippocampus'

            electrode_name = '{}-microwires-{}'.format(
                signalSystem, channel_ids[electrodeNumber])
            description = "Behnke Fried/Micro Inner Wire Bundle (Behnke-Fried BF08R-SP05X-000 and WB09R-SP00X-0B6; Ad-Tech Medical)"
            location = full_brainArea_Location

            # Add electrode group
            electrode_group = nwbfile.create_electrode_group(
                electrode_name,
                description=description,
                location=location,
                device=device)

            #Add Electrode
            nwbfile.add_electrode([channel_ids[electrodeNumber]],
                                  x=RH_x,
                                  y=RH_y,
                                  z=RH_z,
                                  imp=np.nan,
                                  location=full_brainArea_Location,
                                  filtering='300-3000Hz',
                                  group=electrode_group,
                                  origChannel=channel_ids[electrodeNumber])

        if brainArea_location == 'LH':
            full_brainArea_Location = 'Left Hippocampus'

            electrode_name = '{}-microwires-{}'.format(
                signalSystem, channel_ids[electrodeNumber])
            description = "Behnke Fried/Micro Inner Wire Bundle (Behnke-Fried BF08R-SP05X-000 and WB09R-SP00X-0B6; Ad-Tech Medical)"
            location = full_brainArea_Location

            # Add electrode group
            electrode_group = nwbfile.create_electrode_group(
                electrode_name,
                description=description,
                location=location,
                device=device)

            nwbfile.add_electrode([all_channel_id[electrodeNumber]],
                                  x=LH_x,
                                  y=LH_y,
                                  z=LH_z,
                                  imp=np.nan,
                                  location=full_brainArea_Location,
                                  filtering='300-3000Hz',
                                  group=electrode_group,
                                  origChannel=channel_ids[electrodeNumber])
        if brainArea_location == 'RA':
            full_brainArea_Location = 'Right Amygdala'

            electrode_name = '{}-microwires-{}'.format(
                signalSystem, channel_ids[electrodeNumber])
            description = "Behnke Fried/Micro Inner Wire Bundle (Behnke-Fried BF08R-SP05X-000 and WB09R-SP00X-0B6; Ad-Tech Medical)"
            location = full_brainArea_Location

            # Add electrode group
            electrode_group = nwbfile.create_electrode_group(
                electrode_name,
                description=description,
                location=location,
                device=device)

            nwbfile.add_electrode([all_channel_id[electrodeNumber]],
                                  x=RA_x,
                                  y=RA_y,
                                  z=RA_z,
                                  imp=np.nan,
                                  location=full_brainArea_Location,
                                  filtering='300-3000Hz',
                                  group=electrode_group,
                                  origChannel=channel_ids[electrodeNumber])
        if brainArea_location == 'LA':
            full_brainArea_Location = 'Left Amygdala'

            electrode_name = '{}-microwires-{}'.format(
                signalSystem, channel_ids[electrodeNumber])
            description = "Behnke Fried/Micro Inner Wire Bundle (Behnke-Fried BF08R-SP05X-000 and WB09R-SP00X-0B6; Ad-Tech Medical)"
            location = full_brainArea_Location

            # Add electrode group
            electrode_group = nwbfile.create_electrode_group(
                electrode_name,
                description=description,
                location=location,
                device=device)

            nwbfile.add_electrode([all_channel_id[electrodeNumber]],
                                  x=LA_x,
                                  y=LA_y,
                                  z=LA_z,
                                  imp=np.nan,
                                  location=full_brainArea_Location,
                                  filtering='300-3000Hz',
                                  group=electrode_group,
                                  origChannel=channel_ids[electrodeNumber])

    # Create Channel list index
    channel_list = list(range(0, length_all_spike_cluster_ids))
    unique_channel_ids = np.unique(all_channel_id)
    length_ChannelIds = len(np.unique(all_channel_id))
    for yy in range(0, length_ChannelIds):
        a = np.array(np.where(unique_channel_ids[yy] == all_channel_id))
        b = a[0]
        c = b.tolist()
        for i in c:
            channel_list[i] = yy

    #Add WAVEFORM Sampling RATE
    waveform_mean_sampling_rate = [98.4 * 10**3]
    waveform_mean_sampling_rate_matrix = [waveform_mean_sampling_rate
                                          ] * (length_all_spike_cluster_ids)

    # Add Units to NWB file
    for index_id in range(0, length_all_spike_cluster_ids):
        nwbfile.add_unit(
            id=index_id,
            spike_times=all_selected_time_stamps[index_id],
            origClusterID=all_oriClusterIDs[index_id],
            IsolationDist=all_IsolDist[index_id],
            SNR=all_SNR[index_id],
            waveform_mean_encoding=all_selected_mean_waveform_learn[index_id],
            waveform_mean_recognition=all_selected_mean_waveform_recog[
                index_id],
            electrodes=[channel_list[index_id]],
            waveform_mean_sampling_rate=waveform_mean_sampling_rate_matrix[
                index_id])

    return nwbfile
示例#21
0
def export_to_nwb(session_key,
                  nwb_output_dir=default_nwb_output_dir,
                  save=False,
                  overwrite=True):
    this_session = (acquisition.Session & session_key).fetch1()

    identifier = '_'.join([
        this_session['subject_id'],
        this_session['session_time'].strftime('%Y-%m-%d'),
        str(this_session['session_id'])
    ])
    # =============== General ====================
    # -- NWB file - a NWB2.0 file for each session
    nwbfile = NWBFile(session_description=this_session['session_note'],
                      identifier=identifier,
                      session_start_time=this_session['session_time'],
                      file_create_date=datetime.now(tzlocal()),
                      experimenter='; '.join(
                          (acquisition.Session.Experimenter
                           & session_key).fetch('experimenter')),
                      institution=institution,
                      experiment_description=experiment_description,
                      related_publications=related_publications,
                      keywords=keywords)
    # -- subject
    subj = (subject.Subject & session_key).fetch1()
    nwbfile.subject = pynwb.file.Subject(
        subject_id=this_session['subject_id'],
        description=subj['subject_description'],
        sex=subj['sex'],
        species=subj['species'])

    # =============== Extracellular ====================
    probe_insertion = ((extracellular.ProbeInsertion
                        & session_key).fetch1() if extracellular.ProbeInsertion
                       & session_key else None)
    if probe_insertion:
        probe = nwbfile.create_device(name=probe_insertion['probe_name'])
        electrode_group = nwbfile.create_electrode_group(name='; '.join([
            f'{probe_insertion["probe_name"]}: {str(probe_insertion["channel_counts"])}'
        ]),
                                                         description='N/A',
                                                         device=probe,
                                                         location='; '.join([
                                                             f'{k}: {str(v)}'
                                                             for k, v in
                                                             (reference.
                                                              BrainLocation
                                                              & probe_insertion
                                                              ).fetch1().items(
                                                              )
                                                         ]))

        for chn in (reference.Probe.Channel
                    & probe_insertion).fetch(as_dict=True):
            nwbfile.add_electrode(
                id=chn['channel_id'],
                group=electrode_group,
                filtering=hardware_filter,
                imp=np.nan,
                x=np.nan,  # not available from data
                y=np.nan,  # not available from data
                z=np.nan,  # not available from data
                location=electrode_group.location)

        # --- unit spike times ---
        nwbfile.add_unit_column(name='depth',
                                description='depth this unit (um)')
        nwbfile.add_unit_column(
            name='quality',
            description=
            'quality of the spike sorted unit (e.g. excellent, good, poor, fair, etc.)'
        )
        nwbfile.add_unit_column(
            name='cell_type', description='cell type (e.g. PTlower, PTupper)')

        for unit in (extracellular.UnitSpikeTimes
                     & probe_insertion).fetch(as_dict=True):
            # make an electrode table region (which electrode(s) is this unit coming from)
            unit_chn = unit['channel_id'] if isinstance(
                unit['channel_id'], np.ndarray) else [unit['channel_id']]

            nwbfile.add_unit(id=unit['unit_id'],
                             electrodes=np.where(
                                 np.in1d(np.array(nwbfile.electrodes.id.data),
                                         unit_chn))[0],
                             depth=unit['unit_depth'],
                             quality=unit['unit_quality'],
                             cell_type=unit['unit_cell_type'],
                             spike_times=unit['spike_times'])

    # =============== Behavior ====================
    behavior_data = ((behavior.LickTimes
                      & session_key).fetch1() if behavior.LickTimes
                     & session_key else None)
    if behavior_data:
        behav_acq = pynwb.behavior.BehavioralEvents(name='lick_times')
        nwbfile.add_acquisition(behav_acq)
        [behavior_data.pop(k) for k in behavior.LickTimes.primary_key]
        for b_k, b_v in behavior_data.items():
            behav_acq.create_timeseries(name=b_k,
                                        unit='a.u.',
                                        conversion=1.0,
                                        data=np.full_like(b_v, 1).astype(bool),
                                        timestamps=b_v)

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes:
    #                                                                       'id', 'start_time' and 'stop_time'.
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)

    # adjust trial event times to be relative to session's start time
    q_trial_event = (acquisition.TrialSet.EventTime *
                     acquisition.TrialSet.Trial.proj('start_time')).proj(
                         event_time='event_time + start_time')

    if acquisition.TrialSet & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo
        trial_columns = [{
            'name':
            tag.replace('trial_', ''),
            'description':
            re.search(
                f'(?<={tag})(.*)#(.*)',
                str(acquisition.TrialSet.Trial.heading)).groups()[-1].strip()
        } for tag in acquisition.TrialSet.Trial.heading.names
                         if tag not in acquisition.TrialSet.Trial.primary_key +
                         ['start_time', 'stop_time']]

        # Trial Events - discard 'trial_start' and 'trial_stop' as we already have start_time and stop_time
        # also add `_time` suffix to all events

        trial_events = set(((acquisition.TrialSet.EventTime & session_key) -
                            [{
                                'trial_event': 'trial_start'
                            }, {
                                'trial_event': 'trial_stop'
                            }]).fetch('trial_event'))
        event_names = [{
            'name': e + '_time',
            'description': d
        } for e, d in zip(*(reference.ExperimentalEvent & [{
            'event': k
        } for k in trial_events]).fetch('event', 'description'))]
        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns + event_names:
            nwbfile.add_trial_column(**c)

        # Add entry to the trial-table
        for trial in (acquisition.TrialSet.Trial
                      & session_key).fetch(as_dict=True):
            events = dict(
                zip(*(q_trial_event & trial & [{
                    'trial_event': e
                } for e in trial_events]).fetch('trial_event', 'event_time')))

            trial_tag_value = {
                **trial,
                **events, 'stop_time': np.nan
            }  # No stop_time available for this dataset

            trial_tag_value['id'] = trial_tag_value[
                'trial_id']  # rename 'trial_id' to 'id'
            # convert None to np.nan since nwb fields does not take None
            for k, v in trial_tag_value.items():
                trial_tag_value[k] = v if v is not None else np.nan
            [
                trial_tag_value.pop(k)
                for k in acquisition.TrialSet.Trial.primary_key
            ]

            # Final tweaks: i) add '_time' suffix and ii) remove 'trial_' prefix
            events = {k + '_time': trial_tag_value.pop(k) for k in events}
            trial_attrs = {
                k.replace('trial_', ''): trial_tag_value.pop(k)
                for k in
                [n for n in trial_tag_value if n.startswith('trial_')]
            }

            nwbfile.add_trial(**trial_tag_value, **events, **trial_attrs)

    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(
                os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name),
                       mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile
示例#22
0
    description="Random numbers generated with numpy.random.rand")
nwbfile.add_acquisition(ephys_ts)

####################
# .. _units_electrode:
#
# Associate electrodes with units
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# The :ref:`PyNWB Basics tutorial <basics>` demonstrates how to add data about units and specifying custom metadata
# about units. As mentioned :ref:`here <units_fields_ref>`, there are some optional fields for units, one of these
# is *electrodes*. This field takes a list of indices into the electrode table for the electrodes that the unit
# corresponds to. For example, if two units were inferred from the first electrode (*id* = 1, index = 0), you would
# specify that like so:

nwbfile.add_unit(id=1, electrodes=[0])
nwbfile.add_unit(id=2, electrodes=[0])

#######################
# Designating electrophysiology data
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# As mentioned above, :py:class:`~pynwb.ecephys.ElectricalSeries` and :py:class:`~pynwb.ecephys.SpikeEventSeries`
# are meant for storing specific types of extracellular recordings. In addition to these two
# :py:class:`~pynwb.base.TimeSeries` classes, NWB provides some :ref:`data interfaces <basic_data_interfaces>`
# for designating the type of data you are storing. We will briefly discuss them here, and refer the reader to
# :py:mod:`API documentation <pynwb.ecephys>` and :ref:`PyNWB Basics tutorial <basics>` for more details on
# using these objects.
#
# For storing spike data, there are two options. Which one you choose depends on what data you have available.
# If you need to store the complete, continuous raw voltage traces, you should store your the traces with
示例#23
0
def create_nwb_file(Sess, start_time):

    sr = 30000  #30kHz
    if sys.platform == 'win32':
        SaveDir = os.path.join(r'C:\Users\slashchevam\Desktop\NPx\Results',
                               Sess)
        RawDataDir = r'C:\Users\slashchevam\Desktop\NPx'
        ExcelInfoPath = RawDataDir

        PathToUpload = os.path.join(RawDataDir, Sess)

    if sys.platform == 'linux':
        SaveDir = os.path.join('/mnt/gs/departmentN4/Marina/NPx/Results', Sess)
        RawDataDir = '/mnt/gs/projects/OWVinckNatIm/NPx_recordings/'
        PAthToAnalyzed = '/experiment1/recording1/continuous/Neuropix-PXI-100.0/'
        MatlabOutput = '/mnt/gs/projects/OWVinckNatIm/NPx_processed/Lev0_condInfo/'
        ExcelInfoPath = '/mnt/gs/departmentN4/Marina/'

        PathToUpload = RawDataDir + Sess + PAthToAnalyzed

        if not os.path.exists(SaveDir):
            os.makedirs(SaveDir)
        os.chdir(SaveDir)

    # Upload all the data
    spike_stamps = np.load(os.path.join(PathToUpload, "spike_times.npy"))
    spike_times = spike_stamps / sr
    spike_clusters = np.load(os.path.join(PathToUpload, "spike_clusters.npy"))
    cluster_group = pd.read_csv(os.path.join(PathToUpload,
                                             "cluster_group.tsv"),
                                sep="\t")
    cluster_info = pd.read_csv(os.path.join(PathToUpload, "cluster_info.tsv"),
                               sep="\t")

    if len(cluster_group) != len(cluster_info):
        print('Cluster group (manual labeling) and claster info do not match!')

    #excel_info = pd.read_excel((ExcelInfoPath + '\\Recordings_Marina_NPx.xlsx'), sheet_name=Sess)
    excel_info = pd.read_excel(os.path.join(ExcelInfoPath +
                                            'Recordings_Marina_NPx.xlsx'),
                               sheet_name=Sess)

    # Select spikes from good clusters only
    # Have to add the depth of the clusters
    good_clus_info = cluster_info[cluster_info['group'] ==
                                  'good']  # has depth info
    good_clus = good_clus_info[['id', 'group']]
    print("Found", len(good_clus), ' good clusters')

    good_spikes_ind = [x in good_clus['id'].values for x in spike_clusters]
    spike_clus_good = spike_clusters[good_spikes_ind]
    spike_times_good = spike_times[good_spikes_ind]
    # spike_stamps_good = spike_stamps[good_spikes_ind]

    if excel_info['Area'][0] == 'V1':
        good_clus_info['area'] = 'V1'
    else:
        good_clus_info['area'] = good_clus_info['depth'] > np.max(
            good_clus_info['depth']) - 1000
        good_clus_info['area'] = good_clus_info['area'].replace(True, 'V1')
        good_clus_info['area'] = good_clus_info['area'].replace(False, 'HPC')

    del spike_clusters, spike_times, spike_stamps, good_spikes_ind

    # Now reading digitals from condInfo
    # This has to be checked carefully again, especially for few stimuli in the session

    # cond class contains the following:
    #   'spontaneous_brightness': dict_keys(['name', 'time', 'timestamps', 'trl_list', 'conf'])
    #   'natural_images': dict_keys(['name', 'time', 'timestamps', 'trl_list', 'conf', 'img_order', 'img_name'])

    class condInfo:
        pass

    if sys.platform == 'linux':
        mat = scipy.io.loadmat(
            os.path.join((MatlabOutput + Sess), 'condInfo_01.mat'))
    if sys.platform == 'win32':
        mat = scipy.io.loadmat(os.path.join(PathToUpload, 'condInfo_01.mat'))

    SC_stim_labels = mat['StimClass'][0][0][0][0]
    SC_stim_present = np.where(mat['StimClass'][0][0][1][0] == 1)[0]
    SC_stim_labels_present = SC_stim_labels[SC_stim_present]

    cond = [condInfo() for i in range(len(SC_stim_labels_present))]

    for stim in range(len(SC_stim_labels_present)):
        cond[stim].name = SC_stim_labels_present[stim][0]
        cond[stim].stiminfo = mat['StimClass'][0][0][3][
            0, SC_stim_present[stim]][0][0][0][1]  # image indices are here

        # sorting out digitals for spontaneous activity
        # Need this loop in case there are few periods of spont, recorded like separate blocks
        if SC_stim_labels_present[stim][0] == 'spontaneous_brightness':
            cond[stim].time = []
            cond[stim].timestamps = []
            for block in range(
                    len(mat['StimClass'][0][0][3][0, SC_stim_present[stim]][0]
                        [0])):
                print(block)
                cond[stim].time.append(mat['StimClass'][0][0][3][
                    0, SC_stim_present[stim]][0][0][block][2])
                cond[stim].timestamps.append(mat['StimClass'][0][0][3][
                    0, SC_stim_present[stim]][0][0][block][3])

        cond[stim].trl_list = mat['StimClass'][0][0][3][
            0, SC_stim_present[stim]][1]
        cond[stim].conf = mat['StimClass'][0][0][2][
            0,
            SC_stim_present[stim]]  # config is very likely wrong and useless

        # sorting out digitals for natural images
        if SC_stim_labels_present[stim][0] == 'natural_images':
            cond[stim].time = mat['StimClass'][0][0][3][
                0, SC_stim_present[stim]][0][0][0][2]
            cond[stim].timestamps = mat['StimClass'][0][0][3][
                0, SC_stim_present[stim]][0][0][0][3]
            img_order = []
            for i in range(len(cond[stim].stiminfo)):
                img_order.append(int(cond[stim].stiminfo[i][2]))
            cond[stim].img_order = img_order
            cond[stim].img_name = cond[stim].conf[0][0][0][10][
                0]  # currently not used but might be needed later

        # sorting out digitals for drifting gratings
        if SC_stim_labels_present[stim][0] == 'drifting_gratings':
            cond[stim].time = mat['StimClass'][0][0][3][
                0, SC_stim_present[stim]][0][0][0][2]
            cond[stim].timestamps = mat['StimClass'][0][0][3][
                0, SC_stim_present[stim]][0][0][0][3]
            dg_orient = []
            for i in range(len(cond[stim].stiminfo)):
                dg_orient.append(int(cond[stim].stiminfo[i][2]))
            cond[stim].dg_orient = dg_orient

    # Now create NWB file
    start_time = start_time  # datetime(2020, 2, 27, 14, 36, 7, tzinfo=tzlocal())
    nwb_subject = Subject(description="Pretty nice girl",
                          sex='F',
                          species='mouse',
                          subject_id=excel_info['Mouse'].values[0],
                          genotype=excel_info['Genotype'].values[0])

    nwbfile = NWBFile(
        session_description=
        "NPx recording of Natural images and spontaneous activity",
        session_id=Sess,
        identifier='NWB123',
        session_start_time=start_time,
        experimenter='Marina Slashcheva',
        institution='ESI, Frankfurt',
        lab='Martin Vinck',
        notes=' | '.join(
            [x for x in list(excel_info['Note'].values) if str(x) != 'nan']),
        protocol=' | '.join([
            x for x in list(excel_info['experiment'].values) if str(x) != 'nan'
        ]),
        data_collection=
        'Ref: {}, Probe_angle: {}, , Depth: {}, APcoord: {}, MLcoord: {}, Recday: {}, Hemi: {}'
        .format(excel_info['refCh'].values[0],
                excel_info['Angle_probe'].values[0],
                excel_info['Depth'].values[0],
                excel_info['anteroposterior'].values[0],
                excel_info['mediolateral'].values[0],
                excel_info['Recday'].values[0],
                excel_info['Hemisphere'].values[0]),
        subject=nwb_subject)

    # Did not add it for the moment, later add running as a timeseries and add to HDF5 as binary parameter
    # test_ts = TimeSeries(name='test_timeseries', data=data, unit='m', timestamps=timestamps)

    # Add units
    nwbfile.add_unit_column(
        'location',
        'the anatomical location of this unit')  # to be added and CHECKED
    nwbfile.add_unit_column('depth', 'depth on the NPx probe')
    nwbfile.add_unit_column('channel', 'channel on the NPx probe')
    nwbfile.add_unit_column('fr', 'average FR according to KS')

    for un in good_clus_info['id']:
        info_tmp = good_clus_info[good_clus_info['id'] == un]
        spike_times_tmp = spike_times_good[spike_clus_good == un]

        nwbfile.add_unit(id=un,
                         spike_times=np.transpose(spike_times_tmp)[0],
                         location=info_tmp['area'].values[0],
                         depth=info_tmp['depth'].values[0],
                         channel=info_tmp['ch'].values[0],
                         fr=info_tmp['fr'].values[0])
        del spike_times_tmp

    # Add epochs
    for ep in range(len(cond)):
        if cond[ep].name == 'spontaneous_brightness':
            #if len(cond[ep].time) > 1:
            for bl in range(len(cond[ep].time)):
                nwbfile.add_epoch(cond[ep].time[bl][0][0],
                                  cond[ep].time[bl][0][1], cond[ep].name)
            #else:
            #    nwbfile.add_epoch(cond[ep].time[0][0], cond[ep].time[0][1], cond[ep].name)

        if cond[ep].name == 'natural_images':
            nwbfile.add_epoch(cond[ep].time[0][0], cond[ep].time[-1][1],
                              cond[ep].name)

        if cond[ep].name == 'drifting_gratings':
            nwbfile.add_epoch(cond[ep].time[0][0], cond[ep].time[-1][1],
                              cond[ep].name)

    # Add trials
    # Images names can be also added here
    nwbfile.add_trial_column(
        name='start', description='start time relative to the stimulus onset')
    nwbfile.add_trial_column(
        name='stimset',
        description='the visual stimulus type during the trial')
    nwbfile.add_trial_column(name='img_id',
                             description='image ID for Natural Images')

    for ep in range(len(cond)):
        if cond[ep].name == 'spontaneous_brightness':
            #if len(cond[ep].time) > 1:
            for tr in range(len(cond[ep].time)):
                nwbfile.add_trial(start_time=cond[ep].time[tr][0][0],
                                  stop_time=cond[ep].time[tr][0][1],
                                  start=cond[ep].time[tr][0][2],
                                  stimset=(cond[ep].name).encode('utf8'),
                                  img_id=('gray').encode('utf8'))

#            else:
#                nwbfile.add_trial(start_time = cond[ep].time[0][0], stop_time = cond[ep].time[0][1],
#                                  start = cond[ep].time[0][2],
#                                  stimset = (cond[ep].name).encode('utf8'),
#                                  img_id = ('gray').encode('utf8'))

        if cond[ep].name == 'natural_images':
            for tr in range(len(cond[ep].time)):
                nwbfile.add_trial(start_time=cond[ep].time[tr][0],
                                  stop_time=cond[ep].time[tr][1],
                                  start=cond[ep].time[tr][2],
                                  stimset=(cond[ep].name).encode('utf8'),
                                  img_id=(str(
                                      cond[ep].img_order[tr])).encode('utf8'))

        if cond[ep].name == 'drifting_gratings':
            for tr in range(len(cond[ep].time)):
                nwbfile.add_trial(start_time=cond[ep].time[tr][0],
                                  stop_time=cond[ep].time[tr][1],
                                  start=cond[ep].time[tr][2],
                                  stimset=(cond[ep].name).encode('utf8'),
                                  img_id=(str(
                                      cond[ep].dg_orient[tr])).encode('utf8'))

    # Write NWB file
    os.chdir(SaveDir)
    name_to_save = Sess + '.nwb'
    io = NWBHDF5IO(name_to_save, manager=get_manager(), mode='w')
    io.write(nwbfile)
    io.close()

    del nwbfile
示例#24
0
                                    description=description,
                                    comments=comments)
        
        # Store spike waveform data
        nwb.add_acquisition(ephys_ts_S1)


# Check the stored data
print(nwb.acquisition)


# Associate electrodes with units

# M1
for j in np.arange(96):
    nwb.add_unit(electrodes=[j],spike_times=np.ravel(f_info['spikes'][j,1]),electrode_group=electrode_group_M1)
    nwb.add_unit(electrodes=[j],spike_times=np.ravel(f_info['spikes'][j,2]),electrode_group=electrode_group_M1)

# S1
for j in np.arange(96,192):
    nwb.add_unit(electrodes=[j],spike_times=np.ravel(f_info['spikes'][j,1]),electrode_group=electrode_group_S1)
    nwb.add_unit(electrodes=[j],spike_times=np.ravel(f_info['spikes'][j,2]),electrode_group=electrode_group_S1)



# Add behavioral information


# SpatialSeries and Position data interfaces to store cursor_pos
cursor_pos = SpatialSeries(name='cursor_position', data=f_info['cursor_pos'], 
                           reference_frame='0,0', conversion=1e-3, resolution=1e-17, 
示例#25
0
class Alyx2NWBConverter:
    def __init__(self,
                 saveloc=None,
                 nwb_metadata_file=None,
                 metadata_obj: Alyx2NWBMetadata = None,
                 one_object: ONE = None,
                 save_raw=False,
                 save_camera_raw=False,
                 complevel=4,
                 shuffle=False,
                 buffer_size=1):
        """
        Retrieve all Alyx session, subject metadata, raw data for eid using the one apis load method
        Map that to nwb supported datatypes and create an nwb file.
        Parameters
        ----------
        saveloc: str, Path
            save location of nwbfile
        nwb_metadata_file: [dict, str]
            output of Alyx2NWBMetadata as a dict/json location str
        metadata_obj: Alyx2NWBMetadata
        one_object: ONE()
        save_raw: bool
            will load and save large raw files: ecephys.raw.ap/lf.cbin to nwb
        save_camera_raw: bool
            will load and save mice camera movie .mp4: _iblrig_Camera.raw
        complevel: int
            level of compression to apply to raw datasets
            (0-9)>(low,high). https://docs.h5py.org/en/latest/high/dataset.html
        shuffle: bool
            Enable shuffle I/O filter. http://docs.h5py.org/en/latest/high/dataset.html#dataset-shuffle
        """
        self.buffer_size = buffer_size
        self.complevel = complevel
        self.shuffle = shuffle
        if nwb_metadata_file is not None:
            if isinstance(nwb_metadata_file, dict):
                self.nwb_metadata = nwb_metadata_file
            elif isinstance(nwb_metadata_file, str):
                with open(nwb_metadata_file, 'r') as f:
                    self.nwb_metadata = json.load(f)
        elif metadata_obj is not None:
            self.nwb_metadata = metadata_obj.complete_metadata
        else:
            raise Exception(
                'required one of argument: nwb_metadata_file OR metadata_obj')
        if one_object is not None:
            self.one_object = one_object
        elif metadata_obj is not None:
            self.one_object = metadata_obj.one_obj
        else:
            Warning('creating a ONE object and continuing')
            self.one_object = ONE()
        if saveloc is None:
            Warning('saving nwb file in current working directory')
            self.saveloc = str(Path.cwd())
        else:
            self.saveloc = str(saveloc)
        self.eid = self.nwb_metadata["eid"]
        if not isinstance(self.nwb_metadata['NWBFile']['session_start_time'],
                          datetime):
            self.nwb_metadata['NWBFile']['session_start_time'] = \
                datetime.strptime(self.nwb_metadata['NWBFile']['session_start_time'], '%Y-%m-%dT%X').replace(
                    tzinfo=pytz.utc)
            self.nwb_metadata['IBLSubject']['date_of_birth'] = \
                datetime.strptime(self.nwb_metadata['IBLSubject']['date_of_birth'], '%Y-%m-%dT%X').replace(
                    tzinfo=pytz.utc)
        # create nwbfile:
        self.initialize_nwbfile()
        self.no_probes = len(self.nwb_metadata['Probes'])
        if self.no_probes == 0:
            warnings.warn(
                'could not find probe information, will create trials, behavior, acquisition'
            )
        self.electrode_table_exist = False
        self._one_data = _OneData(self.one_object,
                                  self.eid,
                                  self.no_probes,
                                  self.nwb_metadata,
                                  save_raw=save_raw,
                                  save_camera_raw=save_camera_raw)

    def initialize_nwbfile(self):
        """
        Creates self.nwbfile, devices and electrode group of nwb file.
        """
        nwbfile_args = dict(identifier=str(uuid.uuid4()), )
        nwbfile_args.update(**self.nwb_metadata['NWBFile'])
        self.nwbfile = NWBFile(**nwbfile_args)
        # create devices
        [
            self.nwbfile.create_device(**idevice_meta)
            for idevice_meta in self.nwb_metadata['Ecephys']['Device']
        ]
        if 'ElectrodeGroup' in self.nwb_metadata['Ecephys']:
            self.create_electrode_groups(self.nwb_metadata['Ecephys'])

    def create_electrode_groups(self, metadata_ecephys):
        """
        This method is called at __init__.
        Use metadata to create ElectrodeGroup object(s) in the NWBFile

        Parameters
        ----------
        metadata_ecephys : dict
            Dict with key:value pairs for defining the Ecephys group from where this
            ElectrodeGroup belongs. This should contain keys for required groups
            such as 'Device', 'ElectrodeGroup', etc.
        """
        for metadata_elec_group in metadata_ecephys['ElectrodeGroup']:
            eg_name = metadata_elec_group['name']
            # Tests if ElectrodeGroup already exists
            aux = [i.name == eg_name for i in self.nwbfile.children]
            if any(aux):
                print(eg_name + ' already exists in current NWBFile.')
            else:
                device_name = metadata_elec_group['device']
                if device_name in self.nwbfile.devices:
                    device = self.nwbfile.devices[device_name]
                else:
                    print('Device ', device_name, ' for ElectrodeGroup ',
                          eg_name, ' does not exist.')
                    print('Make sure ', device_name,
                          ' is defined in metadata.')

                eg_description = metadata_elec_group['description']
                eg_location = metadata_elec_group['location']
                self.nwbfile.create_electrode_group(name=eg_name,
                                                    location=eg_location,
                                                    device=device,
                                                    description=eg_description)

    def check_module(self, name, description=None):
        """
        Check if processing module exists. If not, create it. Then return module

        Parameters
        ----------
        name: str
        description: str | None (optional)

        Returns
        -------
        pynwb.module

        """

        if name in self.nwbfile.processing:
            return self.nwbfile.processing[name]
        else:
            if description is None:
                description = name
            return self.nwbfile.create_processing_module(name, description)

    def create_stimulus(self):
        """
        Creates stimulus data in nwbfile
        """
        stimulus_list = self._get_data(
            self.nwb_metadata['Stimulus'].get('time_series'))
        for i in stimulus_list:
            self.nwbfile.add_stimulus(pynwb.TimeSeries(**i))

    def create_units(self):
        """
        Units table in nwbfile
        """
        if self.no_probes == 0:
            return
        if not self.electrode_table_exist:
            self.create_electrode_table_ecephys()
        unit_table_list = self._get_data(self.nwb_metadata['Units'])
        # no required arguments for units table. Below are default columns in the table.
        default_args = [
            'id', 'waveform_mean', 'electrodes', 'electrode_group',
            'spike_times', 'obs_intervals'
        ]
        default_ids = _get_default_column_ids(
            default_args, [i['name'] for i in unit_table_list])
        if len(default_ids) != len(default_args):
            warnings.warn(f'could not find all of {default_args} clusters')
        non_default_ids = list(
            set(range(len(unit_table_list))).difference(set(default_ids)))
        default_dict = {
            unit_table_list[id]['name']: unit_table_list[id]['data']
            for id in default_ids
        }
        for cluster_no in range(len(unit_table_list[0]['data'])):
            add_dict = dict()
            for ibl_dataset_name in default_dict:
                if ibl_dataset_name == 'electrodes':
                    add_dict.update({
                        ibl_dataset_name:
                        [default_dict[ibl_dataset_name][cluster_no]]
                    })
                if ibl_dataset_name == 'spike_times':
                    add_dict.update({
                        ibl_dataset_name:
                        default_dict[ibl_dataset_name][cluster_no]
                    })
                elif ibl_dataset_name == 'obs_intervals':  # common across all clusters
                    add_dict.update(
                        {ibl_dataset_name: default_dict[ibl_dataset_name]})
                elif ibl_dataset_name == 'electrode_group':
                    add_dict.update({
                        ibl_dataset_name:
                        self.nwbfile.electrode_groups[self.nwb_metadata[
                            'Probes'][default_dict[ibl_dataset_name]
                                      [cluster_no]]['name']]
                    })
                elif ibl_dataset_name == 'id':
                    if cluster_no >= self._one_data.data_attrs_dump[
                            'unit_table_length'][0]:
                        add_dict.update({
                            ibl_dataset_name:
                            default_dict[ibl_dataset_name][cluster_no] +
                            self._one_data.data_attrs_dump['unit_table_length']
                            [0]
                        })
                    else:
                        add_dict.update({
                            ibl_dataset_name:
                            default_dict[ibl_dataset_name][cluster_no]
                        })
                elif ibl_dataset_name == 'waveform_mean':
                    add_dict.update({
                        ibl_dataset_name:
                        np.mean(default_dict[ibl_dataset_name][cluster_no],
                                axis=1)
                    })  # finding the mean along all the channels of the sluter
            self.nwbfile.add_unit(**add_dict)

        for id in non_default_ids:
            if isinstance(unit_table_list[id]['data'], object):
                unit_table_list[id]['data'] = unit_table_list[id][
                    'data'].tolist()  # convert string numpy
            self.nwbfile.add_unit_column(
                name=unit_table_list[id]['name'],
                description=unit_table_list[id]['description'],
                data=unit_table_list[id]['data'])

    def create_electrode_table_ecephys(self):
        """
        Creates electrode table
        """
        if self.no_probes == 0:
            return
        if self.electrode_table_exist:
            pass
        electrode_table_list = self._get_data(
            self.nwb_metadata['ElectrodeTable'])
        # electrode table has required arguments:
        required_args = ['group', 'x', 'y']
        default_ids = _get_default_column_ids(
            required_args, [i['name'] for i in electrode_table_list])
        non_default_ids = list(
            set(range(len(electrode_table_list))).difference(set(default_ids)))
        default_dict = {
            electrode_table_list[id]['name']: electrode_table_list[id]['data']
            for id in default_ids
        }
        if 'group' in default_dict:
            group_labels = default_dict['group']
        else:  # else fill with probe zero data.
            group_labels = np.concatenate([
                np.ones(self._one_data.
                        data_attrs_dump['electrode_table_length'][i],
                        dtype=int) * i for i in range(self.no_probes)
            ])
        for electrode_no in range(len(electrode_table_list[0]['data'])):
            if 'x' in default_dict:
                x = default_dict['x'][electrode_no][0]
                y = default_dict['y'][electrode_no][1]
            else:
                x = float('NaN')
                y = float('NaN')
            group_data = self.nwbfile.electrode_groups[self.nwb_metadata[
                'Probes'][group_labels[electrode_no]]['name']]
            self.nwbfile.add_electrode(x=x,
                                       y=y,
                                       z=float('NaN'),
                                       imp=float('NaN'),
                                       location='None',
                                       group=group_data,
                                       filtering='none')
        for id in non_default_ids:
            self.nwbfile.add_electrode_column(
                name=electrode_table_list[id]['name'],
                description=electrode_table_list[id]['description'],
                data=electrode_table_list[id]['data'])
        # create probes specific DynamicTableRegion:
        self.probe_dt_region = [
            self.nwbfile.create_electrode_table_region(region=list(
                range(self._one_data.data_attrs_dump['electrode_table_length']
                      [j])),
                                                       description=i['name'])
            for j, i in enumerate(self.nwb_metadata['Probes'])
        ]
        self.probe_dt_region_all = self.nwbfile.create_electrode_table_region(
            region=list(
                range(
                    sum(self._one_data.
                        data_attrs_dump['electrode_table_length']))),
            description='AllProbes')
        self.electrode_table_exist = True

    def create_timeseries_ecephys(self):
        """
        create SpikeEventSeries, ElectricalSeries, Spectrum datatypes within nwbfile>processing>ecephys
        """
        if self.no_probes == 0:
            return
        if not self.electrode_table_exist:
            self.create_electrode_table_ecephys()
        if 'ecephys' not in self.nwbfile.processing:
            mod = self.nwbfile.create_processing_module(
                'ecephys', 'Processed electrophysiology data of IBL')
        else:
            mod = self.nwbfile.get_processing_module('ecephys')
        for neurodata_type_name, neurodata_type_args_list in self.nwb_metadata[
                'Ecephys']['Ecephys'].items():
            data_retrieved_args_list = self._get_data(
                neurodata_type_args_list
            )  # list of dicts with keys as argument names
            for no, neurodata_type_args in enumerate(data_retrieved_args_list):
                ibl_dataset_name = neurodata_type_args_list[no]['data']
                if 'ElectricalSeries' in neurodata_type_name:
                    timestamps_names = self._one_data.data_attrs_dump[
                        '_iblqc_ephysTimeRms.timestamps']
                    data_names = self._one_data.data_attrs_dump[
                        '_iblqc_ephysTimeRms.rms']
                    for data_idx, data in enumerate(
                            neurodata_type_args['data']):
                        probe_no = [
                            j for j in range(self.no_probes)
                            if self.nwb_metadata['Probes'][j]['name'] in
                            data_names[data_idx]
                        ][0]
                        if data.shape[1] > self._one_data.data_attrs_dump[
                                'electrode_table_length'][probe_no]:
                            if 'channels.rawInd' in self._one_data.loaded_datasets:
                                channel_idx = self._one_data.loaded_datasets[
                                    'channels.rawInd'][probe_no].data.astype(
                                        'int')
                            else:
                                warnings.warn('could not find channels.rawInd')
                                break
                        else:
                            channel_idx = slice(None)
                        mod.add(
                            ElectricalSeries(
                                name=data_names[data_idx],
                                description=neurodata_type_args['description'],
                                timestamps=neurodata_type_args['timestamps']
                                [timestamps_names.index(data_names[data_idx])],
                                data=data[:, channel_idx],
                                electrodes=self.probe_dt_region[probe_no]))
                elif 'Spectrum' in neurodata_type_name:
                    if ibl_dataset_name in '_iblqc_ephysSpectralDensity.power':
                        freqs_names = self._one_data.data_attrs_dump[
                            '_iblqc_ephysSpectralDensity.freqs']
                        data_names = self._one_data.data_attrs_dump[
                            '_iblqc_ephysSpectralDensity.power']
                        for data_idx, data in enumerate(
                                neurodata_type_args['data']):
                            mod.add(
                                Spectrum(name=data_names[data_idx],
                                         frequencies=neurodata_type_args[
                                             'frequencies'][freqs_names.index(
                                                 data_names[data_idx])],
                                         power=data))
                elif 'SpikeEventSeries' in neurodata_type_name:
                    neurodata_type_args.update(
                        dict(electrodes=self.probe_dt_region_all))
                    mod.add(
                        pynwb.ecephys.SpikeEventSeries(**neurodata_type_args))

    def create_behavior(self):
        """
        Create behavior processing module
        """
        self.check_module('behavior')
        for behavior_datatype in self.nwb_metadata['Behavior']:
            if behavior_datatype == 'Position':
                position_cont = pynwb.behavior.Position()
                time_series_list_details = self._get_data(
                    self.nwb_metadata['Behavior'][behavior_datatype]
                    ['spatial_series'])
                if len(time_series_list_details) == 0:
                    continue
                # rate_list = [150.0,60.0,60.0] # based on the google doc for _iblrig_body/left/rightCamera.raw,
                dataname_list = self._one_data.data_attrs_dump['camera.dlc']
                data_list = time_series_list_details[0]['data']
                timestamps_list = time_series_list_details[0]['timestamps']
                for dataname, data, timestamps in zip(dataname_list, data_list,
                                                      timestamps_list):
                    colnames = data.columns
                    data_np = data.to_numpy()
                    x_column_ids = [
                        n for n, k in enumerate(colnames) if 'x' in k
                    ]
                    for x_column_id in x_column_ids:
                        data_loop = data_np[:, x_column_id:x_column_id + 2]
                        position_cont.create_spatial_series(
                            name=dataname + colnames[x_column_id][:-2],
                            data=data_loop,
                            reference_frame='none',
                            timestamps=timestamps,
                            conversion=1e-3)
                self.nwbfile.processing['behavior'].add(position_cont)
            elif not (behavior_datatype == 'BehavioralEpochs'):
                time_series_func = pynwb.TimeSeries
                time_series_list_details = self._get_data(
                    self.nwb_metadata['Behavior'][behavior_datatype]
                    ['time_series'])
                if len(time_series_list_details) == 0:
                    continue
                time_series_list_obj = []
                for i in time_series_list_details:
                    unit = 'radians/sec' if 'velocity' in i[
                        'name'] else 'radians'
                    time_series_list_obj.append(
                        time_series_func(**i, unit=unit))
                func = getattr(pynwb.behavior, behavior_datatype)
                self.nwbfile.processing['behavior'].add(
                    func(time_series=time_series_list_obj))
            else:
                time_series_func = pynwb.epoch.TimeIntervals
                time_series_list_details = self._get_data(
                    self.nwb_metadata['Behavior'][behavior_datatype]
                    ['time_intervals'])
                if len(time_series_list_details) == 0:
                    continue
                for k in time_series_list_details:
                    time_intervals = time_series_func('BehavioralEpochs')
                    for time_interval in k['timestamps']:
                        time_intervals.add_interval(
                            start_time=time_interval[0],
                            stop_time=time_interval[1])
                    time_intervals.add_column(k['name'],
                                              k['description'],
                                              data=k['data'])
                    self.nwbfile.processing['behavior'].add(time_intervals)

    def create_acquisition(self):
        """
        Acquisition data like audiospectrogram(raw beh data), nidq(raw ephys data), raw camera data.
        These are independent of probe type.
        """
        for neurodata_type_name, neurodata_type_args_list in self.nwb_metadata[
                'Acquisition'].items():
            data_retrieved_args_list = self._get_data(neurodata_type_args_list)
            for neurodata_type_args in data_retrieved_args_list:
                if neurodata_type_name == 'ImageSeries':
                    for types, times in zip(neurodata_type_args['data'],
                                            neurodata_type_args['timestamps']):
                        customargs = dict(name='camera_raw',
                                          external_file=[str(types)],
                                          format='external',
                                          timestamps=times,
                                          unit='n.a.')
                        self.nwbfile.add_acquisition(ImageSeries(**customargs))
                elif neurodata_type_name == 'DecompositionSeries':
                    neurodata_type_args['bands'] = np.squeeze(
                        neurodata_type_args['bands'])
                    freqs = DynamicTable(
                        'bands',
                        'spectogram frequencies',
                        id=np.arange(neurodata_type_args['bands'].shape[0]))
                    freqs.add_column('freq',
                                     'frequency value',
                                     data=neurodata_type_args['bands'])
                    neurodata_type_args.update(dict(bands=freqs))
                    temp = neurodata_type_args['data'][:, :, np.newaxis]
                    neurodata_type_args['data'] = np.moveaxis(
                        temp, [0, 1, 2], [0, 2, 1])
                    ts = neurodata_type_args.pop('timestamps')
                    starting_time = ts[0][0] if isinstance(
                        ts[0], np.ndarray) else ts[0]
                    neurodata_type_args.update(
                        dict(starting_time=np.float64(starting_time),
                             rate=1 / np.mean(np.diff(ts.squeeze())),
                             unit='sec'))
                    self.nwbfile.add_acquisition(
                        DecompositionSeries(**neurodata_type_args))
                elif neurodata_type_name == 'ElectricalSeries':
                    if not self.electrode_table_exist:
                        self.create_electrode_table_ecephys()
                    if neurodata_type_args['name'] in ['raw.lf', 'raw.ap']:
                        for probe_no in range(self.no_probes):
                            if neurodata_type_args['data'][probe_no].shape[
                                    1] > self._one_data.data_attrs_dump[
                                        'electrode_table_length'][probe_no]:
                                if 'channels.rawInd' in self._one_data.loaded_datasets:
                                    channel_idx = self._one_data.loaded_datasets[
                                        'channels.rawInd'][
                                            probe_no].data.astype('int')
                                else:
                                    warnings.warn(
                                        'could not find channels.rawInd')
                                    break
                            else:
                                channel_idx = slice(None)
                            self.nwbfile.add_acquisition(
                                ElectricalSeries(
                                    name=neurodata_type_args['name'] + '_' +
                                    self.nwb_metadata['Probes'][probe_no]
                                    ['name'],
                                    starting_time=np.abs(
                                        np.round(
                                            neurodata_type_args['timestamps']
                                            [probe_no][0, 1], 2)
                                    ),  # round starting times of the order of 1e-5
                                    rate=neurodata_type_args['data']
                                    [probe_no].fs,
                                    data=H5DataIO(
                                        DataChunkIterator(
                                            _iter_datasetview(
                                                neurodata_type_args['data']
                                                [probe_no],
                                                channel_ids=channel_idx),
                                            buffer_size=self.buffer_size),
                                        compression=True,
                                        shuffle=self.shuffle,
                                        compression_opts=self.complevel),
                                    electrodes=self.probe_dt_region[probe_no],
                                    channel_conversion=neurodata_type_args[
                                        'data']
                                    [probe_no].channel_conversion_sample2v[
                                        neurodata_type_args['data']
                                        [probe_no].type][channel_idx]))
                    elif neurodata_type_args['name'] in ['raw.nidq']:
                        self.nwbfile.add_acquisition(
                            ElectricalSeries(**neurodata_type_args))

    def create_probes(self):
        """
        Fills in all the probes metadata into the custom NeuroPixels extension.
        """
        for i in self.nwb_metadata['Probes']:
            self.nwbfile.add_device(IblProbes(**i))

    def create_iblsubject(self):
        """
        Populates the custom subject extension for IBL mice daata
        """
        self.nwbfile.subject = IblSubject(**self.nwb_metadata['IBLSubject'])

    def create_lab_meta_data(self):
        """
        Populates the custom lab_meta_data extension for IBL sessions data
        """
        self.nwbfile.add_lab_meta_data(
            IblSessionData(**self.nwb_metadata['IBLSessionsData']))

    def create_trials(self):
        table_data = self._get_data(self.nwb_metadata['Trials'])
        required_fields = ['start_time', 'stop_time']
        required_data = [i for i in table_data if i['name'] in required_fields]
        optional_data = [
            i for i in table_data if i['name'] not in required_fields
        ]
        if len(required_fields) != len(required_data):
            warnings.warn(
                'could not find required datasets: trials.start_time, trials.stop_time, '
                'skipping trials table')
            return
        for start_time, stop_time in zip(required_data[0]['data'][:, 0],
                                         required_data[1]['data'][:, 1]):
            self.nwbfile.add_trial(start_time=start_time, stop_time=stop_time)
        for op_data in optional_data:
            if op_data['data'].shape[0] == required_data[0]['data'].shape[0]:
                self.nwbfile.add_trial_column(
                    name=op_data['name'],
                    description=op_data['description'],
                    data=op_data['data'])
            else:
                warnings.warn(
                    f'shape of trials.{op_data["name"]} does not match other trials.* datasets'
                )

    def _get_data(self, sub_metadata):
        """
        Uses OneData class to query ONE datasets on server and download them locally
        Parameters
        ----------
        sub_metadata: [list, dict]
            list of metadata dicts containing a data key with a dataset type string as value to retrieve data from(npy, tsv etc)

        Returns
        -------
        out_dict: dict
            dictionary with actual data loaded in the data field
        """
        include_idx = []
        out_dict_trim = []
        alt_datatypes = ['bands', 'power', 'frequencies', 'timestamps']
        if isinstance(sub_metadata, list):
            out_dict = deepcopy(sub_metadata)
        elif isinstance(sub_metadata, dict):
            out_dict = deepcopy(list(sub_metadata))
        else:
            return []
        req_datatypes = ['data']
        for count, neurodata_type_args in enumerate(out_dict):
            for alt_names in alt_datatypes:
                if neurodata_type_args.get(
                        alt_names
                ):  # in case of Decomposotion series, Spectrum
                    neurodata_type_args[
                        alt_names] = self._one_data.download_dataset(
                            neurodata_type_args[alt_names],
                            neurodata_type_args['name'])
                    req_datatypes.append(alt_names)
            if neurodata_type_args[
                    'name'] == 'id':  # valid in case of units table.
                neurodata_type_args['data'] = self._one_data.download_dataset(
                    neurodata_type_args['data'], 'cluster_id')
            else:
                out_dict[count]['data'] = self._one_data.download_dataset(
                    neurodata_type_args['data'], neurodata_type_args['name'])
            if all([out_dict[count][i] is not None for i in req_datatypes]):
                include_idx.extend([count])
        out_dict_trim.extend([out_dict[j0] for j0 in include_idx])
        return out_dict_trim

    def run_conversion(self):
        """
        Single method to create all datasets and metadata in nwbfile in one go
        Returns
        -------

        """
        execute_list = [
            self.create_stimulus, self.create_trials,
            self.create_electrode_table_ecephys,
            self.create_timeseries_ecephys, self.create_units,
            self.create_behavior, self.create_probes, self.create_iblsubject,
            self.create_lab_meta_data, self.create_acquisition
        ]
        t = tqdm(execute_list)
        for i in t:
            t.set_postfix(current=f'creating nwb ' + i.__name__.split('_')[-1])
            i()
        print('done converting')

    def write_nwb(self, read_check=True):
        """
        After run_conversion(), write nwbfile to disk with the loaded nwbfile
        Parameters
        ----------
        read_check: bool
            Round trip verification
        """
        print('Saving to file, please wait...')
        with NWBHDF5IO(self.saveloc, 'w') as io:
            io.write(self.nwbfile)
            print('File successfully saved at: ', str(self.saveloc))

        if read_check:
            with NWBHDF5IO(self.saveloc, 'r') as io:
                io.read()
                print('Read check: OK')
示例#26
0
if COMPRESS:
    mp_data = H5DataIO(mp_data, compression='gzip')
ts = TimeSeries('membrane_potential', mp_data, unit='mV', rate=10000.)
nwbfile.add_acquisition(ts)

# convert spike data

for spk_file in sorted(glob(os.path.join(run_dir, 'spiketimes*'))):
    with open(spk_file, 'rb') as file:
        data = pickle.load(file)
    cell_type = spk_file.split('_')[-2]
    for row in data:
        cell_id = row[0]
        spikes = np.array(row[1:], dtype=np.float) / 1000
        nwbfile.add_unit(spike_times=spikes,
                         cell_type=cell_type,
                         cell_type_id=cell_id)

print(nwbfile.units['cell_type'].data)
print(nwbfile.units.get_unit_spike_times(21))

# write NWB file

with NWBHDF5IO(run_dir + '.nwb', 'w') as io:
    io.write(nwbfile)

# access data

read_io = NWBHDF5IO(run_dir + '.nwb', 'r')
nwbfile = read_io.read()
示例#27
0
def export_to_nwb(session_key, nwb_output_dir=default_nwb_output_dir, save=False, overwrite=True):

    this_session = (acquisition.Session & session_key).fetch1()

    # ===============================================================================
    # ============================== META INFORMATION ===============================
    # ===============================================================================

    # -- NWB file - a NWB2.0 file for each session
    experimenter = (acquisition.Session.Experimenter & session_key).fetch1('experimenter')
    nwbfile = NWBFile(identifier='_'.join(
        [this_session['subject'],
         this_session['session_time'].strftime('%Y-%m-%d %H:%M:%S')]),
        related_publications='https://www.nature.com/articles/s41586-018-0633-x',
        experiment_description='Extracelluar recording in ALM',
        session_description='',
        session_start_time=this_session['session_time'],
        file_create_date=datetime.now(tzlocal()),
        experimenter=experimenter,
        institution=institution,
        keywords=['motor planning', 'anterior lateral cortex',
                  'ALM', 'Extracellular recording', 'optogenetics'])
    # -- subject
    subj = (subject.Subject & session_key).fetch1()

    nwbfile.subject = pynwb.file.Subject(
        subject_id=str(this_session['subject']),
        genotype=' x '.join((subject.Zygosity &
                             subj).fetch('allele')) \
                 if len(subject.Zygosity & subj) else 'unknown',
        sex=subj['sex'],
        species=subj['species'],
        date_of_birth=datetime.combine(subj['date_of_birth'], zero_zero_time) if subj['date_of_birth'] else None)
    # -- virus
    nwbfile.virus = json.dumps([{k: str(v) for k, v in virus_injection.items() if k not in subj}
                                for virus_injection in action.VirusInjection * reference.Virus & session_key])

    # ===============================================================================
    # ======================== EXTRACELLULAR & CLUSTERING ===========================
    # ===============================================================================

    """
    In the event of multiple probe recording (i.e. multiple probe insertions), the clustering results
    (and the associated units) are associated with the corresponding probe.
    Each probe insertion is associated with one ElectrodeConfiguration (which may define multiple electrode groups)
    """

    dj_insert_location = ephys.ProbeInsertion

    for probe_insertion in ephys.ProbeInsertion & session_key:
        probe = (reference.Probe & probe_insertion).fetch1()
        electrode_group = nwbfile.create_electrode_group(
                name=probe['probe_type'] + '_g1',
                description='N/A',
                device=nwbfile.create_device(name=probe['probe_type']),
                location=json.dumps({k: str(v) for k, v in (dj_insert_location & session_key).fetch1().items()
                                     if k not in dj_insert_location.primary_key}))

        for chn in (reference.Probe.Channel & probe).fetch(as_dict=True):
            nwbfile.add_electrode(
                id=chn['channel_id']-1,
                group=electrode_group,
                filtering=hardware_filter,
                imp=-1.,
                x=np.nan,
                y=np.nan,
                z=np.nan,
                location=(dj_insert_location & session_key).fetch1('brain_location'))


        # --- unit spike times ---
        nwbfile.add_unit_column(name='cell_type', description='cell type (e.g. fast spiking or pyramidal)')
        nwbfile.add_unit_column(name='sampling_rate', description='sampling rate of the waveform, Hz')

        spk_times_all = np.hstack((ephys.UnitSpikeTimes & probe_insertion).fetch('spike_times'))

        obs_min = np.min(spk_times_all)
        obs_max = np.max(spk_times_all)

        for unit in (ephys.UnitSpikeTimes & probe_insertion).fetch(as_dict=True):
            nwbfile.add_unit(id=unit['unit_id'],
                             electrodes=[unit['channel']-1],
                             electrode_group=electrode_group,
                             cell_type=unit['unit_cell_type'],
                             spike_times=unit['spike_times'],
                             obs_intervals=np.array([[obs_min - 0.001, obs_max + 0.001]]),
                             waveform_mean=np.mean(unit['spike_waveform'], axis=0),
                             waveform_sd=np.std(unit['spike_waveform'], axis=0),
                             sampling_rate=20000)

    # ===============================================================================
    # ============================= PHOTO-STIMULATION ===============================
    # ===============================================================================
    stim_sites = {}
    for photostim in acquisition.PhotoStim * reference.BrainLocation & session_key:

        stim_device = (nwbfile.get_device(photostim['photo_stim_method'])
                       if photostim['photo_stim_method'] in nwbfile.devices
                       else nwbfile.create_device(name=photostim['photo_stim_method']))

        stim_site = pynwb.ogen.OptogeneticStimulusSite(
            name=photostim['brain_location'],
            device=stim_device,
            excitation_lambda=float(photostim['photo_stim_wavelength']),
            location=json.dumps({k: str(v) for k, v in photostim.items()
                                if k in acquisition.PhotoStim.heading.names and k not in acquisition.PhotoStim.primary_key + ['photo_stim_method', 'photo_stim_wavelength']}),
            description='')
        nwbfile.add_ogen_site(stim_site)

    # ===============================================================================
    # =============================== BEHAVIOR TRIALS ===============================
    # ===============================================================================

    # =============== TrialSet ====================
    # NWB 'trial' (of type dynamic table) by default comes with three mandatory attributes: 'start_time' and 'stop_time'
    # Other trial-related information needs to be added in to the trial-table as additional columns (with column name
    # and column description)

    dj_trial = acquisition.Session * behavior.TrialSet.Trial
    skip_adding_columns = acquisition.Session.primary_key + \
        ['trial_id', 'trial_start_idx', 'trial_end_idx', 'trial_start_time', 'session_note']

    if behavior.TrialSet.Trial & session_key:
        # Get trial descriptors from TrialSet.Trial and TrialStimInfo
        trial_columns = [{'name': tag.replace('trial_', ''),
                          'description': re.sub('\s+:|\s+', ' ', re.search(
                              f'(?<={tag})(.*)', str(dj_trial.heading)).group()).strip()}
                         for tag in dj_trial.heading.names
                         if tag not in skip_adding_columns]


        # Add new table columns to nwb trial-table for trial-label
        for c in trial_columns:
            nwbfile.add_trial_column(**c)

        # Add entry to the trial-table
        for trial in (dj_trial & session_key).fetch(as_dict=True):
            trial['start_time'] = float(trial['trial_start_time'])
            trial['stop_time'] = float(trial['trial_start_time']) + 5.0
            trial['trial_pole_in_time'] = trial['start_time'] + trial['trial_pole_in_time']
            trial['trial_pole_out_time'] = trial['start_time'] + trial['trial_pole_out_time']
            trial['trial_cue_time'] = trial['start_time'] + trial['trial_cue_time']
            [trial.pop(k) for k in skip_adding_columns]
            for k in trial.keys():
                if 'trial_' in k:
                    trial[k.replace('trial_', '')] = trial.pop(k)
            nwbfile.add_trial(**trial)


    # =============== Write NWB 2.0 file ===============
    if save:
        save_file_name = ''.join([nwbfile.identifier, '.nwb'])
        if not os.path.exists(nwb_output_dir):
            os.makedirs(nwb_output_dir)
        if not overwrite and os.path.exists(os.path.join(nwb_output_dir, save_file_name)):
            return nwbfile
        with NWBHDF5IO(os.path.join(nwb_output_dir, save_file_name), mode='w') as io:
            io.write(nwbfile)
            print(f'Write NWB 2.0 file: {save_file_name}')

    return nwbfile
示例#28
0
class ShowPSTHTestCase(unittest.TestCase):
    def setUp(self):

        start_time = datetime(2017, 4, 3, 11, tzinfo=tzlocal())
        create_date = datetime(2017, 4, 15, 12, tzinfo=tzlocal())

        self.nwbfile = NWBFile(session_description='NWBFile for PSTH',
                               identifier='NWB123',
                               session_start_time=start_time,
                               file_create_date=create_date)

        self.nwbfile.add_unit_column('location',
                                     'the anatomical location of this unit')
        self.nwbfile.add_unit_column(
            'quality', 'the quality for the inference of this unit')

        self.nwbfile.add_unit(id=1,
                              spike_times=[2.2, 3.0, 4.5],
                              obs_intervals=[[1, 10]],
                              location='CA1',
                              quality=0.95)
        self.nwbfile.add_unit(id=2,
                              spike_times=[2.2, 3.0, 25.0, 26.0],
                              obs_intervals=[[1, 10], [20, 30]],
                              location='CA3',
                              quality=0.85)
        self.nwbfile.add_unit(id=3,
                              spike_times=[1.2, 2.3, 3.3, 4.5],
                              obs_intervals=[[1, 10], [20, 30]],
                              location='CA1',
                              quality=0.90)

        self.nwbfile.add_trial_column(
            name='stim', description='the visual stimuli during the trial')

        self.nwbfile.add_trial(start_time=0.0, stop_time=2.0, stim='person')
        self.nwbfile.add_trial(start_time=3.0, stop_time=5.0, stim='ocean')
        self.nwbfile.add_trial(start_time=6.0, stop_time=8.0, stim='desert')

    def test_get_min_spike_time(self):
        assert (get_min_spike_time(self.nwbfile.units) == 1.2)

    def test_align_by_trials(self):
        ComparetoAT = [
            np.array([2.2, 3.0, 25.0, 26.0]),
            np.array([-0.8, 0., 22., 23.]),
            np.array([-3.8, -3., 19., 20.])
        ]

        AT = align_by_trials(self.nwbfile.units,
                             index=1,
                             before=20.,
                             after=30.)

        np.testing.assert_allclose(AT, ComparetoAT, rtol=1e-02)

    def test_align_by_time_intervals_Nonetrials_select(self):
        time_intervals = TimeIntervals(name='Test Time Interval')
        time_intervals.add_interval(start_time=21.0, stop_time=28.0)
        time_intervals.add_interval(start_time=22.0, stop_time=26.0)
        time_intervals.add_interval(start_time=22.0, stop_time=28.4)

        ATI = align_by_time_intervals(self.nwbfile.units,
                                      index=1,
                                      intervals=time_intervals,
                                      stop_label=None,
                                      before=20.,
                                      after=30.)

        ComparedtoATI = [
            np.array([-18.8, -18., 4., 5.]),
            np.array([-19.8, -19., 3., 4.]),
            np.array([-19.8, -19., 3., 4.])
        ]

        np.testing.assert_array_equal(ATI, ComparedtoATI)

    def test_align_by_time_intervals(self):
        time_intervals = TimeIntervals(name='Test Time Interval')
        time_intervals.add_interval(start_time=21.0, stop_time=28.0)
        time_intervals.add_interval(start_time=22.0, stop_time=26.0)
        time_intervals.add_interval(start_time=22.0, stop_time=28.4)

        ATI = align_by_time_intervals(self.nwbfile.units,
                                      index=1,
                                      intervals=time_intervals,
                                      stop_label=None,
                                      before=20.,
                                      after=30.,
                                      rows_select=[0, 1])

        ComparedtoATI = [
            np.array([-18.8, -18., 4., 5.]),
            np.array([-19.8, -19., 3., 4.])
        ]

        np.testing.assert_array_equal(ATI, ComparedtoATI)
示例#29
0
# Add units
nwbfile.add_unit_column(
    'location',
    'the anatomical location of this unit')  # to be added and CHECKED
nwbfile.add_unit_column('depth', 'depth on the NPx probe')
nwbfile.add_unit_column('channel', 'channel on the NPx probe')
nwbfile.add_unit_column('fr', 'average FR according to KS')

for un in good_clus_info['id']:
    info_tmp = good_clus_info[good_clus_info['id'] == un]
    spike_times_tmp = spike_times_good[spike_clus_good == un]

    nwbfile.add_unit(id=un,
                     spike_times=np.transpose(spike_times_tmp)[0],
                     location=info_tmp['area'].values[0],
                     depth=info_tmp['depth'].values[0],
                     channel=info_tmp['ch'].values[0],
                     fr=info_tmp['fr'].values[0])
    del spike_times_tmp

# Add epochs
for ep in range(len(cond)):
    if cond[ep].name == 'spontaneous_brightness':
        nwbfile.add_epoch(cond[ep].time[0][0], cond[ep].time[0][1],
                          cond[ep].name)
    if cond[ep].name == 'natural_images':
        nwbfile.add_epoch(cond[ep].time[0][0], cond[ep].time[-1][1],
                          cond[ep].name)

# Add trials
# Images names can be also added here
示例#30
0
# Add Units by cluster
for i in cluster_info:
    c = cluster_info[i]
    times = np.array(spike_times[c])
    annotations = phy_annotations[i]
    annotations = annotations.astype(int)
    channel = cluster_channel[i]
    channel = channel.astype(int)
    duration = waveform_duration[i]
    duration = duration.astype(int)

    nwb_file.add_unit(spike_times=np.ravel(times),
                      electrodes=waveform_chans[i, :],
                      electrode_group=electrode_groups[cluster_probe[i]],
                      waveform_mean=waveform[i, :, :],
                      id=i,
                      phy_annotations=annotations,
                      peak_channel=channel,
                      waveform_duration=duration,
                      cluster_depths=cluster_depths[i],
                      sampling_rate=30000.0)

# Add spike amps and depths
amps = {}
depths = {}

for c in cluster_info.keys():
    amps[c] = spike_amps[cluster_info[c]]
    depths[c] = spike_depths[cluster_info[c]]

add_ragged_data_to_dynamic_table(
    table=nwb_file.units,