Пример #1
0
    def test_watcher_multicast(self):
        from adlkit.data_provider.tests.mock_config import mock_read_batches, \
            mock_sample_specification, mock_expected_malloc_requests
        mock_read_batches = copy.deepcopy(mock_read_batches)
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        batch_size = 100
        max_batches = 5
        n_generators = 5
        n_readers = 5
        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_multipler=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             n_readers=n_readers,
                                             n_generators=n_generators,
                                             n_buckets=2,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        tmp_data_provider.malloc_requests = mock_expected_malloc_requests
        tmp_data_provider.make_shared_malloc(0)

        for batch in mock_read_batches:
            try:
                tmp_data_provider.out_queue.put(batch, timeout=0)
            except Queue.Full:
                max_batches -= 1

        tmp_watcher = BaseWatcher(
            worker_id=0,
            shared_memory_pointer=tmp_data_provider.shared_memory,
            multicast_queues=tmp_data_provider.multicast_queues,
            max_batches=max_batches,
            out_queue=tmp_data_provider.out_queue)

        tmp_watcher.watch()

        gen_counter = [0] * 5

        for loop in range(max_batches):
            for gen_index in range(len(tmp_data_provider.multicast_queues)):
                this = None
                try:
                    this = tmp_data_provider.multicast_queues[gen_index].get()
                except Queue.Empty:
                    pass
                finally:
                    if this is not None:
                        gen_counter[gen_index] += 1

        for gen_count in gen_counter:
            self.assertEqual(gen_count, max_batches)
Пример #2
0
    def test_filler_to_malloc(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, \
            mock_expected_malloc_requests
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        batch_size = 100

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             n_readers=4,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        tmp_data_provider.start_filler(FileFiller,
                                       io_ctlr=IOController(),
                                       shape_reader_io_ctlr=None)

        tmp_data_provider.process_malloc_requests()

        tmp_data_provider.worker_count = worker_id = 1

        tmp_data_provider.make_shared_malloc()

        bucket = 0
        data_set = 0
        lock = 0
        # TODO !!possible off by one error, needs logic check!!

        # check to make sure the shared mem was extended
        # self.assertEqual(len(tmp_data_provider.shared_memory),
        #                  tmp_data_provider.worker_count + 1,
        #                  "shared memory was not extended correctly ")

        # check to make sure all buckets were allocated
        self.assertEqual(len(tmp_data_provider.shared_memory), 10,
                         "shared memory buckets were not allocated correctly")

        # Multiple Generator start and end locks
        self.assertEqual(len(tmp_data_provider.shared_memory[bucket]), 4,
                         "shared memory locks were not set correctly")

        # check to make sure all data sets were allocated
        self.assertEqual(
            len(tmp_data_provider.shared_memory[bucket][lock + 1]),
            len(mock_expected_malloc_requests),
            "shared memory data sets were not allocated correctly")

        # check to make sure the shape matches out expected value
        self.assertEqual(
            tmp_data_provider.shared_memory[bucket][lock + 1][data_set].shape,
            (100, 5), "shared memory shape doesn't match")
        tmp_data_provider.hard_stop()
Пример #3
0
    def test_make_shared_malloc(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, \
            mock_expected_malloc_requests
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        batch_size = 100

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             n_readers=4)

        tmp_data_provider.malloc_requests = mock_expected_malloc_requests

        tmp_data_provider.worker_count = 1

        worker_id = 0

        tmp_data_provider.make_shared_malloc(worker_id)

        bucket = 1
        data_set = 0
        lock = 0

        self.assertEqual(len(tmp_data_provider.shared_memory), 1,
                         "shared memory was not extended correctly ")

        # check to make sure all buckets were allocated
        self.assertEqual(len(tmp_data_provider.shared_memory[worker_id]), 10,
                         "shared memory buckets were not allocated correctly")

        # check to make sure the shape matches out expected value

        self.assertEqual(
            tmp_data_provider.shared_memory[worker_id][bucket][lock + 1]
            [data_set].shape, (100, 5), "shared memory shape doesn't match")
Пример #4
0
    def test_watcher_multicast(self):
        from adlkit.data_provider.tests.mock_config import mock_read_batches, \
            mock_sample_specification, mock_expected_malloc_requests
        mock_read_batches = copy.deepcopy(mock_read_batches)
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        batch_size = 100
        max_batches = 5
        n_generators = 5
        n_readers = 5
        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_multipler=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             n_readers=n_readers,
                                             n_generators=n_generators,
                                             n_buckets=2,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        tmp_data_provider.malloc_requests = mock_expected_malloc_requests
        tmp_data_provider.make_shared_malloc()

        for batch in mock_read_batches:
            # try:
            # tmp_data_provider.out_queue.put(batch, timeout=0)
            success = tmp_data_provider.comm_driver.write('out',
                                                          batch,
                                                          block=False)
            # except Queue.Full:
            if not success:
                max_batches -= 1

        tmp_watcher = BaseWatcher(
            worker_id=0,
            comm_driver=tmp_data_provider.comm_driver,
            proxy_comm_drivers=tmp_data_provider.proxy_comm_drivers,
            shared_memory_pointer=tmp_data_provider.shared_memory,
            # multicast_queues=tmp_data_provider.multicast_queues,
            max_batches=max_batches,
            # out_queue=tmp_data_provider.out_queue
        )

        tmp_watcher.watch()

        gen_counter = [0] * 5

        for loop in range(max_batches):
            for index, comm_driver in enumerate(
                    tmp_data_provider.proxy_comm_drivers):
                # this = None
                # try:
                # this = tmp_data_provider.multicast_queues[gen_index].get()
                this = comm_driver.read('out', block=False)
                # except Queue.Empty:
                #     pass
                # finally:
                if this is not None:
                    gen_counter[index] += 1

        for gen_count in gen_counter:
            self.assertEqual(gen_count, max_batches)

        tmp_data_provider.hard_stop()