def test_fn_registration(self): process_bundle_descriptors = [ beam_fn_api_pb2.ProcessBundleDescriptor( id=str(100 + ix), transforms={ str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix)) }) for ix in range(4) ] test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=process_bundle_descriptors)) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_worker.SdkHarness(channel) harness.run() self.assertEqual( harness.worker.fns, {item.id: item for item in process_bundle_descriptors})
def test_source_split(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) worker = sdk_harness.SdkWorker( None, data_plane.GrpcClientDataChannelFactory()) worker.register( beam_fn_api_pb2.RegisterRequest(process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor(primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness.serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src")) ]) ])) split_response = worker.initial_source_split( beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) self.assertEqual(expected_splits, [ sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits ]) self.assertEqual([s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def test_fn_registration(self): fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)] process_bundle_descriptors = [ beam_fn_api_pb2.ProcessBundleDescriptor( id=str(100 + ix), primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform(function_spec=fn) ]) for ix, fn in enumerate(fns) ] test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=process_bundle_descriptors)) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_worker.SdkHarness(channel) harness.run() self.assertEqual( harness.worker.fns, {item.id: item for item in fns + process_bundle_descriptors})
def _get_process_bundles(self, prefix, size): return [ beam_fn_api_pb2.ProcessBundleDescriptor( id=str(str(prefix) + "-" + str(ix)), transforms={ str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix)) }) for ix in range(size) ]
def test_source_split_via_instruction(self): source = RangeSource(0, 100) expected_splits = list(source.split(30)) test_controller = BeamFnControlServicer([ beam_fn_api_pb2.InstructionRequest( instruction_id="register_request", register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[ beam_fn_api_pb2.ProcessBundleDescriptor( primitive_transform=[ beam_fn_api_pb2.PrimitiveTransform( function_spec=sdk_harness. serialize_and_pack_py_fn( SourceBundle(1.0, source, None, None), sdk_harness.PYTHON_SOURCE_URN, id="src")) ]) ])), beam_fn_api_pb2.InstructionRequest( instruction_id="split_request", initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( desired_bundle_size_bytes=30, source_reference="src")) ]) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) beam_fn_api_pb2.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() channel = grpc.insecure_channel("localhost:%s" % test_port) harness = sdk_harness.SdkHarness(channel) harness.run() split_response = test_controller.responses[ "split_request"].initial_source_split self.assertEqual(expected_splits, [ sdk_harness.unpack_and_deserialize_py_fn(s.source) for s in split_response.splits ]) self.assertEqual([s.weight for s in expected_splits], [s.relative_size for s in split_response.splits])
def _build_process_bundle_descriptor(self): res = beam_fn_api_pb2.ProcessBundleDescriptor( id=self.bundle_uid, transforms={ transform.unique_name: transform for transform in self.stage.transforms }, pcollections=dict( self.execution_context.pipeline_components.pcollections.items()), coders=dict(self.execution_context.pipeline_components.coders.items()), windowing_strategies=dict( self.execution_context.pipeline_components.windowing_strategies. items()), environments=dict( self.execution_context.pipeline_components.environments.items()), state_api_service_descriptor=self.state_api_service_descriptor()) return res
def _build_process_bundle_descriptor(self): # Cannot be invoked until *after* _extract_endpoints is called. return beam_fn_api_pb2.ProcessBundleDescriptor( id=self.bundle_uid, transforms={ transform.unique_name: transform for transform in self.stage.transforms }, pcollections=dict(self.execution_context.pipeline_components. pcollections.items()), coders=dict( self.execution_context.pipeline_components.coders.items()), windowing_strategies=dict( self.execution_context.pipeline_components. windowing_strategies.items()), environments=dict(self.execution_context.pipeline_components. environments.items()), state_api_service_descriptor=self.state_api_service_descriptor())
def run_stage(self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): context = pipeline_context.PipelineContext(pipeline_components) data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element( transform.outputs) data_input[target] = pcoll_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element( transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString( ) else: transform.spec.payload = "" elif transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( 'materialize:' + transform.inputs[tag], beam.pvalue.SideInputData.from_runner_api( si, None)) return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={ transform.unique_name: transform for transform in stage.transforms }, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) process_bundle_registration = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor])) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=process_bundle_descriptor. id)) # Write all the input data to the channel. for (transform_id, name), pcoll_id in data_input.items(): data_out = controller.data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in pcoll_buffers[pcoll_id]: data_out.write(element_data) data_out.close() # Store the required side inputs into state. for (transform_id, tag), (pcoll_id, si) in data_side_input.items(): elements_by_window = _WindowGroupingBuffer(si) for element_data in pcoll_buffers[pcoll_id]: elements_by_window.append(element_data) for window, elements_data in elements_by_window.items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey. MultimapSideInput(ptransform_id=transform_id, side_input_id=tag, window=window)) controller.state_handler.blocking_append( state_key, elements_data, process_bundle.instruction_id) # Register and start running the bundle. logging.debug('Register and start running the bundle') controller.control_handler.push(process_bundle_registration) controller.control_handler.push(process_bundle) # Wait for the bundle to finish. logging.debug('Wait for the bundle to finish.') while True: result = controller.control_handler.pull() if result and result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) break expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in data_output.items() ] # Gather all output data. logging.debug('Gather all output data from %s.', expected_targets) for output in controller.data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = (output.target.primitive_transform_reference, output.target.name) if target_tuple in data_output: pcoll_id = data_output[target_tuple] if pcoll_id.startswith('materialize:'): # Just store the data chunks for replay. pcoll_buffers[pcoll_id].append(output.data) elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element( transform_proto.inputs.values()) output_pcoll = only_element( transform_proto.outputs.values()) pre_gbk_coder = context.coders[ safe_coders[pipeline_components. pcollections[input_pcoll].coder_id]] post_gbk_coder = context.coders[ safe_coders[pipeline_components. pcollections[output_pcoll].coder_id]] windowing_strategy = context.windowing_strategies[ pipeline_components.pcollections[output_pcoll]. windowing_strategy_id] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder, windowing_strategy) pcoll_buffers[pcoll_id].append(output.data) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id) return result
def _map_task_to_protos(self, map_task, data_operation_spec): input_data = {} side_input_data = {} runner_sinks = {} context = pipeline_context.PipelineContext() transform_protos = {} used_pcollections = {} def uniquify(*names): # An injective mapping from string* to string. return ':'.join("%s:%d" % (name, len(name)) for name in names) def pcollection_id(op_ix, out_ix): if (op_ix, out_ix) not in used_pcollections: used_pcollections[op_ix, out_ix] = uniquify( map_task[op_ix][0], 'out', str(out_ix)) return used_pcollections[op_ix, out_ix] def get_inputs(op): if hasattr(op, 'inputs'): inputs = op.inputs elif hasattr(op, 'input'): inputs = [op.input] else: inputs = [] return {'in%s' % ix: pcollection_id(*input) for ix, input in enumerate(inputs)} def get_outputs(op_ix): op = map_task[op_ix][1] return {tag: pcollection_id(op_ix, out_ix) for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))} for op_ix, (stage_name, operation) in enumerate(map_task): transform_id = uniquify(stage_name) if isinstance(operation, operation_specs.WorkerInMemoryWrite): # Write this data back to the runner. target_name = only_element(get_inputs(operation).keys()) runner_sinks[(transform_id, target_name)] = operation transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, any_param=proto_utils.pack_Any(data_operation_spec), payload=data_operation_spec.SerializeToString() \ if data_operation_spec is not None else None) elif isinstance(operation, operation_specs.WorkerRead): # A Read from an in-memory source is done over the data plane. if (isinstance(operation.source.source, maptask_executor_runner.InMemorySource) and isinstance(operation.source.source.default_output_coder(), WindowedValueCoder)): target_name = only_element(get_outputs(op_ix).keys()) input_data[(transform_id, target_name)] = self._reencode_elements( operation.source.source.read(None), operation.source.source.default_output_coder()) transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, any_param=proto_utils.pack_Any(data_operation_spec), payload=data_operation_spec.SerializeToString() \ if data_operation_spec is not None else None) else: # Otherwise serialize the source and execute it there. # TODO: Use SDFs with an initial impulse. # The Dataflow runner harness strips the base64 encoding. do the same # here until we get the same thing back that we sent in. source_bytes = base64.b64decode( pickler.dumps(operation.source.source)) transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.PYTHON_SOURCE_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue( value=source_bytes)), payload=source_bytes) elif isinstance(operation, operation_specs.WorkerDoFn): # Record the contents of each side input for access via the state api. side_input_extras = [] for si in operation.side_inputs: assert isinstance(si.source, iobase.BoundedSource) element_coder = si.source.default_output_coder() # TODO(robertwb): Actually flesh out the ViewFn API. side_input_extras.append((si.tag, element_coder)) side_input_data[ bundle_processor.side_input_tag(transform_id, si.tag)] = ( self._reencode_elements( si.source.read(si.source.get_range_tracker(None, None)), element_coder)) augmented_serialized_fn = pickler.dumps( (operation.serialized_fn, side_input_extras)) transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.PYTHON_DOFN_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue(value=augmented_serialized_fn)), payload=augmented_serialized_fn) elif isinstance(operation, operation_specs.WorkerFlatten): # Flatten is nice and simple. transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.IDENTITY_DOFN_URN) else: raise NotImplementedError(operation) transform_protos[transform_id] = beam_runner_api_pb2.PTransform( unique_name=stage_name, spec=transform_spec, inputs=get_inputs(operation), outputs=get_outputs(op_ix)) pcollection_protos = { name: beam_runner_api_pb2.PCollection( unique_name=name, coder_id=context.coders.get_id( map_task[op_id][1].output_coders[out_id])) for (op_id, out_id), name in used_pcollections.items() } # Must follow creation of pcollection_protos to capture used coders. context_proto = context.to_runner_api() process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms=transform_protos, pcollections=pcollection_protos, coders=dict(context_proto.coders.items()), windowing_strategies=dict(context_proto.windowing_strategies.items()), environments=dict(context_proto.environments.items())) return input_data, side_input_data, runner_sinks, process_bundle_descriptor
def _map_task_registration(self, map_task, state_handler, data_operation_spec): input_data = {} runner_sinks = {} transforms = [] transform_index_to_id = {} # Maps coders to new coder objects and references. coders = {} def coder_id(coder): if coder not in coders: coders[coder] = beam_fn_api_pb2.Coder( function_spec=sdk_worker.pack_function_spec_data( json.dumps(coder.as_cloud_object()), sdk_worker.PYTHON_CODER_URN, id=self._next_uid())) return coders[coder].function_spec.id def output_tags(op): return getattr(op, 'output_tags', ['out']) def as_target(op_input): input_op_index, input_output_index = op_input input_op = map_task[input_op_index][1] return { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference=transform_index_to_id[ input_op_index], name=output_tags(input_op)[input_output_index]) ]) } def outputs(op): return { tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder)) for tag, coder in zip(output_tags(op), op.output_coders) } for op_ix, (stage_name, operation) in enumerate(map_task): transform_id = transform_index_to_id[op_ix] = self._next_uid() if isinstance(operation, operation_specs.WorkerInMemoryWrite): # Write this data back to the runner. fn = beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_OUTPUT_URN, id=self._next_uid()) if data_operation_spec: fn.data.Pack(data_operation_spec) inputs = as_target(operation.input) side_inputs = {} runner_sinks[(transform_id, 'out')] = operation elif isinstance(operation, operation_specs.WorkerRead): # A Read is either translated to a direct injection of windowed values # into the sdk worker, or an injection of the source object into the # sdk worker as data followed by an SDF that reads that source. if (isinstance(operation.source.source, maptask_executor_runner.InMemorySource) and isinstance( operation.source.source.default_output_coder(), WindowedValueCoder)): output_stream = create_OutputStream() element_coder = (operation.source.source. default_output_coder().get_impl()) # Re-encode the elements in the nested context and # concatenate them together for element in operation.source.source.read(None): element_coder.encode_to_stream(element, output_stream, True) target_name = self._next_uid() input_data[(transform_id, target_name)] = output_stream.get() fn = beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid()) if data_operation_spec: fn.data.Pack(data_operation_spec) inputs = {target_name: beam_fn_api_pb2.Target.List()} side_inputs = {} else: # Read the source object from the runner. source_coder = beam.coders.DillCoder() input_transform_id = self._next_uid() output_stream = create_OutputStream() source_coder.get_impl().encode_to_stream( GlobalWindows.windowed_value(operation.source), output_stream, True) target_name = self._next_uid() input_data[(input_transform_id, target_name)] = output_stream.get() input_ptransform = beam_fn_api_pb2.PrimitiveTransform( id=input_transform_id, function_spec=beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid()), # TODO(robertwb): Possible name collision. step_name=stage_name + '/inject_source', inputs={target_name: beam_fn_api_pb2.Target.List()}, outputs={ 'out': beam_fn_api_pb2.PCollection( coder_reference=coder_id(source_coder)) }) if data_operation_spec: input_ptransform.function_spec.data.Pack( data_operation_spec) transforms.append(input_ptransform) # Read the elements out of the source. fn = sdk_worker.pack_function_spec_data( OLDE_SOURCE_SPLITTABLE_DOFN_DATA, sdk_worker.PYTHON_DOFN_URN, id=self._next_uid()) inputs = { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference= input_transform_id, name='out') ]) } side_inputs = {} elif isinstance(operation, operation_specs.WorkerDoFn): fn = sdk_worker.pack_function_spec_data( operation.serialized_fn, sdk_worker.PYTHON_DOFN_URN, id=self._next_uid()) inputs = as_target(operation.input) # Store the contents of each side input for state access. for si in operation.side_inputs: assert isinstance(si.source, iobase.BoundedSource) element_coder = si.source.default_output_coder() view_id = self._next_uid() # TODO(robertwb): Actually flesh out the ViewFn API. side_inputs[si.tag] = beam_fn_api_pb2.SideInput( view_fn=sdk_worker.serialize_and_pack_py_fn( element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN, id=view_id)) # Re-encode the elements in the nested context and # concatenate them together output_stream = create_OutputStream() for element in si.source.read( si.source.get_range_tracker(None, None)): element_coder.get_impl().encode_to_stream( element, output_stream, True) elements_data = output_stream.get() state_key = beam_fn_api_pb2.StateKey.MultimapSideInput( key=view_id) state_handler.Clear(state_key) state_handler.Append(state_key, elements_data) elif isinstance(operation, operation_specs.WorkerFlatten): fn = sdk_worker.pack_function_spec_data( operation.serialized_fn, sdk_worker.IDENTITY_DOFN_URN, id=self._next_uid()) inputs = { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference= transform_index_to_id[input_op_index], name=output_tags(map_task[input_op_index] [1])[input_output_index]) for input_op_index, input_output_index in operation.inputs ]) } side_inputs = {} else: raise TypeError(operation) ptransform = beam_fn_api_pb2.PrimitiveTransform( id=transform_id, function_spec=fn, step_name=stage_name, inputs=inputs, side_inputs=side_inputs, outputs=outputs(operation)) transforms.append(ptransform) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), coders=coders.values(), primitive_transform=transforms) return beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor ])), runner_sinks, input_data
def run_stage( self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): coders = pipeline_context.PipelineContext(pipeline_components).coders data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: pcoll_id = transform.spec.payload if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element(transform.outputs) data_input[target] = pcoll_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element(transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString() transform.spec.any_param.Pack(data_operation_spec) else: transform.spec.payload = "" transform.spec.any_param.Clear() return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) if data_side_input: raise NotImplementedError('Side inputs.') process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={transform.unique_name: transform for transform in stage.transforms}, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) process_bundle_registration = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor])) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference= process_bundle_descriptor.id)) # Write all the input data to the channel. for (transform_id, name), pcoll_id in data_input.items(): data_out = controller.data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in pcoll_buffers[pcoll_id]: data_out.write(element_data) data_out.close() # Register and start running the bundle. controller.control_handler.push(process_bundle_registration) controller.control_handler.push(process_bundle) # Wait for the bundle to finish. while True: result = controller.control_handler.pull() if result and result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) break # Gather all output data. expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in data_output.items()] for output in controller.data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = ( output.target.primitive_transform_reference, output.target.name) if target_tuple in data_output: pcoll_id = data_output[target_tuple] if pcoll_id.startswith('materialize:'): # Just store the data chunks for replay. pcoll_buffers[pcoll_id].append(output.data) elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element(transform_proto.inputs.values()) output_pcoll = only_element(transform_proto.outputs.values()) pre_gbk_coder = coders[safe_coders[ pipeline_components.pcollections[input_pcoll].coder_id]] post_gbk_coder = coders[safe_coders[ pipeline_components.pcollections[output_pcoll].coder_id]] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder) pcoll_buffers[pcoll_id].append(output.data) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id)
def make_process_bundle_descriptor(self, data_api_service_descriptor, state_api_service_descriptor): # type: (Optional[endpoints_pb2.ApiServiceDescriptor], Optional[endpoints_pb2.ApiServiceDescriptor]) -> beam_fn_api_pb2.ProcessBundleDescriptor """Creates a ProcessBundleDescriptor for invoking the WindowFn's merge operation. """ def make_channel_payload(coder_id): # type: (str) -> bytes data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) if data_api_service_descriptor: data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) return data_spec.SerializeToString() pipeline_context = self._execution_context_ref().pipeline_context global_windowing_strategy_id = self.uid('global_windowing_strategy') global_windowing_strategy_proto = core.Windowing( window.GlobalWindows()).to_runner_api(pipeline_context) coders = dict(pipeline_context.coders.get_id_to_proto_map()) def make_coder(urn, *components): # type: (str, str) -> str coder_proto = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.FunctionSpec(urn=urn), component_coder_ids=components) coder_id = self.uid('coder') coders[coder_id] = coder_proto pipeline_context.coders.put_proto(coder_id, coder_proto) return coder_id bytes_coder_id = make_coder(common_urns.coders.BYTES.urn) window_coder_id = self._windowing_strategy_proto.window_coder_id global_window_coder_id = make_coder( common_urns.coders.GLOBAL_WINDOW.urn) iter_window_coder_id = make_coder(common_urns.coders.ITERABLE.urn, window_coder_id) input_coder_id = make_coder(common_urns.coders.KV.urn, bytes_coder_id, iter_window_coder_id) output_coder_id = make_coder( common_urns.coders.KV.urn, bytes_coder_id, make_coder( common_urns.coders.KV.urn, iter_window_coder_id, make_coder( common_urns.coders.ITERABLE.urn, make_coder(common_urns.coders.KV.urn, window_coder_id, iter_window_coder_id)))) windowed_input_coder_id = make_coder( common_urns.coders.WINDOWED_VALUE.urn, input_coder_id, global_window_coder_id) windowed_output_coder_id = make_coder( common_urns.coders.WINDOWED_VALUE.urn, output_coder_id, global_window_coder_id) self.windowed_input_coder_impl = pipeline_context.coders[ windowed_input_coder_id].get_impl() self.windowed_output_coder_impl = pipeline_context.coders[ windowed_output_coder_id].get_impl() self._bundle_processor_id = self.uid('merge_windows') return beam_fn_api_pb2.ProcessBundleDescriptor( id=self._bundle_processor_id, transforms={ self.TO_SDK_TRANSFORM: beam_runner_api_pb2.PTransform( unique_name='MergeWindows/Read', spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, payload=make_channel_payload(windowed_input_coder_id)), outputs={'input': 'input'}), 'Merge': beam_runner_api_pb2.PTransform( unique_name='MergeWindows/Merge', spec=beam_runner_api_pb2.FunctionSpec( urn=common_urns.primitives.MERGE_WINDOWS.urn, payload=self._windowing_strategy_proto.window_fn. SerializeToString()), inputs={'input': 'input'}, outputs={'output': 'output'}), self.FROM_SDK_TRANSFORM: beam_runner_api_pb2.PTransform( unique_name='MergeWindows/Write', spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, payload=make_channel_payload( windowed_output_coder_id)), inputs={'output': 'output'}), }, pcollections={ 'input': beam_runner_api_pb2.PCollection( unique_name='input', windowing_strategy_id=global_windowing_strategy_id, coder_id=input_coder_id), 'output': beam_runner_api_pb2.PCollection( unique_name='output', windowing_strategy_id=global_windowing_strategy_id, coder_id=output_coder_id), }, coders=coders, windowing_strategies={ global_windowing_strategy_id: global_windowing_strategy_proto, }, environments=dict(self._execution_context_ref(). pipeline_components.environments.items()), state_api_service_descriptor=state_api_service_descriptor, timer_api_service_descriptor=data_api_service_descriptor)
def run_stage( self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): context = pipeline_context.PipelineContext(pipeline_components) data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element(transform.outputs) data_input[target] = pcoll_buffers[pcoll_id] elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element(transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString() else: transform.spec.payload = "" elif transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( 'materialize:' + transform.inputs[tag], beam.pvalue.SideInputData.from_runner_api(si, None)) return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={transform.unique_name: transform for transform in stage.transforms}, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) # Store the required side inputs into state. for (transform_id, tag), (pcoll_id, si) in data_side_input.items(): elements_by_window = _WindowGroupingBuffer(si) for element_data in pcoll_buffers[pcoll_id]: elements_by_window.append(element_data) for window, elements_data in elements_by_window.items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( ptransform_id=transform_id, side_input_id=tag, window=window)) controller.state_handler.blocking_append(state_key, elements_data, None) def get_buffer(pcoll_id): if pcoll_id.startswith('materialize:'): if pcoll_id not in pcoll_buffers: # Just store the data chunks for replay. pcoll_buffers[pcoll_id] = list() elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element(transform_proto.inputs.values()) output_pcoll = only_element(transform_proto.outputs.values()) pre_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[input_pcoll].coder_id]] post_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[output_pcoll].coder_id]] windowing_strategy = context.windowing_strategies[ pipeline_components .pcollections[output_pcoll].windowing_strategy_id] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder, windowing_strategy) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id) return pcoll_buffers[pcoll_id] return BundleManager( controller, get_buffer, process_bundle_descriptor, self._progress_frequency).process_bundle(data_input, data_output)
def run_stage( self, worker_handler_factory, pipeline_components, stage, pcoll_buffers, safe_coders): def iterable_state_write(values, element_coder_impl): token = unique_name(None, 'iter').encode('ascii') out = create_OutputStream() for element in values: element_coder_impl.encode_to_stream(element, out, True) controller.state.blocking_append( beam_fn_api_pb2.StateKey( runner=beam_fn_api_pb2.StateKey.Runner(key=token)), out.get()) return token controller = worker_handler_factory(stage.environment) context = pipeline_context.PipelineContext( pipeline_components, iterable_state_write=iterable_state_write) data_api_service_descriptor = controller.data_api_service_descriptor() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data ApiServiceDescriptor. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element(transform.outputs) if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER: data_input[target] = [ENCODED_IMPULSE_VALUE] else: data_input[target] = pcoll_buffers[pcoll_id] coder_id = pipeline_components.pcollections[ only_element(transform.outputs.values())].coder_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element(transform.inputs) data_output[target] = pcoll_id coder_id = pipeline_components.pcollections[ only_element(transform.inputs.values())].coder_id else: raise NotImplementedError data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) if data_api_service_descriptor: data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) transform.spec.payload = data_spec.SerializeToString() elif transform.spec.urn == common_urns.primitives.PAR_DO.urn: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( create_buffer_id(transform.inputs[tag]), si.access_pattern) return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={transform.unique_name: transform for transform in stage.transforms}, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) if controller.state_api_service_descriptor(): process_bundle_descriptor.state_api_service_descriptor.url = ( controller.state_api_service_descriptor().url) # Store the required side inputs into state. for (transform_id, tag), (buffer_id, si) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) value_coder = context.coders[safe_coders[ pipeline_components.pcollections[pcoll_id].coder_id]] elements_by_window = _WindowGroupingBuffer(si, value_coder) for element_data in pcoll_buffers[buffer_id]: elements_by_window.append(element_data) for key, window, elements_data in elements_by_window.encoded_items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( ptransform_id=transform_id, side_input_id=tag, window=window, key=key)) controller.state.blocking_append(state_key, elements_data) def get_buffer(buffer_id): kind, name = split_buffer_id(buffer_id) if kind in ('materialize', 'timers'): if buffer_id not in pcoll_buffers: # Just store the data chunks for replay. pcoll_buffers[buffer_id] = list() elif kind == 'group': # This is a grouping write, create a grouping buffer if needed. if buffer_id not in pcoll_buffers: original_gbk_transform = name transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element(list(transform_proto.inputs.values())) output_pcoll = only_element(list(transform_proto.outputs.values())) pre_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[input_pcoll].coder_id]] post_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[output_pcoll].coder_id]] windowing_strategy = context.windowing_strategies[ pipeline_components .pcollections[output_pcoll].windowing_strategy_id] pcoll_buffers[buffer_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder, windowing_strategy) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(buffer_id) return pcoll_buffers[buffer_id] for k in range(self._bundle_repeat): try: controller.state.checkpoint() BundleManager( controller, lambda pcoll_id: [], process_bundle_descriptor, self._progress_frequency, k).process_bundle(data_input, data_output) finally: controller.state.restore() result = BundleManager( controller, get_buffer, process_bundle_descriptor, self._progress_frequency).process_bundle(data_input, data_output) while True: timer_inputs = {} for transform_id, timer_writes in stage.timer_pcollections: windowed_timer_coder_impl = context.coders[ pipeline_components.pcollections[timer_writes].coder_id].get_impl() written_timers = get_buffer( create_buffer_id(timer_writes, kind='timers')) if written_timers: # Keep only the "last" timer set per key and window. timers_by_key_and_window = {} for elements_data in written_timers: input_stream = create_InputStream(elements_data) while input_stream.size() > 0: windowed_key_timer = windowed_timer_coder_impl.decode_from_stream( input_stream, True) key, _ = windowed_key_timer.value # TODO: Explode and merge windows. assert len(windowed_key_timer.windows) == 1 timers_by_key_and_window[ key, windowed_key_timer.windows[0]] = windowed_key_timer out = create_OutputStream() for windowed_key_timer in timers_by_key_and_window.values(): windowed_timer_coder_impl.encode_to_stream( windowed_key_timer, out, True) timer_inputs[transform_id, 'out'] = [out.get()] written_timers[:] = [] if timer_inputs: # The worker will be waiting on these inputs as well. for other_input in data_input: if other_input not in timer_inputs: timer_inputs[other_input] = [] # TODO(robertwb): merge results BundleManager( controller, get_buffer, process_bundle_descriptor, self._progress_frequency, True).process_bundle(timer_inputs, data_output) else: break return result