Ejemplo n.º 1
0
  def process_bundle(self, inputs, expected_outputs):
    # Unique id for the instruction processing this bundle.
    BundleManager._uid_counter += 1
    process_bundle_id = 'bundle_%s' % BundleManager._uid_counter

    # Register the bundle descriptor, if needed.
    if self._registered:
      registration_future = None
    else:
      process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
          register=beam_fn_api_pb2.RegisterRequest(
              process_bundle_descriptor=[self._bundle_descriptor]))
      registration_future = self._controller.control_handler.push(
          process_bundle_registration)
      self._registered = True

    # Write all the input data to the channel.
    for (transform_id, name), elements in inputs.items():
      data_out = self._controller.data_plane_handler.output_stream(
          process_bundle_id, beam_fn_api_pb2.Target(
              primitive_transform_reference=transform_id, name=name))
      for element_data in elements:
        data_out.write(element_data)
      data_out.close()

    # Actually start the bundle.
    if registration_future and registration_future.get().error:
      raise RuntimeError(registration_future.get().error)
    process_bundle = beam_fn_api_pb2.InstructionRequest(
        instruction_id=process_bundle_id,
        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
            process_bundle_descriptor_reference=self._bundle_descriptor.id))
    result_future = self._controller.control_handler.push(process_bundle)

    with ProgressRequester(
        self._controller, process_bundle_id, self._progress_frequency):
      # Gather all output data.
      expected_targets = [
          beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                                 name=output_name)
          for (transform_id, output_name), _ in expected_outputs.items()]
      logging.debug('Gather all output data from %s.', expected_targets)
      for output in self._controller.data_plane_handler.input_elements(
          process_bundle_id,
          expected_targets,
          abort_callback=lambda: (result_future.is_done()
                                  and result_future.get().error)):
        target_tuple = (
            output.target.primitive_transform_reference, output.target.name)
        if target_tuple in expected_outputs:
          self._get_buffer(expected_outputs[target_tuple]).append(output.data)

      logging.debug('Wait for the bundle to finish.')
      result = result_future.get()

    if result.error:
      raise RuntimeError(result.error)
    return result
Ejemplo n.º 2
0
    def _run_map_task(self, map_task, control_handler, state_handler,
                      data_plane_handler, data_operation_spec):
        registration, sinks, input_data = self._map_task_registration(
            map_task, state_handler, data_operation_spec)
        control_handler.push(registration)
        process_bundle = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
                process_bundle_descriptor_reference=registration.register.
                process_bundle_descriptor[0].id))

        for (transform_id, name), elements in input_data.items():
            data_out = data_plane_handler.output_stream(
                process_bundle.instruction_id,
                beam_fn_api_pb2.Target(
                    primitive_transform_reference=transform_id, name=name))
            data_out.write(elements)
            data_out.close()

        control_handler.push(process_bundle)
        while True:
            result = control_handler.pull()
            if result.instruction_id == process_bundle.instruction_id:
                if result.error:
                    raise RuntimeError(result.error)
                expected_targets = [
                    beam_fn_api_pb2.Target(
                        primitive_transform_reference=transform_id,
                        name=output_name)
                    for (transform_id, output_name), _ in sinks.items()
                ]
                for output in data_plane_handler.input_elements(
                        process_bundle.instruction_id, expected_targets):
                    target_tuple = (
                        output.target.primitive_transform_reference,
                        output.target.name)
                    if target_tuple not in sinks:
                        # Unconsumed output.
                        continue
                    sink_op = sinks[target_tuple]
                    coder = sink_op.output_coders[0]
                    input_stream = create_InputStream(output.data)
                    elements = []
                    while input_stream.size() > 0:
                        elements.append(coder.get_impl().decode_from_stream(
                            input_stream, True))
                    if not sink_op.write_windowed_values:
                        elements = [e.value for e in elements]
                    for e in elements:
                        sink_op.output_buffer.append(e)
                return
Ejemplo n.º 3
0
def create(factory, transform_id, transform_proto, grpc_port, consumers):
  # Timers are the one special case where we don't want to call the
  # (unlabeled) operation.process() method, which we detect here.
  # TODO(robertwb): Consider generalizing if there are any more cases.
  output_pcoll = only_element(transform_proto.outputs.values())
  output_consumers = only_element(consumers.values())
  if (len(output_consumers) == 1
      and isinstance(only_element(output_consumers), operations.DoOperation)):
    do_op = only_element(output_consumers)
    for tag, pcoll_id in do_op.timer_inputs.items():
      if pcoll_id == output_pcoll:
        output_consumers[:] = [TimerConsumer(tag, do_op)]
        break

  target = beam_fn_api_pb2.Target(
      primitive_transform_reference=transform_id,
      name=only_element(list(transform_proto.outputs.keys())))
  if grpc_port.coder_id:
    output_coder = factory.get_coder(grpc_port.coder_id)
  else:
    logging.error(
        'Missing required coder_id on grpc_port for %s; '
        'using deprecated fallback.',
        transform_id)
    output_coder = factory.get_only_output_coder(transform_proto)
  return DataInputOperation(
      transform_proto.unique_name,
      transform_proto.unique_name,
      consumers,
      factory.counter_factory,
      factory.state_sampler,
      output_coder,
      input_target=target,
      data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
Ejemplo n.º 4
0
 def as_target(op_input):
     input_op_index, input_output_index = op_input
     input_op = map_task[input_op_index][1]
     return {
         'ignored_input_tag':
         beam_fn_api_pb2.Target.List(target=[
             beam_fn_api_pb2.Target(
                 primitive_transform_reference=transform_index_to_id[
                     input_op_index],
                 name=output_tags(input_op)[input_output_index])
         ])
     }
Ejemplo n.º 5
0
def create(factory, transform_id, transform_proto, grpc_port, consumers):
  target = beam_fn_api_pb2.Target(
      primitive_transform_reference=transform_id,
      name=only_element(transform_proto.outputs.keys()))
  return DataInputOperation(
      transform_proto.unique_name,
      transform_proto.unique_name,
      consumers,
      factory.counter_factory,
      factory.state_sampler,
      factory.get_only_output_coder(transform_proto),
      input_target=target,
      data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
Ejemplo n.º 6
0
def create(factory, transform_id, transform_proto, grpc_port, consumers):
  target = beam_fn_api_pb2.Target(
      primitive_transform_reference=transform_id,
      name=only_element(transform_proto.inputs.keys()))
  return DataOutputOperation(
      transform_proto.unique_name,
      transform_proto.unique_name,
      consumers,
      factory.counter_factory,
      factory.state_sampler,
      # TODO(robertwb): Perhaps this could be distinct from the input coder?
      factory.get_only_input_coder(transform_proto),
      target=target,
      data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
Ejemplo n.º 7
0
    def _data_channel_test_one_direction(self, from_channel, to_channel):
        def send(instruction_id, target, data):
            stream = from_channel.output_stream(instruction_id, target)
            stream.write(data)
            stream.close()

        target_1 = beam_fn_api_pb2.Target(primitive_transform_reference='1',
                                          name='out')
        target_2 = beam_fn_api_pb2.Target(primitive_transform_reference='2',
                                          name='out')

        # Single write.
        send('0', target_1, 'abc')
        self.assertEqual(list(to_channel.input_elements('0', [target_1])), [
            beam_fn_api_pb2.Elements.Data(
                instruction_reference='0', target=target_1, data='abc')
        ])

        # Multiple interleaved writes to multiple instructions.
        target_2 = beam_fn_api_pb2.Target(primitive_transform_reference='2',
                                          name='out')

        send('1', target_1, 'abc')
        send('2', target_1, 'def')
        self.assertEqual(list(to_channel.input_elements('1', [target_1])), [
            beam_fn_api_pb2.Elements.Data(
                instruction_reference='1', target=target_1, data='abc')
        ])
        send('2', target_2, 'ghi')
        self.assertEqual(
            list(to_channel.input_elements('2', [target_1, target_2])), [
                beam_fn_api_pb2.Elements.Data(
                    instruction_reference='2', target=target_1, data='def'),
                beam_fn_api_pb2.Elements.Data(
                    instruction_reference='2', target=target_2, data='ghi')
            ])
Ejemplo n.º 8
0
def create(factory, transform_id, transform_proto, grpc_port, consumers):
  target = beam_fn_api_pb2.Target(
      primitive_transform_reference=transform_id,
      name=only_element(list(transform_proto.inputs.keys())))
  if grpc_port.coder_id:
    output_coder = factory.get_coder(grpc_port.coder_id)
  else:
    logging.error(
        'Missing required coder_id on grpc_port for %s; '
        'using deprecated fallback.',
        transform_id)
    output_coder = factory.get_only_input_coder(transform_proto)
  return DataOutputOperation(
      transform_proto.unique_name,
      transform_proto.unique_name,
      consumers,
      factory.counter_factory,
      factory.state_sampler,
      output_coder,
      target=target,
      data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
Ejemplo n.º 9
0
    def run_stage(self, controller, pipeline_components, stage, pcoll_buffers,
                  safe_coders):

        context = pipeline_context.PipelineContext(pipeline_components)
        data_operation_spec = controller.data_operation_spec()

        def extract_endpoints(stage):
            # Returns maps of transform names to PCollection identifiers.
            # Also mutates IO stages to point to the data data_operation_spec.
            data_input = {}
            data_side_input = {}
            data_output = {}
            for transform in stage.transforms:
                if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                          bundle_processor.DATA_OUTPUT_URN):
                    pcoll_id = transform.spec.payload
                    if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
                        target = transform.unique_name, only_element(
                            transform.outputs)
                        data_input[target] = pcoll_id
                    elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
                        target = transform.unique_name, only_element(
                            transform.inputs)
                        data_output[target] = pcoll_id
                    else:
                        raise NotImplementedError
                    if data_operation_spec:
                        transform.spec.payload = data_operation_spec.SerializeToString(
                        )
                    else:
                        transform.spec.payload = ""
                elif transform.spec.urn == urns.PARDO_TRANSFORM:
                    payload = proto_utils.parse_Bytes(
                        transform.spec.payload,
                        beam_runner_api_pb2.ParDoPayload)
                    for tag, si in payload.side_inputs.items():
                        data_side_input[transform.unique_name, tag] = (
                            'materialize:' + transform.inputs[tag],
                            beam.pvalue.SideInputData.from_runner_api(
                                si, None))
            return data_input, data_side_input, data_output

        logging.info('Running %s', stage.name)
        logging.debug('       %s', stage)
        data_input, data_side_input, data_output = extract_endpoints(stage)

        process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._next_uid(),
            transforms={
                transform.unique_name: transform
                for transform in stage.transforms
            },
            pcollections=dict(pipeline_components.pcollections.items()),
            coders=dict(pipeline_components.coders.items()),
            windowing_strategies=dict(
                pipeline_components.windowing_strategies.items()),
            environments=dict(pipeline_components.environments.items()))

        process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            register=beam_fn_api_pb2.RegisterRequest(
                process_bundle_descriptor=[process_bundle_descriptor]))

        process_bundle = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
                process_bundle_descriptor_reference=process_bundle_descriptor.
                id))

        # Write all the input data to the channel.
        for (transform_id, name), pcoll_id in data_input.items():
            data_out = controller.data_plane_handler.output_stream(
                process_bundle.instruction_id,
                beam_fn_api_pb2.Target(
                    primitive_transform_reference=transform_id, name=name))
            for element_data in pcoll_buffers[pcoll_id]:
                data_out.write(element_data)
            data_out.close()

        # Store the required side inputs into state.
        for (transform_id, tag), (pcoll_id, si) in data_side_input.items():
            elements_by_window = _WindowGroupingBuffer(si)
            for element_data in pcoll_buffers[pcoll_id]:
                elements_by_window.append(element_data)
            for window, elements_data in elements_by_window.items():
                state_key = beam_fn_api_pb2.StateKey(
                    multimap_side_input=beam_fn_api_pb2.StateKey.
                    MultimapSideInput(ptransform_id=transform_id,
                                      side_input_id=tag,
                                      window=window))
                controller.state_handler.blocking_append(
                    state_key, elements_data, process_bundle.instruction_id)

        # Register and start running the bundle.
        logging.debug('Register and start running the bundle')
        controller.control_handler.push(process_bundle_registration)
        controller.control_handler.push(process_bundle)

        # Wait for the bundle to finish.
        logging.debug('Wait for the bundle to finish.')
        while True:
            result = controller.control_handler.pull()
            if result and result.instruction_id == process_bundle.instruction_id:
                if result.error:
                    raise RuntimeError(result.error)
                break

        expected_targets = [
            beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                                   name=output_name)
            for (transform_id, output_name), _ in data_output.items()
        ]

        # Gather all output data.
        logging.debug('Gather all output data from %s.', expected_targets)

        for output in controller.data_plane_handler.input_elements(
                process_bundle.instruction_id, expected_targets):
            target_tuple = (output.target.primitive_transform_reference,
                            output.target.name)
            if target_tuple in data_output:
                pcoll_id = data_output[target_tuple]
                if pcoll_id.startswith('materialize:'):
                    # Just store the data chunks for replay.
                    pcoll_buffers[pcoll_id].append(output.data)
                elif pcoll_id.startswith('group:'):
                    # This is a grouping write, create a grouping buffer if needed.
                    if pcoll_id not in pcoll_buffers:
                        original_gbk_transform = pcoll_id.split(':', 1)[1]
                        transform_proto = pipeline_components.transforms[
                            original_gbk_transform]
                        input_pcoll = only_element(
                            transform_proto.inputs.values())
                        output_pcoll = only_element(
                            transform_proto.outputs.values())
                        pre_gbk_coder = context.coders[
                            safe_coders[pipeline_components.
                                        pcollections[input_pcoll].coder_id]]
                        post_gbk_coder = context.coders[
                            safe_coders[pipeline_components.
                                        pcollections[output_pcoll].coder_id]]
                        windowing_strategy = context.windowing_strategies[
                            pipeline_components.pcollections[output_pcoll].
                            windowing_strategy_id]
                        pcoll_buffers[pcoll_id] = _GroupingBuffer(
                            pre_gbk_coder, post_gbk_coder, windowing_strategy)
                    pcoll_buffers[pcoll_id].append(output.data)
                else:
                    # These should be the only two identifiers we produce for now,
                    # but special side input writes may go here.
                    raise NotImplementedError(pcoll_id)
        return result
Ejemplo n.º 10
0
    def _map_task_registration(self, map_task, state_handler,
                               data_operation_spec):
        input_data = {}
        runner_sinks = {}
        transforms = []
        transform_index_to_id = {}

        # Maps coders to new coder objects and references.
        coders = {}

        def coder_id(coder):
            if coder not in coders:
                coders[coder] = beam_fn_api_pb2.Coder(
                    function_spec=sdk_worker.pack_function_spec_data(
                        json.dumps(coder.as_cloud_object()),
                        sdk_worker.PYTHON_CODER_URN,
                        id=self._next_uid()))

            return coders[coder].function_spec.id

        def output_tags(op):
            return getattr(op, 'output_tags', ['out'])

        def as_target(op_input):
            input_op_index, input_output_index = op_input
            input_op = map_task[input_op_index][1]
            return {
                'ignored_input_tag':
                beam_fn_api_pb2.Target.List(target=[
                    beam_fn_api_pb2.Target(
                        primitive_transform_reference=transform_index_to_id[
                            input_op_index],
                        name=output_tags(input_op)[input_output_index])
                ])
            }

        def outputs(op):
            return {
                tag:
                beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder))
                for tag, coder in zip(output_tags(op), op.output_coders)
            }

        for op_ix, (stage_name, operation) in enumerate(map_task):
            transform_id = transform_index_to_id[op_ix] = self._next_uid()
            if isinstance(operation, operation_specs.WorkerInMemoryWrite):
                # Write this data back to the runner.
                fn = beam_fn_api_pb2.FunctionSpec(
                    urn=sdk_worker.DATA_OUTPUT_URN, id=self._next_uid())
                if data_operation_spec:
                    fn.data.Pack(data_operation_spec)
                inputs = as_target(operation.input)
                side_inputs = {}
                runner_sinks[(transform_id, 'out')] = operation

            elif isinstance(operation, operation_specs.WorkerRead):
                # A Read is either translated to a direct injection of windowed values
                # into the sdk worker, or an injection of the source object into the
                # sdk worker as data followed by an SDF that reads that source.
                if (isinstance(operation.source.source,
                               maptask_executor_runner.InMemorySource)
                        and isinstance(
                            operation.source.source.default_output_coder(),
                            WindowedValueCoder)):
                    output_stream = create_OutputStream()
                    element_coder = (operation.source.source.
                                     default_output_coder().get_impl())
                    # Re-encode the elements in the nested context and
                    # concatenate them together
                    for element in operation.source.source.read(None):
                        element_coder.encode_to_stream(element, output_stream,
                                                       True)
                    target_name = self._next_uid()
                    input_data[(transform_id,
                                target_name)] = output_stream.get()
                    fn = beam_fn_api_pb2.FunctionSpec(
                        urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid())
                    if data_operation_spec:
                        fn.data.Pack(data_operation_spec)
                    inputs = {target_name: beam_fn_api_pb2.Target.List()}
                    side_inputs = {}
                else:
                    # Read the source object from the runner.
                    source_coder = beam.coders.DillCoder()
                    input_transform_id = self._next_uid()
                    output_stream = create_OutputStream()
                    source_coder.get_impl().encode_to_stream(
                        GlobalWindows.windowed_value(operation.source),
                        output_stream, True)
                    target_name = self._next_uid()
                    input_data[(input_transform_id,
                                target_name)] = output_stream.get()
                    input_ptransform = beam_fn_api_pb2.PrimitiveTransform(
                        id=input_transform_id,
                        function_spec=beam_fn_api_pb2.FunctionSpec(
                            urn=sdk_worker.DATA_INPUT_URN,
                            id=self._next_uid()),
                        # TODO(robertwb): Possible name collision.
                        step_name=stage_name + '/inject_source',
                        inputs={target_name: beam_fn_api_pb2.Target.List()},
                        outputs={
                            'out':
                            beam_fn_api_pb2.PCollection(
                                coder_reference=coder_id(source_coder))
                        })
                    if data_operation_spec:
                        input_ptransform.function_spec.data.Pack(
                            data_operation_spec)
                    transforms.append(input_ptransform)

                    # Read the elements out of the source.
                    fn = sdk_worker.pack_function_spec_data(
                        OLDE_SOURCE_SPLITTABLE_DOFN_DATA,
                        sdk_worker.PYTHON_DOFN_URN,
                        id=self._next_uid())
                    inputs = {
                        'ignored_input_tag':
                        beam_fn_api_pb2.Target.List(target=[
                            beam_fn_api_pb2.Target(
                                primitive_transform_reference=
                                input_transform_id,
                                name='out')
                        ])
                    }
                    side_inputs = {}

            elif isinstance(operation, operation_specs.WorkerDoFn):
                fn = sdk_worker.pack_function_spec_data(
                    operation.serialized_fn,
                    sdk_worker.PYTHON_DOFN_URN,
                    id=self._next_uid())
                inputs = as_target(operation.input)
                # Store the contents of each side input for state access.
                for si in operation.side_inputs:
                    assert isinstance(si.source, iobase.BoundedSource)
                    element_coder = si.source.default_output_coder()
                    view_id = self._next_uid()
                    # TODO(robertwb): Actually flesh out the ViewFn API.
                    side_inputs[si.tag] = beam_fn_api_pb2.SideInput(
                        view_fn=sdk_worker.serialize_and_pack_py_fn(
                            element_coder,
                            urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN,
                            id=view_id))
                    # Re-encode the elements in the nested context and
                    # concatenate them together
                    output_stream = create_OutputStream()
                    for element in si.source.read(
                            si.source.get_range_tracker(None, None)):
                        element_coder.get_impl().encode_to_stream(
                            element, output_stream, True)
                    elements_data = output_stream.get()
                    state_key = beam_fn_api_pb2.StateKey.MultimapSideInput(
                        key=view_id)
                    state_handler.Clear(state_key)
                    state_handler.Append(state_key, elements_data)

            elif isinstance(operation, operation_specs.WorkerFlatten):
                fn = sdk_worker.pack_function_spec_data(
                    operation.serialized_fn,
                    sdk_worker.IDENTITY_DOFN_URN,
                    id=self._next_uid())
                inputs = {
                    'ignored_input_tag':
                    beam_fn_api_pb2.Target.List(target=[
                        beam_fn_api_pb2.Target(
                            primitive_transform_reference=
                            transform_index_to_id[input_op_index],
                            name=output_tags(map_task[input_op_index]
                                             [1])[input_output_index]) for
                        input_op_index, input_output_index in operation.inputs
                    ])
                }
                side_inputs = {}

            else:
                raise TypeError(operation)

            ptransform = beam_fn_api_pb2.PrimitiveTransform(
                id=transform_id,
                function_spec=fn,
                step_name=stage_name,
                inputs=inputs,
                side_inputs=side_inputs,
                outputs=outputs(operation))
            transforms.append(ptransform)

        process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._next_uid(),
            coders=coders.values(),
            primitive_transform=transforms)
        return beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            register=beam_fn_api_pb2.RegisterRequest(
                process_bundle_descriptor=[process_bundle_descriptor
                                           ])), runner_sinks, input_data
Ejemplo n.º 11
0
  def run_stage(
      self, controller, pipeline_components, stage, pcoll_buffers, safe_coders):

    coders = pipeline_context.PipelineContext(pipeline_components).coders
    data_operation_spec = controller.data_operation_spec()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data data_operation_spec.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        pcoll_id = transform.spec.payload
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            data_input[target] = pcoll_id
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
          else:
            raise NotImplementedError
          if data_operation_spec:
            transform.spec.payload = data_operation_spec.SerializeToString()
            transform.spec.any_param.Pack(data_operation_spec)
          else:
            transform.spec.payload = ""
            transform.spec.any_param.Clear()
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)
    if data_side_input:
      raise NotImplementedError('Side inputs.')

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
        instruction_id=self._next_uid(),
        register=beam_fn_api_pb2.RegisterRequest(
            process_bundle_descriptor=[process_bundle_descriptor]))

    process_bundle = beam_fn_api_pb2.InstructionRequest(
        instruction_id=self._next_uid(),
        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
            process_bundle_descriptor_reference=
            process_bundle_descriptor.id))

    # Write all the input data to the channel.
    for (transform_id, name), pcoll_id in data_input.items():
      data_out = controller.data_plane_handler.output_stream(
          process_bundle.instruction_id, beam_fn_api_pb2.Target(
              primitive_transform_reference=transform_id, name=name))
      for element_data in pcoll_buffers[pcoll_id]:
        data_out.write(element_data)
      data_out.close()

    # Register and start running the bundle.
    controller.control_handler.push(process_bundle_registration)
    controller.control_handler.push(process_bundle)

    # Wait for the bundle to finish.
    while True:
      result = controller.control_handler.pull()
      if result and result.instruction_id == process_bundle.instruction_id:
        if result.error:
          raise RuntimeError(result.error)
        break

    # Gather all output data.
    expected_targets = [
        beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                               name=output_name)
        for (transform_id, output_name), _ in data_output.items()]
    for output in controller.data_plane_handler.input_elements(
        process_bundle.instruction_id, expected_targets):
      target_tuple = (
          output.target.primitive_transform_reference, output.target.name)
      if target_tuple in data_output:
        pcoll_id = data_output[target_tuple]
        if pcoll_id.startswith('materialize:'):
          # Just store the data chunks for replay.
          pcoll_buffers[pcoll_id].append(output.data)
        elif pcoll_id.startswith('group:'):
          # This is a grouping write, create a grouping buffer if needed.
          if pcoll_id not in pcoll_buffers:
            original_gbk_transform = pcoll_id.split(':', 1)[1]
            transform_proto = pipeline_components.transforms[
                original_gbk_transform]
            input_pcoll = only_element(transform_proto.inputs.values())
            output_pcoll = only_element(transform_proto.outputs.values())
            pre_gbk_coder = coders[safe_coders[
                pipeline_components.pcollections[input_pcoll].coder_id]]
            post_gbk_coder = coders[safe_coders[
                pipeline_components.pcollections[output_pcoll].coder_id]]
            pcoll_buffers[pcoll_id] = _GroupingBuffer(
                pre_gbk_coder, post_gbk_coder)
          pcoll_buffers[pcoll_id].append(output.data)
        else:
          # These should be the only two identifiers we produce for now,
          # but special side input writes may go here.
          raise NotImplementedError(pcoll_id)
Ejemplo n.º 12
0
    def create_execution_tree_from_fn_api(self, descriptor):
        # TODO(vikasrk): Add an id field to Coder proto and use that instead.
        coders = {
            coder.function_spec.id: operation_specs.get_coder_from_spec(
                json.loads(unpack_function_spec_data(coder.function_spec)))
            for coder in descriptor.coders
        }

        counter_factory = counters.CounterFactory()
        # TODO(robertwb): Figure out the correct prefix to use for output counters
        # from StateSampler.
        state_sampler = statesampler.StateSampler(
            'fnapi-step%s-' % descriptor.id, counter_factory)
        consumers = collections.defaultdict(
            lambda: collections.defaultdict(list))
        ops_by_id = {}
        reversed_ops = []

        for transform in reversed(descriptor.primitive_transform):
            # TODO(robertwb): Figure out how to plumb through the operation name (e.g.
            # "s3") from the service through the FnAPI so that msec counters can be
            # reported and correctly plumbed through the service and the UI.
            operation_name = 'fnapis%s' % transform.id

            def only_element(iterable):
                element, = iterable
                return element

            if transform.function_spec.urn == DATA_OUTPUT_URN:
                target = beam_fn_api_pb2.Target(
                    primitive_transform_reference=transform.id,
                    name=only_element(transform.outputs.keys()))

                op = DataOutputOperation(
                    operation_name, transform.step_name,
                    consumers[transform.id], counter_factory, state_sampler,
                    coders[only_element(
                        transform.outputs.values()).coder_reference], target,
                    self.data_channel_factory.create_data_channel(
                        transform.function_spec))

            elif transform.function_spec.urn == DATA_INPUT_URN:
                target = beam_fn_api_pb2.Target(
                    primitive_transform_reference=transform.id,
                    name=only_element(transform.inputs.keys()))
                op = DataInputOperation(
                    operation_name, transform.step_name,
                    consumers[transform.id], counter_factory, state_sampler,
                    coders[only_element(
                        transform.outputs.values()).coder_reference], target,
                    self.data_channel_factory.create_data_channel(
                        transform.function_spec))

            elif transform.function_spec.urn == PYTHON_DOFN_URN:

                def create_side_input(tag, si):
                    # TODO(robertwb): Extract windows (and keys) out of element data.
                    return operation_specs.WorkerSideInputSource(
                        tag=tag,
                        source=SideInputSource(
                            self.state_handler,
                            beam_fn_api_pb2.StateKey.MultimapSideInput(
                                key=si.view_fn.id.encode('utf-8')),
                            coder=unpack_and_deserialize_py_fn(si.view_fn)))

                output_tags = list(transform.outputs.keys())
                spec = operation_specs.WorkerDoFn(
                    serialized_fn=unpack_function_spec_data(
                        transform.function_spec),
                    output_tags=output_tags,
                    input=None,
                    side_inputs=[
                        create_side_input(tag, si)
                        for tag, si in transform.side_inputs.items()
                    ],
                    output_coders=[
                        coders[transform.outputs[out].coder_reference]
                        for out in output_tags
                    ])

                op = operations.DoOperation(operation_name, spec,
                                            counter_factory, state_sampler)
                # TODO(robertwb): Move these to the constructor.
                op.step_name = transform.step_name
                for tag, op_consumers in consumers[transform.id].items():
                    for consumer in op_consumers:
                        op.add_receiver(consumer, output_tags.index(tag))

            elif transform.function_spec.urn == IDENTITY_DOFN_URN:
                op = operations.FlattenOperation(operation_name, None,
                                                 counter_factory,
                                                 state_sampler)
                # TODO(robertwb): Move these to the constructor.
                op.step_name = transform.step_name
                for tag, op_consumers in consumers[transform.id].items():
                    for consumer in op_consumers:
                        op.add_receiver(consumer, 0)

            elif transform.function_spec.urn == PYTHON_SOURCE_URN:
                source = load_compressed(
                    unpack_function_spec_data(transform.function_spec))
                # TODO(vikasrk): Remove this once custom source is implemented with
                # splittable dofn via the data plane.
                spec = operation_specs.WorkerRead(
                    iobase.SourceBundle(1.0, source, None, None),
                    [WindowedValueCoder(source.default_output_coder())])
                op = operations.ReadOperation(operation_name, spec,
                                              counter_factory, state_sampler)
                op.step_name = transform.step_name
                output_tags = list(transform.outputs.keys())
                for tag, op_consumers in consumers[transform.id].items():
                    for consumer in op_consumers:
                        op.add_receiver(consumer, output_tags.index(tag))

            else:
                raise NotImplementedError

            # Record consumers.
            for _, inputs in transform.inputs.items():
                for target in inputs.target:
                    consumers[target.primitive_transform_reference][
                        target.name].append(op)

            reversed_ops.append(op)
            ops_by_id[transform.id] = op

        return list(reversed(reversed_ops))