def test_watcher_multicast(self): from adlkit.data_provider.tests.mock_config import mock_read_batches, \ mock_sample_specification, mock_expected_malloc_requests mock_read_batches = copy.deepcopy(mock_read_batches) mock_sample_specification = copy.deepcopy(mock_sample_specification) mock_expected_malloc_requests = copy.deepcopy( mock_expected_malloc_requests) batch_size = 100 max_batches = 5 n_generators = 5 n_readers = 5 tmp_data_provider = FileDataProvider(mock_sample_specification, batch_size=batch_size, read_multipler=2, make_one_hot=True, make_class_index=True, n_readers=n_readers, n_generators=n_generators, n_buckets=2, sleep_duration=sleep_duration) tmp_data_provider.start_queues() tmp_data_provider.malloc_requests = mock_expected_malloc_requests tmp_data_provider.make_shared_malloc(0) for batch in mock_read_batches: try: tmp_data_provider.out_queue.put(batch, timeout=0) except Queue.Full: max_batches -= 1 tmp_watcher = BaseWatcher( worker_id=0, shared_memory_pointer=tmp_data_provider.shared_memory, multicast_queues=tmp_data_provider.multicast_queues, max_batches=max_batches, out_queue=tmp_data_provider.out_queue) tmp_watcher.watch() gen_counter = [0] * 5 for loop in range(max_batches): for gen_index in range(len(tmp_data_provider.multicast_queues)): this = None try: this = tmp_data_provider.multicast_queues[gen_index].get() except Queue.Empty: pass finally: if this is not None: gen_counter[gen_index] += 1 for gen_count in gen_counter: self.assertEqual(gen_count, max_batches)
def test_filler_to_malloc(self): from adlkit.data_provider.tests.mock_config import mock_sample_specification, \ mock_expected_malloc_requests mock_sample_specification = copy.deepcopy(mock_sample_specification) mock_expected_malloc_requests = copy.deepcopy( mock_expected_malloc_requests) batch_size = 100 tmp_data_provider = FileDataProvider(mock_sample_specification, batch_size=batch_size, n_readers=4, sleep_duration=sleep_duration) tmp_data_provider.start_queues() tmp_data_provider.start_filler(FileFiller, io_ctlr=IOController(), shape_reader_io_ctlr=None) tmp_data_provider.process_malloc_requests() tmp_data_provider.worker_count = worker_id = 1 tmp_data_provider.make_shared_malloc() bucket = 0 data_set = 0 lock = 0 # TODO !!possible off by one error, needs logic check!! # check to make sure the shared mem was extended # self.assertEqual(len(tmp_data_provider.shared_memory), # tmp_data_provider.worker_count + 1, # "shared memory was not extended correctly ") # check to make sure all buckets were allocated self.assertEqual(len(tmp_data_provider.shared_memory), 10, "shared memory buckets were not allocated correctly") # Multiple Generator start and end locks self.assertEqual(len(tmp_data_provider.shared_memory[bucket]), 4, "shared memory locks were not set correctly") # check to make sure all data sets were allocated self.assertEqual( len(tmp_data_provider.shared_memory[bucket][lock + 1]), len(mock_expected_malloc_requests), "shared memory data sets were not allocated correctly") # check to make sure the shape matches out expected value self.assertEqual( tmp_data_provider.shared_memory[bucket][lock + 1][data_set].shape, (100, 5), "shared memory shape doesn't match") tmp_data_provider.hard_stop()
def test_make_shared_malloc(self): from adlkit.data_provider.tests.mock_config import mock_sample_specification, \ mock_expected_malloc_requests mock_sample_specification = copy.deepcopy(mock_sample_specification) mock_expected_malloc_requests = copy.deepcopy( mock_expected_malloc_requests) batch_size = 100 tmp_data_provider = FileDataProvider(mock_sample_specification, batch_size=batch_size, n_readers=4) tmp_data_provider.malloc_requests = mock_expected_malloc_requests tmp_data_provider.worker_count = 1 worker_id = 0 tmp_data_provider.make_shared_malloc(worker_id) bucket = 1 data_set = 0 lock = 0 self.assertEqual(len(tmp_data_provider.shared_memory), 1, "shared memory was not extended correctly ") # check to make sure all buckets were allocated self.assertEqual(len(tmp_data_provider.shared_memory[worker_id]), 10, "shared memory buckets were not allocated correctly") # check to make sure the shape matches out expected value self.assertEqual( tmp_data_provider.shared_memory[worker_id][bucket][lock + 1] [data_set].shape, (100, 5), "shared memory shape doesn't match")
def test_watcher_multicast(self): from adlkit.data_provider.tests.mock_config import mock_read_batches, \ mock_sample_specification, mock_expected_malloc_requests mock_read_batches = copy.deepcopy(mock_read_batches) mock_sample_specification = copy.deepcopy(mock_sample_specification) mock_expected_malloc_requests = copy.deepcopy( mock_expected_malloc_requests) batch_size = 100 max_batches = 5 n_generators = 5 n_readers = 5 tmp_data_provider = FileDataProvider(mock_sample_specification, batch_size=batch_size, read_multipler=2, make_one_hot=True, make_class_index=True, n_readers=n_readers, n_generators=n_generators, n_buckets=2, sleep_duration=sleep_duration) tmp_data_provider.start_queues() tmp_data_provider.malloc_requests = mock_expected_malloc_requests tmp_data_provider.make_shared_malloc() for batch in mock_read_batches: # try: # tmp_data_provider.out_queue.put(batch, timeout=0) success = tmp_data_provider.comm_driver.write('out', batch, block=False) # except Queue.Full: if not success: max_batches -= 1 tmp_watcher = BaseWatcher( worker_id=0, comm_driver=tmp_data_provider.comm_driver, proxy_comm_drivers=tmp_data_provider.proxy_comm_drivers, shared_memory_pointer=tmp_data_provider.shared_memory, # multicast_queues=tmp_data_provider.multicast_queues, max_batches=max_batches, # out_queue=tmp_data_provider.out_queue ) tmp_watcher.watch() gen_counter = [0] * 5 for loop in range(max_batches): for index, comm_driver in enumerate( tmp_data_provider.proxy_comm_drivers): # this = None # try: # this = tmp_data_provider.multicast_queues[gen_index].get() this = comm_driver.read('out', block=False) # except Queue.Empty: # pass # finally: if this is not None: gen_counter[index] += 1 for gen_count in gen_counter: self.assertEqual(gen_count, max_batches) tmp_data_provider.hard_stop()