Example #1
0
    def _info(self):

        return dataset_info.DatasetInfo(
            builder=self,
            features=features.FeaturesDict({"x": tf.int64}),
            supervised_keys=("x", "x"),
        )
    def test_writing(self):
        # First read in stuff.
        mnist_builder = mnist.MNIST(data_dir=tempfile.mkdtemp(
            dir=self.get_temp_dir()))

        info = dataset_info.DatasetInfo(builder=mnist_builder,
                                        features=mnist_builder.info.features)
        info.read_from_directory(_INFO_DIR)

        # Read the json file into a string.
        with tf.io.gfile.GFile(info._dataset_info_path(_INFO_DIR)) as f:
            existing_json = json.load(f)

        # Now write to a temp directory.
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            info.write_to_directory(tmp_dir)

            # Read the newly written json file into a string.
            with tf.io.gfile.GFile(info._dataset_info_path(tmp_dir)) as f:
                new_json = json.load(f)

            # Read the newly written LICENSE file into a string.
            with tf.io.gfile.GFile(info._license_path(tmp_dir)) as f:
                license_ = f.read()

        # Assert what was read and then written and read again is the same.
        self.assertEqual(existing_json, new_json)

        # Assert correct license was written.
        self.assertEqual(existing_json["redistributionInfo"]["license"],
                         license_)

        if six.PY3:
            # Only test on Python 3 to avoid u'' formatting issues
            self.assertEqual(repr(info), INFO_STR)
Example #3
0
    def test_reading(self):
        info = dataset_info.DatasetInfo(builder=self._builder)
        info.read_from_directory(_INFO_DIR)

        # Assert that we read the file and initialized DatasetInfo.
        self.assertTrue(info.initialized)
        self.assertEqual("mnist", info.name)
        self.assertEqual("mnist/3.0.1", info.full_name)

        # Test splits are initialized properly.
        split_dict = info.splits

        # Assert they are the correct number.
        self.assertTrue(len(split_dict), 2)

        # Assert on what they are
        self.assertIn("train", split_dict)
        self.assertIn("test", split_dict)

        # Assert that this is computed correctly.
        self.assertEqual(40, info.splits.total_num_examples)
        self.assertEqual(11594722, info.dataset_size)

        self.assertEqual("image", info.supervised_keys[0])
        self.assertEqual("label", info.supervised_keys[1])
        self.assertEqual(info.module_name,
                         "tensorflow_datasets.image_classification.mnist")
        self.assertEqual(False, info.disable_shuffling)

        self.assertEqual(info.version, utils.Version("3.0.1"))
        self.assertEqual(info.release_notes, {})
    def test_writing(self):
        # First read in stuff.
        mnist_builder = mnist.MNIST(data_dir=tempfile.mkdtemp(
            dir=self.get_temp_dir()))

        info = dataset_info.DatasetInfo(builder=mnist_builder,
                                        features=mnist_builder.info.features)
        info.read_from_directory(_INFO_DIR)

        # Read the json file into a string.
        with tf.io.gfile.GFile(info._dataset_info_path(_INFO_DIR)) as f:
            existing_json = json.load(f)

        # Now write to a temp directory.
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            info.write_to_directory(tmp_dir)

            # Read the newly written json file into a string.
            with tf.io.gfile.GFile(info._dataset_info_path(tmp_dir)) as f:
                new_json = json.load(f)

            # Read the newly written LICENSE file into a string.
            with tf.io.gfile.GFile(info._license_path(tmp_dir)) as f:
                license_ = f.read()

        # Assert what was read and then written and read again is the same.
        self.assertEqual(existing_json, new_json)

        # Assert correct license was written.
        self.assertEqual(existing_json["redistributionInfo"]["license"],
                         license_)

        # Do not check the full string as it display the generated path.
        self.assertEqual(_INFO_STR % mnist_builder.data_dir, repr(info))
        self.assertIn("'test': <SplitInfo num_examples=", repr(info))
Example #5
0
 def info(self) -> dataset_info.DatasetInfo:
     # HF supports Sequences defined as list: {'video': [{'image': Image()}]}
     with _mock_list_as_sequence():
         info = self._info()
     # HF DatasetInfo do not have `builder` args, so we insert
     # here
     return dataset_info.DatasetInfo(builder=self, **info.kwargs)
 def _info(self):
     return dataset_info.DatasetInfo(
         builder=self,
         features=features.FeaturesDict({"im": features.Image()}),
         supervised_keys=("im", "im"),
         metadata=dataset_info.MetadataDict(),
     )
    def test_reading(self):
        info = dataset_info.DatasetInfo(builder=self._builder)
        info.read_from_directory(_INFO_DIR)

        # Assert that we read the file and initialized DatasetInfo.
        self.assertTrue(info.initialized)
        self.assertEqual("dummy_dataset_shared_generator", info.name)
        self.assertEqual("dummy_dataset_shared_generator/1.0.0",
                         info.full_name)

        # Test splits are initialized properly.
        split_dict = info.splits

        # Assert they are the correct number.
        self.assertTrue(len(split_dict), 2)

        # Assert on what they are
        self.assertIn("train", split_dict)
        self.assertIn("test", split_dict)

        # Assert that this is computed correctly.
        self.assertEqual(40, info.splits.total_num_examples)

        self.assertEqual("image", info.supervised_keys[0])
        self.assertEqual("label", info.supervised_keys[1])
 def _info(self) -> dataset_info.DatasetInfo:
   return dataset_info.DatasetInfo(
       builder=self,
       description='Generic text translation dataset.',
       features=features_lib.FeaturesDict({
           lang: features_lib.Text() for lang in self._languages
       }),
   )
Example #9
0
 def _info(self):
     return dataset_info.DatasetInfo(
         builder=self,
         features=features.FeaturesDict({
             'id': tf.int64,
         }),
         description='Minimal DatasetBuilder.',
     )
Example #10
0
 def _info(self):
   return dataset_info.DatasetInfo(
       builder=self,
       features=features.FeaturesDict({
           "image": features.Image(shape=(28, 28, 1)),
           "label": features.ClassLabel(num_classes=10),
       }),
   )
Example #11
0
 def _info(self):
     mnist_shape = (_MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1)
     return dataset_info.DatasetInfo(specs=features.SpecDict({
         "input":
         features.Image(shape=mnist_shape),
         "target":
         tf.int64,
     }), )
 def _info(self):
   return dataset_info.DatasetInfo(
       features=features.FeaturesDict({
           "x": tf.int64,
           "y": tf.int64,
           "z": tf.string,
       }),
   )
Example #13
0
 def _info(self):
   cifar_shape = (_CIFAR_IMAGE_SIZE, _CIFAR_IMAGE_SIZE, 3)
   return dataset_info.DatasetInfo(
       specs=features.SpecDict({
           "input": features.Image(shape=cifar_shape),
           "target": tf.int64,  # Could replace by features.Label()
       }),
   )
Example #14
0
    def test_reading_from_gcs_bucket(self):
        mnist_builder = mnist.MNIST(data_dir=tempfile.mkdtemp(
            dir=self.get_temp_dir()))
        info = dataset_info.DatasetInfo(builder=mnist_builder)
        info = mnist_builder.info

        # A nominal check to see if we read it.
        self.assertTrue(info.initialized)
        self.assertEqual(10000, info.splits["test"].num_examples)
Example #15
0
 def _info(self):
   return dataset_info.DatasetInfo(
       builder=self,
       features=features.FeaturesDict({
           'image': features.Image(shape=(28, 28, 1)),
           'label': features.ClassLabel(num_classes=10),
       }),
       description='Mnist description.',
   )
Example #16
0
 def test_non_existent_dir(self):
     # The error messages raised by Windows is different from Unix.
     if os.name == "nt":
         err = "The system cannot find the path specified"
     else:
         err = "No such file or dir"
     info = dataset_info.DatasetInfo(builder=self._builder)
     with self.assertRaisesWithPredicateMatch(tf.errors.NotFoundError, err):
         info.read_from_directory(_NON_EXISTENT_DIR)
Example #17
0
 def test_set_file_format_override_fails(self):
     info = dataset_info.DatasetInfo(builder=self._builder)
     info.set_file_format(file_adapters.FileFormat.TFRECORD)
     self.assertEqual(info.file_format, file_adapters.FileFormat.TFRECORD)
     with pytest.raises(
             ValueError,
             match=
             "File format is already set to FileFormat.TFRECORD. Got FileFormat.RIEGELI"
     ):
         info.set_file_format(file_adapters.FileFormat.RIEGELI)
Example #18
0
    def test_reading_from_package_data(self):
        # We have mnist's 1.0.0 checked in the package data, so this should work.
        mnist_builder = mnist.MNIST(data_dir=tempfile.mkdtemp(
            dir=self.get_temp_dir()))
        info = dataset_info.DatasetInfo(builder=mnist_builder)
        info = mnist_builder.info

        # A nominal check to see if we read it.
        self.assertTrue(info.initialized)
        self.assertEqual(10000, info.splits["test"].num_examples)
Example #19
0
 def test_reading_different_version(self):
     info = dataset_info.DatasetInfo(builder=self._builder)
     info._identity.version = utils.Version("2.0.0")
     with pytest.raises(
             AssertionError,
             match=
             "The constructed DatasetInfo instance and the restored proto version do not match"
     ):
         # The dataset in _INFO_DIR has version 3.0.1 whereas the builder is 2.0.0
         info.read_from_directory(_INFO_DIR)
Example #20
0
 def _info(self):
   return dataset_info.DatasetInfo(
       builder=self,
       features=features.FeaturesDict({
           "image": features.Image(shape=(16, 16, 1)),
           "label": features.ClassLabel(names=["dog", "cat"]),
           "id": tf.int32,
       }),
       supervised_keys=("x", "x"),
       metadata=dataset_info.BeamMetadataDict(),
   )
 def _info(self):
   return dataset_info.DatasetInfo(
       builder=self,
       features=features.FeaturesDict({
           "frames": features.Sequence({
               "coordinates": features.Sequence(
                   features.Tensor(shape=(2,), dtype=tf.int32)
               ),
           }),
       }),
   )
Example #22
0
 def _info(self):
   cifar_shape = (_CIFAR_IMAGE_SIZE, _CIFAR_IMAGE_SIZE, 3)
   label_to_use = "coarse_labels" if self._use_coarse_labels else "fine_labels"
   return dataset_info.DatasetInfo(
       specs=features.SpecDict({
           "input": features.Image(shape=cifar_shape),
           "target": features.OneOf(choice=label_to_use, feature_dict={
               "coarse_labels": tf.int64,
               "fine_labels": tf.int64,
           }),
       }),
   )
    def test_reading_from_gcs_bucket(self):
        # The base TestCase prevents GCS access, so we explicitly ask it to restore
        # access here.
        with self.gcs_access():
            mnist_builder = mnist.MNIST(data_dir=tempfile.mkdtemp(
                dir=self.get_temp_dir()))
            info = dataset_info.DatasetInfo(builder=mnist_builder)
            info = mnist_builder.info

            # A nominal check to see if we read it.
            self.assertTrue(info.initialized)
            self.assertEqual(10000, info.splits["test"].num_examples)
Example #24
0
 def _info(self) -> dataset_info.DatasetInfo:
     return dataset_info.DatasetInfo(
         builder=self,
         description='Generic image classification dataset.',
         features=features_lib.FeaturesDict({
             'image':
             features_lib.Image(),
             'label':
             features_lib.ClassLabel(),
             'image/filename':
             features_lib.Text(),
         }),
         supervised_keys=('image', 'label'),
     )
 def _info(self):
     return dataset_info.DatasetInfo(
         builder=self,
         features=features.FeaturesDict({
             'image':
             features.Image(shape=(16, 16, 1)),
             'label':
             features.ClassLabel(names=['dog', 'cat']),
             'id':
             tf.int32,
         }),
         supervised_keys=('x', 'x'),
         metadata=dataset_info.BeamMetadataDict(),
     )
Example #26
0
def _dataset_info(data_dir: str,
                  name: str = 'test_dataset') -> dataset_info.DatasetInfo:
    return dataset_info.DatasetInfo(
        builder=dataset_info.DatasetIdentity(
            name=name,
            data_dir=data_dir,
            module_name='',
            version='1.0.0',
        ),
        description='Test builder',
        features=tfds.features.FeaturesDict({
            'a': tf.int32,
            'b': tf.int32,
        }),
    )
Example #27
0
 def test_set_splits_normal(self):
     info = dataset_info.DatasetInfo(builder=self._builder)
     split_info1 = splits_lib.SplitInfo(name="train",
                                        shard_lengths=[1, 2],
                                        num_bytes=0)
     split_info2 = splits_lib.SplitInfo(name="test",
                                        shard_lengths=[1],
                                        num_bytes=0)
     split_dict = splits_lib.SplitDict(
         split_infos=[split_info1, split_info2])
     info.set_splits(split_dict)
     self.assertEqual(str(info.splits), str(split_dict))
     self.assertEqual(str(info.as_proto.splits),
                      str([split_info1.to_proto(),
                           split_info2.to_proto()]))
Example #28
0
 def test_set_splits_incorrect_dataset_name(self):
     info = dataset_info.DatasetInfo(builder=self._builder)
     split_info1 = splits_lib.SplitInfo(
         name="train",
         shard_lengths=[1, 2],
         num_bytes=0,
         filename_template=naming.ShardedFileTemplate(
             dataset_name="some_other_dataset",
             split="train",
             data_dir=info.data_dir,
             filetype_suffix="tfrecord"))
     split_dict = splits_lib.SplitDict(split_infos=[split_info1])
     with pytest.raises(AssertionError,
                        match="SplitDict contains SplitInfo for split"):
         info.set_splits(split_dict)
    def test_updates_on_bucket_info(self):

        info = dataset_info.DatasetInfo(builder=self._builder,
                                        description="won't be updated")
        # No statistics in the above.
        self.assertEqual(0, info.splits.total_num_examples)
        self.assertEqual(0, len(info.as_proto.schema.feature))

        # Partial update will happen here.
        info.read_from_directory(_INFO_DIR)

        # Assert that description (things specified in the code) didn't change
        # but statistics are updated.
        self.assertEqual("won't be updated", info.description)

        # These are dynamically computed, so will be updated.
        self.assertEqual(40, info.splits.total_num_examples)
        self.assertEqual(2, len(info.as_proto.schema.feature))
Example #30
0
  def test_writing(self):
    # First read in stuff.
    info = dataset_info.DatasetInfo()
    info.read_from_directory(_TESTDATA)

    # Read the json file into a string.
    with tf.gfile.Open(info._dataset_info_filename(_TESTDATA)) as f:
      existing_json = json.load(f)

    # Now write to a temp directory.
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      info.write_to_directory(tmp_dir)

      # Read the newly written json file into a string.
      with tf.gfile.Open(info._dataset_info_filename(tmp_dir)) as f:
        new_json = json.load(f)

    # Assert what was read and then written and read again is the same.
    self.assertEqual(existing_json, new_json)