def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, rowgroup_selector=None, reader_pool=None, num_epochs=1, cur_shard=None, shard_count=None, cache=None, worker_class=None, transform_spec=None): """Initializes a reader object. :param pyarrow_filesystem: An instance of ``pyarrow.FileSystem`` that will be used. If not specified, then a default one will be selected based on the url (only for ``hdfs://`` or ``file://``; for ``s3://`` support, use ``make_reader``). The default hdfs driver is ``libhdfs3``. If you want to to use ``libhdfs``, use ``pyarrow_filesystem=pyarrow.hdfs.connect('hdfs:///some/path', driver='libhdfs')``. :param dataset_path: filepath to a parquet directory on the specified filesystem. e.g. ``'/user/yevgeni/parquet8'``, or ``'/tmp/mydataset'``. :param schema_fields: Either list of unischema fields to subset, or ``None`` to read all fields. OR an NGram object, then it will return an NGram of the specified properties. :param shuffle_row_groups: Whether to shuffle row groups (the order in which full row groups are read) :param shuffle_row_drop_partitions: This is is a positive integer which determines how many partitions to break up a row group into for increased shuffling in exchange for worse performance (extra reads). For example if you specify 2 each row group read will drop half of the rows within every row group and read the remaining rows in separate reads. It is recommended to keep this number below the regular row group size in order to not waste reads which drop all rows. :param predicate: instance of predicate object to filter rows to be returned by reader. :param rowgroup_selector: instance of row group selector object to select row groups to be read :param reader_pool: parallelization pool. ``ThreadPool(10)`` (10 threads) is used by default. This pool is a custom implementation used to parallelize reading data from the dataset. Any object from workers_pool package can be used (e.g. :class:`petastorm.workers_pool.process_pool.ProcessPool`). :param num_epochs: An epoch is a single pass over all rows in the dataset. Setting ``num_epochs`` to ``None`` will result in an infinite number of epochs. :param cur_shard: An int denoting the current shard number used. Each reader instance should pass in a unique shard number in the range ``[0, shard_count)``. ``shard_count`` must be supplied as well. Defaults to None :param shard_count: An int denoting the number of shard partitions there are. Defaults to None :param cache: An object conforming to :class:`.CacheBase` interface. Before loading row groups from a parquet file the Reader will attempt to load these values from cache. Caching is useful when communication to the main data store is either slow or expensive and the local machine has large enough storage to store entire dataset (or a partition of a dataset if shards are used). By default, use the :class:`.NullCache` implementation. :param worker_class: This is the class that will be instantiated on a different thread/process. It's responsibility is to load and filter the data. """ # 1. Open the parquet storage (dataset) # 2. Get a list of all groups # 3. Filter rowgroups # a. predicates # b. row-group selector (our indexing mechanism) # c. partition: used to get a subset of data for distributed training # 4. Create a rowgroup ventilator object # 5. Start workers pool if not (isinstance(schema_fields, collections.Iterable) or isinstance(schema_fields, NGram) or schema_fields is None): raise ValueError( 'Fields must be either None, an iterable collection of Unischema fields ' 'or an NGram object.') self.ngram = schema_fields if isinstance(schema_fields, NGram) else None # By default, use original method of working with list of dictionaries and not arrow tables worker_class = worker_class or PyDictReaderWorker self._results_queue_reader = worker_class.new_results_queue_reader() if self.ngram and not self.ngram.timestamp_overlap and shuffle_row_drop_partitions > 1: raise NotImplementedError( 'Using timestamp_overlap=False is not implemented with' ' shuffle_options.shuffle_row_drop_partitions > 1') cache = cache or NullCache() self._workers_pool = reader_pool or ThreadPool(10) # 1. Resolve dataset path (hdfs://, file://) and open the parquet storage (dataset) self.dataset = pq.ParquetDataset(dataset_path, filesystem=pyarrow_filesystem, validate_schema=False) stored_schema = infer_or_load_unischema(self.dataset) # Make a schema view (a view is a Unischema containing only a subset of fields # Will raise an exception if invalid schema fields are in schema_fields if self.ngram: fields = self.ngram.get_field_names_at_all_timesteps() else: fields = schema_fields if isinstance( schema_fields, collections.Iterable) else None storage_schema = stored_schema.create_schema_view( fields) if fields else stored_schema if transform_spec: self.schema = transform_schema(storage_schema, transform_spec) else: self.schema = storage_schema # 2. Get a list of all row groups row_groups = dataset_metadata.load_row_groups(self.dataset) # 3. Filter rowgroups filtered_row_group_indexes, worker_predicate = self._filter_row_groups( self.dataset, row_groups, predicate, rowgroup_selector, cur_shard, shard_count) # 4. Create a rowgroup ventilator object normalized_shuffle_row_drop_partitions = \ self._normalize_shuffle_options(shuffle_row_drop_partitions, self.dataset) self.ventilator = self._create_ventilator( filtered_row_group_indexes, shuffle_row_groups, normalized_shuffle_row_drop_partitions, num_epochs, worker_predicate, self._workers_pool.workers_count + _VENTILATE_EXTRA_ROWGROUPS) # 5. Start workers pool self._workers_pool.start( worker_class, (pyarrow_filesystem, dataset_path, storage_schema, self.ngram, row_groups, cache, transform_spec), ventilator=self.ventilator) logger.debug('Workers pool started') self.last_row_consumed = False
def test_noop_transform(): transformed_schema = transform_schema( TestSchema, TransformSpec(lambda x: x, edit_fields=None, removed_fields=None)) assert transformed_schema.fields == TestSchema.fields