Пример #1
0
 def check_cython_coder(self, python_field_coders, cython_field_coders,
                        data):
     from apache_beam.coders.coder_impl import create_InputStream, create_OutputStream
     from pyflink.fn_execution.fast_coder_impl import InputStreamAndFunctionWrapper
     py_flatten_row_coder = coder_impl.FlattenRowCoderImpl(
         python_field_coders)
     internal = py_flatten_row_coder.encode(data)
     input_stream = create_InputStream(internal)
     output_stream = create_OutputStream()
     cy_flatten_row_coder = fast_coder_impl.FlattenRowCoderImpl(
         cython_field_coders)
     value = cy_flatten_row_coder.decode_from_stream(input_stream, False)
     wrapper_func_input_element = InputStreamAndFunctionWrapper(
         lambda v: [v[i] for i in range(len(v))], value)
     cy_flatten_row_coder.encode_to_stream(wrapper_func_input_element,
                                           output_stream, False)
     generator_result = py_flatten_row_coder.decode_from_stream(
         create_InputStream(output_stream.get()), False)
     result = []
     for item in generator_result:
         result.append(item)
     try:
         self.assertEqual(result, data)
     except AssertionError:
         self.assertEqual(len(result), len(data))
         self.assertEqual(len(result[0]), len(data[0]))
         for i in range(len(data[0])):
             if isinstance(data[0][i], float):
                 from pyflink.table.tests.test_udf import float_equal
                 assert float_equal(data[0][i], result[0][i], 1e-6)
             else:
                 self.assertEqual(data[0][i], result[0][i])
Пример #2
0
 def check_cython_coder(self, python_field_coders, cython_field_coders,
                        data):
     from apache_beam.coders.coder_impl import create_InputStream, create_OutputStream
     from pyflink.fn_execution.beam.beam_stream import BeamInputStream, BeamOutputStream
     py_flatten_row_coder = coder_impl.FlattenRowCoderImpl(
         python_field_coders)
     internal = py_flatten_row_coder.encode(data)
     beam_input_stream = create_InputStream(internal)
     input_stream = BeamInputStream(beam_input_stream,
                                    beam_input_stream.size())
     beam_output_stream = create_OutputStream()
     cy_flatten_row_coder = coder_impl_fast.FlattenRowCoderImpl(
         cython_field_coders)
     value = cy_flatten_row_coder.decode_from_stream(input_stream)
     output_stream = BeamOutputStream(beam_output_stream)
     cy_flatten_row_coder.encode_to_stream(value, output_stream)
     output_stream.flush()
     generator_result = py_flatten_row_coder.decode_from_stream(
         create_InputStream(beam_output_stream.get()), False)
     result = []
     for item in generator_result:
         result.append(item)
     try:
         self.assertEqual(result, [data])
     except AssertionError:
         data = [data]
         self.assertEqual(len(result), len(data))
         self.assertEqual(len(result[0]), len(data[0]))
         for i in range(len(data[0])):
             if isinstance(data[0][i], float):
                 from pyflink.table.tests.test_udf import float_equal
                 assert float_equal(data[0][i], result[0][i], 1e-6)
             else:
                 self.assertEqual(data[0][i], result[0][i])
Пример #3
0
  def _partially_cached_iterable(
      self,
      state_key,  # type: beam_fn_api_pb2.StateKey
      coder  # type: coder_impl.CoderImpl
    ):
    # type: (...) -> Iterable[Any]

    """Materialized the first page of data, concatenated with a lazy iterable
    of the rest, if any.
    """
    data, continuation_token = self._underlying.get_raw(state_key, None)
    head = []
    input_stream = coder_impl.create_InputStream(data)
    while input_stream.size() > 0:
      head.append(coder.decode_from_stream(input_stream, True))

    if continuation_token is None:
      return head
    else:

      def iter_func():
        for item in head:
          yield item
        if continuation_token:
          for item in self._lazy_iterator(state_key, coder, continuation_token):
            yield item

      return _IterableFromIterator(iter_func)
Пример #4
0
 def append(self, elements_data):
   input_stream = create_InputStream(elements_data)
   while input_stream.size() > 0:
     windowed_value = self._windowed_value_coder.get_impl(
         ).decode_from_stream(input_stream, True)
     for window in windowed_value.windows:
       self._values_by_window[window].append(windowed_value.value)
Пример #5
0
def _decode_gauge(coder, payload):
    """Returns a tuple of (timestamp, value)."""
    timestamp_coder = coders.VarIntCoder().get_impl()
    stream = coder_impl.create_InputStream(payload)
    time_ms = timestamp_coder.decode_from_stream(stream, True)
    return (time_ms / 1000.0,
            coder.get_impl().decode_from_stream(stream, True))
Пример #6
0
 def append(self, elements_data):
     input_stream = create_InputStream(elements_data)
     while input_stream.size() > 0:
         windowed_value = self._windowed_value_coder.get_impl(
         ).decode_from_stream(input_stream, True)
         for window in windowed_value.windows:
             self._values_by_window[window].append(windowed_value.value)
Пример #7
0
 def __iter__(self):
     # TODO(robertwb): Support pagination.
     input_stream = coder_impl.create_InputStream(
         state_handler.blocking_get(self._state_key))
     while input_stream.size() > 0:
         yield self._coder_impl.decode_from_stream(
             input_stream, True)
Пример #8
0
 def iterable_state_read(state_token, elem_coder):
     state = spec.get('state').get(state_token.decode('latin1'))
     if state is None:
         state = ''
     input_stream = coder_impl.create_InputStream(
         state.encode('latin1'))
     while input_stream.size() > 0:
         yield elem_coder.decode_from_stream(input_stream, True)
Пример #9
0
 def append(self, elements_data):
   input_stream = create_InputStream(elements_data)
   while input_stream.size() > 0:
     windowed_key_value = self._pre_grouped_coder.get_impl(
         ).decode_from_stream(input_stream, True)
     key = windowed_key_value.value[0]
     windowed_value = windowed_key_value.with_value(
         windowed_key_value.value[1])
     self._table[self._key_coder.encode(key)].append(windowed_value)
Пример #10
0
def _decode_distribution(value_coder, payload):
    """Returns a tuple of (count, sum, min, max)."""
    count_coder = coders.VarIntCoder().get_impl()
    value_coder = value_coder.get_impl()
    stream = coder_impl.create_InputStream(payload)
    return (count_coder.decode_from_stream(stream, True),
            value_coder.decode_from_stream(stream, True),
            value_coder.decode_from_stream(stream, True),
            value_coder.decode_from_stream(stream, True))
Пример #11
0
 def append(self, elements_data):
   input_stream = create_InputStream(elements_data)
   while input_stream.size() > 0:
     windowed_key_value = self._pre_grouped_coder.get_impl(
         ).decode_from_stream(input_stream, True)
     key = windowed_key_value.value[0]
     windowed_value = windowed_key_value.with_value(
         windowed_key_value.value[1])
     self._table[self._key_coder.encode(key)].append(windowed_value)
Пример #12
0
 def process_encoded(self, encoded_windowed_values):
   input_stream = coder_impl.create_InputStream(encoded_windowed_values)
   while input_stream.size() > 0:
     with self.splitting_lock:
       if self.index == self.stop - 1:
         return
       self.index += 1
     decoded_value = self.windowed_coder_impl.decode_from_stream(
         input_stream, True)
     self.output(decoded_value)
Пример #13
0
 def process_encoded(self, encoded_windowed_values):
   input_stream = coder_impl.create_InputStream(encoded_windowed_values)
   while input_stream.size() > 0:
     with self.splitting_lock:
       if self.index == self.stop - 1:
         return
       self.index += 1
     decoded_value = self.windowed_coder_impl.decode_from_stream(
         input_stream, True)
     self.output(decoded_value)
Пример #14
0
 def append(self, elements_data):
     # type: (bytes) -> None
     input_stream = create_InputStream(elements_data)
     while input_stream.size() > 0:
         windowed_val_coder_impl = self._windowed_value_coder.get_impl(
         )  # type: WindowedValueCoderImpl
         windowed_value = windowed_val_coder_impl.decode_from_stream(
             input_stream, True)
         key, value = self._kv_extractor(windowed_value.value)
         for window in windowed_value.windows:
             self._values_by_window[key, window].append(value)
Пример #15
0
 def __iter__(self):
   data, continuation_token = self._state_handler.blocking_get(self._state_key)
   while True:
     input_stream = coder_impl.create_InputStream(data)
     while input_stream.size() > 0:
       yield self._coder_impl.decode_from_stream(input_stream, True)
     if not continuation_token:
       break
     else:
       data, continuation_token = self._state_handler.blocking_get(
           self._state_key, continuation_token)
Пример #16
0
 def __iter__(self):
   data, continuation_token = self._state_handler.blocking_get(self._state_key)
   while True:
     input_stream = coder_impl.create_InputStream(data)
     while input_stream.size() > 0:
       yield self._coder_impl.decode_from_stream(input_stream, True)
     if not continuation_token:
       break
     else:
       data, continuation_token = self._state_handler.blocking_get(
           self._state_key, continuation_token)
Пример #17
0
 def append(self, elements_data):
   input_stream = create_InputStream(elements_data)
   coder_impl = self._pre_grouped_coder.get_impl()
   key_coder_impl = self._key_coder.get_impl()
   # TODO(robertwb): We could optimize this even more by using a
   # window-dropping coder for the data plane.
   is_trivial_windowing = self._windowing.is_default()
   while input_stream.size() > 0:
     windowed_key_value = coder_impl.decode_from_stream(input_stream, True)
     key, value = windowed_key_value.value
     self._table[key_coder_impl.encode(key)].append(
         value if is_trivial_windowing
         else windowed_key_value.with_value(value))
Пример #18
0
 def _materialize_iter(self, state_key, coder):
     """Materializes the state lazily, one element at a time.
    :return A generator which returns the next element if advanced.
 """
     continuation_token = None
     while True:
         data, continuation_token = \
             self._underlying.get_raw(state_key, continuation_token)
         input_stream = coder_impl.create_InputStream(data)
         while input_stream.size() > 0:
             yield coder.decode_from_stream(input_stream, True)
         if not continuation_token:
             break
Пример #19
0
 def append(self, elements_data):
   input_stream = create_InputStream(elements_data)
   coder_impl = self._pre_grouped_coder.get_impl()
   key_coder_impl = self._key_coder.get_impl()
   # TODO(robertwb): We could optimize this even more by using a
   # window-dropping coder for the data plane.
   is_trivial_windowing = self._windowing.is_default()
   while input_stream.size() > 0:
     windowed_key_value = coder_impl.decode_from_stream(input_stream, True)
     key, value = windowed_key_value.value
     self._table[key_coder_impl.encode(key)].append(
         value if is_trivial_windowing
         else windowed_key_value.with_value(value))
Пример #20
0
    def _run_map_task(self, map_task, control_handler, state_handler,
                      data_plane_handler, data_operation_spec):
        registration, sinks, input_data = self._map_task_registration(
            map_task, state_handler, data_operation_spec)
        control_handler.push(registration)
        process_bundle = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
                process_bundle_descriptor_reference=registration.register.
                process_bundle_descriptor[0].id))

        for (transform_id, name), elements in input_data.items():
            data_out = data_plane_handler.output_stream(
                process_bundle.instruction_id,
                beam_fn_api_pb2.Target(
                    primitive_transform_reference=transform_id, name=name))
            data_out.write(elements)
            data_out.close()

        control_handler.push(process_bundle)
        while True:
            result = control_handler.pull()
            if result.instruction_id == process_bundle.instruction_id:
                if result.error:
                    raise RuntimeError(result.error)
                expected_targets = [
                    beam_fn_api_pb2.Target(
                        primitive_transform_reference=transform_id,
                        name=output_name)
                    for (transform_id, output_name), _ in sinks.items()
                ]
                for output in data_plane_handler.input_elements(
                        process_bundle.instruction_id, expected_targets):
                    target_tuple = (
                        output.target.primitive_transform_reference,
                        output.target.name)
                    if target_tuple not in sinks:
                        # Unconsumed output.
                        continue
                    sink_op = sinks[target_tuple]
                    coder = sink_op.output_coders[0]
                    input_stream = create_InputStream(output.data)
                    elements = []
                    while input_stream.size() > 0:
                        elements.append(coder.get_impl().decode_from_stream(
                            input_stream, True))
                    if not sink_op.write_windowed_values:
                        elements = [e.value for e in elements]
                    for e in elements:
                        sink_op.output_buffer.append(e)
                return
Пример #21
0
 def _iterate_raw(self, state_key, iterate_type, iterator_token,
                  map_key_coder, map_value_coder):
     output_stream = coder_impl.create_OutputStream()
     output_stream.write_byte(self.ITERATE_FLAG)
     output_stream.write_byte(iterate_type.value)
     if not isinstance(iterator_token, IteratorToken):
         # The iterator token represents a Java iterator
         output_stream.write_bigendian_int32(len(iterator_token))
         output_stream.write(iterator_token)
     else:
         output_stream.write_bigendian_int32(0)
     continuation_token = output_stream.get()
     data, response_token = self._underlying.get_raw(
         state_key, continuation_token)
     if len(response_token) != 0:
         # The new iterator token is an UUID which represents a cached iterator at Java
         # side.
         new_iterator_token = response_token
         if iterator_token == IteratorToken.NOT_START:
             # This is the first request but not the last request of current state.
             # It means there is a new iterator has been created and cached at Java side.
             self._inc_cached_iterators_num()
     else:
         new_iterator_token = IteratorToken.FINISHED
         if iterator_token != IteratorToken.NOT_START:
             # This is not the first request but the last request of current state.
             # It means the cached iterator created at Java side has been removed as
             # current iteration has finished.
             self._dec_cached_iterators_num()
     input_stream = coder_impl.create_InputStream(data)
     if iterate_type == IterateType.ITEMS or iterate_type == IterateType.VALUES:
         # decode both key and value
         current_batch = {}
         while input_stream.size() > 0:
             key = map_key_coder.decode_from_stream(input_stream, True)
             is_not_none = input_stream.read_byte()
             if is_not_none:
                 value = map_value_coder.decode_from_stream(
                     input_stream, True)
             else:
                 value = None
             current_batch[key] = value
     else:
         # only decode key
         current_batch = []
         while input_stream.size() > 0:
             key = map_key_coder.decode_from_stream(input_stream, True)
             current_batch.append(key)
     return current_batch, new_iterator_token
Пример #22
0
  def _run_map_task(
      self, map_task, control_handler, state_handler, data_plane_handler,
      data_operation_spec):
    registration, sinks, input_data = self._map_task_registration(
        map_task, state_handler, data_operation_spec)
    control_handler.push(registration)
    process_bundle = beam_fn_api_pb2.InstructionRequest(
        instruction_id=self._next_uid(),
        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
            process_bundle_descriptor_reference=registration.register.
            process_bundle_descriptor[0].id))

    for (transform_id, name), elements in input_data.items():
      data_out = data_plane_handler.output_stream(
          process_bundle.instruction_id, beam_fn_api_pb2.Target(
              primitive_transform_reference=transform_id, name=name))
      data_out.write(elements)
      data_out.close()

    control_handler.push(process_bundle)
    while True:
      result = control_handler.pull()
      if result.instruction_id == process_bundle.instruction_id:
        if result.error:
          raise RuntimeError(result.error)
        expected_targets = [
            beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                                   name=output_name)
            for (transform_id, output_name), _ in sinks.items()]
        for output in data_plane_handler.input_elements(
            process_bundle.instruction_id, expected_targets):
          target_tuple = (
              output.target.primitive_transform_reference, output.target.name)
          if target_tuple not in sinks:
            # Unconsumed output.
            continue
          sink_op = sinks[target_tuple]
          coder = sink_op.output_coders[0]
          input_stream = create_InputStream(output.data)
          elements = []
          while input_stream.size() > 0:
            elements.append(coder.get_impl().decode_from_stream(
                input_stream, True))
          if not sink_op.write_windowed_values:
            elements = [e.value for e in elements]
          for e in elements:
            sink_op.output_buffer.append(e)
        return
Пример #23
0
 def _get_raw(self, state_key, map_key, map_key_coder, map_value_coder):
     output_stream = coder_impl.create_OutputStream()
     output_stream.write_byte(self.GET_FLAG)
     map_key_coder.encode_to_stream(map_key, output_stream, True)
     continuation_token = output_stream.get()
     data, response_token = self._underlying.get_raw(state_key, continuation_token)
     input_stream = coder_impl.create_InputStream(data)
     result_flag = input_stream.read_byte()
     if result_flag == self.EXIST_FLAG:
         return True, map_value_coder.decode_from_stream(input_stream, True)
     elif result_flag == self.IS_NONE_FLAG:
         return True, None
     elif result_flag == self.NOT_EXIST_FLAG:
         return False, None
     else:
         raise Exception("Unknown response flag: " + str(result_flag))
Пример #24
0
 def _lazy_iterator(
     self,
     state_key,  # type: beam_fn_api_pb2.StateKey
     coder,  # type: coder_impl.CoderImpl
     continuation_token=None  # type: Optional[bytes]
 ):
     # type: (...) -> Iterator[Any]
     """Materializes the state lazily, one element at a time.
    :return A generator which returns the next element if advanced.
 """
     while True:
         data, continuation_token = (self._underlying.get_raw(
             state_key, continuation_token))
         input_stream = coder_impl.create_InputStream(data)
         while input_stream.size() > 0:
             yield coder.decode_from_stream(input_stream, True)
         if not continuation_token:
             break
Пример #25
0
 def partition(self, n):
   # type: (int) -> List[List[bytes]]
   if self.cleared:
     raise RuntimeError('Trying to partition a cleared ListBuffer.')
   if len(self._inputs) >= n or len(self._inputs) == 0:
     return [self._inputs[k::n] for k in range(n)]
   else:
     if not self._grouped_output:
       output_stream_list = [create_OutputStream() for _ in range(n)]
       idx = 0
       for input in self._inputs:
         input_stream = create_InputStream(input)
         while input_stream.size() > 0:
           decoded_value = self._coder_impl.decode_from_stream(
               input_stream, True)
           self._coder_impl.encode_to_stream(
               decoded_value, output_stream_list[idx], True)
           idx = (idx + 1) % n
       self._grouped_output = [[output_stream.get()]
                               for output_stream in output_stream_list]
     return self._grouped_output
Пример #26
0
def decode_nested(coder, encoded, nested=True):
    return coder.get_impl().decode_from_stream(
        coder_impl.create_InputStream(encoded), nested)
Пример #27
0
 def process_encoded(self, encoded_windowed_values):
   input_stream = coder_impl.create_InputStream(encoded_windowed_values)
   while input_stream.size() > 0:
     decoded_value = self.windowed_coder_impl.decode_from_stream(
         input_stream, True)
     self.output(decoded_value)
Пример #28
0
  def run_stage(
      self,
      worker_handler_factory,
      pipeline_components,
      stage,
      pcoll_buffers,
      safe_coders):

    def iterable_state_write(values, element_coder_impl):
      token = unique_name(None, 'iter').encode('ascii')
      out = create_OutputStream()
      for element in values:
        element_coder_impl.encode_to_stream(element, out, True)
      controller.state.blocking_append(
          beam_fn_api_pb2.StateKey(
              runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
          out.get())
      return token

    controller = worker_handler_factory(stage.environment)
    context = pipeline_context.PipelineContext(
        pipeline_components, iterable_state_write=iterable_state_write)
    data_api_service_descriptor = controller.data_api_service_descriptor()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data ApiServiceDescriptor.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          pcoll_id = transform.spec.payload
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
              data_input[target] = [ENCODED_IMPULSE_VALUE]
            else:
              data_input[target] = pcoll_buffers[pcoll_id]
            coder_id = pipeline_components.pcollections[
                only_element(transform.outputs.values())].coder_id
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
            coder_id = pipeline_components.pcollections[
                only_element(transform.inputs.values())].coder_id
          else:
            raise NotImplementedError
          data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
          if data_api_service_descriptor:
            data_spec.api_service_descriptor.url = (
                data_api_service_descriptor.url)
          transform.spec.payload = data_spec.SerializeToString()
        elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS:
          payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
          for tag, si in payload.side_inputs.items():
            data_side_input[transform.unique_name, tag] = (
                create_buffer_id(transform.inputs[tag]), si.access_pattern)
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    if controller.state_api_service_descriptor():
      process_bundle_descriptor.state_api_service_descriptor.url = (
          controller.state_api_service_descriptor().url)

    # Store the required side inputs into state.
    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
      _, pcoll_id = split_buffer_id(buffer_id)
      value_coder = context.coders[safe_coders[
          pipeline_components.pcollections[pcoll_id].coder_id]]
      elements_by_window = _WindowGroupingBuffer(si, value_coder)
      for element_data in pcoll_buffers[buffer_id]:
        elements_by_window.append(element_data)
      for key, window, elements_data in elements_by_window.encoded_items():
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                ptransform_id=transform_id,
                side_input_id=tag,
                window=window,
                key=key))
        controller.state.blocking_append(state_key, elements_data)

    def get_buffer(buffer_id):
      kind, name = split_buffer_id(buffer_id)
      if kind in ('materialize', 'timers'):
        if buffer_id not in pcoll_buffers:
          # Just store the data chunks for replay.
          pcoll_buffers[buffer_id] = list()
      elif kind == 'group':
        # This is a grouping write, create a grouping buffer if needed.
        if buffer_id not in pcoll_buffers:
          original_gbk_transform = name
          transform_proto = pipeline_components.transforms[
              original_gbk_transform]
          input_pcoll = only_element(list(transform_proto.inputs.values()))
          output_pcoll = only_element(list(transform_proto.outputs.values()))
          pre_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[input_pcoll].coder_id]]
          post_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[output_pcoll].coder_id]]
          windowing_strategy = context.windowing_strategies[
              pipeline_components
              .pcollections[output_pcoll].windowing_strategy_id]
          pcoll_buffers[buffer_id] = _GroupingBuffer(
              pre_gbk_coder, post_gbk_coder, windowing_strategy)
      else:
        # These should be the only two identifiers we produce for now,
        # but special side input writes may go here.
        raise NotImplementedError(buffer_id)
      return pcoll_buffers[buffer_id]

    for k in range(self._bundle_repeat):
      try:
        controller.state.checkpoint()
        BundleManager(
            controller, lambda pcoll_id: [], process_bundle_descriptor,
            self._progress_frequency, k).process_bundle(data_input, data_output)
      finally:
        controller.state.restore()

    result = BundleManager(
        controller, get_buffer, process_bundle_descriptor,
        self._progress_frequency).process_bundle(
            data_input, data_output)

    last_result = result
    while True:
      deferred_inputs = collections.defaultdict(list)
      for transform_id, timer_writes in stage.timer_pcollections:

        # Queue any set timers as new inputs.
        windowed_timer_coder_impl = context.coders[
            pipeline_components.pcollections[timer_writes].coder_id].get_impl()
        written_timers = get_buffer(
            create_buffer_id(timer_writes, kind='timers'))
        if written_timers:
          # Keep only the "last" timer set per key and window.
          timers_by_key_and_window = {}
          for elements_data in written_timers:
            input_stream = create_InputStream(elements_data)
            while input_stream.size() > 0:
              windowed_key_timer = windowed_timer_coder_impl.decode_from_stream(
                  input_stream, True)
              key, _ = windowed_key_timer.value
              # TODO: Explode and merge windows.
              assert len(windowed_key_timer.windows) == 1
              timers_by_key_and_window[
                  key, windowed_key_timer.windows[0]] = windowed_key_timer
          out = create_OutputStream()
          for windowed_key_timer in timers_by_key_and_window.values():
            windowed_timer_coder_impl.encode_to_stream(
                windowed_key_timer, out, True)
          deferred_inputs[transform_id, 'out'] = [out.get()]
          written_timers[:] = []

      # Queue any delayed bundle applications.
      for delayed_application in last_result.process_bundle.residual_roots:
        # Find the io transform that feeds this transform.
        # TODO(SDF): Memoize?
        application = delayed_application.application
        input_pcoll = process_bundle_descriptor.transforms[
            application.ptransform_id].inputs[application.input_id]
        for input_id, proto in process_bundle_descriptor.transforms.items():
          if (proto.spec.urn == bundle_processor.DATA_INPUT_URN
              and input_pcoll in proto.outputs.values()):
            deferred_inputs[input_id, 'out'].append(application.element)
            break
        else:
          raise RuntimeError(
              'No IO transform feeds %s' % application.ptransform_id)

      if deferred_inputs:
        # The worker will be waiting on these inputs as well.
        for other_input in data_input:
          if other_input not in deferred_inputs:
            deferred_inputs[other_input] = []
        # TODO(robertwb): merge results
        last_result = BundleManager(
            controller,
            get_buffer,
            process_bundle_descriptor,
            self._progress_frequency,
            True).process_bundle(deferred_inputs, data_output)
      else:
        break

    return result
Пример #29
0
 def append(self, elements_data):
   input_stream = create_InputStream(elements_data)
   while input_stream.size() > 0:
     key, value = self._pre_grouped_coder.get_impl().decode_from_stream(
         input_stream, True).value
     self._table[self._key_coder.encode(key)].append(value)
def decode_nested(coder, encoded, nested=True):
  return coder.get_impl().decode_from_stream(
      coder_impl.create_InputStream(encoded), nested)
Пример #31
0
 def __iter__(self):
   # TODO(robertwb): Support pagination.
   input_stream = coder_impl.create_InputStream(
       state_handler.blocking_get(state_key, None))
   while input_stream.size() > 0:
     yield element_coder_impl.decode_from_stream(input_stream, True)
Пример #32
0
 def __iter__(self):
   # TODO(robertwb): Support pagination.
   input_stream = coder_impl.create_InputStream(
       self._state_handler.Get(self._state_key).data)
   while input_stream.size() > 0:
     yield self._coder.get_impl().decode_from_stream(input_stream, True)
Пример #33
0
 def process_encoded(self, encoded_windowed_values):
   input_stream = coder_impl.create_InputStream(encoded_windowed_values)
   while input_stream.size() > 0:
     decoded_value = self.windowed_coder.get_impl().decode_from_stream(
         input_stream, True)
     self.output(decoded_value)
Пример #34
0
  def run_stage(
      self,
      worker_handler_factory,
      pipeline_components,
      stage,
      pcoll_buffers,
      safe_coders):

    def iterable_state_write(values, element_coder_impl):
      token = unique_name(None, 'iter').encode('ascii')
      out = create_OutputStream()
      for element in values:
        element_coder_impl.encode_to_stream(element, out, True)
      controller.state.blocking_append(
          beam_fn_api_pb2.StateKey(
              runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
          out.get())
      return token

    controller = worker_handler_factory(stage.environment)
    context = pipeline_context.PipelineContext(
        pipeline_components, iterable_state_write=iterable_state_write)
    data_api_service_descriptor = controller.data_api_service_descriptor()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data ApiServiceDescriptor.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          pcoll_id = transform.spec.payload
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
              data_input[target] = [ENCODED_IMPULSE_VALUE]
            else:
              data_input[target] = pcoll_buffers[pcoll_id]
            coder_id = pipeline_components.pcollections[
                only_element(transform.outputs.values())].coder_id
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
            coder_id = pipeline_components.pcollections[
                only_element(transform.inputs.values())].coder_id
          else:
            raise NotImplementedError
          data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
          if data_api_service_descriptor:
            data_spec.api_service_descriptor.url = (
                data_api_service_descriptor.url)
          transform.spec.payload = data_spec.SerializeToString()
        elif transform.spec.urn == common_urns.primitives.PAR_DO.urn:
          payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
          for tag, si in payload.side_inputs.items():
            data_side_input[transform.unique_name, tag] = (
                create_buffer_id(transform.inputs[tag]), si.access_pattern)
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    if controller.state_api_service_descriptor():
      process_bundle_descriptor.state_api_service_descriptor.url = (
          controller.state_api_service_descriptor().url)

    # Store the required side inputs into state.
    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
      _, pcoll_id = split_buffer_id(buffer_id)
      value_coder = context.coders[safe_coders[
          pipeline_components.pcollections[pcoll_id].coder_id]]
      elements_by_window = _WindowGroupingBuffer(si, value_coder)
      for element_data in pcoll_buffers[buffer_id]:
        elements_by_window.append(element_data)
      for key, window, elements_data in elements_by_window.encoded_items():
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                ptransform_id=transform_id,
                side_input_id=tag,
                window=window,
                key=key))
        controller.state.blocking_append(state_key, elements_data)

    def get_buffer(buffer_id):
      kind, name = split_buffer_id(buffer_id)
      if kind in ('materialize', 'timers'):
        if buffer_id not in pcoll_buffers:
          # Just store the data chunks for replay.
          pcoll_buffers[buffer_id] = list()
      elif kind == 'group':
        # This is a grouping write, create a grouping buffer if needed.
        if buffer_id not in pcoll_buffers:
          original_gbk_transform = name
          transform_proto = pipeline_components.transforms[
              original_gbk_transform]
          input_pcoll = only_element(list(transform_proto.inputs.values()))
          output_pcoll = only_element(list(transform_proto.outputs.values()))
          pre_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[input_pcoll].coder_id]]
          post_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[output_pcoll].coder_id]]
          windowing_strategy = context.windowing_strategies[
              pipeline_components
              .pcollections[output_pcoll].windowing_strategy_id]
          pcoll_buffers[buffer_id] = _GroupingBuffer(
              pre_gbk_coder, post_gbk_coder, windowing_strategy)
      else:
        # These should be the only two identifiers we produce for now,
        # but special side input writes may go here.
        raise NotImplementedError(buffer_id)
      return pcoll_buffers[buffer_id]

    for k in range(self._bundle_repeat):
      try:
        controller.state.checkpoint()
        BundleManager(
            controller, lambda pcoll_id: [], process_bundle_descriptor,
            self._progress_frequency, k).process_bundle(data_input, data_output)
      finally:
        controller.state.restore()

    result = BundleManager(
        controller, get_buffer, process_bundle_descriptor,
        self._progress_frequency).process_bundle(data_input, data_output)

    while True:
      timer_inputs = {}
      for transform_id, timer_writes in stage.timer_pcollections:
        windowed_timer_coder_impl = context.coders[
            pipeline_components.pcollections[timer_writes].coder_id].get_impl()
        written_timers = get_buffer(
            create_buffer_id(timer_writes, kind='timers'))
        if written_timers:
          # Keep only the "last" timer set per key and window.
          timers_by_key_and_window = {}
          for elements_data in written_timers:
            input_stream = create_InputStream(elements_data)
            while input_stream.size() > 0:
              windowed_key_timer = windowed_timer_coder_impl.decode_from_stream(
                  input_stream, True)
              key, _ = windowed_key_timer.value
              # TODO: Explode and merge windows.
              assert len(windowed_key_timer.windows) == 1
              timers_by_key_and_window[
                  key, windowed_key_timer.windows[0]] = windowed_key_timer
          out = create_OutputStream()
          for windowed_key_timer in timers_by_key_and_window.values():
            windowed_timer_coder_impl.encode_to_stream(
                windowed_key_timer, out, True)
          timer_inputs[transform_id, 'out'] = [out.get()]
          written_timers[:] = []
      if timer_inputs:
        # The worker will be waiting on these inputs as well.
        for other_input in data_input:
          if other_input not in timer_inputs:
            timer_inputs[other_input] = []
        # TODO(robertwb): merge results
        BundleManager(
            controller,
            get_buffer,
            process_bundle_descriptor,
            self._progress_frequency,
            True).process_bundle(timer_inputs, data_output)
      else:
        break

    return result