def test_source_split(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) worker = sdk_harness.SdkWorker( None, data_plane.GrpcClientDataChannelFactory()) worker.register( beam_fn_api_pb2.RegisterRequest(process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor(primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness.serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src")) ]) ])) split_response = worker.initial_source_split( beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) self.assertEqual(expected_splits, [ sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits ]) self.assertEqual([s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def test_try_split_with_any_exception(self): source_bundle = SourceBundle( range_trackers.OffsetRangeTracker.OFFSET_INFINITY, RangeSource(0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY), 0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY) self.sdf_restriction_tracker = ( iobase._SDFBoundedSourceRestrictionTracker( iobase._SDFBoundedSourceRestriction(source_bundle))) self.sdf_restriction_tracker.try_claim(0) self.assertIsNone(self.sdf_restriction_tracker.try_split(0.5))
def setUp(self): self.initial_start_pos = 0 self.initial_stop_pos = 4 source_bundle = SourceBundle( self.initial_stop_pos - self.initial_start_pos, RangeSource(self.initial_start_pos, self.initial_stop_pos), self.initial_start_pos, self.initial_stop_pos) self.sdf_restriction_tracker = ( iobase._SDFBoundedSourceRestrictionTracker( iobase._SDFBoundedSourceRestriction(source_bundle)))
def split(self, desired_bundle_size, start_position=0, stop_position=None): """ Implements method `apache_beam.io.iobase.BoundedSource.split` `BillboardSource` is unsplittable, so only a single source is returned. """ stop_position = range_trackers.OffsetRangeTracker.OFFSET_INFINITY yield SourceBundle(weight=1, source=self, start_position=start_position, stop_position=stop_position)
def test_create_tracker(self): expected_start = 1 expected_stop = 3 source_bundle = SourceBundle(expected_stop - expected_start, RangeSource(1, 3), expected_start, expected_stop) restriction_tracker = (self.sdf_restriction_provider.create_tracker( iobase._SDFBoundedSourceRestriction(source_bundle))) self.assertTrue( isinstance(restriction_tracker, iobase._SDFBoundedSourceRestrictionTracker)) self.assertEqual(expected_start, restriction_tracker.start_pos()) self.assertEqual(expected_stop, restriction_tracker.stop_pos())
def split(self, desired_bundle_size=None, start_position=None, stop_position=None): """ Splits the source into a set of bundles, using the row_set if it is set. *** At this point, only splitting an entire table into samples based on the sample row keys is supported *** :param desired_bundle_size: the desired size (in bytes) of the bundles returned. :param start_position: if specified, the position must be used as the starting position of the first bundle. :param stop_position: if specified, the position must be used as the ending position of the last bundle. Returns: an iterator of objects of type 'SourceBundle' that gives information about the generated bundles. """ if desired_bundle_size is not None or start_position is not None or stop_position is not None: raise NotImplementedError # TODO: Use the desired bundle size to split accordingly # TODO: Allow users to provide their own row sets sample_row_keys = list(self.get_sample_row_keys()) bundles = [] if len(sample_row_keys) > 0 and sample_row_keys[0] != b'': bundles.append( SourceBundle(sample_row_keys[0].offset_bytes, self, b'', sample_row_keys[0].row_key)) for i in range(1, len(sample_row_keys)): pos_start = sample_row_keys[i - 1].offset_bytes pos_stop = sample_row_keys[i].offset_bytes bundles.append( SourceBundle(pos_stop - pos_start, self, sample_row_keys[i - 1].row_key, sample_row_keys[i].row_key)) # Shuffle is needed to allow reading from different locations of the table for better efficiency shuffle(bundles) return bundles
def split(self, desired_bundle_size, start_offset=None, stop_offset=None): if start_offset is None: start_offset = self._start_offset if stop_offset is None: stop_offset = self._stop_offset if self._splittable: splits = OffsetRange(start_offset, stop_offset).split(desired_bundle_size, self._min_bundle_size) for split in splits: yield SourceBundle( split.stop - split.start, _SingleFileSource( # Copying this so that each sub-source gets a fresh instance. pickler.loads(pickler.dumps(self._file_based_source)), self._file_name, split.start, split.stop, min_bundle_size=self._min_bundle_size, splittable=self._splittable), split.start, split.stop) else: # Returning a single sub-source with end offset set to OFFSET_INFINITY (so # that all data of the source gets read) since this source is # unsplittable. Choosing size of the file as end offset will be wrong for # certain unsplittable source, e.g., compressed sources. yield SourceBundle( stop_offset - start_offset, _SingleFileSource(self._file_based_source, self._file_name, start_offset, OffsetRangeTracker.OFFSET_INFINITY, min_bundle_size=self._min_bundle_size, splittable=self._splittable), start_offset, OffsetRangeTracker.OFFSET_INFINITY)
def test_source_split_via_instruction(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( instruction_id="register_request", register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor( primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness. serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src")) ]) ])), beam_fn_api_pb2.InstructionRequest( instruction_id="split_request", initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_harness.SdkHarness(channel) harness.run() split_response = test_controller.responses[ "split_request"].initial_source_split self.assertEqual(expected_splits, [ sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits ]) self.assertEqual([s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def split(self, desired_bundle_size, start_position=None, stop_position=None): if self.split_result is None: bq = bigquery_tools.BigQueryWrapper() if self.query is not None: self._setup_temporary_dataset(bq) self.table_reference = self._execute_query(bq) schema, metadata_list = self._export_files(bq) self.split_result = [ TextSource(metadata.path, 0, CompressionTypes.UNCOMPRESSED, True, self.coder(schema)) for metadata in metadata_list ] if self.query is not None: bq.clean_up_temporary_dataset(self.project.get()) for source in self.split_result: yield SourceBundle(0, source, None, None)