def test_fn_registration(self): fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)] process_bundle_descriptors = [ beam_fn_api_pb2.ProcessBundleDescriptor( id=str(100 + ix), primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform(function_spec=fn) ]) for ix, fn in enumerate(fns) ] test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=process_bundle_descriptors)) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_worker.SdkHarness(channel) harness.run() self.assertEqual( harness.worker.fns, {item.id: item for item in fns + process_bundle_descriptors})
def __init__(self): self.state_handler = FnApiRunner.SimpleState() self.control_server = grpc.server( futures.ThreadPoolExecutor(max_workers=10)) self.control_port = self.control_server.add_insecure_port('[::]:0') self.data_server = grpc.server( futures.ThreadPoolExecutor(max_workers=10)) self.data_port = self.data_server.add_insecure_port('[::]:0') self.control_handler = streaming_rpc_handler( beam_fn_api_pb2.BeamFnControlServicer, 'Control') beam_fn_api_pb2.add_BeamFnControlServicer_to_server( self.control_handler, self.control_server) self.data_plane_handler = data_plane.GrpcServerDataChannel() beam_fn_api_pb2.add_BeamFnDataServicer_to_server( self.data_plane_handler, self.data_server) logging.info('starting control server on port %s', self.control_port) logging.info('starting data server on port %s', self.data_port) self.data_server.start() self.control_server.start() self.worker = sdk_worker.SdkHarness( grpc.insecure_channel('localhost:%s' % self.control_port)) self.worker_thread = threading.Thread(target=self.worker.run) logging.info('starting worker') self.worker_thread.start()
def test_source_split_via_instruction(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( instruction_id="register_request", register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor( primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness. serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src")) ]) ])), beam_fn_api_pb2.InstructionRequest( instruction_id="split_request", initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_harness.SdkHarness(channel) harness.run() split_response = test_controller.responses[ "split_request"].initial_source_split self.assertEqual(expected_splits, [ sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits ]) self.assertEqual([s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def test_source_split_via_instruction(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( instruction_id="register_request", register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor( primitive_transform=[beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness.serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src"))])])), beam_fn_api_pb2.InstructionRequest( instruction_id="split_request", initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_harness.SdkHarness(channel) harness.run() split_response = test_controller.responses[ "split_request"].initial_source_split self.assertEqual( expected_splits, [sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits]) self.assertEqual( [s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def test_fn_registration(self): fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)] process_bundle_descriptors = [beam_fn_api_pb2.ProcessBundleDescriptor( id=str(100+ix), primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform(function_spec=fn)]) for ix, fn in enumerate(fns)] test_controller = BeamFnControlServicer([beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=process_bundle_descriptors))]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_worker.SdkHarness(channel) harness.run() self.assertEqual( harness.worker.fns, {item.id: item for item in fns + process_bundle_descriptors})