def _test_generator_based_builder(self, builder_cls):
        with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = builder_cls(data_dir=tmp_dir)
            builder.download_and_prepare()
            train_dataset = builder.as_dataset(
                split=dataset_builder.Split.TRAIN)
            valid_dataset = builder.as_dataset(
                split=dataset_builder.Split.VALIDATION)
            test_dataset = builder.as_dataset(split=dataset_builder.Split.TEST)

            def validate_dataset(dataset, min_val, max_val, test_range=False):
                els = []
                for el in dataset:
                    x, y, z = el["x"].numpy(), el["y"].numpy(), el["z"].numpy()
                    self.assertEqual(-x, y)
                    self.assertEqual(x, int(z))
                    self.assertGreaterEqual(x, min_val)
                    self.assertLess(x, max_val)
                    els.append(x)
                if test_range:
                    self.assertEqual(list(range(min_val, max_val)),
                                     sorted(els))

            validate_dataset(train_dataset, 0, 30)
            validate_dataset(valid_dataset, 0, 30)
            validate_dataset(test_dataset, 30, 40, True)
Ejemplo n.º 2
0
  def test_file_backed_with_args(self):
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      # Set all the args to non-default values, including Tokenizer
      tokenizer = text_encoder.Tokenizer(
          reserved_tokens=['<FOOBAR>'], alphanum_only=False)
      encoder = text_encoder.TokenTextEncoder(
          vocab_list=['hi', 'bye', ZH_HELLO],
          lowercase=True,
          oov_buckets=2,
          oov_token='ZOO',
          tokenizer=tokenizer)

      vocab_fname = os.path.join(tmp_dir, 'vocab')
      encoder.save_to_file(vocab_fname)

      file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file(
          vocab_fname)
      self.assertEqual(encoder.tokens, file_backed_encoder.tokens)
      self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size)
      self.assertEqual(encoder.lowercase, file_backed_encoder.lowercase)
      self.assertEqual(encoder.oov_token, file_backed_encoder.oov_token)
      self.assertEqual(encoder.tokenizer.alphanum_only,
                       file_backed_encoder.tokenizer.alphanum_only)
      self.assertEqual(encoder.tokenizer.reserved_tokens,
                       file_backed_encoder.tokenizer.reserved_tokens)
    def test_shared_generator(self):
        with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
            builder.download_and_prepare()

            written_filepaths = [
                os.path.join(builder._data_dir, fname)
                for fname in tf.gfile.ListDirectory(builder._data_dir)
            ]
            # The data_dir contains the cached directory by default
            expected_filepaths = []
            for split in builder.splits:
                expected_filepaths.extend(split.filepaths)
            self.assertEqual(sorted(expected_filepaths),
                             sorted(written_filepaths))

            splits = [dataset_builder.Split.TRAIN, dataset_builder.Split.TEST]
            datasets = [builder.as_dataset(split=split) for split in splits]
            data = [[el["x"].numpy() for el in dataset]
                    for dataset in datasets]

            train_data, test_data = data
            self.assertEqual(20, len(train_data))
            self.assertEqual(10, len(test_data))
            self.assertEqual(list(range(30)), sorted(train_data + test_data))
Ejemplo n.º 4
0
    def test_shared_generator(self):
        with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
            builder.download_and_prepare()

            written_filepaths = [
                os.path.join(builder._data_dir, fname)
                for fname in tf.io.gfile.listdir(builder._data_dir)
            ]
            # The data_dir contains the cached directory by default
            expected_filepaths = builder._build_split_filenames(
                split_info_list=builder.info.splits.values())
            expected_filepaths.append(
                os.path.join(builder._data_dir, "dataset_info.json"))
            self.assertEqual(sorted(expected_filepaths),
                             sorted(written_filepaths))

            splits_list = [splits_lib.Split.TRAIN, splits_lib.Split.TEST]
            train_data, test_data = [[
                el["x"] for el in dataset_utils.dataset_as_numpy(
                    builder.as_dataset(split=split))
            ] for split in splits_list]

            self.assertEqual(20, len(train_data))
            self.assertEqual(10, len(test_data))
            self.assertEqual(list(range(30)), sorted(train_data + test_data))

            # Builder's info should also have the above information.
            self.assertTrue(builder.info.initialized)
            self.assertEqual(
                20, builder.info.splits[splits_lib.Split.TRAIN].num_examples)
            self.assertEqual(
                10, builder.info.splits[splits_lib.Split.TEST].num_examples)
            self.assertEqual(30, builder.info.splits.total_num_examples)
Ejemplo n.º 5
0
  def test_with_configs(self):
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      builder1 = DummyDatasetWithConfigs(config="plus1", data_dir=tmp_dir)
      builder2 = DummyDatasetWithConfigs(config="plus2", data_dir=tmp_dir)
      # Test that builder.builder_config is the correct config
      self.assertIs(builder1.builder_config,
                    DummyDatasetWithConfigs.builder_configs["plus1"])
      self.assertIs(builder2.builder_config,
                    DummyDatasetWithConfigs.builder_configs["plus2"])
      builder1.download_and_prepare()
      builder2.download_and_prepare()
      data_dir1 = os.path.join(tmp_dir, builder1.name, "plus1", "0.0.1")
      data_dir2 = os.path.join(tmp_dir, builder2.name, "plus2", "0.0.2")
      # Test that subdirectories were created per config
      self.assertTrue(tf.gfile.Exists(data_dir1))
      self.assertTrue(tf.gfile.Exists(data_dir2))
      # 2 train shards, 1 test shard, plus metadata files
      self.assertGreater(len(tf.gfile.ListDirectory(data_dir1)), 3)
      self.assertGreater(len(tf.gfile.ListDirectory(data_dir2)), 3)

      # Test that the config was used and they didn't collide.
      splits_list = [splits_lib.Split.TRAIN, splits_lib.Split.TEST]
      for builder, incr in [(builder1, 1), (builder2, 2)]:
        train_data, test_data = [
            [el["x"] for el in
             dataset_utils.dataset_as_numpy(builder.as_dataset(split=split))]
            for split in splits_list
        ]

        self.assertEqual(20, len(train_data))
        self.assertEqual(10, len(test_data))
        self.assertEqual([incr + el for el in range(30)],
                         sorted(train_data + test_data))
Ejemplo n.º 6
0
  def test_determinism(self):
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      ds = registered.load(
          name="dummy_dataset_shared_generator",
          data_dir=tmp_dir,
          split=splits_lib.Split.TRAIN,
          as_dataset_kwargs=dict(shuffle_files=False))
      ds_values = list(dataset_utils.dataset_as_numpy(ds))

      # Ensure determinism. If this test fail, this mean that numpy random
      # module isn't always determinist (maybe between version, architecture,
      # ...), and so our datasets aren't guarantee either
      l = list(range(20))
      np.random.RandomState(42).shuffle(l)
      self.assertEqual(l, [
          0, 17, 15, 1, 8, 5, 11, 3, 18, 16, 13, 2, 9, 19, 4, 12, 7, 10, 14, 6
      ])

      # Ensure determinism. If this test fails, this mean the dataset are not
      # deterministically generated.
      self.assertEqual(
          [e["x"] for e in ds_values],
          [24, 1, 3, 4, 15, 25, 0, 16, 21, 10, 6, 13, 27, 22, 12, 28, 9, 19,
           18, 7],
      )
Ejemplo n.º 7
0
    def test_download(self):
        with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
            dm = download_manager.DownloadManager(tmp_dir)

            output_filenames = ["a", "b", "c"]
            urls = ["http://%s.com/foo" % fname for fname in output_filenames]

            retrieved = []

            def mock_urlretrieve(url, filepath):
                retrieved.append(url)
                with tf.gfile.Open(filepath, "w") as f:
                    f.write(output_filenames[len(retrieved) - 1])

            with tf.test.mock.patch("six.moves.urllib.request.urlretrieve",
                                    mock_urlretrieve):
                output_paths = dm.download(urls,
                                           output_filenames,
                                           num_threads=1)
                self.assertEqual(retrieved, urls)
                self.assertEqual([
                    os.path.join(tmp_dir, fname) for fname in output_filenames
                ], output_paths)
                for fname, path in zip(output_filenames, output_paths):
                    with tf.gfile.Open(path) as f:
                        self.assertEqual(fname, f.read())
Ejemplo n.º 8
0
 def test_invalid_split_dataset(self):
   with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
     with self.assertRaisesWithPredicateMatch(ValueError, "ALL is a special"):
       # Raise error during .download_and_prepare()
       registered.load(
           name="invalid_split_dataset",
           data_dir=tmp_dir,
       )
 def test_load(self):
     with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
         dataset = registered.load(name="dummy_dataset_shared_generator",
                                   data_dir=tmp_dir,
                                   download=True,
                                   split=dataset_builder.Split.TRAIN)
         data = list(dataset)
         self.assertEqual(20, len(data))
Ejemplo n.º 10
0
 def test_numpy_iterator(self):
   with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
     builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
     builder.download_and_prepare()
     items = []
     for item in builder.numpy_iterator(split=splits.Split.TRAIN):
       items.append(item)
     self.assertEqual(20, len(items))
Ejemplo n.º 11
0
 def test_load(self):
     with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
         dataset = registered.load(name="dummy_dataset_shared_generator",
                                   data_dir=tmp_dir,
                                   download=True,
                                   split=splits_lib.Split.TRAIN)
         data = list(dataset_utils.dataset_as_numpy(dataset))
         self.assertEqual(20, len(data))
         self.assertLess(data[0]["x"], 30)
Ejemplo n.º 12
0
 def test_file_backed(self):
   with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
     vocab_fname = os.path.join(tmp_dir, 'vocab')
     encoder = text_encoder.TokenTextEncoder(
         vocab_list=['hi', 'bye', ZH_HELLO])
     encoder.save_to_file(vocab_fname)
     file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file(
         vocab_fname)
     self.assertEqual(encoder.tokens, file_backed_encoder.tokens)
Ejemplo n.º 13
0
  def test_get_data_dir_with_config(self):
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      config_name = "plus1"
      builder = DummyDatasetWithConfigs(config=config_name, data_dir=tmp_dir)

      builder_data_dir = os.path.join(tmp_dir, builder.name, config_name)
      version_data_dir = os.path.join(builder_data_dir, "0.0.1")

      tf.gfile.MakeDirs(version_data_dir)
      self.assertEqual(builder._build_data_dir(), version_data_dir)
Ejemplo n.º 14
0
  def test_file_backed(self, additional_tokens):
    encoder = text_encoder.ByteTextEncoder(additional_tokens=additional_tokens)
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      vocab_fname = os.path.join(tmp_dir, 'vocab')
      encoder.save_to_file(vocab_fname)

      file_backed_encoder = text_encoder.ByteTextEncoder.load_from_file(
          vocab_fname)
      self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size)
      self.assertEqual(encoder.additional_tokens,
                       file_backed_encoder.additional_tokens)
Ejemplo n.º 15
0
 def test_config_construction(self):
   with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
     self.assertSetEqual(
         set(["plus1", "plus2"]),
         set(DummyDatasetWithConfigs.builder_configs.keys()))
     plus1_config = DummyDatasetWithConfigs.builder_configs["plus1"]
     builder = DummyDatasetWithConfigs(config="plus1", data_dir=tmp_dir)
     self.assertIs(plus1_config, builder.builder_config)
     builder = DummyDatasetWithConfigs(config=plus1_config, data_dir=tmp_dir)
     self.assertIs(plus1_config, builder.builder_config)
     self.assertIs(builder.builder_config,
                   DummyDatasetWithConfigs.BUILDER_CONFIGS[0])
Ejemplo n.º 16
0
    def test_statistics_generation_variable_sizes(self):
        with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = RandomShapedImageGenerator(data_dir=tmp_dir)
            builder.download_and_prepare()

            # Get the expected type of the feature.
            schema_feature = builder.info.as_proto.schema.feature[0]
            self.assertEqual("im", schema_feature.name)

            self.assertEqual(-1, schema_feature.shape.dim[0].size)
            self.assertEqual(-1, schema_feature.shape.dim[1].size)
            self.assertEqual(3, schema_feature.shape.dim[2].size)
Ejemplo n.º 17
0
    def test_statistics_generation(self):
        with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
            builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
            builder.download_and_prepare()

            # Overall
            self.assertEqual(30, builder.info.splits.total_num_examples)

            # Per split.
            test_split = builder.info.splits["test"].get_proto()
            train_split = builder.info.splits["train"].get_proto()
            self.assertEqual(10, test_split.statistics.num_examples)
            self.assertEqual(20, train_split.statistics.num_examples)
Ejemplo n.º 18
0
  def test_multi_split(self):
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      ds_train, ds_test = registered.load(
          name="dummy_dataset_shared_generator",
          data_dir=tmp_dir,
          split=[splits_lib.Split.TRAIN, splits_lib.Split.TEST],
          as_dataset_kwargs=dict(shuffle_files=False))

      data = list(dataset_utils.as_numpy(ds_train))
      self.assertEqual(20, len(data))

      data = list(dataset_utils.as_numpy(ds_test))
      self.assertEqual(10, len(data))
Ejemplo n.º 19
0
    def test_save_load_metadata(self):
        text_f = features.Text(encoder=text_encoder.ByteTextEncoder(
            additional_tokens=['HI']))
        text = u'HI 你好'
        ids = text_f.str2ints(text)
        self.assertEqual(1, ids[0])

        with test_utils.tmp_dir(self.get_temp_dir()) as data_dir:
            feature_name = 'dummy'
            text_f.save_metadata(data_dir, feature_name)

            new_f = features.Text()
            new_f.load_metadata(data_dir, feature_name)
            self.assertEqual(ids, text_f.str2ints(text))
Ejemplo n.º 20
0
  def test_build_data_dir(self):
    # Test that the dataset loads the data_dir for the builder's version
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
      self.assertEqual(str(builder.info.version), "1.0.0")
      builder_data_dir = os.path.join(tmp_dir, builder.name)
      version_dir = os.path.join(builder_data_dir, "1.0.0")

      # The dataset folder contains multiple other versions
      tf.gfile.MakeDirs(os.path.join(builder_data_dir, "14.0.0.invalid"))
      tf.gfile.MakeDirs(os.path.join(builder_data_dir, "10.0.0"))
      tf.gfile.MakeDirs(os.path.join(builder_data_dir, "9.0.0"))
      tf.gfile.MakeDirs(os.path.join(builder_data_dir, "0.1.0"))

      # The builder's version dir is chosen
      self.assertEqual(builder._build_data_dir(), version_dir)
Ejemplo n.º 21
0
  def test_get_data_dir(self):
    # Test that the dataset load the most recent dir
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
      builder_data_dir = os.path.join(tmp_dir, builder.name)

      # The dataset folder contains multiple versions
      tf.gfile.MakeDirs(os.path.join(builder_data_dir, "14.0.0.invalid"))
      tf.gfile.MakeDirs(os.path.join(builder_data_dir, "10.0.0"))
      tf.gfile.MakeDirs(os.path.join(builder_data_dir, "9.0.0"))

      # The last valid version is chosen by default
      most_recent_dir = os.path.join(builder_data_dir, "10.0.0")
      v9_dir = os.path.join(builder_data_dir, "9.0.0")
      self.assertEqual(builder._get_data_dir(), most_recent_dir)
      self.assertEqual(builder._get_data_dir(version="9.0.0"), v9_dir)
  def test_save_load(self):
    labels1 = features.ClassLabel(names=['label3', 'label1', 'label2'])
    labels2 = features.ClassLabel(num_classes=None)
    labels3 = features.ClassLabel(num_classes=1)

    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      labels1.save_metadata(tmp_dir, 'test-labels')
      labels2.load_metadata(tmp_dir, 'test-labels')
      with self.assertRaisesWithPredicateMatch(
          ValueError, 'number of names do not match the defined num_classes'):
        labels3.load_metadata(tmp_dir, 'test-labels')

    # labels2 should have been copied from label1
    self.assertEqual(3, labels2.num_classes)
    self.assertEqual(labels2.names, [
        'label3',
        'label1',
        'label2',
    ])
Ejemplo n.º 23
0
  def test_writing(self):
    # First read in stuff.
    info = dataset_info.DatasetInfo()
    info.read_from_directory(_TESTDATA)

    # Read the json file into a string.
    with tf.gfile.Open(info._dataset_info_filename(_TESTDATA)) as f:
      existing_json = json.load(f)

    # Now write to a temp directory.
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      info.write_to_directory(tmp_dir)

      # Read the newly written json file into a string.
      with tf.gfile.Open(info._dataset_info_filename(tmp_dir)) as f:
        new_json = json.load(f)

    # Assert what was read and then written and read again is the same.
    self.assertEqual(existing_json, new_json)