def _transpose_dict_list(dict_list): """Transpose a nested dict[list] into a list[nested dict].""" # 1. Unstack numpy arrays into list dict_list = utils.map_nested(_np_to_list, dict_list, dict_only=True) # 2. Extract the sequence length (and ensure the length is constant for all # elements) length = {'value': None} # dict because `nonlocal` is Python3 only def update_length(elem): if length['value'] is None: length['value'] = len(elem) elif length['value'] != len(elem): raise ValueError( 'The length of all elements of one sequence should be the same. ' 'Got {} != {}'.format(length['value'], len(elem))) return elem utils.map_nested(update_length, dict_list, dict_only=True) # 3. Extract each individual elements return [ utils.map_nested(lambda elem: elem[i], dict_list, dict_only=True) # pylint: disable=cell-var-from-loop for i in range(length['value']) ]
def assertAllEqualNested(self, d1, d2, *, atol: Optional[float] = None): """Same as assertAllEqual but compatible with nested dict. Args: d1: First element to compare d2: Second element to compare atol: If given, perform a close float comparison. Otherwise, perform an exact comparison """ if isinstance(d1, dict): # assertAllEqual do not works well with dictionaries so assert # on each individual elements instead zipped_examples = utils.zip_nested(d1, d2, dict_only=True) utils.map_nested( # recursively call assertAllEqualNested in case there is a dataset. lambda x: self.assertAllEqualNested(x[0], x[1], atol=atol), zipped_examples, dict_only=True, ) elif isinstance(d1, (tf.data.Dataset, dataset_utils._IterableDataset)): # pylint: disable=protected-access # Checks length and elements of the dataset. At the moment, more than one # level of nested datasets is not supported. self.assertEqual(len(d1), len(d2)) for ex1, ex2 in zip(d1, d2): self.assertAllEqualNested(ex1, ex2, atol=atol) elif atol: self.assertAllClose(d1, d2, atol=atol) else: self.assertAllEqual(d1, d2)
def assertFeatureTest(self, fdict, test, feature, shape, dtype): """Test that encode=>decoding of a value works correctly.""" # self._process_subtest_exp(e) input_value = {"inner": test.value} if test.raise_cls is not None: with self._subTest("raise"): if not test.raise_msg: raise ValueError( "test.raise_msg should be set with {} for test {}". format(test.raise_cls, type(feature))) with self.assertRaisesWithPredicateMatch( test.raise_cls, test.raise_msg): features_encode_decode(fdict, input_value) else: # Test the serialization only if test.expected_serialized is not None: with self._subTest("out_serialize"): self.assertEqual( test.expected_serialized, feature.encode_example(test.value), ) # Assert the returned type match the expected one with self._subTest("out"): out = features_encode_decode(fdict, input_value, as_tensor=True) out = out["inner"] with self._subTest("dtype"): out_dtypes = utils.map_nested(lambda s: s.dtype, out) self.assertEqual(out_dtypes, feature.dtype) with self._subTest("shape"): # For shape, because (None, 3) match with (5, 3), we use # tf.TensorShape.assert_is_compatible_with on each of the elements out_shapes = utils.zip_nested(out, feature.shape) utils.map_nested( lambda x: x[0].shape.assert_is_compatible_with(x[1]), out_shapes) # Test serialization + decoding from disk with self._subTest("out_value"): decoded_examples = features_encode_decode(fdict, input_value) decoded_examples = decoded_examples["inner"] if isinstance(decoded_examples, dict): # assertAllEqual do not works well with dictionaries so assert # on each individual elements instead zipped_examples = utils.zip_nested( test.expected, decoded_examples, dict_only=True, ) utils.map_nested( lambda x: self.assertAllEqual(x[0], x[1]), zipped_examples, dict_only=True, ) else: self.assertAllEqual(test.expected, decoded_examples)
def assertFeatureTest(self, fdict, test, feature, shape, dtype): """Test that encode=>decoding of a value works correctly.""" # test feature.encode_example can be pickled and unpickled for beam. dill.loads(dill.dumps(feature.encode_example)) input_value = {'inner': test.value} if test.raise_cls is not None: with self._subTest('raise'): if not test.raise_msg: raise ValueError( 'test.raise_msg should be set with {} for test {}'. format(test.raise_cls, type(feature))) with self.assertRaisesWithPredicateMatch( test.raise_cls, test.raise_msg): features_encode_decode(fdict, input_value, decoders=test.decoders) else: # Test the serialization only if test.expected_serialized is not None: with self._subTest('out_serialize'): self.assertEqual( test.expected_serialized, feature.encode_example(test.value), ) # Test serialization + decoding from disk with self._subTest('out'): out_tensor, out_numpy = features_encode_decode( fdict, input_value, decoders={'inner': test.decoders}, ) out_tensor = out_tensor['inner'] out_numpy = out_numpy['inner'] # Assert the returned type match the expected one with self._subTest('dtype'): out_dtypes = tf.nest.map_structure(lambda s: s.dtype, out_tensor) self.assertEqual(out_dtypes, test.dtype or feature.dtype) with self._subTest('shape'): # For shape, because (None, 3) match with (5, 3), we use # tf.TensorShape.assert_is_compatible_with on each of the elements expected_shape = feature.shape if test.shape is None else test.shape out_shapes = utils.zip_nested(out_tensor, expected_shape) utils.map_nested( lambda x: x[0].shape.assert_is_compatible_with(x[1]), out_shapes) # Assert value with self._subTest('out_value'): # Eventually construct the tf.RaggedTensor expected = tf.nest.map_structure( lambda t: t.build() if isinstance(t, RaggedConstant) else t, test.expected) self.assertAllEqualNested(out_numpy, expected)
def assertAllEqualNested(self, d1, d2): """Same as assertAllEqual but compatible with nested dict.""" if isinstance(d1, dict): # assertAllEqual do not works well with dictionaries so assert # on each individual elements instead zipped_examples = utils.zip_nested(d1, d2, dict_only=True) utils.map_nested( lambda x: self.assertAllEqual(x[0], x[1]), zipped_examples, dict_only=True, ) else: self.assertAllEqual(d1, d2)
def _map_promise(map_fn, all_inputs, async_): """Map the function into each element and resolve the promise.""" all_promises = utils.map_nested(map_fn, all_inputs) # Apply the function if async_: # TODO(tfds): Fix for nested case if isinstance(all_promises, dict): merged_promise = promise.Promise.for_dict(all_promises) elif isinstance(all_promises, list): merged_promise = promise.Promise.all(all_promises) else: merged_promise = all_promises return merged_promise return utils.map_nested(lambda p: p.get(), all_promises)
def get_serialized_info(self): """Return the tf-example features for the adapter, as stored on disk. This function indicates how this feature is encoded on file internally. The DatasetBuilder are written on disk as tf.train.Example proto. Ex: ``` return { 'image': tf.VarLenFeature(tf.uint8): 'height': tf.FixedLenFeature((), tf.int32), 'width': tf.FixedLenFeature((), tf.int32), } ``` FeatureConnector which are not containers should return the feature proto directly: ``` return tf.FixedLenFeature((64, 64), tf.uint8) ``` If not defined, the retuned values are automatically deduced from the `get_tensor_info` function. Returns: features: Either a dict of feature proto object, or a feature proto object """ return utils.map_nested(to_serialized_field, self.get_tensor_info())
def read(self, name, instructions, split_infos, shuffle_files=False): """Returns tf.data.Dataset instance(s). Args: name (str): name of the dataset. instructions (ReadInstruction, List[], Dict[]): instruction(s) to read. Instructions can be string and will then be passed to the Instruction constructor as it. split_infos (list of SplitInfo proto): the available splits for dataset. shuffle_files (bool): defaults to False. If True, input files are shuffled before being read. Returns: a single tf.data.Dataset instance if instruction is a single ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset corresponding to given instructions param shape. """ name2shard_lengths = {info.name: info.shard_lengths for info in split_infos} name2len = {name: sum(lengths) for name, lengths in name2shard_lengths.items()} read_instruction = functools.partial( _read_single_instruction, parse_fn=self._parser.parse_example, name=name, path=self._path, name2len=name2len, name2shard_lengths=name2shard_lengths, shuffle_files=shuffle_files) datasets = utils.map_nested(read_instruction, instructions, map_tuple=True) return datasets
def get_serialized_info(self): """See base class for details.""" # Add the additional length dimension to every serialized features def add_length_dim(serialized_info): """Add the length dimension to the serialized_info. Args: serialized_info: One of tf.io.FixedLenFeature, tf.io.VarLenFeature,... Returns: new_serialized_info: serialized_info with extended first dimension """ if isinstance(serialized_info, tf.io.FixedLenFeature): if self._length is not None: return tf.io.FixedLenFeature( shape=(self._length, ) + serialized_info.shape, dtype=serialized_info.dtype, ) else: return tf.io.FixedLenSequenceFeature( shape=serialized_info.shape, dtype=serialized_info.dtype, allow_missing=True, ) elif isinstance(serialized_info, tf.io.VarLenFeature): return serialized_info else: raise ValueError( 'FixedLenSequenceFeature not supported inside Sequence') return serialized_info tensor_info = self._feature.get_serialized_info() return utils.map_nested(add_length_dim, tensor_info)
def _get_dl_download_result(self, url): if self.DL_DOWNLOAD_RESULT is None: # This is only to be backwards compatible with old approach. # In the future it will be replaced with using self.example_dir. return self._get_dl_extract_result(url) return utils.map_nested(lambda fname: os.path.join(self.example_dir, fname), self.DL_DOWNLOAD_RESULT)
def _get_dl_extract_result(self, url): tf.nest.map_structure(self._add_url, url) del url if self.DL_EXTRACT_RESULT is None: return self.example_dir return utils.map_nested(lambda fname: os.path.join(self.example_dir, fname), self.DL_EXTRACT_RESULT)
def _parallel_run(function, input_struct, max_workers=1): """Run the function on each element of data_struct using a pool of workers.""" # Distribute the work in a pool launch_thread_pool = concurrent.futures.ThreadPoolExecutor with launch_thread_pool(max_workers=max_workers) as executor: def launch_worker(value): return executor.submit(function, value) output_struct = utils.map_nested(launch_worker, input_struct) # Gather all results once all workers have finished def gather_results(value): return value.result() return utils.map_nested(gather_results, output_struct)
def as_dataset(self, split=None, batch_size=None, shuffle_files=None, as_supervised=False): """Constructs a `tf.data.Dataset`. Callers must pass arguments as keyword arguments. Args: split: `tfds.core.SplitBase`, which subset(s) of the data to read. If None (default), returns all splits in a dict `<key: tfds.Split, value: tf.data.Dataset>`. batch_size: `int`, batch size. Note that variable-length features will be 0-padded if `batch_size` is set. Users that want more custom behavior should use `batch_size=None` and use the `tf.data` API to construct a custom pipeline. If `batch_size == -1`, will return feature dictionaries of the whole dataset with `tf.Tensor`s instead of a `tf.data.Dataset`. shuffle_files: `bool`, whether to shuffle the input files. Defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise. as_supervised: `bool`, if `True`, the returned `tf.data.Dataset` will have a 2-tuple structure `(input, label)` according to `builder.info.supervised_keys`. If `False`, the default, the returned `tf.data.Dataset` will have a dictionary with all the features. Returns: `tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value: tfds.data.Dataset>`. If `batch_size` is -1, will return feature dictionaries containing the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`. """ logging.info("Constructing tf.data.Dataset for split %s, from %s", split, self._data_dir) if not tf.io.gfile.exists(self._data_dir): raise AssertionError(( "Dataset %s: could not find data in %s. Please make sure to call " "dataset_builder.download_and_prepare(), or pass download=True to " "tfds.load() before trying to access the tf.data.Dataset object." ) % (self.name, self._data_dir_root)) # By default, return all splits if split is None: split = {s: s for s in self.info.splits} # Create a dataset for each of the given splits build_single_dataset = functools.partial( self._build_single_dataset, shuffle_files=shuffle_files, batch_size=batch_size, as_supervised=as_supervised, ) datasets = utils.map_nested(build_single_dataset, split, map_tuple=True) return datasets
def assertFeature(self, specs, serialized_info, tests): """Test the TFRecordExampleAdapter encoding.""" adapter = file_format_adapter.TFRecordExampleAdapter(specs) with self._subTest("serialized_info"): self.assertEqual(serialized_info, adapter._parser._build_feature_specs()) for i, test in enumerate(tests): with self._subTest(str(i)): if test.raise_cls is not None: with self.assertRaisesWithPredicateMatch( test.raise_cls, test.raise_msg): adapter._serializer.serialize_example(test.value) continue serialized = adapter._serializer.serialize_example(test.value) if test.expected_serialized is not None: example_proto = tf.train.Example() example_proto.ParseFromString(serialized) expected_proto = tf.train.Example( features=tf.train.Features( feature=test.expected_serialized)) self.assertEqual(expected_proto, example_proto) example = _parse_example(serialized, adapter._parser.parse_example) with self._subTest("dtype"): out_dtypes = utils.map_nested(lambda s: s.dtype, example) expected_dtypes = utils.map_nested(lambda s: s.dtype, specs) self.assertEqual(out_dtypes, expected_dtypes) with self._subTest("shape"): # For shape, because (None, 3) match with (5, 3), we use # tf.TensorShape.assert_is_compatible_with on each of the elements utils.map_nested( lambda x: x[0].shape.assert_is_compatible_with(x[1]. shape), utils.zip_nested(example, specs)) np_example = dataset_utils.as_numpy(example) self.assertAllEqualNested(np_example, test.expected)
def get_tensor_info(self): """See base class for details.""" # Add the additional length dimension to every shape def add_length_dim(tensor_info): tensor_info = feature_lib.TensorInfo.copy_from(tensor_info) tensor_info.shape = (self._length, ) + tensor_info.shape return tensor_info tensor_info = self._feature.get_tensor_info() return utils.map_nested(add_length_dim, tensor_info)
def assertAllEqualNested(self, d1, d2): """Same as assertAllEqual but compatible with nested dict.""" if isinstance(d1, dict): # assertAllEqual do not works well with dictionaries so assert # on each individual elements instead zipped_examples = utils.zip_nested(d1, d2, dict_only=True) utils.map_nested( # recursively call assertAllEqualNested in case there is a dataset. lambda x: self.assertAllEqualNested(x[0], x[1]), zipped_examples, dict_only=True, ) elif isinstance(d1, (tf.data.Dataset, dataset_utils._IterableDataset)): # pylint: disable=protected-access # Checks length and elements of the dataset. At the moment, more than one # level of nested datasets is not supported. self.assertEqual(len(d1), len(d2)) for ex1, ex2 in zip(d1, d2): self.assertAllEqualNested(ex1, ex2) else: self.assertAllEqual(d1, d2)
def read( self, name, instructions, split_infos, read_config, shuffle_files, ): """Returns tf.data.Dataset instance(s). Args: name (str): name of the dataset. instructions (ReadInstruction, List[], Dict[]): instruction(s) to read. Instructions can be string and will then be passed to the Instruction constructor as it. split_infos (list of SplitInfo proto): the available splits for dataset. read_config: `tfds.ReadConfig`, the input pipeline options shuffle_files (bool): If True, input files are shuffled before being read. Returns: a single tf.data.Dataset instance if instruction is a single ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset corresponding to given instructions param shape. """ def _read_instruction_to_file_instructions(instruction): file_instructions = make_file_instructions(name, split_infos, instruction) files = file_instructions.file_instructions if not files: msg = 'Instruction "%s" corresponds to no data!' % instruction raise AssertionError(msg) return tuple(files) files = utils.map_nested(_read_instruction_to_file_instructions, instructions, map_tuple=False) return utils.map_nested(functools.partial(self.read_files, read_config=read_config, shuffle_files=shuffle_files), files, map_tuple=False)
def get_tensor_info(self): """See base class for details.""" # Add the additional length dimension to every shape def add_length_dim(tensor_info): return feature_lib.TensorInfo( shape=(self._length,) + tensor_info.shape, dtype=tensor_info.dtype, ) tensor_info = super(SequenceDict, self).get_tensor_info() return utils.map_nested(add_length_dim, tensor_info)
def get_serialized_info(self): """See base class for details.""" # Add the additional length dimension to every serialized features def add_length_dim(tensor_info): """Add the length dimension to the serialized_info.""" return feature_lib.TensorInfo( shape=(self._length, ) + tensor_info.shape, dtype=tensor_info.dtype, ) tensor_info = self._feature.get_serialized_info() return utils.map_nested(add_length_dim, tensor_info)
def encode_example(self, example_dict): # Convert nested dict[list] into list[nested dict] sequence_elements = _transpose_dict_list(example_dict) # If length is static, ensure that the given length match if self._length is not None and len(sequence_elements) != self._length: raise ValueError( 'Input sequence length do not match the defined one. Got {} != ' '{}'.format(len(sequence_elements), self._length)) # Empty sequences return empty arrays if not sequence_elements: def _build_empty_np(serialized_info): return np.empty( shape=tuple(s if s else 0 for s in serialized_info.shape), dtype=serialized_info.dtype.as_numpy_dtype, ) return utils.map_nested(_build_empty_np, self.get_serialized_info()) # Encode each individual elements sequence_elements = [ self.feature.encode_example(sequence_elem) for sequence_elem in sequence_elements ] # Then convert back list[nested dict] => nested dict[list] def _stack_nested(sequence_elements): """Recursivelly stack the tensors from the same dict field.""" if isinstance(sequence_elements[0], dict): return { # Stack along the first dimension k: _stack_nested(sub_sequence) for k, sub_sequence in utils.zip_dict(*sequence_elements) } # Note: As each field can be a nested ragged list, we don't check here # that all elements from the list have matching dtype/shape. # Checking is done in `example_serializer` when elements # are converted to numpy array and stacked togethers. return list(sequence_elements) return _stack_nested(sequence_elements)
def encode_example(self, example_dict): # Convert nested dict[list] into list[nested dict] sequence_elements = _transpose_dict_list(example_dict) # If length is static, ensure that the given length match if self._length is not None and len(sequence_elements) != self._length: raise ValueError( 'Input sequence length do not match the defined one. Got {} != ' '{}'.format(len(sequence_elements), self._length)) # Empty sequences return empty arrays if not sequence_elements: def _build_empty_np(serialized_info): return np.empty( shape=tuple(s if s else 0 for s in serialized_info.shape), dtype=serialized_info.dtype.as_numpy_dtype, ) return utils.map_nested(_build_empty_np, self.get_serialized_info()) # Encode each individual elements sequence_elements = [ self.feature.encode_example(sequence_elem) for sequence_elem in sequence_elements ] # Then merge the elements back together def _stack_nested(sequence_elements): if isinstance(sequence_elements[0], dict): return { # Stack along the first dimension k: _stack_nested(sub_sequence) for k, sub_sequence in utils.zip_dict(*sequence_elements) } return stack_arrays(*sequence_elements) return _stack_nested(sequence_elements)
def _map_promise(map_fn, all_inputs): """Map the function into each element and resolve the promise.""" all_promises = utils.map_nested(map_fn, all_inputs) # Apply the function res = utils.map_nested(_wait_on_promise, all_promises) return res
def _process_exp(self, exp): # Check the shape/dtype with self._subTest("shape"): self.assertEqual(exp.feature.shape, exp.shape) with self._subTest("dtype"): self.assertEqual(exp.feature.dtype, exp.dtype) # Check the serialized features if exp.serialized_features is not None: with self._subTest("serialized_features"): self.assertEqual( exp.serialized_features, exp.feature.get_serialized_features(), ) # Create the feature dict fdict = features.FeaturesDict({exp.name: exp.feature}) for i, test in enumerate(exp.tests): with self._subTest(str(i)): # self._process_subtest_exp(e) input_value = {exp.name: test.value} if test.raise_cls is not None: with self._subTest("raise"): if not test.raise_msg: raise ValueError( "test.raise_msg should be set with {}for test {}" .format(test.raise_cls, exp.name)) with self.assertRaisesWithPredicateMatch( test.raise_cls, test.raise_msg): features_encode_decode(fdict, input_value) else: # Test the serialization only if test.expected_serialized is not None: with self._subTest("out_serialize"): self.assertEqual( test.expected_serialized, exp.feature.encode_sample(test.value), ) # Assert the returned type match the expected one with self._subTest("out_extract"): out = features_encode_decode(fdict, input_value, as_tensor=True) out = out[exp.name] with self._subTest("out_dtype"): out_dtypes = utils.map_nested(lambda s: s.dtype, out) self.assertEqual(out_dtypes, exp.feature.dtype) with self._subTest("out_shape"): # For shape, because (None, 3) match with (5, 3), we use # tf.TensorShape.assert_is_compatible_with on each of the elements out_shapes = utils.zip_nested(out, exp.feature.shape) utils.map_nested( lambda x: x[0].shape.assert_is_compatible_with(x[ 1]), out_shapes) # Test serialization + decoding from disk with self._subTest("out_value"): decoded_samples = features_encode_decode( fdict, input_value) self.assertAllEqual(test.expected, decoded_samples[exp.name])
def as_dataset(self, split=None, batch_size=None, shuffle_files=False, decoders=None, as_supervised=False, in_memory=None): # pylint: disable=line-too-long """Constructs a `tf.data.Dataset`. Callers must pass arguments as keyword arguments. The output types vary depending on the parameters. Examples: ```python builder = tfds.builder('imdb_reviews') builder.download_and_prepare() # Default parameters: Returns the dict of tf.data.Dataset ds_all_dict = builder.as_dataset() assert isinstance(ds_all_dict, dict) print(ds_all_dict.keys()) # ==> ['test', 'train', 'unsupervised'] assert isinstance(ds_all_dict['test'], tf.data.Dataset) # Each dataset (test, train, unsup.) consists of dictionaries # {'label': <tf.Tensor: .. dtype=int64, numpy=1>, # 'text': <tf.Tensor: .. dtype=string, numpy=b"I've watched the movie ..">} # {'label': <tf.Tensor: .. dtype=int64, numpy=1>, # 'text': <tf.Tensor: .. dtype=string, numpy=b'If you love Japanese ..'>} # With as_supervised: tf.data.Dataset only contains (feature, label) tuples ds_all_supervised = builder.as_dataset(as_supervised=True) assert isinstance(ds_all_supervised, dict) print(ds_all_supervised.keys()) # ==> ['test', 'train', 'unsupervised'] assert isinstance(ds_all_supervised['test'], tf.data.Dataset) # Each dataset (test, train, unsup.) consists of tuples (text, label) # (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">, # <tf.Tensor: ... dtype=int64, numpy=1>) # (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">, # <tf.Tensor: ... dtype=int64, numpy=1>) # Same as above plus requesting a particular split ds_test_supervised = builder.as_dataset(as_supervised=True, split='test') assert isinstance(ds_test_supervised, tf.data.Dataset) # The dataset consists of tuples (text, label) # (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">, # <tf.Tensor: ... dtype=int64, numpy=1>) # (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">, # <tf.Tensor: ... dtype=int64, numpy=1>) ``` Args: split: `tfds.core.SplitBase`, which subset(s) of the data to read. If None (default), returns all splits in a dict `<key: tfds.Split, value: tf.data.Dataset>`. batch_size: `int`, batch size. Note that variable-length features will be 0-padded if `batch_size` is set. Users that want more custom behavior should use `batch_size=None` and use the `tf.data` API to construct a custom pipeline. If `batch_size == -1`, will return feature dictionaries of the whole dataset with `tf.Tensor`s instead of a `tf.data.Dataset`. shuffle_files: `bool`, whether to shuffle the input files. Defaults to `False`. decoders: Nested dict of `Decoder` objects which allow to customize the decoding. The structure should match the feature structure, but only customized feature keys need to be present. See [the guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md) for more info. as_supervised: `bool`, if `True`, the returned `tf.data.Dataset` will have a 2-tuple structure `(input, label)` according to `builder.info.supervised_keys`. If `False`, the default, the returned `tf.data.Dataset` will have a dictionary with all the features. in_memory: `bool`, if `True`, loads the dataset in memory which increases iteration speeds. Note that if `True` and the dataset has unknown dimensions, the features will be padded to the maximum size across the dataset. Returns: `tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value: tfds.data.Dataset>`. If `batch_size` is -1, will return feature dictionaries containing the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`. """ # pylint: enable=line-too-long logging.info("Constructing tf.data.Dataset for split %s, from %s", split, self._data_dir) if not tf.io.gfile.exists(self._data_dir): raise AssertionError(( "Dataset %s: could not find data in %s. Please make sure to call " "dataset_builder.download_and_prepare(), or pass download=True to " "tfds.load() before trying to access the tf.data.Dataset object." ) % (self.name, self._data_dir_root)) # By default, return all splits if split is None: split = {s: s for s in self.info.splits} # Create a dataset for each of the given splits build_single_dataset = functools.partial( self._build_single_dataset, shuffle_files=shuffle_files, batch_size=batch_size, decoders=decoders, as_supervised=as_supervised, in_memory=in_memory, ) datasets = utils.map_nested(build_single_dataset, split, map_tuple=True) return datasets
def dtype(self): """Return the dtype (or dict of dtype) of this FeatureConnector.""" return utils.map_nested(lambda t: t.dtype, self.get_tensor_info())
def _get_shape(s): if isinstance(s, tf.data.Dataset): return utils.map_nested(_get_shape, s.element_spec) else: return s.shape
def get_serialized_info(self): """See base class for details.""" # Add the additional length dimension to every serialized features tensor_info = self._feature.get_serialized_info() return utils.map_nested(self._add_length_dim, tensor_info)