Exemplo n.º 1
0
  def __getitem__(self, window):
    target_window = self._side_input_data.window_mapping_fn(window)
    if target_window not in self._cache:
      state_handler = self._state_handler
      access_pattern = self._side_input_data.access_pattern

      if access_pattern == common_urns.side_inputs.ITERABLE.urn:
        state_key = beam_fn_api_pb2.StateKey(
            iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
                transform_id=self._transform_id,
                side_input_id=self._tag,
                window=self._target_window_coder.encode(target_window)))
        raw_view = _StateBackedIterable(
            state_handler, state_key, self._element_coder)

      elif access_pattern == common_urns.side_inputs.MULTIMAP.urn:
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id=self._transform_id,
                side_input_id=self._tag,
                window=self._target_window_coder.encode(target_window),
                key=b''))
        cache = {}
        key_coder_impl = self._element_coder.key_coder().get_impl()
        value_coder = self._element_coder.value_coder()

        class MultiMap(object):
          def __getitem__(self, key):
            if key not in cache:
              keyed_state_key = beam_fn_api_pb2.StateKey()
              keyed_state_key.CopyFrom(state_key)
              keyed_state_key.multimap_side_input.key = (
                  key_coder_impl.encode_nested(key))
              cache[key] = _StateBackedIterable(
                  state_handler, keyed_state_key, value_coder)
            return cache[key]

          def __reduce__(self):
            # TODO(robertwb): Figure out how to support this.
            raise TypeError(common_urns.side_inputs.MULTIMAP.urn)

        raw_view = MultiMap()

      else:
        raise ValueError(
            "Unknown access pattern: '%s'" % access_pattern)

      self._cache[target_window] = self._side_input_data.view_fn(raw_view)
    return self._cache[target_window]
Exemplo n.º 2
0
 def _create_internal_map_state(self, name, encoded_namespace,
                                map_key_coder, map_value_coder, ttl_config,
                                cache_type):
     # Currently the `beam_fn_api.proto` does not support MapState, so we use the
     # the `MultimapSideInput` message to mark the state as a MapState for now.
     from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor
     state_proto = StateDescriptor()
     state_proto.state_name = name
     if ttl_config is not None:
         state_proto.state_ttl_config.CopyFrom(ttl_config._to_proto())
     state_key = beam_fn_api_pb2.StateKey(
         multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
             transform_id="",
             window=encoded_namespace,
             side_input_id=base64.b64encode(
                 state_proto.SerializeToString()),
             key=self._encoded_current_key))
     if cache_type == SynchronousKvRuntimeState.CacheType.DISABLE_CACHE:
         write_cache_size = 0
     else:
         write_cache_size = self._map_state_write_cache_size
     return InternalSynchronousMapRuntimeState(self._map_state_handler,
                                               state_key, map_key_coder,
                                               map_value_coder,
                                               write_cache_size)
Exemplo n.º 3
0
    def __getitem__(self, window):
        target_window = self._side_input_data.window_mapping_fn(window)
        if target_window not in self._cache:
            state_key = beam_fn_api_pb2.StateKey(
                multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                    ptransform_id=self._transform_id,
                    side_input_id=self._tag,
                    window=self._target_window_coder.encode(target_window),
                    key=''))
            state_handler = self._state_handler
            access_pattern = self._side_input_data.access_pattern

            class AllElements(object):
                def __init__(self, state_key, coder):
                    self._state_key = state_key
                    self._coder_impl = coder.get_impl()

                def __iter__(self):
                    # TODO(robertwb): Support pagination.
                    input_stream = coder_impl.create_InputStream(
                        state_handler.blocking_get(self._state_key, None))
                    while input_stream.size() > 0:
                        yield self._coder_impl.decode_from_stream(
                            input_stream, True)

                def __reduce__(self):
                    return list, (list(self), )

            if access_pattern == common_urns.ITERABLE_SIDE_INPUT:
                raw_view = AllElements(state_key, self._element_coder)

            elif access_pattern == common_urns.MULTIMAP_SIDE_INPUT:
                cache = {}
                key_coder_impl = self._element_coder.key_coder().get_impl()
                value_coder = self._element_coder.value_coder()

                class MultiMap(object):
                    def __getitem__(self, key):
                        if key not in cache:
                            keyed_state_key = beam_fn_api_pb2.StateKey()
                            keyed_state_key.CopyFrom(state_key)
                            keyed_state_key.multimap_side_input.key = (
                                key_coder_impl.encode_nested(key))
                            cache[key] = AllElements(keyed_state_key,
                                                     value_coder)
                        return cache[key]

                    def __reduce__(self):
                        # TODO(robertwb): Figure out how to support this.
                        raise TypeError(common_urns.MULTIMAP_SIDE_INPUT)

                raw_view = MultiMap()

            else:
                raise ValueError("Unknown access pattern: '%s'" %
                                 access_pattern)

            self._cache[target_window] = self._side_input_data.view_fn(
                raw_view)
        return self._cache[target_window]
Exemplo n.º 4
0
    def __getitem__(self, window):
        target_window = self._side_input_data.window_mapping_fn(window)
        if target_window not in self._cache:
            state_key = beam_fn_api_pb2.StateKey(
                multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                    ptransform_id=self._transform_id,
                    side_input_id=self._tag,
                    window=self._target_window_coder.encode(target_window)))
            element_coder_impl = self._element_coder.get_impl()
            state_handler = self._state_handler

            class AllElements(object):
                def __iter__(self):
                    # TODO(robertwb): Support pagination.
                    input_stream = coder_impl.create_InputStream(
                        state_handler.blocking_get(state_key, None))
                    while input_stream.size() > 0:
                        yield element_coder_impl.decode_from_stream(
                            input_stream, True)

                def __reduce__(self):
                    return list, (list(self), )

            self._cache[target_window] = self._side_input_data.view_fn(
                AllElements())
        return self._cache[target_window]
Exemplo n.º 5
0
    def test_extend_fetches_initial_state(self):
        coder = VarIntCoder()
        coder_impl = coder.get_impl()

        class UnderlyingStateHandler(object):
            """Simply returns an incremented counter as the state "value."
      """
            def set_value(self, value):
                self._encoded_values = coder.encode(value)

            def get_raw(self, *args):
                return self._encoded_values, None

            def append_raw(self, _key, bytes):
                self._encoded_values += bytes

            def clear(self, *args):
                self._encoded_values = bytes()

            @contextlib.contextmanager
            def process_instruction_id(self, bundle_id):
                yield

        underlying_state_handler = UnderlyingStateHandler()
        state_cache = statecache.StateCache(100)
        handler = sdk_worker.CachingStateHandler(state_cache,
                                                 underlying_state_handler)

        state = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))

        cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())

        def get():
            return list(handler.blocking_get(state, coder_impl, True))

        def append(value):
            handler.extend(state, coder_impl, [value], True)

        def clear():
            handler.clear(state, True)

        # Initialize state
        underlying_state_handler.set_value(42)
        with handler.process_instruction_id('bundle', [cache_token]):
            # Append without reading beforehand
            append(43)
            self.assertEqual(get(), [42, 43])
            clear()
            self.assertEqual(get(), [])
            append(44)
            self.assertEqual(get(), [44])
Exemplo n.º 6
0
 def iterable_state_write(values, element_coder_impl):
   token = unique_name(None, 'iter').encode('ascii')
   out = create_OutputStream()
   for element in values:
     element_coder_impl.encode_to_stream(element, out, True)
   controller.state.blocking_append(
       beam_fn_api_pb2.StateKey(
           runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
       out.get())
   return token
Exemplo n.º 7
0
 def _create_state(self, state_spec: userstate.StateSpec) -> userstate.AccumulatingRuntimeState:
     if isinstance(state_spec, userstate.BagStateSpec):
         bag_state = SynchronousBagRuntimeState(
             self._state_handler,
             state_key=beam_fn_api_pb2.StateKey(
                 bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                     transform_id="",
                     user_state_id=state_spec.name,
                     key=self._encoded_current_key)),
             value_coder=state_spec.coder)
         return bag_state
     else:
         raise NotImplementedError(state_spec)
Exemplo n.º 8
0
  def _store_side_inputs_in_state(self,
                                  runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
                                  data_side_input,  # type: DataSideInput
                                 ):
    # type: (...) -> None
    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
      _, pcoll_id = split_buffer_id(buffer_id)
      value_coder = runner_execution_context.pipeline_context.coders[
          runner_execution_context.safe_coders[
              runner_execution_context.data_channel_coders[pcoll_id]]]
      elements_by_window = WindowGroupingBuffer(si, value_coder)
      if buffer_id not in runner_execution_context.pcoll_buffers:
        runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer(
            coder_impl=value_coder.get_impl())
      for element_data in runner_execution_context.pcoll_buffers[buffer_id]:
        elements_by_window.append(element_data)

      if si.urn == common_urns.side_inputs.ITERABLE.urn:
        for _, window, elements_data in elements_by_window.encoded_items():
          state_key = beam_fn_api_pb2.StateKey(
              iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
                  transform_id=transform_id, side_input_id=tag, window=window))
          (
              runner_execution_context.worker_handler_manager.state_servicer.
              append_raw(state_key, elements_data))
      elif si.urn == common_urns.side_inputs.MULTIMAP.urn:
        for key, window, elements_data in elements_by_window.encoded_items():
          state_key = beam_fn_api_pb2.StateKey(
              multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                  transform_id=transform_id,
                  side_input_id=tag,
                  window=window,
                  key=key))
          (
              runner_execution_context.worker_handler_manager.state_servicer.
              append_raw(state_key, elements_data))
      else:
        raise ValueError("Unknown access pattern: '%s'" % si.urn)
Exemplo n.º 9
0
    def test_caching(self):

        coder = VarIntCoder()
        coder_impl = coder.get_impl()

        class FakeUnderlyingState(object):
            """Simply returns an incremented counter as the state "value."
      """
            def set_counter(self, n):
                self._counter = n

            def get_raw(self, *args):
                self._counter += 1
                return coder.encode(self._counter), None

            @contextlib.contextmanager
            def process_instruction_id(self, bundle_id):
                yield

        underlying_state = FakeUnderlyingState()
        state_cache = statecache.StateCache(100)
        caching_state_hander = sdk_worker.CachingStateHandler(
            state_cache, underlying_state)

        state1 = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))
        state2 = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state2'))
        side1 = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id='transform', side_input_id='side1'))
        side2 = beam_fn_api_pb2.StateKey(
            iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
                transform_id='transform', side_input_id='side2'))

        state_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())
        state_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token2',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())
        side1_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'side1_token1',
            side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            SideInput(transform_id='transform', side_input_id='side1'))
        side1_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'side1_token2',
            side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            SideInput(transform_id='transform', side_input_id='side1'))

        def get_as_list(key):
            return list(caching_state_hander.blocking_get(key, coder_impl))

        underlying_state.set_counter(100)
        with caching_state_hander.process_instruction_id('bundle1', []):
            self.assertEqual(get_as_list(state1), [101])  # uncached
            self.assertEqual(get_as_list(state2), [102])  # uncached
            self.assertEqual(get_as_list(state1), [101])  # cached on bundle
            self.assertEqual(get_as_list(side1), [103])  # uncached
            self.assertEqual(get_as_list(side2), [104])  # uncached

        underlying_state.set_counter(200)
        with caching_state_hander.process_instruction_id(
                'bundle2', [state_token1, side1_token1]):
            self.assertEqual(get_as_list(state1), [201])  # uncached
            self.assertEqual(get_as_list(state2), [202])  # uncached
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(side1), [203])  # uncached
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [204])  # uncached
            self.assertEqual(get_as_list(side2), [204])  # cached on bundle

        underlying_state.set_counter(300)
        with caching_state_hander.process_instruction_id(
                'bundle3', [state_token1, side1_token1]):
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(state2),
                             [202])  # cached on state token1
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [301])  # uncached
            self.assertEqual(get_as_list(side2), [301])  # cached on bundle

        underlying_state.set_counter(400)
        with caching_state_hander.process_instruction_id(
                'bundle4', [state_token2, side1_token1]):
            self.assertEqual(get_as_list(state1), [401])  # uncached
            self.assertEqual(get_as_list(state2), [402])  # uncached
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [403])  # uncached
            self.assertEqual(get_as_list(side2), [403])  # cached on bundle

        underlying_state.set_counter(500)
        with caching_state_hander.process_instruction_id(
                'bundle5', [state_token2, side1_token2]):
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(state2),
                             [402])  # cached on state token2
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(side1), [501])  # uncached
            self.assertEqual(get_as_list(side1),
                             [501])  # cached on side1_token2
            self.assertEqual(get_as_list(side2), [502])  # uncached
            self.assertEqual(get_as_list(side2), [502])  # cached on bundle
Exemplo n.º 10
0
  def run_stage(
      self, controller, pipeline_components, stage, pcoll_buffers, safe_coders):

    context = pipeline_context.PipelineContext(pipeline_components)
    data_operation_spec = controller.data_operation_spec()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data data_operation_spec.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          pcoll_id = transform.spec.payload
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            data_input[target] = pcoll_buffers[pcoll_id]
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
          else:
            raise NotImplementedError
          if data_operation_spec:
            transform.spec.payload = data_operation_spec.SerializeToString()
          else:
            transform.spec.payload = ""
        elif transform.spec.urn == urns.PARDO_TRANSFORM:
          payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
          for tag, si in payload.side_inputs.items():
            data_side_input[transform.unique_name, tag] = (
                'materialize:' + transform.inputs[tag],
                beam.pvalue.SideInputData.from_runner_api(si, None))
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    # Store the required side inputs into state.
    for (transform_id, tag), (pcoll_id, si) in data_side_input.items():
      elements_by_window = _WindowGroupingBuffer(si)
      for element_data in pcoll_buffers[pcoll_id]:
        elements_by_window.append(element_data)
      for window, elements_data in elements_by_window.items():
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                ptransform_id=transform_id,
                side_input_id=tag,
                window=window))
        controller.state_handler.blocking_append(state_key, elements_data, None)

    def get_buffer(pcoll_id):
      if pcoll_id.startswith('materialize:'):
        if pcoll_id not in pcoll_buffers:
          # Just store the data chunks for replay.
          pcoll_buffers[pcoll_id] = list()
      elif pcoll_id.startswith('group:'):
        # This is a grouping write, create a grouping buffer if needed.
        if pcoll_id not in pcoll_buffers:
          original_gbk_transform = pcoll_id.split(':', 1)[1]
          transform_proto = pipeline_components.transforms[
              original_gbk_transform]
          input_pcoll = only_element(transform_proto.inputs.values())
          output_pcoll = only_element(transform_proto.outputs.values())
          pre_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[input_pcoll].coder_id]]
          post_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[output_pcoll].coder_id]]
          windowing_strategy = context.windowing_strategies[
              pipeline_components
              .pcollections[output_pcoll].windowing_strategy_id]
          pcoll_buffers[pcoll_id] = _GroupingBuffer(
              pre_gbk_coder, post_gbk_coder, windowing_strategy)
      else:
        # These should be the only two identifiers we produce for now,
        # but special side input writes may go here.
        raise NotImplementedError(pcoll_id)
      return pcoll_buffers[pcoll_id]

    return BundleManager(
        controller, get_buffer, process_bundle_descriptor,
        self._progress_frequency).process_bundle(data_input, data_output)
Exemplo n.º 11
0
  def run_stage(
      self,
      worker_handler_factory,
      pipeline_components,
      stage,
      pcoll_buffers,
      safe_coders):

    def iterable_state_write(values, element_coder_impl):
      token = unique_name(None, 'iter').encode('ascii')
      out = create_OutputStream()
      for element in values:
        element_coder_impl.encode_to_stream(element, out, True)
      controller.state.blocking_append(
          beam_fn_api_pb2.StateKey(
              runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
          out.get())
      return token

    controller = worker_handler_factory(stage.environment)
    context = pipeline_context.PipelineContext(
        pipeline_components, iterable_state_write=iterable_state_write)
    data_api_service_descriptor = controller.data_api_service_descriptor()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data ApiServiceDescriptor.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          pcoll_id = transform.spec.payload
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
              data_input[target] = [ENCODED_IMPULSE_VALUE]
            else:
              data_input[target] = pcoll_buffers[pcoll_id]
            coder_id = pipeline_components.pcollections[
                only_element(transform.outputs.values())].coder_id
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
            coder_id = pipeline_components.pcollections[
                only_element(transform.inputs.values())].coder_id
          else:
            raise NotImplementedError
          data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
          if data_api_service_descriptor:
            data_spec.api_service_descriptor.url = (
                data_api_service_descriptor.url)
          transform.spec.payload = data_spec.SerializeToString()
        elif transform.spec.urn == common_urns.primitives.PAR_DO.urn:
          payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
          for tag, si in payload.side_inputs.items():
            data_side_input[transform.unique_name, tag] = (
                create_buffer_id(transform.inputs[tag]), si.access_pattern)
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    if controller.state_api_service_descriptor():
      process_bundle_descriptor.state_api_service_descriptor.url = (
          controller.state_api_service_descriptor().url)

    # Store the required side inputs into state.
    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
      _, pcoll_id = split_buffer_id(buffer_id)
      value_coder = context.coders[safe_coders[
          pipeline_components.pcollections[pcoll_id].coder_id]]
      elements_by_window = _WindowGroupingBuffer(si, value_coder)
      for element_data in pcoll_buffers[buffer_id]:
        elements_by_window.append(element_data)
      for key, window, elements_data in elements_by_window.encoded_items():
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                ptransform_id=transform_id,
                side_input_id=tag,
                window=window,
                key=key))
        controller.state.blocking_append(state_key, elements_data)

    def get_buffer(buffer_id):
      kind, name = split_buffer_id(buffer_id)
      if kind in ('materialize', 'timers'):
        if buffer_id not in pcoll_buffers:
          # Just store the data chunks for replay.
          pcoll_buffers[buffer_id] = list()
      elif kind == 'group':
        # This is a grouping write, create a grouping buffer if needed.
        if buffer_id not in pcoll_buffers:
          original_gbk_transform = name
          transform_proto = pipeline_components.transforms[
              original_gbk_transform]
          input_pcoll = only_element(list(transform_proto.inputs.values()))
          output_pcoll = only_element(list(transform_proto.outputs.values()))
          pre_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[input_pcoll].coder_id]]
          post_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[output_pcoll].coder_id]]
          windowing_strategy = context.windowing_strategies[
              pipeline_components
              .pcollections[output_pcoll].windowing_strategy_id]
          pcoll_buffers[buffer_id] = _GroupingBuffer(
              pre_gbk_coder, post_gbk_coder, windowing_strategy)
      else:
        # These should be the only two identifiers we produce for now,
        # but special side input writes may go here.
        raise NotImplementedError(buffer_id)
      return pcoll_buffers[buffer_id]

    for k in range(self._bundle_repeat):
      try:
        controller.state.checkpoint()
        BundleManager(
            controller, lambda pcoll_id: [], process_bundle_descriptor,
            self._progress_frequency, k).process_bundle(data_input, data_output)
      finally:
        controller.state.restore()

    result = BundleManager(
        controller, get_buffer, process_bundle_descriptor,
        self._progress_frequency).process_bundle(data_input, data_output)

    while True:
      timer_inputs = {}
      for transform_id, timer_writes in stage.timer_pcollections:
        windowed_timer_coder_impl = context.coders[
            pipeline_components.pcollections[timer_writes].coder_id].get_impl()
        written_timers = get_buffer(
            create_buffer_id(timer_writes, kind='timers'))
        if written_timers:
          # Keep only the "last" timer set per key and window.
          timers_by_key_and_window = {}
          for elements_data in written_timers:
            input_stream = create_InputStream(elements_data)
            while input_stream.size() > 0:
              windowed_key_timer = windowed_timer_coder_impl.decode_from_stream(
                  input_stream, True)
              key, _ = windowed_key_timer.value
              # TODO: Explode and merge windows.
              assert len(windowed_key_timer.windows) == 1
              timers_by_key_and_window[
                  key, windowed_key_timer.windows[0]] = windowed_key_timer
          out = create_OutputStream()
          for windowed_key_timer in timers_by_key_and_window.values():
            windowed_timer_coder_impl.encode_to_stream(
                windowed_key_timer, out, True)
          timer_inputs[transform_id, 'out'] = [out.get()]
          written_timers[:] = []
      if timer_inputs:
        # The worker will be waiting on these inputs as well.
        for other_input in data_input:
          if other_input not in timer_inputs:
            timer_inputs[other_input] = []
        # TODO(robertwb): merge results
        BundleManager(
            controller,
            get_buffer,
            process_bundle_descriptor,
            self._progress_frequency,
            True).process_bundle(timer_inputs, data_output)
      else:
        break

    return result