def test_source_split(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) worker = sdk_harness.SdkWorker( None, data_plane.GrpcClientDataChannelFactory()) worker.register( beam_fn_api_pb2.RegisterRequest(process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor(primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness.serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src")) ]) ])) split_response = worker.initial_source_split( beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) self.assertEqual(expected_splits, [ sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits ]) self.assertEqual([s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def test_source_split(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) worker = sdk_harness.SdkWorker( None, data_plane.GrpcClientDataChannelFactory()) worker.register( beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[beam_fn_api_pb2.ProcessBundleDescriptor( primitive_transform=[beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness.serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src"))])])) split_response = worker.initial_source_split( beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) self.assertEqual( expected_splits, [sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits]) self.assertEqual( [s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def setUp(self): self.initial_range_start = 0 self.initial_range_stop = 4 self.initial_range_source = RangeSource(self.initial_range_start, self.initial_range_stop) self.sdf_restriction_provider = ( iobase._SDFBoundedSourceRestrictionProvider(desired_chunk_size=2))
def test_source_split_via_instruction(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( instruction_id="register_request", register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor( primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness. serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src")) ]) ])), beam_fn_api_pb2.InstructionRequest( instruction_id="split_request", initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_harness.SdkHarness(channel) harness.run() split_response = test_controller.responses[ "split_request"].initial_source_split self.assertEqual(expected_splits, [ sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits ]) self.assertEqual([s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def test_source_split_via_instruction(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( instruction_id="register_request", register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor( primitive_transform=[beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness.serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src"))])])), beam_fn_api_pb2.InstructionRequest( instruction_id="split_request", initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_harness.SdkHarness(channel) harness.run() split_response = test_controller.responses[ "split_request"].initial_source_split self.assertEqual( expected_splits, [sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits]) self.assertEqual( [s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def test_try_split_with_any_exception(self): source_bundle = SourceBundle( range_trackers.OffsetRangeTracker.OFFSET_INFINITY, RangeSource(0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY), 0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY) self.sdf_restriction_tracker = ( iobase._SDFBoundedSourceRestrictionTracker( iobase._SDFBoundedSourceRestriction(source_bundle))) self.sdf_restriction_tracker.try_claim(0) self.assertIsNone(self.sdf_restriction_tracker.try_split(0.5))
def setUp(self): self.initial_start_pos = 0 self.initial_stop_pos = 4 source_bundle = SourceBundle( self.initial_stop_pos - self.initial_start_pos, RangeSource(self.initial_start_pos, self.initial_stop_pos), self.initial_start_pos, self.initial_stop_pos) self.sdf_restriction_tracker = ( iobase._SDFBoundedSourceRestrictionTracker( iobase._SDFBoundedSourceRestriction(source_bundle)))
def test_create_tracker(self): expected_start = 1 expected_stop = 3 source_bundle = SourceBundle(expected_stop - expected_start, RangeSource(1, 3), expected_start, expected_stop) restriction_tracker = (self.sdf_restriction_provider.create_tracker( iobase._SDFBoundedSourceRestriction(source_bundle))) self.assertTrue( isinstance(restriction_tracker, iobase._SDFBoundedSourceRestrictionTracker)) self.assertEqual(expected_start, restriction_tracker.start_pos()) self.assertEqual(expected_stop, restriction_tracker.stop_pos())
def test_initialization(self): with self.assertRaises(ValueError): iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
def test_sdf_wrap_range_source(self): self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
def test_sdf_wrapper_overrides_read(self, sdf_wrapper_mock_expand): def _fake_wrapper_expand(pbegin): return (pbegin | beam.Create(['fake'])) sdf_wrapper_mock_expand.side_effect = _fake_wrapper_expand self._run_sdf_wrapper_pipeline(RangeSource(0, 4), ['fake'])