def _serialize_shard(self, key_example): """Returns (shard#, (hkey, serialized_example)).""" key, example = key_example serialized_example = self._serializer.serialize_example(example) hkey = self._hasher.hash_key(key) bucketid = shuffle.get_bucket_number(hkey, _BEAM_NUM_TEMP_SHARDS) hkey = _long_for_py2(hkey) bucketid = _long_for_py2(bucketid) return (bucketid, (hkey, serialized_example))
def test_order(self): shards_number = 10 shards = [shuffle.get_bucket_number(k, shards_number) for k in range(1024)] # Check max(shard_x) < min(shard_y) if x < y. previous_shard = 0 for shard in shards: self.assertGreaterEqual(shard, previous_shard) previous_shard = shard # Check distribution: all shards are used. counts = collections.Counter(shards) self.assertEqual(len(counts), shards_number) # And all shards contain same number of elements (102 or 102 in that case). self.assertEqual(len(set(counts.values())), 2)
def _serialize_shard( self, key_example: Tuple[hashing.HashKey, Example], ) -> Tuple[int, Tuple[Any, bytes]]: """Returns (shard#, (hkey, serialized_example)).""" key, example = key_example serialized_example = self._serializer.serialize_example(example) if self._disable_shuffling: hkey = key else: hkey = self._hasher.hash_key(key) bucketid = shuffle.get_bucket_number(hkey, _BEAM_NUM_TEMP_SHARDS) hkey = _long_for_py2(hkey) bucketid = _long_for_py2(bucketid) return (bucketid, (hkey, serialized_example))