Exemple #1
0
 def _config_value(cls, obj, typehint):
   """
   Helper to create a ConfigValue with an encoded value.
   """
   coder = registry.get_coder(typehint)
   urns = list(iter_urns(coder))
   if 'beam:coder:pickled_python:v1' in urns:
     raise RuntimeError("Found non-portable coder for %s" % (typehint, ))
   return ConfigValue(
       coder_urn=urns, payload=coder.get_impl().encode_nested(obj))
Exemple #2
0
    def create_stages(self, pipeline_proto):

        # First define a couple of helpers.

        def union(a, b):
            # Minimize the number of distinct sets.
            if not a or a == b:
                return b
            elif not b:
                return a
            else:
                return frozenset.union(a, b)

        class Stage(object):
            """A set of Transforms that can be sent to the worker for processing."""
            def __init__(self,
                         name,
                         transforms,
                         downstream_side_inputs=None,
                         must_follow=frozenset()):
                self.name = name
                self.transforms = transforms
                self.downstream_side_inputs = downstream_side_inputs
                self.must_follow = must_follow

            def __repr__(self):
                must_follow = ', '.join(prev.name for prev in self.must_follow)
                downstream_side_inputs = ', '.join(
                    str(si) for si in self.downstream_side_inputs)
                return "%s\n  %s\n  must follow: %s\n  downstream_side_inputs: %s" % (
                    self.name, '\n'.join([
                        "%s:%s" % (transform.unique_name, transform.spec.urn)
                        for transform in self.transforms
                    ]), must_follow, downstream_side_inputs)

            def can_fuse(self, consumer):
                def no_overlap(a, b):
                    return not a.intersection(b)

                return (not self in consumer.must_follow
                        and not self.is_flatten()
                        and not consumer.is_flatten()
                        and no_overlap(self.downstream_side_inputs,
                                       consumer.side_inputs()))

            def fuse(self, other):
                return Stage(
                    "(%s)+(%s)" % (self.name, other.name),
                    self.transforms + other.transforms,
                    union(self.downstream_side_inputs,
                          other.downstream_side_inputs),
                    union(self.must_follow, other.must_follow))

            def is_flatten(self):
                return any(transform.spec.urn == urns.FLATTEN_TRANSFORM
                           for transform in self.transforms)

            def side_inputs(self):
                for transform in self.transforms:
                    if transform.spec.urn == urns.PARDO_TRANSFORM:
                        payload = proto_utils.parse_Bytes(
                            transform.spec.payload,
                            beam_runner_api_pb2.ParDoPayload)
                        for side_input in payload.side_inputs:
                            yield transform.inputs[side_input]

            def has_as_main_input(self, pcoll):
                for transform in self.transforms:
                    if transform.spec.urn == urns.PARDO_TRANSFORM:
                        payload = proto_utils.parse_Bytes(
                            transform.spec.payload,
                            beam_runner_api_pb2.ParDoPayload)
                        local_side_inputs = payload.side_inputs
                    else:
                        local_side_inputs = {}
                    for local_id, pipeline_id in transform.inputs.items():
                        if pcoll == pipeline_id and local_id not in local_side_inputs:
                            return True

            def deduplicate_read(self):
                seen_pcolls = set()
                new_transforms = []
                for transform in self.transforms:
                    if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
                        pcoll = only_element(transform.outputs.items())[1]
                        if pcoll in seen_pcolls:
                            continue
                        seen_pcolls.add(pcoll)
                    new_transforms.append(transform)
                self.transforms = new_transforms

        # Now define the "optimization" phases.

        safe_coders = {}

        def expand_gbk(stages):
            """Transforms each GBK into a write followed by a read.
      """
            good_coder_urns = set(beam.coders.Coder._known_urns.keys()) - set(
                [urns.PICKLED_CODER])
            coders = pipeline_components.coders

            for coder_id, coder_proto in coders.items():
                if coder_proto.spec.spec.urn == urns.BYTES_CODER:
                    bytes_coder_id = coder_id
                    break
            else:
                bytes_coder_id = unique_name(coders, 'bytes_coder')
                pipeline_components.coders[bytes_coder_id].CopyFrom(
                    beam.coders.BytesCoder().to_runner_api(None))

            coder_substitutions = {}

            def wrap_unknown_coders(coder_id, with_bytes):
                if (coder_id, with_bytes) not in coder_substitutions:
                    wrapped_coder_id = None
                    coder_proto = coders[coder_id]
                    if coder_proto.spec.spec.urn == urns.LENGTH_PREFIX_CODER:
                        coder_substitutions[coder_id, with_bytes] = (
                            bytes_coder_id if with_bytes else coder_id)
                    elif coder_proto.spec.spec.urn in good_coder_urns:
                        wrapped_components = [
                            wrap_unknown_coders(c, with_bytes)
                            for c in coder_proto.component_coder_ids
                        ]
                        if wrapped_components == list(
                                coder_proto.component_coder_ids):
                            # Use as is.
                            coder_substitutions[coder_id,
                                                with_bytes] = coder_id
                        else:
                            wrapped_coder_id = unique_name(
                                coders, coder_id +
                                ("_bytes" if with_bytes else "_len_prefix"))
                            coders[wrapped_coder_id].CopyFrom(coder_proto)
                            coders[wrapped_coder_id].component_coder_ids[:] = [
                                wrap_unknown_coders(c, with_bytes)
                                for c in coder_proto.component_coder_ids
                            ]
                            coder_substitutions[coder_id,
                                                with_bytes] = wrapped_coder_id
                    else:
                        # Not a known coder.
                        if with_bytes:
                            coder_substitutions[coder_id,
                                                with_bytes] = bytes_coder_id
                        else:
                            wrapped_coder_id = unique_name(
                                coders, coder_id + "_len_prefix")
                            len_prefix_coder_proto = beam_runner_api_pb2.Coder(
                                spec=beam_runner_api_pb2.SdkFunctionSpec(
                                    spec=beam_runner_api_pb2.FunctionSpec(
                                        urn=urns.LENGTH_PREFIX_CODER)),
                                component_coder_ids=[coder_id])
                            coders[wrapped_coder_id].CopyFrom(
                                len_prefix_coder_proto)
                            coder_substitutions[coder_id,
                                                with_bytes] = wrapped_coder_id
                    # This operation is idempotent.
                    if wrapped_coder_id:
                        coder_substitutions[wrapped_coder_id,
                                            with_bytes] = wrapped_coder_id
                return coder_substitutions[coder_id, with_bytes]

            def fix_pcoll_coder(pcoll):
                new_coder_id = wrap_unknown_coders(pcoll.coder_id, False)
                safe_coders[new_coder_id] = wrap_unknown_coders(
                    pcoll.coder_id, True)
                pcoll.coder_id = new_coder_id

            for stage in stages:
                assert len(stage.transforms) == 1
                transform = stage.transforms[0]
                if transform.spec.urn == urns.GROUP_BY_KEY_TRANSFORM:
                    for pcoll_id in transform.inputs.values():
                        fix_pcoll_coder(
                            pipeline_components.pcollections[pcoll_id])
                    for pcoll_id in transform.outputs.values():
                        fix_pcoll_coder(
                            pipeline_components.pcollections[pcoll_id])

                    # This is used later to correlate the read and write.
                    param = str("group:%s" % stage.name)
                    gbk_write = Stage(transform.unique_name + '/Write', [
                        beam_runner_api_pb2.PTransform(
                            unique_name=transform.unique_name + '/Write',
                            inputs=transform.inputs,
                            spec=beam_runner_api_pb2.FunctionSpec(
                                urn=bundle_processor.DATA_OUTPUT_URN,
                                payload=param))
                    ],
                                      downstream_side_inputs=frozenset(),
                                      must_follow=stage.must_follow)
                    yield gbk_write

                    yield Stage(transform.unique_name + '/Read', [
                        beam_runner_api_pb2.PTransform(
                            unique_name=transform.unique_name + '/Read',
                            outputs=transform.outputs,
                            spec=beam_runner_api_pb2.FunctionSpec(
                                urn=bundle_processor.DATA_INPUT_URN,
                                payload=param))
                    ],
                                downstream_side_inputs=frozenset(),
                                must_follow=union(frozenset([gbk_write]),
                                                  stage.must_follow))
                else:
                    yield stage

        def sink_flattens(stages):
            """Sink flattens and remove them from the graph.

      A flatten that cannot be sunk/fused away becomes multiple writes (to the
      same logical sink) followed by a read.
      """
            # TODO(robertwb): Actually attempt to sink rather than always materialize.
            # TODO(robertwb): Possibly fuse this into one of the stages.
            pcollections = pipeline_components.pcollections
            for stage in stages:
                assert len(stage.transforms) == 1
                transform = stage.transforms[0]
                if transform.spec.urn == urns.FLATTEN_TRANSFORM:
                    # This is used later to correlate the read and writes.
                    param = str("materialize:%s" % transform.unique_name)
                    output_pcoll_id, = transform.outputs.values()
                    output_coder_id = pcollections[output_pcoll_id].coder_id
                    flatten_writes = []
                    for local_in, pcoll_in in transform.inputs.items():

                        if pcollections[pcoll_in].coder_id != output_coder_id:
                            # Flatten inputs must all be written with the same coder as is
                            # used to read them.
                            pcollections[pcoll_in].coder_id = output_coder_id
                            transcoded_pcollection = (transform.unique_name +
                                                      '/Transcode/' +
                                                      local_in + '/out')
                            yield Stage(
                                transform.unique_name + '/Transcode/' +
                                local_in, [
                                    beam_runner_api_pb2.PTransform(
                                        unique_name=transform.unique_name +
                                        '/Transcode/' + local_in,
                                        inputs={local_in: pcoll_in},
                                        outputs={
                                            'out': transcoded_pcollection
                                        },
                                        spec=beam_runner_api_pb2.FunctionSpec(
                                            urn=bundle_processor.
                                            IDENTITY_DOFN_URN))
                                ],
                                downstream_side_inputs=frozenset(),
                                must_follow=stage.must_follow)
                            pcollections[transcoded_pcollection].CopyFrom(
                                pcollections[pcoll_in])
                            pcollections[
                                transcoded_pcollection].coder_id = output_coder_id
                        else:
                            transcoded_pcollection = pcoll_in

                        flatten_write = Stage(
                            transform.unique_name + '/Write/' + local_in, [
                                beam_runner_api_pb2.PTransform(
                                    unique_name=transform.unique_name +
                                    '/Write/' + local_in,
                                    inputs={local_in: transcoded_pcollection},
                                    spec=beam_runner_api_pb2.FunctionSpec(
                                        urn=bundle_processor.DATA_OUTPUT_URN,
                                        payload=param))
                            ],
                            downstream_side_inputs=frozenset(),
                            must_follow=stage.must_follow)
                        flatten_writes.append(flatten_write)
                        yield flatten_write

                    yield Stage(transform.unique_name + '/Read', [
                        beam_runner_api_pb2.PTransform(
                            unique_name=transform.unique_name + '/Read',
                            outputs=transform.outputs,
                            spec=beam_runner_api_pb2.FunctionSpec(
                                urn=bundle_processor.DATA_INPUT_URN,
                                payload=param))
                    ],
                                downstream_side_inputs=frozenset(),
                                must_follow=union(frozenset(flatten_writes),
                                                  stage.must_follow))

                else:
                    yield stage

        def annotate_downstream_side_inputs(stages):
            """Annotate each stage with fusion-prohibiting information.

      Each stage is annotated with the (transitive) set of pcollections that
      depend on this stage that are also used later in the pipeline as a
      side input.

      While theoretically this could result in O(n^2) annotations, the size of
      each set is bounded by the number of side inputs (typically much smaller
      than the number of total nodes) and the number of *distinct* side-input
      sets is also generally small (and shared due to the use of union
      defined above).

      This representation is also amenable to simple recomputation on fusion.
      """
            consumers = collections.defaultdict(list)
            all_side_inputs = set()
            for stage in stages:
                for transform in stage.transforms:
                    for input in transform.inputs.values():
                        consumers[input].append(stage)
                for si in stage.side_inputs():
                    all_side_inputs.add(si)
            all_side_inputs = frozenset(all_side_inputs)

            downstream_side_inputs_by_stage = {}

            def compute_downstream_side_inputs(stage):
                if stage not in downstream_side_inputs_by_stage:
                    downstream_side_inputs = frozenset()
                    for transform in stage.transforms:
                        for output in transform.outputs.values():
                            if output in all_side_inputs:
                                downstream_side_inputs = union(
                                    downstream_side_inputs,
                                    frozenset([output]))
                            for consumer in consumers[output]:
                                downstream_side_inputs = union(
                                    downstream_side_inputs,
                                    compute_downstream_side_inputs(consumer))
                    downstream_side_inputs_by_stage[
                        stage] = downstream_side_inputs
                return downstream_side_inputs_by_stage[stage]

            for stage in stages:
                stage.downstream_side_inputs = compute_downstream_side_inputs(
                    stage)
            return stages

        def greedily_fuse(stages):
            """Places transforms sharing an edge in the same stage, whenever possible.
      """
            producers_by_pcoll = {}
            consumers_by_pcoll = collections.defaultdict(list)

            # Used to always reference the correct stage as the producer and
            # consumer maps are not updated when stages are fused away.
            replacements = {}

            def replacement(s):
                old_ss = []
                while s in replacements:
                    old_ss.append(s)
                    s = replacements[s]
                for old_s in old_ss[:-1]:
                    replacements[old_s] = s
                return s

            def fuse(producer, consumer):
                fused = producer.fuse(consumer)
                replacements[producer] = fused
                replacements[consumer] = fused

            # First record the producers and consumers of each PCollection.
            for stage in stages:
                for transform in stage.transforms:
                    for input in transform.inputs.values():
                        consumers_by_pcoll[input].append(stage)
                    for output in transform.outputs.values():
                        producers_by_pcoll[output] = stage

            logging.debug('consumers\n%s', consumers_by_pcoll)
            logging.debug('producers\n%s', producers_by_pcoll)

            # Now try to fuse away all pcollections.
            for pcoll, producer in producers_by_pcoll.items():
                pcoll_as_param = str("materialize:%s" % pcoll)
                write_pcoll = None
                for consumer in consumers_by_pcoll[pcoll]:
                    producer = replacement(producer)
                    consumer = replacement(consumer)
                    # Update consumer.must_follow set, as it's used in can_fuse.
                    consumer.must_follow = frozenset(
                        replacement(s) for s in consumer.must_follow)
                    if producer.can_fuse(consumer):
                        fuse(producer, consumer)
                    else:
                        # If we can't fuse, do a read + write.
                        if write_pcoll is None:
                            write_pcoll = Stage(pcoll + '/Write', [
                                beam_runner_api_pb2.PTransform(
                                    unique_name=pcoll + '/Write',
                                    inputs={'in': pcoll},
                                    spec=beam_runner_api_pb2.FunctionSpec(
                                        urn=bundle_processor.DATA_OUTPUT_URN,
                                        payload=pcoll_as_param))
                            ])
                            fuse(producer, write_pcoll)
                        if consumer.has_as_main_input(pcoll):
                            read_pcoll = Stage(pcoll + '/Read', [
                                beam_runner_api_pb2.PTransform(
                                    unique_name=pcoll + '/Read',
                                    outputs={'out': pcoll},
                                    spec=beam_runner_api_pb2.FunctionSpec(
                                        urn=bundle_processor.DATA_INPUT_URN,
                                        payload=pcoll_as_param))
                            ],
                                               must_follow=frozenset(
                                                   [write_pcoll]))
                            fuse(read_pcoll, consumer)
                        else:
                            consumer.must_follow = union(
                                consumer.must_follow, frozenset([write_pcoll]))

            # Everything that was originally a stage or a replacement, but wasn't
            # replaced, should be in the final graph.
            final_stages = frozenset(stages).union(
                replacements.values()).difference(replacements.keys())

            for stage in final_stages:
                # Update all references to their final values before throwing
                # the replacement data away.
                stage.must_follow = frozenset(
                    replacement(s) for s in stage.must_follow)
                # Two reads of the same stage may have been fused.  This is unneeded.
                stage.deduplicate_read()
            return final_stages

        def sort_stages(stages):
            """Order stages suitable for sequential execution.
      """
            seen = set()
            ordered = []

            def process(stage):
                if stage not in seen:
                    seen.add(stage)
                    for prev in stage.must_follow:
                        process(prev)
                    ordered.append(stage)

            for stage in stages:
                process(stage)
            return ordered

        # Now actually apply the operations.

        pipeline_components = copy.deepcopy(pipeline_proto.components)

        # Reify coders.
        # TODO(BEAM-2717): Remove once Coders are already in proto.
        coders = pipeline_context.PipelineContext(pipeline_components).coders
        for pcoll in pipeline_components.pcollections.values():
            if pcoll.coder_id not in coders:
                window_coder = coders[pipeline_components.windowing_strategies[
                    pcoll.windowing_strategy_id].window_coder_id]
                coder = WindowedValueCoder(registry.get_coder(
                    pickler.loads(pcoll.coder_id)),
                                           window_coder=window_coder)
                pcoll.coder_id = coders.get_id(coder)
        coders.populate_map(pipeline_components.coders)

        known_composites = set([urns.GROUP_BY_KEY_TRANSFORM])

        def leaf_transforms(root_ids):
            for root_id in root_ids:
                root = pipeline_proto.components.transforms[root_id]
                if root.spec.urn in known_composites or not root.subtransforms:
                    yield root_id
                else:
                    for leaf in leaf_transforms(root.subtransforms):
                        yield leaf

        # Initial set of stages are singleton leaf transforms.
        stages = [
            Stage(name, [pipeline_proto.components.transforms[name]])
            for name in leaf_transforms(pipeline_proto.root_transform_ids)
        ]

        # Apply each phase in order.
        for phase in [
                annotate_downstream_side_inputs, expand_gbk, sink_flattens,
                greedily_fuse, sort_stages
        ]:
            logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
            stages = list(phase(stages))
            logging.debug('Stages: %s', [str(s) for s in stages])

        # Return the (possibly mutated) context and ordered set of stages.
        return pipeline_components, stages, safe_coders
  def create_stages(self, pipeline_proto):

    # First define a couple of helpers.

    def union(a, b):
      # Minimize the number of distinct sets.
      if not a or a == b:
        return b
      elif not b:
        return a
      else:
        return frozenset.union(a, b)

    class Stage(object):
      """A set of Transforms that can be sent to the worker for processing."""
      def __init__(self, name, transforms,
                   downstream_side_inputs=None, must_follow=frozenset()):
        self.name = name
        self.transforms = transforms
        self.downstream_side_inputs = downstream_side_inputs
        self.must_follow = must_follow

      def __repr__(self):
        must_follow = ', '.join(prev.name for prev in self.must_follow)
        return "%s\n    %s\n    must follow: %s" % (
            self.name,
            '\n'.join(["%s:%s" % (transform.unique_name, transform.spec.urn)
                       for transform in self.transforms]),
            must_follow)

      def can_fuse(self, consumer):
        def no_overlap(a, b):
          return not a.intersection(b)
        return (
            not self in consumer.must_follow
            and not self.is_flatten() and not consumer.is_flatten()
            and no_overlap(self.downstream_side_inputs, consumer.side_inputs()))

      def fuse(self, other):
        return Stage(
            "(%s)+(%s)" % (self.name, other.name),
            self.transforms + other.transforms,
            union(self.downstream_side_inputs, other.downstream_side_inputs),
            union(self.must_follow, other.must_follow))

      def is_flatten(self):
        return any(transform.spec.urn == urns.FLATTEN_TRANSFORM
                   for transform in self.transforms)

      def side_inputs(self):
        for transform in self.transforms:
          if transform.spec.urn == urns.PARDO_TRANSFORM:
            payload = proto_utils.parse_Bytes(
                transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
            for side_input in payload.side_inputs:
              yield transform.inputs[side_input]

      def has_as_main_input(self, pcoll):
        for transform in self.transforms:
          if transform.spec.urn == urns.PARDO_TRANSFORM:
            payload = proto_utils.parse_Bytes(
                transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
            local_side_inputs = payload.side_inputs
          else:
            local_side_inputs = {}
          for local_id, pipeline_id in transform.inputs.items():
            if pcoll == pipeline_id and local_id not in local_side_inputs:
              return True

      def deduplicate_read(self):
        seen_pcolls = set()
        new_transforms = []
        for transform in self.transforms:
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            pcoll = only_element(transform.outputs.items())[1]
            if pcoll in seen_pcolls:
              continue
            seen_pcolls.add(pcoll)
          new_transforms.append(transform)
        self.transforms = new_transforms

    # Now define the "optimization" phases.

    safe_coders = {}

    def expand_gbk(stages):
      """Transforms each GBK into a write followed by a read.
      """
      good_coder_urns = set(beam.coders.Coder._known_urns.keys()) - set([
          urns.PICKLED_CODER])
      coders = pipeline_components.coders

      for coder_id, coder_proto in coders.items():
        if coder_proto.spec.spec.urn == urns.BYTES_CODER:
          bytes_coder_id = coder_id
          break
      else:
        bytes_coder_id = unique_name(coders, 'bytes_coder')
        pipeline_components.coders[bytes_coder_id].CopyFrom(
            beam.coders.BytesCoder().to_runner_api(None))

      coder_substitutions = {}

      def wrap_unknown_coders(coder_id, with_bytes):
        if (coder_id, with_bytes) not in coder_substitutions:
          wrapped_coder_id = None
          coder_proto = coders[coder_id]
          if coder_proto.spec.spec.urn == urns.LENGTH_PREFIX_CODER:
            coder_substitutions[coder_id, with_bytes] = (
                bytes_coder_id if with_bytes else coder_id)
          elif coder_proto.spec.spec.urn in good_coder_urns:
            wrapped_components = [wrap_unknown_coders(c, with_bytes)
                                  for c in coder_proto.component_coder_ids]
            if wrapped_components == list(coder_proto.component_coder_ids):
              # Use as is.
              coder_substitutions[coder_id, with_bytes] = coder_id
            else:
              wrapped_coder_id = unique_name(
                  coders,
                  coder_id + ("_bytes" if with_bytes else "_len_prefix"))
              coders[wrapped_coder_id].CopyFrom(coder_proto)
              coders[wrapped_coder_id].component_coder_ids[:] = [
                  wrap_unknown_coders(c, with_bytes)
                  for c in coder_proto.component_coder_ids]
              coder_substitutions[coder_id, with_bytes] = wrapped_coder_id
          else:
            # Not a known coder.
            if with_bytes:
              coder_substitutions[coder_id, with_bytes] = bytes_coder_id
            else:
              wrapped_coder_id = unique_name(coders, coder_id +  "_len_prefix")
              len_prefix_coder_proto = beam_runner_api_pb2.Coder(
                  spec=beam_runner_api_pb2.SdkFunctionSpec(
                      spec=beam_runner_api_pb2.FunctionSpec(
                          urn=urns.LENGTH_PREFIX_CODER)),
                  component_coder_ids=[coder_id])
              coders[wrapped_coder_id].CopyFrom(len_prefix_coder_proto)
              coder_substitutions[coder_id, with_bytes] = wrapped_coder_id
          # This operation is idempotent.
          if wrapped_coder_id:
            coder_substitutions[wrapped_coder_id, with_bytes] = wrapped_coder_id
        return coder_substitutions[coder_id, with_bytes]

      def fix_pcoll_coder(pcoll):
        new_coder_id = wrap_unknown_coders(pcoll.coder_id, False)
        safe_coders[new_coder_id] = wrap_unknown_coders(pcoll.coder_id, True)
        pcoll.coder_id = new_coder_id

      for stage in stages:
        assert len(stage.transforms) == 1
        transform = stage.transforms[0]
        if transform.spec.urn == urns.GROUP_BY_KEY_ONLY_TRANSFORM:
          for pcoll_id in transform.inputs.values():
            fix_pcoll_coder(pipeline_components.pcollections[pcoll_id])
          for pcoll_id in transform.outputs.values():
            fix_pcoll_coder(pipeline_components.pcollections[pcoll_id])

          # This is used later to correlate the read and write.
          param = str("group:%s" % stage.name)
          gbk_write = Stage(
              transform.unique_name + '/Write',
              [beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Write',
                  inputs=transform.inputs,
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=bundle_processor.DATA_OUTPUT_URN,
                      any_param=proto_utils.pack_Any(
                          wrappers_pb2.BytesValue(value=param)),
                      payload=param))],
              downstream_side_inputs=frozenset(),
              must_follow=stage.must_follow)
          yield gbk_write

          yield Stage(
              transform.unique_name + '/Read',
              [beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Read',
                  outputs=transform.outputs,
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=bundle_processor.DATA_INPUT_URN,
                      any_param=proto_utils.pack_Any(
                          wrappers_pb2.BytesValue(value=param)),
                      payload=param))],
              downstream_side_inputs=frozenset(),
              must_follow=union(frozenset([gbk_write]), stage.must_follow))
        else:
          yield stage

    def sink_flattens(stages):
      """Sink flattens and remove them from the graph.

      A flatten that cannot be sunk/fused away becomes multiple writes (to the
      same logical sink) followed by a read.
      """
      # TODO(robertwb): Actually attempt to sink rather than always materialize.
      # TODO(robertwb): Possibly fuse this into one of the stages.
      pcollections = pipeline_components.pcollections
      for stage in stages:
        assert len(stage.transforms) == 1
        transform = stage.transforms[0]
        if transform.spec.urn == urns.FLATTEN_TRANSFORM:
          # This is used later to correlate the read and writes.
          param = str("materialize:%s" % transform.unique_name)
          output_pcoll_id, = transform.outputs.values()
          output_coder_id = pcollections[output_pcoll_id].coder_id
          flatten_writes = []
          for local_in, pcoll_in in transform.inputs.items():

            if pcollections[pcoll_in].coder_id != output_coder_id:
              # Flatten inputs must all be written with the same coder as is
              # used to read them.
              pcollections[pcoll_in].coder_id = output_coder_id
              transcoded_pcollection = (
                  transform.unique_name + '/Transcode/' + local_in + '/out')
              yield Stage(
                  transform.unique_name + '/Transcode/' + local_in,
                  [beam_runner_api_pb2.PTransform(
                      unique_name=
                      transform.unique_name + '/Transcode/' + local_in,
                      inputs={local_in: pcoll_in},
                      outputs={'out': transcoded_pcollection},
                      spec=beam_runner_api_pb2.FunctionSpec(
                          urn=bundle_processor.IDENTITY_DOFN_URN))],
                  downstream_side_inputs=frozenset(),
                  must_follow=stage.must_follow)
              pcollections[transcoded_pcollection].CopyFrom(
                  pcollections[pcoll_in])
              pcollections[transcoded_pcollection].coder_id = output_coder_id
            else:
              transcoded_pcollection = pcoll_in

            flatten_write = Stage(
                transform.unique_name + '/Write/' + local_in,
                [beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/Write/' + local_in,
                    inputs={local_in: transcoded_pcollection},
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=bundle_processor.DATA_OUTPUT_URN,
                        any_param=proto_utils.pack_Any(
                            wrappers_pb2.BytesValue(
                                value=param)),
                        payload=param))],
                downstream_side_inputs=frozenset(),
                must_follow=stage.must_follow)
            flatten_writes.append(flatten_write)
            yield flatten_write

          yield Stage(
              transform.unique_name + '/Read',
              [beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Read',
                  outputs=transform.outputs,
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=bundle_processor.DATA_INPUT_URN,
                      any_param=proto_utils.pack_Any(
                          wrappers_pb2.BytesValue(
                              value=param)),
                      payload=param))],
              downstream_side_inputs=frozenset(),
              must_follow=union(frozenset(flatten_writes), stage.must_follow))

        else:
          yield stage

    def annotate_downstream_side_inputs(stages):
      """Annotate each stage with fusion-prohibiting information.

      Each stage is annotated with the (transitive) set of pcollections that
      depend on this stage that are also used later in the pipeline as a
      side input.

      While theoretically this could result in O(n^2) annotations, the size of
      each set is bounded by the number of side inputs (typically much smaller
      than the number of total nodes) and the number of *distinct* side-input
      sets is also generally small (and shared due to the use of union
      defined above).

      This representation is also amenable to simple recomputation on fusion.
      """
      consumers = collections.defaultdict(list)
      all_side_inputs = set()
      for stage in stages:
        for transform in stage.transforms:
          for input in transform.inputs.values():
            consumers[input].append(stage)
        for si in stage.side_inputs():
          all_side_inputs.add(si)
      all_side_inputs = frozenset(all_side_inputs)

      downstream_side_inputs_by_stage = {}

      def compute_downstream_side_inputs(stage):
        if stage not in downstream_side_inputs_by_stage:
          downstream_side_inputs = frozenset()
          for transform in stage.transforms:
            for output in transform.outputs.values():
              if output in all_side_inputs:
                downstream_side_inputs = union(downstream_side_inputs, output)
                for consumer in consumers[output]:
                  downstream_side_inputs = union(
                      downstream_side_inputs,
                      compute_downstream_side_inputs(consumer))
          downstream_side_inputs_by_stage[stage] = downstream_side_inputs
        return downstream_side_inputs_by_stage[stage]

      for stage in stages:
        stage.downstream_side_inputs = compute_downstream_side_inputs(stage)
      return stages

    def greedily_fuse(stages):
      """Places transforms sharing an edge in the same stage, whenever possible.
      """
      producers_by_pcoll = {}
      consumers_by_pcoll = collections.defaultdict(list)

      # Used to always reference the correct stage as the producer and
      # consumer maps are not updated when stages are fused away.
      replacements = {}

      def replacement(s):
        old_ss = []
        while s in replacements:
          old_ss.append(s)
          s = replacements[s]
        for old_s in old_ss[:-1]:
          replacements[old_s] = s
        return s

      def fuse(producer, consumer):
        fused = producer.fuse(consumer)
        replacements[producer] = fused
        replacements[consumer] = fused

      # First record the producers and consumers of each PCollection.
      for stage in stages:
        for transform in stage.transforms:
          for input in transform.inputs.values():
            consumers_by_pcoll[input].append(stage)
          for output in transform.outputs.values():
            producers_by_pcoll[output] = stage

      logging.debug('consumers\n%s', consumers_by_pcoll)
      logging.debug('producers\n%s', producers_by_pcoll)

      # Now try to fuse away all pcollections.
      for pcoll, producer in producers_by_pcoll.items():
        pcoll_as_param = str("materialize:%s" % pcoll)
        write_pcoll = None
        for consumer in consumers_by_pcoll[pcoll]:
          producer = replacement(producer)
          consumer = replacement(consumer)
          # Update consumer.must_follow set, as it's used in can_fuse.
          consumer.must_follow = set(
              replacement(s) for s in consumer.must_follow)
          if producer.can_fuse(consumer):
            fuse(producer, consumer)
          else:
            # If we can't fuse, do a read + write.
            if write_pcoll is None:
              write_pcoll = Stage(
                  pcoll + '/Write',
                  [beam_runner_api_pb2.PTransform(
                      unique_name=pcoll + '/Write',
                      inputs={'in': pcoll},
                      spec=beam_runner_api_pb2.FunctionSpec(
                          urn=bundle_processor.DATA_OUTPUT_URN,
                          any_param=proto_utils.pack_Any(
                              wrappers_pb2.BytesValue(
                                  value=pcoll_as_param)),
                          payload=pcoll_as_param))])
              fuse(producer, write_pcoll)
            if consumer.has_as_main_input(pcoll):
              read_pcoll = Stage(
                  pcoll + '/Read',
                  [beam_runner_api_pb2.PTransform(
                      unique_name=pcoll + '/Read',
                      outputs={'out': pcoll},
                      spec=beam_runner_api_pb2.FunctionSpec(
                          urn=bundle_processor.DATA_INPUT_URN,
                          any_param=proto_utils.pack_Any(
                              wrappers_pb2.BytesValue(
                                  value=pcoll_as_param)),
                          payload=pcoll_as_param))],
                  must_follow={write_pcoll})
              fuse(read_pcoll, consumer)

      # Everything that was originally a stage or a replacement, but wasn't
      # replaced, should be in the final graph.
      final_stages = frozenset(stages).union(replacements.values()).difference(
          replacements.keys())

      for stage in final_stages:
        # Update all references to their final values before throwing
        # the replacement data away.
        stage.must_follow = frozenset(replacement(s) for s in stage.must_follow)
        # Two reads of the same stage may have been fused.  This is unneeded.
        stage.deduplicate_read()
      return final_stages

    def sort_stages(stages):
      """Order stages suitable for sequential execution.
      """
      seen = set()
      ordered = []

      def process(stage):
        if stage not in seen:
          seen.add(stage)
          for prev in stage.must_follow:
            process(prev)
          ordered.append(stage)
      for stage in stages:
        process(stage)
      return ordered

    # Now actually apply the operations.

    pipeline_components = copy.deepcopy(pipeline_proto.components)

    # Reify coders.
    # TODO(BEAM-2717): Remove once Coders are already in proto.
    coders = pipeline_context.PipelineContext(pipeline_components).coders
    for pcoll in pipeline_components.pcollections.values():
      if pcoll.coder_id not in coders:
        window_coder = coders[
            pipeline_components.windowing_strategies[
                pcoll.windowing_strategy_id].window_coder_id]
        coder = WindowedValueCoder(
            registry.get_coder(pickler.loads(pcoll.coder_id)),
            window_coder=window_coder)
        pcoll.coder_id = coders.get_id(coder)
    coders.populate_map(pipeline_components.coders)

    # Initial set of stages are singleton transforms.
    stages = [
        Stage(name, [transform])
        for name, transform in pipeline_proto.components.transforms.items()
        if not transform.subtransforms]

    # Apply each phase in order.
    for phase in [
        annotate_downstream_side_inputs, expand_gbk, sink_flattens,
        greedily_fuse, sort_stages]:
      logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
      stages = list(phase(stages))
      logging.debug('Stages: %s', [str(s) for s in stages])

    # Return the (possibly mutated) context and ordered set of stages.
    return pipeline_components, stages, safe_coders