示例#1
0
def test_random_shuffling_buffer_stream_through():
    """Feed a 0:99 sequence through a RandomShufflingBuffer. Check that the order has changed."""
    input_sequence = range(100)
    a = _feed_a_sequence_through_the_queue(RandomShufflingBuffer(10, 3), input_sequence)
    b = _feed_a_sequence_through_the_queue(RandomShufflingBuffer(10, 3), input_sequence)
    assert len(a) == len(input_sequence)
    assert set(a) == set(b)
    assert a != b
示例#2
0
    def _iter_impl(self):
        """
        The Data Loader iterator stops the for-loop when reader runs out of samples.
        """
        # As we iterate over incoming samples, we are going to store them in `self._batch_acc`, until we have a batch of
        # the requested batch_size ready.

        keys = None
        if self.shuffling_queue_capacity > 0:
            # We can not know what is the reasonable number to use for the extra capacity, so we set a huge number
            # and give up on the unbound growth protection mechanism.
            min_after_dequeue = self.shuffling_queue_capacity - 1
            self._shuffling_buffer = RandomShufflingBuffer(
                self.shuffling_queue_capacity,
                min_after_retrieve=min_after_dequeue,
                extra_capacity=100000000)
        else:
            self._shuffling_buffer = NoopShufflingBuffer()

        for row in self.reader:
            # Default collate does not work nicely on namedtuples and treat them as lists
            # Using dict will result in the yielded structures being dicts as well
            row_as_dict = row._asdict()

            keys = row_as_dict.keys()

            # Promote some types that are incompatible with pytorch to be pytorch friendly.
            _sanitize_pytorch_types(row_as_dict)

            # Add rows to shuffling buffer
            if not self.reader.is_batched_reader:
                self._shuffling_buffer.add_many([row_as_dict])
            else:
                # Transposition:
                #   row_as_dict:        {'a': [1,2,3], 'b':[4,5,6]}
                #   row_group_as_tuple: [(1, 4), (2, 5), (3, 6)]
                # The order within a tuple is defined by key order in 'keys'
                row_group_as_tuple = list(zip(*(row_as_dict[k] for k in keys)))

                # Adding data as 'row-by-row' into a shuffling buffer. This is a pretty
                # slow implementation though. Probably can comeup with a faster way to shuffle,
                # perhaps at the expense of a larger memory consumption...
                self._shuffling_buffer.add_many(row_group_as_tuple)

            # _yield_batches will emit as much batches as are allowed by the shuffling_buffer (RandomShufflingBuffer
            # will avoid underflowing below a certain number of samples to guarantee some samples decorrelation)
            for batch in self._yield_batches(keys):
                yield batch

        # Once reader can not read new rows, we might still have a bunch of rows waiting in the shuffling buffer.
        # Telling shuffling buffer that we are finished allows to deplete the buffer completely, regardless its
        # min_after_dequeue setting.
        self._shuffling_buffer.finish()

        for batch in self._yield_batches(keys):
            yield batch

        # Yield the last and partial batch
        if self._batch_acc:
            yield self.collate_fn(self._batch_acc)
示例#3
0
def test_longer_random_sequence_of_queue_ops():
    """A long random sequence of added and retrieved values"""
    q = RandomShufflingBuffer(100, 80)

    for _ in six.moves.xrange(10000):
        if q.can_add():
            q.add_many(np.random.random((np.random.randint(1, 10),)))
        assert q.size < 100 + 10
        for _ in range(np.random.randint(1, 10)):
            if not q.can_retrieve():
                break
            # Make sure never get to less than `min_after_retrieve` elements
            assert 80 <= q.size
            q.retrieve()
示例#4
0
    def __init__(self,
                 reader,
                 batch_size=1,
                 collate_fn=decimal_friendly_collate,
                 shuffling_queue_capacity=0):
        """
        Initializes a data loader object, with a default collate.

        Number of epochs is defined by the configuration of the reader argument.

        An optional shuffling queue is created if shuffling_queue_capacity is greater than 0. No samples will be
        returned to a user by the ``DataLoader`` until the queue is full. After that, batches of `batch_size`
        will be created by uniformly sampling the shuffling queue. Once no more samples are available from the data
        reader, the shuffling queue is allowed to be consumed till no further samples are available.

        Note that the last returned batch could have less then ``batch_size`` samples.

        NOTE: if you are using ``make_batch_reader``, this shuffling queue will be randomizing the order of the
        entire batches and not changing the order of elements within a batch. This is likely not what you intend to do.

        :param reader: petastorm Reader instance
        :param batch_size: the number of items to return per batch; factored into the len() of this reader
        :param collate_fn: an optional callable to merge a list of samples to form a mini-batch.
        :param shuffling_queue_capacity: Queue capacity is passed to the underlying :class:`tf.RandomShuffleQueue`
          instance. If set to 0, no suffling will be done.
        """
        self.reader = reader
        self.batch_size = batch_size
        self.collate_fn = collate_fn

        # _batch_acc accumulates samples for a single batch.
        self._batch_acc = []
        if shuffling_queue_capacity > 0:
            # We can not know what is the reasonable number to use for the extra capacity, so we set a huge number
            # and give up on the unbound growth protection mechanism.
            min_after_dequeue = shuffling_queue_capacity - 1
            self._shuffling_buffer = RandomShufflingBuffer(
                shuffling_queue_capacity,
                min_after_retrieve=min_after_dequeue,
                extra_capacity=100000000)
        else:
            self._shuffling_buffer = NoopShufflingBuffer()
示例#5
0
def make_reader(dataset_url,
                schema_fields=None,
                reader_pool_type='thread',
                workers_count=10,
                pyarrow_serialize=False,
                results_queue_size=50,
                shuffle_row_groups=True,
                shuffle_row_drop_partitions=1,
                predicate=None,
                rowgroup_selector=None,
                num_epochs=1,
                cur_shard=None,
                shard_count=None,
                cache_type='null',
                cache_location=None,
                cache_size_limit=None,
                cache_row_size_estimate=None,
                cache_extra_settings=None,
                hdfs_driver='libhdfs3',
                reader_engine='reader_v1',
                reader_engine_params=None):
    """
    Creates an instance of Reader for reading Petastorm datasets. A Petastorm dataset is a dataset generated using
    :func:`~petastorm.etl.dataset_metadata.materialize_dataset` context manager as explained
    `here <https://petastorm.readthedocs.io/en/latest/readme_include.html#generating-a-dataset>`_.

    See :func:`~petastorm.make_batch_reader` to read from a Parquet store that was not generated using
    :func:`~petastorm.etl.dataset_metadata.materialize_dataset`.

    :param dataset_url: an filepath or a url to a parquet directory,
        e.g. ``'hdfs://some_hdfs_cluster/user/yevgeni/parquet8'``, or ``'file:///tmp/mydataset'``
        or ``'s3://bucket/mydataset'``.
    :param schema_fields: Can be: a list of unischema fields and/or regex pattern strings; ``None`` to read all fields;
            an NGram object, then it will return an NGram of the specified fields.
    :param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'process', 'dummy']
        denoting a thread pool, process pool, or running everything in the master thread. Defaults to 'thread'
    :param workers_count: An int for the number of workers to use in the reader pool. This only is used for the
        thread or process pool. Defaults to 10
    :param pyarrow_serialize: Whether to use pyarrow for serialization. Currently only applicable to process pool.
        Defaults to False.
    :param results_queue_size: Size of the results queue to store prefetched rows. Currently only applicable to
        thread reader pool type.
    :param shuffle_row_groups: Whether to shuffle row groups (the order in which full row groups are read)
    :param shuffle_row_drop_partitions: This is is a positive integer which determines how many partitions to
        break up a row group into for increased shuffling in exchange for worse performance (extra reads).
        For example if you specify 2 each row group read will drop half of the rows within every row group and
        read the remaining rows in separate reads. It is recommended to keep this number below the regular row
        group size in order to not waste reads which drop all rows.
    :param predicate: instance of :class:`.PredicateBase` object to filter rows to be returned by reader. The predicate
        will be passed a single row and must return a boolean value indicating whether to include it in the results.
    :param rowgroup_selector: instance of row group selector object to select row groups to be read
    :param num_epochs: An epoch is a single pass over all rows in the dataset. Setting ``num_epochs`` to
        ``None`` will result in an infinite number of epochs.
    :param cur_shard: An int denoting the current shard number. Each node reading a shard should
        pass in a unique shard number in the range [0, shard_count). shard_count must be supplied as well.
        Defaults to None
    :param shard_count: An int denoting the number of shards to break this dataset into. Defaults to None
    :param cache_type: A string denoting the cache type, if desired. Options are [None, 'null', 'local-disk'] to
        either have a null/noop cache or a cache implemented using diskcache. Caching is useful when communication
        to the main data store is either slow or expensive and the local machine has large enough storage
        to store entire dataset (or a partition of a dataset if shard_count is used). By default will be a null cache.
    :param cache_location: A string denoting the location or path of the cache.
    :param cache_size_limit: An int specifying the size limit of the cache in bytes
    :param cache_row_size_estimate: An int specifying the estimated size of a row in the dataset
    :param cache_extra_settings: A dictionary of extra settings to pass to the cache implementation,
    :param hdfs_driver: A string denoting the hdfs driver to use (if using a dataset on hdfs). Current choices are
        libhdfs (java through JNI) or libhdfs3 (C++)
    :param reader_engine: Multiple engine implementations exist ('reader_v1' and 'experimental_reader_v2'). 'reader_v1'
        (the default value) selects a stable reader implementation.
    :param reader_engine_params: For advanced usage: a dictionary with arguments passed directly to a reader
        implementation constructor chosen by ``reader_engine`` argument.  You should not use this parameter, unless you
        fine-tuning of a reader.
    :return: A :class:`Reader` object
    """

    if dataset_url is None or not isinstance(dataset_url, six.string_types):
        raise ValueError("""dataset_url must be a string""")

    dataset_url = dataset_url[:-1] if dataset_url[-1] == '/' else dataset_url
    logger.debug('dataset_url: %s', dataset_url)

    resolver = FilesystemResolver(dataset_url, hdfs_driver=hdfs_driver)
    filesystem = resolver.filesystem()
    dataset_path = resolver.get_dataset_path()

    if cache_type is None or cache_type == 'null':
        cache = NullCache()
    elif cache_type == 'local-disk':
        cache = LocalDiskCache(cache_location, cache_size_limit,
                               cache_row_size_estimate, **cache_extra_settings
                               or {})
    else:
        raise ValueError('Unknown cache_type: {}'.format(cache_type))

    # Fail if this is a non-petastorm dataset. Typically, a Parquet store will have hundred thousands rows in a single
    # rowgroup. Using PyDictReaderWorker or ReaderV2 implementation is very inefficient as it processes data on a
    # row by row basis. ArrowReaderWorker (used by make_batch_reader) is much more efficient in these cases.
    try:
        dataset_metadata.get_schema_from_dataset_url(dataset_url)
    except PetastormMetadataError:
        raise RuntimeError(
            'Currently make_reader supports reading only Petastorm datasets. '
            'To read from a non-Petastorm Parquet store use make_batch_reader')

    if reader_engine == 'reader_v1':
        if reader_pool_type == 'thread':
            reader_pool = ThreadPool(workers_count, results_queue_size)
        elif reader_pool_type == 'process':
            if pyarrow_serialize:
                serializer = PyArrowSerializer()
            else:
                serializer = PickleSerializer()
            reader_pool = ProcessPool(workers_count, serializer)
        elif reader_pool_type == 'dummy':
            reader_pool = DummyPool()
        else:
            raise ValueError(
                'Unknown reader_pool_type: {}'.format(reader_pool_type))

        # Create a dictionary with all ReaderV2 parameters, so we can merge with reader_engine_params if specified
        kwargs = {
            'schema_fields': schema_fields,
            'reader_pool': reader_pool,
            'shuffle_row_groups': shuffle_row_groups,
            'shuffle_row_drop_partitions': shuffle_row_drop_partitions,
            'predicate': predicate,
            'rowgroup_selector': rowgroup_selector,
            'num_epochs': num_epochs,
            'cur_shard': cur_shard,
            'shard_count': shard_count,
            'cache': cache,
        }

        if reader_engine_params:
            kwargs.update(reader_engine_params)

        try:
            return Reader(filesystem,
                          dataset_path,
                          worker_class=PyDictReaderWorker,
                          **kwargs)
        except PetastormMetadataError as e:
            logger.error('Unexpected exception: %s', str(e))
            raise RuntimeError(
                'make_reader has failed. If you were trying to open a Parquet store that was not '
                'created using Petastorm materialize_dataset and it contains only scalar columns, '
                'you may use make_batch_reader to read it.\n'
                'Inner exception: %s', str(e))

    elif reader_engine == 'experimental_reader_v2':
        if reader_pool_type == 'thread':
            decoder_pool = ThreadPoolExecutor(workers_count)
        elif reader_pool_type == 'process':
            decoder_pool = ProcessPoolExecutor(workers_count)
        elif reader_pool_type == 'dummy':
            decoder_pool = SameThreadExecutor()
        else:
            raise ValueError(
                'Unknown reader_pool_type: {}'.format(reader_pool_type))

        # TODO(yevgeni): once ReaderV2 is ready to be out of experimental status, we should extend
        # the make_reader interfaces to take shuffling buffer parameters explicitly
        shuffling_queue = RandomShufflingBuffer(
            1000, 800) if shuffle_row_groups else NoopShufflingBuffer()

        # Create a dictionary with all ReaderV2 parameters, so we can merge with reader_engine_params if specified
        kwargs = {
            'schema_fields': schema_fields,
            'predicate': predicate,
            'rowgroup_selector': rowgroup_selector,
            'num_epochs': num_epochs,
            'cur_shard': cur_shard,
            'shard_count': shard_count,
            'cache': cache,
            'decoder_pool': decoder_pool,
            'shuffling_queue': shuffling_queue,
            'shuffle_row_groups': shuffle_row_groups,
            'shuffle_row_drop_partitions': shuffle_row_drop_partitions,
        }

        if reader_engine_params:
            kwargs.update(reader_engine_params)

        return ReaderV2(dataset_url, **kwargs)

    else:
        raise ValueError(
            'Unexpected value of reader_engine argument \'%s\'. '
            'Supported reader_engine values are \'reader_v1\' and \'experimental_reader_v2\'',
            reader_engine)
示例#6
0
def reader_v2_throughput(dataset_url, field_regex=None, warmup_cycles_count=300, measure_cycles_count=1000,
                         pool_type=WorkerPoolType.THREAD, loaders_count=3, decoders_count=3,
                         read_method=ReadMethod.PYTHON, shuffling_queue_size=500, min_after_dequeue=400,
                         reader_extra_args=None, pyarrow_serialize=False, spawn_new_process=True):
    """Constructs a ReaderV2 instance and uses it to performs throughput measurements.

    The function will spawn a new process if ``spawn_separate_process`` is set. This is needed to make memory footprint
    measurements accurate.

    :param dataset_url: A url of the dataset to be used for measurements.
    :param field_regex:  A list of regular expressions. Only fields that match one of the regex patterns will be used
      during the benchmark.
    :param warmup_cycles_count: Number of warmup cycles. During warmup cycles no measurements are being recorded.
    :param measure_cycles_count: Number of measurements cycles. Only time elapsed during measurements cycles are used
      in throughput calculations.
    :param pool_type: :class:`WorkerPoolType` enum value.
    :param loaders_count: Number of IO threads.
    :param decoders_count: Number of threads or processes used for decoding. ``pool_type`` parameter defines
      whether multiple processes or threads are used for parallel decoding.
    :param read_method:  An enum :class:`ReadMethod` that defines whether a :class:`petastorm.reader.Reader` will be
      used.
    :param shuffling_queue_size: Maximum number of elements in the shuffling queue.
    :param min_after_dequeue: Minimum number of elements in a shuffling queue before entries can be read from it.
    :param reader_extra_args: Extra arguments that would be passed to Reader constructor.
    :param pyarrow_serialize: When True, pyarrow.serialize library will be used for serializing decoded payloads.
    :param spawn_new_process: This function will respawn itself in a new process if the argument is True. Spawning
      a new process is needed to get an accurate memory footprint.

    :return: An instance of ``BenchmarkResult`` namedtuple with the results of the benchmark. The namedtuple has
      the following fields: `time_mean`, `samples_per_second`, `memory_info` and `cpu`
    """
    if not reader_extra_args:
        reader_extra_args = dict()

    if spawn_new_process:
        args = copy.deepcopy(locals())
        args['spawn_new_process'] = False
        executor = ProcessPoolExecutor(1)
        future = executor.submit(reader_v2_throughput, **args)
        return future.result()

    logger.info('Arguments: %s', locals())

    if 'schema_fields' not in reader_extra_args:
        unischema_fields = match_unischema_fields(get_schema_from_dataset_url(dataset_url), field_regex)
        reader_extra_args['schema_fields'] = unischema_fields

    logger.info('Fields used in the benchmark: %s', str(reader_extra_args['schema_fields']))

    decoder_pool_executor = _create_concurrent_executor(pool_type, decoders_count)

    with ReaderV2(dataset_url, num_epochs=None,
                  loader_pool=ThreadPoolExecutor(loaders_count),
                  decoder_pool=decoder_pool_executor,
                  shuffling_queue=RandomShufflingBuffer(shuffling_queue_size, min_after_dequeue),
                  **reader_extra_args) as reader:

        if read_method == ReadMethod.PYTHON:
            result = _time_warmup_and_work(reader, warmup_cycles_count, measure_cycles_count)
        elif read_method == ReadMethod.TF:
            result = _time_warmup_and_work_tf(reader, warmup_cycles_count, measure_cycles_count, 0, 0)
        else:
            raise RuntimeError('Unexpected reader_type value: %s', str(read_method))

    return result
示例#7
0
def make_reader(dataset_url,
                schema_fields=None,
                reader_pool_type='thread', workers_count=10, pyarrow_serialize=False,
                shuffle_row_groups=True, shuffle_row_drop_partitions=1,
                predicate=None,
                rowgroup_selector=None,
                num_epochs=1,
                cur_shard=None, shard_count=None,
                cache_type='null', cache_location=None, cache_size_limit=None,
                cache_row_size_estimate=None, cache_extra_settings=None,
                hdfs_driver='libhdfs3',
                infer_schema=False,
                reader_engine='reader_v1', reader_engine_params=None):
    """
    Factory convenience method for :class:`Reader`.

    :param dataset_url: an filepath or a url to a parquet directory,
        e.g. ``'hdfs://some_hdfs_cluster/user/yevgeni/parquet8'``, or ``'file:///tmp/mydataset'``
        or ``'s3://bucket/mydataset'``.
    :param schema_fields: Either list of unischema fields to subset, or ``None`` to read all fields.
            OR an NGram object, then it will return an NGram of the specified properties.
    :param reader_pool_type: A string denoting the reader pool type. Should be one of ['thread', 'process', 'dummy']
        denoting a thread pool, process pool, or running everything in the master thread. Defaults to 'thread'
    :param workers_count: An int for the number of workers to use in the reader pool. This only is used for the
        thread or process pool. Defaults to 10
    :param pyarrow_serialize: Whether to use pyarrow for serialization. Currently only applicable to process pool.
        Defaults to False.
    :param shuffle_row_groups: Whether to shuffle row groups (the order in which full row groups are read)
    :param shuffle_row_drop_partitions: This is is a positive integer which determines how many partitions to
        break up a row group into for increased shuffling in exchange for worse performance (extra reads).
        For example if you specify 2 each row group read will drop half of the rows within every row group and
        read the remaining rows in separate reads. It is recommended to keep this number below the regular row
        group size in order to not waste reads which drop all rows.
    :param predicate: instance of :class:`.PredicateBase` object to filter rows to be returned by reader.
    :param rowgroup_selector: instance of row group selector object to select row groups to be read
    :param num_epochs: An epoch is a single pass over all rows in the dataset. Setting ``num_epochs`` to
        ``None`` will result in an infinite number of epochs.
    :param cur_shard: An int denoting the current shard number. Each node reading a shard should
        pass in a unique shard number in the range [0, shard_count). shard_count must be supplied as well.
        Defaults to None
    :param shard_count: An int denoting the number of shards to break this dataset into. Defaults to None
    :param cache_type: A string denoting the cache type, if desired. Options are [None, 'null', 'local-disk'] to
        either have a null/noop cache or a cache implemented using diskcache. Caching is useful when communication
        to the main data store is either slow or expensive and the local machine has large enough storage
        to store entire dataset (or a partition of a dataset if shard_count is used). By default will be a null cache.
    :param cache_location: A string denoting the location or path of the cache.
    :param cache_size_limit: An int specifying the size limit of the cache in bytes
    :param cache_row_size_estimate: An int specifying the estimated size of a row in the dataset
    :param cache_extra_settings: A dictionary of extra settings to pass to the cache implementation,
    :param hdfs_driver: A string denoting the hdfs driver to use (if using a dataset on hdfs). Current choices are
        libhdfs (java through JNI) or libhdfs3 (C++)
    :param infer_schema: Whether to infer the unischema object from the parquet schema.
            Only works for schemas containing certain scalar type. This option allows getting around explicitly
            generating petastorm metadata using :func:`petastorm.etl.dataset_metadata.materialize_dataset` or
            petastorm-generate-metadata.py
    :param reader_engine: Multiple engine implementations exist ('reader_v1' and 'experimental_reader_v2'). 'reader_v1'
        (the default value) selects a stable reader implementation.
    :param reader_engine_params: For advanced usage: a dictionary with arguments passed directly to a reader
        implementation constructor chosen by ``reader_engine`` argument.  You should not use this parameter, unless you
        fine-tuning of a reader.
    :return: A :class:`Reader` object
    """

    if dataset_url is None or not isinstance(dataset_url, six.string_types):
        raise ValueError("""dataset_url must be a string""")

    dataset_url = dataset_url[:-1] if dataset_url[-1] == '/' else dataset_url
    logger.debug('dataset_url: %s', dataset_url)

    resolver = FilesystemResolver(dataset_url, hdfs_driver=hdfs_driver)
    filesystem = resolver.filesystem()
    dataset_path = resolver.get_dataset_path()

    if cache_type is None or cache_type == 'null':
        cache = NullCache()
    elif cache_type == 'local-disk':
        cache = LocalDiskCache(cache_location, cache_size_limit, cache_row_size_estimate, **cache_extra_settings or {})
    else:
        raise ValueError('Unknown cache_type: {}'.format(cache_type))

    if reader_engine == 'reader_v1':
        if reader_pool_type == 'thread':
            reader_pool = ThreadPool(workers_count)
        elif reader_pool_type == 'process':
            reader_pool = ProcessPool(workers_count, pyarrow_serialize=pyarrow_serialize)
        elif reader_pool_type == 'dummy':
            reader_pool = DummyPool()
        else:
            raise ValueError('Unknown reader_pool_type: {}'.format(reader_pool_type))

        # Create a dictionary with all ReaderV2 parameters, so we can merge with reader_engine_params if specified
        kwargs = {
            'schema_fields': schema_fields,
            'reader_pool': reader_pool,
            'shuffle_row_groups': shuffle_row_groups,
            'shuffle_row_drop_partitions': shuffle_row_drop_partitions,
            'predicate': predicate,
            'rowgroup_selector': rowgroup_selector,
            'num_epochs': num_epochs,
            'cur_shard': cur_shard,
            'shard_count': shard_count,
            'cache': cache,
            'infer_schema': infer_schema,
        }

        if reader_engine_params:
            kwargs.update(reader_engine_params)

        return Reader(filesystem, dataset_path, **kwargs)
    elif reader_engine == 'experimental_reader_v2':
        if reader_pool_type == 'thread':
            decoder_pool = ThreadPoolExecutor(workers_count)
        elif reader_pool_type == 'process':
            decoder_pool = ProcessPoolExecutor(workers_count)
        elif reader_pool_type == 'dummy':
            decoder_pool = SameThreadExecutor()
        else:
            raise ValueError('Unknown reader_pool_type: {}'.format(reader_pool_type))

        # TODO(yevgeni): once ReaderV2 is ready to be out of experimental status, we should extend
        # the make_reader interfaces to take shuffling buffer parameters explicitly
        shuffling_queue = RandomShufflingBuffer(1000, 800) if shuffle_row_groups else NoopShufflingBuffer()

        # Create a dictionary with all ReaderV2 parameters, so we can merge with reader_engine_params if specified
        kwargs = {
            'schema_fields': schema_fields,
            'predicate': predicate,
            'rowgroup_selector': rowgroup_selector,
            'num_epochs': num_epochs,
            'cur_shard': cur_shard,
            'shard_count': shard_count,
            'cache': cache,
            'decoder_pool': decoder_pool,
            'shuffling_queue': shuffling_queue,
            'shuffle_row_groups': shuffle_row_groups,
            'shuffle_row_drop_partitions': shuffle_row_drop_partitions,
            'infer_schema': infer_schema,
        }

        if reader_engine_params:
            kwargs.update(reader_engine_params)

        return ReaderV2(dataset_url, **kwargs)

    else:
        raise ValueError('Unexpected value of reader_engine argument \'%s\'. '
                         'Supported reader_engine values are \'reader_v1\' and \'experimental_reader_v2\'',
                         reader_engine)
示例#8
0
def test_random_shuffling_buffer_can_add_retrieve_flags():
    """Check can_add/can_retrieve flags at all possible states"""
    q = RandomShufflingBuffer(5, 3)

    # Empty buffer. Can start adding, nothing to retrieve yet
    assert q.size == 0
    assert q.can_add()
    assert not q.can_retrieve()

    # Under min_after_retrieve elements, so can not retrieve just yet
    q.add_many([1, 2])
    assert q.can_add()
    assert not q.can_retrieve()
    assert q.size == 2

    # Got to min_after_retrieve elements, can start retrieving
    q.add_many([3])
    assert q.can_retrieve()
    assert q.size == 3

    # But when we retrieve we are again under min_after_retrieve, so can not retrieve again
    q.retrieve()
    assert not q.can_retrieve()
    assert q.size == 2

    # Getting back to the retrievable state with enough items in the buffer
    q.add_many([4, 5])
    assert q.can_add()
    assert q.can_retrieve()
    assert q.size == 4

    # Can overrun the capacity (as long as below extra_capacity), but can not add if we are above
    # shuffling_buffer_capacity
    q.add_many([6, 7, 8, 9])
    assert not q.can_add()
    assert q.can_retrieve()
    assert q.size == 8

    # Getting one out. Still have more than shuffling_buffer_capacity
    q.retrieve()
    assert not q.can_add()
    assert q.can_retrieve()
    assert q.size == 7

    # Retrieve enough to get back to addable state
    [q.retrieve() for _ in range(4)]
    assert q.can_add()
    assert q.can_retrieve()
    assert q.size == 3

    # Retrieve the last element so we go under min_after_retrieve and can not retrieve any more
    q.retrieve()
    assert q.can_add()
    assert not q.can_retrieve()
    with pytest.raises(RuntimeError):
        q.retrieve()

    assert q.size == 2

    # finish() will allow us to deplete the buffer completely
    q.finish()
    assert not q.can_add()
    assert q.can_retrieve()
    assert q.size == 2

    q.retrieve()
    assert not q.can_add()
    assert q.can_retrieve()
    assert q.size == 1

    q.retrieve()
    assert not q.can_add()
    assert not q.can_retrieve()
    assert q.size == 0