def __init__( self, example_specs, path, hash_salt, disable_shuffling: bool, file_format=file_adapters.DEFAULT_FILE_FORMAT, ): """Initializes Writer. Args: example_specs: spec to build ExampleSerializer. path (str): path where records should be written in. hash_salt (str or bytes): salt to hash keys. disable_shuffling (bool): Specifies whether to shuffle the records. file_format (FileFormat): format of the record files in which the dataset should be written in. """ self._example_specs = example_specs self._serializer = example_serializer.ExampleSerializer(example_specs) self._shuffler = shuffle.Shuffler(os.path.dirname(path), hash_salt, disable_shuffling) self._num_examples = 0 self._path = path self._file_format = file_format
def test_all_mem(self): shuffler = shuffle.Shuffler(self.get_temp_dir()) for key, item in _ITEMS: shuffler.add(key, item) self.assertEqual(shuffler.size, _TOTAL_SIZE) records = list(iter(shuffler)) self.assertEqual(records, _ORDERED_ITEMS)
def _test_items(self, salt, expected_order): shuffler = shuffle.Shuffler(self.get_temp_dir(), salt) for key, item in _ITEMS: shuffler.add(key, item) self.assertEqual(shuffler.size, _TOTAL_SIZE) records = list(iter(shuffler)) self.assertEqual(records, expected_order)
def __init__( self, example_specs, filename_template: naming.ShardedFileTemplate, hash_salt, disable_shuffling: bool, file_format=file_adapters.DEFAULT_FILE_FORMAT, ): """Initializes Writer. Args: example_specs: spec to build ExampleSerializer. filename_template: template to format sharded filenames. hash_salt (str or bytes): salt to hash keys. disable_shuffling (bool): Specifies whether to shuffle the records. file_format (FileFormat): format of the record files in which the dataset should be written in. """ self._example_specs = example_specs self._serializer = example_serializer.ExampleSerializer(example_specs) self._shuffler = shuffle.Shuffler(dirpath=filename_template.data_dir, hash_salt=hash_salt, disable_shuffling=disable_shuffling) self._num_examples = 0 self._filename_template = filename_template self._file_format = file_format
def test_duplicate_key(self): shuffler = shuffle.Shuffler(self.get_temp_dir(), 'split1') shuffler.add(1, b'a') shuffler.add(2, b'b') shuffler.add(1, b'c') iterator = iter(shuffler) self.assertEqual(next(iterator), b'a') with self.assertRaises(shuffle.DuplicatedKeysError): next(iterator)
def test_duplicate_key(self): shuffler = shuffle.Shuffler(self.get_temp_dir()) shuffler.add(1, b'a') shuffler.add(2, b'b') shuffler.add(1, b'c') iterator = iter(shuffler) self.assertEqual(next(iterator), b'a') with self.assertRaisesWithPredicateMatch( AssertionError, 'Two records share the same hashed key!'): next(iterator)
def test_duplicate_key(self): shuffler = shuffle.Shuffler(self.get_temp_dir(), 'split1') shuffler.add(1, b'a') shuffler.add(2, b'b') shuffler.add(1, b'c') iterator = iter(shuffler) self.assertEqual(next(iterator), (86269847664267119453139349052967691808, b'a')) with self.assertRaises(shuffle.DuplicatedKeysError): next(iterator)
def _test_items(self, salt, expected_order): shuffler = shuffle.Shuffler(self.get_temp_dir(), salt) for key, item in _ITEMS: shuffler.add(key, item) self.assertEqual(shuffler.size, _TOTAL_SIZE) if not shuffler._in_memory: # Check size of temporary bucket files expected_size = (16 + 8) * len(_ITEMS) + sum(len(t[1]) for t in _ITEMS) size = 0 for bucket in shuffler._buckets: if not bucket._fobj: continue bucket._fobj.close() size += len(open(bucket._path, 'rb').read()) self.assertEqual(size, expected_size) # Check records can be read as expected: records = list(iter(shuffler)) self.assertEqual(records, expected_order)
def _test_items(self, salt, items, expected_order, disable_shuffling=False): shuffler = shuffle.Shuffler(self.get_temp_dir(), salt, disable_shuffling) for key, item in items: shuffler.add(key, item) self.assertEqual(shuffler.size, _TOTAL_SIZE) if not shuffler._in_memory: # Check size of temporary bucket files expected_size = (16 + 8) * len(items) + sum( len(t[1]) for t in items) size = 0 for bucket in shuffler._buckets: if not bucket._fobj: continue bucket._fobj.close() with open(bucket._path, 'rb') as f: size += len(f.read()) self.assertEqual(size, expected_size) # Check records can be read as expected: records = list(ex for _, ex in shuffler) self.assertEqual(records, expected_order)
def test_nonbytes(self): shuffler = shuffle.Shuffler(self.get_temp_dir()) with self.assertRaisesWithPredicateMatch(AssertionError, 'Only bytes'): shuffler.add(1, u'a') with self.assertRaisesWithPredicateMatch(AssertionError, 'Only bytes'): shuffler.add(1, 123)
def __init__(self, example_specs, path, hash_salt): self._serializer = example_serializer.ExampleSerializer(example_specs) self._shuffler = shuffle.Shuffler(os.path.dirname(path), hash_salt) self._num_examples = 0 self._path = path