def test_multi_sub_recording_extractor(self):
        RX_multi = se.MultiRecordingTimeExtractor(
            recordings=[self.RX, self.RX, self.RX],
            epoch_names=['A', 'B', 'C']
        )
        RX_sub = RX_multi.get_epoch('C')
        check_recordings_equal(self.RX, RX_sub)
        check_recordings_equal(self.RX, RX_multi.recordings[0])
        check_recordings_equal(self.RX, RX_multi.recordings[1])
        check_recordings_equal(self.RX, RX_multi.recordings[2])
        self.assertEqual(4, len(RX_sub.get_channel_ids()))

        RX_multi = se.MultiRecordingChannelExtractor(
            recordings=[self.RX, self.RX2, self.RX3],
            groups=[1, 2, 3]
        )
        RX_sub = se.SubRecordingExtractor(RX_multi, channel_ids=[4, 5, 6, 7], renamed_channel_ids=[0, 1, 2, 3])
        check_recordings_equal(self.RX2, RX_sub)
        check_recordings_equal(self.RX, RX_multi.recordings[0])
        check_recordings_equal(self.RX2, RX_multi.recordings[1])
        check_recordings_equal(self.RX3, RX_multi.recordings[2])
        self.assertEqual([2, 2, 2, 2], list(RX_sub.get_channel_groups()))
        self.assertEqual(12, len(RX_multi.get_channel_ids()))

        rx1 = self.RX
        rx2 = self.RX2
        rx3 = self.RX3
        rx2.set_channel_property(0, "foo", 100)
        rx3.set_channel_locations([11, 11], channel_ids=0)
        RX_multi_c = se.MultiRecordingChannelExtractor(
            recordings=[rx1, rx2, rx3],
            groups=[0, 0, 1]
        )
        self.assertTrue(np.array_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], RX_multi_c.get_channel_ids()))
        self.assertTrue(np.array_equal([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], RX_multi_c.get_channel_groups()))
        self.assertEqual(rx2.get_channel_property(0, "foo"), RX_multi_c.get_channel_property(4, "foo"))
        self.assertTrue(np.array_equal(rx3.get_channel_locations([0])[0], RX_multi_c.get_channel_locations([8])[0]))
Example #2
0
    def test_multi_sub_recording_extractor(self):
        RX_multi = se.MultiRecordingTimeExtractor(
            recordings=[self.RX, self.RX, self.RX],
            epoch_names=['A', 'B', 'C'])
        RX_sub = RX_multi.get_epoch('C')
        self._check_recordings_equal(self.RX, RX_sub)
        self.assertEqual(4, len(RX_sub.get_channel_ids()))

        RX_multi = se.MultiRecordingChannelExtractor(
            recordings=[self.RX, self.RX2, self.RX3], groups=[1, 2, 3])
        print(RX_multi.get_channel_groups())
        RX_sub = se.SubRecordingExtractor(RX_multi,
                                          channel_ids=[4, 5, 6, 7],
                                          renamed_channel_ids=[0, 1, 2, 3])
        self._check_recordings_equal(self.RX2, RX_sub)
        self.assertEqual([2, 2, 2, 2], RX_sub.get_channel_groups())
        self.assertEqual(12, len(RX_multi.get_channel_ids()))
    def test_dump_load_multi_sub_extractor(self):
        # generate dumpable formats
        path1 = self.test_dir + '/mda'
        path2 = path1 + '/firings_true.mda'
        se.MdaRecordingExtractor.write_recording(self.RX, path1)
        se.MdaSortingExtractor.write_sorting(self.SX, path2)
        RX_mda = se.MdaRecordingExtractor(path1)
        SX_mda = se.MdaSortingExtractor(path2)

        RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda])
        check_dumping(RX_multi_chan)
        RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], )
        check_dumping(RX_multi_time)
        RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1])
        check_dumping(RX_multi_chan)

        SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2])
        check_dumping(SX_sub)
        SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda])
        check_dumping(SX_multi)