Exemplo n.º 1
0
  def test_load_from_gcs(self):
    from tensorflow_datasets.image_classification import mnist  # pylint:disable=import-outside-toplevel,g-import-not-at-top
    with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
      with mock.patch.object(
          mnist.MNIST, "_download_and_prepare",
          side_effect=NotImplementedError):
        # Make sure the dataset cannot be generated.
        with self.assertRaises(NotImplementedError):
          load.load(
              name="mnist",
              data_dir=tmp_dir)
        # Enable GCS access so that dataset will be loaded from GCS.
        with self.gcs_access():
          _, info = load.load(
              name="mnist",
              data_dir=tmp_dir,
              with_info=True)
      self.assertSetEqual(
          set(["dataset_info.json",
               "image.image.json",
               "mnist-test.tfrecord-00000-of-00001",
               "mnist-train.tfrecord-00000-of-00001",
              ]),
          set(tf.io.gfile.listdir(os.path.join(tmp_dir, "mnist/3.0.1"))))

      self.assertEqual(set(info.splits.keys()), set(["train", "test"]))
def test_builder_code_not_found(code_builder: dataset_builder.DatasetBuilder):
    """If the code isn't found, use files instead."""

    # Patch `tfds.builder_cls` to emulate that the dataset isn't registered
    with mock.patch.object(
            load,
            'builder_cls',
            side_effect=registered.DatasetNotFoundError(code_builder.name),
    ):
        # When the code isn't found, loading dataset require explicit config name:
        # tfds.builder('ds/config')
        config_name = code_builder.name
        if code_builder.builder_config:
            config_name = f'{config_name}/{code_builder.builder_config.name}'

        # Files exists, but not code, loading from files
        builder = load.builder(config_name)
        assert isinstance(builder, read_only_builder.ReadOnlyBuilder)
        load.load(config_name, split=[])  # Dataset found -> no error

        # Neither code not files found, raise DatasetNotFoundError
        with pytest.raises(registered.DatasetNotFoundError):
            load.builder(config_name, data_dir='/tmp/non-existing/tfds/dir')

        with pytest.raises(registered.DatasetNotFoundError):
            load.load(config_name,
                      split=[],
                      data_dir='/tmp/non-existing/tfds/dir')
Exemplo n.º 3
0
def test_builder_code_not_found(code_builder: dataset_builder.DatasetBuilder):
  """If the code isn't found, use files instead."""

  # Patch `tfds.builder_cls` to emulate that the dataset isn't registered
  with mock.patch.object(
      load,
      'builder_cls',
      side_effect=registered.DatasetNotFoundError(code_builder.name),
  ):
    # Files exists, but not code, loading from files
    builder = load.builder(code_builder.name)
    assert isinstance(builder, read_only_builder.ReadOnlyBuilder)
    load.load(code_builder.name, split=[])  # Dataset found -> no error

    if code_builder.builder_config:
      # When the code isn't found, default config is infered from `.config/`
      assert builder.builder_config.name == code_builder.BUILDER_CONFIGS[0].name

      # Explicitly passing a config should works too.
      config_name = f'{code_builder.name}/{code_builder.builder_config.name}'
      builder = load.builder(config_name)
      assert isinstance(builder, read_only_builder.ReadOnlyBuilder)

    # Neither code not files found, raise DatasetNotFoundError
    with pytest.raises(registered.DatasetNotFoundError):
      load.builder(code_builder.name, data_dir='/tmp/non-existing/tfds/dir')

    with pytest.raises(registered.DatasetNotFoundError):
      load.load(
          code_builder.name, split=[], data_dir='/tmp/non-existing/tfds/dir'
      )
Exemplo n.º 4
0
 def test_invalid_split_dataset(self):
     with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
         with self.assertRaisesWithPredicateMatch(ValueError,
                                                  "`all` is a special"):
             # Raise error during .download_and_prepare()
             load.load(
                 name="invalid_split_dataset",
                 data_dir=tmp_dir,
             )
  def test_load_with_config(self):
    data_dir = "foo"
    name = "empty_dataset_builder/bar/k1=1"
    # EmptyDatasetBuilder returns self from as_dataset
    builder = load.load(name=name, split=splits.Split.TEST, data_dir=data_dir)
    expected = dict(data_dir=data_dir, k1=1, config="bar")
    self.assertEqual(expected, builder.kwargs)

    name = "empty_dataset_builder/bar"
    builder = load.load(name=name, split=splits.Split.TEST, data_dir=data_dir)
    self.assertEqual(dict(data_dir=data_dir, config="bar"), builder.kwargs)
Exemplo n.º 6
0
 def test_show_examples_supervised(self, _):
     with testing.mock_data(num_examples=20):
         ds, ds_info = load.load('imagenet2012',
                                 split='train',
                                 with_info=True,
                                 as_supervised=True)
     visualization.show_examples(ds, ds_info)
Exemplo n.º 7
0
 def test_load(self):
     with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
         dataset = load.load(name="dummy_dataset_with_configs",
                             data_dir=tmp_dir,
                             download=True,
                             split=splits_lib.Split.TRAIN)
         data = list(dataset_utils.as_numpy(dataset))
         self.assertEqual(20, len(data))
         self.assertLess(data[0]["x"], 30)
Exemplo n.º 8
0
 def test_mocking_lm1b(self):
   with mocking.mock_data():
     ds = load.load('lm1b/bytes', split='train')
     self.assertEqual(ds.element_spec, {
         'text': tf.TensorSpec(shape=(None,), dtype=tf.int64),
     })
     for ex in ds.take(10):
       self.assertEqual(ex['text'].dtype, tf.int64)
       ex['text'].shape.assert_is_compatible_with((None,))
Exemplo n.º 9
0
 def test_mocking_imagenet(self):
   with mocking.mock_data():
     ds = load.load('imagenet2012', split='train')
     self.assertEqual(ds.element_spec, {
         'file_name': tf.TensorSpec(shape=(), dtype=tf.string),
         'image': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
         'label': tf.TensorSpec(shape=(), dtype=tf.int64),
     })
     list(ds.take(3))  # Iteration should work
Exemplo n.º 10
0
def _as_df(ds_name: str, **kwargs) -> pandas.DataFrame:
    """Loads the dataset as `pandas.DataFrame`."""
    with testing.mock_data(num_examples=3):
        ds, ds_info = load.load(ds_name,
                                split='train',
                                with_info=True,
                                **kwargs)
    df = as_dataframe.as_dataframe(ds, ds_info)
    return df
Exemplo n.º 11
0
    def test_show_examples_graph_with_colors_and_labels(self, _):
        with testing.mock_data(num_examples=20):
            ds, ds_info = load.load('ogbg_molpcba',
                                    split='train',
                                    with_info=True)

        # Dictionaries used to map nodes and edges to colors.
        atomic_numbers_to_elements = {
            6: 'C',
            7: 'N',
            8: 'O',
            9: 'F',
            14: 'Si',
            15: 'P',
            16: 'S',
            17: 'Cl',
            35: 'Br,'
        }
        elements_to_colors = {
            element: f'C{index}'
            for index, element in enumerate(
                atomic_numbers_to_elements.values())
        }
        bond_types_to_colors = {num: f'C{num}' for num in range(4)}

        # Node colors are atomic numbers.
        def node_color_fn(graph):
            atomic_numbers = 1 + graph['node_feat'][:, 0].numpy()
            return {
                index:
                elements_to_colors[atomic_numbers_to_elements[atomic_number]]
                for index, atomic_number in enumerate(atomic_numbers)
            }

        # Node labels are element names.
        def node_label_fn(graph):
            atomic_numbers = 1 + graph['node_feat'][:, 0].numpy()
            return {
                index: atomic_numbers_to_elements[atomic_number]
                for index, atomic_number in enumerate(atomic_numbers)
            }

        # Edge colors are bond types.
        def edge_color_fn(graph):
            bonds = graph['edge_index'].numpy()
            bond_types = graph['edge_feat'][:, 0].numpy()
            return {
                tuple(bond): bond_types_to_colors[bond_type]
                for bond, bond_type in zip(bonds, bond_types)
            }

        visualization.show_examples(ds,
                                    ds_info,
                                    node_color_fn=node_color_fn,
                                    node_label_fn=node_label_fn,
                                    edge_color_fn=edge_color_fn)
Exemplo n.º 12
0
    def test_load(self):
        name = "empty_dataset_builder/k1=1"
        data_dir = "foo"
        as_dataset_kwargs = dict(a=1, b=2)

        # EmptyDatasetBuilder returns self from as_dataset
        builder = load.load(name=name,
                            split=splits.Split.TEST,
                            data_dir=data_dir,
                            download=False,
                            as_dataset_kwargs=as_dataset_kwargs)
        self.assertTrue(builder.as_dataset_called)
        self.assertFalse(builder.download_called)
        self.assertEqual(splits.Split.TEST,
                         builder.as_dataset_kwargs.pop("split"))
        self.assertIsNone(builder.as_dataset_kwargs.pop("batch_size"))
        self.assertFalse(builder.as_dataset_kwargs.pop("as_supervised"))
        self.assertFalse(builder.as_dataset_kwargs.pop("decoders"))
        self.assertIsNone(builder.as_dataset_kwargs.pop("read_config"))
        self.assertFalse(builder.as_dataset_kwargs.pop("shuffle_files"))
        self.assertEqual(builder.as_dataset_kwargs, as_dataset_kwargs)
        self.assertEqual(dict(data_dir=data_dir, k1=1), builder.kwargs)

        builder = load.load(name,
                            split=splits.Split.TRAIN,
                            data_dir=data_dir,
                            download=True,
                            as_dataset_kwargs=as_dataset_kwargs)
        self.assertTrue(builder.as_dataset_called)
        self.assertTrue(builder.download_called)

        # Tests for different batch_size
        # By default batch_size=None
        builder = load.load(name=name,
                            split=splits.Split.TEST,
                            data_dir=data_dir)
        self.assertIsNone(builder.as_dataset_kwargs.pop("batch_size"))
        # Setting batch_size=1
        builder = load.load(name=name,
                            split=splits.Split.TEST,
                            data_dir=data_dir,
                            batch_size=1)
        self.assertEqual(1, builder.as_dataset_kwargs.pop("batch_size"))
Exemplo n.º 13
0
  def test_custom_as_dataset(self):
    def _as_dataset(self, *args, **kwargs):  # pylint: disable=unused-argument
      return tf.data.Dataset.from_generator(
          lambda: ({  # pylint: disable=g-long-lambda
              'text': t,
          } for t in ['some sentence', 'some other sentence']),
          output_types=self.info.features.dtype,
          output_shapes=self.info.features.shape,
      )

    with mocking.mock_data(as_dataset_fn=_as_dataset):
      ds = load.load('lm1b', split='train')
      out = [ex['text'] for ex in dataset_utils.as_numpy(ds)]
      self.assertEqual(out, [b'some sentence', b'some other sentence'])
Exemplo n.º 14
0
    def test_load_data_dir(self):
        """Ensure that `tfds.load` also supports multiple data_dir."""
        constants.add_data_dir(self.other_data_dir)

        class MultiDirDataset(DummyDatasetSharedGenerator):  # pylint: disable=unused-variable
            VERSION = utils.Version("1.2.0")

        data_dir = os.path.join(self.other_data_dir, "multi_dir_dataset",
                                "1.2.0")
        tf.io.gfile.makedirs(data_dir)

        with mock.patch.object(dataset_info.DatasetInfo,
                               "read_from_directory"):
            _, info = load.load("multi_dir_dataset", split=[], with_info=True)
        self.assertEqual(info.data_dir, data_dir)
Exemplo n.º 15
0
 def test_max_values(self):
   with mocking.mock_data(num_examples=50):
     ds = load.load('mnist', split='train')
     self.assertEqual(ds.element_spec, {
         'image': tf.TensorSpec(shape=(28, 28, 1), dtype=tf.uint8),
         'label': tf.TensorSpec(shape=(), dtype=tf.int64),
     })
     for ex in ds.take(50):
       self.assertLessEqual(tf.math.reduce_max(ex['label']).numpy(), 10)
     self.assertEqual(  # Test determinism
         [ex['label'].numpy() for ex in ds.take(5)],
         [1, 9, 2, 5, 3],
     )
     self.assertEqual(  # Iterating twice should yield the same samples
         [ex['label'].numpy() for ex in ds.take(5)],
         [1, 9, 2, 5, 3],
     )
Exemplo n.º 16
0
    def test_nested_sequence(self):
        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            ds_train, ds_info = load.load(name="nested_sequence_builder",
                                          data_dir=tmp_dir,
                                          split="train",
                                          with_info=True,
                                          shuffle_files=False)
            ex0, ex1, ex2 = [
                ex["frames"]["coordinates"]
                for ex in dataset_utils.as_numpy(ds_train)
            ]
            self.assertAllEqual(
                ex0,
                tf.ragged.constant([
                    [[0, 1], [2, 3], [4, 5]],
                    [],
                    [[6, 7]],
                ],
                                   inner_shape=(2, )))
            self.assertAllEqual(ex1, tf.ragged.constant([], ragged_rank=1))
            self.assertAllEqual(
                ex2,
                tf.ragged.constant([
                    [[10, 11]],
                    [[12, 13], [14, 15]],
                ],
                                   inner_shape=(2, )))

            self.assertEqual(
                ds_info.features.dtype,
                {"frames": {
                    "coordinates": tf.int32
                }},
            )
            self.assertEqual(
                ds_info.features.shape,
                {"frames": {
                    "coordinates": (None, None, 2)
                }},
            )
            nested_tensor_info = ds_info.features.get_tensor_info()
            self.assertEqual(
                nested_tensor_info["frames"]["coordinates"].sequence_rank,
                2,
            )
Exemplo n.º 17
0
 def test_mocking_imagenet_decoders(self):
   with mocking.mock_data():
     ds, ds_info = load.load(
         'imagenet2012',
         split='train',
         decoders={'image': decode.SkipDecoding()},
         with_info=True,
     )
     self.assertEqual(ds.element_spec, {
         'file_name': tf.TensorSpec(shape=(), dtype=tf.string),
         'image': tf.TensorSpec(shape=(), dtype=tf.string),  # Encoded images
         'label': tf.TensorSpec(shape=(), dtype=tf.int64),
     })
     for ex in ds.take(10):
       # Image decoding should works
       image = ds_info.features['image'].decode_example(ex['image'])
       image.shape.assert_is_compatible_with((None, None, 3))
       self.assertEqual(image.dtype, tf.uint8)
Exemplo n.º 18
0
 def test_load_all_splits(self):
     name = "empty_dataset_builder"
     # EmptyDatasetBuilder returns self from as_dataset
     builder = load.load(name=name, data_dir="foo")
     self.assertTrue(builder.as_dataset_called)
     self.assertIsNone(builder.as_dataset_kwargs.pop("split"))
Exemplo n.º 19
0
 def test_show_examples_missing_sample(self, _):
     with testing.mock_data(num_examples=3):
         ds, ds_info = load.load('imagenet2012',
                                 split='train',
                                 with_info=True)
     visualization.show_examples(ds.take(3), ds_info)
Exemplo n.º 20
0
 def test_show_examples_graph(self, _):
     with testing.mock_data(num_examples=20):
         ds, ds_info = load.load('ogbg_molpcba',
                                 split='train',
                                 with_info=True)
     visualization.show_examples(ds, ds_info)