Пример #1
0
    def test_split_overwrite(self):
        s1 = splits.SplitDict("ds_name")
        s1.add(tfds.core.SplitInfo(name="train", shard_lengths=[15]))

        s2 = splits.SplitDict("ds_name")
        s2.add(tfds.core.SplitInfo(name="train", shard_lengths=[15]))

        self.assertTrue(splits.check_splits_equals(s1, s2))

        # Modifying num_shards should also modify the underlying proto
        s2["train"].shard_lengths = [5, 5, 5]
        self.assertEqual(s2["train"].shard_lengths, [5, 5, 5])
        self.assertEqual(s2["train"].get_proto().shard_lengths, [5, 5, 5])
        self.assertFalse(splits.check_splits_equals(s1, s2))
Пример #2
0
  def test_split_overwrite(self):
    s1 = splits.SplitDict()
    s1.add(tfds.core.SplitInfo(name="train", num_shards=15))

    s2 = splits.SplitDict()
    s2.add(tfds.core.SplitInfo(name="train", num_shards=15))

    self.assertTrue(splits.check_splits_equals(s1, s2))

    # Modifying num_shards should also modify the underlying proto
    s2["train"].num_shards = 10
    self.assertEqual(s2["train"].num_shards, 10)
    self.assertEqual(s2["train"].get_proto().num_shards, 10)
    self.assertFalse(splits.check_splits_equals(s1, s2))
Пример #3
0
    def test_check_splits_equals(self):
        s1 = splits.SplitDict("ds_name")
        s1.add(tfds.core.SplitInfo(name="train", num_shards=10))
        s1.add(tfds.core.SplitInfo(name="test", num_shards=3))

        s2 = splits.SplitDict("ds_name")
        s2.add(tfds.core.SplitInfo(name="train", num_shards=10))
        s2.add(tfds.core.SplitInfo(name="test", num_shards=3))

        s3 = splits.SplitDict("ds_name")
        s3.add(tfds.core.SplitInfo(name="train", num_shards=10))
        s3.add(tfds.core.SplitInfo(name="test", num_shards=3))
        s3.add(tfds.core.SplitInfo(name="valid", num_shards=0))

        s4 = splits.SplitDict("ds_name")
        s4.add(tfds.core.SplitInfo(name="train", num_shards=11))
        s4.add(tfds.core.SplitInfo(name="test", num_shards=3))

        self.assertTrue(splits.check_splits_equals(s1, s1))
        self.assertTrue(splits.check_splits_equals(s1, s2))
        self.assertFalse(splits.check_splits_equals(s1, s3))  # Not same names
        self.assertFalse(splits.check_splits_equals(s1, s4))  # Nb of shards !=
Пример #4
0
  def update_splits_if_different(self, split_dict):
    """Overwrite the splits if they are different from the current ones.

    * If splits aren't already defined or different (ex: different number of
      shards), then the new split dict is used. This will trigger stats
      computation during download_and_prepare.
    * If splits are already defined in DatasetInfo and similar (same names and
      shards): keep the restored split which contains the statistics (restored
      from GCS or file)

    Args:
      split_dict: `tfds.core.SplitDict`, the new split
    """
    assert isinstance(split_dict, splits_lib.SplitDict)

    # If splits are already defined and identical, then we do not update
    if self._splits and splits_lib.check_splits_equals(
        self._splits, split_dict):
      return

    self._set_splits(split_dict)