def test_start_reader(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, \
            mock_batches, mock_expected_malloc_requests
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_batches = copy.deepcopy(mock_batches)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=500,
                                             read_multiplier=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()

        # inserting mock data into the queue
        for batch in mock_batches:
            success = tmp_data_provider.comm_driver.write('in', batch)
            self.assertTrue(success)

        tmp_data_provider.malloc_requests = mock_expected_malloc_requests

        reader_id = tmp_data_provider.start_reader(FileReader,
                                                   io_ctlr=IOController())
        # lg.debug("I am expecting to write to {0}".format(hex(id(tmp_data_provider.shared_memory[0][0][1][0]))))
        max_batches = 5

        out = list()
        start_time = datetime.datetime.utcnow()
        end_time = datetime.timedelta(seconds=100) + start_time
        while datetime.datetime.utcnow() < end_time and len(
                out) != max_batches:
            # try:
            batch = tmp_data_provider.comm_driver.read('out', block=False)
            # except Queue.Empty:
            #     pass
            # finally:
            if batch is not None:
                out.append(batch)

        for item in out:
            # tmp_worker_id, tmp_bucket_index, tmp_data_sets, batch_id = item
            tmp_bucket_index, tmp_data_sets, batch_id = item
            # Check that each bucket was successfully updated
            self.assertEqual(
                tmp_data_provider.shared_memory[tmp_bucket_index][0].value, 1)

        # check for correct reader_id assignment
        self.assertEqual(len(tmp_data_provider.readers), reader_id + 1)

        # check that mock data successfully was processed
        self.assertEquals(
            len(out), max_batches,
            "test consumed {0} of {1} expected batches from the out_queue".
            format(len(out), max_batches))
Beispiel #2
0
    def test_fill_with_wrap(self):
        """
        testing with 3 files, 3 classes, w/ shape (1000,5)
        :return: 
        """
        from mock_config import mock_classes, mock_class_index_map, \
            mock_data_sets, mock_file_index_list
        mock_classes = copy.deepcopy(mock_classes)
        mock_class_index_map = copy.deepcopy(mock_class_index_map)
        mock_data_sets = copy.deepcopy(mock_data_sets)
        mock_file_index_list = copy.deepcopy(mock_file_index_list)

        max_batches = 10
        batch_size = 500
        read_size = batch_size * 2

        comm_driver = QueueCommDriver({'ctl': 10, 'in': 100, 'malloc': 100})

        filler = FileFiller(classes=mock_classes,
                            class_index_map=mock_class_index_map,
                            comm_driver=comm_driver,
                            worker_id=1,
                            read_size=read_size,
                            max_batches=max_batches,
                            wrap_examples=True,
                            data_sets=mock_data_sets,
                            file_index_list=mock_file_index_list,
                            io_ctlr=IOController()
                            # io_driver=H5DataIODriver()
                            )

        filler.fill()

        out = int()
        start_time = datetime.datetime.utcnow()
        end_time = datetime.timedelta(seconds=10) + start_time
        while datetime.datetime.utcnow() < end_time:

            batch = comm_driver.read('in', block=False)

            if batch is not None:
                out += 1
                count = 0
                for item in batch:
                    count += item[3][1] - item[3][0]

                self.assertEquals(
                    count, read_size,
                    "read_batch was returned with read_size {0}, instead of batch_size {1} \n {"
                    "2}".format(count, read_size, batch))
            if out == max_batches:
                break

        self.assertEquals(
            out, max_batches,
            "test consumed {0} of {1} expected batches from the in_queue".
            format(out, max_batches))
    def test_generator_one_hot_and_class_index(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, \
            mock_batches, mock_expected_malloc_requests, mock_file_index_list, mock_class_index_map
        mock_class_index_map = copy.deepcopy(mock_class_index_map)
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_batches = copy.deepcopy(mock_batches)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)
        mock_file_index_list = copy.deepcopy(mock_file_index_list)

        batch_size = 500
        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_multiplier=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        max_batches = len(mock_batches)
        # inserting mock data, once again
        for batch in mock_batches:
            # try:
            # tmp_data_provider.in_queue.put(batch, timeout=1)
            success = tmp_data_provider.comm_driver.write('in',
                                                          batch,
                                                          block=False)
            # except Queue.Full:
            if not success:
                max_batches -= 1

        tmp_data_provider.malloc_requests = mock_expected_malloc_requests

        tmp_data_provider.start_reader(FileReader, io_ctlr=IOController())
        # lg.debug("I am expecting to write to {0}".format(hex(id(tmp_data_provider.shared_memory[0][0][1][0]))))

        this = BaseGenerator(
            # out_queue=tmp_data_provider.out_queue,
            comm_driver=tmp_data_provider.comm_driver,
            batch_size=tmp_data_provider.config.batch_size,
            read_size=tmp_data_provider.config.read_size,
            max_batches=max_batches,
            shared_memory_pointer=tmp_data_provider.shared_memory,
            file_index_list=mock_file_index_list,
            class_index_map=mock_class_index_map)

        count = 0
        for tmp in this.generate():
            count += 1
            self.assertEqual(len(tmp), 4)
            self.assertEqual(tmp[0].shape, (batch_size, 5))
            self.assertEqual(tmp[1].shape, (batch_size, 5))
            self.assertEqual(tmp[2].shape, (batch_size, ))
            self.assertEqual(tmp[3].shape, (batch_size, 3))

        self.assertEqual(count, max_batches)
    def test_filler_to_malloc(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, \
            mock_expected_malloc_requests
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        batch_size = 100

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             n_readers=4,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        tmp_data_provider.start_filler(FileFiller,
                                       io_ctlr=IOController(),
                                       shape_reader_io_ctlr=None)

        tmp_data_provider.process_malloc_requests()

        tmp_data_provider.worker_count = worker_id = 1

        tmp_data_provider.make_shared_malloc()

        bucket = 0
        data_set = 0
        lock = 0
        # TODO !!possible off by one error, needs logic check!!

        # check to make sure the shared mem was extended
        # self.assertEqual(len(tmp_data_provider.shared_memory),
        #                  tmp_data_provider.worker_count + 1,
        #                  "shared memory was not extended correctly ")

        # check to make sure all buckets were allocated
        self.assertEqual(len(tmp_data_provider.shared_memory), 10,
                         "shared memory buckets were not allocated correctly")

        # Multiple Generator start and end locks
        self.assertEqual(len(tmp_data_provider.shared_memory[bucket]), 4,
                         "shared memory locks were not set correctly")

        # check to make sure all data sets were allocated
        self.assertEqual(
            len(tmp_data_provider.shared_memory[bucket][lock + 1]),
            len(mock_expected_malloc_requests),
            "shared memory data sets were not allocated correctly")

        # check to make sure the shape matches out expected value
        self.assertEqual(
            tmp_data_provider.shared_memory[bucket][lock + 1][data_set].shape,
            (100, 5), "shared memory shape doesn't match")
        tmp_data_provider.hard_stop()
    def test_end_to_end(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification
        mock_sample_specification = copy.deepcopy(mock_sample_specification)

        batch_size = 5
        max_batches = 5
        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=batch_size,
                                             read_multiplier=2,
                                             make_one_hot=True,
                                             make_class_index=True,
                                             make_file_index=True,
                                             n_readers=4,
                                             sleep_duration=sleep_duration,
                                             max_batches=max_batches)

        tmp_data_provider.start_queues()
        tmp_data_provider.start_filler(FileFiller, io_ctlr=IOController())

        tmp_data_provider.process_malloc_requests()

        tmp_data_provider.start_reader(FileReader, io_ctlr=IOController())

        # lg.debug("I am expecting to write to {0}".format(hex(id(tmp_data_provider.shared_memory[0][0][1][0]))))

        generator_id = tmp_data_provider.start_generator(BaseGenerator)

        count = 0
        for tmp in tmp_data_provider.generators[generator_id].generate():
            count += 1
            self.assertEqual(len(tmp), 5)
            self.assertEqual(tmp[0].shape, (batch_size, 5))
            self.assertEqual(tmp[1].shape, (batch_size, 5))
            self.assertEqual(tmp[2].shape, (batch_size, ))
            self.assertEqual(tmp[3].shape, (batch_size, 3))
            self.assertEqual(len(tmp[4][0]), 2)

        self.assertEqual(
            count, max_batches,
            "we have a leaky generator here {0} != {1}".format(
                count, max_batches))
    def test_generator(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, mock_batches, \
            mock_expected_malloc_requests, mock_file_index_list, mock_class_index_map
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_batches = copy.deepcopy(mock_batches)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)
        mock_file_index_list = copy.deepcopy(mock_file_index_list)
        mock_class_index_map = copy.deepcopy(mock_class_index_map)

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=500,
                                             read_multiplier=2,
                                             n_readers=3,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        max_batches = len(mock_batches)
        # max_batches = 100
        # inserting mock data into the queue (should probably be a function at this point)
        for batch in mock_batches:
            # try:
            # tmp_data_provider.in_queue.put(batch, timeout=1)
            success = tmp_data_provider.comm_driver.write('in',
                                                          batch,
                                                          block=False)
            # except Queue.Full:
            if not success:
                max_batches -= 1

        tmp_data_provider.malloc_requests = mock_expected_malloc_requests

        tmp_data_provider.start_reader(FileReader, io_ctlr=IOController())
        # lg.debug("I am expecting to write to {0}".format(hex(id(tmp_data_provider.shared_memory[0][0][1][0]))))

        this = BaseGenerator(
            comm_driver=tmp_data_provider.comm_driver,
            batch_size=tmp_data_provider.config.batch_size,
            read_size=tmp_data_provider.config.read_size,
            max_batches=max_batches,
            shared_memory_pointer=tmp_data_provider.shared_memory,
            file_index_list=mock_file_index_list,
            class_index_map=mock_class_index_map)

        batch_size = tmp_data_provider.config.batch_size
        self.assertGreater(max_batches, 0)
        for _ in range(max_batches):
            tmp = this.generate().next()
            self.assertEqual(len(tmp), 2)
            self.assertEqual(tmp[0].shape, (batch_size, 5))
            self.assertEqual(tmp[1].shape, (batch_size, 5))
    def test_writer(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification

        mock_sample_specification = copy.deepcopy(mock_sample_specification)

        batch_size = 100
        max_batches = 5
        n_generators = 5
        n_readers = 5
        n_buckets = 10
        q_multiplier = 5

        data_dst = '/{}/{}'.format(os.getcwd(), int(time.time()))
        tmp_data_provider = WatchedH5FileDataProvider(
            mock_sample_specification,
            batch_size=batch_size,
            read_multipler=2,
            make_one_hot=True,
            make_class_index=True,
            n_readers=n_readers,
            n_generators=n_generators,
            q_multiplier=q_multiplier,
            n_buckets=n_buckets,
            sleep_duration=sleep_duration,
            max_batches=max_batches,
            writer_config=[{
                'data_dst': data_dst,
                'io_ctlr': IOController(),
                # 'pre_write_function':''
            }])

        success = tmp_data_provider.start(writer_class=BaseWriter)

        self.assertTrue(success,
                        msg='DataProvider was not started successfully.')

        end_time = datetime.timedelta(seconds=10) + datetime.datetime.utcnow()

        while datetime.datetime.utcnow() < end_time:
            with tmp_data_provider.writers[0].stop.get_lock():
                if tmp_data_provider.writers_have_stopped():
                    break
                else:
                    time.sleep(1)

        exists = os.path.exists(data_dst)
        self.assertTrue(exists)

        tmp_data_provider.hard_stop()
        os.remove(data_dst)
Beispiel #8
0
    def test_inform_data_provider(self):
        from mock_config import mock_class_index_map, mock_classes, \
            mock_data_sets, mock_expected_malloc_requests, mock_file_index_list
        mock_classes = copy.deepcopy(mock_classes)
        mock_class_index_map = copy.deepcopy(mock_class_index_map)
        mock_data_sets = copy.deepcopy(mock_data_sets)
        mock_file_index_list = copy.deepcopy(mock_file_index_list)

        max_batches = 5
        batch_size = 200
        read_size = batch_size * 2

        comm_driver = QueueCommDriver({'ctl': 10, 'in': 100, 'malloc': 100})

        filler = FileFiller(classes=mock_classes,
                            class_index_map=mock_class_index_map,
                            comm_driver=comm_driver,
                            worker_id=1,
                            read_size=read_size,
                            max_batches=max_batches,
                            data_sets=mock_data_sets,
                            file_index_list=mock_file_index_list,
                            io_ctlr=IOController()
                            # io_driver=H5DataIODriver()
                            )

        filler.fill()

        out = list()
        start_time = datetime.datetime.utcnow()
        end_time = datetime.timedelta(seconds=10) + start_time
        while datetime.datetime.utcnow() < end_time:
            malloc_request = comm_driver.read('malloc', block=False)

            if malloc_request is not None:
                out.extend(malloc_request)

            if len(out) == len(mock_expected_malloc_requests):
                break

        self.assertEqual(len(out), len(mock_expected_malloc_requests))
        for request, expected_request in zip(out,
                                             mock_expected_malloc_requests):
            self.assertEqual(request, expected_request)
    def test_process_malloc_requests(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification, \
            mock_expected_malloc_requests
        mock_sample_specification = copy.deepcopy(mock_sample_specification)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=100,
                                             n_readers=4,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        tmp_data_provider.start_filler(FileFiller, io_ctlr=IOController())

        tmp_data_provider.process_malloc_requests()

        for request, expected_request in zip(tmp_data_provider.malloc_requests,
                                             mock_expected_malloc_requests):
            self.assertEqual(request, expected_request)
Beispiel #10
0
    def test_start_filler(self):
        from adlkit.data_provider.tests.mock_config import mock_sample_specification
        mock_sample_specification = copy.deepcopy(mock_sample_specification)

        tmp_data_provider = FileDataProvider(mock_sample_specification,
                                             batch_size=100,
                                             read_multiplier=2,
                                             n_readers=4,
                                             sleep_duration=sleep_duration)

        tmp_data_provider.start_queues()
        tmp_data_provider.start_filler(FileFiller, io_ctlr=IOController())

        max_batches = 10

        out = list()
        start_time = datetime.datetime.utcnow()
        end_time = datetime.timedelta(seconds=10) + start_time
        while datetime.datetime.utcnow() < end_time and len(
                out) != max_batches:
            batch = tmp_data_provider.comm_driver.read('in', block=False)

            if batch is not None:
                out.append(batch)
                count = 0
                for item in batch:
                    count += item[3][1] - item[3][0]

                self.assertEquals(
                    count, tmp_data_provider.config.read_size,
                    "batch was returned with batch_size {0}, instead of batch_size {1}"
                    .format(count, tmp_data_provider.config.batch_size))

        self.assertEquals(
            len(out), max_batches,
            "test consumed {0} of {1} expected batches from the in_queue".
            format(len(out), max_batches))
        tmp_data_provider.hard_stop()
Beispiel #11
0
    def test_init(self):
        comm_driver = QueueCommDriver({'ctl': 10})

        shape = [5, 10, 80, 40]
        data_dst = 'hello.h5'
        try:
            os.remove(data_dst)
        except OSError:
            pass
        max_batches = shape[0]

        def test_data_src():
            for datum in np.random.rand(*shape):
                yield datum

        writer = BaseWriter(
            worker_id=1,
            max_batches=max_batches,
            comm_driver=comm_driver,
            data_src=test_data_src(),
            data_dst=data_dst,
            # io_driver=H5DataIODriver()
            io_ctlr=IOController())

        writer.write()

        tmp_io_driver = H5DataIODriver()

        with tmp_io_driver:
            tmp_handle = tmp_io_driver.get(data_dst)

            for key in tmp_handle.keys():
                print(tmp_handle[key].shape)
                self.assertEqual(tmp_handle[key].shape, tuple(shape))

        os.remove(data_dst)
Beispiel #12
0
    def test_make_file_index(self):
        from mock_config import mock_batches, mock_expected_malloc_requests, \
            mock_class_index_map, mock_file_index_malloc, mock_file_index_list

        mock_batches = copy.deepcopy(mock_batches)
        mock_class_index_map = copy.deepcopy(mock_class_index_map)
        mock_file_index_malloc = copy.deepcopy(mock_file_index_malloc)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)
        mock_expected_malloc_requests.extend(mock_file_index_malloc)

        max_batches = len(mock_batches)
        batch_size = 500
        read_size = 2 * batch_size
        bucket_length = 10
        comm_driver = QueueCommDriver({'ctl': 10, 'in': 100, 'out': 100})

        # building queue up with read requests
        for batch in mock_batches:
            success = comm_driver.write('in', batch)
            self.assertTrue(success)

        reader_id = 0
        shared_data_pointer = range(bucket_length)

        for bucket in shared_data_pointer:
            data_sets = []
            for request in mock_expected_malloc_requests:
                # TODO requests are not ordered!!!
                # TODO not sure it matters as long as its consistent
                # reshape the requested shape to match the batch_size
                shape = (read_size, ) + request[1]

                shared_array_base = multiprocessing.Array(ctypes.c_double,
                                                          np.prod(shape),
                                                          lock=False)
                shared_array = np.ctypeslib.as_array(shared_array_base)
                shared_array = shared_array.reshape(shape)
                data_sets.append(shared_array)

            state = multiprocessing.Value('i', 0)
            generator_start_counter = multiprocessing.Value('i', 0)
            generator_end_counter = multiprocessing.Value('i', 0)
            shared_data_pointer[bucket] = [
                state, data_sets, generator_start_counter,
                generator_end_counter
            ]

        reader = FileReader(worker_id=reader_id,
                            comm_driver=comm_driver,
                            shared_memory_pointer=shared_data_pointer,
                            max_batches=max_batches,
                            read_size=read_size,
                            class_index_map=mock_class_index_map,
                            make_one_hot=True,
                            make_class_index=True,
                            make_file_index=True,
                            file_index_list=mock_file_index_list,
                            io_ctlr=IOController()
                            # io_driver=H5DataIODriver()
                            )

        reader.read()

        out = list()
        start_time = datetime.datetime.utcnow()
        end_time = datetime.timedelta(seconds=10) + start_time
        while datetime.datetime.utcnow() < end_time and len(
                out) != max_batches:
            batch = comm_driver.read('out', block=False)
            if batch is not None:
                out.append(batch)

        self.assertEquals(
            len(out), max_batches,
            "test consumed {0} of {1} expected batches from the in_queue".
            format(len(out), max_batches))
Beispiel #13
0
    def test_file_caching(self):
        """
        max_batches and batch_size are directly correlated to test input data

        make sure gen_rand_data has the correct shape

        :return:
        """

        from mock_config import mock_batches, mock_expected_malloc_requests, \
            mock_class_index_map, mock_file_index_list

        mock_batches = copy.deepcopy(mock_batches)
        mock_expected_malloc_requests = copy.deepcopy(
            mock_expected_malloc_requests)
        mock_class_index_map = copy.deepcopy(mock_class_index_map)

        max_batches = len(mock_batches)
        batch_size = 500
        read_size = 2 * batch_size
        bucket_length = 10
        # in_queue = multiprocessing.Queue(maxsize=max_batches)
        # out_queue = multiprocessing.Queue(maxsize=max_batches)

        comm_driver = QueueCommDriver({'ctl': 10, 'in': 100, 'out': 100})

        # building queue up with read requests
        for batch in mock_batches:
            success = comm_driver.write('in', batch)
            self.assertTrue(success)

        reader_id = 0

        shared_data_pointer = range(bucket_length)

        for bucket in shared_data_pointer:
            data_sets = []
            for request in mock_expected_malloc_requests:
                # reshape the requested shape to match the batch_size
                shape = (read_size, ) + request[1]

                shared_array_base = multiprocessing.Array(ctypes.c_double,
                                                          np.prod(shape),
                                                          lock=False)
                shared_array = np.ctypeslib.as_array(shared_array_base)
                shared_array = shared_array.reshape(shape)
                data_sets.append(shared_array)

            state = multiprocessing.Value('i', 0)
            generator_start_counter = multiprocessing.Value('i', 0)
            generator_end_counter = multiprocessing.Value('i', 0)
            shared_data_pointer[bucket] = [
                state, data_sets, generator_start_counter,
                generator_end_counter
            ]

        reader = FileReader(worker_id=reader_id,
                            comm_driver=comm_driver,
                            shared_memory_pointer=shared_data_pointer,
                            max_batches=max_batches,
                            read_size=read_size,
                            class_index_map=mock_class_index_map,
                            file_index_list=mock_file_index_list,
                            io_ctlr=IOController()
                            # io_driver=H5DataIODriver({"cache_handles": True})
                            )

        reader.read()

        out = list()
        start_time = datetime.datetime.utcnow()
        end_time = datetime.timedelta(seconds=10) + start_time
        while datetime.datetime.utcnow() < end_time and len(
                out) != max_batches:
            batch = comm_driver.read('out', block=False)
            if batch is not None:
                out.append(batch)

        self.assertEquals(
            len(out), max_batches,
            "test consumed {0} of {1} expected batches from the in_queue".
            format(len(out), max_batches))

        target = 0
        self.assertGreaterEqual(len(reader.io_ctlr('').file_handle_holder), 0)
        for handle in reader.io_ctlr('').file_handle_holder:
            # Here we try to close the file again, this will raise an Exception and thus means the caching cleaned up
            # successfully. if 0 then we closed everything successfully with the __exit__ function
            target += 1
            try:
                handle.close()
            except Exception:
                target -= 1

        self.assertEqual(target, 0)
Beispiel #14
0
    def test_fill_without_wrap(self):
        """
        testing with 3 files, 3 classes, w/ shape (1000,5)
        
        expected_batches is directly correlated to input batches and expected examples per
        class
        
        :return: 
        """
        from adlkit.data_provider.tests.mock_config import mock_classes, mock_class_index_map, \
            mock_data_sets, mock_file_index_list
        mock_classes = copy.deepcopy(mock_classes)
        mock_class_index_map = copy.deepcopy(mock_class_index_map)
        mock_data_sets = copy.deepcopy(mock_data_sets)
        mock_file_index_list = copy.deepcopy(mock_file_index_list)

        expected_batches = 2
        max_batches = 4
        batch_size = 500
        read_size = batch_size * 2
        # in_queue_str = 'ipc:///tmp/adlkit_socks_0'
        # malloc_queue_str = 'ipc:///tmp/adlkit_socks_1'
        # controller_socket_str = 'ipc:///tmp/adlkit_socks_2'

        # zmq_context = zmq.Context()
        # in_queue_socket = zmq_context.socket(zmq.PULL)
        # in_queue_socket.setsockopt(zmq.RCVHWM, 100)
        # in_queue_socket.connect(in_queue_str)
        #
        # malloc_queue_socket = zmq_context.socket(zmq.PULL)
        # malloc_queue_socket.setsockopt(zmq.RCVHWM, 100)
        # malloc_queue_socket.connect(malloc_queue_str)

        comm_driver = QueueCommDriver({'ctl': 10, 'in': 100, 'malloc': 100})

        filler = FileFiller(classes=mock_classes,
                            class_index_map=mock_class_index_map,
                            comm_driver=comm_driver,
                            worker_id=1,
                            read_size=read_size,
                            max_batches=max_batches,
                            wrap_examples=False,
                            data_sets=mock_data_sets,
                            file_index_list=mock_file_index_list,
                            io_ctlr=IOController()
                            # io_driver=H5DataIODriver(),
                            )

        filler.fill()

        out = int()
        start_time = datetime.datetime.utcnow()
        end_time = datetime.timedelta(seconds=10) + start_time
        while datetime.datetime.utcnow() < end_time:

            batch = comm_driver.read('in', block=False)
            if batch is not None:
                out += 1
                count = 0
                for item in batch:
                    count += item[3][1] - item[3][0]

                self.assertEquals(
                    count, read_size,
                    "read_batch was returned with read_size {0}, instead of batch_size {1} \n {"
                    "2}".format(count, read_size, batch))
            if out == expected_batches:
                break

        self.assertEquals(
            out, expected_batches,
            "test consumed {0} of {1} expected batches from the in_queue".
            format(out, expected_batches))