Ejemplo n.º 1
0
    def test_start(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification
        mock_sample_specification = copy.deepcopy(mock_sample_specification)

        max_batches = 100
        batch_size = 5

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_multiplier=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             n_readers=4,
                                             wrap_examples=True,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start(filler_class=H5Filler,
                                reader_class=H5Reader,
                                generator_class=BaseGenerator)

        # I could make this a function like the above, but it makes sense not to limit
        # the full pipeline with max_batches. Tentatively there is no way to pass parameters
        # to certain parts of the pipeline. Everyone gets every kwarg, for better or for worse.
        generator_id = 0
        for _ in range(max_batches):
            # TODO better checks
            this = tmp_data_provider.generators[generator_id].generate().next()
            self.assertEqual(len(this), 4)
Ejemplo n.º 2
0
    def test_class_index_map(self):
        """
        if this fails, it's most likely due to a lack of allowed file descriptors
        :return:
        """

        from adlkit.data_provider.tests.mock_config import mock_sample_specification
        mock_sample_specification = copy.deepcopy(mock_sample_specification)

        batch_size = 5

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_batches_per_epoch=1000,
                                             read_multiplier=1,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             wrap_examples=True,
                                             n_readers=5,
                                             n_buckets=2,
                                             q_multipler=3,
                                             sleep_duration=sleep_duration,
                                             class_index_map={
                                                 'class_1': -3,
                                                 'class_2': -2,
                                                 'class_10': -1
                                             })

        tmp_data_provider.start(filler_class=FileFiller,
                                reader_class=FileReader,
                                generator_class=BaseGenerator)
        generator_id = 0

        for _ in range(100):
            # for _ in range(10000000):
            tmp = tmp_data_provider.generators[generator_id].generate().next()
            # TODO better checks
            self.assertEqual(len(tmp), 4)

        tmp_data_provider.hard_stop()

        # we are checking the that readers are killing themselves
        # by checking that the pid no longer exists
        # https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python
        for reader_process in tmp_data_provider.readers:
            this = None
            try:
                os.kill(reader_process.pid, 0)
                this = True
            except OSError:
                this = False
            finally:
                self.assertEqual(this, False)
Ejemplo n.º 3
0
    def test_multiple_generators(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification
        mock_sample_specification = copy.deepcopy(mock_sample_specification)

        max_batches = 10
        batch_size = 100
        n_generators = 5

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_multiplier=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             n_readers=3,
                                             n_generators=n_generators,
                                             wrap_examples=True,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start(filler_class=FileFiller,
                                reader_class=FileReader,
                                generator_class=BaseGenerator,
                                watcher_class=BaseWatcher)

        # ugly, but mostly necessary
        generator_counter = [[] for _ in range(n_generators)]
        total_count = 0
        for loop in range(max_batches):
            for gen_index in range(len(tmp_data_provider.generators)):
                this = None
                this = tmp_data_provider.generators[gen_index].generate().next(
                )

                self.assertEqual(len(this), 4)
                generator_counter[gen_index].append(this)
                total_count += 1

        for generator_count in generator_counter:
            self.assertEqual(len(generator_count), max_batches)

        for p_index, payload in enumerate(generator_counter[0]):
            for e_index, entry in enumerate(payload):
                for generator_output in generator_counter:
                    self.assertTrue(
                        np.array_equal(generator_output[p_index][e_index],
                                       entry))

        self.assertEqual(total_count, max_batches * n_generators)

        tmp_data_provider.hard_stop()
Ejemplo n.º 4
0
    def test_multiple_generators(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification
        mock_sample_specification = copy.deepcopy(mock_sample_specification)

        max_batches = 100
        batch_size = 100
        n_generators = 5

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_multiplier=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             n_readers=3,
                                             n_generators=n_generators,
                                             wrap_examples=True,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start(filler_class=H5Filler,
                                reader_class=H5Reader,
                                generator_class=BaseGenerator,
                                watcher_class=BaseWatcher)

        # ugly, but mostly necessary
        generator_counter = [0] * n_generators
        total_count = 0
        for loop in range(max_batches):
            for gen_index in range(len(tmp_data_provider.generators)):
                this = None
                this = tmp_data_provider.generators[gen_index].generate().next(
                )

                self.assertEqual(len(this), 4)
                generator_counter[gen_index] += 1
                total_count += 1

        for generator_count in generator_counter:
            self.assertEqual(generator_count, max_batches)

        self.assertEqual(total_count, max_batches * n_generators)