def with_state_iterables(self, coder_id):
     coder = self.components.coders[coder_id]
     if coder.spec.spec.urn == common_urns.coders.ITERABLE.urn:
         new_coder_id = unique_name(self.components.coders,
                                    coder_id + '_state_backed')
         new_coder = self.components.coders[new_coder_id]
         new_coder.CopyFrom(coder)
         new_coder.spec.spec.urn = common_urns.coders.STATE_BACKED_ITERABLE.urn
         new_coder.spec.spec.payload = b'1'
         new_coder.component_coder_ids[0] = self.with_state_iterables(
             coder.component_coder_ids[0])
         return new_coder_id
     else:
         new_component_ids = [
             self.with_state_iterables(c) for c in coder.component_coder_ids
         ]
         if new_component_ids == coder.component_coder_ids:
             return coder_id
         else:
             new_coder_id = unique_name(self.components.coders,
                                        coder_id + '_state_backed')
             self.components.coders[new_coder_id].CopyFrom(
                 beam_runner_api_pb2.Coder(
                     spec=coder.spec,
                     component_coder_ids=new_component_ids))
             return new_coder_id
Exemplo n.º 2
0
    def parse_coder(self, spec):
        context = pipeline_context.PipelineContext()
        coder_id = str(hash(str(spec)))
        component_ids = [
            context.coders.get_id(self.parse_coder(c))
            for c in spec.get('components', ())
        ]
        if spec.get('state'):

            def iterable_state_read(state_token, elem_coder):
                state = spec.get('state').get(state_token.decode('latin1'))
                if state is None:
                    state = ''
                input_stream = coder_impl.create_InputStream(
                    state.encode('latin1'))
                while input_stream.size() > 0:
                    yield elem_coder.decode_from_stream(input_stream, True)

            context.iterable_state_read = iterable_state_read

        context.coders.put_proto(
            coder_id,
            beam_runner_api_pb2.Coder(spec=beam_runner_api_pb2.FunctionSpec(
                urn=spec['urn'],
                payload=spec.get('payload', '').encode('latin1')),
                                      component_coder_ids=component_ids))
        return context.coders.get_by_id(coder_id)
Exemplo n.º 3
0
 def windowed_coder_id(coder_id):
   proto = beam_runner_api_pb2.Coder(
       spec=beam_runner_api_pb2.SdkFunctionSpec(
           spec=beam_runner_api_pb2.FunctionSpec(
               urn=urns.WINDOWED_VALUE_CODER)),
       component_coder_ids=[coder_id, window_coder_id])
   return add_or_get_coder_id(proto)
 def windowed_coder_id(coder_id, window_coder_id):
     proto = beam_runner_api_pb2.Coder(
         spec=beam_runner_api_pb2.SdkFunctionSpec(
             spec=beam_runner_api_pb2.FunctionSpec(
                 urn=common_urns.coders.WINDOWED_VALUE.urn)),
         component_coder_ids=[coder_id, window_coder_id])
     return pipeline_context.add_or_get_coder_id(proto,
                                                 coder_id + '_windowed')
Exemplo n.º 5
0
 def to_runner_api(self, context):
   urn, typed_param, components = self.to_runner_api_parameter(context)
   return beam_runner_api_pb2.Coder(
       spec=beam_runner_api_pb2.SdkFunctionSpec(
           spec=beam_runner_api_pb2.FunctionSpec(
               urn=urn,
               payload=typed_param.SerializeToString()
               if typed_param is not None else None)),
       component_coder_ids=[context.coders.get_id(c) for c in components])
Exemplo n.º 6
0
 def to_runner_api(self, context):
   # type: (PipelineContext) -> beam_runner_api_pb2.Coder
   urn, typed_param, components = self.to_runner_api_parameter(context)
   return beam_runner_api_pb2.Coder(
       spec=beam_runner_api_pb2.FunctionSpec(
           urn=urn,
           payload=typed_param if isinstance(typed_param, (bytes, type(None)))
           else typed_param.SerializeToString()),
       component_coder_ids=[context.coders.get_id(c) for c in components])
Exemplo n.º 7
0
 def make_coder(urn, *components):
     # type: (str, str) -> str
     coder_proto = beam_runner_api_pb2.Coder(
         spec=beam_runner_api_pb2.FunctionSpec(urn=urn),
         component_coder_ids=components)
     coder_id = self.uid('coder')
     coders[coder_id] = coder_proto
     pipeline_context.coders.put_proto(coder_id, coder_proto)
     return coder_id
Exemplo n.º 8
0
 def parse_coder(self, spec):
   context = pipeline_context.PipelineContext()
   coder_id = str(hash(str(spec)))
   component_ids = [context.coders.get_id(self.parse_coder(c))
                    for c in spec.get('components', ())]
   context.coders.put_proto(coder_id, beam_runner_api_pb2.Coder(
       spec=beam_runner_api_pb2.FunctionSpec(
           urn=spec['urn'], payload=spec.get('payload', '').encode('latin1')),
       component_coder_ids=component_ids))
   return context.coders.get_by_id(coder_id)
 def length_prefixed_and_safe_coder(self, coder_id):
     coder = self.components.coders[coder_id]
     if coder.spec.spec.urn == common_urns.coders.LENGTH_PREFIX.urn:
         return coder_id, self.bytes_coder_id
     elif coder.spec.spec.urn in self._KNOWN_CODER_URNS:
         new_component_ids = [
             self.length_prefixed_coder(c)
             for c in coder.component_coder_ids
         ]
         if new_component_ids == coder.component_coder_ids:
             new_coder_id = coder_id
         else:
             new_coder_id = unique_name(self.components.coders,
                                        coder_id + '_length_prefixed')
             self.components.coders[new_coder_id].CopyFrom(
                 beam_runner_api_pb2.Coder(
                     spec=coder.spec,
                     component_coder_ids=new_component_ids))
         safe_component_ids = [
             self.safe_coders[c] for c in new_component_ids
         ]
         if safe_component_ids == coder.component_coder_ids:
             safe_coder_id = coder_id
         else:
             safe_coder_id = unique_name(self.components.coders,
                                         coder_id + '_safe')
             self.components.coders[safe_coder_id].CopyFrom(
                 beam_runner_api_pb2.Coder(
                     spec=coder.spec,
                     component_coder_ids=safe_component_ids))
         return new_coder_id, safe_coder_id
     else:
         new_coder_id = unique_name(self.components.coders,
                                    coder_id + '_length_prefixed')
         self.components.coders[new_coder_id].CopyFrom(
             beam_runner_api_pb2.Coder(
                 spec=beam_runner_api_pb2.SdkFunctionSpec(
                     spec=beam_runner_api_pb2.FunctionSpec(
                         urn=common_urns.coders.LENGTH_PREFIX.urn)),
                 component_coder_ids=[coder_id]))
         return new_coder_id, self.bytes_coder_id
Exemplo n.º 10
0
 def to_runner_api(self, context):
     urn, typed_param, components = self.to_runner_api_parameter(context)
     return beam_runner_api_pb2.Coder(
         spec=beam_runner_api_pb2.SdkFunctionSpec(
             environment_id=(context.default_environment_id()
                             if context else None),
             spec=beam_runner_api_pb2.FunctionSpec(
                 urn=urn,
                 payload=typed_param if isinstance(typed_param,
                                                   (bytes, type(None))) else
                 typed_param.SerializeToString())),
         component_coder_ids=[context.coders.get_id(c) for c in components])
Exemplo n.º 11
0
 def to_runner_api(self, context):
     """For internal use only; no backwards-compatibility guarantees.
 """
     # TODO(BEAM-115): Use specialized URNs and components.
     serialized_coder = serialize_coder(self)
     return beam_runner_api_pb2.Coder(
         spec=beam_runner_api_pb2.SdkFunctionSpec(
             spec=beam_runner_api_pb2.FunctionSpec(
                 urn=urns.PICKLED_CODER,
                 any_param=proto_utils.pack_Any(
                     google.protobuf.wrappers_pb2.BytesValue(
                         value=serialized_coder)),
                 payload=serialized_coder)))
Exemplo n.º 12
0
 def wrap_unknown_coders(coder_id, with_bytes):
     if (coder_id, with_bytes) not in coder_substitutions:
         wrapped_coder_id = None
         coder_proto = coders[coder_id]
         if coder_proto.spec.spec.urn == urns.LENGTH_PREFIX_CODER:
             coder_substitutions[coder_id, with_bytes] = (
                 bytes_coder_id if with_bytes else coder_id)
         elif coder_proto.spec.spec.urn in good_coder_urns:
             wrapped_components = [
                 wrap_unknown_coders(c, with_bytes)
                 for c in coder_proto.component_coder_ids
             ]
             if wrapped_components == list(
                     coder_proto.component_coder_ids):
                 # Use as is.
                 coder_substitutions[coder_id,
                                     with_bytes] = coder_id
             else:
                 wrapped_coder_id = unique_name(
                     coders, coder_id +
                     ("_bytes" if with_bytes else "_len_prefix"))
                 coders[wrapped_coder_id].CopyFrom(coder_proto)
                 coders[wrapped_coder_id].component_coder_ids[:] = [
                     wrap_unknown_coders(c, with_bytes)
                     for c in coder_proto.component_coder_ids
                 ]
                 coder_substitutions[coder_id,
                                     with_bytes] = wrapped_coder_id
         else:
             # Not a known coder.
             if with_bytes:
                 coder_substitutions[coder_id,
                                     with_bytes] = bytes_coder_id
             else:
                 wrapped_coder_id = unique_name(
                     coders, coder_id + "_len_prefix")
                 len_prefix_coder_proto = beam_runner_api_pb2.Coder(
                     spec=beam_runner_api_pb2.SdkFunctionSpec(
                         spec=beam_runner_api_pb2.FunctionSpec(
                             urn=urns.LENGTH_PREFIX_CODER)),
                     component_coder_ids=[coder_id])
                 coders[wrapped_coder_id].CopyFrom(
                     len_prefix_coder_proto)
                 coder_substitutions[coder_id,
                                     with_bytes] = wrapped_coder_id
         # This operation is idempotent.
         if wrapped_coder_id:
             coder_substitutions[wrapped_coder_id,
                                 with_bytes] = wrapped_coder_id
     return coder_substitutions[coder_id, with_bytes]
Exemplo n.º 13
0
def inject_timer_pcollections(stages, pipeline_context):
    """Create PCollections for fired timers and to-be-set timers.

  At execution time, fired timers and timers-to-set are represented as
  PCollections that are managed by the runner.  This phase adds the
  necissary collections, with their read and writes, to any stages using
  timers.
  """
    for stage in stages:
        for transform in list(stage.transforms):
            if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
                payload = proto_utils.parse_Bytes(
                    transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
                for tag, spec in payload.timer_specs.items():
                    if len(transform.inputs) > 1:
                        raise NotImplementedError('Timers and side inputs.')
                    input_pcoll = pipeline_context.components.pcollections[
                        next(iter(transform.inputs.values()))]
                    # Create the appropriate coder for the timer PCollection.
                    key_coder_id = input_pcoll.coder_id
                    if (pipeline_context.components.coders[key_coder_id].spec.
                            spec.urn == common_urns.coders.KV.urn):
                        key_coder_id = pipeline_context.components.coders[
                            key_coder_id].component_coder_ids[0]
                    key_timer_coder_id = pipeline_context.add_or_get_coder_id(
                        beam_runner_api_pb2.Coder(
                            spec=beam_runner_api_pb2.SdkFunctionSpec(
                                spec=beam_runner_api_pb2.FunctionSpec(
                                    urn=common_urns.coders.KV.urn)),
                            component_coder_ids=[
                                key_coder_id, spec.timer_coder_id
                            ]))
                    # Inject the read and write pcollections.
                    timer_read_pcoll = unique_name(
                        pipeline_context.components.pcollections,
                        '%s_timers_to_read_%s' % (transform.unique_name, tag))
                    timer_write_pcoll = unique_name(
                        pipeline_context.components.pcollections,
                        '%s_timers_to_write_%s' % (transform.unique_name, tag))
                    pipeline_context.components.pcollections[
                        timer_read_pcoll].CopyFrom(
                            beam_runner_api_pb2.PCollection(
                                unique_name=timer_read_pcoll,
                                coder_id=key_timer_coder_id,
                                windowing_strategy_id=input_pcoll.
                                windowing_strategy_id,
                                is_bounded=input_pcoll.is_bounded))
                    pipeline_context.components.pcollections[
                        timer_write_pcoll].CopyFrom(
                            beam_runner_api_pb2.PCollection(
                                unique_name=timer_write_pcoll,
                                coder_id=key_timer_coder_id,
                                windowing_strategy_id=input_pcoll.
                                windowing_strategy_id,
                                is_bounded=input_pcoll.is_bounded))
                    stage.transforms.append(
                        beam_runner_api_pb2.PTransform(
                            unique_name=timer_read_pcoll + '/Read',
                            outputs={'out': timer_read_pcoll},
                            spec=beam_runner_api_pb2.FunctionSpec(
                                urn=bundle_processor.DATA_INPUT_URN,
                                payload=create_buffer_id(timer_read_pcoll,
                                                         kind='timers'))))
                    stage.transforms.append(
                        beam_runner_api_pb2.PTransform(
                            unique_name=timer_write_pcoll + '/Write',
                            inputs={'in': timer_write_pcoll},
                            spec=beam_runner_api_pb2.FunctionSpec(
                                urn=bundle_processor.DATA_OUTPUT_URN,
                                payload=create_buffer_id(timer_write_pcoll,
                                                         kind='timers'))))
                    assert tag not in transform.inputs
                    transform.inputs[tag] = timer_read_pcoll
                    assert tag not in transform.outputs
                    transform.outputs[tag] = timer_write_pcoll
                    stage.timer_pcollections.append(
                        (timer_read_pcoll + '/Read', timer_write_pcoll))
        yield stage
Exemplo n.º 14
0
def lift_combiners(stages, context):
    """Expands CombinePerKey into pre- and post-grouping stages.

  ... -> CombinePerKey -> ...

  becomes

  ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
  """
    for stage in stages:
        assert len(stage.transforms) == 1
        transform = stage.transforms[0]
        if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
            combine_payload = proto_utils.parse_Bytes(
                transform.spec.payload, beam_runner_api_pb2.CombinePayload)

            input_pcoll = context.components.pcollections[only_element(
                list(transform.inputs.values()))]
            output_pcoll = context.components.pcollections[only_element(
                list(transform.outputs.values()))]

            element_coder_id = input_pcoll.coder_id
            element_coder = context.components.coders[element_coder_id]
            key_coder_id, _ = element_coder.component_coder_ids
            accumulator_coder_id = combine_payload.accumulator_coder_id

            key_accumulator_coder = beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.SdkFunctionSpec(
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.coders.KV.urn)),
                component_coder_ids=[key_coder_id, accumulator_coder_id])
            key_accumulator_coder_id = context.add_or_get_coder_id(
                key_accumulator_coder)

            accumulator_iter_coder = beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.SdkFunctionSpec(
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.coders.ITERABLE.urn)),
                component_coder_ids=[accumulator_coder_id])
            accumulator_iter_coder_id = context.add_or_get_coder_id(
                accumulator_iter_coder)

            key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.SdkFunctionSpec(
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.coders.KV.urn)),
                component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
            key_accumulator_iter_coder_id = context.add_or_get_coder_id(
                key_accumulator_iter_coder)

            precombined_pcoll_id = unique_name(context.components.pcollections,
                                               'pcollection')
            context.components.pcollections[precombined_pcoll_id].CopyFrom(
                beam_runner_api_pb2.PCollection(
                    unique_name=transform.unique_name + '/Precombine.out',
                    coder_id=key_accumulator_coder_id,
                    windowing_strategy_id=input_pcoll.windowing_strategy_id,
                    is_bounded=input_pcoll.is_bounded))

            grouped_pcoll_id = unique_name(context.components.pcollections,
                                           'pcollection')
            context.components.pcollections[grouped_pcoll_id].CopyFrom(
                beam_runner_api_pb2.PCollection(
                    unique_name=transform.unique_name + '/Group.out',
                    coder_id=key_accumulator_iter_coder_id,
                    windowing_strategy_id=output_pcoll.windowing_strategy_id,
                    is_bounded=output_pcoll.is_bounded))

            merged_pcoll_id = unique_name(context.components.pcollections,
                                          'pcollection')
            context.components.pcollections[merged_pcoll_id].CopyFrom(
                beam_runner_api_pb2.PCollection(
                    unique_name=transform.unique_name + '/Merge.out',
                    coder_id=key_accumulator_coder_id,
                    windowing_strategy_id=output_pcoll.windowing_strategy_id,
                    is_bounded=output_pcoll.is_bounded))

            def make_stage(base_stage, transform):
                return Stage(
                    transform.unique_name, [transform],
                    downstream_side_inputs=base_stage.downstream_side_inputs,
                    must_follow=base_stage.must_follow,
                    parent=base_stage.name)

            yield make_stage(
                stage,
                beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/Precombine',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.combine_components.
                        COMBINE_PER_KEY_PRECOMBINE.urn,
                        payload=transform.spec.payload),
                    inputs=transform.inputs,
                    outputs={'out': precombined_pcoll_id}))

            yield make_stage(
                stage,
                beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/Group',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.primitives.GROUP_BY_KEY.urn),
                    inputs={'in': precombined_pcoll_id},
                    outputs={'out': grouped_pcoll_id}))

            yield make_stage(
                stage,
                beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/Merge',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.combine_components.
                        COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
                        payload=transform.spec.payload),
                    inputs={'in': grouped_pcoll_id},
                    outputs={'out': merged_pcoll_id}))

            yield make_stage(
                stage,
                beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/ExtractOutputs',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.combine_components.
                        COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
                        payload=transform.spec.payload),
                    inputs={'in': merged_pcoll_id},
                    outputs=transform.outputs))

        else:
            yield stage
Exemplo n.º 15
0
def expand_sdf(stages, context):
  """Transforms splitable DoFns into pair+split+read."""
  for stage in stages:
    assert len(stage.transforms) == 1
    transform = stage.transforms[0]
    if transform.spec.urn == common_urns.primitives.PAR_DO.urn:

      pardo_payload = proto_utils.parse_Bytes(
          transform.spec.payload, beam_runner_api_pb2.ParDoPayload)

      if pardo_payload.splittable:

        def copy_like(protos, original, suffix='_copy', **kwargs):
          if isinstance(original, (str, unicode)):
            key = original
            original = protos[original]
          else:
            key = 'component'
          new_id = unique_name(protos, key + suffix)
          protos[new_id].CopyFrom(original)
          proto = protos[new_id]
          for name, value in kwargs.items():
            if isinstance(value, dict):
              getattr(proto, name).clear()
              getattr(proto, name).update(value)
            elif isinstance(value, list):
              del getattr(proto, name)[:]
              getattr(proto, name).extend(value)
            elif name == 'urn':
              proto.spec.urn = value
            else:
              setattr(proto, name, value)
          return new_id

        def make_stage(base_stage, transform_id, extra_must_follow=()):
          transform = context.components.transforms[transform_id]
          return Stage(
              transform.unique_name,
              [transform],
              base_stage.downstream_side_inputs,
              union(base_stage.must_follow, frozenset(extra_must_follow)),
              parent=base_stage,
              environment=base_stage.environment)

        main_input_tag = only_element(tag for tag in transform.inputs.keys()
                                      if tag not in pardo_payload.side_inputs)
        main_input_id = transform.inputs[main_input_tag]
        element_coder_id = context.components.pcollections[
            main_input_id].coder_id
        paired_coder_id = context.add_or_get_coder_id(
            beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.SdkFunctionSpec(
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.coders.KV.urn)),
                component_coder_ids=[element_coder_id,
                                     pardo_payload.restriction_coder_id]))

        paired_pcoll_id = copy_like(
            context.components.pcollections,
            main_input_id,
            '_paired',
            coder_id=paired_coder_id)
        pair_transform_id = copy_like(
            context.components.transforms,
            transform,
            unique_name=transform.unique_name + '/PairWithRestriction',
            urn=common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
            outputs={'out': paired_pcoll_id})

        split_pcoll_id = copy_like(
            context.components.pcollections,
            main_input_id,
            '_split',
            coder_id=paired_coder_id)
        split_transform_id = copy_like(
            context.components.transforms,
            transform,
            unique_name=transform.unique_name + '/SplitRestriction',
            urn=common_urns.sdf_components.SPLIT_RESTRICTION.urn,
            inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}),
            outputs={'out': split_pcoll_id})

        process_transform_id = copy_like(
            context.components.transforms,
            transform,
            unique_name=transform.unique_name + '/Process',
            urn=common_urns.sdf_components.PROCESS_ELEMENTS.urn,
            inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id}))

        yield make_stage(stage, pair_transform_id)
        split_stage = make_stage(stage, split_transform_id)
        yield split_stage
        yield make_stage(
            stage, process_transform_id, extra_must_follow=[split_stage])

      else:
        yield stage

    else:
      yield stage
Exemplo n.º 16
0
    def lift_combiners(stages):
      """Expands CombinePerKey into pre- and post-grouping stages.

      ... -> CombinePerKey -> ...

      becomes

      ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
      """
      def add_or_get_coder_id(coder_proto):
        for coder_id, coder in pipeline_components.coders.items():
          if coder == coder_proto:
            return coder_id
        new_coder_id = unique_name(pipeline_components.coders, 'coder')
        pipeline_components.coders[new_coder_id].CopyFrom(coder_proto)
        return new_coder_id

      def windowed_coder_id(coder_id):
        proto = beam_runner_api_pb2.Coder(
            spec=beam_runner_api_pb2.SdkFunctionSpec(
                spec=beam_runner_api_pb2.FunctionSpec(
                    urn=urns.WINDOWED_VALUE_CODER)),
            component_coder_ids=[coder_id, window_coder_id])
        return add_or_get_coder_id(proto)

      for stage in stages:
        assert len(stage.transforms) == 1
        transform = stage.transforms[0]
        if transform.spec.urn == urns.COMBINE_PER_KEY_TRANSFORM:
          combine_payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.CombinePayload)

          input_pcoll = pipeline_components.pcollections[only_element(
              transform.inputs.values())]
          output_pcoll = pipeline_components.pcollections[only_element(
              transform.outputs.values())]

          windowed_input_coder = pipeline_components.coders[
              input_pcoll.coder_id]
          element_coder_id, window_coder_id = (
              windowed_input_coder.component_coder_ids)
          element_coder = pipeline_components.coders[element_coder_id]
          key_coder_id, _ = element_coder.component_coder_ids
          accumulator_coder_id = combine_payload.accumulator_coder_id

          key_accumulator_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.KV_CODER)),
              component_coder_ids=[key_coder_id, accumulator_coder_id])
          key_accumulator_coder_id = add_or_get_coder_id(key_accumulator_coder)

          accumulator_iter_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.ITERABLE_CODER)),
              component_coder_ids=[accumulator_coder_id])
          accumulator_iter_coder_id = add_or_get_coder_id(
              accumulator_iter_coder)

          key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.KV_CODER)),
              component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
          key_accumulator_iter_coder_id = add_or_get_coder_id(
              key_accumulator_iter_coder)

          precombined_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[precombined_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Precombine.out',
                  coder_id=windowed_coder_id(key_accumulator_coder_id),
                  windowing_strategy_id=input_pcoll.windowing_strategy_id,
                  is_bounded=input_pcoll.is_bounded))

          grouped_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[grouped_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Group.out',
                  coder_id=windowed_coder_id(key_accumulator_iter_coder_id),
                  windowing_strategy_id=output_pcoll.windowing_strategy_id,
                  is_bounded=output_pcoll.is_bounded))

          merged_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[merged_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Merge.out',
                  coder_id=windowed_coder_id(key_accumulator_coder_id),
                  windowing_strategy_id=output_pcoll.windowing_strategy_id,
                  is_bounded=output_pcoll.is_bounded))

          def make_stage(base_stage, transform):
            return Stage(
                transform.unique_name,
                [transform],
                downstream_side_inputs=base_stage.downstream_side_inputs,
                must_follow=base_stage.must_follow)

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Precombine',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.PRECOMBINE_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs=transform.inputs,
                  outputs={'out': precombined_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Group',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.GROUP_BY_KEY_TRANSFORM),
                  inputs={'in': precombined_pcoll_id},
                  outputs={'out': grouped_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Merge',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.MERGE_ACCUMULATORS_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs={'in': grouped_pcoll_id},
                  outputs={'out': merged_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/ExtractOutputs',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.EXTRACT_OUTPUTS_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs={'in': merged_pcoll_id},
                  outputs=transform.outputs))

        else:
          yield stage