Exemplo n.º 1
0
    def __init__(
        self,
        example_specs,
        path,
        hash_salt,
        disable_shuffling: bool,
        file_format=file_adapters.DEFAULT_FILE_FORMAT,
    ):
        """Initializes Writer.

    Args:
      example_specs: spec to build ExampleSerializer.
      path (str): path where records should be written in.
      hash_salt (str or bytes): salt to hash keys.
      disable_shuffling (bool): Specifies whether to shuffle the records.
      file_format (FileFormat): format of the record files in which the dataset
        should be written in.
    """
        self._example_specs = example_specs
        self._serializer = example_serializer.ExampleSerializer(example_specs)
        self._shuffler = shuffle.Shuffler(os.path.dirname(path), hash_salt,
                                          disable_shuffling)
        self._num_examples = 0
        self._path = path
        self._file_format = file_format
Exemplo n.º 2
0
 def test_all_mem(self):
   shuffler = shuffle.Shuffler(self.get_temp_dir())
   for key, item in _ITEMS:
     shuffler.add(key, item)
   self.assertEqual(shuffler.size, _TOTAL_SIZE)
   records = list(iter(shuffler))
   self.assertEqual(records, _ORDERED_ITEMS)
Exemplo n.º 3
0
 def _test_items(self, salt, expected_order):
     shuffler = shuffle.Shuffler(self.get_temp_dir(), salt)
     for key, item in _ITEMS:
         shuffler.add(key, item)
     self.assertEqual(shuffler.size, _TOTAL_SIZE)
     records = list(iter(shuffler))
     self.assertEqual(records, expected_order)
Exemplo n.º 4
0
    def __init__(
        self,
        example_specs,
        filename_template: naming.ShardedFileTemplate,
        hash_salt,
        disable_shuffling: bool,
        file_format=file_adapters.DEFAULT_FILE_FORMAT,
    ):
        """Initializes Writer.

    Args:
      example_specs: spec to build ExampleSerializer.
      filename_template: template to format sharded filenames.
      hash_salt (str or bytes): salt to hash keys.
      disable_shuffling (bool): Specifies whether to shuffle the records.
      file_format (FileFormat): format of the record files in which the dataset
        should be written in.
    """
        self._example_specs = example_specs
        self._serializer = example_serializer.ExampleSerializer(example_specs)
        self._shuffler = shuffle.Shuffler(dirpath=filename_template.data_dir,
                                          hash_salt=hash_salt,
                                          disable_shuffling=disable_shuffling)
        self._num_examples = 0
        self._filename_template = filename_template
        self._file_format = file_format
 def test_duplicate_key(self):
     shuffler = shuffle.Shuffler(self.get_temp_dir(), 'split1')
     shuffler.add(1, b'a')
     shuffler.add(2, b'b')
     shuffler.add(1, b'c')
     iterator = iter(shuffler)
     self.assertEqual(next(iterator), b'a')
     with self.assertRaises(shuffle.DuplicatedKeysError):
         next(iterator)
Exemplo n.º 6
0
 def test_duplicate_key(self):
   shuffler = shuffle.Shuffler(self.get_temp_dir())
   shuffler.add(1, b'a')
   shuffler.add(2, b'b')
   shuffler.add(1, b'c')
   iterator = iter(shuffler)
   self.assertEqual(next(iterator), b'a')
   with self.assertRaisesWithPredicateMatch(
       AssertionError, 'Two records share the same hashed key!'):
     next(iterator)
Exemplo n.º 7
0
 def test_duplicate_key(self):
     shuffler = shuffle.Shuffler(self.get_temp_dir(), 'split1')
     shuffler.add(1, b'a')
     shuffler.add(2, b'b')
     shuffler.add(1, b'c')
     iterator = iter(shuffler)
     self.assertEqual(next(iterator),
                      (86269847664267119453139349052967691808, b'a'))
     with self.assertRaises(shuffle.DuplicatedKeysError):
         next(iterator)
Exemplo n.º 8
0
 def _test_items(self, salt, expected_order):
   shuffler = shuffle.Shuffler(self.get_temp_dir(), salt)
   for key, item in _ITEMS:
     shuffler.add(key, item)
   self.assertEqual(shuffler.size, _TOTAL_SIZE)
   if not shuffler._in_memory:  # Check size of temporary bucket files
     expected_size = (16 + 8) * len(_ITEMS) + sum(len(t[1]) for t in _ITEMS)
     size = 0
     for bucket in shuffler._buckets:
       if not bucket._fobj:
         continue
       bucket._fobj.close()
       size += len(open(bucket._path, 'rb').read())
     self.assertEqual(size, expected_size)
   # Check records can be read as expected:
   records = list(iter(shuffler))
   self.assertEqual(records, expected_order)
Exemplo n.º 9
0
 def _test_items(self,
                 salt,
                 items,
                 expected_order,
                 disable_shuffling=False):
     shuffler = shuffle.Shuffler(self.get_temp_dir(), salt,
                                 disable_shuffling)
     for key, item in items:
         shuffler.add(key, item)
     self.assertEqual(shuffler.size, _TOTAL_SIZE)
     if not shuffler._in_memory:  # Check size of temporary bucket files
         expected_size = (16 + 8) * len(items) + sum(
             len(t[1]) for t in items)
         size = 0
         for bucket in shuffler._buckets:
             if not bucket._fobj:
                 continue
             bucket._fobj.close()
             with open(bucket._path, 'rb') as f:
                 size += len(f.read())
         self.assertEqual(size, expected_size)
     # Check records can be read as expected:
     records = list(ex for _, ex in shuffler)
     self.assertEqual(records, expected_order)
Exemplo n.º 10
0
 def test_nonbytes(self):
   shuffler = shuffle.Shuffler(self.get_temp_dir())
   with self.assertRaisesWithPredicateMatch(AssertionError, 'Only bytes'):
     shuffler.add(1, u'a')
   with self.assertRaisesWithPredicateMatch(AssertionError, 'Only bytes'):
     shuffler.add(1, 123)
Exemplo n.º 11
0
 def __init__(self, example_specs, path, hash_salt):
     self._serializer = example_serializer.ExampleSerializer(example_specs)
     self._shuffler = shuffle.Shuffler(os.path.dirname(path), hash_salt)
     self._num_examples = 0
     self._path = path