Пример #1
0
    def test_filler_to_malloc(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, \
            mock_expected_malloc_requests
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        batch_size = 100

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             n_readers=4,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        tmp_data_provider.start_filler(FileFiller,
                                       io_ctlr=IOController(),
                                       shape_reader_io_ctlr=None)

        tmp_data_provider.process_malloc_requests()

        tmp_data_provider.worker_count = worker_id = 1

        tmp_data_provider.make_shared_malloc()

        bucket = 0
        data_set = 0
        lock = 0
        # TODO !!possible off by one error, needs logic check!!

        # check to make sure the shared mem was extended
        # self.assertEqual(len(tmp_data_provider.shared_memory),
        #                  tmp_data_provider.worker_count + 1,
        #                  "shared memory was not extended correctly ")

        # check to make sure all buckets were allocated
        self.assertEqual(len(tmp_data_provider.shared_memory), 10,
                         "shared memory buckets were not allocated correctly")

        # Multiple Generator start and end locks
        self.assertEqual(len(tmp_data_provider.shared_memory[bucket]), 4,
                         "shared memory locks were not set correctly")

        # check to make sure all data sets were allocated
        self.assertEqual(
            len(tmp_data_provider.shared_memory[bucket][lock + 1]),
            len(mock_expected_malloc_requests),
            "shared memory data sets were not allocated correctly")

        # check to make sure the shape matches out expected value
        self.assertEqual(
            tmp_data_provider.shared_memory[bucket][lock + 1][data_set].shape,
            (100, 5), "shared memory shape doesn't match")
        tmp_data_provider.hard_stop()
Пример #2
0
    def test_process_malloc_requests(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, \
            mock_expected_malloc_requests
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=100,
                                             n_readers=4,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        tmp_data_provider.start_filler(H5Filler)

        tmp_data_provider.process_malloc_requests()

        for request, expected_request in zip(tmp_data_provider.malloc_requests,
                                             mock_expected_malloc_requests):
            self.assertEqual(request, expected_request)
Пример #3
0
    def test_end_to_end(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification
        mock_sample_specification = copy.deepcopy(mock_sample_specification)

        batch_size = 5
        max_batches = 5
        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_multiplier=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             make_file_index=True,
                                             n_readers=4,
                                             sleep_duration=sleep_duration,
                                             max_batches=max_batches)

        tmp_data_provider.start_queues()
        tmp_data_provider.start_filler(H5Filler)

        tmp_data_provider.process_malloc_requests()

        tmp_data_provider.start_reader(H5Reader)

        # lg.debug("I am expecting to write to {0}".format(hex(id(tmp_data_provider.shared_memory[0][0][1][0]))))

        generator_id = tmp_data_provider.start_generator(BaseGenerator)

        count = 0
        for tmp in tmp_data_provider.generators[generator_id].generate():
            count += 1
            self.assertEqual(len(tmp), 5)
            self.assertEqual(tmp[0].shape, (batch_size, 5))
            self.assertEqual(tmp[1].shape, (batch_size, 5))
            self.assertEqual(tmp[2].shape, (batch_size, ))
            self.assertEqual(tmp[3].shape, (batch_size, 3))
            self.assertEqual(len(tmp[4][0]), 2)

        self.assertEqual(
            count, max_batches,
            "we have a leaky generator here {0} != {1}".format(
                count, max_batches))