Exemple #1
0
    def test_fn_registration(self):
        fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)]

        process_bundle_descriptors = [
            beam_fn_api_pb2.ProcessBundleDescriptor(
                id=str(100 + ix),
                primitive_transform=[
                    beam_fn_api_pb2.PrimitiveTransform(function_spec=fn)
                ]) for ix, fn in enumerate(fns)
        ]

        test_controller = BeamFnControlServicer([
            beam_fn_api_pb2.InstructionRequest(
                register=beam_fn_api_pb2.RegisterRequest(
                    process_bundle_descriptor=process_bundle_descriptors))
        ])

        server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
            test_controller, server)
        test_port = server.add_insecure_port("[::]:0")
        server.start()

        channel = grpc.insecure_channel("localhost:%s" % test_port)
        harness = sdk_worker.SdkHarness(channel)
        harness.run()
        self.assertEqual(
            harness.worker.fns,
            {item.id: item
             for item in fns + process_bundle_descriptors})
Exemple #2
0
def pack_function_spec_data(value, urn, id=None):
    """Returns packed data in a function spec proto."""
    data = wrappers_pb2.BytesValue(value=value)
    fn_proto = beam_fn_api_pb2.FunctionSpec(urn=urn)
    fn_proto.data.Pack(data)
    if id:
        fn_proto.id = id
    return fn_proto
    def _map_task_registration(self, map_task, state_handler,
                               data_operation_spec):
        input_data = {}
        runner_sinks = {}
        transforms = []
        transform_index_to_id = {}

        # Maps coders to new coder objects and references.
        coders = {}

        def coder_id(coder):
            if coder not in coders:
                coders[coder] = beam_fn_api_pb2.Coder(
                    function_spec=sdk_worker.pack_function_spec_data(
                        json.dumps(coder.as_cloud_object()),
                        sdk_worker.PYTHON_CODER_URN,
                        id=self._next_uid()))

            return coders[coder].function_spec.id

        def output_tags(op):
            return getattr(op, 'output_tags', ['out'])

        def as_target(op_input):
            input_op_index, input_output_index = op_input
            input_op = map_task[input_op_index][1]
            return {
                'ignored_input_tag':
                beam_fn_api_pb2.Target.List(target=[
                    beam_fn_api_pb2.Target(
                        primitive_transform_reference=transform_index_to_id[
                            input_op_index],
                        name=output_tags(input_op)[input_output_index])
                ])
            }

        def outputs(op):
            return {
                tag:
                beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder))
                for tag, coder in zip(output_tags(op), op.output_coders)
            }

        for op_ix, (stage_name, operation) in enumerate(map_task):
            transform_id = transform_index_to_id[op_ix] = self._next_uid()
            if isinstance(operation, operation_specs.WorkerInMemoryWrite):
                # Write this data back to the runner.
                fn = beam_fn_api_pb2.FunctionSpec(
                    urn=sdk_worker.DATA_OUTPUT_URN, id=self._next_uid())
                if data_operation_spec:
                    fn.data.Pack(data_operation_spec)
                inputs = as_target(operation.input)
                side_inputs = {}
                runner_sinks[(transform_id, 'out')] = operation

            elif isinstance(operation, operation_specs.WorkerRead):
                # A Read is either translated to a direct injection of windowed values
                # into the sdk worker, or an injection of the source object into the
                # sdk worker as data followed by an SDF that reads that source.
                if (isinstance(operation.source.source,
                               maptask_executor_runner.InMemorySource)
                        and isinstance(
                            operation.source.source.default_output_coder(),
                            WindowedValueCoder)):
                    output_stream = create_OutputStream()
                    element_coder = (operation.source.source.
                                     default_output_coder().get_impl())
                    # Re-encode the elements in the nested context and
                    # concatenate them together
                    for element in operation.source.source.read(None):
                        element_coder.encode_to_stream(element, output_stream,
                                                       True)
                    target_name = self._next_uid()
                    input_data[(transform_id,
                                target_name)] = output_stream.get()
                    fn = beam_fn_api_pb2.FunctionSpec(
                        urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid())
                    if data_operation_spec:
                        fn.data.Pack(data_operation_spec)
                    inputs = {target_name: beam_fn_api_pb2.Target.List()}
                    side_inputs = {}
                else:
                    # Read the source object from the runner.
                    source_coder = beam.coders.DillCoder()
                    input_transform_id = self._next_uid()
                    output_stream = create_OutputStream()
                    source_coder.get_impl().encode_to_stream(
                        GlobalWindows.windowed_value(operation.source),
                        output_stream, True)
                    target_name = self._next_uid()
                    input_data[(input_transform_id,
                                target_name)] = output_stream.get()
                    input_ptransform = beam_fn_api_pb2.PrimitiveTransform(
                        id=input_transform_id,
                        function_spec=beam_fn_api_pb2.FunctionSpec(
                            urn=sdk_worker.DATA_INPUT_URN,
                            id=self._next_uid()),
                        # TODO(robertwb): Possible name collision.
                        step_name=stage_name + '/inject_source',
                        inputs={target_name: beam_fn_api_pb2.Target.List()},
                        outputs={
                            'out':
                            beam_fn_api_pb2.PCollection(
                                coder_reference=coder_id(source_coder))
                        })
                    if data_operation_spec:
                        input_ptransform.function_spec.data.Pack(
                            data_operation_spec)
                    transforms.append(input_ptransform)

                    # Read the elements out of the source.
                    fn = sdk_worker.pack_function_spec_data(
                        OLDE_SOURCE_SPLITTABLE_DOFN_DATA,
                        sdk_worker.PYTHON_DOFN_URN,
                        id=self._next_uid())
                    inputs = {
                        'ignored_input_tag':
                        beam_fn_api_pb2.Target.List(target=[
                            beam_fn_api_pb2.Target(
                                primitive_transform_reference=
                                input_transform_id,
                                name='out')
                        ])
                    }
                    side_inputs = {}

            elif isinstance(operation, operation_specs.WorkerDoFn):
                fn = sdk_worker.pack_function_spec_data(
                    operation.serialized_fn,
                    sdk_worker.PYTHON_DOFN_URN,
                    id=self._next_uid())
                inputs = as_target(operation.input)
                # Store the contents of each side input for state access.
                for si in operation.side_inputs:
                    assert isinstance(si.source, iobase.BoundedSource)
                    element_coder = si.source.default_output_coder()
                    view_id = self._next_uid()
                    # TODO(robertwb): Actually flesh out the ViewFn API.
                    side_inputs[si.tag] = beam_fn_api_pb2.SideInput(
                        view_fn=sdk_worker.serialize_and_pack_py_fn(
                            element_coder,
                            urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN,
                            id=view_id))
                    # Re-encode the elements in the nested context and
                    # concatenate them together
                    output_stream = create_OutputStream()
                    for element in si.source.read(
                            si.source.get_range_tracker(None, None)):
                        element_coder.get_impl().encode_to_stream(
                            element, output_stream, True)
                    elements_data = output_stream.get()
                    state_key = beam_fn_api_pb2.StateKey.MultimapSideInput(
                        key=view_id)
                    state_handler.Clear(state_key)
                    state_handler.Append(state_key, elements_data)

            elif isinstance(operation, operation_specs.WorkerFlatten):
                fn = sdk_worker.pack_function_spec_data(
                    operation.serialized_fn,
                    sdk_worker.IDENTITY_DOFN_URN,
                    id=self._next_uid())
                inputs = {
                    'ignored_input_tag':
                    beam_fn_api_pb2.Target.List(target=[
                        beam_fn_api_pb2.Target(
                            primitive_transform_reference=
                            transform_index_to_id[input_op_index],
                            name=output_tags(map_task[input_op_index]
                                             [1])[input_output_index]) for
                        input_op_index, input_output_index in operation.inputs
                    ])
                }
                side_inputs = {}

            else:
                raise TypeError(operation)

            ptransform = beam_fn_api_pb2.PrimitiveTransform(
                id=transform_id,
                function_spec=fn,
                step_name=stage_name,
                inputs=inputs,
                side_inputs=side_inputs,
                outputs=outputs(operation))
            transforms.append(ptransform)

        process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._next_uid(),
            coders=coders.values(),
            primitive_transform=transforms)
        return beam_fn_api_pb2.InstructionRequest(
            instruction_id=self._next_uid(),
            register=beam_fn_api_pb2.RegisterRequest(
                process_bundle_descriptor=[process_bundle_descriptor
                                           ])), runner_sinks, input_data