Esempio n. 1
0
class NWBFileTest(unittest.TestCase):
    def setUp(self):
        self.start = datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())
        self.ref_time = datetime(1979, 1, 1, 0, tzinfo=tzutc())
        self.create = [
            datetime(2017, 5, 1, 12, tzinfo=tzlocal()),
            datetime(2017, 5, 2, 13, 0, 0, 1, tzinfo=tzutc()),
            datetime(2017, 5, 2, 14, tzinfo=tzutc())
        ]
        self.path = 'nwbfile_test.h5'
        self.nwbfile = NWBFile(
            'a test session description for a test NWBFile',
            'FILE123',
            self.start,
            file_create_date=self.create,
            timestamps_reference_time=self.ref_time,
            experimenter='A test experimenter',
            lab='a test lab',
            institution='a test institution',
            experiment_description='a test experiment description',
            session_id='test1',
            notes='my notes',
            pharmacology='drugs',
            protocol='protocol',
            related_publications='my pubs',
            slices='my slices',
            surgery='surgery',
            virus='a virus',
            source_script='noscript',
            source_script_file_name='nofilename',
            stimulus_notes='test stimulus notes',
            data_collection='test data collection notes',
            keywords=('these', 'are', 'keywords'))

    def test_constructor(self):
        self.assertEqual(self.nwbfile.session_description,
                         'a test session description for a test NWBFile')
        self.assertEqual(self.nwbfile.identifier, 'FILE123')
        self.assertEqual(self.nwbfile.session_start_time, self.start)
        self.assertEqual(self.nwbfile.file_create_date, self.create)
        self.assertEqual(self.nwbfile.lab, 'a test lab')
        self.assertEqual(self.nwbfile.experimenter, 'A test experimenter')
        self.assertEqual(self.nwbfile.institution, 'a test institution')
        self.assertEqual(self.nwbfile.experiment_description,
                         'a test experiment description')
        self.assertEqual(self.nwbfile.session_id, 'test1')
        self.assertEqual(self.nwbfile.stimulus_notes, 'test stimulus notes')
        self.assertEqual(self.nwbfile.data_collection,
                         'test data collection notes')
        self.assertEqual(self.nwbfile.source_script, 'noscript')
        self.assertEqual(self.nwbfile.source_script_file_name, 'nofilename')
        self.assertEqual(self.nwbfile.keywords, ('these', 'are', 'keywords'))
        self.assertEqual(self.nwbfile.timestamps_reference_time, self.ref_time)

    def test_create_electrode_group(self):
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        d = self.nwbfile.create_device('a fake device')
        elecgrp = self.nwbfile.create_electrode_group(name, desc, loc, d)
        self.assertEqual(elecgrp.description, desc)
        self.assertEqual(elecgrp.location, loc)
        self.assertIs(elecgrp.device, d)

    def test_create_electrode_group_invalid_index(self):
        """
        Test the case where the user creates an electrode table region with
        indexes that are out of range of the amount of electrodes added.
        """
        nwbfile = NWBFile('a', 'b', datetime.now(tzlocal()))
        device = nwbfile.create_device('a')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 device=device,
                                                 location='a')
        for i in range(4):
            nwbfile.add_electrode(np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  'a',
                                  'a',
                                  elecgrp,
                                  id=i)
        with self.assertRaises(IndexError):
            nwbfile.create_electrode_table_region(list(range(6)), 'test')

    def test_access_group_after_io(self):
        """
        Motivated by #739
        """
        nwbfile = NWBFile('a', 'b', datetime.now(tzlocal()))
        device = nwbfile.create_device('a')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 device=device,
                                                 location='a')
        nwbfile.add_electrode(np.nan,
                              np.nan,
                              np.nan,
                              np.nan,
                              'a',
                              'a',
                              elecgrp,
                              id=0)

        with NWBHDF5IO('electrodes_mwe.nwb', 'w') as io:
            io.write(nwbfile)

        with NWBHDF5IO('electrodes_mwe.nwb', 'a') as io:
            nwbfile_i = io.read()
            for aa, bb in zip(nwbfile_i.electrodes['group'][:],
                              nwbfile.electrodes['group'][:]):
                self.assertEqual(aa.name, bb.name)

        for i in range(4):
            nwbfile.add_electrode(np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  'a',
                                  'a',
                                  elecgrp,
                                  id=i + 1)

        with NWBHDF5IO('electrodes_mwe.nwb', 'w') as io:
            io.write(nwbfile)

        with NWBHDF5IO('electrodes_mwe.nwb', 'a') as io:
            nwbfile_i = io.read()
            for aa, bb in zip(nwbfile_i.electrodes['group'][:],
                              nwbfile.electrodes['group'][:]):
                self.assertEqual(aa.name, bb.name)

        os.remove("electrodes_mwe.nwb")

    def test_epoch_tags(self):
        tags1 = ['t1', 't2']
        tags2 = ['t3', 't4']
        tstamps = np.arange(1.0, 100.0, 0.1, dtype=np.float)
        ts = TimeSeries("test_ts",
                        list(range(len(tstamps))),
                        'unit',
                        timestamps=tstamps)
        expected_tags = tags1 + tags2
        self.nwbfile.add_epoch(0.0, 1.0, tags1, ts)
        self.nwbfile.add_epoch(0.0, 1.0, tags2, ts)
        tags = self.nwbfile.epoch_tags
        six.assertCountEqual(self, expected_tags, tags)

    def test_add_acquisition(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.acquisition), 1)

    def test_add_stimulus(self):
        self.nwbfile.add_stimulus(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus), 1)

    def test_add_stimulus_template(self):
        self.nwbfile.add_stimulus_template(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus_template), 1)

    def test_add_acquisition_check_dups(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        with self.assertRaises(ValueError):
            self.nwbfile.add_acquisition(
                TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                           'grams',
                           timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))

    def test_get_acquisition_empty(self):
        with self.assertRaisesRegex(ValueError,
                                    "acquisition of NWBFile 'root' is empty"):
            self.nwbfile.get_acquisition()

    def test_get_acquisition_multiple_elements(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "more than one element in acquisition of NWBFile 'root' -- must specify a name"
        with self.assertRaisesRegex(ValueError, msg):
            self.nwbfile.get_acquisition()

    def test_add_acquisition_invalid_name(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "'TEST_TS' not found in acquisition of NWBFile 'root'"
        with self.assertRaisesRegex(KeyError, msg):
            self.nwbfile.get_acquisition("TEST_TS")

    def test_set_electrode_table(self):
        table = ElectrodeTable()  # noqa: F405
        dev1 = self.nwbfile.create_device('dev1')  # noqa: F405
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-1.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-2.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-3.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-4.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        self.nwbfile.set_electrode_table(table)

        self.assertIs(self.nwbfile.ec_electrodes, self.nwbfile.electrodes)
        self.assertIs(self.nwbfile.electrodes, table)
        self.assertIs(table.parent, self.nwbfile)

    def test_add_unit_column(self):
        self.nwbfile.add_unit_column('unit_type', 'the type of unit')
        self.assertEqual(self.nwbfile.units.colnames, ('unit_type', ))

    def test_add_unit(self):
        self.nwbfile.add_unit(id=1)
        self.assertEqual(len(self.nwbfile.units), 1)
        self.nwbfile.add_unit(id=2)
        self.nwbfile.add_unit(id=3)
        self.assertEqual(len(self.nwbfile.units), 3)

    def test_add_trial_column(self):
        self.nwbfile.add_trial_column('trial_type', 'the type of trial')
        self.assertEqual(self.nwbfile.trials.colnames,
                         ('start_time', 'stop_time', 'trial_type'))

    def test_add_trial(self):
        self.nwbfile.add_trial(start_time=10.0, stop_time=20.0)
        self.assertEqual(len(self.nwbfile.trials), 1)
        self.nwbfile.add_trial(start_time=30.0, stop_time=40.0)
        self.nwbfile.add_trial(start_time=50.0, stop_time=70.0)
        self.assertEqual(len(self.nwbfile.trials), 3)

    def test_add_invalid_times_column(self):
        self.nwbfile.add_invalid_times_column(
            'comments', 'description of reason for omitting time')
        self.assertEqual(self.nwbfile.invalid_times.colnames,
                         ('start_time', 'stop_time', 'comments'))

    def test_add_invalid_time_interval(self):

        self.nwbfile.add_invalid_time_interval(start_time=0.0, stop_time=12.0)
        self.assertEqual(len(self.nwbfile.invalid_times), 1)
        self.nwbfile.add_invalid_time_interval(start_time=15.0, stop_time=16.0)
        self.nwbfile.add_invalid_time_interval(start_time=17.0, stop_time=20.5)
        self.assertEqual(len(self.nwbfile.invalid_times), 3)

    def test_add_invalid_time_w_ts(self):
        ts = TimeSeries(name='name', data=[1.2], rate=1.0, unit='na')
        self.nwbfile.add_invalid_time_interval(start_time=18.0,
                                               stop_time=20.6,
                                               timeseries=ts,
                                               tags=('hi', 'there'))

    def test_add_electrode(self):
        dev1 = self.nwbfile.create_device('dev1')  # noqa: F405
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        self.nwbfile.add_electrode(1.0,
                                   2.0,
                                   3.0,
                                   -1.0,
                                   'CA1',
                                   'none',
                                   group=group,
                                   id=1)
        self.assertEqual(self.nwbfile.electrodes[0][0], 1)
        self.assertEqual(self.nwbfile.electrodes[0][1], 1.0)
        self.assertEqual(self.nwbfile.electrodes[0][2], 2.0)
        self.assertEqual(self.nwbfile.electrodes[0][3], 3.0)
        self.assertEqual(self.nwbfile.electrodes[0][4], -1.0)
        self.assertEqual(self.nwbfile.electrodes[0][5], 'CA1')
        self.assertEqual(self.nwbfile.electrodes[0][6], 'none')
        self.assertEqual(self.nwbfile.electrodes[0][7], group)

    def test_all_children(self):
        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        device = self.nwbfile.create_device('a fake device')
        elecgrp = self.nwbfile.create_electrode_group(name, desc, loc, device)
        children = self.nwbfile.all_children()
        self.assertIn(ts1, children)
        self.assertIn(ts2, children)
        self.assertIn(device, children)
        self.assertIn(elecgrp, children)

    def test_fail_if_source_script_file_name_without_source_script(self):
        with self.assertRaises(ValueError):
            # <-- source_script_file_name without source_script is not allowed
            NWBFile('a test session description for a test NWBFile',
                    'FILE123',
                    self.start,
                    source_script=None,
                    source_script_file_name='nofilename')

    def test_get_neurodata_type(self):
        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        p1 = ts1.get_ancestor(neurodata_type='NWBFile')
        self.assertIs(p1, self.nwbfile)
        p2 = ts2.get_ancestor(neurodata_type='NWBFile')
        self.assertIs(p2, self.nwbfile)
Esempio n. 2
0
class NWBFileTest(unittest.TestCase):
    def setUp(self):
        self.start = datetime(2017, 5, 1, 12, 0, 0)
        self.path = 'nwbfile_test.h5'
        self.nwbfile = NWBFile(
            'a fake source',
            'a test session description for a test NWBFile',
            'FILE123',
            self.start,
            experimenter='A test experimenter',
            lab='a test lab',
            institution='a test institution',
            experiment_description='a test experiment description',
            session_id='test1',
            notes='my notes',
            pharmacology='drugs',
            protocol='protocol',
            related_publications='my pubs',
            slices='my slices',
            surgery='surgery',
            virus='a virus',
            source_script='noscript',
            source_script_file_name='nofilename',
            stimulus_notes='test stimulus notes',
            data_collection='test data collection notes')

    def test_constructor(self):
        self.assertEqual(self.nwbfile.session_description,
                         'a test session description for a test NWBFile')
        self.assertEqual(self.nwbfile.identifier, 'FILE123')
        self.assertEqual(self.nwbfile.session_start_time, self.start)
        self.assertEqual(self.nwbfile.lab, 'a test lab')
        self.assertEqual(self.nwbfile.experimenter, 'A test experimenter')
        self.assertEqual(self.nwbfile.institution, 'a test institution')
        self.assertEqual(self.nwbfile.experiment_description,
                         'a test experiment description')
        self.assertEqual(self.nwbfile.session_id, 'test1')
        self.assertEqual(self.nwbfile.stimulus_notes, 'test stimulus notes')
        self.assertEqual(self.nwbfile.data_collection,
                         'test data collection notes')
        self.assertEqual(self.nwbfile.source_script, 'noscript')
        self.assertEqual(self.nwbfile.source_script_file_name, 'nofilename')

    def test_create_electrode_group(self):
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        d = self.nwbfile.create_device('a fake device', 'a fake source')
        elecgrp = self.nwbfile.create_electrode_group(name, 'a fake source',
                                                      desc, loc, d)
        self.assertEqual(elecgrp.description, desc)
        self.assertEqual(elecgrp.location, loc)
        self.assertIs(elecgrp.device, d)

    def test_create_electrode_group_invalid_index(self):
        """
        Test the case where the user creates an electrode table region with
        indexes that are out of range of the amount of electrodes added.
        """
        nwbfile = NWBFile('a', 'b', 'c', datetime.now())
        device = nwbfile.create_device('a', 'b')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 'c',
                                                 device=device,
                                                 location='a')
        for i in range(4):
            nwbfile.add_electrode(i,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  group=elecgrp,
                                  location='a',
                                  filtering='a',
                                  description='a')
        with self.assertRaises(IndexError) as err:
            nwbfile.create_electrode_table_region(list(range(6)), 'test')
        self.assertTrue('out of range' in str(err.exception))

    def test_epoch_tags(self):
        tags1 = ['t1', 't2']
        tags2 = ['t3', 't4']
        tstamps = np.arange(1.0, 100.0, 0.1, dtype=np.float)
        ts = TimeSeries("test_ts",
                        "a hypothetical source",
                        list(range(len(tstamps))),
                        'unit',
                        timestamps=tstamps)
        expected_tags = tags1 + tags2
        self.nwbfile.create_epoch('a fake epoch', 0.0, 1.0, tags1, ts)
        self.nwbfile.create_epoch('a second fake epoch', 0.0, 1.0, tags2, ts)
        tags = self.nwbfile.epoch_tags
        six.assertCountEqual(self, expected_tags, tags)

    def test_add_acquisition(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts',
                       'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.acquisition), 1)

    def test_add_stimulus(self):
        self.nwbfile.add_stimulus(
            TimeSeries('test_ts',
                       'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus), 1)

    def test_add_stimulus_template(self):
        self.nwbfile.add_stimulus_template(
            TimeSeries('test_ts',
                       'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus_template), 1)

    def test_add_acquisition_check_dups(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts',
                       'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        with self.assertRaises(ValueError):
            self.nwbfile.add_acquisition(
                TimeSeries('test_ts',
                           'unit test test_add_acquisition',
                           [0, 1, 2, 3, 4, 5],
                           'grams',
                           timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))

    def test_get_acquisition_empty(self):
        with self.assertRaisesRegex(ValueError,
                                    "acquisition of NWBFile 'root' is empty"):
            self.nwbfile.get_acquisition()

    def test_get_acquisition_multiple_elements(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts1',
                       'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts2',
                       'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "more than one element in acquisition of NWBFile 'root' -- must specify a name"
        with self.assertRaisesRegex(ValueError, msg):
            self.nwbfile.get_acquisition()

    def test_add_acquisition_invalid_name(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts',
                       'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "'TEST_TS' not found in acquisition of NWBFile 'root'"
        with self.assertRaisesRegex(KeyError, msg):
            self.nwbfile.get_acquisition("TEST_TS")

    def test_set_electrode_table(self):
        table = ElectrodeTable('test_table')  # noqa: F405
        dev1 = self.nwbfile.create_device('dev1',
                                          'a test source')  # noqa: F405
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'a test source',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        table.add_row(1, 1.0, 2.0, 3.0, -1.0, 'CA1', 'none',
                      'first channel of tetrode', group)
        table.add_row(2, 1.0, 2.0, 3.0, -2.0, 'CA1', 'none',
                      'second channel of tetrode', group)
        table.add_row(3, 1.0, 2.0, 3.0, -3.0, 'CA1', 'none',
                      'third channel of tetrode', group)
        table.add_row(4, 1.0, 2.0, 3.0, -4.0, 'CA1', 'none',
                      'fourth channel of tetrode', group)
        self.nwbfile.set_electrode_table(table)
        self.assertIs(self.nwbfile.ec_electrodes, table)
        self.assertIs(table.parent, self.nwbfile)

    def test_add_unit_column(self):
        self.nwbfile.add_unit_column('unit_type', 'the type of unit')
        self.assertEqual(self.nwbfile.units.colnames, ('id', 'unit_type'))

    def test_add_unit(self):
        self.nwbfile.add_unit({'id': 1})
        self.assertEqual(len(self.nwbfile.units), 1)
        self.nwbfile.add_unit({'id': 2})
        self.nwbfile.add_unit({'id': 3})
        self.assertEqual(len(self.nwbfile.units), 3)

    def test_add_trial_column(self):
        self.nwbfile.add_trial_column('trial_type', 'the type of trial')
        self.assertEqual(self.nwbfile.trials.colnames,
                         ('start', 'end', 'trial_type'))

    def test_add_trial(self):
        self.nwbfile.add_trial({'start': 10, 'end': 20})
        self.assertEqual(len(self.nwbfile.trials), 1)
        self.nwbfile.add_trial({'start': 30, 'end': 40})
        self.nwbfile.add_trial({'start': 50, 'end': 70})
        self.assertEqual(len(self.nwbfile.trials), 3)

    def test_add_electrode(self):
        dev1 = self.nwbfile.create_device('dev1',
                                          'a test source')  # noqa: F405
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'a test source',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        self.nwbfile.add_electrode(1, 1.0, 2.0, 3.0, -1.0, 'CA1', 'none',
                                   'first channel of tetrode', group)
        self.assertEqual(self.nwbfile.ec_electrodes[0][0], 1)
        self.assertEqual(self.nwbfile.ec_electrodes[0][1], 1.0)
        self.assertEqual(self.nwbfile.ec_electrodes[0][2], 2.0)
        self.assertEqual(self.nwbfile.ec_electrodes[0][3], 3.0)
        self.assertEqual(self.nwbfile.ec_electrodes[0][4], -1.0)
        self.assertEqual(self.nwbfile.ec_electrodes[0][5], 'CA1')
        self.assertEqual(self.nwbfile.ec_electrodes[0][6], 'none')
        self.assertEqual(self.nwbfile.ec_electrodes[0][7],
                         'first channel of tetrode')
        self.assertEqual(self.nwbfile.ec_electrodes[0][8], group)

    def test_all_children(self):
        ts1 = TimeSeries('test_ts1',
                         'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2',
                         'unit test test_add_acquisition', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        device = self.nwbfile.create_device('a fake device', 'a fake source')
        elecgrp = self.nwbfile.create_electrode_group(name, 'a fake source',
                                                      desc, loc, device)
        children = self.nwbfile.all_children()
        self.assertIn(ts1, children)
        self.assertIn(ts2, children)
        self.assertIn(device, children)
        self.assertIn(elecgrp, children)

    def test_fail_if_source_script_file_name_without_source_script(self):
        with self.assertRaises(ValueError):
            # <-- source_script_file_name without source_script is not allowed
            NWBFile('a fake source',
                    'a test session description for a test NWBFile',
                    'FILE123',
                    self.start,
                    source_script=None,
                    source_script_file_name='nofilename')
Esempio n. 3
0
class NWBFileTest(TestCase):
    def setUp(self):
        self.start = datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())
        self.ref_time = datetime(1979, 1, 1, 0, tzinfo=tzutc())
        self.create = [
            datetime(2017, 5, 1, 12, tzinfo=tzlocal()),
            datetime(2017, 5, 2, 13, 0, 0, 1, tzinfo=tzutc()),
            datetime(2017, 5, 2, 14, tzinfo=tzutc())
        ]
        self.path = 'nwbfile_test.h5'
        self.nwbfile = NWBFile(
            'a test session description for a test NWBFile',
            'FILE123',
            self.start,
            file_create_date=self.create,
            timestamps_reference_time=self.ref_time,
            experimenter='A test experimenter',
            lab='a test lab',
            institution='a test institution',
            experiment_description='a test experiment description',
            session_id='test1',
            notes='my notes',
            pharmacology='drugs',
            protocol='protocol',
            related_publications='my pubs',
            slices='my slices',
            surgery='surgery',
            virus='a virus',
            source_script='noscript',
            source_script_file_name='nofilename',
            stimulus_notes='test stimulus notes',
            data_collection='test data collection notes',
            keywords=('these', 'are', 'keywords'))

    def test_constructor(self):
        self.assertEqual(self.nwbfile.session_description,
                         'a test session description for a test NWBFile')
        self.assertEqual(self.nwbfile.identifier, 'FILE123')
        self.assertEqual(self.nwbfile.session_start_time, self.start)
        self.assertEqual(self.nwbfile.file_create_date, self.create)
        self.assertEqual(self.nwbfile.lab, 'a test lab')
        self.assertEqual(self.nwbfile.experimenter, ('A test experimenter', ))
        self.assertEqual(self.nwbfile.institution, 'a test institution')
        self.assertEqual(self.nwbfile.experiment_description,
                         'a test experiment description')
        self.assertEqual(self.nwbfile.session_id, 'test1')
        self.assertEqual(self.nwbfile.stimulus_notes, 'test stimulus notes')
        self.assertEqual(self.nwbfile.data_collection,
                         'test data collection notes')
        self.assertEqual(self.nwbfile.related_publications, ('my pubs', ))
        self.assertEqual(self.nwbfile.source_script, 'noscript')
        self.assertEqual(self.nwbfile.source_script_file_name, 'nofilename')
        self.assertEqual(self.nwbfile.keywords, ('these', 'are', 'keywords'))
        self.assertEqual(self.nwbfile.timestamps_reference_time, self.ref_time)

    def test_create_electrode_group(self):
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        d = self.nwbfile.create_device('a fake device')
        elecgrp = self.nwbfile.create_electrode_group(name, desc, loc, d)
        self.assertEqual(elecgrp.description, desc)
        self.assertEqual(elecgrp.location, loc)
        self.assertIs(elecgrp.device, d)

    def test_create_custom_intervals(self):
        df_words = pd.DataFrame({
            'start_time': [.1, 2.],
            'stop_time': [.8, 2.3],
            'label': ['hello', 'there']
        })
        words = TimeIntervals.from_dataframe(df_words, name='words')
        self.nwbfile.add_time_intervals(words)
        self.assertEqual(self.nwbfile.intervals['words'], words)

    def test_create_electrode_group_invalid_index(self):
        """
        Test the case where the user creates an electrode table region with
        indexes that are out of range of the amount of electrodes added.
        """
        nwbfile = NWBFile('a', 'b', datetime.now(tzlocal()))
        device = nwbfile.create_device('a')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 device=device,
                                                 location='a')
        for i in range(4):
            nwbfile.add_electrode(np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  'a',
                                  'a',
                                  elecgrp,
                                  id=i)
        with self.assertRaises(IndexError):
            nwbfile.create_electrode_table_region(list(range(6)), 'test')

    def test_access_group_after_io(self):
        """
        Motivated by #739
        """
        nwbfile = NWBFile('a', 'b', datetime.now(tzlocal()))
        device = nwbfile.create_device('a')
        elecgrp = nwbfile.create_electrode_group('a',
                                                 'b',
                                                 device=device,
                                                 location='a')
        nwbfile.add_electrode(np.nan,
                              np.nan,
                              np.nan,
                              np.nan,
                              'a',
                              'a',
                              elecgrp,
                              id=0)

        with NWBHDF5IO('electrodes_mwe.nwb', 'w') as io:
            io.write(nwbfile)

        with NWBHDF5IO('electrodes_mwe.nwb', 'a') as io:
            nwbfile_i = io.read()
            for aa, bb in zip(nwbfile_i.electrodes['group'][:],
                              nwbfile.electrodes['group'][:]):
                self.assertEqual(aa.name, bb.name)

        for i in range(4):
            nwbfile.add_electrode(np.nan,
                                  np.nan,
                                  np.nan,
                                  np.nan,
                                  'a',
                                  'a',
                                  elecgrp,
                                  id=i + 1)

        with NWBHDF5IO('electrodes_mwe.nwb', 'w') as io:
            io.write(nwbfile)

        with NWBHDF5IO('electrodes_mwe.nwb', 'a') as io:
            nwbfile_i = io.read()
            for aa, bb in zip(nwbfile_i.electrodes['group'][:],
                              nwbfile.electrodes['group'][:]):
                self.assertEqual(aa.name, bb.name)

        remove_test_file("electrodes_mwe.nwb")

    def test_access_processing(self):
        self.nwbfile.create_processing_module('test_mod', 'test_description')
        # test deprecate .modules
        with self.assertWarnsWith(
                DeprecationWarning,
                'NWBFile.modules has been replaced by NWBFile.processing.'):
            modules = self.nwbfile.modules['test_mod']
        self.assertIs(self.nwbfile.processing['test_mod'], modules)

    def test_epoch_tags(self):
        tags1 = ['t1', 't2']
        tags2 = ['t3', 't4']
        tstamps = np.arange(1.0, 100.0, 0.1, dtype=np.float)
        ts = TimeSeries("test_ts",
                        list(range(len(tstamps))),
                        'unit',
                        timestamps=tstamps)
        expected_tags = tags1 + tags2
        self.nwbfile.add_epoch(0.0, 1.0, tags1, ts)
        self.nwbfile.add_epoch(0.0, 1.0, tags2, ts)
        tags = self.nwbfile.epoch_tags
        self.assertEqual(set(expected_tags), set(tags))

    def test_add_acquisition(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.acquisition), 1)

    def test_add_stimulus(self):
        self.nwbfile.add_stimulus(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus), 1)

    def test_add_stimulus_template(self):
        self.nwbfile.add_stimulus_template(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.stimulus_template), 1)

    def test_add_analysis(self):
        self.nwbfile.add_analysis(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.assertEqual(len(self.nwbfile.analysis), 1)

    def test_add_acquisition_check_dups(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        with self.assertRaises(ValueError):
            self.nwbfile.add_acquisition(
                TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                           'grams',
                           timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))

    def test_get_acquisition_empty(self):
        with self.assertRaisesWith(ValueError,
                                   "acquisition of NWBFile 'root' is empty."):
            self.nwbfile.get_acquisition()

    def test_get_acquisition_multiple_elements(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "More than one element in acquisition of NWBFile 'root' -- must specify a name."
        with self.assertRaisesWith(ValueError, msg):
            self.nwbfile.get_acquisition()

    def test_add_acquisition_invalid_name(self):
        self.nwbfile.add_acquisition(
            TimeSeries('test_ts', [0, 1, 2, 3, 4, 5],
                       'grams',
                       timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]))
        msg = "\"'TEST_TS' not found in acquisition of NWBFile 'root'.\""
        with self.assertRaisesWith(KeyError, msg):
            self.nwbfile.get_acquisition("TEST_TS")

    def test_set_electrode_table(self):
        table = ElectrodeTable()
        dev1 = self.nwbfile.create_device('dev1')
        group = self.nwbfile.create_electrode_group('tetrode1',
                                                    'tetrode description',
                                                    'tetrode location', dev1)
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-1.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-2.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-3.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        table.add_row(x=1.0,
                      y=2.0,
                      z=3.0,
                      imp=-4.0,
                      location='CA1',
                      filtering='none',
                      group=group,
                      group_name='tetrode1')
        self.nwbfile.set_electrode_table(table)

        self.assertIs(self.nwbfile.electrodes, table)
        self.assertIs(table.parent, self.nwbfile)

    def test_add_unit_column(self):
        self.nwbfile.add_unit_column('unit_type', 'the type of unit')
        self.assertEqual(self.nwbfile.units.colnames, ('unit_type', ))

    def test_add_unit(self):
        self.nwbfile.add_unit(id=1)
        self.assertEqual(len(self.nwbfile.units), 1)
        self.nwbfile.add_unit(id=2)
        self.nwbfile.add_unit(id=3)
        self.assertEqual(len(self.nwbfile.units), 3)

    def test_add_trial_column(self):
        self.nwbfile.add_trial_column('trial_type', 'the type of trial')
        self.assertEqual(self.nwbfile.trials.colnames,
                         ('start_time', 'stop_time', 'trial_type'))

    def test_add_trial(self):
        self.nwbfile.add_trial(start_time=10.0, stop_time=20.0)
        self.assertEqual(len(self.nwbfile.trials), 1)
        self.nwbfile.add_trial(start_time=30.0, stop_time=40.0)
        self.nwbfile.add_trial(start_time=50.0, stop_time=70.0)
        self.assertEqual(len(self.nwbfile.trials), 3)

    def test_add_invalid_times_column(self):
        self.nwbfile.add_invalid_times_column(
            'comments', 'description of reason for omitting time')
        self.assertEqual(self.nwbfile.invalid_times.colnames,
                         ('start_time', 'stop_time', 'comments'))

    def test_add_invalid_time_interval(self):

        self.nwbfile.add_invalid_time_interval(start_time=0.0, stop_time=12.0)
        self.assertEqual(len(self.nwbfile.invalid_times), 1)
        self.nwbfile.add_invalid_time_interval(start_time=15.0, stop_time=16.0)
        self.nwbfile.add_invalid_time_interval(start_time=17.0, stop_time=20.5)
        self.assertEqual(len(self.nwbfile.invalid_times), 3)

    def test_add_invalid_time_w_ts(self):
        ts = TimeSeries(name='name', data=[1.2], rate=1.0, unit='na')
        self.nwbfile.add_invalid_time_interval(start_time=18.0,
                                               stop_time=20.6,
                                               timeseries=ts,
                                               tags=('hi', 'there'))

    def test_add_electrode(self):
        dev1 = self.nwbfile.create_device(name='dev1')
        group = self.nwbfile.create_electrode_group(
            name='tetrode1',
            description='tetrode description',
            location='tetrode location',
            device=dev1)
        self.nwbfile.add_electrode(x=1.0,
                                   y=2.0,
                                   z=3.0,
                                   imp=-1.0,
                                   location='CA1',
                                   filtering='none',
                                   group=group,
                                   id=1)
        elec = self.nwbfile.electrodes[0]
        self.assertEqual(elec.index[0], 1)
        self.assertEqual(elec.iloc[0]['x'], 1.0)
        self.assertEqual(elec.iloc[0]['y'], 2.0)
        self.assertEqual(elec.iloc[0]['z'], 3.0)
        self.assertEqual(elec.iloc[0]['location'], 'CA1')
        self.assertEqual(elec.iloc[0]['filtering'], 'none')
        self.assertEqual(elec.iloc[0]['group'], group)

    def test_add_electrode_some_opt(self):
        dev1 = self.nwbfile.create_device(name='dev1')
        group = self.nwbfile.create_electrode_group(
            name='tetrode1',
            description='tetrode description',
            location='tetrode location',
            device=dev1)
        self.nwbfile.add_electrode(x=1.0,
                                   y=2.0,
                                   z=3.0,
                                   imp=-1.0,
                                   location='CA1',
                                   filtering='none',
                                   group=group,
                                   id=1,
                                   rel_x=4.0,
                                   rel_y=5.0,
                                   rel_z=6.0,
                                   reference='ref1')
        self.nwbfile.add_electrode(x=1.0,
                                   y=2.0,
                                   z=3.0,
                                   imp=-1.0,
                                   location='CA1',
                                   filtering='none',
                                   group=group,
                                   id=2,
                                   rel_x=7.0,
                                   rel_y=8.0,
                                   rel_z=9.0,
                                   reference='ref2')
        elec = self.nwbfile.electrodes[0]
        self.assertEqual(elec.iloc[0]['rel_x'], 4.0)
        self.assertEqual(elec.iloc[0]['rel_y'], 5.0)
        self.assertEqual(elec.iloc[0]['rel_z'], 6.0)
        self.assertEqual(elec.iloc[0]['reference'], 'ref1')
        elec = self.nwbfile.electrodes[1]
        self.assertEqual(elec.iloc[0]['rel_x'], 7.0)
        self.assertEqual(elec.iloc[0]['rel_y'], 8.0)
        self.assertEqual(elec.iloc[0]['rel_z'], 9.0)
        self.assertEqual(elec.iloc[0]['reference'], 'ref2')

    def test_all_children(self):
        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        name = 'example_electrode_group'
        desc = 'An example electrode'
        loc = 'an example location'
        device = self.nwbfile.create_device('a fake device')
        elecgrp = self.nwbfile.create_electrode_group(name, desc, loc, device)
        children = self.nwbfile.all_children()
        self.assertIn(ts1, children)
        self.assertIn(ts2, children)
        self.assertIn(device, children)
        self.assertIn(elecgrp, children)

    def test_fail_if_source_script_file_name_without_source_script(self):
        with self.assertRaises(ValueError):
            # <-- source_script_file_name without source_script is not allowed
            NWBFile('a test session description for a test NWBFile',
                    'FILE123',
                    self.start,
                    source_script=None,
                    source_script_file_name='nofilename')

    def test_get_neurodata_type(self):
        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = TimeSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        p1 = ts1.get_ancestor(neurodata_type='NWBFile')
        self.assertIs(p1, self.nwbfile)
        p2 = ts2.get_ancestor(neurodata_type='NWBFile')
        self.assertIs(p2, self.nwbfile)

    def test_print_units(self):
        self.nwbfile.add_unit(spike_times=[1., 2., 3.])
        expected = """units pynwb.misc.Units at 0x%d
Fields:
  colnames: ['spike_times']
  columns: (
    spike_times_index <class 'hdmf.common.table.VectorIndex'>,
    spike_times <class 'hdmf.common.table.VectorData'>
  )
  description: Autogenerated by NWBFile
  id: id <class 'hdmf.common.table.ElementIdentifiers'>
  waveform_unit: volts
"""
        expected = expected % id(self.nwbfile.units)
        self.assertEqual(str(self.nwbfile.units), expected)

    def test_copy(self):
        self.nwbfile.add_unit(spike_times=[1., 2., 3.])
        device = self.nwbfile.create_device('a')
        elecgrp = self.nwbfile.create_electrode_group('a',
                                                      'b',
                                                      device=device,
                                                      location='a')
        self.nwbfile.add_electrode(np.nan,
                                   np.nan,
                                   np.nan,
                                   np.nan,
                                   'a',
                                   'a',
                                   elecgrp,
                                   id=0)
        self.nwbfile.add_electrode(np.nan, np.nan, np.nan, np.nan, 'b', 'b',
                                   elecgrp)
        elec_region = self.nwbfile.create_electrode_table_region([1], 'name')

        ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5],
                         'grams',
                         timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        ts2 = ElectricalSeries('test_ts2', [0, 1, 2, 3, 4, 5],
                               electrodes=elec_region,
                               timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
        self.nwbfile.add_acquisition(ts1)
        self.nwbfile.add_acquisition(ts2)
        self.nwbfile.add_trial(start_time=50.0, stop_time=70.0)
        self.nwbfile.add_invalid_times_column(
            'comments', 'description of reason for omitting time')
        self.nwbfile.create_processing_module('test_mod', 'test_description')
        self.nwbfile.create_time_intervals('custom_interval',
                                           'a custom time interval')
        self.nwbfile.intervals['custom_interval'].add_interval(start_time=10.,
                                                               stop_time=20.)
        newfile = self.nwbfile.copy()

        # test dictionaries
        self.assertIs(self.nwbfile.devices['a'], newfile.devices['a'])
        self.assertIs(self.nwbfile.acquisition['test_ts1'],
                      newfile.acquisition['test_ts1'])
        self.assertIs(self.nwbfile.acquisition['test_ts2'],
                      newfile.acquisition['test_ts2'])
        self.assertIs(self.nwbfile.processing['test_mod'],
                      newfile.processing['test_mod'])

        # test dynamic tables
        self.assertIsNot(self.nwbfile.electrodes, newfile.electrodes)
        self.assertIs(self.nwbfile.electrodes['x'], newfile.electrodes['x'])
        self.assertIsNot(self.nwbfile.units, newfile.units)
        self.assertIs(self.nwbfile.units['spike_times'],
                      newfile.units['spike_times'])
        self.assertIsNot(self.nwbfile.trials, newfile.trials)
        self.assertIsNot(self.nwbfile.trials.parent, newfile.trials.parent)
        self.assertIs(self.nwbfile.trials.id, newfile.trials.id)
        self.assertIs(self.nwbfile.trials['start_time'],
                      newfile.trials['start_time'])
        self.assertIs(self.nwbfile.trials['stop_time'],
                      newfile.trials['stop_time'])
        self.assertIsNot(self.nwbfile.invalid_times, newfile.invalid_times)
        self.assertTupleEqual(self.nwbfile.invalid_times.colnames,
                              newfile.invalid_times.colnames)
        self.assertIsNot(self.nwbfile.intervals['custom_interval'],
                         newfile.intervals['custom_interval'])
        self.assertTupleEqual(
            self.nwbfile.intervals['custom_interval'].colnames,
            newfile.intervals['custom_interval'].colnames)
        self.assertIs(self.nwbfile.intervals['custom_interval']['start_time'],
                      newfile.intervals['custom_interval']['start_time'])
        self.assertIs(self.nwbfile.intervals['custom_interval']['stop_time'],
                      newfile.intervals['custom_interval']['stop_time'])

    def test_multi_experimenters(self):
        self.nwbfile = NWBFile('a test session description for a test NWBFile',
                               'FILE123',
                               self.start,
                               experimenter=('experimenter1', 'experimenter2'))
        self.assertTupleEqual(self.nwbfile.experimenter,
                              ('experimenter1', 'experimenter2'))

    def test_multi_publications(self):
        self.nwbfile = NWBFile('a test session description for a test NWBFile',
                               'FILE123',
                               self.start,
                               related_publications=('pub1', 'pub2'))
        self.assertTupleEqual(self.nwbfile.related_publications,
                              ('pub1', 'pub2'))
Esempio n. 4
0
def no2nwb(NOData, session_use, subjects_ini, path_to_data):
    '''
       Purpose:
           Import the data and associated meta-data from the new/old recognition dataset into an
           NWB file. Each of the features of the dataset, such as the events (i.e., TTLs) or mean waveform, are
           compartmentalized to the appropriate component of the NWB file.


    '''

    # Time scaling (covert uS -----> S for NWB file)
    TIME_SCALING = 10**6

    # Prepare the NO data that will be coverted to the NWB format

    session = NOData.sessions[session_use]
    events = NOData._get_event_data(session_use, experiment_type='All')
    cell_ids = NOData.ls_cells(session_use)
    experiment_id_learn = session['experiment_id_learn']
    experiment_id_recog = session['experiment_id_recog']
    task_descr = session['task_descr']

    # Get the metadata for the subject
    # ============ Read Config File ()
    # load config file (subjects == config file)

    #  Check config file path
    filename = subjects_ini
    if not os.path.exists(filename):
        print('This file does not exist: {}'.format(filename))
        print("Check filename/and or directory")

    # Read the config file
    try:
        # initialze the ConfigParser() class
        config = configparser.ConfigParser()
        # read .ini file
        config.read(filename)
    except:
        print('Failed to read the config file..')
        print('Does this file exist: {}'.format(os.path.exists(filename)))

    #  Read Meta-data from INI file.
    for section in config.sections():
        if session_use == int(section):
            session_id = int(section)  #  The New/Old ID for the session
            #Get the session ID
            for value in config[section]:
                if value.lower() == 'nosessions.age':
                    age = int(config[section][value])
                if value.lower() == 'nosessions.diagnosiscode':
                    epilepsyDxCode = config[section][value]
                    epilepsyDx = getEpilepsyDx(int(epilepsyDxCode))
                if value.lower() == 'nosessions.sex':
                    sex = config[section][value].strip("'")
                if value.lower() == 'nosessions.id':
                    ID = config[section][value].strip("'")
                if value.lower() == 'nosessions.session':
                    pt_session = config[section][value].strip("'")
                if value.lower() == 'nosessions.date':
                    unformattedDate = config[section][value].strip("'")
                    date = datetime.strptime(unformattedDate, '%Y-%m-%d')
                    finaldate = date.replace(hour=0, minute=0)
                if value.lower() == 'nosessions.institution':
                    institution = config[section][value].strip("'")
                if value.lower() == 'nosessions.la':
                    LA = config[section][value].strip("'").split(',')
                    if LA[0] == 'NaN':
                        LA_x = np.nan
                        LA_y = np.nan
                        LA_z = np.nan
                    else:
                        LA_x = float(LA[0])
                        LA_y = float(LA[1])
                        LA_z = float(LA[2])
                if value.lower() == 'nosessions.ra':
                    RA = config[section][value].strip("'").split(',')
                    if RA[0] == 'NaN':
                        RA_x = np.nan
                        RA_y = np.nan
                        RA_z = np.nan
                    else:
                        RA_x = float(RA[0])
                        RA_y = float(RA[1])
                        RA_z = float(RA[2])
                if value.lower() == 'nosessions.lh':
                    LH = config[section][value].strip("'").split(',')
                    if LH[0] == 'NaN':
                        LH_x = np.nan
                        LH_y = np.nan
                        LH_z = np.nan
                    else:
                        LH_x = float(LH[0])
                        LH_y = float(LH[1])
                        LH_z = float(LH[2])
                if value.lower() == 'nosessions.rh':
                    RH = config[section][value].strip("'").split(',')
                    if RH[0] == 'NaN':
                        RH_x = np.nan
                        RH_y = np.nan
                        RH_z = np.nan
                    else:
                        RH_x = float(RH[0])
                        RH_y = float(RH[1])
                        RH_z = float(RH[2])
                if value.lower() == 'nosessions.system':
                    signalSystem = config[section][value].strip("'")

    # =================================================================

    print(
        '======================================================================='
    )
    print('session use: {}'.format(session_id))
    print('age: {}'.format(age))
    print('epilepsy_diagnosis: {}'.format(epilepsyDx))

    nwb_subject = Subject(age=str(age),
                          description=epilepsyDx,
                          sex=sex,
                          species='Human',
                          subject_id=pt_session)

    # Create the NWB file
    nwbfile = NWBFile(
        #source='https://datadryad.org/bitstream/handle/10255/dryad.163179/RecogMemory_MTL_release_v2.zip',
        session_description='New/Old recognition task for ID: {}. '.format(
            session_id),
        identifier='{}_{}'.format(ID, session_use),
        session_start_time=finaldate,  #default session start time
        file_create_date=datetime.now(),
        experiment_description=
        'The data contained within this file describes a new/old recogntion task performed in '
        'patients with intractable epilepsy implanted with depth electrodes and Behnke-Fried '
        'microwires in the human Medical Temporal Lobe (MTL).',
        institution=institution,
        keywords=[
            'Intracranial Recordings', 'Intractable Epilepsy',
            'Single-Unit Recordings', 'Cognitive Neuroscience', 'Learning',
            'Memory', 'Neurosurgery'
        ],
        related_publications=
        'Faraut et al. 2018, Scientific Data; Rutishauser et al. 2015, Nat Neurosci;',
        lab='Rutishauser',
        subject=nwb_subject,
        data_collection='learning: {}, recognition: {}'.format(
            session['experiment_id_learn'], session['experiment_id_recog']))

    # Add events and experiment_id acquisition
    events_description = (
        """ The events coorespond to the TTL markers for each trial. For the learning trials, the TTL markers 
            are the following: 55 = start of the experiment, 1 = stimulus ON, 2 = stimulus OFF, 3 = Question Screen Onset [“Is this an animal?”], 
            20 = Yes (21 = NO) during learning, 6 = End of Delay after Response, 66 = End of Experiment. For the recognition trials, 
            the TTL markers are the following: 55 = start of experiment, 1 = stimulus ON, 2 = stimulus OFF, 3 = Question Screen Onset [“Have you seen this image before?”], 
            31:36 = Confidence (Yes vs. No) response [31 (new, confident), 32 (new, probably), 33 (new, guess), 34 (old, guess), 
            35 (old, probably), 36 (old, confident)], 66 = End of Experiment"""
    )

    event_ts = AnnotationSeries(name='events',
                                data=np.asarray(events[1].values).astype(str),
                                timestamps=np.asarray(events[0].values) /
                                TIME_SCALING,
                                description=events_description)

    experiment_ids_description = (
        """The experiment_ids coorespond to the encoding (i.e., learning) or recogniton trials. The learning trials are demarcated by: {}. The recognition trials are demarcated by: {}. """
        .format(experiment_id_learn, experiment_id_recog))

    experiment_ids = TimeSeries(name='experiment_ids',
                                unit='NA',
                                data=np.asarray(events[2]),
                                timestamps=np.asarray(events[0].values) /
                                TIME_SCALING,
                                description=experiment_ids_description)

    nwbfile.add_acquisition(event_ts)
    nwbfile.add_acquisition(experiment_ids)

    # Add stimuli to the NWB file
    # Get the first cell from the cell list
    cell = NOData.pop_cell(session_use,
                           NOData.ls_cells(session_use)[0], path_to_data)
    trials = cell.trials
    stimuli_recog_path = [trial.file_path_recog for trial in trials]
    stimuli_learn_path = [trial.file_path_learn for trial in trials]

    # Add epochs and trials: storing start and end times for a stimulus

    # First extract the category ids and names that we need
    # The metadata for each trials will be store in a trial table

    cat_id_recog = [trial.category_recog for trial in trials]
    cat_name_recog = [trial.category_name_recog for trial in trials]
    cat_id_learn = [trial.category_learn for trial in trials]
    cat_name_learn = [trial.category_name_learn for trial in trials]

    # Extract the event timestamps
    events_learn_stim_on = events[(events[2] == experiment_id_learn) &
                                  (events[1] == NOData.markers['stimulus_on'])]
    events_learn_stim_off = events[(events[2] == experiment_id_learn) & (
        events[1] == NOData.markers['stimulus_off'])]
    events_learn_delay1_off = events[(events[2] == experiment_id_learn) & (
        events[1] == NOData.markers['delay1_off'])]
    events_learn_delay2_off = events[(events[2] == experiment_id_learn) & (
        events[1] == NOData.markers['delay2_off'])]
    events_learn = events[(events[2] == experiment_id_learn)]
    events_learn_response = []
    events_learn_response_time = []
    for i in range(len(events_learn[0])):
        if (events_learn.iloc[i, 1]
                == NOData.markers['response_learning_animal']) or (
                    events_learn.iloc[i, 1]
                    == NOData.markers['response_learning_non_animal']):
            events_learn_response.append(events_learn.iloc[i, 1] - 20)
            events_learn_response_time.append(events_learn.iloc[i, 0])

    events_recog_stim_on = events[(events[2] == experiment_id_recog) &
                                  (events[1] == NOData.markers['stimulus_on'])]
    events_recog_stim_off = events[(events[2] == experiment_id_recog) & (
        events[1] == NOData.markers['stimulus_off'])]
    events_recog_delay1_off = events[(events[2] == experiment_id_recog) & (
        events[1] == NOData.markers['delay1_off'])]
    events_recog_delay2_off = events[(events[2] == experiment_id_recog) & (
        events[1] == NOData.markers['delay2_off'])]
    events_recog = events[(events[2] == experiment_id_recog)]
    events_recog_response = []
    events_recog_response_time = []
    for i in range(len(events_recog[0])):
        if ((events_recog.iloc[i, 1] == NOData.markers['response_1'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_2'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_3'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_4'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_5'])
                or (events_recog.iloc[i, 1] == NOData.markers['response_6'])):
            events_recog_response.append(events_recog.iloc[i, 1])
            events_recog_response_time.append(events_recog.iloc[i, 0])

    # Extract new_old label
    new_old_recog = [trial.new_old_recog for trial in trials]
    # Create the trial tables

    nwbfile.add_trial_column('stim_on_time',
                             'The Time when the Stimulus is Shown')
    nwbfile.add_trial_column('stim_off_time',
                             'The Time when the Stimulus is Off')
    nwbfile.add_trial_column('delay1_time', 'The Time when Delay1 is Off')
    nwbfile.add_trial_column('delay2_time', 'The Time when Delay2 is Off')
    nwbfile.add_trial_column('stim_phase',
                             'Learning/Recognition Phase During the Trial')
    nwbfile.add_trial_column('stimCategory', 'The Category ID of the Stimulus')
    nwbfile.add_trial_column('category_name',
                             'The Category Name of the Stimulus')
    nwbfile.add_trial_column('external_image_file',
                             'The File Path to the Stimulus')
    nwbfile.add_trial_column(
        'new_old_labels_recog',
        '''The Ground truth Labels for New or Old Stimulus. 0 == Old Stimuli 
                            (presented during the learning phase), 1 = New Stimuli (not seen )'during learning phase'''
    )
    nwbfile.add_trial_column('response_value',
                             'The Response for Each Stimulus')
    nwbfile.add_trial_column('response_time',
                             'The Response Time for each Stimulus')

    range_recog = np.amin([
        len(events_recog_stim_on),
        len(events_recog_stim_off),
        len(events_recog_delay1_off),
        len(events_recog_delay2_off)
    ])
    range_learn = np.amin([
        len(events_learn_stim_on),
        len(events_learn_stim_off),
        len(events_learn_delay1_off),
        len(events_learn_delay2_off)
    ])

    # Iterate the event list and add information into each epoch and trial table
    for i in range(range_learn):

        nwbfile.add_trial(
            start_time=(events_learn_stim_on.iloc[i][0]) / (TIME_SCALING),
            stop_time=(events_learn_delay2_off.iloc[i][0]) / (TIME_SCALING),
            stim_on_time=(events_learn_stim_on.iloc[i][0]) / (TIME_SCALING),
            stim_off_time=(events_learn_stim_off.iloc[i][0]) / (TIME_SCALING),
            delay1_time=(events_learn_delay1_off.iloc[i][0]) / (TIME_SCALING),
            delay2_time=(events_learn_delay2_off.iloc[i][0]) / (TIME_SCALING),
            stim_phase='learn',
            stimCategory=cat_id_learn[i],
            category_name=cat_name_learn[i],
            external_image_file=stimuli_learn_path[i],
            new_old_labels_recog='NA',
            response_value=events_learn_response[i],
            response_time=(events_learn_response_time[i]) / (TIME_SCALING))

    for i in range(range_recog):

        nwbfile.add_trial(
            start_time=events_recog_stim_on.iloc[i][0] / (TIME_SCALING),
            stop_time=events_recog_delay2_off.iloc[i][0] / (TIME_SCALING),
            stim_on_time=events_recog_stim_on.iloc[i][0] / (TIME_SCALING),
            stim_off_time=events_recog_stim_off.iloc[i][0] / (TIME_SCALING),
            delay1_time=events_recog_delay1_off.iloc[i][0] / (TIME_SCALING),
            delay2_time=events_recog_delay2_off.iloc[i][0] / (TIME_SCALING),
            stim_phase='recog',
            stimCategory=cat_id_recog[i],
            category_name=cat_name_recog[i],
            external_image_file=stimuli_recog_path[i],
            new_old_labels_recog=new_old_recog[i],
            response_value=events_recog_response[i],
            response_time=events_recog_response_time[i] / (TIME_SCALING))

    # Add the waveform clustering and the spike data.
    # Get the unique channel id that we will be iterate over
    channel_ids = np.unique([cell_id[0] for cell_id in cell_ids])

    # unique unit id
    unit_id = 0

    # Create unit columns
    nwbfile.add_unit_column('origClusterID', 'The original cluster id')
    nwbfile.add_unit_column('waveform_mean_encoding',
                            'The mean waveform for encoding phase.')
    nwbfile.add_unit_column('waveform_mean_recognition',
                            'The mean waveform for the recognition phase.')
    nwbfile.add_unit_column('IsolationDist', 'IsolDist')
    nwbfile.add_unit_column('SNR', 'SNR')
    nwbfile.add_unit_column('waveform_mean_sampling_rate',
                            'The Sampling Rate of Waveform')

    #Add Stimuli
    stimuli_presentation = []

    # Add stimuli learn
    counter = 1
    for path in stimuli_learn_path:
        if path == 'NA':
            continue
        folders = path.split('\\')

        path = os.path.join(path_to_data, 'Stimuli', folders[0], folders[1],
                            folders[2])
        img = cv2.imread(path)
        resized_image = cv2.resize(img, (300, 400))
        stimuli_presentation.append(resized_image)

    # Add stimuli recog
    counter = 1
    for path in stimuli_recog_path:
        folders = path.split('\\')
        path = os.path.join(path_to_data, 'Stimuli', folders[0], folders[1],
                            folders[2])
        img = cv2.imread(path)
        resized_image = cv2.resize(img, (300, 400))
        stimuli_presentation.append(resized_image)
        name = 'stimuli_recog_' + str(counter)

    # Add stimuli to OpticalSeries
    stimulus_presentation_on_time = []

    for n in range(0, len(events_learn_stim_on)):
        stimulus_presentation_on_time.append(events_learn_stim_on.iloc[n][0] /
                                             (TIME_SCALING))

    for n in range(0, len(events_recog_stim_on)):
        stimulus_presentation_on_time.append(events_recog_stim_on.iloc[n][0] /
                                             (TIME_SCALING))

    name = 'StimulusPresentation'
    stimulus = OpticalSeries(name=name,
                             data=stimuli_presentation,
                             timestamps=stimulus_presentation_on_time[:],
                             orientation='lower left',
                             format='raw',
                             unit='meters',
                             field_of_view=[.2, .3, .7],
                             distance=0.7,
                             dimension=[300, 400, 3])

    nwbfile.add_stimulus(stimulus)

    # Get Unit data
    all_spike_cluster_ids = []
    all_selected_time_stamps = []
    all_IsolDist = []
    all_SNR = []
    all_selected_mean_waveform_learn = []
    all_selected_mean_waveform_recog = []
    all_mean_waveform = []
    all_channel_id = []
    all_oriClusterIDs = []
    all_channel_numbers = []
    all_brain_area = []
    # Iterate the channel list

    # load brain area file
    brain_area_file_path = os.path.join(path_to_data, 'Data', 'events',
                                        session['session'], task_descr,
                                        'brainArea.mat')

    try:
        brain_area_mat = loadmat(brain_area_file_path)
    except FileNotFoundError:
        print("brain_area_mat file not found")

    for channel_id in channel_ids:
        cell_name = 'A' + str(channel_id) + '_cells.mat'
        cell_file_path = os.path.join(path_to_data, 'Data', 'sorted',
                                      session['session'], task_descr,
                                      cell_name)

        try:
            cell_mat = loadmat(cell_file_path)
        except FileNotFoundError:
            print("cell mat file not found")
            continue

        spikes = cell_mat['spikes']
        meanWaveform_recog = cell_mat['meanWaveform_recog']
        meanWaveform_learn = cell_mat['meanWaveform_learn']
        IsolDist_SNR = cell_mat['IsolDist_SNR']

        spike_cluster_id = np.asarray([spike[1] for spike in spikes
                                       ])  # Each Cluster ID of the spike
        spike_timestamps = np.asarray(
            [spike[2]
             for spike in spikes])  # Timestamps of spikes for each ClusterID
        unique_cluster_ids = np.unique(spike_cluster_id)

        # If there are more than one cluster.
        for id in unique_cluster_ids:

            # Grab brain area
            brain_area = extra_brain_area(brain_area_mat, channel_id)

            selected_spike_timestamps = spike_timestamps[spike_cluster_id ==
                                                         id]
            IsolDist, SNR = extract_IsolDist_SNR_by_cluster_id(
                IsolDist_SNR, id)
            selected_mean_waveform_learn = extra_mean_waveform(
                meanWaveform_learn, id)
            selected_mean_waveform_recog = extra_mean_waveform(
                meanWaveform_recog, id)

            # If the mean waveform does not have 256 elements, we set the mean wave form to all 0
            if len(selected_mean_waveform_learn) != 256:
                selected_mean_waveform_learn = np.zeros(256)
            if len(selected_mean_waveform_recog) != 256:
                selected_mean_waveform_recog = np.zeros(256)

            mean_waveform = np.hstack(
                [selected_mean_waveform_learn, selected_mean_waveform_recog])

            # Append unit data
            all_spike_cluster_ids.append(id)
            all_selected_time_stamps.append(selected_spike_timestamps)
            all_IsolDist.append(IsolDist)
            all_SNR.append(SNR)
            all_selected_mean_waveform_learn.append(
                selected_mean_waveform_learn)
            all_selected_mean_waveform_recog.append(
                selected_mean_waveform_recog)
            all_mean_waveform.append(mean_waveform)
            all_channel_id.append(channel_id)
            all_oriClusterIDs.append(int(id))
            all_channel_numbers.append(channel_id)
            all_brain_area.append(brain_area)

            unit_id += 1

    nwbfile.add_electrode_column(
        name='origChannel',
        description='The original channel ID for the channel')

    #Add Device
    device = nwbfile.create_device(name=signalSystem)

    # Add Electrodes (brain Area Locations, MNI coordinates for microwires)
    length_all_spike_cluster_ids = len(all_spike_cluster_ids)
    for electrodeNumber in range(0, len(channel_ids)):

        brainArea_location = extra_brain_area(brain_area_mat,
                                              channel_ids[electrodeNumber])

        if brainArea_location == 'RH':  #  Right Hippocampus
            full_brainArea_Location = 'Right Hippocampus'

            electrode_name = '{}-microwires-{}'.format(
                signalSystem, channel_ids[electrodeNumber])
            description = "Behnke Fried/Micro Inner Wire Bundle (Behnke-Fried BF08R-SP05X-000 and WB09R-SP00X-0B6; Ad-Tech Medical)"
            location = full_brainArea_Location

            # Add electrode group
            electrode_group = nwbfile.create_electrode_group(
                electrode_name,
                description=description,
                location=location,
                device=device)

            #Add Electrode
            nwbfile.add_electrode([channel_ids[electrodeNumber]],
                                  x=RH_x,
                                  y=RH_y,
                                  z=RH_z,
                                  imp=np.nan,
                                  location=full_brainArea_Location,
                                  filtering='300-3000Hz',
                                  group=electrode_group,
                                  origChannel=channel_ids[electrodeNumber])

        if brainArea_location == 'LH':
            full_brainArea_Location = 'Left Hippocampus'

            electrode_name = '{}-microwires-{}'.format(
                signalSystem, channel_ids[electrodeNumber])
            description = "Behnke Fried/Micro Inner Wire Bundle (Behnke-Fried BF08R-SP05X-000 and WB09R-SP00X-0B6; Ad-Tech Medical)"
            location = full_brainArea_Location

            # Add electrode group
            electrode_group = nwbfile.create_electrode_group(
                electrode_name,
                description=description,
                location=location,
                device=device)

            nwbfile.add_electrode([all_channel_id[electrodeNumber]],
                                  x=LH_x,
                                  y=LH_y,
                                  z=LH_z,
                                  imp=np.nan,
                                  location=full_brainArea_Location,
                                  filtering='300-3000Hz',
                                  group=electrode_group,
                                  origChannel=channel_ids[electrodeNumber])
        if brainArea_location == 'RA':
            full_brainArea_Location = 'Right Amygdala'

            electrode_name = '{}-microwires-{}'.format(
                signalSystem, channel_ids[electrodeNumber])
            description = "Behnke Fried/Micro Inner Wire Bundle (Behnke-Fried BF08R-SP05X-000 and WB09R-SP00X-0B6; Ad-Tech Medical)"
            location = full_brainArea_Location

            # Add electrode group
            electrode_group = nwbfile.create_electrode_group(
                electrode_name,
                description=description,
                location=location,
                device=device)

            nwbfile.add_electrode([all_channel_id[electrodeNumber]],
                                  x=RA_x,
                                  y=RA_y,
                                  z=RA_z,
                                  imp=np.nan,
                                  location=full_brainArea_Location,
                                  filtering='300-3000Hz',
                                  group=electrode_group,
                                  origChannel=channel_ids[electrodeNumber])
        if brainArea_location == 'LA':
            full_brainArea_Location = 'Left Amygdala'

            electrode_name = '{}-microwires-{}'.format(
                signalSystem, channel_ids[electrodeNumber])
            description = "Behnke Fried/Micro Inner Wire Bundle (Behnke-Fried BF08R-SP05X-000 and WB09R-SP00X-0B6; Ad-Tech Medical)"
            location = full_brainArea_Location

            # Add electrode group
            electrode_group = nwbfile.create_electrode_group(
                electrode_name,
                description=description,
                location=location,
                device=device)

            nwbfile.add_electrode([all_channel_id[electrodeNumber]],
                                  x=LA_x,
                                  y=LA_y,
                                  z=LA_z,
                                  imp=np.nan,
                                  location=full_brainArea_Location,
                                  filtering='300-3000Hz',
                                  group=electrode_group,
                                  origChannel=channel_ids[electrodeNumber])

    # Create Channel list index
    channel_list = list(range(0, length_all_spike_cluster_ids))
    unique_channel_ids = np.unique(all_channel_id)
    length_ChannelIds = len(np.unique(all_channel_id))
    for yy in range(0, length_ChannelIds):
        a = np.array(np.where(unique_channel_ids[yy] == all_channel_id))
        b = a[0]
        c = b.tolist()
        for i in c:
            channel_list[i] = yy

    #Add WAVEFORM Sampling RATE
    waveform_mean_sampling_rate = [98.4 * 10**3]
    waveform_mean_sampling_rate_matrix = [waveform_mean_sampling_rate
                                          ] * (length_all_spike_cluster_ids)

    # Add Units to NWB file
    for index_id in range(0, length_all_spike_cluster_ids):
        nwbfile.add_unit(
            id=index_id,
            spike_times=all_selected_time_stamps[index_id],
            origClusterID=all_oriClusterIDs[index_id],
            IsolationDist=all_IsolDist[index_id],
            SNR=all_SNR[index_id],
            waveform_mean_encoding=all_selected_mean_waveform_learn[index_id],
            waveform_mean_recognition=all_selected_mean_waveform_recog[
                index_id],
            electrodes=[channel_list[index_id]],
            waveform_mean_sampling_rate=waveform_mean_sampling_rate_matrix[
                index_id])

    return nwbfile
def main():
    # Get absolute path to file
    filename = Path(sys.argv[1]).resolve().as_posix()

    # TODO use with / as context manager - for handling open/closing files
    # TODO update the plexon API to use exceptions instead of this clunky
    # if result == 0 code
    file_reader = PyPL2FileReader()
    file_handle = file_reader.pl2_open_file(filename)

    if (file_handle == 0):
        print_error(file_reader)

    # create the PL2FileInfo instance containing basic header information
    file_info = PL2FileInfo()
    res = file_reader.pl2_get_file_info(file_handle, file_info)

    if (res == 0):
        print_error(file_reader)

    # USER NEEDS TO INPUT:

    # create the NWBFile instance
    session_description = 'Pulvinar recording from McCartney'
    id = 'M20170127'
    session_start_time = tm_to_datetime(file_info.m_CreatorDateTime)
    timezone = pytz.timezone("America/New_York")
    experimenter = 'Ryan Ly'
    lab = 'Kastner Lab'
    institution = 'Princeton University'
    experiment_description = 'Neural correlates of visual attention across the pulvinar'
    session_id = id
    data_collection = file_info.m_CreatorComment.decode('ascii')

    session_start_time = timezone.localize(session_start_time)
    nwbfile = NWBFile(session_description=session_description,
                      identifier=id,
                      session_start_time=session_start_time,
                      experimenter=experimenter,
                      lab=lab,
                      institution=institution,
                      experiment_description=experiment_description,
                      session_id=session_id,
                      data_collection=data_collection)

    # TODO add in the reprocessor metadata from file_info

    # create a recording device instance
    plexon_device_name = file_info.m_CreatorSoftwareName.decode('ascii') + '_v' + \
                         file_info.m_CreatorSoftwareVersion.decode('ascii')
    plexon_device = nwbfile.create_device(name=plexon_device_name)

    eye_trac_device_name = 'ASL_Eye-trac_6_via_' + plexon_device_name
    eye_trac_device = nwbfile.create_device(name=eye_trac_device_name)

    lever_device_name = 'Manual_lever_via_' + plexon_device_name
    lever_device = nwbfile.create_device(name=lever_device_name)
    # TODO update Device to take in metadata information about the Device

    # create electrode groups representing single shanks or other kinds of data
    # create list of metadata for adding into electrode groups. importantly, sasdasdsads
    # these metadata
    electrode_group_metadata = []
    electrode_group_metadata.append({'name': '32ch-array',
                                    'description': '32-channel_array',
                                    'location': 'Pulvinar',
                                    'device': plexon_device,
                                    'channel_ids': range(1, 33)})
    electrode_group_metadata.append({'name': 'test_electrode_1',
                                    'description': 'test_electrode_1',
                                    'location': 'unknown',
                                    'device': plexon_device,
                                    'channel_ids': [97]})
    electrode_group_metadata.append({'name': 'test_electrode_2',
                                    'description': 'test_electrode_2',
                                    'location': 'unknown',
                                    'device': plexon_device,
                                    'channel_ids': [98]})
    electrode_group_metadata.append({'name': 'test_electrode_3',
                                    'description': 'test_electrode_3',
                                    'location': 'unknown',
                                    'device': plexon_device,
                                    'channel_ids': [99]})
    non_electrode_ts_metadata = []
    non_electrode_ts_metadata.append({'name': 'eyetracker_x_voltage',
                                     'description': 'eyetracker_x_voltage',
                                     'location': 'n/a',
                                     'device': eye_trac_device,
                                     'channel_ids': [126]})
    non_electrode_ts_metadata.append({'name': 'eyetracker_y_voltage',
                                     'description': 'eyetracker_y_voltage',
                                     'location': 'n/a',
                                     'device': eye_trac_device,
                                     'channel_ids': [127]})
    non_electrode_ts_metadata.append({'name': 'lever_voltage',
                                     'description': 'lever_voltage',
                                     'location': 'n/a',
                                     'device': lever_device,
                                     'channel_ids': [128]})

    # make an electrode group for every group of channel IDs specified
    electrode_groups = []
    map_electrode_group_to_channel_ids = []
    for egm in electrode_group_metadata:
        print(f'Creating electrode group named "{egm["name"]}"')
        eg = nwbfile.create_electrode_group(name=egm['name'],
                                            description=egm['description'],
                                            location=egm['location'],
                                            device=egm['device'])
        electrode_groups.append(eg)
        map_electrode_group_to_channel_ids.append(egm['channel_ids'])

    # group indices of analog channels in the Plexon file by type, then source
    wb_src_chans = {};
    fp_src_chans = {};
    spkc_src_chans = {};
    ai_src_chans = {};
    aif_src_chans = {};

    for pl2_ind in range(file_info.m_TotalNumberOfAnalogChannels):
        achannel_info = PL2AnalogChannelInfo()
        res = file_reader.pl2_get_analog_channel_info(file_handle, pl2_ind, achannel_info)
        if (res == 0):
            print_error(file_reader)
            break

        if (achannel_info.m_ChannelEnabled and
                achannel_info.m_ChannelRecordingEnabled and
                achannel_info.m_NumberOfValues > 0 and
                achannel_info.m_MaximumNumberOfFragments > 0):
            # store zero-based channel index and electrode channel id
            achan_name = achannel_info.m_Name.decode('ascii');
            if achan_name.startswith('WB'):
                src_chans = wb_src_chans
            elif achan_name.startswith('FP') and get_channel_id(achan_name) <= 125:
                src_chans = fp_src_chans
            elif achan_name.startswith('FP') and get_channel_id(achan_name) > 125:
                src_chans = ai_src_chans
            elif achan_name.startswith('SPKC'):
                src_chans = spkc_src_chans
            elif achan_name.startswith('AI'):
                src_chans = ai_src_chans
            elif achan_name.startswith('AIF'):
                src_chans = aif_src_chans
            else:
                warnings.warn('Unrecognized analog channel: ' + achan_name)
                break

            channel_id = get_channel_id(achan_name)

            # src_chans is a dict {Plexon source ID : list of dict ...
            # {'pl2_ind': analog channel ind, 'channel_id': electrode channel ID}}
            chans = {'pl2_ind': pl2_ind, 'channel_id': channel_id}
            if not achannel_info.m_Source in src_chans:
                src_chans[achannel_info.m_Source] = [chans]
            else:
                src_chans[achannel_info.m_Source].append(chans)

    # create electrodes and create a region of indices in the electrode table
    # corresponding to the electrodes used for this type of analog data (WB, FP,
    # SPKC, AI, AIF)
    if wb_src_chans:
        for src, chans in wb_src_chans.items():
            channel_ids_all = [ch['channel_id'] for ch in chans]
            channel_ids_by_group = get_channel_ids_by_metadata_list(
                    channel_ids_all, electrode_group_metadata)
            for channel_ids, eg in zip(channel_ids_by_group, electrode_groups):
                if channel_ids:
                    # find the mapped pl2 index again
                    group_chans = [c for c in chans if c['channel_id'] in channel_ids]
                    add_electrodes(nwbfile, channel_ids, eg)
                    wb_es = pl2_create_electrode_timeseries(nwbfile,
                                                            file_reader,
                                                            file_handle,
                                                            file_info.m_TimestampFrequency,
                                                            group_chans,
                                                            'Wideband_voltages_' + eg.name,
                                                            'Wideband electrodes, group ' + eg.name)
                    nwbfile.add_acquisition(wb_es)

    if fp_src_chans or spkc_src_chans:
        ecephys_module = nwbfile.create_processing_module(name='ecephys',
                                                          description='Processed extracellular electrophysiology data')

    if fp_src_chans:
        # LFP signal can come from multiple Plexon "sources"
        # for src, chans in fp_src_chans.items():
        #     channel_ids_all = [ch['channel_id'] for ch in chans]
        #     channel_ids_by_group = get_channel_ids_by_metadata_list(
        #             channel_ids_all, electrode_group_metadata)
        #
        #     # channel_ids_by_group is a list that parallels
        #     # electrode_groups. it contains a list of lists of channel_ids that
        #     # are used for this type and source of analog data which are part of
        #     # the corresponding electrode group
        #     # for example, if the ElectrodeGroup at index 2 contains electrodes
        #     # with the channel IDs 1-32, and there is analog channel data for
        #     # channel ID 4-6, then channel_ids_by_group[2] would have
        #     # [4, 5, 6].
        #     d = []
        #     for c in channel_ids_all:
        #         for i in range(len(map_channel_ids_to_electrode_groups)):
        #             if c in get_channel_ids_by_metadata_list[i]:
        #                 d[i].append(c)
        #
        #     for channel_ids, eg in zip(channel_ids_by_group, electrode_groups):
        for src, chans in fp_src_chans.items():
            channel_ids = [ch['channel_id'] for ch in chans]
            electrode_group_to_channel_ids = get_channel_ids_by_metadata_list(
                    channel_ids, electrode_group_metadata)
            for channel_ids, eg, egm in zip(electrode_group_to_channel_ids,
                                                  electrode_groups,
                                                  electrode_group_metadata):
                if channel_ids:
                    group_chans = [c for c in chans if c['channel_id'] in channel_ids]
                    add_electrodes(nwbfile, channel_ids, eg)
                    lfp_es = pl2_create_electrode_timeseries(nwbfile,
                                                             file_reader,
                                                             file_handle,
                                                             file_info.m_TimestampFrequency,
                                                             group_chans,
                                                             'LFP_voltages_' + eg.name,
                                                             ('LFP electrodes, group ' + eg.name +
                                                              '. Low-pass filtering at 200 Hz done online by Plexon data acquisition system.'))

                    print('Adding LFP processing module with electrical series for channel ids [' +
                          ', '.join(str(x) for x in channel_ids) + '] for electrode group ' +
                          eg.name)
                    # TODO add LFP filter properties, though these are not stored in the PL2
                    # file
                    lfp = LFP(lfp_es, name='LFP_' + egm['name'])
                    ecephys_module.add(lfp)

    spkc_es = None
    if spkc_src_chans:
        for src, chans in spkc_src_chans.items():
            channel_ids = [ch['channel_id'] for ch in chans]
            electrode_group_to_channel_ids = get_channel_ids_by_metadata_list(
                    channel_ids, electrode_group_metadata)
            for channel_ids, eg, egm in zip(electrode_group_to_channel_ids,
                                                  electrode_groups,
                                                  electrode_group_metadata):
                if channel_ids:
                    group_chans = [c for c in chans if c['channel_id'] in channel_ids]
                    add_electrodes(nwbfile, channel_ids, eg)
                    spkc_es = pl2_create_electrode_timeseries(nwbfile,
                                                              file_reader,
                                                              file_handle,
                                                              file_info.m_TimestampFrequency,
                                                              group_chans,
                                                              'High-pass_filtered_voltages_' + egm['name'],
                                                              ('High-pass filtered ("SPKC") electrodes, group ' + egm['name'] +
                                                               '. High-pass filtering at 300 Hz done online by Plexon data acquisition system.'))

                    print('Adding SPKC processing module with electrical series for channel ids [' +
                          ', '.join(str(x) for x in channel_ids) + '] for electrode group ' +
                          egm['name'])
                    # TODO add SPKC filter properties, though these are not stored in the PL2
                    # file
                    filt_ephys = FilteredEphys(spkc_es, name='SPKC_' + egm['name'])
                    ecephys_module.add(filt_ephys)

    if ai_src_chans:
        for src, chans in ai_src_chans.items():
            channel_ids_all = [ch['channel_id'] for ch in chans]
            channel_ids_by_group = get_channel_ids_by_metadata_list(
                    channel_ids_all, non_electrode_ts_metadata)
            for channel_ids, gm in zip(channel_ids_by_group, non_electrode_ts_metadata):
                if channel_ids:
                    # find the mapped pl2 index again
                    pl2_inds = [c['pl2_ind'] for c in chans if c['channel_id'] in channel_ids]
                    ai_es = pl2_create_timeseries(nwbfile,
                                                  file_reader,
                                                  file_handle,
                                                  file_info.m_TimestampFrequency,
                                                  pl2_inds,
                                                  ('Auxiliary_input_' + str(src) +
                                                   '_' + gm['name']),
                                                  ('Auxiliary input, source ' + str(src) +
                                                   ', ' + gm['name']))
                    nwbfile.add_acquisition(ai_es)

    if aif_src_chans:
        for src, chans in aif_src_chans.items():
            channel_ids_all = [ch['channel_id'] for ch in chans]
            channel_ids_by_group = get_channel_ids_by_metadata_list(
                    channel_ids_all, non_electrode_ts_metadata)
            for channel_ids, gm in zip(channel_ids_by_group, non_electrode_ts_metadata):
                if channel_ids:
                    # find the mapped pl2 index again
                    pl2_inds = [c['pl2_ind'] for c in chans if c['channel_id'] in channel_ids]
                    aif_es = pl2_create_timeseries(nwbfile,
                                                  file_reader,
                                                  file_handle,
                                                  file_info.m_TimestampFrequency,
                                                  pl2_inds,
                                                  ('Filtered_auxiliary_input_' + str(src) +
                                                   '_' + gm['name']),
                                                  ('Filtered auxiliary input, source ' + str(src) +
                                                   ', ' + gm['name']))
                    nwbfile.add_acquisition(aif_es)

    #### Spikes ####

    # add these columns to unit table
    nwbfile.add_unit_column('pre_threshold_samples', 'number of samples before threshold')
    nwbfile.add_unit_column('num_samples', 'number of samples for each spike waveform')
    nwbfile.add_unit_column('num_spikes', 'number of spikes')
    nwbfile.add_unit_column('Fs', 'sampling frequency')
    nwbfile.add_unit_column('plx_sort_method', 'sorting method encoded by Plexon')
    nwbfile.add_unit_column('plx_sort_range', 'range of sample indices used in Plexon sorting')
    nwbfile.add_unit_column('plx_sort_threshold', 'voltage threshold used by Plexon sorting')
    nwbfile.add_unit_column('is_unsorted', 'whether this unit is the set of unsorted waveforms')
    nwbfile.add_unit_column('channel_id', 'original recording channel ID')

    # since waveforms are not a 1:1 mapping per unit, use table indexing

    nwbfile.add_unit_column('waveforms', 'waveforms for each spike', index=True)

    # add a unit for each spike channel
    for i in range(file_info.m_TotalNumberOfSpikeChannels):
        pl2_add_units(nwbfile, file_reader, file_handle, i)

    # if spkc_series is not None:
    #     # Plexon does not save the indices of the spike times in the
    #     # high-pass filtered data. So work backwards from the spike times
    #     # first convert spike times to sample indices, accounting for imprecision
    #     spike_inds = spike_ts * schannel_info.m_SamplesPerSecond
    #
    #     # TODO this can be SUPER INEFFICIENT
    #     if not all([math.isclose(x, np.round(x)) for x in spike_inds]):
    #         raise InconsistentInputException()
    #
    #     spike_inds = np.round(spike_inds) # need to account for fragments TODO
    #
    #     ed_module = nwbfile.create_processing_module(name='Plexon online sorted units - all',
    #                                                  description='All units detected online')
    #     print('Adding Event Detection processing module for Electrical Series ' +
    #           f'named {spkc_series.name}')
    #     ed = EventDetection(detection_method="xx", # TODO allow user input
    #                         source_electricalseries=spkc_series,
    #                         source_idx=spike_inds,
    #                         times=spike_ts)
    #     ed_module.add(ed)

    # write NWB file to disk
    out_file = './output/nwb_test.nwb'
    with NWBHDF5IO(out_file, 'w') as io:
        print('Writing to file: ' + out_file)
        io.write(nwbfile)

    # Close the PL2 file
    file_reader.pl2_close_file(file_handle)