def test_split_overwrite(self): s1 = splits.SplitDict("ds_name") s1.add(tfds.core.SplitInfo(name="train", shard_lengths=[15])) s2 = splits.SplitDict("ds_name") s2.add(tfds.core.SplitInfo(name="train", shard_lengths=[15])) self.assertTrue(splits.check_splits_equals(s1, s2)) # Modifying num_shards should also modify the underlying proto s2["train"].shard_lengths = [5, 5, 5] self.assertEqual(s2["train"].shard_lengths, [5, 5, 5]) self.assertEqual(s2["train"].get_proto().shard_lengths, [5, 5, 5]) self.assertFalse(splits.check_splits_equals(s1, s2))
def test_split_overwrite(self): s1 = splits.SplitDict() s1.add(tfds.core.SplitInfo(name="train", num_shards=15)) s2 = splits.SplitDict() s2.add(tfds.core.SplitInfo(name="train", num_shards=15)) self.assertTrue(splits.check_splits_equals(s1, s2)) # Modifying num_shards should also modify the underlying proto s2["train"].num_shards = 10 self.assertEqual(s2["train"].num_shards, 10) self.assertEqual(s2["train"].get_proto().num_shards, 10) self.assertFalse(splits.check_splits_equals(s1, s2))
def test_check_splits_equals(self): s1 = splits.SplitDict("ds_name") s1.add(tfds.core.SplitInfo(name="train", num_shards=10)) s1.add(tfds.core.SplitInfo(name="test", num_shards=3)) s2 = splits.SplitDict("ds_name") s2.add(tfds.core.SplitInfo(name="train", num_shards=10)) s2.add(tfds.core.SplitInfo(name="test", num_shards=3)) s3 = splits.SplitDict("ds_name") s3.add(tfds.core.SplitInfo(name="train", num_shards=10)) s3.add(tfds.core.SplitInfo(name="test", num_shards=3)) s3.add(tfds.core.SplitInfo(name="valid", num_shards=0)) s4 = splits.SplitDict("ds_name") s4.add(tfds.core.SplitInfo(name="train", num_shards=11)) s4.add(tfds.core.SplitInfo(name="test", num_shards=3)) self.assertTrue(splits.check_splits_equals(s1, s1)) self.assertTrue(splits.check_splits_equals(s1, s2)) self.assertFalse(splits.check_splits_equals(s1, s3)) # Not same names self.assertFalse(splits.check_splits_equals(s1, s4)) # Nb of shards !=
def update_splits_if_different(self, split_dict): """Overwrite the splits if they are different from the current ones. * If splits aren't already defined or different (ex: different number of shards), then the new split dict is used. This will trigger stats computation during download_and_prepare. * If splits are already defined in DatasetInfo and similar (same names and shards): keep the restored split which contains the statistics (restored from GCS or file) Args: split_dict: `tfds.core.SplitDict`, the new split """ assert isinstance(split_dict, splits_lib.SplitDict) # If splits are already defined and identical, then we do not update if self._splits and splits_lib.check_splits_equals( self._splits, split_dict): return self._set_splits(split_dict)