def _config_value(cls, obj, typehint): """ Helper to create a ConfigValue with an encoded value. """ coder = registry.get_coder(typehint) urns = list(iter_urns(coder)) if 'beam:coder:pickled_python:v1' in urns: raise RuntimeError("Found non-portable coder for %s" % (typehint, )) return ConfigValue( coder_urn=urns, payload=coder.get_impl().encode_nested(obj))
def create_stages(self, pipeline_proto): # First define a couple of helpers. def union(a, b): # Minimize the number of distinct sets. if not a or a == b: return b elif not b: return a else: return frozenset.union(a, b) class Stage(object): """A set of Transforms that can be sent to the worker for processing.""" def __init__(self, name, transforms, downstream_side_inputs=None, must_follow=frozenset()): self.name = name self.transforms = transforms self.downstream_side_inputs = downstream_side_inputs self.must_follow = must_follow def __repr__(self): must_follow = ', '.join(prev.name for prev in self.must_follow) downstream_side_inputs = ', '.join( str(si) for si in self.downstream_side_inputs) return "%s\n %s\n must follow: %s\n downstream_side_inputs: %s" % ( self.name, '\n'.join([ "%s:%s" % (transform.unique_name, transform.spec.urn) for transform in self.transforms ]), must_follow, downstream_side_inputs) def can_fuse(self, consumer): def no_overlap(a, b): return not a.intersection(b) return (not self in consumer.must_follow and not self.is_flatten() and not consumer.is_flatten() and no_overlap(self.downstream_side_inputs, consumer.side_inputs())) def fuse(self, other): return Stage( "(%s)+(%s)" % (self.name, other.name), self.transforms + other.transforms, union(self.downstream_side_inputs, other.downstream_side_inputs), union(self.must_follow, other.must_follow)) def is_flatten(self): return any(transform.spec.urn == urns.FLATTEN_TRANSFORM for transform in self.transforms) def side_inputs(self): for transform in self.transforms: if transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for side_input in payload.side_inputs: yield transform.inputs[side_input] def has_as_main_input(self, pcoll): for transform in self.transforms: if transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) local_side_inputs = payload.side_inputs else: local_side_inputs = {} for local_id, pipeline_id in transform.inputs.items(): if pcoll == pipeline_id and local_id not in local_side_inputs: return True def deduplicate_read(self): seen_pcolls = set() new_transforms = [] for transform in self.transforms: if transform.spec.urn == bundle_processor.DATA_INPUT_URN: pcoll = only_element(transform.outputs.items())[1] if pcoll in seen_pcolls: continue seen_pcolls.add(pcoll) new_transforms.append(transform) self.transforms = new_transforms # Now define the "optimization" phases. safe_coders = {} def expand_gbk(stages): """Transforms each GBK into a write followed by a read. """ good_coder_urns = set(beam.coders.Coder._known_urns.keys()) - set( [urns.PICKLED_CODER]) coders = pipeline_components.coders for coder_id, coder_proto in coders.items(): if coder_proto.spec.spec.urn == urns.BYTES_CODER: bytes_coder_id = coder_id break else: bytes_coder_id = unique_name(coders, 'bytes_coder') pipeline_components.coders[bytes_coder_id].CopyFrom( beam.coders.BytesCoder().to_runner_api(None)) coder_substitutions = {} def wrap_unknown_coders(coder_id, with_bytes): if (coder_id, with_bytes) not in coder_substitutions: wrapped_coder_id = None coder_proto = coders[coder_id] if coder_proto.spec.spec.urn == urns.LENGTH_PREFIX_CODER: coder_substitutions[coder_id, with_bytes] = ( bytes_coder_id if with_bytes else coder_id) elif coder_proto.spec.spec.urn in good_coder_urns: wrapped_components = [ wrap_unknown_coders(c, with_bytes) for c in coder_proto.component_coder_ids ] if wrapped_components == list( coder_proto.component_coder_ids): # Use as is. coder_substitutions[coder_id, with_bytes] = coder_id else: wrapped_coder_id = unique_name( coders, coder_id + ("_bytes" if with_bytes else "_len_prefix")) coders[wrapped_coder_id].CopyFrom(coder_proto) coders[wrapped_coder_id].component_coder_ids[:] = [ wrap_unknown_coders(c, with_bytes) for c in coder_proto.component_coder_ids ] coder_substitutions[coder_id, with_bytes] = wrapped_coder_id else: # Not a known coder. if with_bytes: coder_substitutions[coder_id, with_bytes] = bytes_coder_id else: wrapped_coder_id = unique_name( coders, coder_id + "_len_prefix") len_prefix_coder_proto = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( urn=urns.LENGTH_PREFIX_CODER)), component_coder_ids=[coder_id]) coders[wrapped_coder_id].CopyFrom( len_prefix_coder_proto) coder_substitutions[coder_id, with_bytes] = wrapped_coder_id # This operation is idempotent. if wrapped_coder_id: coder_substitutions[wrapped_coder_id, with_bytes] = wrapped_coder_id return coder_substitutions[coder_id, with_bytes] def fix_pcoll_coder(pcoll): new_coder_id = wrap_unknown_coders(pcoll.coder_id, False) safe_coders[new_coder_id] = wrap_unknown_coders( pcoll.coder_id, True) pcoll.coder_id = new_coder_id for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] if transform.spec.urn == urns.GROUP_BY_KEY_TRANSFORM: for pcoll_id in transform.inputs.values(): fix_pcoll_coder( pipeline_components.pcollections[pcoll_id]) for pcoll_id in transform.outputs.values(): fix_pcoll_coder( pipeline_components.pcollections[pcoll_id]) # This is used later to correlate the read and write. param = str("group:%s" % stage.name) gbk_write = Stage(transform.unique_name + '/Write', [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Write', inputs=transform.inputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, payload=param)) ], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) yield gbk_write yield Stage(transform.unique_name + '/Read', [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Read', outputs=transform.outputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, payload=param)) ], downstream_side_inputs=frozenset(), must_follow=union(frozenset([gbk_write]), stage.must_follow)) else: yield stage def sink_flattens(stages): """Sink flattens and remove them from the graph. A flatten that cannot be sunk/fused away becomes multiple writes (to the same logical sink) followed by a read. """ # TODO(robertwb): Actually attempt to sink rather than always materialize. # TODO(robertwb): Possibly fuse this into one of the stages. pcollections = pipeline_components.pcollections for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] if transform.spec.urn == urns.FLATTEN_TRANSFORM: # This is used later to correlate the read and writes. param = str("materialize:%s" % transform.unique_name) output_pcoll_id, = transform.outputs.values() output_coder_id = pcollections[output_pcoll_id].coder_id flatten_writes = [] for local_in, pcoll_in in transform.inputs.items(): if pcollections[pcoll_in].coder_id != output_coder_id: # Flatten inputs must all be written with the same coder as is # used to read them. pcollections[pcoll_in].coder_id = output_coder_id transcoded_pcollection = (transform.unique_name + '/Transcode/' + local_in + '/out') yield Stage( transform.unique_name + '/Transcode/' + local_in, [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Transcode/' + local_in, inputs={local_in: pcoll_in}, outputs={ 'out': transcoded_pcollection }, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor. IDENTITY_DOFN_URN)) ], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) pcollections[transcoded_pcollection].CopyFrom( pcollections[pcoll_in]) pcollections[ transcoded_pcollection].coder_id = output_coder_id else: transcoded_pcollection = pcoll_in flatten_write = Stage( transform.unique_name + '/Write/' + local_in, [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Write/' + local_in, inputs={local_in: transcoded_pcollection}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, payload=param)) ], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) flatten_writes.append(flatten_write) yield flatten_write yield Stage(transform.unique_name + '/Read', [ beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Read', outputs=transform.outputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, payload=param)) ], downstream_side_inputs=frozenset(), must_follow=union(frozenset(flatten_writes), stage.must_follow)) else: yield stage def annotate_downstream_side_inputs(stages): """Annotate each stage with fusion-prohibiting information. Each stage is annotated with the (transitive) set of pcollections that depend on this stage that are also used later in the pipeline as a side input. While theoretically this could result in O(n^2) annotations, the size of each set is bounded by the number of side inputs (typically much smaller than the number of total nodes) and the number of *distinct* side-input sets is also generally small (and shared due to the use of union defined above). This representation is also amenable to simple recomputation on fusion. """ consumers = collections.defaultdict(list) all_side_inputs = set() for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): consumers[input].append(stage) for si in stage.side_inputs(): all_side_inputs.add(si) all_side_inputs = frozenset(all_side_inputs) downstream_side_inputs_by_stage = {} def compute_downstream_side_inputs(stage): if stage not in downstream_side_inputs_by_stage: downstream_side_inputs = frozenset() for transform in stage.transforms: for output in transform.outputs.values(): if output in all_side_inputs: downstream_side_inputs = union( downstream_side_inputs, frozenset([output])) for consumer in consumers[output]: downstream_side_inputs = union( downstream_side_inputs, compute_downstream_side_inputs(consumer)) downstream_side_inputs_by_stage[ stage] = downstream_side_inputs return downstream_side_inputs_by_stage[stage] for stage in stages: stage.downstream_side_inputs = compute_downstream_side_inputs( stage) return stages def greedily_fuse(stages): """Places transforms sharing an edge in the same stage, whenever possible. """ producers_by_pcoll = {} consumers_by_pcoll = collections.defaultdict(list) # Used to always reference the correct stage as the producer and # consumer maps are not updated when stages are fused away. replacements = {} def replacement(s): old_ss = [] while s in replacements: old_ss.append(s) s = replacements[s] for old_s in old_ss[:-1]: replacements[old_s] = s return s def fuse(producer, consumer): fused = producer.fuse(consumer) replacements[producer] = fused replacements[consumer] = fused # First record the producers and consumers of each PCollection. for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): consumers_by_pcoll[input].append(stage) for output in transform.outputs.values(): producers_by_pcoll[output] = stage logging.debug('consumers\n%s', consumers_by_pcoll) logging.debug('producers\n%s', producers_by_pcoll) # Now try to fuse away all pcollections. for pcoll, producer in producers_by_pcoll.items(): pcoll_as_param = str("materialize:%s" % pcoll) write_pcoll = None for consumer in consumers_by_pcoll[pcoll]: producer = replacement(producer) consumer = replacement(consumer) # Update consumer.must_follow set, as it's used in can_fuse. consumer.must_follow = frozenset( replacement(s) for s in consumer.must_follow) if producer.can_fuse(consumer): fuse(producer, consumer) else: # If we can't fuse, do a read + write. if write_pcoll is None: write_pcoll = Stage(pcoll + '/Write', [ beam_runner_api_pb2.PTransform( unique_name=pcoll + '/Write', inputs={'in': pcoll}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, payload=pcoll_as_param)) ]) fuse(producer, write_pcoll) if consumer.has_as_main_input(pcoll): read_pcoll = Stage(pcoll + '/Read', [ beam_runner_api_pb2.PTransform( unique_name=pcoll + '/Read', outputs={'out': pcoll}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, payload=pcoll_as_param)) ], must_follow=frozenset( [write_pcoll])) fuse(read_pcoll, consumer) else: consumer.must_follow = union( consumer.must_follow, frozenset([write_pcoll])) # Everything that was originally a stage or a replacement, but wasn't # replaced, should be in the final graph. final_stages = frozenset(stages).union( replacements.values()).difference(replacements.keys()) for stage in final_stages: # Update all references to their final values before throwing # the replacement data away. stage.must_follow = frozenset( replacement(s) for s in stage.must_follow) # Two reads of the same stage may have been fused. This is unneeded. stage.deduplicate_read() return final_stages def sort_stages(stages): """Order stages suitable for sequential execution. """ seen = set() ordered = [] def process(stage): if stage not in seen: seen.add(stage) for prev in stage.must_follow: process(prev) ordered.append(stage) for stage in stages: process(stage) return ordered # Now actually apply the operations. pipeline_components = copy.deepcopy(pipeline_proto.components) # Reify coders. # TODO(BEAM-2717): Remove once Coders are already in proto. coders = pipeline_context.PipelineContext(pipeline_components).coders for pcoll in pipeline_components.pcollections.values(): if pcoll.coder_id not in coders: window_coder = coders[pipeline_components.windowing_strategies[ pcoll.windowing_strategy_id].window_coder_id] coder = WindowedValueCoder(registry.get_coder( pickler.loads(pcoll.coder_id)), window_coder=window_coder) pcoll.coder_id = coders.get_id(coder) coders.populate_map(pipeline_components.coders) known_composites = set([urns.GROUP_BY_KEY_TRANSFORM]) def leaf_transforms(root_ids): for root_id in root_ids: root = pipeline_proto.components.transforms[root_id] if root.spec.urn in known_composites or not root.subtransforms: yield root_id else: for leaf in leaf_transforms(root.subtransforms): yield leaf # Initial set of stages are singleton leaf transforms. stages = [ Stage(name, [pipeline_proto.components.transforms[name]]) for name in leaf_transforms(pipeline_proto.root_transform_ids) ] # Apply each phase in order. for phase in [ annotate_downstream_side_inputs, expand_gbk, sink_flattens, greedily_fuse, sort_stages ]: logging.info('%s %s %s', '=' * 20, phase, '=' * 20) stages = list(phase(stages)) logging.debug('Stages: %s', [str(s) for s in stages]) # Return the (possibly mutated) context and ordered set of stages. return pipeline_components, stages, safe_coders
def create_stages(self, pipeline_proto): # First define a couple of helpers. def union(a, b): # Minimize the number of distinct sets. if not a or a == b: return b elif not b: return a else: return frozenset.union(a, b) class Stage(object): """A set of Transforms that can be sent to the worker for processing.""" def __init__(self, name, transforms, downstream_side_inputs=None, must_follow=frozenset()): self.name = name self.transforms = transforms self.downstream_side_inputs = downstream_side_inputs self.must_follow = must_follow def __repr__(self): must_follow = ', '.join(prev.name for prev in self.must_follow) return "%s\n %s\n must follow: %s" % ( self.name, '\n'.join(["%s:%s" % (transform.unique_name, transform.spec.urn) for transform in self.transforms]), must_follow) def can_fuse(self, consumer): def no_overlap(a, b): return not a.intersection(b) return ( not self in consumer.must_follow and not self.is_flatten() and not consumer.is_flatten() and no_overlap(self.downstream_side_inputs, consumer.side_inputs())) def fuse(self, other): return Stage( "(%s)+(%s)" % (self.name, other.name), self.transforms + other.transforms, union(self.downstream_side_inputs, other.downstream_side_inputs), union(self.must_follow, other.must_follow)) def is_flatten(self): return any(transform.spec.urn == urns.FLATTEN_TRANSFORM for transform in self.transforms) def side_inputs(self): for transform in self.transforms: if transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for side_input in payload.side_inputs: yield transform.inputs[side_input] def has_as_main_input(self, pcoll): for transform in self.transforms: if transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) local_side_inputs = payload.side_inputs else: local_side_inputs = {} for local_id, pipeline_id in transform.inputs.items(): if pcoll == pipeline_id and local_id not in local_side_inputs: return True def deduplicate_read(self): seen_pcolls = set() new_transforms = [] for transform in self.transforms: if transform.spec.urn == bundle_processor.DATA_INPUT_URN: pcoll = only_element(transform.outputs.items())[1] if pcoll in seen_pcolls: continue seen_pcolls.add(pcoll) new_transforms.append(transform) self.transforms = new_transforms # Now define the "optimization" phases. safe_coders = {} def expand_gbk(stages): """Transforms each GBK into a write followed by a read. """ good_coder_urns = set(beam.coders.Coder._known_urns.keys()) - set([ urns.PICKLED_CODER]) coders = pipeline_components.coders for coder_id, coder_proto in coders.items(): if coder_proto.spec.spec.urn == urns.BYTES_CODER: bytes_coder_id = coder_id break else: bytes_coder_id = unique_name(coders, 'bytes_coder') pipeline_components.coders[bytes_coder_id].CopyFrom( beam.coders.BytesCoder().to_runner_api(None)) coder_substitutions = {} def wrap_unknown_coders(coder_id, with_bytes): if (coder_id, with_bytes) not in coder_substitutions: wrapped_coder_id = None coder_proto = coders[coder_id] if coder_proto.spec.spec.urn == urns.LENGTH_PREFIX_CODER: coder_substitutions[coder_id, with_bytes] = ( bytes_coder_id if with_bytes else coder_id) elif coder_proto.spec.spec.urn in good_coder_urns: wrapped_components = [wrap_unknown_coders(c, with_bytes) for c in coder_proto.component_coder_ids] if wrapped_components == list(coder_proto.component_coder_ids): # Use as is. coder_substitutions[coder_id, with_bytes] = coder_id else: wrapped_coder_id = unique_name( coders, coder_id + ("_bytes" if with_bytes else "_len_prefix")) coders[wrapped_coder_id].CopyFrom(coder_proto) coders[wrapped_coder_id].component_coder_ids[:] = [ wrap_unknown_coders(c, with_bytes) for c in coder_proto.component_coder_ids] coder_substitutions[coder_id, with_bytes] = wrapped_coder_id else: # Not a known coder. if with_bytes: coder_substitutions[coder_id, with_bytes] = bytes_coder_id else: wrapped_coder_id = unique_name(coders, coder_id + "_len_prefix") len_prefix_coder_proto = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( urn=urns.LENGTH_PREFIX_CODER)), component_coder_ids=[coder_id]) coders[wrapped_coder_id].CopyFrom(len_prefix_coder_proto) coder_substitutions[coder_id, with_bytes] = wrapped_coder_id # This operation is idempotent. if wrapped_coder_id: coder_substitutions[wrapped_coder_id, with_bytes] = wrapped_coder_id return coder_substitutions[coder_id, with_bytes] def fix_pcoll_coder(pcoll): new_coder_id = wrap_unknown_coders(pcoll.coder_id, False) safe_coders[new_coder_id] = wrap_unknown_coders(pcoll.coder_id, True) pcoll.coder_id = new_coder_id for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] if transform.spec.urn == urns.GROUP_BY_KEY_ONLY_TRANSFORM: for pcoll_id in transform.inputs.values(): fix_pcoll_coder(pipeline_components.pcollections[pcoll_id]) for pcoll_id in transform.outputs.values(): fix_pcoll_coder(pipeline_components.pcollections[pcoll_id]) # This is used later to correlate the read and write. param = str("group:%s" % stage.name) gbk_write = Stage( transform.unique_name + '/Write', [beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Write', inputs=transform.inputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue(value=param)), payload=param))], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) yield gbk_write yield Stage( transform.unique_name + '/Read', [beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Read', outputs=transform.outputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue(value=param)), payload=param))], downstream_side_inputs=frozenset(), must_follow=union(frozenset([gbk_write]), stage.must_follow)) else: yield stage def sink_flattens(stages): """Sink flattens and remove them from the graph. A flatten that cannot be sunk/fused away becomes multiple writes (to the same logical sink) followed by a read. """ # TODO(robertwb): Actually attempt to sink rather than always materialize. # TODO(robertwb): Possibly fuse this into one of the stages. pcollections = pipeline_components.pcollections for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] if transform.spec.urn == urns.FLATTEN_TRANSFORM: # This is used later to correlate the read and writes. param = str("materialize:%s" % transform.unique_name) output_pcoll_id, = transform.outputs.values() output_coder_id = pcollections[output_pcoll_id].coder_id flatten_writes = [] for local_in, pcoll_in in transform.inputs.items(): if pcollections[pcoll_in].coder_id != output_coder_id: # Flatten inputs must all be written with the same coder as is # used to read them. pcollections[pcoll_in].coder_id = output_coder_id transcoded_pcollection = ( transform.unique_name + '/Transcode/' + local_in + '/out') yield Stage( transform.unique_name + '/Transcode/' + local_in, [beam_runner_api_pb2.PTransform( unique_name= transform.unique_name + '/Transcode/' + local_in, inputs={local_in: pcoll_in}, outputs={'out': transcoded_pcollection}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.IDENTITY_DOFN_URN))], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) pcollections[transcoded_pcollection].CopyFrom( pcollections[pcoll_in]) pcollections[transcoded_pcollection].coder_id = output_coder_id else: transcoded_pcollection = pcoll_in flatten_write = Stage( transform.unique_name + '/Write/' + local_in, [beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Write/' + local_in, inputs={local_in: transcoded_pcollection}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue( value=param)), payload=param))], downstream_side_inputs=frozenset(), must_follow=stage.must_follow) flatten_writes.append(flatten_write) yield flatten_write yield Stage( transform.unique_name + '/Read', [beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Read', outputs=transform.outputs, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue( value=param)), payload=param))], downstream_side_inputs=frozenset(), must_follow=union(frozenset(flatten_writes), stage.must_follow)) else: yield stage def annotate_downstream_side_inputs(stages): """Annotate each stage with fusion-prohibiting information. Each stage is annotated with the (transitive) set of pcollections that depend on this stage that are also used later in the pipeline as a side input. While theoretically this could result in O(n^2) annotations, the size of each set is bounded by the number of side inputs (typically much smaller than the number of total nodes) and the number of *distinct* side-input sets is also generally small (and shared due to the use of union defined above). This representation is also amenable to simple recomputation on fusion. """ consumers = collections.defaultdict(list) all_side_inputs = set() for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): consumers[input].append(stage) for si in stage.side_inputs(): all_side_inputs.add(si) all_side_inputs = frozenset(all_side_inputs) downstream_side_inputs_by_stage = {} def compute_downstream_side_inputs(stage): if stage not in downstream_side_inputs_by_stage: downstream_side_inputs = frozenset() for transform in stage.transforms: for output in transform.outputs.values(): if output in all_side_inputs: downstream_side_inputs = union(downstream_side_inputs, output) for consumer in consumers[output]: downstream_side_inputs = union( downstream_side_inputs, compute_downstream_side_inputs(consumer)) downstream_side_inputs_by_stage[stage] = downstream_side_inputs return downstream_side_inputs_by_stage[stage] for stage in stages: stage.downstream_side_inputs = compute_downstream_side_inputs(stage) return stages def greedily_fuse(stages): """Places transforms sharing an edge in the same stage, whenever possible. """ producers_by_pcoll = {} consumers_by_pcoll = collections.defaultdict(list) # Used to always reference the correct stage as the producer and # consumer maps are not updated when stages are fused away. replacements = {} def replacement(s): old_ss = [] while s in replacements: old_ss.append(s) s = replacements[s] for old_s in old_ss[:-1]: replacements[old_s] = s return s def fuse(producer, consumer): fused = producer.fuse(consumer) replacements[producer] = fused replacements[consumer] = fused # First record the producers and consumers of each PCollection. for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): consumers_by_pcoll[input].append(stage) for output in transform.outputs.values(): producers_by_pcoll[output] = stage logging.debug('consumers\n%s', consumers_by_pcoll) logging.debug('producers\n%s', producers_by_pcoll) # Now try to fuse away all pcollections. for pcoll, producer in producers_by_pcoll.items(): pcoll_as_param = str("materialize:%s" % pcoll) write_pcoll = None for consumer in consumers_by_pcoll[pcoll]: producer = replacement(producer) consumer = replacement(consumer) # Update consumer.must_follow set, as it's used in can_fuse. consumer.must_follow = set( replacement(s) for s in consumer.must_follow) if producer.can_fuse(consumer): fuse(producer, consumer) else: # If we can't fuse, do a read + write. if write_pcoll is None: write_pcoll = Stage( pcoll + '/Write', [beam_runner_api_pb2.PTransform( unique_name=pcoll + '/Write', inputs={'in': pcoll}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_OUTPUT_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue( value=pcoll_as_param)), payload=pcoll_as_param))]) fuse(producer, write_pcoll) if consumer.has_as_main_input(pcoll): read_pcoll = Stage( pcoll + '/Read', [beam_runner_api_pb2.PTransform( unique_name=pcoll + '/Read', outputs={'out': pcoll}, spec=beam_runner_api_pb2.FunctionSpec( urn=bundle_processor.DATA_INPUT_URN, any_param=proto_utils.pack_Any( wrappers_pb2.BytesValue( value=pcoll_as_param)), payload=pcoll_as_param))], must_follow={write_pcoll}) fuse(read_pcoll, consumer) # Everything that was originally a stage or a replacement, but wasn't # replaced, should be in the final graph. final_stages = frozenset(stages).union(replacements.values()).difference( replacements.keys()) for stage in final_stages: # Update all references to their final values before throwing # the replacement data away. stage.must_follow = frozenset(replacement(s) for s in stage.must_follow) # Two reads of the same stage may have been fused. This is unneeded. stage.deduplicate_read() return final_stages def sort_stages(stages): """Order stages suitable for sequential execution. """ seen = set() ordered = [] def process(stage): if stage not in seen: seen.add(stage) for prev in stage.must_follow: process(prev) ordered.append(stage) for stage in stages: process(stage) return ordered # Now actually apply the operations. pipeline_components = copy.deepcopy(pipeline_proto.components) # Reify coders. # TODO(BEAM-2717): Remove once Coders are already in proto. coders = pipeline_context.PipelineContext(pipeline_components).coders for pcoll in pipeline_components.pcollections.values(): if pcoll.coder_id not in coders: window_coder = coders[ pipeline_components.windowing_strategies[ pcoll.windowing_strategy_id].window_coder_id] coder = WindowedValueCoder( registry.get_coder(pickler.loads(pcoll.coder_id)), window_coder=window_coder) pcoll.coder_id = coders.get_id(coder) coders.populate_map(pipeline_components.coders) # Initial set of stages are singleton transforms. stages = [ Stage(name, [transform]) for name, transform in pipeline_proto.components.transforms.items() if not transform.subtransforms] # Apply each phase in order. for phase in [ annotate_downstream_side_inputs, expand_gbk, sink_flattens, greedily_fuse, sort_stages]: logging.info('%s %s %s', '=' * 20, phase, '=' * 20) stages = list(phase(stages)) logging.debug('Stages: %s', [str(s) for s in stages]) # Return the (possibly mutated) context and ordered set of stages. return pipeline_components, stages, safe_coders