def _transpose_dict_list(dict_list):
    """Transpose a nested dict[list] into a list[nested dict]."""
    # 1. Unstack numpy arrays into list
    dict_list = utils.map_nested(_np_to_list, dict_list, dict_only=True)

    # 2. Extract the sequence length (and ensure the length is constant for all
    # elements)
    length = {'value': None}  # dict because `nonlocal` is Python3 only

    def update_length(elem):
        if length['value'] is None:
            length['value'] = len(elem)
        elif length['value'] != len(elem):
            raise ValueError(
                'The length of all elements of one sequence should be the same. '
                'Got {} != {}'.format(length['value'], len(elem)))
        return elem

    utils.map_nested(update_length, dict_list, dict_only=True)

    # 3. Extract each individual elements
    return [
        utils.map_nested(lambda elem: elem[i], dict_list, dict_only=True)  # pylint: disable=cell-var-from-loop
        for i in range(length['value'])
    ]
    def assertAllEqualNested(self, d1, d2, *, atol: Optional[float] = None):
        """Same as assertAllEqual but compatible with nested dict.

    Args:
      d1: First element to compare
      d2: Second element to compare
      atol: If given, perform a close float comparison. Otherwise, perform an
        exact comparison
    """
        if isinstance(d1, dict):
            # assertAllEqual do not works well with dictionaries so assert
            # on each individual elements instead
            zipped_examples = utils.zip_nested(d1, d2, dict_only=True)
            utils.map_nested(
                # recursively call assertAllEqualNested in case there is a dataset.
                lambda x: self.assertAllEqualNested(x[0], x[1], atol=atol),
                zipped_examples,
                dict_only=True,
            )
        elif isinstance(d1, (tf.data.Dataset, dataset_utils._IterableDataset)):  # pylint: disable=protected-access
            # Checks length and elements of the dataset. At the moment, more than one
            # level of nested datasets is not supported.
            self.assertEqual(len(d1), len(d2))
            for ex1, ex2 in zip(d1, d2):
                self.assertAllEqualNested(ex1, ex2, atol=atol)
        elif atol:
            self.assertAllClose(d1, d2, atol=atol)
        else:
            self.assertAllEqual(d1, d2)
Exemple #3
0
    def assertFeatureTest(self, fdict, test, feature, shape, dtype):
        """Test that encode=>decoding of a value works correctly."""

        # self._process_subtest_exp(e)
        input_value = {"inner": test.value}

        if test.raise_cls is not None:
            with self._subTest("raise"):
                if not test.raise_msg:
                    raise ValueError(
                        "test.raise_msg should be set with {} for test {}".
                        format(test.raise_cls, type(feature)))
                with self.assertRaisesWithPredicateMatch(
                        test.raise_cls, test.raise_msg):
                    features_encode_decode(fdict, input_value)
        else:
            # Test the serialization only
            if test.expected_serialized is not None:
                with self._subTest("out_serialize"):
                    self.assertEqual(
                        test.expected_serialized,
                        feature.encode_example(test.value),
                    )

            # Assert the returned type match the expected one
            with self._subTest("out"):
                out = features_encode_decode(fdict,
                                             input_value,
                                             as_tensor=True)
                out = out["inner"]
                with self._subTest("dtype"):
                    out_dtypes = utils.map_nested(lambda s: s.dtype, out)
                    self.assertEqual(out_dtypes, feature.dtype)
                with self._subTest("shape"):
                    # For shape, because (None, 3) match with (5, 3), we use
                    # tf.TensorShape.assert_is_compatible_with on each of the elements
                    out_shapes = utils.zip_nested(out, feature.shape)
                    utils.map_nested(
                        lambda x: x[0].shape.assert_is_compatible_with(x[1]),
                        out_shapes)

            # Test serialization + decoding from disk
            with self._subTest("out_value"):
                decoded_examples = features_encode_decode(fdict, input_value)
                decoded_examples = decoded_examples["inner"]
                if isinstance(decoded_examples, dict):
                    # assertAllEqual do not works well with dictionaries so assert
                    # on each individual elements instead
                    zipped_examples = utils.zip_nested(
                        test.expected,
                        decoded_examples,
                        dict_only=True,
                    )
                    utils.map_nested(
                        lambda x: self.assertAllEqual(x[0], x[1]),
                        zipped_examples,
                        dict_only=True,
                    )
                else:
                    self.assertAllEqual(test.expected, decoded_examples)
    def assertFeatureTest(self, fdict, test, feature, shape, dtype):
        """Test that encode=>decoding of a value works correctly."""
        # test feature.encode_example can be pickled and unpickled for beam.
        dill.loads(dill.dumps(feature.encode_example))

        input_value = {'inner': test.value}

        if test.raise_cls is not None:
            with self._subTest('raise'):
                if not test.raise_msg:
                    raise ValueError(
                        'test.raise_msg should be set with {} for test {}'.
                        format(test.raise_cls, type(feature)))
                with self.assertRaisesWithPredicateMatch(
                        test.raise_cls, test.raise_msg):
                    features_encode_decode(fdict,
                                           input_value,
                                           decoders=test.decoders)
        else:
            # Test the serialization only
            if test.expected_serialized is not None:
                with self._subTest('out_serialize'):
                    self.assertEqual(
                        test.expected_serialized,
                        feature.encode_example(test.value),
                    )

            # Test serialization + decoding from disk
            with self._subTest('out'):
                out_tensor, out_numpy = features_encode_decode(
                    fdict,
                    input_value,
                    decoders={'inner': test.decoders},
                )
                out_tensor = out_tensor['inner']
                out_numpy = out_numpy['inner']

                # Assert the returned type match the expected one
                with self._subTest('dtype'):
                    out_dtypes = tf.nest.map_structure(lambda s: s.dtype,
                                                       out_tensor)
                    self.assertEqual(out_dtypes, test.dtype or feature.dtype)
                with self._subTest('shape'):
                    # For shape, because (None, 3) match with (5, 3), we use
                    # tf.TensorShape.assert_is_compatible_with on each of the elements
                    expected_shape = feature.shape if test.shape is None else test.shape
                    out_shapes = utils.zip_nested(out_tensor, expected_shape)
                    utils.map_nested(
                        lambda x: x[0].shape.assert_is_compatible_with(x[1]),
                        out_shapes)

                # Assert value
                with self._subTest('out_value'):
                    # Eventually construct the tf.RaggedTensor
                    expected = tf.nest.map_structure(
                        lambda t: t.build()
                        if isinstance(t, RaggedConstant) else t, test.expected)
                    self.assertAllEqualNested(out_numpy, expected)
Exemple #5
0
 def assertAllEqualNested(self, d1, d2):
     """Same as assertAllEqual but compatible with nested dict."""
     if isinstance(d1, dict):
         # assertAllEqual do not works well with dictionaries so assert
         # on each individual elements instead
         zipped_examples = utils.zip_nested(d1, d2, dict_only=True)
         utils.map_nested(
             lambda x: self.assertAllEqual(x[0], x[1]),
             zipped_examples,
             dict_only=True,
         )
     else:
         self.assertAllEqual(d1, d2)
Exemple #6
0
def _map_promise(map_fn, all_inputs, async_):
    """Map the function into each element and resolve the promise."""
    all_promises = utils.map_nested(map_fn, all_inputs)  # Apply the function
    if async_:
        # TODO(tfds): Fix for nested case
        if isinstance(all_promises, dict):
            merged_promise = promise.Promise.for_dict(all_promises)
        elif isinstance(all_promises, list):
            merged_promise = promise.Promise.all(all_promises)
        else:
            merged_promise = all_promises
        return merged_promise

    return utils.map_nested(lambda p: p.get(), all_promises)
Exemple #7
0
    def get_serialized_info(self):
        """Return the tf-example features for the adapter, as stored on disk.

    This function indicates how this feature is encoded on file internally.
    The DatasetBuilder are written on disk as tf.train.Example proto.

    Ex:

    ```
    return {
        'image': tf.VarLenFeature(tf.uint8):
        'height': tf.FixedLenFeature((), tf.int32),
        'width': tf.FixedLenFeature((), tf.int32),
    }
    ```

    FeatureConnector which are not containers should return the feature proto
    directly:

    ```
    return tf.FixedLenFeature((64, 64), tf.uint8)
    ```

    If not defined, the retuned values are automatically deduced from the
    `get_tensor_info` function.

    Returns:
      features: Either a dict of feature proto object, or a feature proto object

    """
        return utils.map_nested(to_serialized_field, self.get_tensor_info())
Exemple #8
0
  def read(self, name, instructions, split_infos, shuffle_files=False):
    """Returns tf.data.Dataset instance(s).

    Args:
      name (str): name of the dataset.
      instructions (ReadInstruction, List[], Dict[]): instruction(s) to read.
        Instructions can be string and will then be passed to the Instruction
        constructor as it.
      split_infos (list of SplitInfo proto): the available splits for dataset.
      shuffle_files (bool): defaults to False. If True, input files are shuffled
        before being read.

    Returns:
       a single tf.data.Dataset instance if instruction is a single
       ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset
       corresponding to given instructions param shape.
    """
    name2shard_lengths = {info.name: info.shard_lengths for info in split_infos}
    name2len = {name: sum(lengths)
                for name, lengths in name2shard_lengths.items()}
    read_instruction = functools.partial(
        _read_single_instruction,
        parse_fn=self._parser.parse_example,
        name=name, path=self._path,
        name2len=name2len, name2shard_lengths=name2shard_lengths,
        shuffle_files=shuffle_files)
    datasets = utils.map_nested(read_instruction, instructions, map_tuple=True)
    return datasets
    def get_serialized_info(self):
        """See base class for details."""

        # Add the additional length dimension to every serialized features

        def add_length_dim(serialized_info):
            """Add the length dimension to the serialized_info.

      Args:
        serialized_info: One of tf.io.FixedLenFeature, tf.io.VarLenFeature,...

      Returns:
        new_serialized_info: serialized_info with extended first dimension
      """
            if isinstance(serialized_info, tf.io.FixedLenFeature):
                if self._length is not None:
                    return tf.io.FixedLenFeature(
                        shape=(self._length, ) + serialized_info.shape,
                        dtype=serialized_info.dtype,
                    )
                else:
                    return tf.io.FixedLenSequenceFeature(
                        shape=serialized_info.shape,
                        dtype=serialized_info.dtype,
                        allow_missing=True,
                    )
            elif isinstance(serialized_info, tf.io.VarLenFeature):
                return serialized_info
            else:
                raise ValueError(
                    'FixedLenSequenceFeature not supported inside Sequence')
            return serialized_info

        tensor_info = self._feature.get_serialized_info()
        return utils.map_nested(add_length_dim, tensor_info)
 def _get_dl_download_result(self, url):
   if self.DL_DOWNLOAD_RESULT is None:
     # This is only to be backwards compatible with old approach.
     # In the future it will be replaced with using self.example_dir.
     return self._get_dl_extract_result(url)
   return utils.map_nested(lambda fname: os.path.join(self.example_dir, fname),
                           self.DL_DOWNLOAD_RESULT)
Exemple #11
0
 def _get_dl_extract_result(self, url):
   tf.nest.map_structure(self._add_url, url)
   del url
   if self.DL_EXTRACT_RESULT is None:
     return self.example_dir
   return utils.map_nested(lambda fname: os.path.join(self.example_dir, fname),
                           self.DL_EXTRACT_RESULT)
Exemple #12
0
def _parallel_run(function, input_struct, max_workers=1):
    """Run the function on each element of data_struct using a pool of workers."""

    # Distribute the work in a pool
    launch_thread_pool = concurrent.futures.ThreadPoolExecutor
    with launch_thread_pool(max_workers=max_workers) as executor:

        def launch_worker(value):
            return executor.submit(function, value)

        output_struct = utils.map_nested(launch_worker, input_struct)

    # Gather all results once all workers have finished
    def gather_results(value):
        return value.result()

    return utils.map_nested(gather_results, output_struct)
Exemple #13
0
    def as_dataset(self,
                   split=None,
                   batch_size=None,
                   shuffle_files=None,
                   as_supervised=False):
        """Constructs a `tf.data.Dataset`.

    Callers must pass arguments as keyword arguments.

    Args:
      split: `tfds.core.SplitBase`, which subset(s) of the data to read. If None
        (default), returns all splits in a dict
        `<key: tfds.Split, value: tf.data.Dataset>`.
      batch_size: `int`, batch size. Note that variable-length features will
        be 0-padded if `batch_size` is set. Users that want more custom behavior
        should use `batch_size=None` and use the `tf.data` API to construct a
        custom pipeline. If `batch_size == -1`, will return feature
        dictionaries of the whole dataset with `tf.Tensor`s instead of a
        `tf.data.Dataset`.
      shuffle_files: `bool`, whether to shuffle the input files.
        Defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise.
      as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
        will have a 2-tuple structure `(input, label)` according to
        `builder.info.supervised_keys`. If `False`, the default,
        the returned `tf.data.Dataset` will have a dictionary with all the
        features.

    Returns:
      `tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
      tfds.data.Dataset>`.

      If `batch_size` is -1, will return feature dictionaries containing
      the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`.
    """
        logging.info("Constructing tf.data.Dataset for split %s, from %s",
                     split, self._data_dir)
        if not tf.io.gfile.exists(self._data_dir):
            raise AssertionError((
                "Dataset %s: could not find data in %s. Please make sure to call "
                "dataset_builder.download_and_prepare(), or pass download=True to "
                "tfds.load() before trying to access the tf.data.Dataset object."
            ) % (self.name, self._data_dir_root))

        # By default, return all splits
        if split is None:
            split = {s: s for s in self.info.splits}

        # Create a dataset for each of the given splits
        build_single_dataset = functools.partial(
            self._build_single_dataset,
            shuffle_files=shuffle_files,
            batch_size=batch_size,
            as_supervised=as_supervised,
        )
        datasets = utils.map_nested(build_single_dataset,
                                    split,
                                    map_tuple=True)
        return datasets
    def assertFeature(self, specs, serialized_info, tests):
        """Test the TFRecordExampleAdapter encoding."""

        adapter = file_format_adapter.TFRecordExampleAdapter(specs)

        with self._subTest("serialized_info"):
            self.assertEqual(serialized_info,
                             adapter._parser._build_feature_specs())

        for i, test in enumerate(tests):
            with self._subTest(str(i)):

                if test.raise_cls is not None:
                    with self.assertRaisesWithPredicateMatch(
                            test.raise_cls, test.raise_msg):
                        adapter._serializer.serialize_example(test.value)
                    continue
                serialized = adapter._serializer.serialize_example(test.value)

                if test.expected_serialized is not None:
                    example_proto = tf.train.Example()
                    example_proto.ParseFromString(serialized)
                    expected_proto = tf.train.Example(
                        features=tf.train.Features(
                            feature=test.expected_serialized))
                    self.assertEqual(expected_proto, example_proto)

                example = _parse_example(serialized,
                                         adapter._parser.parse_example)

                with self._subTest("dtype"):
                    out_dtypes = utils.map_nested(lambda s: s.dtype, example)
                    expected_dtypes = utils.map_nested(lambda s: s.dtype,
                                                       specs)
                    self.assertEqual(out_dtypes, expected_dtypes)
                with self._subTest("shape"):
                    # For shape, because (None, 3) match with (5, 3), we use
                    # tf.TensorShape.assert_is_compatible_with on each of the elements
                    utils.map_nested(
                        lambda x: x[0].shape.assert_is_compatible_with(x[1].
                                                                       shape),
                        utils.zip_nested(example, specs))
                np_example = dataset_utils.as_numpy(example)
                self.assertAllEqualNested(np_example, test.expected)
Exemple #15
0
    def get_tensor_info(self):
        """See base class for details."""

        # Add the additional length dimension to every shape
        def add_length_dim(tensor_info):
            tensor_info = feature_lib.TensorInfo.copy_from(tensor_info)
            tensor_info.shape = (self._length, ) + tensor_info.shape
            return tensor_info

        tensor_info = self._feature.get_tensor_info()
        return utils.map_nested(add_length_dim, tensor_info)
Exemple #16
0
 def assertAllEqualNested(self, d1, d2):
     """Same as assertAllEqual but compatible with nested dict."""
     if isinstance(d1, dict):
         # assertAllEqual do not works well with dictionaries so assert
         # on each individual elements instead
         zipped_examples = utils.zip_nested(d1, d2, dict_only=True)
         utils.map_nested(
             # recursively call assertAllEqualNested in case there is a dataset.
             lambda x: self.assertAllEqualNested(x[0], x[1]),
             zipped_examples,
             dict_only=True,
         )
     elif isinstance(d1, (tf.data.Dataset, dataset_utils._IterableDataset)):  # pylint: disable=protected-access
         # Checks length and elements of the dataset. At the moment, more than one
         # level of nested datasets is not supported.
         self.assertEqual(len(d1), len(d2))
         for ex1, ex2 in zip(d1, d2):
             self.assertAllEqualNested(ex1, ex2)
     else:
         self.assertAllEqual(d1, d2)
Exemple #17
0
    def read(
        self,
        name,
        instructions,
        split_infos,
        read_config,
        shuffle_files,
    ):
        """Returns tf.data.Dataset instance(s).

    Args:
      name (str): name of the dataset.
      instructions (ReadInstruction, List[], Dict[]): instruction(s) to read.
        Instructions can be string and will then be passed to the Instruction
        constructor as it.
      split_infos (list of SplitInfo proto): the available splits for dataset.
      read_config: `tfds.ReadConfig`, the input pipeline options
      shuffle_files (bool): If True, input files are shuffled before being read.

    Returns:
       a single tf.data.Dataset instance if instruction is a single
       ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset
       corresponding to given instructions param shape.
    """
        def _read_instruction_to_file_instructions(instruction):
            file_instructions = make_file_instructions(name, split_infos,
                                                       instruction)
            files = file_instructions.file_instructions
            if not files:
                msg = 'Instruction "%s" corresponds to no data!' % instruction
                raise AssertionError(msg)
            return tuple(files)

        files = utils.map_nested(_read_instruction_to_file_instructions,
                                 instructions,
                                 map_tuple=False)
        return utils.map_nested(functools.partial(self.read_files,
                                                  read_config=read_config,
                                                  shuffle_files=shuffle_files),
                                files,
                                map_tuple=False)
  def get_tensor_info(self):
    """See base class for details."""
    # Add the additional length dimension to every shape

    def add_length_dim(tensor_info):
      return feature_lib.TensorInfo(
          shape=(self._length,) + tensor_info.shape,
          dtype=tensor_info.dtype,
      )

    tensor_info = super(SequenceDict, self).get_tensor_info()
    return utils.map_nested(add_length_dim, tensor_info)
Exemple #19
0
    def get_serialized_info(self):
        """See base class for details."""

        # Add the additional length dimension to every serialized features

        def add_length_dim(tensor_info):
            """Add the length dimension to the serialized_info."""
            return feature_lib.TensorInfo(
                shape=(self._length, ) + tensor_info.shape,
                dtype=tensor_info.dtype,
            )

        tensor_info = self._feature.get_serialized_info()
        return utils.map_nested(add_length_dim, tensor_info)
    def encode_example(self, example_dict):
        # Convert nested dict[list] into list[nested dict]
        sequence_elements = _transpose_dict_list(example_dict)

        # If length is static, ensure that the given length match
        if self._length is not None and len(sequence_elements) != self._length:
            raise ValueError(
                'Input sequence length do not match the defined one. Got {} != '
                '{}'.format(len(sequence_elements), self._length))

        # Empty sequences return empty arrays
        if not sequence_elements:

            def _build_empty_np(serialized_info):
                return np.empty(
                    shape=tuple(s if s else 0 for s in serialized_info.shape),
                    dtype=serialized_info.dtype.as_numpy_dtype,
                )

            return utils.map_nested(_build_empty_np,
                                    self.get_serialized_info())

        # Encode each individual elements
        sequence_elements = [
            self.feature.encode_example(sequence_elem)
            for sequence_elem in sequence_elements
        ]

        # Then convert back list[nested dict] => nested dict[list]
        def _stack_nested(sequence_elements):
            """Recursivelly stack the tensors from the same dict field."""
            if isinstance(sequence_elements[0], dict):
                return {
                    # Stack along the first dimension
                    k: _stack_nested(sub_sequence)
                    for k, sub_sequence in utils.zip_dict(*sequence_elements)
                }
            # Note: As each field can be a nested ragged list, we don't check here
            # that all elements from the list have matching dtype/shape.
            # Checking is done in `example_serializer` when elements
            # are converted to numpy array and stacked togethers.
            return list(sequence_elements)

        return _stack_nested(sequence_elements)
Exemple #21
0
    def encode_example(self, example_dict):
        # Convert nested dict[list] into list[nested dict]
        sequence_elements = _transpose_dict_list(example_dict)

        # If length is static, ensure that the given length match
        if self._length is not None and len(sequence_elements) != self._length:
            raise ValueError(
                'Input sequence length do not match the defined one. Got {} != '
                '{}'.format(len(sequence_elements), self._length))

        # Empty sequences return empty arrays
        if not sequence_elements:

            def _build_empty_np(serialized_info):
                return np.empty(
                    shape=tuple(s if s else 0 for s in serialized_info.shape),
                    dtype=serialized_info.dtype.as_numpy_dtype,
                )

            return utils.map_nested(_build_empty_np,
                                    self.get_serialized_info())

        # Encode each individual elements
        sequence_elements = [
            self.feature.encode_example(sequence_elem)
            for sequence_elem in sequence_elements
        ]

        # Then merge the elements back together
        def _stack_nested(sequence_elements):
            if isinstance(sequence_elements[0], dict):
                return {
                    # Stack along the first dimension
                    k: _stack_nested(sub_sequence)
                    for k, sub_sequence in utils.zip_dict(*sequence_elements)
                }
            return stack_arrays(*sequence_elements)

        return _stack_nested(sequence_elements)
Exemple #22
0
def _map_promise(map_fn, all_inputs):
    """Map the function into each element and resolve the promise."""
    all_promises = utils.map_nested(map_fn, all_inputs)  # Apply the function
    res = utils.map_nested(_wait_on_promise, all_promises)
    return res
Exemple #23
0
    def _process_exp(self, exp):

        # Check the shape/dtype
        with self._subTest("shape"):
            self.assertEqual(exp.feature.shape, exp.shape)
        with self._subTest("dtype"):
            self.assertEqual(exp.feature.dtype, exp.dtype)

        # Check the serialized features
        if exp.serialized_features is not None:
            with self._subTest("serialized_features"):
                self.assertEqual(
                    exp.serialized_features,
                    exp.feature.get_serialized_features(),
                )

        # Create the feature dict
        fdict = features.FeaturesDict({exp.name: exp.feature})
        for i, test in enumerate(exp.tests):
            with self._subTest(str(i)):
                # self._process_subtest_exp(e)
                input_value = {exp.name: test.value}

                if test.raise_cls is not None:
                    with self._subTest("raise"):
                        if not test.raise_msg:
                            raise ValueError(
                                "test.raise_msg should be set with {}for test {}"
                                .format(test.raise_cls, exp.name))
                        with self.assertRaisesWithPredicateMatch(
                                test.raise_cls, test.raise_msg):
                            features_encode_decode(fdict, input_value)
                else:
                    # Test the serialization only
                    if test.expected_serialized is not None:
                        with self._subTest("out_serialize"):
                            self.assertEqual(
                                test.expected_serialized,
                                exp.feature.encode_sample(test.value),
                            )

                    # Assert the returned type match the expected one
                    with self._subTest("out_extract"):
                        out = features_encode_decode(fdict,
                                                     input_value,
                                                     as_tensor=True)
                        out = out[exp.name]
                    with self._subTest("out_dtype"):
                        out_dtypes = utils.map_nested(lambda s: s.dtype, out)
                        self.assertEqual(out_dtypes, exp.feature.dtype)
                    with self._subTest("out_shape"):
                        # For shape, because (None, 3) match with (5, 3), we use
                        # tf.TensorShape.assert_is_compatible_with on each of the elements
                        out_shapes = utils.zip_nested(out, exp.feature.shape)
                        utils.map_nested(
                            lambda x: x[0].shape.assert_is_compatible_with(x[
                                1]), out_shapes)

                    # Test serialization + decoding from disk
                    with self._subTest("out_value"):
                        decoded_samples = features_encode_decode(
                            fdict, input_value)
                        self.assertAllEqual(test.expected,
                                            decoded_samples[exp.name])
Exemple #24
0
    def as_dataset(self,
                   split=None,
                   batch_size=None,
                   shuffle_files=False,
                   decoders=None,
                   as_supervised=False,
                   in_memory=None):
        # pylint: disable=line-too-long
        """Constructs a `tf.data.Dataset`.

    Callers must pass arguments as keyword arguments.

    The output types vary depending on the parameters. Examples:

    ```python
    builder = tfds.builder('imdb_reviews')
    builder.download_and_prepare()

    # Default parameters: Returns the dict of tf.data.Dataset
    ds_all_dict = builder.as_dataset()
    assert isinstance(ds_all_dict, dict)
    print(ds_all_dict.keys())  # ==> ['test', 'train', 'unsupervised']

    assert isinstance(ds_all_dict['test'], tf.data.Dataset)
    # Each dataset (test, train, unsup.) consists of dictionaries
    # {'label': <tf.Tensor: .. dtype=int64, numpy=1>,
    #  'text': <tf.Tensor: .. dtype=string, numpy=b"I've watched the movie ..">}
    # {'label': <tf.Tensor: .. dtype=int64, numpy=1>,
    #  'text': <tf.Tensor: .. dtype=string, numpy=b'If you love Japanese ..'>}

    # With as_supervised: tf.data.Dataset only contains (feature, label) tuples
    ds_all_supervised = builder.as_dataset(as_supervised=True)
    assert isinstance(ds_all_supervised, dict)
    print(ds_all_supervised.keys())  # ==> ['test', 'train', 'unsupervised']

    assert isinstance(ds_all_supervised['test'], tf.data.Dataset)
    # Each dataset (test, train, unsup.) consists of tuples (text, label)
    # (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">,
    #  <tf.Tensor: ... dtype=int64, numpy=1>)
    # (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">,
    #  <tf.Tensor: ... dtype=int64, numpy=1>)

    # Same as above plus requesting a particular split
    ds_test_supervised = builder.as_dataset(as_supervised=True, split='test')
    assert isinstance(ds_test_supervised, tf.data.Dataset)
    # The dataset consists of tuples (text, label)
    # (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">,
    #  <tf.Tensor: ... dtype=int64, numpy=1>)
    # (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">,
    #  <tf.Tensor: ... dtype=int64, numpy=1>)
    ```

    Args:
      split: `tfds.core.SplitBase`, which subset(s) of the data to read. If None
        (default), returns all splits in a dict
        `<key: tfds.Split, value: tf.data.Dataset>`.
      batch_size: `int`, batch size. Note that variable-length features will
        be 0-padded if `batch_size` is set. Users that want more custom behavior
        should use `batch_size=None` and use the `tf.data` API to construct a
        custom pipeline. If `batch_size == -1`, will return feature
        dictionaries of the whole dataset with `tf.Tensor`s instead of a
        `tf.data.Dataset`.
      shuffle_files: `bool`, whether to shuffle the input files. Defaults to
        `False`.
      decoders: Nested dict of `Decoder` objects which allow to customize the
        decoding. The structure should match the feature structure, but only
        customized feature keys need to be present. See
        [the guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md)
        for more info.
      as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
        will have a 2-tuple structure `(input, label)` according to
        `builder.info.supervised_keys`. If `False`, the default,
        the returned `tf.data.Dataset` will have a dictionary with all the
        features.
      in_memory: `bool`, if `True`, loads the dataset in memory which
        increases iteration speeds. Note that if `True` and the dataset has
        unknown dimensions, the features will be padded to the maximum
        size across the dataset.

    Returns:
      `tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
      tfds.data.Dataset>`.

      If `batch_size` is -1, will return feature dictionaries containing
      the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`.
    """
        # pylint: enable=line-too-long
        logging.info("Constructing tf.data.Dataset for split %s, from %s",
                     split, self._data_dir)
        if not tf.io.gfile.exists(self._data_dir):
            raise AssertionError((
                "Dataset %s: could not find data in %s. Please make sure to call "
                "dataset_builder.download_and_prepare(), or pass download=True to "
                "tfds.load() before trying to access the tf.data.Dataset object."
            ) % (self.name, self._data_dir_root))

        # By default, return all splits
        if split is None:
            split = {s: s for s in self.info.splits}

        # Create a dataset for each of the given splits
        build_single_dataset = functools.partial(
            self._build_single_dataset,
            shuffle_files=shuffle_files,
            batch_size=batch_size,
            decoders=decoders,
            as_supervised=as_supervised,
            in_memory=in_memory,
        )
        datasets = utils.map_nested(build_single_dataset,
                                    split,
                                    map_tuple=True)
        return datasets
Exemple #25
0
 def dtype(self):
     """Return the dtype (or dict of dtype) of this FeatureConnector."""
     return utils.map_nested(lambda t: t.dtype, self.get_tensor_info())
 def _get_shape(s):
     if isinstance(s, tf.data.Dataset):
         return utils.map_nested(_get_shape, s.element_spec)
     else:
         return s.shape
 def get_serialized_info(self):
     """See base class for details."""
     # Add the additional length dimension to every serialized features
     tensor_info = self._feature.get_serialized_info()
     return utils.map_nested(self._add_length_dim, tensor_info)