Exemple #1
0
    def test_fn_registration(self):
        process_bundle_descriptors = [
            beam_fn_api_pb2.ProcessBundleDescriptor(
                id=str(100 + ix),
                transforms={
                    str(ix):
                    beam_runner_api_pb2.PTransform(unique_name=str(ix))
                }) for ix in range(4)
        ]

        test_controller = BeamFnControlServicer([
            beam_fn_api_pb2.InstructionRequest(
                register=beam_fn_api_pb2.RegisterRequest(
                    process_bundle_descriptor=process_bundle_descriptors))
        ])

        server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
            test_controller, server)
        test_port = server.add_insecure_port("[::]:0")
        server.start()

        channel = grpc.insecure_channel("localhost:%s" % test_port)
        harness = sdk_worker.SdkHarness(channel)
        harness.run()
        self.assertEqual(
            harness.worker.fns,
            {item.id: item
             for item in process_bundle_descriptors})
Exemple #2
0
    def test_source_split(self):
        source = RangeSource(0, 100)
        expected_splits = list(source.split(30))

        worker = sdk_harness.SdkWorker(
            None, data_plane.GrpcClientDataChannelFactory())
        worker.register(
            beam_fn_api_pb2.RegisterRequest(process_bundle_descriptor=[
                beam_fn_api_pb2.ProcessBundleDescriptor(primitive_transform=[
                    beam_fn_api_pb2.PrimitiveTransform(
                        function_spec=sdk_harness.serialize_and_pack_py_fn(
                            SourceBundle(1.0, source, None, None),
                            sdk_harness.PYTHON_SOURCE_URN,
                            id="src"))
                ])
            ]))
        split_response = worker.initial_source_split(
            beam_fn_api_pb2.InitialSourceSplitRequest(
                desired_bundle_size_bytes=30, source_reference="src"))

        self.assertEqual(expected_splits, [
            sdk_harness.unpack_and_deserialize_py_fn(s.source)
            for s in split_response.splits
        ])

        self.assertEqual([s.weight for s in expected_splits],
                         [s.relative_size for s in split_response.splits])
Exemple #3
0
    def test_fn_registration(self):
        fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)]

        process_bundle_descriptors = [
            beam_fn_api_pb2.ProcessBundleDescriptor(
                id=str(100 + ix),
                primitive_transform=[
                    beam_fn_api_pb2.PrimitiveTransform(function_spec=fn)
                ]) for ix, fn in enumerate(fns)
        ]

        test_controller = BeamFnControlServicer([
            beam_fn_api_pb2.InstructionRequest(
                register=beam_fn_api_pb2.RegisterRequest(
                    process_bundle_descriptor=process_bundle_descriptors))
        ])

        server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
            test_controller, server)
        test_port = server.add_insecure_port("[::]:0")
        server.start()

        channel = grpc.insecure_channel("localhost:%s" % test_port)
        harness = sdk_worker.SdkHarness(channel)
        harness.run()
        self.assertEqual(
            harness.worker.fns,
            {item.id: item
             for item in fns + process_bundle_descriptors})
Exemple #4
0
 def _get_process_bundles(self, prefix, size):
   return [
       beam_fn_api_pb2.ProcessBundleDescriptor(
           id=str(str(prefix) + "-" + str(ix)),
           transforms={
               str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix))
           }) for ix in range(size)
   ]
Exemple #5
0
    def test_source_split_via_instruction(self):

        source = RangeSource(0, 100)
        expected_splits = list(source.split(30))

        test_controller = BeamFnControlServicer([
            beam_fn_api_pb2.InstructionRequest(
                instruction_id="register_request",
                register=beam_fn_api_pb2.RegisterRequest(
                    process_bundle_descriptor=[
                        beam_fn_api_pb2.ProcessBundleDescriptor(
                            primitive_transform=[
                                beam_fn_api_pb2.PrimitiveTransform(
                                    function_spec=sdk_harness.
                                    serialize_and_pack_py_fn(
                                        SourceBundle(1.0, source, None, None),
                                        sdk_harness.PYTHON_SOURCE_URN,
                                        id="src"))
                            ])
                    ])),
            beam_fn_api_pb2.InstructionRequest(
                instruction_id="split_request",
                initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest(
                    desired_bundle_size_bytes=30, source_reference="src"))
        ])

        server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
            test_controller, server)
        test_port = server.add_insecure_port("[::]:0")
        server.start()

        channel = grpc.insecure_channel("localhost:%s" % test_port)
        harness = sdk_harness.SdkHarness(channel)
        harness.run()

        split_response = test_controller.responses[
            "split_request"].initial_source_split

        self.assertEqual(expected_splits, [
            sdk_harness.unpack_and_deserialize_py_fn(s.source)
            for s in split_response.splits
        ])

        self.assertEqual([s.weight for s in expected_splits],
                         [s.relative_size for s in split_response.splits])
Exemple #6
0
 def _build_process_bundle_descriptor(self):
   res = beam_fn_api_pb2.ProcessBundleDescriptor(
       id=self.bundle_uid,
       transforms={
           transform.unique_name: transform
           for transform in self.stage.transforms
       },
       pcollections=dict(
           self.execution_context.pipeline_components.pcollections.items()),
       coders=dict(self.execution_context.pipeline_components.coders.items()),
       windowing_strategies=dict(
           self.execution_context.pipeline_components.windowing_strategies.
           items()),
       environments=dict(
           self.execution_context.pipeline_components.environments.items()),
       state_api_service_descriptor=self.state_api_service_descriptor())
   return res
Exemple #7
0
 def _build_process_bundle_descriptor(self):
     # Cannot be invoked until *after* _extract_endpoints is called.
     return beam_fn_api_pb2.ProcessBundleDescriptor(
         id=self.bundle_uid,
         transforms={
             transform.unique_name: transform
             for transform in self.stage.transforms
         },
         pcollections=dict(self.execution_context.pipeline_components.
                           pcollections.items()),
         coders=dict(
             self.execution_context.pipeline_components.coders.items()),
         windowing_strategies=dict(
             self.execution_context.pipeline_components.
             windowing_strategies.items()),
         environments=dict(self.execution_context.pipeline_components.
                           environments.items()),
         state_api_service_descriptor=self.state_api_service_descriptor())
Exemple #8
0
    def run_stage(self, controller, pipeline_components, stage, pcoll_buffers,
                  safe_coders):

        context = pipeline_context.PipelineContext(pipeline_components)
        data_operation_spec = controller.data_operation_spec()

        def extract_endpoints(stage):
            # Returns maps of transform names to PCollection identifiers.
            # Also mutates IO stages to point to the data data_operation_spec.
            data_input = {}
            data_side_input = {}
            data_output = {}
            for transform in stage.transforms:
                if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                          bundle_processor.DATA_OUTPUT_URN):
                    pcoll_id = transform.spec.payload
                    if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
                        target = transform.unique_name, only_element(
                            transform.outputs)
                        data_input[target] = pcoll_id
                    elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
                        target = transform.unique_name, only_element(
                            transform.inputs)
                        data_output[target] = pcoll_id
                    else:
                        raise NotImplementedError
                    if data_operation_spec:
                        transform.spec.payload = data_operation_spec.SerializeToString(
                        )
                    else:
                        transform.spec.payload = ""
                elif transform.spec.urn == urns.PARDO_TRANSFORM:
                    payload = proto_utils.parse_Bytes(
                        transform.spec.payload,
                        beam_runner_api_pb2.ParDoPayload)
                    for tag, si in payload.side_inputs.items():
                        data_side_input[transform.unique_name, tag] = (
                            'materialize:' + transform.inputs[tag],
                            beam.pvalue.SideInputData.from_runner_api(
                                si, None))
            return data_input, data_side_input, data_output

        logging.info('Running %s', stage.name)
        logging.debug('       %s', stage)
        data_input, data_side_input, data_output = extract_endpoints(stage)

        process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._next_uid(),
            transforms={
                transform.unique_name: transform
                for transform in stage.transforms
            },
            pcollections=dict(pipeline_components.pcollections.items()),
            coders=dict(pipeline_components.coders.items()),
            windowing_strategies=dict(
                pipeline_components.windowing_strategies.items()),
            environments=dict(pipeline_components.environments.items()))

        process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            register=beam_fn_api_pb2.RegisterRequest(
                process_bundle_descriptor=[process_bundle_descriptor]))

        process_bundle = beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
                process_bundle_descriptor_reference=process_bundle_descriptor.
                id))

        # Write all the input data to the channel.
        for (transform_id, name), pcoll_id in data_input.items():
            data_out = controller.data_plane_handler.output_stream(
                process_bundle.instruction_id,
                beam_fn_api_pb2.Target(
                    primitive_transform_reference=transform_id, name=name))
            for element_data in pcoll_buffers[pcoll_id]:
                data_out.write(element_data)
            data_out.close()

        # Store the required side inputs into state.
        for (transform_id, tag), (pcoll_id, si) in data_side_input.items():
            elements_by_window = _WindowGroupingBuffer(si)
            for element_data in pcoll_buffers[pcoll_id]:
                elements_by_window.append(element_data)
            for window, elements_data in elements_by_window.items():
                state_key = beam_fn_api_pb2.StateKey(
                    multimap_side_input=beam_fn_api_pb2.StateKey.
                    MultimapSideInput(ptransform_id=transform_id,
                                      side_input_id=tag,
                                      window=window))
                controller.state_handler.blocking_append(
                    state_key, elements_data, process_bundle.instruction_id)

        # Register and start running the bundle.
        logging.debug('Register and start running the bundle')
        controller.control_handler.push(process_bundle_registration)
        controller.control_handler.push(process_bundle)

        # Wait for the bundle to finish.
        logging.debug('Wait for the bundle to finish.')
        while True:
            result = controller.control_handler.pull()
            if result and result.instruction_id == process_bundle.instruction_id:
                if result.error:
                    raise RuntimeError(result.error)
                break

        expected_targets = [
            beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                                   name=output_name)
            for (transform_id, output_name), _ in data_output.items()
        ]

        # Gather all output data.
        logging.debug('Gather all output data from %s.', expected_targets)

        for output in controller.data_plane_handler.input_elements(
                process_bundle.instruction_id, expected_targets):
            target_tuple = (output.target.primitive_transform_reference,
                            output.target.name)
            if target_tuple in data_output:
                pcoll_id = data_output[target_tuple]
                if pcoll_id.startswith('materialize:'):
                    # Just store the data chunks for replay.
                    pcoll_buffers[pcoll_id].append(output.data)
                elif pcoll_id.startswith('group:'):
                    # This is a grouping write, create a grouping buffer if needed.
                    if pcoll_id not in pcoll_buffers:
                        original_gbk_transform = pcoll_id.split(':', 1)[1]
                        transform_proto = pipeline_components.transforms[
                            original_gbk_transform]
                        input_pcoll = only_element(
                            transform_proto.inputs.values())
                        output_pcoll = only_element(
                            transform_proto.outputs.values())
                        pre_gbk_coder = context.coders[
                            safe_coders[pipeline_components.
                                        pcollections[input_pcoll].coder_id]]
                        post_gbk_coder = context.coders[
                            safe_coders[pipeline_components.
                                        pcollections[output_pcoll].coder_id]]
                        windowing_strategy = context.windowing_strategies[
                            pipeline_components.pcollections[output_pcoll].
                            windowing_strategy_id]
                        pcoll_buffers[pcoll_id] = _GroupingBuffer(
                            pre_gbk_coder, post_gbk_coder, windowing_strategy)
                    pcoll_buffers[pcoll_id].append(output.data)
                else:
                    # These should be the only two identifiers we produce for now,
                    # but special side input writes may go here.
                    raise NotImplementedError(pcoll_id)
        return result
Exemple #9
0
  def _map_task_to_protos(self, map_task, data_operation_spec):
    input_data = {}
    side_input_data = {}
    runner_sinks = {}

    context = pipeline_context.PipelineContext()
    transform_protos = {}
    used_pcollections = {}

    def uniquify(*names):
      # An injective mapping from string* to string.
      return ':'.join("%s:%d" % (name, len(name)) for name in names)

    def pcollection_id(op_ix, out_ix):
      if (op_ix, out_ix) not in used_pcollections:
        used_pcollections[op_ix, out_ix] = uniquify(
            map_task[op_ix][0], 'out', str(out_ix))
      return used_pcollections[op_ix, out_ix]

    def get_inputs(op):
      if hasattr(op, 'inputs'):
        inputs = op.inputs
      elif hasattr(op, 'input'):
        inputs = [op.input]
      else:
        inputs = []
      return {'in%s' % ix: pcollection_id(*input)
              for ix, input in enumerate(inputs)}

    def get_outputs(op_ix):
      op = map_task[op_ix][1]
      return {tag: pcollection_id(op_ix, out_ix)
              for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))}

    for op_ix, (stage_name, operation) in enumerate(map_task):
      transform_id = uniquify(stage_name)

      if isinstance(operation, operation_specs.WorkerInMemoryWrite):
        # Write this data back to the runner.
        target_name = only_element(get_inputs(operation).keys())
        runner_sinks[(transform_id, target_name)] = operation
        transform_spec = beam_runner_api_pb2.FunctionSpec(
            urn=bundle_processor.DATA_OUTPUT_URN,
            any_param=proto_utils.pack_Any(data_operation_spec),
            payload=data_operation_spec.SerializeToString() \
                if data_operation_spec is not None else None)

      elif isinstance(operation, operation_specs.WorkerRead):
        # A Read from an in-memory source is done over the data plane.
        if (isinstance(operation.source.source,
                       maptask_executor_runner.InMemorySource)
            and isinstance(operation.source.source.default_output_coder(),
                           WindowedValueCoder)):
          target_name = only_element(get_outputs(op_ix).keys())
          input_data[(transform_id, target_name)] = self._reencode_elements(
              operation.source.source.read(None),
              operation.source.source.default_output_coder())
          transform_spec = beam_runner_api_pb2.FunctionSpec(
              urn=bundle_processor.DATA_INPUT_URN,
              any_param=proto_utils.pack_Any(data_operation_spec),
              payload=data_operation_spec.SerializeToString() \
                  if data_operation_spec is not None else None)

        else:
          # Otherwise serialize the source and execute it there.
          # TODO: Use SDFs with an initial impulse.
          # The Dataflow runner harness strips the base64 encoding. do the same
          # here until we get the same thing back that we sent in.
          source_bytes = base64.b64decode(
              pickler.dumps(operation.source.source))
          transform_spec = beam_runner_api_pb2.FunctionSpec(
              urn=bundle_processor.PYTHON_SOURCE_URN,
              any_param=proto_utils.pack_Any(
                  wrappers_pb2.BytesValue(
                      value=source_bytes)),
              payload=source_bytes)

      elif isinstance(operation, operation_specs.WorkerDoFn):
        # Record the contents of each side input for access via the state api.
        side_input_extras = []
        for si in operation.side_inputs:
          assert isinstance(si.source, iobase.BoundedSource)
          element_coder = si.source.default_output_coder()
          # TODO(robertwb): Actually flesh out the ViewFn API.
          side_input_extras.append((si.tag, element_coder))
          side_input_data[
              bundle_processor.side_input_tag(transform_id, si.tag)] = (
                  self._reencode_elements(
                      si.source.read(si.source.get_range_tracker(None, None)),
                      element_coder))
        augmented_serialized_fn = pickler.dumps(
            (operation.serialized_fn, side_input_extras))
        transform_spec = beam_runner_api_pb2.FunctionSpec(
            urn=bundle_processor.PYTHON_DOFN_URN,
            any_param=proto_utils.pack_Any(
                wrappers_pb2.BytesValue(value=augmented_serialized_fn)),
            payload=augmented_serialized_fn)

      elif isinstance(operation, operation_specs.WorkerFlatten):
        # Flatten is nice and simple.
        transform_spec = beam_runner_api_pb2.FunctionSpec(
            urn=bundle_processor.IDENTITY_DOFN_URN)

      else:
        raise NotImplementedError(operation)

      transform_protos[transform_id] = beam_runner_api_pb2.PTransform(
          unique_name=stage_name,
          spec=transform_spec,
          inputs=get_inputs(operation),
          outputs=get_outputs(op_ix))

    pcollection_protos = {
        name: beam_runner_api_pb2.PCollection(
            unique_name=name,
            coder_id=context.coders.get_id(
                map_task[op_id][1].output_coders[out_id]))
        for (op_id, out_id), name in used_pcollections.items()
    }
    # Must follow creation of pcollection_protos to capture used coders.
    context_proto = context.to_runner_api()
    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms=transform_protos,
        pcollections=pcollection_protos,
        coders=dict(context_proto.coders.items()),
        windowing_strategies=dict(context_proto.windowing_strategies.items()),
        environments=dict(context_proto.environments.items()))
    return input_data, side_input_data, runner_sinks, process_bundle_descriptor
    def _map_task_registration(self, map_task, state_handler,
                               data_operation_spec):
        input_data = {}
        runner_sinks = {}
        transforms = []
        transform_index_to_id = {}

        # Maps coders to new coder objects and references.
        coders = {}

        def coder_id(coder):
            if coder not in coders:
                coders[coder] = beam_fn_api_pb2.Coder(
                    function_spec=sdk_worker.pack_function_spec_data(
                        json.dumps(coder.as_cloud_object()),
                        sdk_worker.PYTHON_CODER_URN,
                        id=self._next_uid()))

            return coders[coder].function_spec.id

        def output_tags(op):
            return getattr(op, 'output_tags', ['out'])

        def as_target(op_input):
            input_op_index, input_output_index = op_input
            input_op = map_task[input_op_index][1]
            return {
                'ignored_input_tag':
                beam_fn_api_pb2.Target.List(target=[
                    beam_fn_api_pb2.Target(
                        primitive_transform_reference=transform_index_to_id[
                            input_op_index],
                        name=output_tags(input_op)[input_output_index])
                ])
            }

        def outputs(op):
            return {
                tag:
                beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder))
                for tag, coder in zip(output_tags(op), op.output_coders)
            }

        for op_ix, (stage_name, operation) in enumerate(map_task):
            transform_id = transform_index_to_id[op_ix] = self._next_uid()
            if isinstance(operation, operation_specs.WorkerInMemoryWrite):
                # Write this data back to the runner.
                fn = beam_fn_api_pb2.FunctionSpec(
                    urn=sdk_worker.DATA_OUTPUT_URN, id=self._next_uid())
                if data_operation_spec:
                    fn.data.Pack(data_operation_spec)
                inputs = as_target(operation.input)
                side_inputs = {}
                runner_sinks[(transform_id, 'out')] = operation

            elif isinstance(operation, operation_specs.WorkerRead):
                # A Read is either translated to a direct injection of windowed values
                # into the sdk worker, or an injection of the source object into the
                # sdk worker as data followed by an SDF that reads that source.
                if (isinstance(operation.source.source,
                               maptask_executor_runner.InMemorySource)
                        and isinstance(
                            operation.source.source.default_output_coder(),
                            WindowedValueCoder)):
                    output_stream = create_OutputStream()
                    element_coder = (operation.source.source.
                                     default_output_coder().get_impl())
                    # Re-encode the elements in the nested context and
                    # concatenate them together
                    for element in operation.source.source.read(None):
                        element_coder.encode_to_stream(element, output_stream,
                                                       True)
                    target_name = self._next_uid()
                    input_data[(transform_id,
                                target_name)] = output_stream.get()
                    fn = beam_fn_api_pb2.FunctionSpec(
                        urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid())
                    if data_operation_spec:
                        fn.data.Pack(data_operation_spec)
                    inputs = {target_name: beam_fn_api_pb2.Target.List()}
                    side_inputs = {}
                else:
                    # Read the source object from the runner.
                    source_coder = beam.coders.DillCoder()
                    input_transform_id = self._next_uid()
                    output_stream = create_OutputStream()
                    source_coder.get_impl().encode_to_stream(
                        GlobalWindows.windowed_value(operation.source),
                        output_stream, True)
                    target_name = self._next_uid()
                    input_data[(input_transform_id,
                                target_name)] = output_stream.get()
                    input_ptransform = beam_fn_api_pb2.PrimitiveTransform(
                        id=input_transform_id,
                        function_spec=beam_fn_api_pb2.FunctionSpec(
                            urn=sdk_worker.DATA_INPUT_URN,
                            id=self._next_uid()),
                        # TODO(robertwb): Possible name collision.
                        step_name=stage_name + '/inject_source',
                        inputs={target_name: beam_fn_api_pb2.Target.List()},
                        outputs={
                            'out':
                            beam_fn_api_pb2.PCollection(
                                coder_reference=coder_id(source_coder))
                        })
                    if data_operation_spec:
                        input_ptransform.function_spec.data.Pack(
                            data_operation_spec)
                    transforms.append(input_ptransform)

                    # Read the elements out of the source.
                    fn = sdk_worker.pack_function_spec_data(
                        OLDE_SOURCE_SPLITTABLE_DOFN_DATA,
                        sdk_worker.PYTHON_DOFN_URN,
                        id=self._next_uid())
                    inputs = {
                        'ignored_input_tag':
                        beam_fn_api_pb2.Target.List(target=[
                            beam_fn_api_pb2.Target(
                                primitive_transform_reference=
                                input_transform_id,
                                name='out')
                        ])
                    }
                    side_inputs = {}

            elif isinstance(operation, operation_specs.WorkerDoFn):
                fn = sdk_worker.pack_function_spec_data(
                    operation.serialized_fn,
                    sdk_worker.PYTHON_DOFN_URN,
                    id=self._next_uid())
                inputs = as_target(operation.input)
                # Store the contents of each side input for state access.
                for si in operation.side_inputs:
                    assert isinstance(si.source, iobase.BoundedSource)
                    element_coder = si.source.default_output_coder()
                    view_id = self._next_uid()
                    # TODO(robertwb): Actually flesh out the ViewFn API.
                    side_inputs[si.tag] = beam_fn_api_pb2.SideInput(
                        view_fn=sdk_worker.serialize_and_pack_py_fn(
                            element_coder,
                            urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN,
                            id=view_id))
                    # Re-encode the elements in the nested context and
                    # concatenate them together
                    output_stream = create_OutputStream()
                    for element in si.source.read(
                            si.source.get_range_tracker(None, None)):
                        element_coder.get_impl().encode_to_stream(
                            element, output_stream, True)
                    elements_data = output_stream.get()
                    state_key = beam_fn_api_pb2.StateKey.MultimapSideInput(
                        key=view_id)
                    state_handler.Clear(state_key)
                    state_handler.Append(state_key, elements_data)

            elif isinstance(operation, operation_specs.WorkerFlatten):
                fn = sdk_worker.pack_function_spec_data(
                    operation.serialized_fn,
                    sdk_worker.IDENTITY_DOFN_URN,
                    id=self._next_uid())
                inputs = {
                    'ignored_input_tag':
                    beam_fn_api_pb2.Target.List(target=[
                        beam_fn_api_pb2.Target(
                            primitive_transform_reference=
                            transform_index_to_id[input_op_index],
                            name=output_tags(map_task[input_op_index]
                                             [1])[input_output_index]) for
                        input_op_index, input_output_index in operation.inputs
                    ])
                }
                side_inputs = {}

            else:
                raise TypeError(operation)

            ptransform = beam_fn_api_pb2.PrimitiveTransform(
                id=transform_id,
                function_spec=fn,
                step_name=stage_name,
                inputs=inputs,
                side_inputs=side_inputs,
                outputs=outputs(operation))
            transforms.append(ptransform)

        process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._next_uid(),
            coders=coders.values(),
            primitive_transform=transforms)
        return beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            register=beam_fn_api_pb2.RegisterRequest(
                process_bundle_descriptor=[process_bundle_descriptor
                                           ])), runner_sinks, input_data
Exemple #11
0
  def run_stage(
      self, controller, pipeline_components, stage, pcoll_buffers, safe_coders):

    coders = pipeline_context.PipelineContext(pipeline_components).coders
    data_operation_spec = controller.data_operation_spec()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data data_operation_spec.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        pcoll_id = transform.spec.payload
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            data_input[target] = pcoll_id
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
          else:
            raise NotImplementedError
          if data_operation_spec:
            transform.spec.payload = data_operation_spec.SerializeToString()
            transform.spec.any_param.Pack(data_operation_spec)
          else:
            transform.spec.payload = ""
            transform.spec.any_param.Clear()
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)
    if data_side_input:
      raise NotImplementedError('Side inputs.')

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
        instruction_id=self._next_uid(),
        register=beam_fn_api_pb2.RegisterRequest(
            process_bundle_descriptor=[process_bundle_descriptor]))

    process_bundle = beam_fn_api_pb2.InstructionRequest(
        instruction_id=self._next_uid(),
        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
            process_bundle_descriptor_reference=
            process_bundle_descriptor.id))

    # Write all the input data to the channel.
    for (transform_id, name), pcoll_id in data_input.items():
      data_out = controller.data_plane_handler.output_stream(
          process_bundle.instruction_id, beam_fn_api_pb2.Target(
              primitive_transform_reference=transform_id, name=name))
      for element_data in pcoll_buffers[pcoll_id]:
        data_out.write(element_data)
      data_out.close()

    # Register and start running the bundle.
    controller.control_handler.push(process_bundle_registration)
    controller.control_handler.push(process_bundle)

    # Wait for the bundle to finish.
    while True:
      result = controller.control_handler.pull()
      if result and result.instruction_id == process_bundle.instruction_id:
        if result.error:
          raise RuntimeError(result.error)
        break

    # Gather all output data.
    expected_targets = [
        beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
                               name=output_name)
        for (transform_id, output_name), _ in data_output.items()]
    for output in controller.data_plane_handler.input_elements(
        process_bundle.instruction_id, expected_targets):
      target_tuple = (
          output.target.primitive_transform_reference, output.target.name)
      if target_tuple in data_output:
        pcoll_id = data_output[target_tuple]
        if pcoll_id.startswith('materialize:'):
          # Just store the data chunks for replay.
          pcoll_buffers[pcoll_id].append(output.data)
        elif pcoll_id.startswith('group:'):
          # This is a grouping write, create a grouping buffer if needed.
          if pcoll_id not in pcoll_buffers:
            original_gbk_transform = pcoll_id.split(':', 1)[1]
            transform_proto = pipeline_components.transforms[
                original_gbk_transform]
            input_pcoll = only_element(transform_proto.inputs.values())
            output_pcoll = only_element(transform_proto.outputs.values())
            pre_gbk_coder = coders[safe_coders[
                pipeline_components.pcollections[input_pcoll].coder_id]]
            post_gbk_coder = coders[safe_coders[
                pipeline_components.pcollections[output_pcoll].coder_id]]
            pcoll_buffers[pcoll_id] = _GroupingBuffer(
                pre_gbk_coder, post_gbk_coder)
          pcoll_buffers[pcoll_id].append(output.data)
        else:
          # These should be the only two identifiers we produce for now,
          # but special side input writes may go here.
          raise NotImplementedError(pcoll_id)
Exemple #12
0
    def make_process_bundle_descriptor(self, data_api_service_descriptor,
                                       state_api_service_descriptor):
        # type: (Optional[endpoints_pb2.ApiServiceDescriptor], Optional[endpoints_pb2.ApiServiceDescriptor]) -> beam_fn_api_pb2.ProcessBundleDescriptor
        """Creates a ProcessBundleDescriptor for invoking the WindowFn's
    merge operation.
    """
        def make_channel_payload(coder_id):
            # type: (str) -> bytes
            data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
            if data_api_service_descriptor:
                data_spec.api_service_descriptor.url = (
                    data_api_service_descriptor.url)
            return data_spec.SerializeToString()

        pipeline_context = self._execution_context_ref().pipeline_context
        global_windowing_strategy_id = self.uid('global_windowing_strategy')
        global_windowing_strategy_proto = core.Windowing(
            window.GlobalWindows()).to_runner_api(pipeline_context)
        coders = dict(pipeline_context.coders.get_id_to_proto_map())

        def make_coder(urn, *components):
            # type: (str, str) -> str
            coder_proto = beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.FunctionSpec(urn=urn),
                component_coder_ids=components)
            coder_id = self.uid('coder')
            coders[coder_id] = coder_proto
            pipeline_context.coders.put_proto(coder_id, coder_proto)
            return coder_id

        bytes_coder_id = make_coder(common_urns.coders.BYTES.urn)
        window_coder_id = self._windowing_strategy_proto.window_coder_id
        global_window_coder_id = make_coder(
            common_urns.coders.GLOBAL_WINDOW.urn)
        iter_window_coder_id = make_coder(common_urns.coders.ITERABLE.urn,
                                          window_coder_id)
        input_coder_id = make_coder(common_urns.coders.KV.urn, bytes_coder_id,
                                    iter_window_coder_id)
        output_coder_id = make_coder(
            common_urns.coders.KV.urn, bytes_coder_id,
            make_coder(
                common_urns.coders.KV.urn, iter_window_coder_id,
                make_coder(
                    common_urns.coders.ITERABLE.urn,
                    make_coder(common_urns.coders.KV.urn, window_coder_id,
                               iter_window_coder_id))))
        windowed_input_coder_id = make_coder(
            common_urns.coders.WINDOWED_VALUE.urn, input_coder_id,
            global_window_coder_id)
        windowed_output_coder_id = make_coder(
            common_urns.coders.WINDOWED_VALUE.urn, output_coder_id,
            global_window_coder_id)

        self.windowed_input_coder_impl = pipeline_context.coders[
            windowed_input_coder_id].get_impl()
        self.windowed_output_coder_impl = pipeline_context.coders[
            windowed_output_coder_id].get_impl()

        self._bundle_processor_id = self.uid('merge_windows')
        return beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._bundle_processor_id,
            transforms={
                self.TO_SDK_TRANSFORM:
                beam_runner_api_pb2.PTransform(
                    unique_name='MergeWindows/Read',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=bundle_processor.DATA_INPUT_URN,
                        payload=make_channel_payload(windowed_input_coder_id)),
                    outputs={'input': 'input'}),
                'Merge':
                beam_runner_api_pb2.PTransform(
                    unique_name='MergeWindows/Merge',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.primitives.MERGE_WINDOWS.urn,
                        payload=self._windowing_strategy_proto.window_fn.
                        SerializeToString()),
                    inputs={'input': 'input'},
                    outputs={'output': 'output'}),
                self.FROM_SDK_TRANSFORM:
                beam_runner_api_pb2.PTransform(
                    unique_name='MergeWindows/Write',
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=bundle_processor.DATA_OUTPUT_URN,
                        payload=make_channel_payload(
                            windowed_output_coder_id)),
                    inputs={'output': 'output'}),
            },
            pcollections={
                'input':
                beam_runner_api_pb2.PCollection(
                    unique_name='input',
                    windowing_strategy_id=global_windowing_strategy_id,
                    coder_id=input_coder_id),
                'output':
                beam_runner_api_pb2.PCollection(
                    unique_name='output',
                    windowing_strategy_id=global_windowing_strategy_id,
                    coder_id=output_coder_id),
            },
            coders=coders,
            windowing_strategies={
                global_windowing_strategy_id: global_windowing_strategy_proto,
            },
            environments=dict(self._execution_context_ref().
                              pipeline_components.environments.items()),
            state_api_service_descriptor=state_api_service_descriptor,
            timer_api_service_descriptor=data_api_service_descriptor)
Exemple #13
0
  def run_stage(
      self, controller, pipeline_components, stage, pcoll_buffers, safe_coders):

    context = pipeline_context.PipelineContext(pipeline_components)
    data_operation_spec = controller.data_operation_spec()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data data_operation_spec.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          pcoll_id = transform.spec.payload
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            data_input[target] = pcoll_buffers[pcoll_id]
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
          else:
            raise NotImplementedError
          if data_operation_spec:
            transform.spec.payload = data_operation_spec.SerializeToString()
          else:
            transform.spec.payload = ""
        elif transform.spec.urn == urns.PARDO_TRANSFORM:
          payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
          for tag, si in payload.side_inputs.items():
            data_side_input[transform.unique_name, tag] = (
                'materialize:' + transform.inputs[tag],
                beam.pvalue.SideInputData.from_runner_api(si, None))
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    # Store the required side inputs into state.
    for (transform_id, tag), (pcoll_id, si) in data_side_input.items():
      elements_by_window = _WindowGroupingBuffer(si)
      for element_data in pcoll_buffers[pcoll_id]:
        elements_by_window.append(element_data)
      for window, elements_data in elements_by_window.items():
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                ptransform_id=transform_id,
                side_input_id=tag,
                window=window))
        controller.state_handler.blocking_append(state_key, elements_data, None)

    def get_buffer(pcoll_id):
      if pcoll_id.startswith('materialize:'):
        if pcoll_id not in pcoll_buffers:
          # Just store the data chunks for replay.
          pcoll_buffers[pcoll_id] = list()
      elif pcoll_id.startswith('group:'):
        # This is a grouping write, create a grouping buffer if needed.
        if pcoll_id not in pcoll_buffers:
          original_gbk_transform = pcoll_id.split(':', 1)[1]
          transform_proto = pipeline_components.transforms[
              original_gbk_transform]
          input_pcoll = only_element(transform_proto.inputs.values())
          output_pcoll = only_element(transform_proto.outputs.values())
          pre_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[input_pcoll].coder_id]]
          post_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[output_pcoll].coder_id]]
          windowing_strategy = context.windowing_strategies[
              pipeline_components
              .pcollections[output_pcoll].windowing_strategy_id]
          pcoll_buffers[pcoll_id] = _GroupingBuffer(
              pre_gbk_coder, post_gbk_coder, windowing_strategy)
      else:
        # These should be the only two identifiers we produce for now,
        # but special side input writes may go here.
        raise NotImplementedError(pcoll_id)
      return pcoll_buffers[pcoll_id]

    return BundleManager(
        controller, get_buffer, process_bundle_descriptor,
        self._progress_frequency).process_bundle(data_input, data_output)
Exemple #14
0
  def run_stage(
      self,
      worker_handler_factory,
      pipeline_components,
      stage,
      pcoll_buffers,
      safe_coders):

    def iterable_state_write(values, element_coder_impl):
      token = unique_name(None, 'iter').encode('ascii')
      out = create_OutputStream()
      for element in values:
        element_coder_impl.encode_to_stream(element, out, True)
      controller.state.blocking_append(
          beam_fn_api_pb2.StateKey(
              runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
          out.get())
      return token

    controller = worker_handler_factory(stage.environment)
    context = pipeline_context.PipelineContext(
        pipeline_components, iterable_state_write=iterable_state_write)
    data_api_service_descriptor = controller.data_api_service_descriptor()

    def extract_endpoints(stage):
      # Returns maps of transform names to PCollection identifiers.
      # Also mutates IO stages to point to the data ApiServiceDescriptor.
      data_input = {}
      data_side_input = {}
      data_output = {}
      for transform in stage.transforms:
        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                  bundle_processor.DATA_OUTPUT_URN):
          pcoll_id = transform.spec.payload
          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
            target = transform.unique_name, only_element(transform.outputs)
            if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
              data_input[target] = [ENCODED_IMPULSE_VALUE]
            else:
              data_input[target] = pcoll_buffers[pcoll_id]
            coder_id = pipeline_components.pcollections[
                only_element(transform.outputs.values())].coder_id
          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
            target = transform.unique_name, only_element(transform.inputs)
            data_output[target] = pcoll_id
            coder_id = pipeline_components.pcollections[
                only_element(transform.inputs.values())].coder_id
          else:
            raise NotImplementedError
          data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
          if data_api_service_descriptor:
            data_spec.api_service_descriptor.url = (
                data_api_service_descriptor.url)
          transform.spec.payload = data_spec.SerializeToString()
        elif transform.spec.urn == common_urns.primitives.PAR_DO.urn:
          payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
          for tag, si in payload.side_inputs.items():
            data_side_input[transform.unique_name, tag] = (
                create_buffer_id(transform.inputs[tag]), si.access_pattern)
      return data_input, data_side_input, data_output

    logging.info('Running %s', stage.name)
    logging.debug('       %s', stage)
    data_input, data_side_input, data_output = extract_endpoints(stage)

    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
        id=self._next_uid(),
        transforms={transform.unique_name: transform
                    for transform in stage.transforms},
        pcollections=dict(pipeline_components.pcollections.items()),
        coders=dict(pipeline_components.coders.items()),
        windowing_strategies=dict(
            pipeline_components.windowing_strategies.items()),
        environments=dict(pipeline_components.environments.items()))

    if controller.state_api_service_descriptor():
      process_bundle_descriptor.state_api_service_descriptor.url = (
          controller.state_api_service_descriptor().url)

    # Store the required side inputs into state.
    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
      _, pcoll_id = split_buffer_id(buffer_id)
      value_coder = context.coders[safe_coders[
          pipeline_components.pcollections[pcoll_id].coder_id]]
      elements_by_window = _WindowGroupingBuffer(si, value_coder)
      for element_data in pcoll_buffers[buffer_id]:
        elements_by_window.append(element_data)
      for key, window, elements_data in elements_by_window.encoded_items():
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                ptransform_id=transform_id,
                side_input_id=tag,
                window=window,
                key=key))
        controller.state.blocking_append(state_key, elements_data)

    def get_buffer(buffer_id):
      kind, name = split_buffer_id(buffer_id)
      if kind in ('materialize', 'timers'):
        if buffer_id not in pcoll_buffers:
          # Just store the data chunks for replay.
          pcoll_buffers[buffer_id] = list()
      elif kind == 'group':
        # This is a grouping write, create a grouping buffer if needed.
        if buffer_id not in pcoll_buffers:
          original_gbk_transform = name
          transform_proto = pipeline_components.transforms[
              original_gbk_transform]
          input_pcoll = only_element(list(transform_proto.inputs.values()))
          output_pcoll = only_element(list(transform_proto.outputs.values()))
          pre_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[input_pcoll].coder_id]]
          post_gbk_coder = context.coders[safe_coders[
              pipeline_components.pcollections[output_pcoll].coder_id]]
          windowing_strategy = context.windowing_strategies[
              pipeline_components
              .pcollections[output_pcoll].windowing_strategy_id]
          pcoll_buffers[buffer_id] = _GroupingBuffer(
              pre_gbk_coder, post_gbk_coder, windowing_strategy)
      else:
        # These should be the only two identifiers we produce for now,
        # but special side input writes may go here.
        raise NotImplementedError(buffer_id)
      return pcoll_buffers[buffer_id]

    for k in range(self._bundle_repeat):
      try:
        controller.state.checkpoint()
        BundleManager(
            controller, lambda pcoll_id: [], process_bundle_descriptor,
            self._progress_frequency, k).process_bundle(data_input, data_output)
      finally:
        controller.state.restore()

    result = BundleManager(
        controller, get_buffer, process_bundle_descriptor,
        self._progress_frequency).process_bundle(data_input, data_output)

    while True:
      timer_inputs = {}
      for transform_id, timer_writes in stage.timer_pcollections:
        windowed_timer_coder_impl = context.coders[
            pipeline_components.pcollections[timer_writes].coder_id].get_impl()
        written_timers = get_buffer(
            create_buffer_id(timer_writes, kind='timers'))
        if written_timers:
          # Keep only the "last" timer set per key and window.
          timers_by_key_and_window = {}
          for elements_data in written_timers:
            input_stream = create_InputStream(elements_data)
            while input_stream.size() > 0:
              windowed_key_timer = windowed_timer_coder_impl.decode_from_stream(
                  input_stream, True)
              key, _ = windowed_key_timer.value
              # TODO: Explode and merge windows.
              assert len(windowed_key_timer.windows) == 1
              timers_by_key_and_window[
                  key, windowed_key_timer.windows[0]] = windowed_key_timer
          out = create_OutputStream()
          for windowed_key_timer in timers_by_key_and_window.values():
            windowed_timer_coder_impl.encode_to_stream(
                windowed_key_timer, out, True)
          timer_inputs[transform_id, 'out'] = [out.get()]
          written_timers[:] = []
      if timer_inputs:
        # The worker will be waiting on these inputs as well.
        for other_input in data_input:
          if other_input not in timer_inputs:
            timer_inputs[other_input] = []
        # TODO(robertwb): merge results
        BundleManager(
            controller,
            get_buffer,
            process_bundle_descriptor,
            self._progress_frequency,
            True).process_bundle(timer_inputs, data_output)
      else:
        break

    return result