Exemple #1
0
 def test_electrode_group(self):
     ut = Units()
     device = Device('test_device')
     electrode_group = ElectrodeGroup('test_electrode_group', 'description',
                                      'location', device)
     ut.add_unit(electrode_group=electrode_group)
     self.assertEqual(ut['electrode_group'][0], electrode_group)
Exemple #2
0
 def setUpContainer(self):
     """ Return the test Units to read/write """
     ut = Units(name='UnitsTest',
                description='a simple table for testing Units')
     ut.add_unit(spike_times=[0, 1, 2], obs_intervals=[[0, 1], [2, 3]])
     ut.add_unit(spike_times=[3, 4, 5], obs_intervals=[[2, 5], [6, 7]])
     return ut
Exemple #3
0
 def test_add_spike_times(self):
     ut = Units()
     ut.add_unit(spike_times=[0, 1, 2])
     ut.add_unit(spike_times=[3, 4, 5])
     self.assertEqual(ut.id.data, [0, 1])
     self.assertEqual(ut['spike_times'].target.data, [0, 1, 2, 3, 4, 5])
     self.assertEqual(ut['spike_times'].data, [3, 6])
     self.assertEqual(ut['spike_times'][0], [0, 1, 2])
     self.assertEqual(ut['spike_times'][1], [3, 4, 5])
Exemple #4
0
 def test_get_obs_intervals(self):
     ut = Units()
     ut.add_unit(obs_intervals=[[0, 1]])
     ut.add_unit(obs_intervals=[[2, 3], [4, 5]])
     self.assertTrue(
         np.all(ut.get_unit_obs_intervals(0) == np.array([[0, 1]])))
     self.assertTrue(
         np.all(ut.get_unit_obs_intervals(1) == np.array([[2, 3], [4, 5]])))
Exemple #5
0
    def __init__(self, units: Units):
        """
        Visualize PSTH with the ability to select electrodes from clicking within the grid.

        Parameters
        ----------
        units : Units
            Units Table of an NWBFile.
        """
        super().__init__()
        self.electrodes = units.get_ancestor("NWBFile").electrodes

        self.tuning_curve = TuningCurveWidget(units)
        self.electrode_position_selector = ElectrodePositionSelector(
            self.electrodes,
            pre_alignment_window=[
                0, -self.tuning_curve.children[0].start_ft.value
            ],
            post_alignment_window=[
                0, self.tuning_curve.children[0].end_ft.value
            ],
            event_name=self.tuning_curve.children[0].trial_event_controller.
            value)
        self.electrode_position_selector.scatter.on_click(self.update_point)

        self.children = [self.tuning_curve, self.electrode_position_selector]
        self.tuning_curve.children[0].unit_controller.observe(
            self.handle_unit_controller, "value")
        self.tuning_curve.children[0].start_ft.observe(self.handle_response,
                                                       "value")
        self.tuning_curve.children[0].end_ft.observe(self.handle_response,
                                                     "value")
        self.tuning_curve.children[0].trial_event_controller.observe(
            self.handle_response, "value")
Exemple #6
0
def raster_grid_widget(units: Units):

    trials = units.get_ancestor('NWBFile').trials
    if trials is None:
        return widgets.HTML('No trials present')

    groups = infer_categorical_columns(trials)

    control_widgets = widgets.VBox(children=[])

    rows_controller = widgets.Dropdown(options=[None] + list(groups),
                                       description='rows: ',
                                       layout=Layout(width='95%'))
    cols_controller = widgets.Dropdown(options=[None] + list(groups),
                                       description='cols: ',
                                       layout=Layout(width='95%'))
    control_widgets.children = list(
        control_widgets.children) + [rows_controller, cols_controller]

    trial_event_controller = make_trial_event_controller(trials)
    control_widgets.children = list(
        control_widgets.children) + [trial_event_controller]

    unit_controller = int_controller(len(units['spike_times'].data) - 1)
    control_widgets.children = list(
        control_widgets.children) + [unit_controller]

    before_slider = widgets.FloatSlider(.5,
                                        min=0,
                                        max=5.,
                                        description='before (s)',
                                        continuous_update=False)
    control_widgets.children = list(control_widgets.children) + [before_slider]

    after_slider = widgets.FloatSlider(2.,
                                       min=0,
                                       max=5.,
                                       description='after (s)',
                                       continuous_update=False)
    control_widgets.children = list(control_widgets.children) + [after_slider]

    controls = {
        'units': fixed(units),
        'trials': fixed(trials),
        'index': unit_controller.children[0],
        'after': after_slider,
        'before': before_slider,
        'align_by': trial_event_controller,
        'rows_label': rows_controller,
        'cols_label': cols_controller
    }

    out_fig = widgets.interactive_output(raster_grid, controls)
    vbox = widgets.VBox(children=[control_widgets, out_fig])

    return vbox
Exemple #7
0
 def test_obs_intervals(self):
     ut = Units()
     ut.add_unit(obs_intervals=[[0, 1]])
     ut.add_unit(obs_intervals=[[2, 3], [4, 5]])
     self.assertTrue(np.all(ut['obs_intervals'][0] == np.array([[0, 1]])))
     self.assertTrue(
         np.all(ut['obs_intervals'][1] == np.array([[2, 3], [4, 5]])))
Exemple #8
0
 def test_times_and_intervals(self):
     ut = Units()
     ut.add_unit(spike_times=[0, 1, 2], obs_intervals=[[0, 2]])
     ut.add_unit(spike_times=[3, 4, 5], obs_intervals=[[2, 3], [4, 5]])
     self.assertTrue(all(ut['spike_times'][0] == np.array([0, 1, 2])))
     self.assertTrue(all(ut['spike_times'][1] == np.array([3, 4, 5])))
     self.assertTrue(np.all(ut['obs_intervals'][0] == np.array([[0, 2]])))
     self.assertTrue(np.all(ut['obs_intervals'][1] == np.array([[2, 3], [4, 5]])))
Exemple #9
0
 def setUpContainer(self):
     """ Return the test Units to read/write """
     ut = Units(name='UnitsTest',
                description='a simple table for testing Units')
     ut.add_unit(spike_times=[0, 1, 2],
                 obs_intervals=[[0, 1], [2, 3]],
                 waveform_mean=[1., 2., 3.],
                 waveform_sd=[4., 5., 6.])
     ut.add_unit(spike_times=[3, 4, 5],
                 obs_intervals=[[2, 5], [6, 7]],
                 waveform_mean=[1., 2., 3.],
                 waveform_sd=[4., 5., 6.])
     ut.waveform_rate = 40000.
     ut.resolution = 1 / 40000
     return ut
Exemple #10
0
 def test_get_spike_times_interval():
     ut = Units()
     ut.add_unit(spike_times=[0, 1, 2])
     ut.add_unit(spike_times=[3, 4, 5])
     np.testing.assert_array_equal(ut.get_unit_spike_times(0, (.5, 3)),
                                   [1, 2])
     np.testing.assert_array_equal(ut.get_unit_spike_times(0, (-.5, 1.1)),
                                   [0, 1])
Exemple #11
0
 def test_add_waveforms(self):
     ut = Units()
     wf1 = [
             [  # elec 1
                 [1, 2, 3],
                 [1, 2, 3],
                 [1, 2, 3]
             ], [  # elec 2
                 [1, 2, 3],
                 [1, 2, 3],
                 [1, 2, 3]
             ]
         ]
     wf2 = [
             [     # elec 1
                 [1, 2, 3],  # spike 1, [sample 1, sample 2, sample 3]
                 [1, 2, 3],  # spike 2
                 [1, 2, 3],  # spike 3
                 [1, 2, 3]   # spike 4
             ], [  # elec 2
                 [1, 2, 3],  # spike 1
                 [1, 2, 3],  # spike 2
                 [1, 2, 3],  # spike 3
                 [1, 2, 3]   # spike 4
             ], [  # elec 3
                 [1, 2, 3],  # spike 1
                 [1, 2, 3],  # spike 2
                 [1, 2, 3],  # spike 3
                 [1, 2, 3]   # spike 4
             ]
         ]
     ut.add_unit(waveforms=wf1)
     ut.add_unit(waveforms=wf2)
     self.assertEqual(ut.id.data, [0, 1])
     self.assertEqual(ut['waveforms'].target.data, [3, 6, 10, 14, 18])
     self.assertEqual(ut['waveforms'].data, [2, 5])
     self.assertListEqual(ut['waveforms'][0], wf1)
     self.assertListEqual(ut['waveforms'][1], wf2)
Exemple #12
0
 def setUpContainer(self):
     """ Return the test Units to read/write """
     ut = Units(name='UnitsTest',
                description='a simple table for testing Units')
     ut.add_unit(
         spike_times=[0, 1, 2],
         obs_intervals=[[0, 1], [2, 3]],
         waveform_mean=[1., 2., 3.],
         waveform_sd=[4., 5., 6.],
         waveforms=[
             [  # elec 1
                 [1, 2, 3], [1, 2, 3], [1, 2, 3]
             ],
             [  # elec 2
                 [1, 2, 3], [1, 2, 3], [1, 2, 3]
             ]
         ])
     ut.add_unit(
         spike_times=[3, 4, 5],
         obs_intervals=[[2, 5], [6, 7]],
         waveform_mean=[1., 2., 3.],
         waveform_sd=[4., 5., 6.],
         waveforms=np.array([
             [  # elec 1
                 [1, 2, 3],  # spike 1, [sample 1, sample 2, sample 3]
                 [1, 2, 3],  # spike 2
                 [1, 2, 3],  # spike 3
                 [1, 2, 3]  # spike 4
             ],
             [  # elec 2
                 [1, 2, 3],  # spike 1
                 [1, 2, 3],  # spike 2
                 [1, 2, 3],  # spike 3
                 [1, 2, 3]  # spike 4
             ],
             [  # elec 3
                 [1, 2, 3],  # spike 1
                 [1, 2, 3],  # spike 2
                 [1, 2, 3],  # spike 3
                 [1, 2, 3]  # spike 4
             ]
         ]))
     ut.waveform_rate = 40000.
     ut.resolution = 1 / 40000
     return ut
Exemple #13
0
 def setUpContainer(self):
     ut = Units(name='UnitsTest',
                description='a simple table for testing Units')
     ut.add_unit(spike_times=[0, 1, 2])
     ut.add_unit(spike_times=[3, 4, 5])
     return ut
Exemple #14
0
def psth_widget(units: Units,
                unit_controller=None,
                after_slider=None,
                before_slider=None,
                trial_event_controller=None,
                trial_order_controller=None,
                trial_group_controller=None,
                sigma_in_secs=.05,
                ntt=1000):
    """

    Parameters
    ----------
    units: pynwb.misc.Units
    unit_controller
    after_slider
    before_slider
    trial_event_controller
    trial_order_controller
    trial_group_controller
    sigma_in_secs: float
    ntt: int
        Number of timepoints to use for smoothed PSTH

    Returns
    -------

    """

    trials = units.get_ancestor('NWBFile').trials
    if trials is None:
        return widgets.HTML('No trials present')

    control_widgets = widgets.VBox(children=[])

    if unit_controller is None:
        nunits = len(units['spike_times'].data)
        #unit_controller = int_controller(nunits)
        unit_controller = widgets.Dropdown(options=[x for x in range(nunits)],
                                           description='unit: ')
        control_widgets.children = list(
            control_widgets.children) + [unit_controller]

    if trial_event_controller is None:
        trial_event_controller = make_trial_event_controller(trials)
        control_widgets.children = list(
            control_widgets.children) + [trial_event_controller]

    if trial_order_controller is None:
        trials = units.get_ancestor('NWBFile').trials
        trial_order_controller = widgets.Dropdown(options=trials.colnames,
                                                  value='start_time',
                                                  description='order by: ')
        control_widgets.children = list(
            control_widgets.children) + [trial_order_controller]

    if trial_group_controller is None:
        trials = units.get_ancestor('NWBFile').trials
        trial_group_controller = widgets.Dropdown(options=[None] +
                                                  list(trials.colnames),
                                                  description='group by')
        control_widgets.children = list(
            control_widgets.children) + [trial_group_controller]

    if before_slider is None:
        before_slider = widgets.FloatSlider(.5,
                                            min=0,
                                            max=5.,
                                            description='before (s)',
                                            continuous_update=False)
        control_widgets.children = list(
            control_widgets.children) + [before_slider]

    if after_slider is None:
        after_slider = widgets.FloatSlider(2.,
                                           min=0,
                                           max=5.,
                                           description='after (s)',
                                           continuous_update=False)
        control_widgets.children = list(
            control_widgets.children) + [after_slider]

    controls = {
        'units': fixed(units),
        'sigma_in_secs': fixed(sigma_in_secs),
        'ntt': fixed(ntt),
        'index': unit_controller,
        'after': after_slider,
        'before': before_slider,
        'start_label': trial_event_controller,
        'order_by': trial_order_controller,
        'group_by': trial_group_controller
    }

    out_fig = widgets.interactive_output(trials_psth, controls)
    vbox = widgets.VBox(children=[control_widgets, out_fig])
    return vbox
Exemple #15
0
 def test_get_spike_times_multi(self):
     ut = Units()
     ut.add_unit(spike_times=[0, 1, 2])
     ut.add_unit(spike_times=[3, 4, 5])
     np.testing.assert_array_equal(ut.get_unit_spike_times((0, 1)), [[0, 1, 2], [3, 4, 5]])
Exemple #16
0
 def test_get_spike_times_multi_interval(self):
     ut = Units()
     ut.add_unit(spike_times=[0, 1, 2])
     ut.add_unit(spike_times=[3, 4, 5])
     np.testing.assert_array_equal(ut.get_unit_spike_times((0, 1), (1.5, 3.5)), [[2], [3]])
Exemple #17
0
def convert(
        input_file,
        session_start_time,
        subject_date_of_birth,
        subject_id='I5',
        subject_description='naive',
        subject_genotype='wild-type',
        subject_sex='M',
        subject_weight='11.6g',
        subject_species='Mus musculus',
        subject_brain_region='Medial Entorhinal Cortex',
        surgery='Probe: +/-3.3mm ML, 0.2mm A of sinus, then as deep as possible',
        session_id='npI5_0417_baseline_1',
        experimenter='Kei Masuda',
        experiment_description='Virtual Hallway Task',
        institution='Stanford University School of Medicine',
        lab_name='Giocomo Lab'):
    """
    Read in the .mat file specified by input_file and convert to .nwb format.

    Parameters
    ----------
    input_file : np.ndarray (..., n_channels, n_time)
        the .mat file to be converted
    subject_id : string
        the unique subject ID number for the subject of the experiment
    subject_date_of_birth : datetime ISO 8601
        the date and time the subject was born
    subject_description : string
        important information specific to this subject that differentiates it from other members of it's species
    subject_genotype : string
        the genetic strain of this species.
    subject_sex : string
        Male or Female
    subject_weight :
        the weight of the subject around the time of the experiment
    subject_species : string
        the name of the species of the subject
    subject_brain_region : basestring
        the name of the brain region where the electrode probe is recording from
    surgery : str
        information about the subject's surgery to implant electrodes
    session_id: string
        human-readable ID# for the experiment session that has a one-to-one relationship with a recording session
    session_start_time : datetime
        date and time that the experiment started
    experimenter : string
        who ran the experiment, first and last name
    experiment_description : string
        what task was being run during the session
    institution : string
        what institution was the experiment performed in
    lab_name : string
        the lab where the experiment was performed

    Returns
    -------
    nwbfile : NWBFile
        The contents of the .mat file converted into the NWB format.  The nwbfile is saved to disk using NDWHDF5
    """

    # input matlab data
    matfile = hdf5storage.loadmat(input_file)

    # output path for nwb data
    def replace_last(source_string, replace_what, replace_with):
        head, _sep, tail = source_string.rpartition(replace_what)
        return head + replace_with + tail

    outpath = replace_last(input_file, '.mat', '.nwb')

    create_date = datetime.today()
    timezone_cali = pytz.timezone('US/Pacific')
    create_date_tz = timezone_cali.localize(create_date)

    # if loading data from config.yaml, convert string dates into datetime
    if isinstance(session_start_time, str):
        session_start_time = datetime.strptime(session_start_time,
                                               '%B %d, %Y %I:%M%p')
        session_start_time = timezone_cali.localize(session_start_time)

    if isinstance(subject_date_of_birth, str):
        subject_date_of_birth = datetime.strptime(subject_date_of_birth,
                                                  '%B %d, %Y %I:%M%p')
        subject_date_of_birth = timezone_cali.localize(subject_date_of_birth)

    # create unique identifier for this experimental session
    uuid_identifier = uuid.uuid1()

    # Create NWB file
    nwbfile = NWBFile(
        session_description=experiment_description,  # required
        identifier=uuid_identifier.hex,  # required
        session_id=session_id,
        experiment_description=experiment_description,
        experimenter=experimenter,
        surgery=surgery,
        institution=institution,
        lab=lab_name,
        session_start_time=session_start_time,  # required
        file_create_date=create_date_tz)  # optional

    # add information about the subject of the experiment
    experiment_subject = Subject(subject_id=subject_id,
                                 species=subject_species,
                                 description=subject_description,
                                 genotype=subject_genotype,
                                 date_of_birth=subject_date_of_birth,
                                 weight=subject_weight,
                                 sex=subject_sex)
    nwbfile.subject = experiment_subject

    # adding constants via LabMetaData container
    # constants
    sample_rate = float(matfile['sp'][0]['sample_rate'][0][0][0])
    n_channels_dat = int(matfile['sp'][0]['n_channels_dat'][0][0][0])
    dat_path = matfile['sp'][0]['dat_path'][0][0][0]
    offset = int(matfile['sp'][0]['offset'][0][0][0])
    data_dtype = matfile['sp'][0]['dtype'][0][0][0]
    hp_filtered = bool(matfile['sp'][0]['hp_filtered'][0][0][0])
    vr_session_offset = matfile['sp'][0]['vr_session_offset'][0][0][0]
    # container
    lab_metadata = LabMetaData_ext(name='LabMetaData',
                                   acquisition_sampling_rate=sample_rate,
                                   number_of_electrodes=n_channels_dat,
                                   file_path=dat_path,
                                   bytes_to_skip=offset,
                                   raw_data_dtype=data_dtype,
                                   high_pass_filtered=hp_filtered,
                                   movie_start_time=vr_session_offset)
    nwbfile.add_lab_meta_data(lab_metadata)

    # Adding trial information
    nwbfile.add_trial_column(
        'trial_contrast',
        'visual contrast of the maze through which the mouse is running')
    trial = np.ravel(matfile['trial'])
    trial_nums = np.unique(trial)
    position_time = np.ravel(matfile['post'])
    # matlab trial numbers start at 1. To correctly index trial_contract vector,
    # subtracting 1 from 'num' so index starts at 0
    for num in trial_nums:
        trial_times = position_time[trial == num]
        nwbfile.add_trial(start_time=trial_times[0],
                          stop_time=trial_times[-1],
                          trial_contrast=matfile['trial_contrast'][num - 1][0])

    # Add mouse position inside:
    position = Position()
    position_virtual = np.ravel(matfile['posx'])
    # position inside the virtual environment
    sampling_rate = 1 / (position_time[1] - position_time[0])
    position.create_spatial_series(
        name='Position',
        data=position_virtual,
        starting_time=position_time[0],
        rate=sampling_rate,
        reference_frame='The start of the trial, which begins at the start '
        'of the virtual hallway.',
        conversion=0.01,
        description='Subject position in the virtual hallway.',
        comments='The values should be >0 and <400cm. Values greater than '
        '400cm mean that the mouse briefly exited the maze.',
    )

    # physical position on the mouse wheel
    physical_posx = position_virtual
    trial_gain = np.ravel(matfile['trial_gain'])
    for num in trial_nums:
        physical_posx[trial ==
                      num] = physical_posx[trial == num] / trial_gain[num - 1]

    position.create_spatial_series(
        name='PhysicalPosition',
        data=physical_posx,
        starting_time=position_time[0],
        rate=sampling_rate,
        reference_frame='Location on wheel re-referenced to zero '
        'at the start of each trial.',
        conversion=0.01,
        description='Physical location on the wheel measured '
        'since the beginning of the trial.',
        comments='Physical location found by dividing the '
        'virtual position by the "trial_gain"')
    nwbfile.add_acquisition(position)

    # Add timing of lick events, as well as mouse's virtual position during lick event
    lick_events = BehavioralEvents()
    lick_events.create_timeseries(
        'LickEvents',
        data=np.ravel(matfile['lickx']),
        timestamps=np.ravel(matfile['lickt']),
        unit='centimeter',
        description='Subject position in virtual hallway during the lick.')
    nwbfile.add_acquisition(lick_events)

    # Add information on the visual stimulus that was shown to the subject
    # Assumed rate=60 [Hz]. Update if necessary
    # Update external_file to link to Unity environment file
    visualization = ImageSeries(
        name='ImageSeries',
        unit='seconds',
        format='external',
        external_file=list(['https://unity.com/VR-and-AR-corner']),
        starting_time=vr_session_offset,
        starting_frame=[[0]],
        rate=float(60),
        description='virtual Unity environment that the mouse navigates through'
    )
    nwbfile.add_stimulus(visualization)

    # Add the recording device, a neuropixel probe
    recording_device = nwbfile.create_device(name='neuropixel_probes')
    electrode_group_description = 'single neuropixels probe http://www.open-ephys.org/neuropixelscorded'
    electrode_group_name = 'probe1'

    electrode_group = nwbfile.create_electrode_group(
        electrode_group_name,
        description=electrode_group_description,
        location=subject_brain_region,
        device=recording_device)

    # Add information about each electrode
    xcoords = np.ravel(matfile['sp'][0]['xcoords'][0])
    ycoords = np.ravel(matfile['sp'][0]['ycoords'][0])
    data_filtered_flag = matfile['sp'][0]['hp_filtered'][0][0]
    if data_filtered_flag:
        filter_desc = 'The raw voltage signals from the electrodes were high-pass filtered'
    else:
        filter_desc = 'The raw voltage signals from the electrodes were not high-pass filtered'

    num_recording_electrodes = xcoords.shape[0]
    recording_electrodes = range(0, num_recording_electrodes)

    # create electrode columns for the x,y location on the neuropixel  probe
    # the standard x,y,z locations are reserved for Allen Brain Atlas location
    nwbfile.add_electrode_column('rel_x', 'electrode x-location on the probe')
    nwbfile.add_electrode_column('rel_y', 'electrode y-location on the probe')

    for idx in recording_electrodes:
        nwbfile.add_electrode(id=idx,
                              x=np.nan,
                              y=np.nan,
                              z=np.nan,
                              rel_x=float(xcoords[idx]),
                              rel_y=float(ycoords[idx]),
                              imp=np.nan,
                              location='medial entorhinal cortex',
                              filtering=filter_desc,
                              group=electrode_group)

    # Add information about each unit, termed 'cluster' in giocomo data
    # create new columns in unit table
    nwbfile.add_unit_column(
        'quality',
        'labels given to clusters during manual sorting in phy (1=MUA, '
        '2=Good, 3=Unsorted)')

    # cluster information
    cluster_ids = matfile['sp'][0]['cids'][0][0]
    cluster_quality = matfile['sp'][0]['cgs'][0][0]
    # spikes in time
    spike_times = np.ravel(matfile['sp'][0]['st'][0])  # the time of each spike
    spike_cluster = np.ravel(
        matfile['sp'][0]['clu'][0])  # the cluster_id that spiked at that time

    for i, cluster_id in enumerate(cluster_ids):
        unit_spike_times = spike_times[spike_cluster == cluster_id]
        waveforms = matfile['sp'][0]['temps'][0][cluster_id]
        nwbfile.add_unit(id=int(cluster_id),
                         spike_times=unit_spike_times,
                         quality=cluster_quality[i],
                         waveform_mean=waveforms,
                         electrode_group=electrode_group)

    # Trying to add another Units table to hold the results of the automatic spike sorting
    # create TemplateUnits units table
    template_units = Units(
        name='TemplateUnits',
        description='units assigned during automatic spike sorting')
    template_units.add_column(
        'tempScalingAmps',
        'scaling amplitude applied to the template when extracting spike',
        index=True)

    # information on extracted spike templates
    spike_templates = np.ravel(matfile['sp'][0]['spikeTemplates'][0])
    spike_template_ids = np.unique(spike_templates)
    # template scaling amplitudes
    temp_scaling_amps = np.ravel(matfile['sp'][0]['tempScalingAmps'][0])

    for i, spike_template_id in enumerate(spike_template_ids):
        template_spike_times = spike_times[spike_templates ==
                                           spike_template_id]
        temp_scaling_amps_per_template = temp_scaling_amps[spike_templates ==
                                                           spike_template_id]
        template_units.add_unit(id=int(spike_template_id),
                                spike_times=template_spike_times,
                                electrode_group=electrode_group,
                                tempScalingAmps=temp_scaling_amps_per_template)

    # create ecephys processing module
    spike_template_module = nwbfile.create_processing_module(
        name='ecephys',
        description='units assigned during automatic spike sorting')

    # add template_units table to processing module
    spike_template_module.add(template_units)

    print(nwbfile)
    print('converted to NWB:N')
    print('saving ...')

    with NWBHDF5IO(outpath, 'w') as io:
        io.write(nwbfile)
        print('saved', outpath)
def conversion_function(source_paths,
                        f_nwb,
                        metadata,
                        add_spikeglx=False,
                        add_processed=False):
    """
    Copy data stored in a set of .npz files to a single NWB file.

    Parameters
    ----------
    source_paths : dict
        Dictionary with paths to source files/directories. e.g.:
        {'spikeglx data': {'type': 'file', 'path': ''}
         'processed data': {'type': 'file', 'path': ''}}
    f_nwb : str
        Path to output NWB file, e.g. 'my_file.nwb'.
    metadata : dict
        Dictionary containing metadata
    add_spikeglx: bool
    add_processed: bool
    """

    # Source files
    npx_file_path = None
    mat_file_path = None
    for k, v in source_paths.items():
        if source_paths[k]['path'] != '':
            if k == 'spikeglx data':
                npx_file_path = source_paths[k]['path']
            if k == 'processed data':
                mat_file_path = source_paths[k]['path']

    # Remove lab_meta_data from metadata, it will be added later
    metadata0 = copy.deepcopy(metadata)
    metadata0['NWBFile'].pop('lab_meta_data', None)

    # Create nwb
    nwbfile = pynwb.NWBFile(**metadata0['NWBFile'])

    # If adding processed data
    if add_processed:
        # Source matlab data
        matfile = hdf5storage.loadmat(mat_file_path)

        # Adding trial information
        nwbfile.add_trial_column(
            name='trial_contrast',
            description=
            'visual contrast of the maze through which the mouse is running')
        trial = np.ravel(matfile['trial'])
        trial_nums = np.unique(trial)
        position_time = np.ravel(matfile['post'])
        # matlab trial numbers start at 1. To correctly index trial_contract vector,
        # subtracting 1 from 'num' so index starts at 0
        for num in trial_nums:
            trial_times = position_time[trial == num]
            nwbfile.add_trial(start_time=trial_times[0],
                              stop_time=trial_times[-1],
                              trial_contrast=matfile['trial_contrast'][num -
                                                                       1][0])

        # create behavior processing module
        behavior = nwbfile.create_processing_module(
            name='behavior', description='behavior processing module')

        # Add mouse position
        position = Position(name=metadata['Behavior']['Position']['name'])
        meta_pos_names = [
            sps['name']
            for sps in metadata['Behavior']['Position']['spatial_series']
        ]

        # Position inside the virtual environment
        pos_vir_meta_ind = meta_pos_names.index('VirtualPosition')
        meta_vir = metadata['Behavior']['Position']['spatial_series'][
            pos_vir_meta_ind]
        position_virtual = np.ravel(matfile['posx'])
        sampling_rate = 1 / (position_time[1] - position_time[0])
        position.create_spatial_series(
            name=meta_vir['name'],
            data=position_virtual,
            starting_time=position_time[0],
            rate=sampling_rate,
            reference_frame=meta_vir['reference_frame'],
            conversion=meta_vir['conversion'],
            description=meta_vir['description'],
            comments=meta_vir['comments'])

        # Physical position on the mouse wheel
        pos_phys_meta_ind = meta_pos_names.index('PhysicalPosition')
        meta_phys = metadata['Behavior']['Position']['spatial_series'][
            pos_phys_meta_ind]
        physical_posx = position_virtual
        trial_gain = np.ravel(matfile['trial_gain'])
        for num in trial_nums:
            physical_posx[trial == num] = physical_posx[
                trial == num] / trial_gain[num - 1]
        position.create_spatial_series(
            name=meta_phys['name'],
            data=physical_posx,
            starting_time=position_time[0],
            rate=sampling_rate,
            reference_frame=meta_phys['reference_frame'],
            conversion=meta_phys['conversion'],
            description=meta_phys['description'],
            comments=meta_phys['comments'])

        behavior.add(position)

        # Add timing of lick events, as well as mouse's virtual position during lick event
        lick_events = BehavioralEvents(
            name=metadata['Behavior']['BehavioralEvents']['name'])
        meta_ts = metadata['Behavior']['BehavioralEvents']['time_series']
        meta_ts['data'] = np.ravel(matfile['lickx'])
        meta_ts['timestamps'] = np.ravel(matfile['lickt'])
        lick_events.create_timeseries(**meta_ts)

        behavior.add(lick_events)

        # Add the recording device, a neuropixel probe
        recording_device = nwbfile.create_device(
            name=metadata['Ecephys']['Device'][0]['name'])

        # Add ElectrodeGroup
        electrode_group = nwbfile.create_electrode_group(
            name=metadata['Ecephys']['ElectrodeGroup'][0]['name'],
            description=metadata['Ecephys']['ElectrodeGroup'][0]
            ['description'],
            location=metadata['Ecephys']['ElectrodeGroup'][0]['location'],
            device=recording_device)

        # Add information about each electrode
        xcoords = np.ravel(matfile['sp'][0]['xcoords'][0])
        ycoords = np.ravel(matfile['sp'][0]['ycoords'][0])
        data_filtered_flag = matfile['sp'][0]['hp_filtered'][0][0]
        if metadata['NWBFile']['lab_meta_data']['high_pass_filtered']:
            filter_desc = 'The raw voltage signals from the electrodes were high-pass filtered'
        else:
            filter_desc = 'The raw voltage signals from the electrodes were not high-pass filtered'
        num_recording_electrodes = xcoords.shape[0]
        recording_electrodes = range(0, num_recording_electrodes)

        # create electrode columns for the x,y location on the neuropixel  probe
        # the standard x,y,z locations are reserved for Allen Brain Atlas location
        nwbfile.add_electrode_column('relativex',
                                     'electrode x-location on the probe')
        nwbfile.add_electrode_column('relativey',
                                     'electrode y-location on the probe')
        for idx in recording_electrodes:
            nwbfile.add_electrode(id=idx,
                                  x=np.nan,
                                  y=np.nan,
                                  z=np.nan,
                                  relativex=float(xcoords[idx]),
                                  relativey=float(ycoords[idx]),
                                  imp=np.nan,
                                  location='medial entorhinal cortex',
                                  filtering=filter_desc,
                                  group=electrode_group)

        # Add information about each unit, termed 'cluster' in giocomo data
        # create new columns in unit table
        nwbfile.add_unit_column(
            name='quality',
            description='labels given to clusters during manual sorting in phy '
            '(1=MUA, 2=Good, 3=Unsorted)')

        # cluster information
        cluster_ids = matfile['sp'][0]['cids'][0][0]
        cluster_quality = matfile['sp'][0]['cgs'][0][0]
        # spikes in time
        spike_times = np.ravel(
            matfile['sp'][0]['st'][0])  # the time of each spike
        spike_cluster = np.ravel(
            matfile['sp'][0]['clu']
            [0])  # the cluster_id that spiked at that time
        for i, cluster_id in enumerate(cluster_ids):
            unit_spike_times = spike_times[spike_cluster == cluster_id]
            waveforms = matfile['sp'][0]['temps'][0][cluster_id]
            nwbfile.add_unit(id=int(cluster_id),
                             spike_times=unit_spike_times,
                             quality=cluster_quality[i],
                             waveform_mean=waveforms,
                             electrode_group=electrode_group)

        # Trying to add another Units table to hold the results of the automatic spike sorting
        # create TemplateUnits units table
        template_units = Units(
            name='TemplateUnits',
            description='units assigned during automatic spike sorting')
        template_units.add_column(
            name='tempScalingAmps',
            description=
            'scaling amplitude applied to the template when extracting spike',
            index=True)
        # information on extracted spike templates
        spike_templates = np.ravel(matfile['sp'][0]['spikeTemplates'][0])
        spike_template_ids = np.unique(spike_templates)
        # template scaling amplitudes
        temp_scaling_amps = np.ravel(matfile['sp'][0]['tempScalingAmps'][0])
        for i, spike_template_id in enumerate(spike_template_ids):
            template_spike_times = spike_times[spike_templates ==
                                               spike_template_id]
            temp_scaling_amps_per_template = temp_scaling_amps[
                spike_templates == spike_template_id]
            template_units.add_unit(
                id=int(spike_template_id),
                spike_times=template_spike_times,
                electrode_group=electrode_group,
                tempScalingAmps=temp_scaling_amps_per_template)

        # create ecephys processing module
        spike_template_module = nwbfile.create_processing_module(
            name='ecephys',
            description='units assigned during automatic spike sorting')
        # add template_units table to processing module
        spike_template_module.add(template_units)

    # Add other fields
    # Add lab_meta_data
    if 'lab_meta_data' in metadata['NWBFile']:
        lab_metadata = LabMetaData_ext(
            name=metadata['NWBFile']['lab_meta_data']['name'],
            acquisition_sampling_rate=metadata['NWBFile']['lab_meta_data']
            ['acquisition_sampling_rate'],
            number_of_electrodes=metadata['NWBFile']['lab_meta_data']
            ['number_of_electrodes'],
            file_path=metadata['NWBFile']['lab_meta_data']['file_path'],
            bytes_to_skip=metadata['NWBFile']['lab_meta_data']
            ['bytes_to_skip'],
            raw_data_dtype=metadata['NWBFile']['lab_meta_data']
            ['raw_data_dtype'],
            high_pass_filtered=metadata['NWBFile']['lab_meta_data']
            ['high_pass_filtered'],
            movie_start_time=metadata['NWBFile']['lab_meta_data']
            ['movie_start_time'],
            subject_brain_region=metadata['NWBFile']['lab_meta_data']
            ['subject_brain_region'])
        nwbfile.add_lab_meta_data(lab_metadata)

    # add information about the subject of the experiment
    if 'Subject' in metadata:
        experiment_subject = Subject(
            subject_id=metadata['Subject']['subject_id'],
            species=metadata['Subject']['species'],
            description=metadata['Subject']['description'],
            genotype=metadata['Subject']['genotype'],
            date_of_birth=metadata['Subject']['date_of_birth'],
            weight=metadata['Subject']['weight'],
            sex=metadata['Subject']['sex'])
        nwbfile.subject = experiment_subject

    # If adding SpikeGLX data
    if add_spikeglx:
        # Create extractor for SpikeGLX data
        extractor = Spikeglx2NWB(nwbfile=nwbfile,
                                 metadata=metadata0,
                                 npx_file=npx_file_path)
        # Add acquisition data
        extractor.add_acquisition(es_name='ElectricalSeries',
                                  metadata=metadata['Ecephys'])
        # Run spike sorting method
        #extractor.run_spike_sorting()
        # Save content to NWB file
        extractor.save(to_path=f_nwb)
    else:
        # Write to nwb file
        with pynwb.NWBHDF5IO(f_nwb, 'w') as io:
            io.write(nwbfile)
            print(nwbfile)

    # Check file was saved and inform on screen
    print('File saved at:')
    print(f_nwb)
    print('Size: ', os.stat(f_nwb).st_size / 1e6, ' mb')
Exemple #19
0
 def test_times(self):
     ut = Units()
     ut.add_unit(spike_times=[0, 1, 2])
     ut.add_unit(spike_times=[3, 4, 5])
     self.assertTrue(all(ut['spike_times'][0] == np.array([0, 1, 2])))
     self.assertTrue(all(ut['spike_times'][1] == np.array([3, 4, 5])))
Exemple #20
0
 def test_get_spike_times(self):
     ut = Units()
     ut.add_unit(spike_times=[0, 1, 2])
     ut.add_unit(spike_times=[3, 4, 5])
     self.assertTrue(all(ut.get_unit_spike_times(0) == np.array([0, 1, 2])))
     self.assertTrue(all(ut.get_unit_spike_times(1) == np.array([3, 4, 5])))
Exemple #21
0
 def test_waveform_attrs(self):
     ut = Units(waveform_rate=40000.)
     self.assertEqual(ut.waveform_rate, 40000.)
     self.assertEqual(ut.waveform_unit, 'volts')
Exemple #22
0
 def setUpContainer(self):
     """ Return placeholder Units object. Tested units are added directly to the NWBFile in addContainer """
     return Units('placeholder')  # this will get ignored
Exemple #23
0
 def setUpContainer(self):
     # this will get ignored
     return Units('placeholder_units')
Exemple #24
0
 def test_init(self):
     ut = Units()
     self.assertEqual(ut.name, 'Units')
     self.assertFalse(ut.columns)