def Expand(self, request, context=None): try: pipeline = beam_pipeline.Pipeline(options=self._options) def with_pipeline(component, pcoll_id=None): component.pipeline = pipeline if pcoll_id: component.producer, component.tag = producers[pcoll_id] # We need the lookup to resolve back to this id. context.pcollections._obj_to_id[component] = pcoll_id return component context = pipeline_context.PipelineContext( request.components, default_environment=self._default_environment, namespace=request.namespace) producers = { pcoll_id: (context.transforms.get_by_id(t_id), pcoll_tag) for t_id, t_proto in request.components.transforms.items() for pcoll_tag, pcoll_id in t_proto.outputs.items() } transform = with_pipeline( ptransform.PTransform.from_runner_api(request.transform, context)) inputs = transform._pvaluish_from_dict({ tag: with_pipeline(context.pcollections.get_by_id(pcoll_id), pcoll_id) for tag, pcoll_id in request.transform.inputs.items() }) if not inputs: inputs = pipeline with external.ExternalTransform.outer_namespace(request.namespace): result = pipeline.apply(transform, inputs, request.transform.unique_name) expanded_transform = pipeline._root_transform().parts[-1] # TODO(BEAM-1833): Use named outputs internally. if isinstance(result, dict): expanded_transform.outputs = result pipeline_proto = pipeline.to_runner_api(context=context) # TODO(BEAM-1833): Use named inputs internally. expanded_transform_id = context.transforms.get_id( expanded_transform) expanded_transform_proto = pipeline_proto.components.transforms.pop( expanded_transform_id) expanded_transform_proto.inputs.clear() expanded_transform_proto.inputs.update(request.transform.inputs) for transform_id in pipeline_proto.root_transform_ids: del pipeline_proto.components.transforms[transform_id] return beam_expansion_api_pb2.ExpansionResponse( components=pipeline_proto.components, transform=expanded_transform_proto, requirements=pipeline_proto.requirements) except Exception: # pylint: disable=broad-except return beam_expansion_api_pb2.ExpansionResponse( error=traceback.format_exc())
def test_serialization(self): context = pipeline_context.PipelineContext() float_coder_ref = context.coders.get_id(coders.FloatCoder()) bytes_coder_ref = context.coders.get_id(coders.BytesCoder()) proto = context.to_runner_api() context2 = pipeline_context.PipelineContext.from_runner_api(proto) self.assertEqual(coders.FloatCoder(), context2.coders.get_by_id(float_coder_ref)) self.assertEqual(coders.BytesCoder(), context2.coders.get_by_id(bytes_coder_ref))
def parse_coder(self, spec): context = pipeline_context.PipelineContext() coder_id = str(hash(str(spec))) component_ids = [context.coders.get_id(self.parse_coder(c)) for c in spec.get('components', ())] context.coders.put_proto(coder_id, beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.FunctionSpec( urn=spec['urn'], payload=spec.get('payload', '').encode('latin1')), component_coder_ids=component_ids)) return context.coders.get_by_id(coder_id)
def run_pipeline(self, pipeline): docker_image = ( pipeline.options.view_as(PortableOptions).harness_docker_image or self.default_docker_image()) job_endpoint = pipeline.options.view_as(PortableOptions).job_endpoint if not job_endpoint: raise ValueError( 'job_endpoint should be provided while creating runner.') proto_context = pipeline_context.PipelineContext( default_environment_url=docker_image) proto_pipeline = pipeline.to_runner_api(context=proto_context) if not self.is_embedded_fnapi_runner: # Java has different expectations about coders # (windowed in Fn API, but *un*windowed in runner API), whereas the # embedded FnApiRunner treats them consistently, so we must guard this # for now, until FnApiRunner is fixed. # See also BEAM-2717. for pcoll in proto_pipeline.components.pcollections.values(): if pcoll.coder_id not in proto_context.coders: # This is not really a coder id, but a pickled coder. coder = coders.registry.get_coder( pickler.loads(pcoll.coder_id)) pcoll.coder_id = proto_context.coders.get_id(coder) proto_context.coders.populate_map(proto_pipeline.components.coders) # Some runners won't detect the GroupByKey transform unless it has no # subtransforms. Remove all sub-transforms until BEAM-4605 is resolved. for _, transform_proto in list( proto_pipeline.components.transforms.items()): if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn: for sub_transform in transform_proto.subtransforms: del proto_pipeline.components.transforms[sub_transform] del transform_proto.subtransforms[:] job_service = beam_job_api_pb2_grpc.JobServiceStub( grpc.insecure_channel(job_endpoint)) prepare_response = job_service.Prepare( beam_job_api_pb2.PrepareJobRequest(job_name='job', pipeline=proto_pipeline)) if prepare_response.artifact_staging_endpoint.url: stager = portable_stager.PortableStager( grpc.insecure_channel( prepare_response.artifact_staging_endpoint.url), prepare_response.staging_session_token) retrieval_token, _ = stager.stage_job_resources( pipeline._options, staging_location='') else: retrieval_token = None run_response = job_service.Run( beam_job_api_pb2.RunJobRequest( preparation_id=prepare_response.preparation_id, retrieval_token=retrieval_token)) return PipelineResult(job_service, run_response.job_id)
def deserialize_windowing_strategy(cls, serialized_data): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.runners import pipeline_context from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.transforms.core import Windowing proto = beam_runner_api_pb2.MessageWithComponents() proto.ParseFromString(cls.json_string_to_byte_array(serialized_data)) return Windowing.from_runner_api( proto.windowing_strategy, pipeline_context.PipelineContext(proto.components))
def to_runner_api(self): from apache_beam.runners import pipeline_context from apache_beam.runners.api import beam_runner_api_pb2 context = pipeline_context.PipelineContext() # Mutates context; placing inline would force dependence on # argument evaluation order. root_transform_id = context.transforms.get_id(self._root_transform()) proto = beam_runner_api_pb2.Pipeline( root_transform_id=root_transform_id, components=context.to_runner_api()) return proto
def to_runner_api(self): """For internal use only; no backwards-compatibility guarantees.""" from apache_beam.runners import pipeline_context from apache_beam.portability.api import beam_runner_api_pb2 context = pipeline_context.PipelineContext() # Mutates context; placing inline would force dependence on # argument evaluation order. root_transform_id = context.transforms.get_id(self._root_transform()) proto = beam_runner_api_pb2.Pipeline( root_transform_ids=[root_transform_id], components=context.to_runner_api()) return proto
def test_windowing_encoding(self): for windowing in ( Windowing(GlobalWindows()), Windowing(FixedWindows(1, 3), AfterCount(6), accumulation_mode=AccumulationMode.ACCUMULATING), Windowing(SlidingWindows(10, 15, 21), AfterCount(28), timestamp_combiner=TimestampCombiner.OUTPUT_AT_LATEST, accumulation_mode=AccumulationMode.DISCARDING)): context = pipeline_context.PipelineContext() self.assertEqual( windowing, Windowing.from_runner_api(windowing.to_runner_api(context), context))
def from_runner_api(proto, runner, options): p = Pipeline(runner=runner, options=options) from apache_beam.runners import pipeline_context context = pipeline_context.PipelineContext(proto.components) root_transform_id, = proto.root_transform_ids p.transforms_stack = [context.transforms.get_by_id(root_transform_id)] # TODO(robertwb): These are only needed to continue construction. Omit? p.applied_labels = set( [t.unique_name for t in proto.components.transforms.values()]) for id in proto.components.pcollections: context.pcollections.get_by_id(id).pipeline = p return p
def to_runner_api(self, return_context=False, context=None, use_fake_coders=False, default_environment=None): """For internal use only; no backwards-compatibility guarantees.""" from apache_beam.runners import pipeline_context from apache_beam.portability.api import beam_runner_api_pb2 if context is None: context = pipeline_context.PipelineContext( use_fake_coders=use_fake_coders, default_environment=default_environment) elif default_environment is not None: raise ValueError( 'Only one of context or default_environment may be specificed.' ) # The RunnerAPI spec requires certain transforms to have KV inputs # (and corresponding outputs). # Currently we only upgrade to KV pairs. If there is a need for more # general shapes, potential conflicts will have to be resolved. # We also only handle single-input, and (for fixing the output) single # output, which is sufficient. class ForceKvInputTypes(PipelineVisitor): def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): if (transform_node.transform and transform_node.transform. runner_api_requires_keyed_input()): pcoll = transform_node.inputs[0] pcoll.element_type = typehints.coerce_to_kv_type( pcoll.element_type, transform_node.full_label) if len(transform_node.outputs) == 1: # The runner often has expectations about the output types as well. output, = transform_node.outputs.values() output.element_type = transform_node.transform.infer_output_type( pcoll.element_type) self.visit(ForceKvInputTypes()) # Mutates context; placing inline would force dependence on # argument evaluation order. root_transform_id = context.transforms.get_id(self._root_transform()) proto = beam_runner_api_pb2.Pipeline( root_transform_ids=[root_transform_id], components=context.to_runner_api()) proto.components.transforms[root_transform_id].unique_name = ( root_transform_id) if return_context: return proto, context else: return proto
def test_trigger_encoding(self): for trigger_fn in (DefaultTrigger(), AfterAll(AfterCount(1), AfterCount(10)), AfterAny(AfterCount(10), AfterCount(100)), AfterWatermark(early=AfterCount(1000)), AfterWatermark(early=AfterCount(1000), late=AfterCount(1)), Repeatedly(AfterCount(100)), trigger.OrFinally(AfterCount(3), AfterCount(10))): context = pipeline_context.PipelineContext() self.assertEqual( trigger_fn, TriggerFn.from_runner_api(trigger_fn.to_runner_api(context), context))
def run_pipeline(self, pipeline): # Java has different expectations about coders # (windowed in Fn API, but *un*windowed in runner API), whereas the # FnApiRunner treats them consistently, so we must guard this. # See also BEAM-2717. proto_context = pipeline_context.PipelineContext( default_environment_url=self._docker_image) proto_pipeline = pipeline.to_runner_api(context=proto_context) if self._runner_api_address: for pcoll in proto_pipeline.components.pcollections.values(): if pcoll.coder_id not in proto_context.coders: coder = coders.registry.get_coder( pickler.loads(pcoll.coder_id)) pcoll.coder_id = proto_context.coders.get_id(coder) proto_context.coders.populate_map(proto_pipeline.components.coders) # Some runners won't detect the GroupByKey transform unless it has no # subtransforms. Remove all sub-transforms until BEAM-4605 is resolved. for _, transform_proto in list( proto_pipeline.components.transforms.items()): if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn: for sub_transform in transform_proto.subtransforms: del proto_pipeline.components.transforms[sub_transform] del transform_proto.subtransforms[:] job_service = self._create_job_service() prepare_response = job_service.Prepare( beam_job_api_pb2.PrepareJobRequest(job_name='job', pipeline=proto_pipeline)) if prepare_response.artifact_staging_endpoint.url: # Must commit something to get a retrieval token, # committing empty manifest for now. # TODO(BEAM-3883): Actually stage required files. artifact_service = beam_artifact_api_pb2_grpc.ArtifactStagingServiceStub( grpc.insecure_channel( prepare_response.artifact_staging_endpoint.url)) commit_manifest = artifact_service.CommitManifest( beam_artifact_api_pb2.CommitManifestRequest( manifest=beam_artifact_api_pb2.Manifest(), staging_session_token=prepare_response. staging_session_token)) retrieval_token = commit_manifest.retrieval_token else: retrieval_token = None run_response = job_service.Run( beam_job_api_pb2.RunJobRequest( preparation_id=prepare_response.preparation_id, retrieval_token=retrieval_token)) return PipelineResult(job_service, run_response.job_id)
def __init__(self, descriptor, data_channel_factory, counter_factory, state_sampler, state_handler): self.descriptor = descriptor self.data_channel_factory = data_channel_factory self.counter_factory = counter_factory self.state_sampler = state_sampler self.state_handler = state_handler self.context = pipeline_context.PipelineContext( descriptor, iterable_state_read=lambda token, element_coder_impl: _StateBackedIterable( state_handler, beam_fn_api_pb2.StateKey( runner=beam_fn_api_pb2.StateKey.Runner(key=token)), element_coder_impl))
def _test_runner_api_round_trip(self, transform, urn): context = pipeline_context.PipelineContext() proto = transform.to_runner_api(context) self.assertEqual(urn, proto.urn) payload = ( proto_utils.parse_Bytes( proto.payload, beam_runner_api_pb2.GroupIntoBatchesPayload)) self.assertEqual(transform.params.batch_size, payload.batch_size) self.assertEqual( transform.params.max_buffering_duration_secs * 1000, payload.max_buffering_duration_millis) transform_from_proto = ( transform.__class__.from_runner_api_parameter(None, payload, None)) self.assertIsInstance(transform_from_proto, transform.__class__) self.assertEqual(transform.params, transform_from_proto.params)
def test_no_output_coder(self): external_transform = beam.ExternalTransform( 'map_to_union_types', None, expansion_service.ExpansionServiceServicer()) with beam.Pipeline() as p: res = (p | beam.Create([2, 2], reshuffle=False) | external_transform) assert_that(res, equal_to([2, 2])) context = pipeline_context.PipelineContext( external_transform._expanded_components) self.assertEqual(len(external_transform._expanded_transform.outputs), 1) for _, pcol_id in external_transform._expanded_transform.outputs.items( ): pcol = context.pcollections.get_by_id(pcol_id) self.assertEqual(pcol.element_type, typehints.Any)
def run_pipeline(self, pipeline): docker_image = ( pipeline.options.view_as(PortableOptions).harness_docker_image or self.default_docker_image()) job_endpoint = pipeline.options.view_as(PortableOptions).job_endpoint if not job_endpoint: raise ValueError( 'job_endpoint should be provided while creating runner.') proto_context = pipeline_context.PipelineContext( default_environment_url=docker_image) proto_pipeline = pipeline.to_runner_api(context=proto_context) # Some runners won't detect the GroupByKey transform unless it has no # subtransforms. Remove all sub-transforms until BEAM-4605 is resolved. for _, transform_proto in list( proto_pipeline.components.transforms.items()): if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn: for sub_transform in transform_proto.subtransforms: del proto_pipeline.components.transforms[sub_transform] del transform_proto.subtransforms[:] job_service = beam_job_api_pb2_grpc.JobServiceStub( grpc.insecure_channel(job_endpoint)) prepare_response = job_service.Prepare( beam_job_api_pb2.PrepareJobRequest(job_name='job', pipeline=proto_pipeline)) if prepare_response.artifact_staging_endpoint.url: # Must commit something to get a retrieval token, # committing empty manifest for now. # TODO(BEAM-3883): Actually stage required files. artifact_service = beam_artifact_api_pb2_grpc.ArtifactStagingServiceStub( grpc.insecure_channel( prepare_response.artifact_staging_endpoint.url)) commit_manifest = artifact_service.CommitManifest( beam_artifact_api_pb2.CommitManifestRequest( manifest=beam_artifact_api_pb2.Manifest(), staging_session_token=prepare_response. staging_session_token)) retrieval_token = commit_manifest.retrieval_token else: retrieval_token = None run_response = job_service.Run( beam_job_api_pb2.RunJobRequest( preparation_id=prepare_response.preparation_id, retrieval_token=retrieval_token)) return PipelineResult(job_service, run_response.job_id)
def check_coder(self, coder, *values): self._observe(coder) for v in values: self.assertEqual(v, coder.decode(coder.encode(v))) self.assertEqual(coder.estimate_size(v), len(coder.encode(v))) self.assertEqual(coder.estimate_size(v), coder.get_impl().estimate_size(v)) self.assertEqual(coder.get_impl().get_estimated_size_and_observables(v), (coder.get_impl().estimate_size(v), [])) copy1 = dill.loads(dill.dumps(coder)) context = pipeline_context.PipelineContext() copy2 = coders.Coder.from_runner_api(coder.to_runner_api(context), context) for v in values: self.assertEqual(v, copy1.decode(copy2.encode(v))) if coder.is_deterministic(): self.assertEqual(copy1.encode(v), copy2.encode(v))
def check_coder(self, coder, *values, **kwargs): context = kwargs.pop('context', pipeline_context.PipelineContext()) test_size_estimation = kwargs.pop('test_size_estimation', True) assert not kwargs self._observe(coder) for v in values: self.assertEqual(v, coder.decode(coder.encode(v))) if test_size_estimation: self.assertEqual(coder.estimate_size(v), len(coder.encode(v))) self.assertEqual(coder.estimate_size(v), coder.get_impl().estimate_size(v)) self.assertEqual(coder.get_impl().get_estimated_size_and_observables(v), (coder.get_impl().estimate_size(v), [])) copy1 = pickler.loads(pickler.dumps(coder)) copy2 = coders.Coder.from_runner_api(coder.to_runner_api(context), context) for v in values: self.assertEqual(v, copy1.decode(copy2.encode(v))) if coder.is_deterministic(): self.assertEqual(copy1.encode(v), copy2.encode(v))
def test_environment_encoding(self): for environment in (DockerEnvironment(), DockerEnvironment(container_image='img'), ProcessEnvironment('run.sh'), ProcessEnvironment('run.sh', os='linux', arch='amd64', env={'k1': 'v1'}), ExternalEnvironment('localhost:8080'), ExternalEnvironment('localhost:8080', params={'k1': 'v1'}), EmbeddedPythonEnvironment(), EmbeddedPythonGrpcEnvironment(), EmbeddedPythonGrpcEnvironment(state_cache_size=0), SubprocessSDKEnvironment(command_string=u'foö')): context = pipeline_context.PipelineContext() self.assertEqual( environment, Environment.from_runner_api(environment.to_runner_api(context), context))
def test_equal_environments_are_deduplicated_when_fetched_by_obj_or_proto( self): context = pipeline_context.PipelineContext() env = environments.SubprocessSDKEnvironment(command_string="foo") env_proto = env.to_runner_api(None) id_from_proto = context.environments.get_by_proto(env_proto) id_from_obj = context.environments.get_id(env) self.assertEqual(id_from_obj, id_from_proto) self.assertEqual( context.environments.get_by_id(id_from_obj).command_string, "foo") env = environments.SubprocessSDKEnvironment(command_string="bar") env_proto = env.to_runner_api(None) id_from_obj = context.environments.get_id(env) id_from_proto = context.environments.get_by_proto( env_proto, deduplicate=True) self.assertEqual(id_from_obj, id_from_proto) self.assertEqual( context.environments.get_by_id(id_from_obj).command_string, "bar")
def __init__( self, stages, # type: List[translations.Stage] worker_handler_manager, # type: worker_handlers.WorkerHandlerManager pipeline_components, # type: beam_runner_api_pb2.Components safe_coders: translations.SafeCoderMapping, data_channel_coders: Dict[str, str], ) -> None: """ :param worker_handler_manager: This class manages the set of worker handlers, and the communication with state / control APIs. :param pipeline_components: (beam_runner_api_pb2.Components) :param safe_coders: A map from Coder ID to Safe Coder ID. :param data_channel_coders: A map from PCollection ID to the ID of the Coder for that PCollection. """ self.stages = stages self.side_input_descriptors_by_stage = ( self._build_data_side_inputs_map(stages)) self.pcoll_buffers = { } # type: MutableMapping[bytes, PartitionableBuffer] self.timer_buffers = {} # type: MutableMapping[bytes, ListBuffer] self.worker_handler_manager = worker_handler_manager self.pipeline_components = pipeline_components self.safe_coders = safe_coders self.data_channel_coders = data_channel_coders self.input_transform_to_buffer_id = { t.unique_name: t.spec.payload for s in stages for t in s.transforms if t.spec.urn == bundle_processor.DATA_INPUT_URN } self.watermark_manager = WatermarkManager(stages) self.pipeline_context = pipeline_context.PipelineContext( self.pipeline_components, iterable_state_write=self._iterable_state_write) self._last_uid = -1 self.safe_windowing_strategies = { id: self._make_safe_windowing_strategy(id) for id in self.pipeline_components.windowing_strategies.keys() }
def from_runner_api(proto, # type: beam_runner_api_pb2.Pipeline runner, # type: PipelineRunner options, # type: PipelineOptions return_context=False, allow_proto_holders=False ): # type: (...) -> Pipeline """For internal use only; no backwards-compatibility guarantees.""" p = Pipeline(runner=runner, options=options) from apache_beam.runners import pipeline_context context = pipeline_context.PipelineContext( proto.components, allow_proto_holders=allow_proto_holders, requirements=proto.requirements) root_transform_id, = proto.root_transform_ids p.transforms_stack = [context.transforms.get_by_id(root_transform_id)] # TODO(robertwb): These are only needed to continue construction. Omit? p.applied_labels = set( [t.unique_name for t in proto.components.transforms.values()]) for id in proto.components.pcollections: pcollection = context.pcollections.get_by_id(id) pcollection.pipeline = p if not pcollection.producer: raise ValueError('No producer for %s' % id) # Inject PBegin input where necessary. from apache_beam.io.iobase import Read from apache_beam.transforms.core import Create has_pbegin = [Read, Create] for id in proto.components.transforms: transform = context.transforms.get_by_id(id) if not transform.inputs and transform.transform.__class__ in has_pbegin: transform.inputs = (pvalue.PBegin(p), ) if return_context: return p, context # type: ignore # too complicated for now else: return p
def test_environment_encoding(self): for environment in (DockerEnvironment(), DockerEnvironment(container_image='img'), DockerEnvironment(capabilities=['x, y, z']), ProcessEnvironment('run.sh'), ProcessEnvironment('run.sh', os='linux', arch='amd64', env={'k1': 'v1'}), ExternalEnvironment('localhost:8080'), ExternalEnvironment('localhost:8080', params={'k1': 'v1'}), EmbeddedPythonEnvironment(), EmbeddedPythonGrpcEnvironment(), EmbeddedPythonGrpcEnvironment( state_cache_size=0, data_buffer_time_limit_ms=0), SubprocessSDKEnvironment(command_string=u'foö')): context = pipeline_context.PipelineContext() proto = environment.to_runner_api(context) reconstructed = Environment.from_runner_api(proto, context) self.assertEqual(environment, reconstructed) self.assertEqual(proto, reconstructed.to_runner_api(context))
def __init__(self, worker_handler_manager, # type: worker_handlers.WorkerHandlerManager pipeline_components, # type: beam_runner_api_pb2.Components safe_coders, data_channel_coders, ): """ :param worker_handler_manager: This class manages the set of worker handlers, and the communication with state / control APIs. :param pipeline_components: (beam_runner_api_pb2.Components): TODO :param safe_coders: :param data_channel_coders: """ self.pcoll_buffers = {} # type: MutableMapping[bytes, PartitionableBuffer] self.worker_handler_manager = worker_handler_manager self.pipeline_components = pipeline_components self.safe_coders = safe_coders self.data_channel_coders = data_channel_coders self.pipeline_context = pipeline_context.PipelineContext( self.pipeline_components, iterable_state_write=self._iterable_state_write) self._last_uid = -1
def run_pipeline(self, pipeline): # Java has different expectations about coders # (windowed in Fn API, but *un*windowed in runner API), whereas the # FnApiRunner treats them consistently, so we must guard this. # See also BEAM-2717. proto_context = pipeline_context.PipelineContext( default_environment_url=self._docker_image) proto_pipeline = pipeline.to_runner_api(context=proto_context) if self._runner_api_address: for pcoll in proto_pipeline.components.pcollections.values(): if pcoll.coder_id not in proto_context.coders: coder = coders.registry.get_coder(pickler.loads(pcoll.coder_id)) pcoll.coder_id = proto_context.coders.get_id(coder) proto_context.coders.populate_map(proto_pipeline.components.coders) job_service = self._get_job_service() prepare_response = job_service.Prepare( beam_job_api_pb2.PrepareJobRequest( job_name='job', pipeline=proto_pipeline)) run_response = job_service.Run(beam_job_api_pb2.RunJobRequest( preparation_id=prepare_response.preparation_id)) return PipelineResult(job_service, run_response.job_id)
def from_runner_api(proto, runner, options): """For internal use only; no backwards-compatibility guarantees.""" p = Pipeline(runner=runner, options=options) from apache_beam.runners import pipeline_context context = pipeline_context.PipelineContext(proto.components) root_transform_id, = proto.root_transform_ids p.transforms_stack = [context.transforms.get_by_id(root_transform_id)] # TODO(robertwb): These are only needed to continue construction. Omit? p.applied_labels = set( [t.unique_name for t in proto.components.transforms.values()]) for id in proto.components.pcollections: pcollection = context.pcollections.get_by_id(id) pcollection.pipeline = p # Inject PBegin input where necessary. from apache_beam.io.iobase import Read from apache_beam.transforms.core import Create has_pbegin = [Read, Create] for id in proto.components.transforms: transform = context.transforms.get_by_id(id) if not transform.inputs and transform.transform.__class__ in has_pbegin: transform.inputs = (pvalue.PBegin(p), ) return p
def test_runner_api_transformation_properties_none(self, unused_mock_pubsub): # Confirming that properties stay None after a runner API transformation. source = _PubSubSource(topic='projects/fakeprj/topics/a_topic', with_attributes=True) transform = Read(source) context = pipeline_context.PipelineContext() proto_transform_spec = transform.to_runner_api(context) self.assertEqual(common_urns.composites.PUBSUB_READ.urn, proto_transform_spec.urn) pubsub_read_payload = (proto_utils.parse_Bytes( proto_transform_spec.payload, beam_runner_api_pb2.PubSubReadPayload)) proto_transform = beam_runner_api_pb2.PTransform( unique_name="dummy_label", spec=proto_transform_spec) transform_from_proto = Read.from_runner_api_parameter( proto_transform, pubsub_read_payload, None) self.assertIsNone(transform_from_proto.source.full_subscription) self.assertIsNone(transform_from_proto.source.id_label) self.assertIsNone(transform_from_proto.source.timestamp_attribute)
def run_stage(self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): context = pipeline_context.PipelineContext(pipeline_components) data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element( transform.outputs) data_input[target] = pcoll_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element( transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString( ) else: transform.spec.payload = "" elif transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( 'materialize:' + transform.inputs[tag], beam.pvalue.SideInputData.from_runner_api( si, None)) return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={ transform.unique_name: transform for transform in stage.transforms }, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) process_bundle_registration = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor])) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=process_bundle_descriptor. id)) # Write all the input data to the channel. for (transform_id, name), pcoll_id in data_input.items(): data_out = controller.data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in pcoll_buffers[pcoll_id]: data_out.write(element_data) data_out.close() # Store the required side inputs into state. for (transform_id, tag), (pcoll_id, si) in data_side_input.items(): elements_by_window = _WindowGroupingBuffer(si) for element_data in pcoll_buffers[pcoll_id]: elements_by_window.append(element_data) for window, elements_data in elements_by_window.items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey. MultimapSideInput(ptransform_id=transform_id, side_input_id=tag, window=window)) controller.state_handler.blocking_append( state_key, elements_data, process_bundle.instruction_id) # Register and start running the bundle. logging.debug('Register and start running the bundle') controller.control_handler.push(process_bundle_registration) controller.control_handler.push(process_bundle) # Wait for the bundle to finish. logging.debug('Wait for the bundle to finish.') while True: result = controller.control_handler.pull() if result and result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) break expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in data_output.items() ] # Gather all output data. logging.debug('Gather all output data from %s.', expected_targets) for output in controller.data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = (output.target.primitive_transform_reference, output.target.name) if target_tuple in data_output: pcoll_id = data_output[target_tuple] if pcoll_id.startswith('materialize:'): # Just store the data chunks for replay. pcoll_buffers[pcoll_id].append(output.data) elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element( transform_proto.inputs.values()) output_pcoll = only_element( transform_proto.outputs.values()) pre_gbk_coder = context.coders[ safe_coders[pipeline_components. pcollections[input_pcoll].coder_id]] post_gbk_coder = context.coders[ safe_coders[pipeline_components. pcollections[output_pcoll].coder_id]] windowing_strategy = context.windowing_strategies[ pipeline_components.pcollections[output_pcoll]. windowing_strategy_id] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder, windowing_strategy) pcoll_buffers[pcoll_id].append(output.data) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id) return result
def create_stages(self, pipeline_proto): # First define a couple of helpers. def union(a, b): # Minimize the number of distinct sets. if not a or a == b: return b elif not b: return a else: return frozenset.union(a, b) class Stage(object): """A set of Transforms that can be sent to the worker for processing.""" def __init__(self, name, transforms, downstream_side_inputs=None, must_follow=frozenset()): self.name = name self.transforms = transforms self.downstream_side_inputs = downstream_side_inputs self.must_follow = must_follow def __repr__(self): must_follow = ', '.join(prev.name for prev in self.must_follow) downstream_side_inputs = ', '.join( str(si) for si in self.downstream_side_inputs) return "%s\n %s\n must follow: %s\n downstream_side_inputs: %s" % ( self.name, '\n'.join([ "%s:%s" % (transform.unique_name, transform.spec.urn) for transform in self.transforms ]), must_follow, downstream_side_inputs) def can_fuse(self, consumer): def no_overlap(a, b): return not a.intersection(b) return (not self in consumer.must_follow and not self.is_flatten() and not consumer.is_flatten() and no_overlap(self.downstream_side_inputs, consumer.side_inputs())) def fuse(self, other): return Stage( "(%s)+(%s)" % (self.name, other.name), self.transforms + other.transforms, union(self.downstream_side_inputs, other.downstream_side_inputs), union(self.must_follow, other.must_follow)) def is_flatten(self): return any(transform.spec.urn == urns.FLATTEN_TRANSFORM for transform in self.transforms) def side_inputs(self): for transform in self.transforms: if transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for side_input in payload.side_inputs: yield transform.inputs[side_input] def has_as_main_input(self, pcoll): for transform in self.transforms: if transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) local_side_inputs = payload.side_inputs else: local_side_inputs = {} for local_id, pipeline_id in transform.inputs.items(): if pcoll == pipeline_id and local_id not in local_side_inputs: return True def deduplicate_read(self): seen_pcolls = set() new_transforms = [] for transform in self.transforms: if transform.spec.urn == bundle_processor.DATA_INPUT_URN: pcoll = only_element(transform.outputs.items())[1] if pcoll in seen_pcolls: continue seen_pcolls.add(pcoll) new_transforms.append(transform) self.transforms = new_transforms # Now define the "optimization" phases. safe_coders = {} def expand_gbk(stages): """Transforms each GBK into a write followed by a read. """ good_coder_urns = set(beam.coders.Coder._known_urns.keys()) - set( [urns.PICKLED_CODER]) coders = pipeline_components.coders for coder_id, coder_proto in coders.items(): if coder_proto.spec.spec.urn == urns.BYTES_CODER: bytes_coder_id = coder_id break else: bytes_coder_id = unique_name(coders, 'bytes_coder') pipeline_components.coders[bytes_coder_id].CopyFrom( beam.coders.BytesCoder().to_runner_api(None)) coder_substitutions = {} def wrap_unknown_coders(coder_id, with_bytes): if (coder_id, with_bytes) not in coder_substitutions: wrapped_coder_id = None coder_proto = coders[coder_id] if coder_proto.spec.spec.urn == urns.LENGTH_PREFIX_CODER: coder_substitutions[coder_id, with_bytes] = ( bytes_coder_id if with_bytes else coder_id) elif coder_proto.spec.spec.urn in good_coder_urns: wrapped_components = [ wrap_unknown_coders(c, with_bytes) for c in coder_proto.component_coder_ids ] if wrapped_components == list( coder_proto.component_coder_ids): # Use as is. coder_substitutions[coder_id, with_bytes] = coder_id else: wrapped_coder_id = unique_name( coders, coder_id + ("_bytes" if with_bytes else "_len_prefix")) coders[wrapped_coder_id].CopyFrom(coder_proto) coders[wrapped_coder_id].component_coder_ids[:] = [ wrap_unknown_coders(c, with_bytes) for c in coder_proto.component_coder_ids ] coder_substitutions[coder_id, with_bytes] = wrapped_coder_id else: # Not a known coder. if with_bytes: coder_substitutions[coder_id, with_bytes] = bytes_coder_id else: wrapped_coder_id = unique_name( coders, coder_id + "_len_prefix") len_prefix_coder_proto = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( urn=urns.LENGTH_PREFIX_CODER)), component_coder_ids=[coder_id]) coders[wrapped_coder_id].CopyFrom( len_prefix_coder_proto) coder_substitutions[coder_id, with_bytes] = wrapped_coder_id # This operation is idempotent. if wrapped_coder_id: coder_substitutions[wrapped_coder_id, with_bytes] = wrapped_coder_id return coder_substitutions[coder_id, with_bytes] def fix_pcoll_coder(pcoll): new_coder_id = wrap_unknown_coders(pcoll.coder_id, False) safe_coders[new_coder_id] = wrap_unknown_coders( pcoll.coder_id, True) pcoll.coder_id = new_coder_id for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] if transform.spec.urn == urns.GROUP_BY_KEY_TRANSFORM: for pcoll_id in transform.inputs.values(): fix_pcoll_coder( pipeline_components.pcollections[pcoll_id]) for pcoll_id in transform.outputs.values(): fix_pcoll_coder( pipeline_components.pcollections[pcoll_id]) # This is used later to correlate the read and write. param = str("group:%s" % stage.name) gbk_write = Stage(transform.unique_name + '/Write', [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Write', inputs=transform.inputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, payload=param)) ], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) yield gbk_write yield Stage(transform.unique_name + '/Read', [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Read', outputs=transform.outputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, payload=param)) ], downstream_side_inputs=frozenset(), must_follow=union(frozenset([gbk_write]), stage.must_follow)) else: yield stage def sink_flattens(stages): """Sink flattens and remove them from the graph. A flatten that cannot be sunk/fused away becomes multiple writes (to the same logical sink) followed by a read. """ # TODO(robertwb): Actually attempt to sink rather than always materialize. # TODO(robertwb): Possibly fuse this into one of the stages. pcollections = pipeline_components.pcollections for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] if transform.spec.urn == urns.FLATTEN_TRANSFORM: # This is used later to correlate the read and writes. param = str("materialize:%s" % transform.unique_name) output_pcoll_id, = transform.outputs.values() output_coder_id = pcollections[output_pcoll_id].coder_id flatten_writes = [] for local_in, pcoll_in in transform.inputs.items(): if pcollections[pcoll_in].coder_id != output_coder_id: # Flatten inputs must all be written with the same coder as is # used to read them. pcollections[pcoll_in].coder_id = output_coder_id transcoded_pcollection = (transform.unique_name + '/Transcode/' + local_in + '/out') yield Stage( transform.unique_name + '/Transcode/' + local_in, [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Transcode/' + local_in, inputs={local_in: pcoll_in}, outputs={ 'out': transcoded_pcollection }, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor. IDENTITY_DOFN_URN)) ], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) pcollections[transcoded_pcollection].CopyFrom( pcollections[pcoll_in]) pcollections[ transcoded_pcollection].coder_id = output_coder_id else: transcoded_pcollection = pcoll_in flatten_write = Stage( transform.unique_name + '/Write/' + local_in, [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Write/' + local_in, inputs={local_in: transcoded_pcollection}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, payload=param)) ], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) flatten_writes.append(flatten_write) yield flatten_write yield Stage(transform.unique_name + '/Read', [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Read', outputs=transform.outputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, payload=param)) ], downstream_side_inputs=frozenset(), must_follow=union(frozenset(flatten_writes), stage.must_follow)) else: yield stage def annotate_downstream_side_inputs(stages): """Annotate each stage with fusion-prohibiting information. Each stage is annotated with the (transitive) set of pcollections that depend on this stage that are also used later in the pipeline as a side input. While theoretically this could result in O(n^2) annotations, the size of each set is bounded by the number of side inputs (typically much smaller than the number of total nodes) and the number of *distinct* side-input sets is also generally small (and shared due to the use of union defined above). This representation is also amenable to simple recomputation on fusion. """ consumers = collections.defaultdict(list) all_side_inputs = set() for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): consumers[input].append(stage) for si in stage.side_inputs(): all_side_inputs.add(si) all_side_inputs = frozenset(all_side_inputs) downstream_side_inputs_by_stage = {} def compute_downstream_side_inputs(stage): if stage not in downstream_side_inputs_by_stage: downstream_side_inputs = frozenset() for transform in stage.transforms: for output in transform.outputs.values(): if output in all_side_inputs: downstream_side_inputs = union( downstream_side_inputs, frozenset([output])) for consumer in consumers[output]: downstream_side_inputs = union( downstream_side_inputs, compute_downstream_side_inputs(consumer)) downstream_side_inputs_by_stage[ stage] = downstream_side_inputs return downstream_side_inputs_by_stage[stage] for stage in stages: stage.downstream_side_inputs = compute_downstream_side_inputs( stage) return stages def greedily_fuse(stages): """Places transforms sharing an edge in the same stage, whenever possible. """ producers_by_pcoll = {} consumers_by_pcoll = collections.defaultdict(list) # Used to always reference the correct stage as the producer and # consumer maps are not updated when stages are fused away. replacements = {} def replacement(s): old_ss = [] while s in replacements: old_ss.append(s) s = replacements[s] for old_s in old_ss[:-1]: replacements[old_s] = s return s def fuse(producer, consumer): fused = producer.fuse(consumer) replacements[producer] = fused replacements[consumer] = fused # First record the producers and consumers of each PCollection. for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): consumers_by_pcoll[input].append(stage) for output in transform.outputs.values(): producers_by_pcoll[output] = stage logging.debug('consumers\n%s', consumers_by_pcoll) logging.debug('producers\n%s', producers_by_pcoll) # Now try to fuse away all pcollections. for pcoll, producer in producers_by_pcoll.items(): pcoll_as_param = str("materialize:%s" % pcoll) write_pcoll = None for consumer in consumers_by_pcoll[pcoll]: producer = replacement(producer) consumer = replacement(consumer) # Update consumer.must_follow set, as it's used in can_fuse. consumer.must_follow = frozenset( replacement(s) for s in consumer.must_follow) if producer.can_fuse(consumer): fuse(producer, consumer) else: # If we can't fuse, do a read + write. if write_pcoll is None: write_pcoll = Stage(pcoll + '/Write', [ beam_runner_api_pb2.PTransform( unique_name=pcoll + '/Write', inputs={'in': pcoll}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, payload=pcoll_as_param)) ]) fuse(producer, write_pcoll) if consumer.has_as_main_input(pcoll): read_pcoll = Stage(pcoll + '/Read', [ beam_runner_api_pb2.PTransform( unique_name=pcoll + '/Read', outputs={'out': pcoll}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, payload=pcoll_as_param)) ], must_follow=frozenset( [write_pcoll])) fuse(read_pcoll, consumer) else: consumer.must_follow = union( consumer.must_follow, frozenset([write_pcoll])) # Everything that was originally a stage or a replacement, but wasn't # replaced, should be in the final graph. final_stages = frozenset(stages).union( replacements.values()).difference(replacements.keys()) for stage in final_stages: # Update all references to their final values before throwing # the replacement data away. stage.must_follow = frozenset( replacement(s) for s in stage.must_follow) # Two reads of the same stage may have been fused. This is unneeded. stage.deduplicate_read() return final_stages def sort_stages(stages): """Order stages suitable for sequential execution. """ seen = set() ordered = [] def process(stage): if stage not in seen: seen.add(stage) for prev in stage.must_follow: process(prev) ordered.append(stage) for stage in stages: process(stage) return ordered # Now actually apply the operations. pipeline_components = copy.deepcopy(pipeline_proto.components) # Reify coders. # TODO(BEAM-2717): Remove once Coders are already in proto. coders = pipeline_context.PipelineContext(pipeline_components).coders for pcoll in pipeline_components.pcollections.values(): if pcoll.coder_id not in coders: window_coder = coders[pipeline_components.windowing_strategies[ pcoll.windowing_strategy_id].window_coder_id] coder = WindowedValueCoder(registry.get_coder( pickler.loads(pcoll.coder_id)), window_coder=window_coder) pcoll.coder_id = coders.get_id(coder) coders.populate_map(pipeline_components.coders) known_composites = set([urns.GROUP_BY_KEY_TRANSFORM]) def leaf_transforms(root_ids): for root_id in root_ids: root = pipeline_proto.components.transforms[root_id] if root.spec.urn in known_composites or not root.subtransforms: yield root_id else: for leaf in leaf_transforms(root.subtransforms): yield leaf # Initial set of stages are singleton leaf transforms. stages = [ Stage(name, [pipeline_proto.components.transforms[name]]) for name in leaf_transforms(pipeline_proto.root_transform_ids) ] # Apply each phase in order. for phase in [ annotate_downstream_side_inputs, expand_gbk, sink_flattens, greedily_fuse, sort_stages ]: logging.info('%s %s %s', '=' * 20, phase, '=' * 20) stages = list(phase(stages)) logging.debug('Stages: %s', [str(s) for s in stages]) # Return the (possibly mutated) context and ordered set of stages. return pipeline_components, stages, safe_coders
def _map_task_to_protos(self, map_task, data_operation_spec): input_data = {} side_input_data = {} runner_sinks = {} context = pipeline_context.PipelineContext() transform_protos = {} used_pcollections = {} def uniquify(*names): # An injective mapping from string* to string. return ':'.join("%s:%d" % (name, len(name)) for name in names) def pcollection_id(op_ix, out_ix): if (op_ix, out_ix) not in used_pcollections: used_pcollections[op_ix, out_ix] = uniquify( map_task[op_ix][0], 'out', str(out_ix)) return used_pcollections[op_ix, out_ix] def get_inputs(op): if hasattr(op, 'inputs'): inputs = op.inputs elif hasattr(op, 'input'): inputs = [op.input] else: inputs = [] return {'in%s' % ix: pcollection_id(*input) for ix, input in enumerate(inputs)} def get_outputs(op_ix): op = map_task[op_ix][1] return {tag: pcollection_id(op_ix, out_ix) for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))} for op_ix, (stage_name, operation) in enumerate(map_task): transform_id = uniquify(stage_name) if isinstance(operation, operation_specs.WorkerInMemoryWrite): # Write this data back to the runner. target_name = only_element(get_inputs(operation).keys()) runner_sinks[(transform_id, target_name)] = operation transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, any_param=proto_utils.pack_Any(data_operation_spec), payload=data_operation_spec.SerializeToString() \ if data_operation_spec is not None else None) elif isinstance(operation, operation_specs.WorkerRead): # A Read from an in-memory source is done over the data plane. if (isinstance(operation.source.source, maptask_executor_runner.InMemorySource) and isinstance(operation.source.source.default_output_coder(), WindowedValueCoder)): target_name = only_element(get_outputs(op_ix).keys()) input_data[(transform_id, target_name)] = self._reencode_elements( operation.source.source.read(None), operation.source.source.default_output_coder()) transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, any_param=proto_utils.pack_Any(data_operation_spec), payload=data_operation_spec.SerializeToString() \ if data_operation_spec is not None else None) else: # Otherwise serialize the source and execute it there. # TODO: Use SDFs with an initial impulse. # The Dataflow runner harness strips the base64 encoding. do the same # here until we get the same thing back that we sent in. source_bytes = base64.b64decode( pickler.dumps(operation.source.source)) transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.PYTHON_SOURCE_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue( value=source_bytes)), payload=source_bytes) elif isinstance(operation, operation_specs.WorkerDoFn): # Record the contents of each side input for access via the state api. side_input_extras = [] for si in operation.side_inputs: assert isinstance(si.source, iobase.BoundedSource) element_coder = si.source.default_output_coder() # TODO(robertwb): Actually flesh out the ViewFn API. side_input_extras.append((si.tag, element_coder)) side_input_data[ bundle_processor.side_input_tag(transform_id, si.tag)] = ( self._reencode_elements( si.source.read(si.source.get_range_tracker(None, None)), element_coder)) augmented_serialized_fn = pickler.dumps( (operation.serialized_fn, side_input_extras)) transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.PYTHON_DOFN_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue(value=augmented_serialized_fn)), payload=augmented_serialized_fn) elif isinstance(operation, operation_specs.WorkerFlatten): # Flatten is nice and simple. transform_spec = beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.IDENTITY_DOFN_URN) else: raise NotImplementedError(operation) transform_protos[transform_id] = beam_runner_api_pb2.PTransform( unique_name=stage_name, spec=transform_spec, inputs=get_inputs(operation), outputs=get_outputs(op_ix)) pcollection_protos = { name: beam_runner_api_pb2.PCollection( unique_name=name, coder_id=context.coders.get_id( map_task[op_id][1].output_coders[out_id])) for (op_id, out_id), name in used_pcollections.items() } # Must follow creation of pcollection_protos to capture used coders. context_proto = context.to_runner_api() process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms=transform_protos, pcollections=pcollection_protos, coders=dict(context_proto.coders.items()), windowing_strategies=dict(context_proto.windowing_strategies.items()), environments=dict(context_proto.environments.items())) return input_data, side_input_data, runner_sinks, process_bundle_descriptor