Beispiel #1
0
 def to_runner_api(self, context):
     return beam_runner_api_pb2.PCollection(
         unique_name=self._unique_name(),
         coder_id=context.coder_id_from_element_type(self.element_type),
         is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED,
         windowing_strategy_id=context.windowing_strategies.get_id(
             self.windowing))
def read_to_impulse(stages, pipeline_context):
  """Translates Read operations into Impulse operations."""
  for stage in stages:
    # First map Reads, if any, to Impulse + triggered read op.
    for transform in list(stage.transforms):
      if transform.spec.urn == common_urns.deprecated_primitives.READ.urn:
        read_pc = only_element(transform.outputs.values())
        read_pc_proto = pipeline_context.components.pcollections[read_pc]
        impulse_pc = unique_name(
            pipeline_context.components.pcollections, 'Impulse')
        pipeline_context.components.pcollections[impulse_pc].CopyFrom(
            beam_runner_api_pb2.PCollection(
                unique_name=impulse_pc,
                coder_id=pipeline_context.bytes_coder_id,
                windowing_strategy_id=read_pc_proto.windowing_strategy_id,
                is_bounded=read_pc_proto.is_bounded))
        stage.transforms.remove(transform)
        # TODO(robertwb): If this goes multi-process before fn-api
        # read is default, expand into split + reshuffle + read.
        stage.transforms.append(
            beam_runner_api_pb2.PTransform(
                unique_name=transform.unique_name + '/Impulse',
                spec=beam_runner_api_pb2.FunctionSpec(
                    urn=common_urns.primitives.IMPULSE.urn),
                outputs={'out': impulse_pc}))
        stage.transforms.append(
            beam_runner_api_pb2.PTransform(
                unique_name=transform.unique_name,
                spec=beam_runner_api_pb2.FunctionSpec(
                    urn=python_urns.IMPULSE_READ_TRANSFORM,
                    payload=transform.spec.payload),
                inputs={'in': impulse_pc},
                outputs={'out': read_pc}))

    yield stage
Beispiel #3
0
 def to_runner_api(self, context):
     return beam_runner_api_pb2.PCollection(
         unique_name='%d%s.%s' % (len(
             self.producer.full_label), self.producer.full_label, self.tag),
         coder_id=pickler.dumps(self.element_type),
         is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED,
         windowing_strategy_id=context.windowing_strategies.get_id(
             self.windowing))
Beispiel #4
0
 def to_runner_api(self, context):
     # type: (PipelineContext) -> beam_runner_api_pb2.PCollection
     return beam_runner_api_pb2.PCollection(
         unique_name=self._unique_name(),
         coder_id=context.coder_id_from_element_type(
             self.element_type, self.requires_deterministic_key_coder),
         is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED
         if self.is_bounded else beam_runner_api_pb2.IsBounded.UNBOUNDED,
         windowing_strategy_id=context.windowing_strategies.get_id(
             self.windowing))
Beispiel #5
0
 def to_runner_api(self, context):
   from apache_beam.portability.api import beam_runner_api_pb2
   from apache_beam.internal import pickler
   return beam_runner_api_pb2.PCollection(
       unique_name='%d%s.%s' % (
           len(self.producer.full_label), self.producer.full_label, self.tag),
       coder_id=pickler.dumps(self.element_type),
       is_bounded=beam_runner_api_pb2.BOUNDED,
       windowing_strategy_id=context.windowing_strategies.get_id(
           self.windowing))
Beispiel #6
0
  def _map_task_to_protos(self, map_task, data_operation_spec):
    input_data = {}
    side_input_data = {}
    runner_sinks = {}

    context = pipeline_context.PipelineContext()
    transform_protos = {}
    used_pcollections = {}

    def uniquify(*names):
      # An injective mapping from string* to string.
      return ':'.join("%s:%d" % (name, len(name)) for name in names)

    def pcollection_id(op_ix, out_ix):
      if (op_ix, out_ix) not in used_pcollections:
        used_pcollections[op_ix, out_ix] = uniquify(
            map_task[op_ix][0], 'out', str(out_ix))
      return used_pcollections[op_ix, out_ix]

    def get_inputs(op):
      if hasattr(op, 'inputs'):
        inputs = op.inputs
      elif hasattr(op, 'input'):
        inputs = [op.input]
      else:
        inputs = []
      return {'in%s' % ix: pcollection_id(*input)
              for ix, input in enumerate(inputs)}

    def get_outputs(op_ix):
      op = map_task[op_ix][1]
      return {tag: pcollection_id(op_ix, out_ix)
              for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))}

    for op_ix, (stage_name, operation) in enumerate(map_task):
      transform_id = uniquify(stage_name)

      if isinstance(operation, operation_specs.WorkerInMemoryWrite):
        # Write this data back to the runner.
        target_name = only_element(get_inputs(operation).keys())
        runner_sinks[(transform_id, target_name)] = operation
        transform_spec = beam_runner_api_pb2.FunctionSpec(
            urn=bundle_processor.DATA_OUTPUT_URN,
            any_param=proto_utils.pack_Any(data_operation_spec),
            payload=data_operation_spec.SerializeToString() \
                if data_operation_spec is not None else None)

      elif isinstance(operation, operation_specs.WorkerRead):
        # A Read from an in-memory source is done over the data plane.
        if (isinstance(operation.source.source,
                       maptask_executor_runner.InMemorySource)
            and isinstance(operation.source.source.default_output_coder(),
                           WindowedValueCoder)):
          target_name = only_element(get_outputs(op_ix).keys())
          input_data[(transform_id, target_name)] = self._reencode_elements(
              operation.source.source.read(None),
              operation.source.source.default_output_coder())
          transform_spec = beam_runner_api_pb2.FunctionSpec(
              urn=bundle_processor.DATA_INPUT_URN,
              any_param=proto_utils.pack_Any(data_operation_spec),
              payload=data_operation_spec.SerializeToString() \
                  if data_operation_spec is not None else None)

        else:
          # Otherwise serialize the source and execute it there.
          # TODO: Use SDFs with an initial impulse.
          # The Dataflow runner harness strips the base64 encoding. do the same
          # here until we get the same thing back that we sent in.
          source_bytes = base64.b64decode(
              pickler.dumps(operation.source.source))
          transform_spec = beam_runner_api_pb2.FunctionSpec(
              urn=bundle_processor.PYTHON_SOURCE_URN,
              any_param=proto_utils.pack_Any(
                  wrappers_pb2.BytesValue(
                      value=source_bytes)),
              payload=source_bytes)

      elif isinstance(operation, operation_specs.WorkerDoFn):
        # Record the contents of each side input for access via the state api.
        side_input_extras = []
        for si in operation.side_inputs:
          assert isinstance(si.source, iobase.BoundedSource)
          element_coder = si.source.default_output_coder()
          # TODO(robertwb): Actually flesh out the ViewFn API.
          side_input_extras.append((si.tag, element_coder))
          side_input_data[
              bundle_processor.side_input_tag(transform_id, si.tag)] = (
                  self._reencode_elements(
                      si.source.read(si.source.get_range_tracker(None, None)),
                      element_coder))
        augmented_serialized_fn = pickler.dumps(
            (operation.serialized_fn, side_input_extras))
        transform_spec = beam_runner_api_pb2.FunctionSpec(
            urn=bundle_processor.PYTHON_DOFN_URN,
            any_param=proto_utils.pack_Any(
                wrappers_pb2.BytesValue(value=augmented_serialized_fn)),
            payload=augmented_serialized_fn)

      elif isinstance(operation, operation_specs.WorkerFlatten):
        # Flatten is nice and simple.
        transform_spec = beam_runner_api_pb2.FunctionSpec(
            urn=bundle_processor.IDENTITY_DOFN_URN)

      else:
        raise NotImplementedError(operation)

      transform_protos[transform_id] = beam_runner_api_pb2.PTransform(
          unique_name=stage_name,
          spec=transform_spec,
          inputs=get_inputs(operation),
          outputs=get_outputs(op_ix))

    pcollection_protos = {
        name: beam_runner_api_pb2.PCollection(
            unique_name=name,
            coder_id=context.coders.get_id(
                map_task[op_id][1].output_coders[out_id]))
        for (op_id, out_id), name in used_pcollections.items()
    }
    # Must follow creation of pcollection_protos to capture used coders.
    context_proto = context.to_runner_api()
    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms=transform_protos,
        pcollections=pcollection_protos,
        coders=dict(context_proto.coders.items()),
        windowing_strategies=dict(context_proto.windowing_strategies.items()),
        environments=dict(context_proto.environments.items()))
    return input_data, side_input_data, runner_sinks, process_bundle_descriptor
def inject_timer_pcollections(stages, pipeline_context):
    """Create PCollections for fired timers and to-be-set timers.

  At execution time, fired timers and timers-to-set are represented as
  PCollections that are managed by the runner.  This phase adds the
  necissary collections, with their read and writes, to any stages using
  timers.
  """
    for stage in stages:
        for transform in list(stage.transforms):
            if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
                payload = proto_utils.parse_Bytes(
                    transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
                for tag, spec in payload.timer_specs.items():
                    if len(transform.inputs) > 1:
                        raise NotImplementedError('Timers and side inputs.')
                    input_pcoll = pipeline_context.components.pcollections[
                        next(iter(transform.inputs.values()))]
                    # Create the appropriate coder for the timer PCollection.
                    key_coder_id = input_pcoll.coder_id
                    if (pipeline_context.components.coders[key_coder_id].spec.
                            spec.urn == common_urns.coders.KV.urn):
                        key_coder_id = pipeline_context.components.coders[
                            key_coder_id].component_coder_ids[0]
                    key_timer_coder_id = pipeline_context.add_or_get_coder_id(
                        beam_runner_api_pb2.Coder(
                            spec=beam_runner_api_pb2.SdkFunctionSpec(
                                spec=beam_runner_api_pb2.FunctionSpec(
                                    urn=common_urns.coders.KV.urn)),
                            component_coder_ids=[
                                key_coder_id, spec.timer_coder_id
                            ]))
                    # Inject the read and write pcollections.
                    timer_read_pcoll = unique_name(
                        pipeline_context.components.pcollections,
                        '%s_timers_to_read_%s' % (transform.unique_name, tag))
                    timer_write_pcoll = unique_name(
                        pipeline_context.components.pcollections,
                        '%s_timers_to_write_%s' % (transform.unique_name, tag))
                    pipeline_context.components.pcollections[
                        timer_read_pcoll].CopyFrom(
                            beam_runner_api_pb2.PCollection(
                                unique_name=timer_read_pcoll,
                                coder_id=key_timer_coder_id,
                                windowing_strategy_id=input_pcoll.
                                windowing_strategy_id,
                                is_bounded=input_pcoll.is_bounded))
                    pipeline_context.components.pcollections[
                        timer_write_pcoll].CopyFrom(
                            beam_runner_api_pb2.PCollection(
                                unique_name=timer_write_pcoll,
                                coder_id=key_timer_coder_id,
                                windowing_strategy_id=input_pcoll.
                                windowing_strategy_id,
                                is_bounded=input_pcoll.is_bounded))
                    stage.transforms.append(
                        beam_runner_api_pb2.PTransform(
                            unique_name=timer_read_pcoll + '/Read',
                            outputs={'out': timer_read_pcoll},
                            spec=beam_runner_api_pb2.FunctionSpec(
                                urn=bundle_processor.DATA_INPUT_URN,
                                payload=create_buffer_id(timer_read_pcoll,
                                                         kind='timers'))))
                    stage.transforms.append(
                        beam_runner_api_pb2.PTransform(
                            unique_name=timer_write_pcoll + '/Write',
                            inputs={'in': timer_write_pcoll},
                            spec=beam_runner_api_pb2.FunctionSpec(
                                urn=bundle_processor.DATA_OUTPUT_URN,
                                payload=create_buffer_id(timer_write_pcoll,
                                                         kind='timers'))))
                    assert tag not in transform.inputs
                    transform.inputs[tag] = timer_read_pcoll
                    assert tag not in transform.outputs
                    transform.outputs[tag] = timer_write_pcoll
                    stage.timer_pcollections.append(
                        (timer_read_pcoll + '/Read', timer_write_pcoll))
        yield stage
def lift_combiners(stages, context):
    """Expands CombinePerKey into pre- and post-grouping stages.

  ... -> CombinePerKey -> ...

  becomes

  ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
  """
    for stage in stages:
        assert len(stage.transforms) == 1
        transform = stage.transforms[0]
        if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
            combine_payload = proto_utils.parse_Bytes(
                transform.spec.payload, beam_runner_api_pb2.CombinePayload)

            input_pcoll = context.components.pcollections[only_element(
                list(transform.inputs.values()))]
            output_pcoll = context.components.pcollections[only_element(
                list(transform.outputs.values()))]

            element_coder_id = input_pcoll.coder_id
            element_coder = context.components.coders[element_coder_id]
            key_coder_id, _ = element_coder.component_coder_ids
            accumulator_coder_id = combine_payload.accumulator_coder_id

            key_accumulator_coder = beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.SdkFunctionSpec(
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.coders.KV.urn)),
                component_coder_ids=[key_coder_id, accumulator_coder_id])
            key_accumulator_coder_id = context.add_or_get_coder_id(
                key_accumulator_coder)

            accumulator_iter_coder = beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.SdkFunctionSpec(
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.coders.ITERABLE.urn)),
                component_coder_ids=[accumulator_coder_id])
            accumulator_iter_coder_id = context.add_or_get_coder_id(
                accumulator_iter_coder)

            key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.SdkFunctionSpec(
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.coders.KV.urn)),
                component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
            key_accumulator_iter_coder_id = context.add_or_get_coder_id(
                key_accumulator_iter_coder)

            precombined_pcoll_id = unique_name(context.components.pcollections,
                                               'pcollection')
            context.components.pcollections[precombined_pcoll_id].CopyFrom(
                beam_runner_api_pb2.PCollection(
                    unique_name=transform.unique_name + '/Precombine.out',
                    coder_id=key_accumulator_coder_id,
                    windowing_strategy_id=input_pcoll.windowing_strategy_id,
                    is_bounded=input_pcoll.is_bounded))

            grouped_pcoll_id = unique_name(context.components.pcollections,
                                           'pcollection')
            context.components.pcollections[grouped_pcoll_id].CopyFrom(
                beam_runner_api_pb2.PCollection(
                    unique_name=transform.unique_name + '/Group.out',
                    coder_id=key_accumulator_iter_coder_id,
                    windowing_strategy_id=output_pcoll.windowing_strategy_id,
                    is_bounded=output_pcoll.is_bounded))

            merged_pcoll_id = unique_name(context.components.pcollections,
                                          'pcollection')
            context.components.pcollections[merged_pcoll_id].CopyFrom(
                beam_runner_api_pb2.PCollection(
                    unique_name=transform.unique_name + '/Merge.out',
                    coder_id=key_accumulator_coder_id,
                    windowing_strategy_id=output_pcoll.windowing_strategy_id,
                    is_bounded=output_pcoll.is_bounded))

            def make_stage(base_stage, transform):
                return Stage(
                    transform.unique_name, [transform],
                    downstream_side_inputs=base_stage.downstream_side_inputs,
                    must_follow=base_stage.must_follow,
                    parent=base_stage.name)

            yield make_stage(
                stage,
                beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/Precombine',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.combine_components.
                        COMBINE_PER_KEY_PRECOMBINE.urn,
                        payload=transform.spec.payload),
                    inputs=transform.inputs,
                    outputs={'out': precombined_pcoll_id}))

            yield make_stage(
                stage,
                beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/Group',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.primitives.GROUP_BY_KEY.urn),
                    inputs={'in': precombined_pcoll_id},
                    outputs={'out': grouped_pcoll_id}))

            yield make_stage(
                stage,
                beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/Merge',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.combine_components.
                        COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
                        payload=transform.spec.payload),
                    inputs={'in': grouped_pcoll_id},
                    outputs={'out': merged_pcoll_id}))

            yield make_stage(
                stage,
                beam_runner_api_pb2.PTransform(
                    unique_name=transform.unique_name + '/ExtractOutputs',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.combine_components.
                        COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
                        payload=transform.spec.payload),
                    inputs={'in': merged_pcoll_id},
                    outputs=transform.outputs))

        else:
            yield stage
Beispiel #9
0
    def make_process_bundle_descriptor(self, data_api_service_descriptor,
                                       state_api_service_descriptor):
        # type: (Optional[endpoints_pb2.ApiServiceDescriptor], Optional[endpoints_pb2.ApiServiceDescriptor]) -> beam_fn_api_pb2.ProcessBundleDescriptor
        """Creates a ProcessBundleDescriptor for invoking the WindowFn's
    merge operation.
    """
        def make_channel_payload(coder_id):
            # type: (str) -> bytes
            data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
            if data_api_service_descriptor:
                data_spec.api_service_descriptor.url = (
                    data_api_service_descriptor.url)
            return data_spec.SerializeToString()

        pipeline_context = self._execution_context_ref().pipeline_context
        global_windowing_strategy_id = self.uid('global_windowing_strategy')
        global_windowing_strategy_proto = core.Windowing(
            window.GlobalWindows()).to_runner_api(pipeline_context)
        coders = dict(pipeline_context.coders.get_id_to_proto_map())

        def make_coder(urn, *components):
            # type: (str, str) -> str
            coder_proto = beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.FunctionSpec(urn=urn),
                component_coder_ids=components)
            coder_id = self.uid('coder')
            coders[coder_id] = coder_proto
            pipeline_context.coders.put_proto(coder_id, coder_proto)
            return coder_id

        bytes_coder_id = make_coder(common_urns.coders.BYTES.urn)
        window_coder_id = self._windowing_strategy_proto.window_coder_id
        global_window_coder_id = make_coder(
            common_urns.coders.GLOBAL_WINDOW.urn)
        iter_window_coder_id = make_coder(common_urns.coders.ITERABLE.urn,
                                          window_coder_id)
        input_coder_id = make_coder(common_urns.coders.KV.urn, bytes_coder_id,
                                    iter_window_coder_id)
        output_coder_id = make_coder(
            common_urns.coders.KV.urn, bytes_coder_id,
            make_coder(
                common_urns.coders.KV.urn, iter_window_coder_id,
                make_coder(
                    common_urns.coders.ITERABLE.urn,
                    make_coder(common_urns.coders.KV.urn, window_coder_id,
                               iter_window_coder_id))))
        windowed_input_coder_id = make_coder(
            common_urns.coders.WINDOWED_VALUE.urn, input_coder_id,
            global_window_coder_id)
        windowed_output_coder_id = make_coder(
            common_urns.coders.WINDOWED_VALUE.urn, output_coder_id,
            global_window_coder_id)

        self.windowed_input_coder_impl = pipeline_context.coders[
            windowed_input_coder_id].get_impl()
        self.windowed_output_coder_impl = pipeline_context.coders[
            windowed_output_coder_id].get_impl()

        self._bundle_processor_id = self.uid('merge_windows')
        return beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._bundle_processor_id,
            transforms={
                self.TO_SDK_TRANSFORM:
                beam_runner_api_pb2.PTransform(
                    unique_name='MergeWindows/Read',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=bundle_processor.DATA_INPUT_URN,
                        payload=make_channel_payload(windowed_input_coder_id)),
                    outputs={'input': 'input'}),
                'Merge':
                beam_runner_api_pb2.PTransform(
                    unique_name='MergeWindows/Merge',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.primitives.MERGE_WINDOWS.urn,
                        payload=self._windowing_strategy_proto.window_fn.
                        SerializeToString()),
                    inputs={'input': 'input'},
                    outputs={'output': 'output'}),
                self.FROM_SDK_TRANSFORM:
                beam_runner_api_pb2.PTransform(
                    unique_name='MergeWindows/Write',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=bundle_processor.DATA_OUTPUT_URN,
                        payload=make_channel_payload(
                            windowed_output_coder_id)),
                    inputs={'output': 'output'}),
            },
            pcollections={
                'input':
                beam_runner_api_pb2.PCollection(
                    unique_name='input',
                    windowing_strategy_id=global_windowing_strategy_id,
                    coder_id=input_coder_id),
                'output':
                beam_runner_api_pb2.PCollection(
                    unique_name='output',
                    windowing_strategy_id=global_windowing_strategy_id,
                    coder_id=output_coder_id),
            },
            coders=coders,
            windowing_strategies={
                global_windowing_strategy_id: global_windowing_strategy_proto,
            },
            environments=dict(self._execution_context_ref().
                              pipeline_components.environments.items()),
            state_api_service_descriptor=state_api_service_descriptor,
            timer_api_service_descriptor=data_api_service_descriptor)
Beispiel #10
0
    def lift_combiners(stages):
      """Expands CombinePerKey into pre- and post-grouping stages.

      ... -> CombinePerKey -> ...

      becomes

      ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
      """
      def add_or_get_coder_id(coder_proto):
        for coder_id, coder in pipeline_components.coders.items():
          if coder == coder_proto:
            return coder_id
        new_coder_id = unique_name(pipeline_components.coders, 'coder')
        pipeline_components.coders[new_coder_id].CopyFrom(coder_proto)
        return new_coder_id

      def windowed_coder_id(coder_id):
        proto = beam_runner_api_pb2.Coder(
            spec=beam_runner_api_pb2.SdkFunctionSpec(
                spec=beam_runner_api_pb2.FunctionSpec(
                    urn=urns.WINDOWED_VALUE_CODER)),
            component_coder_ids=[coder_id, window_coder_id])
        return add_or_get_coder_id(proto)

      for stage in stages:
        assert len(stage.transforms) == 1
        transform = stage.transforms[0]
        if transform.spec.urn == urns.COMBINE_PER_KEY_TRANSFORM:
          combine_payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.CombinePayload)

          input_pcoll = pipeline_components.pcollections[only_element(
              transform.inputs.values())]
          output_pcoll = pipeline_components.pcollections[only_element(
              transform.outputs.values())]

          windowed_input_coder = pipeline_components.coders[
              input_pcoll.coder_id]
          element_coder_id, window_coder_id = (
              windowed_input_coder.component_coder_ids)
          element_coder = pipeline_components.coders[element_coder_id]
          key_coder_id, _ = element_coder.component_coder_ids
          accumulator_coder_id = combine_payload.accumulator_coder_id

          key_accumulator_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.KV_CODER)),
              component_coder_ids=[key_coder_id, accumulator_coder_id])
          key_accumulator_coder_id = add_or_get_coder_id(key_accumulator_coder)

          accumulator_iter_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.ITERABLE_CODER)),
              component_coder_ids=[accumulator_coder_id])
          accumulator_iter_coder_id = add_or_get_coder_id(
              accumulator_iter_coder)

          key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.KV_CODER)),
              component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
          key_accumulator_iter_coder_id = add_or_get_coder_id(
              key_accumulator_iter_coder)

          precombined_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[precombined_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Precombine.out',
                  coder_id=windowed_coder_id(key_accumulator_coder_id),
                  windowing_strategy_id=input_pcoll.windowing_strategy_id,
                  is_bounded=input_pcoll.is_bounded))

          grouped_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[grouped_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Group.out',
                  coder_id=windowed_coder_id(key_accumulator_iter_coder_id),
                  windowing_strategy_id=output_pcoll.windowing_strategy_id,
                  is_bounded=output_pcoll.is_bounded))

          merged_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[merged_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Merge.out',
                  coder_id=windowed_coder_id(key_accumulator_coder_id),
                  windowing_strategy_id=output_pcoll.windowing_strategy_id,
                  is_bounded=output_pcoll.is_bounded))

          def make_stage(base_stage, transform):
            return Stage(
                transform.unique_name,
                [transform],
                downstream_side_inputs=base_stage.downstream_side_inputs,
                must_follow=base_stage.must_follow)

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Precombine',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.PRECOMBINE_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs=transform.inputs,
                  outputs={'out': precombined_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Group',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.GROUP_BY_KEY_TRANSFORM),
                  inputs={'in': precombined_pcoll_id},
                  outputs={'out': grouped_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Merge',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.MERGE_ACCUMULATORS_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs={'in': grouped_pcoll_id},
                  outputs={'out': merged_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/ExtractOutputs',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.EXTRACT_OUTPUTS_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs={'in': merged_pcoll_id},
                  outputs=transform.outputs))

        else:
          yield stage