def test_fn_registration(self): fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)] process_bundle_descriptors = [ beam_fn_api_pb2.ProcessBundleDescriptor( id=str(100 + ix), primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform(function_spec=fn) ]) for ix, fn in enumerate(fns) ] test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=process_bundle_descriptors)) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_worker.SdkHarness(channel) harness.run() self.assertEqual( harness.worker.fns, {item.id: item for item in fns + process_bundle_descriptors})
def test_source_split_via_instruction(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( instruction_id="register_request", register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor( primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness. serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src")) ]) ])), beam_fn_api_pb2.InstructionRequest( instruction_id="split_request", initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_harness.SdkHarness(channel) harness.run() split_response = test_controller.responses[ "split_request"].initial_source_split self.assertEqual(expected_splits, [ sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits ]) self.assertEqual([s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def _run_map_task(self, map_task, control_handler, state_handler, data_plane_handler, data_operation_spec): registration, sinks, input_data = self._map_task_registration( map_task, state_handler, data_operation_spec) control_handler.push(registration) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=registration.register. process_bundle_descriptor[0].id)) for (transform_id, name), elements in input_data.items(): data_out = data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) data_out.write(elements) data_out.close() control_handler.push(process_bundle) while True: result = control_handler.pull() if result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) expected_targets = [ beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in sinks.items() ] for output in data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = ( output.target.primitive_transform_reference, output.target.name) if target_tuple not in sinks: # Unconsumed output. continue sink_op = sinks[target_tuple] coder = sink_op.output_coders[0] input_stream = create_InputStream(output.data) elements = [] while input_stream.size() > 0: elements.append(coder.get_impl().decode_from_stream( input_stream, True)) if not sink_op.write_windowed_values: elements = [e.value for e in elements] for e in elements: sink_op.output_buffer.append(e) return
def _map_task_registration(self, map_task, state_handler, data_operation_spec): input_data = {} runner_sinks = {} transforms = [] transform_index_to_id = {} # Maps coders to new coder objects and references. coders = {} def coder_id(coder): if coder not in coders: coders[coder] = beam_fn_api_pb2.Coder( function_spec=sdk_worker.pack_function_spec_data( json.dumps(coder.as_cloud_object()), sdk_worker.PYTHON_CODER_URN, id=self._next_uid())) return coders[coder].function_spec.id def output_tags(op): return getattr(op, 'output_tags', ['out']) def as_target(op_input): input_op_index, input_output_index = op_input input_op = map_task[input_op_index][1] return { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference=transform_index_to_id[ input_op_index], name=output_tags(input_op)[input_output_index]) ]) } def outputs(op): return { tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder)) for tag, coder in zip(output_tags(op), op.output_coders) } for op_ix, (stage_name, operation) in enumerate(map_task): transform_id = transform_index_to_id[op_ix] = self._next_uid() if isinstance(operation, operation_specs.WorkerInMemoryWrite): # Write this data back to the runner. fn = beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_OUTPUT_URN, id=self._next_uid()) if data_operation_spec: fn.data.Pack(data_operation_spec) inputs = as_target(operation.input) side_inputs = {} runner_sinks[(transform_id, 'out')] = operation elif isinstance(operation, operation_specs.WorkerRead): # A Read is either translated to a direct injection of windowed values # into the sdk worker, or an injection of the source object into the # sdk worker as data followed by an SDF that reads that source. if (isinstance(operation.source.source, worker_runner_base.InMemorySource) and isinstance( operation.source.source.default_output_coder(), WindowedValueCoder)): output_stream = create_OutputStream() element_coder = (operation.source.source. default_output_coder().get_impl()) # Re-encode the elements in the nested context and # concatenate them together for element in operation.source.source.read(None): element_coder.encode_to_stream(element, output_stream, True) target_name = self._next_uid() input_data[(transform_id, target_name)] = output_stream.get() fn = beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid()) if data_operation_spec: fn.data.Pack(data_operation_spec) inputs = {target_name: beam_fn_api_pb2.Target.List()} side_inputs = {} else: # Read the source object from the runner. source_coder = beam.coders.DillCoder() input_transform_id = self._next_uid() output_stream = create_OutputStream() source_coder.get_impl().encode_to_stream( GlobalWindows.windowed_value(operation.source), output_stream, True) target_name = self._next_uid() input_data[(input_transform_id, target_name)] = output_stream.get() input_ptransform = beam_fn_api_pb2.PrimitiveTransform( id=input_transform_id, function_spec=beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid()), # TODO(robertwb): Possible name collision. step_name=stage_name + '/inject_source', inputs={target_name: beam_fn_api_pb2.Target.List()}, outputs={ 'out': beam_fn_api_pb2.PCollection( coder_reference=coder_id(source_coder)) }) if data_operation_spec: input_ptransform.function_spec.data.Pack( data_operation_spec) transforms.append(input_ptransform) # Read the elements out of the source. fn = sdk_worker.pack_function_spec_data( OLDE_SOURCE_SPLITTABLE_DOFN_DATA, sdk_worker.PYTHON_DOFN_URN, id=self._next_uid()) inputs = { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference= input_transform_id, name='out') ]) } side_inputs = {} elif isinstance(operation, operation_specs.WorkerDoFn): fn = sdk_worker.pack_function_spec_data( operation.serialized_fn, sdk_worker.PYTHON_DOFN_URN, id=self._next_uid()) inputs = as_target(operation.input) # Store the contents of each side input for state access. for si in operation.side_inputs: assert isinstance(si.source, iobase.BoundedSource) element_coder = si.source.default_output_coder() view_id = self._next_uid() # TODO(robertwb): Actually flesh out the ViewFn API. side_inputs[si.tag] = beam_fn_api_pb2.SideInput( view_fn=sdk_worker.serialize_and_pack_py_fn( element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN, id=view_id)) # Re-encode the elements in the nested context and # concatenate them together output_stream = create_OutputStream() for element in si.source.read( si.source.get_range_tracker(None, None)): element_coder.get_impl().encode_to_stream( element, output_stream, True) elements_data = output_stream.get() state_key = beam_fn_api_pb2.StateKey( function_spec_reference=view_id) state_handler.Clear(state_key) state_handler.Append( beam_fn_api_pb2.SimpleStateAppendRequest( state_key=state_key, data=[elements_data])) elif isinstance(operation, operation_specs.WorkerFlatten): fn = sdk_worker.pack_function_spec_data( operation.serialized_fn, sdk_worker.IDENTITY_DOFN_URN, id=self._next_uid()) inputs = { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference= transform_index_to_id[input_op_index], name=output_tags(map_task[input_op_index] [1])[input_output_index]) for input_op_index, input_output_index in operation.inputs ]) } side_inputs = {} else: raise TypeError(operation) ptransform = beam_fn_api_pb2.PrimitiveTransform( id=transform_id, function_spec=fn, step_name=stage_name, inputs=inputs, side_inputs=side_inputs, outputs=outputs(operation)) transforms.append(ptransform) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), coders=coders.values(), primitive_transform=transforms) return beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor ])), runner_sinks, input_data