Example #1
0
  def GetArtifact(self, request, context=None):
    if request.artifact.type_urn == common_urns.artifact_types.FILE.urn:
      payload = proto_utils.parse_Bytes(
          request.artifact.type_payload,
          beam_runner_api_pb2.ArtifactFilePayload)
      read_handle = self._file_reader(payload.path)
    elif request.artifact.type_urn == common_urns.artifact_types.URL.urn:
      payload = proto_utils.parse_Bytes(
          request.artifact.type_payload, beam_runner_api_pb2.ArtifactUrlPayload)
      # TODO(Py3): Remove the unneeded contextlib wrapper.
      read_handle = contextlib.closing(urlopen(payload.url))
    elif request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn:
      payload = proto_utils.parse_Bytes(
          request.artifact.type_payload,
          beam_runner_api_pb2.EmbeddedFilePayload)
      read_handle = BytesIO(payload.data)
    else:
      raise NotImplementedError(request.artifact.type_urn)

    with read_handle as fin:
      while True:
        chunk = fin.read(self._chunk_size)
        if not chunk:
          break
        yield beam_artifact_api_pb2.GetArtifactResponse(data=chunk)
Example #2
0
 def _extract_environment(transform):
   if transform.spec.urn in PAR_DO_URNS:
     pardo_payload = proto_utils.parse_Bytes(
         transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
     return pardo_payload.do_fn.environment_id
   elif transform.spec.urn in COMBINE_URNS:
     combine_payload = proto_utils.parse_Bytes(
         transform.spec.payload, beam_runner_api_pb2.CombinePayload)
     return combine_payload.combine_fn.environment_id
   else:
     return None
 def _extract_environment(transform):
     if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
         pardo_payload = proto_utils.parse_Bytes(
             transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
         return pardo_payload.do_fn.environment_id
     elif transform.spec.urn in (
             common_urns.composites.COMBINE_PER_KEY.urn,
             common_urns.combine_components.COMBINE_PGBKCV.urn,
             common_urns.combine_components.COMBINE_MERGE_ACCUMULATORS.urn,
             common_urns.combine_components.COMBINE_EXTRACT_OUTPUTS.urn):
         combine_payload = proto_utils.parse_Bytes(
             transform.spec.payload, beam_runner_api_pb2.CombinePayload)
         return combine_payload.combine_fn.environment_id
     else:
         return None
Example #4
0
  def test_sdk_harness_container_image_overrides(self):
    test_environment = DockerEnvironment(
        container_image='dummy_container_image')
    proto_pipeline, _ = Pipeline().to_runner_api(
      return_context=True, default_environment=test_environment)

    pipeline_options = PipelineOptions([
        '--experiments=beam_fn_api',
        '--experiments=use_unified_worker',
        '--temp_location',
        'gs://any-location/temp'
    ])

    # Accessing non-public method for testing.
    apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
        proto_pipeline, {'.*dummy.*': 'new_dummy_container_image'},
        pipeline_options)

    self.assertIsNotNone(1, len(proto_pipeline.components.environments))
    env = list(proto_pipeline.components.environments.values())[0]

    from apache_beam.utils import proto_utils
    docker_payload = proto_utils.parse_Bytes(
        env.payload, beam_runner_api_pb2.DockerPayload)

    # Container image should be overridden by a the given override.
    self.assertEqual(
        docker_payload.container_image, 'new_dummy_container_image')
Example #5
0
  def test_pipeline_sdk_not_overridden(self):
    pipeline_options = PipelineOptions([
        '--experiments=beam_fn_api',
        '--experiments=use_unified_worker',
        '--temp_location',
        'gs://any-location/temp',
        '--sdk_container_image=dummy_prefix/dummy_name:dummy_tag'
    ])

    pipeline = Pipeline(options=pipeline_options)
    pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

    proto_pipeline, _ = pipeline.to_runner_api(return_context=True)

    dummy_env = DockerEnvironment(
        container_image='dummy_prefix/dummy_name:dummy_tag')
    proto_pipeline, _ = pipeline.to_runner_api(
        return_context=True, default_environment=dummy_env)

    # Accessing non-public method for testing.
    apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
        proto_pipeline, {}, pipeline_options)

    self.assertIsNotNone(2, len(proto_pipeline.components.environments))

    from apache_beam.utils import proto_utils
    found_override = False
    for env in proto_pipeline.components.environments.values():
      docker_payload = proto_utils.parse_Bytes(
          env.payload, beam_runner_api_pb2.DockerPayload)
      if docker_payload.container_image.startswith(
          names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY):
        found_override = True

    self.assertFalse(found_override)
Example #6
0
    def _build_data_side_inputs_map(stages):
        # type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput]
        """Builds an index mapping stages to side input descriptors.

    A side input descriptor is a map of side input IDs to side input access
    patterns for all of the outputs of a stage that will be consumed as a
    side input.
    """
        transform_consumers = collections.defaultdict(
            list
        )  # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]]
        stage_consumers = collections.defaultdict(
            list)  # type: DefaultDict[str, List[translations.Stage]]

        def get_all_side_inputs():
            # type: () -> Set[str]
            all_side_inputs = set()  # type: Set[str]
            for stage in stages:
                for transform in stage.transforms:
                    for input in transform.inputs.values():
                        transform_consumers[input].append(transform)
                        stage_consumers[input].append(stage)
                for si in stage.side_inputs():
                    all_side_inputs.add(si)
            return all_side_inputs

        all_side_inputs = frozenset(get_all_side_inputs())
        data_side_inputs_by_producing_stage = {
        }  # type: Dict[str, DataSideInput]

        producing_stages_by_pcoll = {}

        for s in stages:
            data_side_inputs_by_producing_stage[s.name] = {}
            for transform in s.transforms:
                for o in transform.outputs.values():
                    if o in s.side_inputs():
                        continue
                    if o in producing_stages_by_pcoll:
                        continue
                    producing_stages_by_pcoll[o] = s

        for side_pc in all_side_inputs:
            for consuming_transform in transform_consumers[side_pc]:
                if consuming_transform.spec.urn not in translations.PAR_DO_URNS:
                    continue
                producing_stage = producing_stages_by_pcoll[side_pc]
                payload = proto_utils.parse_Bytes(
                    consuming_transform.spec.payload,
                    beam_runner_api_pb2.ParDoPayload)
                for si_tag in payload.side_inputs:
                    if consuming_transform.inputs[si_tag] == side_pc:
                        side_input_id = (consuming_transform.unique_name,
                                         si_tag)
                        data_side_inputs_by_producing_stage[
                            producing_stage.name][side_input_id] = (
                                translations.create_buffer_id(side_pc),
                                payload.side_inputs[si_tag].access_pattern)

        return data_side_inputs_by_producing_stage
Example #7
0
 def extract_endpoints(stage):
   # Returns maps of transform names to PCollection identifiers.
   # Also mutates IO stages to point to the data data_operation_spec.
   data_input = {}
   data_side_input = {}
   data_output = {}
   for transform in stage.transforms:
     if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                               bundle_processor.DATA_OUTPUT_URN):
       pcoll_id = transform.spec.payload
       if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
         target = transform.unique_name, only_element(transform.outputs)
         data_input[target] = pcoll_buffers[pcoll_id]
       elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
         target = transform.unique_name, only_element(transform.inputs)
         data_output[target] = pcoll_id
       else:
         raise NotImplementedError
       if data_operation_spec:
         transform.spec.payload = data_operation_spec.SerializeToString()
       else:
         transform.spec.payload = ""
     elif transform.spec.urn == urns.PARDO_TRANSFORM:
       payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       for tag, si in payload.side_inputs.items():
         data_side_input[transform.unique_name, tag] = (
             'materialize:' + transform.inputs[tag],
             beam.pvalue.SideInputData.from_runner_api(si, None))
   return data_input, data_side_input, data_output
 def side_inputs(self):
   for transform in self.transforms:
     if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
       payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       for side_input in payload.side_inputs:
         yield transform.inputs[side_input]
Example #9
0
    def test_runner_api_transformation_with_subscription(
            self, unused_mock_pubsub):
        source = _PubSubSource(
            topic=None,
            subscription='projects/fakeprj/subscriptions/a_subscription',
            id_label='a_label',
            timestamp_attribute='b_label',
            with_attributes=True)
        transform = Read(source)

        context = pipeline_context.PipelineContext()
        proto_transform_spec = transform.to_runner_api(context)
        self.assertEqual(common_urns.composites.PUBSUB_READ.urn,
                         proto_transform_spec.urn)

        pubsub_read_payload = (proto_utils.parse_Bytes(
            proto_transform_spec.payload,
            beam_runner_api_pb2.PubSubReadPayload))
        self.assertEqual('projects/fakeprj/subscriptions/a_subscription',
                         pubsub_read_payload.subscription)
        self.assertEqual('a_label', pubsub_read_payload.id_attribute)
        self.assertEqual('b_label', pubsub_read_payload.timestamp_attribute)
        self.assertEqual('', pubsub_read_payload.topic)
        self.assertTrue(pubsub_read_payload.with_attributes)

        proto_transform = beam_runner_api_pb2.PTransform(
            unique_name="dummy_label", spec=proto_transform_spec)

        transform_from_proto = Read.from_runner_api_parameter(
            proto_transform, pubsub_read_payload, None)
        self.assertTrue(isinstance(transform_from_proto, Read))
        self.assertTrue(isinstance(transform_from_proto.source, _PubSubSource))
        self.assertTrue(transform_from_proto.source.with_attributes)
        self.assertEqual('projects/fakeprj/subscriptions/a_subscription',
                         transform_from_proto.source.full_subscription)
Example #10
0
 def extract_endpoints(stage):
     # Returns maps of transform names to PCollection identifiers.
     # Also mutates IO stages to point to the data data_operation_spec.
     data_input = {}
     data_side_input = {}
     data_output = {}
     for transform in stage.transforms:
         if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                   bundle_processor.DATA_OUTPUT_URN):
             pcoll_id = transform.spec.payload
             if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
                 target = transform.unique_name, only_element(
                     transform.outputs)
                 data_input[target] = pcoll_id
             elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
                 target = transform.unique_name, only_element(
                     transform.inputs)
                 data_output[target] = pcoll_id
             else:
                 raise NotImplementedError
             if data_operation_spec:
                 transform.spec.payload = data_operation_spec.SerializeToString(
                 )
             else:
                 transform.spec.payload = ""
         elif transform.spec.urn == urns.PARDO_TRANSFORM:
             payload = proto_utils.parse_Bytes(
                 transform.spec.payload,
                 beam_runner_api_pb2.ParDoPayload)
             for tag, si in payload.side_inputs.items():
                 data_side_input[transform.unique_name, tag] = (
                     'materialize:' + transform.inputs[tag],
                     beam.pvalue.SideInputData.from_runner_api(
                         si, None))
     return data_input, data_side_input, data_output
Example #11
0
    def test_runner_api_transformation_properties_none(self,
                                                       unused_mock_pubsub):
        # Confirming that properties stay None after a runner API transformation.
        sink = _PubSubSink(
            topic='projects/fakeprj/topics/a_topic',
            id_label=None,
            with_attributes=True,
            # We expect encoded PubSub write transform to always return attributes.
            timestamp_attribute=None)
        transform = Write(sink)

        context = pipeline_context.PipelineContext()
        proto_transform_spec = transform.to_runner_api(context)
        self.assertEqual(common_urns.composites.PUBSUB_WRITE.urn,
                         proto_transform_spec.urn)

        pubsub_write_payload = (proto_utils.parse_Bytes(
            proto_transform_spec.payload,
            beam_runner_api_pb2.PubSubWritePayload))
        proto_transform = beam_runner_api_pb2.PTransform(
            unique_name="dummy_label", spec=proto_transform_spec)
        transform_from_proto = Write.from_runner_api_parameter(
            proto_transform, pubsub_write_payload, None)

        self.assertTrue(isinstance(transform_from_proto, Write))
        self.assertTrue(isinstance(transform_from_proto.sink, _PubSubSink))
        self.assertTrue(transform_from_proto.sink.with_attributes)
        self.assertIsNone(transform_from_proto.sink.id_label)
        self.assertIsNone(transform_from_proto.sink.timestamp_attribute)
Example #12
0
 def from_runner_api(cls, proto, context):
   if proto is None or not proto.urn:
     return None
   parameter_type, constructor = cls._known_urns[proto.urn]
   return constructor(
       proto_utils.parse_Bytes(proto.payload, parameter_type),
       context)
 def side_inputs(self):
   for transform in self.transforms:
     if transform.spec.urn == urns.PARDO_TRANSFORM:
       payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       for side_input in payload.side_inputs:
         yield transform.inputs[side_input]
Example #14
0
 def side_inputs(self):
   for transform in self.transforms:
     if transform.spec.urn == urns.PARDO_TRANSFORM:
       payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       for side_input in payload.side_inputs:
         yield transform.inputs[side_input]
 def side_inputs(self):
     for transform in self.transforms:
         if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
             payload = proto_utils.parse_Bytes(
                 transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
             for side_input in payload.side_inputs:
                 yield transform.inputs[side_input]
Example #16
0
  def extract_bundle_inputs_and_outputs(self):
    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]]

    """Returns maps of transform names to PCollection identifiers.

    Also mutates IO stages to point to the data ApiServiceDescriptor.

    Returns:
      A tuple of (data_input, data_output, expected_timer_output) dictionaries.
        `data_input` is a dictionary mapping (transform_name, output_name) to a
        PCollection buffer; `data_output` is a dictionary mapping
        (transform_name, output_name) to a PCollection ID.
        `expected_timer_output` is a dictionary mapping transform_id and
        timer family ID to a buffer id for timers.
    """
    data_input = {}  # type: Dict[str, PartitionableBuffer]
    data_output = {}  # type: DataOutput
    # A mapping of {(transform_id, timer_family_id) : buffer_id}
    expected_timer_output = {}  # type: Dict[Tuple[str, str], str]
    for transform in self.stage.transforms:
      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                bundle_processor.DATA_OUTPUT_URN):
        pcoll_id = transform.spec.payload
        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
          coder_id = self.execution_context.data_channel_coders[only_element(
              transform.outputs.values())]
          coder = self.execution_context.pipeline_context.coders[
              self.execution_context.safe_coders.get(coder_id, coder_id)]
          if pcoll_id == translations.IMPULSE_BUFFER:
            data_input[transform.unique_name] = ListBuffer(
                coder_impl=coder.get_impl())
            data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
          else:
            if pcoll_id not in self.execution_context.pcoll_buffers:
              self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
                  coder_impl=coder.get_impl())
            data_input[transform.unique_name] = (
                self.execution_context.pcoll_buffers[pcoll_id])
        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
          data_output[transform.unique_name] = pcoll_id
          coder_id = self.execution_context.data_channel_coders[only_element(
              transform.inputs.values())]
        else:
          raise NotImplementedError
        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
        data_api_service_descriptor = self.data_api_service_descriptor()
        if data_api_service_descriptor:
          data_spec.api_service_descriptor.url = (
              data_api_service_descriptor.url)
        transform.spec.payload = data_spec.SerializeToString()
      elif transform.spec.urn in translations.PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for timer_family_id in payload.timer_family_specs.keys():
          expected_timer_output[(transform.unique_name, timer_family_id)] = (
              create_buffer_id(timer_family_id, 'timers'))
    return data_input, data_output, expected_timer_output
Example #17
0
 def create_operation(self, transform_id, consumers):
   transform_proto = self.descriptor.transforms[transform_id]
   if not transform_proto.unique_name:
     logging.debug("No unique name set for transform %s" % transform_id)
     transform_proto.unique_name = transform_id
   creator, parameter_type = self._known_urns[transform_proto.spec.urn]
   payload = proto_utils.parse_Bytes(
       transform_proto.spec.payload, parameter_type)
   return creator(self, transform_id, transform_proto, payload, consumers)
Example #18
0
    def from_runner_api(cls, fn_proto, context):
        """Converts from an SdkFunctionSpec to a Fn object.

    Prefer registering a urn with its parameter type and constructor.
    """
        parameter_type, constructor = cls._known_urns[fn_proto.spec.urn]
        return constructor(
            proto_utils.parse_Bytes(fn_proto.spec.payload, parameter_type),
            context)
Example #19
0
File: urns.py Project: xubii/beam
    def from_runner_api(cls, fn_proto, context):
        # type: (beam_runner_api_pb2.FunctionSpec, PipelineContext) -> Any
        """Converts from an FunctionSpec to a Fn object.

    Prefer registering a urn with its parameter type and constructor.
    """
        parameter_type, constructor = cls._known_urns[fn_proto.urn]
        return constructor(
            proto_utils.parse_Bytes(fn_proto.payload, parameter_type), context)
Example #20
0
  def from_runner_api(cls, fn_proto, context):
    """Converts from an SdkFunctionSpec to a Fn object.

    Prefer registering a urn with its parameter type and constructor.
    """
    parameter_type, constructor = cls._known_urns[fn_proto.spec.urn]
    return constructor(
        proto_utils.parse_Bytes(fn_proto.spec.payload, parameter_type),
        context)
 def create_operation(self, transform_id, consumers):
   transform_proto = self.descriptor.transforms[transform_id]
   if not transform_proto.unique_name:
     logging.warn("No unique name set for transform %s" % transform_id)
     transform_proto.unique_name = transform_id
   creator, parameter_type = self._known_urns[transform_proto.spec.urn]
   payload = proto_utils.parse_Bytes(
       transform_proto.spec.payload, parameter_type)
   return creator(self, transform_id, transform_proto, payload, consumers)
Example #22
0
def annotate_stateful_dofns_as_roots(stages, pipeline_context):
  for stage in stages:
    for transform in stage.transforms:
      if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
        pardo_payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        if pardo_payload.state_specs or pardo_payload.timer_specs:
          stage.forced_root = True
    yield stage
Example #23
0
  def from_runner_api(cls, coder_proto, context):
    """Converts from an SdkFunctionSpec to a Fn object.

    Prefer registering a urn with its parameter type and constructor.
    """
    parameter_type, constructor = cls._known_urns[coder_proto.spec.spec.urn]
    return constructor(
        proto_utils.parse_Bytes(coder_proto.spec.spec.payload, parameter_type),
        [context.coders.get_by_id(c) for c in coder_proto.component_coder_ids],
        context)
 def GetArtifact(self, request):
     if request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn:
         content = proto_utils.parse_Bytes(
             request.artifact.type_payload,
             beam_runner_api_pb2.EmbeddedFilePayload).data
         for k in range(0, len(content), 13):
             yield beam_artifact_api_pb2.GetArtifactResponse(
                 data=content[k:k + 13])
     else:
         raise NotImplementedError
Example #25
0
  def from_runner_api(cls, coder_proto, context):
    """Converts from an SdkFunctionSpec to a Fn object.

    Prefer registering a urn with its parameter type and constructor.
    """
    parameter_type, constructor = cls._known_urns[coder_proto.spec.spec.urn]
    return constructor(
        proto_utils.parse_Bytes(coder_proto.spec.spec.payload, parameter_type),
        [context.coders.get_by_id(c) for c in coder_proto.component_coder_ids],
        context)
Example #26
0
 def has_as_main_input(self, pcoll):
   for transform in self.transforms:
     if transform.spec.urn == urns.PARDO_TRANSFORM:
       payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       local_side_inputs = payload.side_inputs
     else:
       local_side_inputs = {}
     for local_id, pipeline_id in transform.inputs.items():
       if pcoll == pipeline_id and local_id not in local_side_inputs:
         return True
Example #27
0
 def _build_timer_coders_id_map(self):
   timer_coder_ids = {}
   for transform_id, transform_proto in (self._process_bundle_descriptor
       .transforms.items()):
     if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
       pardo_payload = proto_utils.parse_Bytes(
           transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload)
       for id, timer_family_spec in pardo_payload.timer_family_specs.items():
         timer_coder_ids[(transform_id, id)] = (
             timer_family_spec.timer_family_coder_id)
   return timer_coder_ids
Example #28
0
    def test_build_payload_with_builder_methods(self):
        payload_builder = JavaClassLookupPayloadBuilder('dummy_class_name')
        payload_builder.with_constructor('abc',
                                         123,
                                         str_field='def',
                                         int_field=456)
        payload_builder.add_builder_method('builder_method1',
                                           'abc1',
                                           1234,
                                           str_field1='abc2',
                                           int_field1=2345)
        payload_builder.add_builder_method('builder_method2',
                                           'abc3',
                                           3456,
                                           str_field2='abc4',
                                           int_field2=4567)
        payload_bytes = payload_builder.payload()
        payload_from_bytes = proto_utils.parse_Bytes(
            payload_bytes, external_transforms_pb2.JavaClassLookupPayload)
        self.assertTrue(
            isinstance(payload_from_bytes,
                       external_transforms_pb2.JavaClassLookupPayload))
        self._verify_row(payload_from_bytes.constructor_schema,
                         payload_from_bytes.constructor_payload, {
                             'ignore0': 'abc',
                             'ignore1': 123,
                             'str_field': 'def',
                             'int_field': 456
                         })
        self.assertEqual(2, len(payload_from_bytes.builder_methods))
        builder_method = payload_from_bytes.builder_methods[0]
        self.assertTrue(
            isinstance(builder_method, external_transforms_pb2.BuilderMethod))
        self.assertEqual('builder_method1', builder_method.name)

        self._verify_row(
            builder_method.schema, builder_method.payload, {
                'ignore0': 'abc1',
                'ignore1': 1234,
                'str_field1': 'abc2',
                'int_field1': 2345
            })

        builder_method = payload_from_bytes.builder_methods[1]
        self.assertTrue(
            isinstance(builder_method, external_transforms_pb2.BuilderMethod))
        self.assertEqual('builder_method2', builder_method.name)
        self._verify_row(
            builder_method.schema, builder_method.payload, {
                'ignore0': 'abc3',
                'ignore1': 3456,
                'str_field2': 'abc4',
                'int_field2': 4567
            })
 def has_as_main_input(self, pcoll):
   for transform in self.transforms:
     if transform.spec.urn == urns.PARDO_TRANSFORM:
       payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       local_side_inputs = payload.side_inputs
     else:
       local_side_inputs = {}
     for local_id, pipeline_id in transform.inputs.items():
       if pcoll == pipeline_id and local_id not in local_side_inputs:
         return True
Example #30
0
 def create(
         cls,
         environment,  # type: beam_runner_api_pb2.Environment
         state,  # type: sdk_worker.StateHandler
         provision_info,  # type: ExtendedProvisionInfo
         grpc_server  # type: GrpcServer
 ):
     # type: (...) -> WorkerHandler
     constructor, payload_type = cls._registered_environments[
         environment.urn]
     return constructor(
         proto_utils.parse_Bytes(environment.payload, payload_type), state,
         provision_info, grpc_server)
Example #31
0
    def from_runner_api(
        cls,
        proto,  # type: Optional[beam_runner_api_pb2.PTransform]
        context  # type: PipelineContext
    ):
        # type: (...) -> Optional[PTransform]
        if proto is None or proto.spec is None or not proto.spec.urn:
            return None
        parameter_type, constructor = cls._known_urns[proto.spec.urn]

        return constructor(
            proto, proto_utils.parse_Bytes(proto.spec.payload, parameter_type),
            context)
Example #32
0
    def from_runner_api(
        cls,
        proto,  # type: Optional[beam_runner_api_pb2.Environment]
        context  # type: PipelineContext
    ):
        # type: (...) -> Optional[Environment]
        if proto is None or not proto.urn:
            return None
        parameter_type, constructor = cls._known_urns[proto.urn]

        return constructor(
            proto_utils.parse_Bytes(proto.payload, parameter_type),
            proto.capabilities, proto.dependencies, context)
Example #33
0
    def from_runner_api(cls, proto, context):
        if proto is None or not proto.urn:
            return None
        parameter_type, constructor = cls._known_urns[proto.urn]

        try:
            return constructor(
                proto_utils.parse_Bytes(proto.payload, parameter_type),
                context)
        except Exception:
            if context.allow_proto_holders:
                return RunnerAPIEnvironmentHolder(proto)
            raise
Example #34
0
def create_data_stream_keyed_process_function(factory, transform_id, transform_proto, parameter,
                                              consumers):
    urn = parameter.do_fn.urn
    payload = proto_utils.parse_Bytes(
        parameter.do_fn.payload, flink_fn_execution_pb2.UserDefinedDataStreamFunction)
    if urn == datastream_operations.DATA_STREAM_STATELESS_FUNCTION_URN:
        return _create_user_defined_function_operation(
            factory, transform_proto, consumers, payload,
            beam_operations.StatelessFunctionOperation,
            datastream_operations.StatelessOperation)
    else:
        return _create_user_defined_function_operation(
            factory, transform_proto, consumers, payload,
            beam_operations.StatefulFunctionOperation,
            datastream_operations.StatefulOperation)
Example #35
0
    def test_implicit_builder_with_constructor(self):
        constructor_transform = (JavaExternalTransform('org.pkg.MyTransform')(
            'abc').withIntProperty(5))

        payload_bytes = constructor_transform._payload_builder.payload()
        payload_from_bytes = proto_utils.parse_Bytes(payload_bytes,
                                                     JavaClassLookupPayload)
        self.assertEqual('org.pkg.MyTransform', payload_from_bytes.class_name)
        self._verify_row(payload_from_bytes.constructor_schema,
                         payload_from_bytes.constructor_payload,
                         {'ignore0': 'abc'})
        builder_method = payload_from_bytes.builder_methods[0]
        self.assertEqual('withIntProperty', builder_method.name)
        self._verify_row(builder_method.schema, builder_method.payload,
                         {'ignore0': 5})
Example #36
0
 def _check_requirements(self, pipeline_proto):
   """Check that this runner can satisfy all pipeline requirements."""
   supported_requirements = set(self.supported_requirements())
   for requirement in pipeline_proto.requirements:
     if requirement not in supported_requirements:
       raise ValueError(
           'Unable to run pipeline with requirement: %s' % requirement)
   for transform in pipeline_proto.components.transforms.values():
     if transform.spec.urn == common_urns.primitives.TEST_STREAM.urn:
       raise NotImplementedError(transform.spec.urn)
     elif transform.spec.urn in translations.PAR_DO_URNS:
       payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       for timer in payload.timer_family_specs.values():
         if timer.time_domain != beam_runner_api_pb2.TimeDomain.EVENT_TIME:
           raise NotImplementedError(timer.time_domain)
Example #37
0
    def from_runner_api(cls, proto, context):
        if proto is None or not proto.urn:
            return None
        parameter_type, constructor = cls._known_urns[proto.urn]

        try:
            return constructor(
                proto_utils.parse_Bytes(proto.payload, parameter_type),
                context)
        except Exception:
            if context.allow_proto_holders:
                # For external transforms we cannot build a Python ParDo object so
                # we build a holder transform instead.
                from apache_beam.transforms.core import RunnerAPIPTransformHolder
                return RunnerAPIPTransformHolder(proto)
            raise
Example #38
0
 def extract_endpoints(stage):
   # Returns maps of transform names to PCollection identifiers.
   # Also mutates IO stages to point to the data ApiServiceDescriptor.
   data_input = {}
   data_side_input = {}
   data_output = {}
   for transform in stage.transforms:
     if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                               bundle_processor.DATA_OUTPUT_URN):
       pcoll_id = transform.spec.payload
       if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
         target = transform.unique_name, only_element(transform.outputs)
         if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
           data_input[target] = [ENCODED_IMPULSE_VALUE]
         else:
           data_input[target] = pcoll_buffers[pcoll_id]
         coder_id = pipeline_components.pcollections[
             only_element(transform.outputs.values())].coder_id
       elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
         target = transform.unique_name, only_element(transform.inputs)
         data_output[target] = pcoll_id
         coder_id = pipeline_components.pcollections[
             only_element(transform.inputs.values())].coder_id
       else:
         raise NotImplementedError
       data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
       if data_api_service_descriptor:
         data_spec.api_service_descriptor.url = (
             data_api_service_descriptor.url)
       transform.spec.payload = data_spec.SerializeToString()
     elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS:
       payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       for tag, si in payload.side_inputs.items():
         data_side_input[transform.unique_name, tag] = (
             create_buffer_id(transform.inputs[tag]), si.access_pattern)
   return data_input, data_side_input, data_output
Example #39
0
def inject_timer_pcollections(stages, pipeline_context):
  """Create PCollections for fired timers and to-be-set timers.

  At execution time, fired timers and timers-to-set are represented as
  PCollections that are managed by the runner.  This phase adds the
  necissary collections, with their read and writes, to any stages using
  timers.
  """
  for stage in stages:
    for transform in list(stage.transforms):
      if transform.spec.urn in PAR_DO_URNS:
        payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
        for tag, spec in payload.timer_specs.items():
          if len(transform.inputs) > 1:
            raise NotImplementedError('Timers and side inputs.')
          input_pcoll = pipeline_context.components.pcollections[
              next(iter(transform.inputs.values()))]
          # Create the appropriate coder for the timer PCollection.
          key_coder_id = input_pcoll.coder_id
          if (pipeline_context.components.coders[key_coder_id].spec.spec.urn
              == common_urns.coders.KV.urn):
            key_coder_id = pipeline_context.components.coders[
                key_coder_id].component_coder_ids[0]
          key_timer_coder_id = pipeline_context.add_or_get_coder_id(
              beam_runner_api_pb2.Coder(
                  spec=beam_runner_api_pb2.SdkFunctionSpec(
                      spec=beam_runner_api_pb2.FunctionSpec(
                          urn=common_urns.coders.KV.urn)),
                  component_coder_ids=[key_coder_id, spec.timer_coder_id]))
          # Inject the read and write pcollections.
          timer_read_pcoll = unique_name(
              pipeline_context.components.pcollections,
              '%s_timers_to_read_%s' % (transform.unique_name, tag))
          timer_write_pcoll = unique_name(
              pipeline_context.components.pcollections,
              '%s_timers_to_write_%s' % (transform.unique_name, tag))
          pipeline_context.components.pcollections[timer_read_pcoll].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=timer_read_pcoll,
                  coder_id=key_timer_coder_id,
                  windowing_strategy_id=input_pcoll.windowing_strategy_id,
                  is_bounded=input_pcoll.is_bounded))
          pipeline_context.components.pcollections[timer_write_pcoll].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=timer_write_pcoll,
                  coder_id=key_timer_coder_id,
                  windowing_strategy_id=input_pcoll.windowing_strategy_id,
                  is_bounded=input_pcoll.is_bounded))
          stage.transforms.append(
              beam_runner_api_pb2.PTransform(
                  unique_name=timer_read_pcoll + '/Read',
                  outputs={'out': timer_read_pcoll},
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=bundle_processor.DATA_INPUT_URN,
                      payload=create_buffer_id(
                          timer_read_pcoll, kind='timers'))))
          stage.transforms.append(
              beam_runner_api_pb2.PTransform(
                  unique_name=timer_write_pcoll + '/Write',
                  inputs={'in': timer_write_pcoll},
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=bundle_processor.DATA_OUTPUT_URN,
                      payload=create_buffer_id(
                          timer_write_pcoll, kind='timers'))))
          assert tag not in transform.inputs
          transform.inputs[tag] = timer_read_pcoll
          assert tag not in transform.outputs
          transform.outputs[tag] = timer_write_pcoll
          stage.timer_pcollections.append(
              (timer_read_pcoll + '/Read', timer_write_pcoll))
    yield stage
Example #40
0
    def lift_combiners(stages):
      """Expands CombinePerKey into pre- and post-grouping stages.

      ... -> CombinePerKey -> ...

      becomes

      ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
      """
      def add_or_get_coder_id(coder_proto):
        for coder_id, coder in pipeline_components.coders.items():
          if coder == coder_proto:
            return coder_id
        new_coder_id = unique_name(pipeline_components.coders, 'coder')
        pipeline_components.coders[new_coder_id].CopyFrom(coder_proto)
        return new_coder_id

      def windowed_coder_id(coder_id):
        proto = beam_runner_api_pb2.Coder(
            spec=beam_runner_api_pb2.SdkFunctionSpec(
                spec=beam_runner_api_pb2.FunctionSpec(
                    urn=urns.WINDOWED_VALUE_CODER)),
            component_coder_ids=[coder_id, window_coder_id])
        return add_or_get_coder_id(proto)

      for stage in stages:
        assert len(stage.transforms) == 1
        transform = stage.transforms[0]
        if transform.spec.urn == urns.COMBINE_PER_KEY_TRANSFORM:
          combine_payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.CombinePayload)

          input_pcoll = pipeline_components.pcollections[only_element(
              transform.inputs.values())]
          output_pcoll = pipeline_components.pcollections[only_element(
              transform.outputs.values())]

          windowed_input_coder = pipeline_components.coders[
              input_pcoll.coder_id]
          element_coder_id, window_coder_id = (
              windowed_input_coder.component_coder_ids)
          element_coder = pipeline_components.coders[element_coder_id]
          key_coder_id, _ = element_coder.component_coder_ids
          accumulator_coder_id = combine_payload.accumulator_coder_id

          key_accumulator_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.KV_CODER)),
              component_coder_ids=[key_coder_id, accumulator_coder_id])
          key_accumulator_coder_id = add_or_get_coder_id(key_accumulator_coder)

          accumulator_iter_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.ITERABLE_CODER)),
              component_coder_ids=[accumulator_coder_id])
          accumulator_iter_coder_id = add_or_get_coder_id(
              accumulator_iter_coder)

          key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
              spec=beam_runner_api_pb2.SdkFunctionSpec(
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.KV_CODER)),
              component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
          key_accumulator_iter_coder_id = add_or_get_coder_id(
              key_accumulator_iter_coder)

          precombined_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[precombined_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Precombine.out',
                  coder_id=windowed_coder_id(key_accumulator_coder_id),
                  windowing_strategy_id=input_pcoll.windowing_strategy_id,
                  is_bounded=input_pcoll.is_bounded))

          grouped_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[grouped_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Group.out',
                  coder_id=windowed_coder_id(key_accumulator_iter_coder_id),
                  windowing_strategy_id=output_pcoll.windowing_strategy_id,
                  is_bounded=output_pcoll.is_bounded))

          merged_pcoll_id = unique_name(
              pipeline_components.pcollections, 'pcollection')
          pipeline_components.pcollections[merged_pcoll_id].CopyFrom(
              beam_runner_api_pb2.PCollection(
                  unique_name=transform.unique_name + '/Merge.out',
                  coder_id=windowed_coder_id(key_accumulator_coder_id),
                  windowing_strategy_id=output_pcoll.windowing_strategy_id,
                  is_bounded=output_pcoll.is_bounded))

          def make_stage(base_stage, transform):
            return Stage(
                transform.unique_name,
                [transform],
                downstream_side_inputs=base_stage.downstream_side_inputs,
                must_follow=base_stage.must_follow)

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Precombine',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.PRECOMBINE_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs=transform.inputs,
                  outputs={'out': precombined_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Group',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.GROUP_BY_KEY_TRANSFORM),
                  inputs={'in': precombined_pcoll_id},
                  outputs={'out': grouped_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/Merge',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.MERGE_ACCUMULATORS_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs={'in': grouped_pcoll_id},
                  outputs={'out': merged_pcoll_id}))

          yield make_stage(
              stage,
              beam_runner_api_pb2.PTransform(
                  unique_name=transform.unique_name + '/ExtractOutputs',
                  spec=beam_runner_api_pb2.FunctionSpec(
                      urn=urns.EXTRACT_OUTPUTS_TRANSFORM,
                      payload=transform.spec.payload),
                  inputs={'in': merged_pcoll_id},
                  outputs=transform.outputs))

        else:
          yield stage
 def is_side_input(transform_proto, tag):
   if transform_proto.spec.urn == urns.PARDO_TRANSFORM:
     return tag in proto_utils.parse_Bytes(
         transform_proto.spec.payload,
         beam_runner_api_pb2.ParDoPayload).side_inputs
 def create_operation(self, transform_id, consumers):
   transform_proto = self.descriptor.transforms[transform_id]
   creator, parameter_type = self._known_urns[transform_proto.spec.urn]
   payload = proto_utils.parse_Bytes(
       transform_proto.spec.payload, parameter_type)
   return creator(self, transform_id, transform_proto, payload, consumers)
Example #43
0
  def executable_stage_transform(
      self, known_runner_urns, all_consumers, components):
    if (len(self.transforms) == 1
        and self.transforms[0].spec.urn in known_runner_urns):
      return self.transforms[0]

    else:
      all_inputs = set(
          pcoll for t in self.transforms for pcoll in t.inputs.values())
      all_outputs = set(
          pcoll for t in self.transforms for pcoll in t.outputs.values())
      internal_transforms = set(id(t) for t in self.transforms)
      external_outputs = [pcoll for pcoll in all_outputs
                          if all_consumers[pcoll] - internal_transforms]

      stage_components = beam_runner_api_pb2.Components()
      stage_components.CopyFrom(components)

      # Only keep the referenced PCollections.
      for pcoll_id in stage_components.pcollections.keys():
        if pcoll_id not in all_inputs and pcoll_id not in all_outputs:
          del stage_components.pcollections[pcoll_id]

      # Only keep the transforms in this stage.
      # Also gather up payload data as we iterate over the transforms.
      stage_components.transforms.clear()
      main_inputs = set()
      side_inputs = []
      user_states = []
      timers = []
      for ix, transform in enumerate(self.transforms):
        transform_id = 'transform_%d' % ix
        if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
          payload = proto_utils.parse_Bytes(
              transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
          for tag in payload.side_inputs.keys():
            side_inputs.append(
                beam_runner_api_pb2.ExecutableStagePayload.SideInputId(
                    transform_id=transform_id,
                    local_name=tag))
          for tag in payload.state_specs.keys():
            user_states.append(
                beam_runner_api_pb2.ExecutableStagePayload.UserStateId(
                    transform_id=transform_id,
                    local_name=tag))
          for tag in payload.timer_specs.keys():
            timers.append(
                beam_runner_api_pb2.ExecutableStagePayload.TimerId(
                    transform_id=transform_id,
                    local_name=tag))
          main_inputs.update(
              pcoll_id
              for tag, pcoll_id in transform.inputs.items()
              if tag not in payload.side_inputs)
        else:
          main_inputs.update(transform.inputs.values())
        stage_components.transforms[transform_id].CopyFrom(transform)

      main_input_id = only_element(main_inputs - all_outputs)
      named_inputs = dict({
          '%s:%s' % (side.transform_id, side.local_name):
          stage_components.transforms[side.transform_id].inputs[side.local_name]
          for side in side_inputs
      }, main_input=main_input_id)
      payload = beam_runner_api_pb2.ExecutableStagePayload(
          environment=components.environments[self.environment],
          input=main_input_id,
          outputs=external_outputs,
          transforms=stage_components.transforms.keys(),
          components=stage_components,
          side_inputs=side_inputs,
          user_states=user_states,
          timers=timers)

      return beam_runner_api_pb2.PTransform(
          unique_name=unique_name(None, self.name),
          spec=beam_runner_api_pb2.FunctionSpec(
              urn='beam:runner:executable_stage:v1',
              payload=payload.SerializeToString()),
          inputs=named_inputs,
          outputs={'output_%d' % ix: pcoll
                   for ix, pcoll in enumerate(external_outputs)})
Example #44
0
def lift_combiners(stages, context):
  """Expands CombinePerKey into pre- and post-grouping stages.

  ... -> CombinePerKey -> ...

  becomes

  ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
  """
  for stage in stages:
    assert len(stage.transforms) == 1
    transform = stage.transforms[0]
    if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
      combine_payload = proto_utils.parse_Bytes(
          transform.spec.payload, beam_runner_api_pb2.CombinePayload)

      input_pcoll = context.components.pcollections[only_element(
          list(transform.inputs.values()))]
      output_pcoll = context.components.pcollections[only_element(
          list(transform.outputs.values()))]

      element_coder_id = input_pcoll.coder_id
      element_coder = context.components.coders[element_coder_id]
      key_coder_id, _ = element_coder.component_coder_ids
      accumulator_coder_id = combine_payload.accumulator_coder_id

      key_accumulator_coder = beam_runner_api_pb2.Coder(
          spec=beam_runner_api_pb2.SdkFunctionSpec(
              spec=beam_runner_api_pb2.FunctionSpec(
                  urn=common_urns.coders.KV.urn)),
          component_coder_ids=[key_coder_id, accumulator_coder_id])
      key_accumulator_coder_id = context.add_or_get_coder_id(
          key_accumulator_coder)

      accumulator_iter_coder = beam_runner_api_pb2.Coder(
          spec=beam_runner_api_pb2.SdkFunctionSpec(
              spec=beam_runner_api_pb2.FunctionSpec(
                  urn=common_urns.coders.ITERABLE.urn)),
          component_coder_ids=[accumulator_coder_id])
      accumulator_iter_coder_id = context.add_or_get_coder_id(
          accumulator_iter_coder)

      key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
          spec=beam_runner_api_pb2.SdkFunctionSpec(
              spec=beam_runner_api_pb2.FunctionSpec(
                  urn=common_urns.coders.KV.urn)),
          component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
      key_accumulator_iter_coder_id = context.add_or_get_coder_id(
          key_accumulator_iter_coder)

      precombined_pcoll_id = unique_name(
          context.components.pcollections, 'pcollection')
      context.components.pcollections[precombined_pcoll_id].CopyFrom(
          beam_runner_api_pb2.PCollection(
              unique_name=transform.unique_name + '/Precombine.out',
              coder_id=key_accumulator_coder_id,
              windowing_strategy_id=input_pcoll.windowing_strategy_id,
              is_bounded=input_pcoll.is_bounded))

      grouped_pcoll_id = unique_name(
          context.components.pcollections, 'pcollection')
      context.components.pcollections[grouped_pcoll_id].CopyFrom(
          beam_runner_api_pb2.PCollection(
              unique_name=transform.unique_name + '/Group.out',
              coder_id=key_accumulator_iter_coder_id,
              windowing_strategy_id=output_pcoll.windowing_strategy_id,
              is_bounded=output_pcoll.is_bounded))

      merged_pcoll_id = unique_name(
          context.components.pcollections, 'pcollection')
      context.components.pcollections[merged_pcoll_id].CopyFrom(
          beam_runner_api_pb2.PCollection(
              unique_name=transform.unique_name + '/Merge.out',
              coder_id=key_accumulator_coder_id,
              windowing_strategy_id=output_pcoll.windowing_strategy_id,
              is_bounded=output_pcoll.is_bounded))

      def make_stage(base_stage, transform):
        return Stage(
            transform.unique_name,
            [transform],
            downstream_side_inputs=base_stage.downstream_side_inputs,
            must_follow=base_stage.must_follow,
            parent=base_stage.name,
            environment=base_stage.environment)

      yield make_stage(
          stage,
          beam_runner_api_pb2.PTransform(
              unique_name=transform.unique_name + '/Precombine',
              spec=beam_runner_api_pb2.FunctionSpec(
                  urn=common_urns.combine_components
                  .COMBINE_PER_KEY_PRECOMBINE.urn,
                  payload=transform.spec.payload),
              inputs=transform.inputs,
              outputs={'out': precombined_pcoll_id}))

      yield make_stage(
          stage,
          beam_runner_api_pb2.PTransform(
              unique_name=transform.unique_name + '/Group',
              spec=beam_runner_api_pb2.FunctionSpec(
                  urn=common_urns.primitives.GROUP_BY_KEY.urn),
              inputs={'in': precombined_pcoll_id},
              outputs={'out': grouped_pcoll_id}))

      yield make_stage(
          stage,
          beam_runner_api_pb2.PTransform(
              unique_name=transform.unique_name + '/Merge',
              spec=beam_runner_api_pb2.FunctionSpec(
                  urn=common_urns.combine_components
                  .COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
                  payload=transform.spec.payload),
              inputs={'in': grouped_pcoll_id},
              outputs={'out': merged_pcoll_id}))

      yield make_stage(
          stage,
          beam_runner_api_pb2.PTransform(
              unique_name=transform.unique_name + '/ExtractOutputs',
              spec=beam_runner_api_pb2.FunctionSpec(
                  urn=common_urns.combine_components
                  .COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
                  payload=transform.spec.payload),
              inputs={'in': merged_pcoll_id},
              outputs=transform.outputs))

    else:
      yield stage
Example #45
0
def expand_sdf(stages, context):
  """Transforms splitable DoFns into pair+split+read."""
  for stage in stages:
    assert len(stage.transforms) == 1
    transform = stage.transforms[0]
    if transform.spec.urn == common_urns.primitives.PAR_DO.urn:

      pardo_payload = proto_utils.parse_Bytes(
          transform.spec.payload, beam_runner_api_pb2.ParDoPayload)

      if pardo_payload.splittable:

        def copy_like(protos, original, suffix='_copy', **kwargs):
          if isinstance(original, (str, unicode)):
            key = original
            original = protos[original]
          else:
            key = 'component'
          new_id = unique_name(protos, key + suffix)
          protos[new_id].CopyFrom(original)
          proto = protos[new_id]
          for name, value in kwargs.items():
            if isinstance(value, dict):
              getattr(proto, name).clear()
              getattr(proto, name).update(value)
            elif isinstance(value, list):
              del getattr(proto, name)[:]
              getattr(proto, name).extend(value)
            elif name == 'urn':
              proto.spec.urn = value
            else:
              setattr(proto, name, value)
          return new_id

        def make_stage(base_stage, transform_id, extra_must_follow=()):
          transform = context.components.transforms[transform_id]
          return Stage(
              transform.unique_name,
              [transform],
              base_stage.downstream_side_inputs,
              union(base_stage.must_follow, frozenset(extra_must_follow)),
              parent=base_stage,
              environment=base_stage.environment)

        main_input_tag = only_element(tag for tag in transform.inputs.keys()
                                      if tag not in pardo_payload.side_inputs)
        main_input_id = transform.inputs[main_input_tag]
        element_coder_id = context.components.pcollections[
            main_input_id].coder_id
        paired_coder_id = context.add_or_get_coder_id(
            beam_runner_api_pb2.Coder(
                spec=beam_runner_api_pb2.SdkFunctionSpec(
                    spec=beam_runner_api_pb2.FunctionSpec(
                        urn=common_urns.coders.KV.urn)),
                component_coder_ids=[element_coder_id,
                                     pardo_payload.restriction_coder_id]))

        paired_pcoll_id = copy_like(
            context.components.pcollections,
            main_input_id,
            '_paired',
            coder_id=paired_coder_id)
        pair_transform_id = copy_like(
            context.components.transforms,
            transform,
            unique_name=transform.unique_name + '/PairWithRestriction',
            urn=common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
            outputs={'out': paired_pcoll_id})

        split_pcoll_id = copy_like(
            context.components.pcollections,
            main_input_id,
            '_split',
            coder_id=paired_coder_id)
        split_transform_id = copy_like(
            context.components.transforms,
            transform,
            unique_name=transform.unique_name + '/SplitRestriction',
            urn=common_urns.sdf_components.SPLIT_RESTRICTION.urn,
            inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}),
            outputs={'out': split_pcoll_id})

        process_transform_id = copy_like(
            context.components.transforms,
            transform,
            unique_name=transform.unique_name + '/Process',
            urn=common_urns.sdf_components.PROCESS_ELEMENTS.urn,
            inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id}))

        yield make_stage(stage, pair_transform_id)
        split_stage = make_stage(stage, split_transform_id)
        yield split_stage
        yield make_stage(
            stage, process_transform_id, extra_must_follow=[split_stage])

      else:
        yield stage

    else:
      yield stage
 def is_side_input(transform_proto, tag):
   if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
     return tag in proto_utils.parse_Bytes(
         transform_proto.spec.payload,
         beam_runner_api_pb2.ParDoPayload).side_inputs
Example #47
0
 def create(cls, environment, state, provision_info):
   constructor, payload_type = cls._registered_environments[environment.urn]
   return constructor(
       proto_utils.parse_Bytes(environment.payload, payload_type),
       state,
       provision_info)