def test_filler_to_malloc(self): from adlkit.data_provider.tests.mock_config import mock_sample_specification, \ mock_expected_malloc_requests mock_sample_specification = copy.deepcopy(mock_sample_specification) mock_expected_malloc_requests = copy.deepcopy( mock_expected_malloc_requests) batch_size = 100 tmp_data_provider = FileDataProvider(mock_sample_specification, batch_size=batch_size, n_readers=4, sleep_duration=sleep_duration) tmp_data_provider.start_queues() tmp_data_provider.start_filler(FileFiller, io_ctlr=IOController(), shape_reader_io_ctlr=None) tmp_data_provider.process_malloc_requests() tmp_data_provider.worker_count = worker_id = 1 tmp_data_provider.make_shared_malloc() bucket = 0 data_set = 0 lock = 0 # TODO !!possible off by one error, needs logic check!! # check to make sure the shared mem was extended # self.assertEqual(len(tmp_data_provider.shared_memory), # tmp_data_provider.worker_count + 1, # "shared memory was not extended correctly ") # check to make sure all buckets were allocated self.assertEqual(len(tmp_data_provider.shared_memory), 10, "shared memory buckets were not allocated correctly") # Multiple Generator start and end locks self.assertEqual(len(tmp_data_provider.shared_memory[bucket]), 4, "shared memory locks were not set correctly") # check to make sure all data sets were allocated self.assertEqual( len(tmp_data_provider.shared_memory[bucket][lock + 1]), len(mock_expected_malloc_requests), "shared memory data sets were not allocated correctly") # check to make sure the shape matches out expected value self.assertEqual( tmp_data_provider.shared_memory[bucket][lock + 1][data_set].shape, (100, 5), "shared memory shape doesn't match") tmp_data_provider.hard_stop()
def test_process_malloc_requests(self): from adlkit.data_provider.tests.mock_config import mock_sample_specification, \ mock_expected_malloc_requests mock_sample_specification = copy.deepcopy(mock_sample_specification) mock_expected_malloc_requests = copy.deepcopy( mock_expected_malloc_requests) tmp_data_provider = FileDataProvider(mock_sample_specification, batch_size=100, n_readers=4, sleep_duration=sleep_duration) tmp_data_provider.start_queues() tmp_data_provider.start_filler(H5Filler) tmp_data_provider.process_malloc_requests() for request, expected_request in zip(tmp_data_provider.malloc_requests, mock_expected_malloc_requests): self.assertEqual(request, expected_request)
def test_end_to_end(self): from adlkit.data_provider.tests.mock_config import mock_sample_specification mock_sample_specification = copy.deepcopy(mock_sample_specification) batch_size = 5 max_batches = 5 tmp_data_provider = FileDataProvider(mock_sample_specification, batch_size=batch_size, read_multiplier=2, make_one_hot=True, make_class_index=True, make_file_index=True, n_readers=4, sleep_duration=sleep_duration, max_batches=max_batches) tmp_data_provider.start_queues() tmp_data_provider.start_filler(H5Filler) tmp_data_provider.process_malloc_requests() tmp_data_provider.start_reader(H5Reader) # lg.debug("I am expecting to write to {0}".format(hex(id(tmp_data_provider.shared_memory[0][0][1][0])))) generator_id = tmp_data_provider.start_generator(BaseGenerator) count = 0 for tmp in tmp_data_provider.generators[generator_id].generate(): count += 1 self.assertEqual(len(tmp), 5) self.assertEqual(tmp[0].shape, (batch_size, 5)) self.assertEqual(tmp[1].shape, (batch_size, 5)) self.assertEqual(tmp[2].shape, (batch_size, )) self.assertEqual(tmp[3].shape, (batch_size, 3)) self.assertEqual(len(tmp[4][0]), 2) self.assertEqual( count, max_batches, "we have a leaky generator here {0} != {1}".format( count, max_batches))