Ejemplo n.º 1
0
 def from_runner_api(proto, unused_context):
   assert proto.view_fn.spec.urn == urns.PICKLED_PYTHON_VIEWFN
   assert proto.window_mapping_fn.spec.urn == urns.PICKLED_WINDOW_MAPPING_FN
   return SideInputData(
       proto.access_pattern.urn,
       pickler.loads(proto.window_mapping_fn.spec.payload),
       *pickler.loads(proto.view_fn.spec.payload))
Ejemplo n.º 2
0
  def __init__(self, fn, *args, **kwargs):
    if isinstance(fn, type) and issubclass(fn, WithTypeHints):
      # Don't treat Fn class objects as callables.
      raise ValueError('Use %s() not %s.' % (fn.__name__, fn.__name__))
    self.fn = self.make_fn(fn)
    # Now that we figure out the label, initialize the super-class.
    super(PTransformWithSideInputs, self).__init__()

    if (any([isinstance(v, pvalue.PCollection) for v in args]) or
        any([isinstance(v, pvalue.PCollection) for v in kwargs.itervalues()])):
      raise error.SideInputError(
          'PCollection used directly as side input argument. Specify '
          'AsIter(pcollection) or AsSingleton(pcollection) to indicate how the '
          'PCollection is to be used.')
    self.args, self.kwargs, self.side_inputs = util.remove_objects_from_args(
        args, kwargs, pvalue.AsSideInput)
    self.raw_side_inputs = args, kwargs

    # Prevent name collisions with fns of the form '<function <lambda> at ...>'
    self._cached_fn = self.fn

    # Ensure fn and side inputs are picklable for remote execution.
    self.fn = pickler.loads(pickler.dumps(self.fn))
    self.args = pickler.loads(pickler.dumps(self.args))
    self.kwargs = pickler.loads(pickler.dumps(self.kwargs))

    # For type hints, because loads(dumps(class)) != class.
    self.fn = self._cached_fn
Ejemplo n.º 3
0
 def test_nested_class(self):
   """Tests that a nested class object is pickled correctly."""
   self.assertEquals(
       'X:abc',
       loads(dumps(module_test.TopClass.NestedClass('abc'))).datum)
   self.assertEquals(
       'Y:abc',
       loads(dumps(module_test.TopClass.MiddleClass.NestedClass('abc'))).datum)
Ejemplo n.º 4
0
 def visit_transform(self, transform_node):
   try:
     # Transforms must be picklable.
     pickler.loads(pickler.dumps(transform_node.transform,
                                 enable_trace=False),
                   enable_trace=False)
   except Exception:
     Visitor.ok = False
Ejemplo n.º 5
0
 def visit_transform(self, transform_node):
   if transform_node.side_inputs:
     # No side inputs (yet).
     Visitor.ok = False
   try:
     # Transforms must be picklable.
     pickler.loads(pickler.dumps(transform_node.transform))
   except Exception:
     Visitor.ok = False
Ejemplo n.º 6
0
  def _get_concat_source(self):
    if self._concat_source is None:
      pattern = self._pattern.get()

      single_file_sources = []
      match_result = FileSystems.match([pattern])[0]
      files_metadata = match_result.metadata_list

      # We create a reference for FileBasedSource that will be serialized along
      # with each _SingleFileSource. To prevent this FileBasedSource from having
      # a reference to ConcatSource (resulting in quadratic space complexity)
      # we clone it here.
      file_based_source_ref = pickler.loads(pickler.dumps(self))

      for file_metadata in files_metadata:
        file_name = file_metadata.path
        file_size = file_metadata.size_in_bytes
        if file_size == 0:
          continue  # Ignoring empty file.

        # We determine splittability of this specific file.
        splittable = (
            self.splittable and
            _determine_splittability_from_compression_type(
                file_name, self._compression_type))

        single_file_source = _SingleFileSource(
            file_based_source_ref, file_name,
            0,
            file_size,
            min_bundle_size=self._min_bundle_size,
            splittable=splittable)
        single_file_sources.append(single_file_source)
      self._concat_source = concat_source.ConcatSource(single_file_sources)
    return self._concat_source
Ejemplo n.º 7
0
 def __init__(self, operation_name, spec, counter_factory, state_sampler):
   super(PGBKCVOperation, self).__init__(
       operation_name, spec, counter_factory, state_sampler)
   # Combiners do not accept deferred side-inputs (the ignored fourth
   # argument) and therefore the code to handle the extra args/kwargs is
   # simpler than for the DoFn's of ParDo.
   fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
   self.combine_fn = curry_combine_fn(fn, args, kwargs)
   if (getattr(fn.add_input, 'im_func', None)
       is core.CombineFn.add_input.__func__):
     # Old versions of the SDK have CombineFns that don't implement add_input.
     self.combine_fn_add_input = (
         lambda a, e: self.combine_fn.add_inputs(a, [e]))
   else:
     self.combine_fn_add_input = self.combine_fn.add_input
   # Optimization for the (known tiny accumulator, often wide keyspace)
   # combine functions.
   # TODO(b/36567833): Bound by in-memory size rather than key count.
   self.max_keys = (
       1000 * 1000 if
       isinstance(fn, (combiners.CountCombineFn, combiners.MeanCombineFn)) or
       # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized
       # combiners to the short list above.
       (isinstance(fn, core.CallableWrapperCombineFn) and
        fn._fn in (min, max, sum)) else 100 * 1000)  # pylint: disable=protected-access
   self.key_count = 0
   self.table = {}
  def run_pipeline(self, pipeline):
    """Execute test pipeline and verify test matcher"""
    options = pipeline._options.view_as(TestOptions)
    on_success_matcher = options.on_success_matcher
    is_streaming = options.view_as(StandardOptions).streaming

    # [BEAM-1889] Do not send this to remote workers also, there is no need to
    # send this option to remote executors.
    options.on_success_matcher = None

    self.result = super(TestDirectRunner, self).run_pipeline(pipeline)

    try:
      if not is_streaming:
        self.result.wait_until_finish()

      if on_success_matcher:
        from hamcrest import assert_that as hc_assert_that
        hc_assert_that(self.result, pickler.loads(on_success_matcher))
    finally:
      if not PipelineState.is_terminal(self.result.state):
        self.result.cancel()
        self.result.wait_until_finish()

    return self.result
Ejemplo n.º 9
0
  def run_pipeline(self, pipeline):
    docker_image = (
        pipeline.options.view_as(PortableOptions).harness_docker_image
        or self.default_docker_image())
    job_endpoint = pipeline.options.view_as(PortableOptions).job_endpoint
    if not job_endpoint:
      raise ValueError(
          'job_endpoint should be provided while creating runner.')

    proto_context = pipeline_context.PipelineContext(
        default_environment_url=docker_image)
    proto_pipeline = pipeline.to_runner_api(context=proto_context)

    if not self.is_embedded_fnapi_runner:
      # Java has different expectations about coders
      # (windowed in Fn API, but *un*windowed in runner API), whereas the
      # embedded FnApiRunner treats them consistently, so we must guard this
      # for now, until FnApiRunner is fixed.
      # See also BEAM-2717.
      for pcoll in proto_pipeline.components.pcollections.values():
        if pcoll.coder_id not in proto_context.coders:
          # This is not really a coder id, but a pickled coder.
          coder = coders.registry.get_coder(pickler.loads(pcoll.coder_id))
          pcoll.coder_id = proto_context.coders.get_id(coder)
      proto_context.coders.populate_map(proto_pipeline.components.coders)

    # Some runners won't detect the GroupByKey transform unless it has no
    # subtransforms.  Remove all sub-transforms until BEAM-4605 is resolved.
    for _, transform_proto in list(
        proto_pipeline.components.transforms.items()):
      if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
        for sub_transform in transform_proto.subtransforms:
          del proto_pipeline.components.transforms[sub_transform]
        del transform_proto.subtransforms[:]

    # TODO: Define URNs for options.
    options = {'beam:option:' + k + ':v1': v
               for k, v in pipeline._options.get_all_options().iteritems()
               if v is not None}

    job_service = beam_job_api_pb2_grpc.JobServiceStub(
        grpc.insecure_channel(job_endpoint))
    prepare_response = job_service.Prepare(
        beam_job_api_pb2.PrepareJobRequest(
            job_name='job', pipeline=proto_pipeline,
            pipeline_options=job_utils.dict_to_struct(options)))
    if prepare_response.artifact_staging_endpoint.url:
      stager = portable_stager.PortableStager(
          grpc.insecure_channel(prepare_response.artifact_staging_endpoint.url),
          prepare_response.staging_session_token)
      retrieval_token, _ = stager.stage_job_resources(
          pipeline._options,
          staging_location='')
    else:
      retrieval_token = None
    run_response = job_service.Run(
        beam_job_api_pb2.RunJobRequest(
            preparation_id=prepare_response.preparation_id,
            retrieval_token=retrieval_token))
    return PipelineResult(job_service, run_response.job_id)
Ejemplo n.º 10
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = pickler.loads(pickler.dumps(transform.dofn))

    pipeline_options = self._evaluation_context.pipeline_options
    if (pipeline_options is not None
        and pipeline_options.view_as(TypeOptions).runtime_type_check):
      dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

    dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
    self.runner = DoFnRunner(
        dofn, transform.args, transform.kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()
Ejemplo n.º 11
0
def _create_pardo_operation(
    factory, transform_id, transform_proto, consumers,
    serialized_fn, side_inputs_proto=None):

  if side_inputs_proto:
    input_tags_to_coders = factory.get_input_coders(transform_proto)
    tagged_side_inputs = [
        (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context))
        for tag, si in side_inputs_proto.items()]
    tagged_side_inputs.sort(key=lambda tag_si: int(tag_si[0][4:]))
    side_input_maps = [
        StateBackedSideInputMap(
            factory.state_handler,
            transform_id,
            tag,
            si,
            input_tags_to_coders[tag])
        for tag, si in tagged_side_inputs]
  else:
    side_input_maps = []

  output_tags = list(transform_proto.outputs.keys())

  # Hack to match out prefix injected by dataflow runner.
  def mutate_tag(tag):
    if 'None' in output_tags:
      if tag == 'None':
        return 'out'
      else:
        return 'out_' + tag
    else:
      return tag

  dofn_data = pickler.loads(serialized_fn)
  if not dofn_data[-1]:
    # Windowing not set.
    side_input_tags = side_inputs_proto or ()
    pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items()
                 if tag not in side_input_tags]
    windowing = factory.context.windowing_strategies.get_by_id(
        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))

  output_coders = factory.get_output_coders(transform_proto)
  spec = operation_specs.WorkerDoFn(
      serialized_fn=serialized_fn,
      output_tags=[mutate_tag(tag) for tag in output_tags],
      input=None,
      side_inputs=None,  # Fn API uses proto definitions and the Fn State API
      output_coders=[output_coders[tag] for tag in output_tags])
  return factory.augment_oldstyle_op(
      operations.DoOperation(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler,
          side_input_maps),
      transform_proto.unique_name,
      consumers,
      output_tags)
Ejemplo n.º 12
0
  def run_pipeline(self, pipeline):
    """Execute test pipeline and verify test matcher"""
    options = pipeline._options.view_as(TestOptions)
    on_success_matcher = options.on_success_matcher
    wait_duration = options.wait_until_finish_duration
    is_streaming = options.view_as(StandardOptions).streaming

    # [BEAM-1889] Do not send this to remote workers also, there is no need to
    # send this option to remote executors.
    options.on_success_matcher = None

    self.result = super(TestDataflowRunner, self).run_pipeline(pipeline)
    if self.result.has_job:
      # TODO(markflyhigh)(BEAM-1890): Use print since Nose dosen't show logs
      # in some cases.
      print('Found: %s.' % self.build_console_url(pipeline.options))

    try:
      self.wait_until_in_state(PipelineState.RUNNING)

      if is_streaming and not wait_duration:
        logging.warning('Waiting indefinitely for streaming job.')
      self.result.wait_until_finish(duration=wait_duration)

      if on_success_matcher:
        from hamcrest import assert_that as hc_assert_that
        hc_assert_that(self.result, pickler.loads(on_success_matcher))
    finally:
      if not self.result.is_in_terminal_state():
        self.result.cancel()
        self.wait_until_in_state(PipelineState.CANCELLED, timeout=300)

    return self.result
Ejemplo n.º 13
0
  def _get_concat_source(self):
    if self._concat_source is None:
      single_file_sources = []
      file_names = [f for f in fileio.ChannelFactory.glob(self._pattern)]
      sizes = FileBasedSource._estimate_sizes_of_files(file_names,
                                                       self._pattern)

      # We create a reference for FileBasedSource that will be serialized along
      # with each _SingleFileSource. To prevent this FileBasedSource from having
      # a reference to ConcatSource (resulting in quadratic space complexity)
      # we clone it here.
      file_based_source_ref = pickler.loads(pickler.dumps(self))

      for index, file_name in enumerate(file_names):
        if sizes[index] == 0:
          continue  # Ignoring empty file.

        # We determine splittability of this specific file.
        splittable = self.splittable
        if (splittable and
            self._compression_type == fileio.CompressionTypes.AUTO):
          compression_type = fileio.CompressionTypes.detect_compression_type(
              file_name)
          if compression_type != fileio.CompressionTypes.UNCOMPRESSED:
            splittable = False

        single_file_source = _SingleFileSource(
            file_based_source_ref, file_name,
            0,
            sizes[index],
            min_bundle_size=self._min_bundle_size,
            splittable=splittable)
        single_file_sources.append(single_file_source)
      self._concat_source = concat_source.ConcatSource(single_file_sources)
    return self._concat_source
Ejemplo n.º 14
0
  def run(self, pipeline):
    """Execute test pipeline and verify test matcher"""
    options = pipeline.options.view_as(TestOptions)
    on_success_matcher = options.on_success_matcher

    # [BEAM-1889] Do not send this to remote workers also, there is no need to
    # send this option to remote executors.
    options.on_success_matcher = None

    self.result = super(TestDataflowRunner, self).run(pipeline)
    if self.result.has_job:
      project = pipeline.options.view_as(GoogleCloudOptions).project
      job_id = self.result.job_id()
      # TODO(markflyhigh)(BEAM-1890): Use print since Nose dosen't show logs
      # in some cases.
      print (
          'Found: https://console.cloud.google.com/dataflow/job/%s?project=%s' %
          (job_id, project))
    self.result.wait_until_finish()

    if on_success_matcher:
      from hamcrest import assert_that as hc_assert_that
      hc_assert_that(self.result, pickler.loads(on_success_matcher))

    return self.result
Ejemplo n.º 15
0
 def test_get_coder_with_composite_custom_coder(self):
   typecoders.registry.register_coder(CustomClass, CustomCoder)
   coder = typecoders.registry.get_coder(typehints.KV[CustomClass, str])
   revived_coder = pickler.loads(pickler.dumps(coder))
   self.assertEqual(
       (CustomClass(123), 'abc'),
       revived_coder.decode(revived_coder.encode((CustomClass(123), 'abc'))))
Ejemplo n.º 16
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (pickler.loads(pickler.dumps(transform.dofn))
            if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.runner = DoFnRunner(
        dofn, args, kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()
Ejemplo n.º 17
0
 def test_append_verifier_in_extra_opt(self):
   extra_opt = {'matcher': SimpleMatcher()}
   opt_list = TestPipeline().get_full_options_as_args(**extra_opt)
   _, value = opt_list[0].split('=', 1)
   matcher = pickler.loads(value)
   self.assertTrue(isinstance(matcher, BaseMatcher))
   hc_assert_that(None, matcher)
Ejemplo n.º 18
0
 def _get_manifest_contents(self, manifest_file_handle):
   manifest_file_handle.seek(0)
   try:
     return pickler.loads(manifest_file_handle.read())
   except ValueError as e:
     tf.compat.v1.logging.error('Can\'t load cache manifest contents: %s',
                                str(e))
     return {}
Ejemplo n.º 19
0
 def from_runner_api(proto, context):
   # Producer and tag will be filled in later, the key point is that the
   # same object is returned for the same pcollection id.
   return PCollection(
       None,
       element_type=pickler.loads(proto.coder_id),
       windowing=context.windowing_strategies.get_by_id(
           proto.windowing_strategy_id))
Ejemplo n.º 20
0
  def test_lambda_with_globals(self):
    """Tests that the globals of a function are preserved."""

    # The point of the test is that the lambda being called after unpickling
    # relies on having the re module being loaded.
    self.assertEquals(
        ['abc', 'def'],
        loads(dumps(module_test.get_lambda_with_globals()))('abc def'))
Ejemplo n.º 21
0
 def from_runner_api(proto, context):
   # Producer and tag will be filled in later, the key point is that the
   # same object is returned for the same pcollection id.
   return PCollection(
       None,
       element_type=pickler.loads(proto.coder_id),
       windowing=context.windowing_strategies.get_by_id(
           proto.windowing_strategy_id))
Ejemplo n.º 22
0
    def test_lambda_with_globals(self):
        """Tests that the globals of a function are preserved."""

        # The point of the test is that the lambda being called after unpickling
        # relies on having the re module being loaded.
        self.assertEquals(
            ['abc', 'def'],
            loads(dumps(module_test.get_lambda_with_globals()))('abc def'))
Ejemplo n.º 23
0
    def setup(self):
        with self.scoped_start_state:
            super(DoOperation, self).setup()

            # See fn_data in dataflow_runner.py
            fn, args, kwargs, tags_and_types, window_fn = (pickler.loads(
                self.spec.serialized_fn))

            state = common.DoFnState(self.counter_factory)
            state.step_name = self.name_context.logging_name()

            # Tag to output index map used to dispatch the side output values emitted
            # by the DoFn function to the appropriate receivers. The main output is
            # tagged with None and is associated with its corresponding index.
            self.tagged_receivers = _TaggedReceivers(
                self.counter_factory, self.name_context.logging_name())

            output_tag_prefix = PropertyNames.OUT + '_'
            for index, tag in enumerate(self.spec.output_tags):
                if tag == PropertyNames.OUT:
                    original_tag = None
                elif tag.startswith(output_tag_prefix):
                    original_tag = tag[len(output_tag_prefix):]
                else:
                    raise ValueError(
                        'Unexpected output name for operation: %s' % tag)
                self.tagged_receivers[original_tag] = self.receivers[index]

            if self.user_state_context:
                self.user_state_context.update_timer_receivers(
                    self.tagged_receivers)
                self.timer_specs = {
                    spec.name: spec
                    for spec in userstate.get_dofn_specs(fn)[1]
                }

            if self.side_input_maps is None:
                if tags_and_types:
                    self.side_input_maps = list(
                        self._read_side_inputs(tags_and_types))
                else:
                    self.side_input_maps = []

            self.dofn_runner = common.DoFnRunner(
                fn,
                args,
                kwargs,
                self.side_input_maps,
                window_fn,
                tagged_receivers=self.tagged_receivers,
                step_name=self.name_context.logging_name(),
                state=state,
                user_state_context=self.user_state_context,
                operation_name=self.name_context.metrics_name())

            self.dofn_receiver = (self.dofn_runner if isinstance(
                self.dofn_runner, Receiver) else DoFnRunnerReceiver(
                    self.dofn_runner))
Ejemplo n.º 24
0
    def start_bundle(self):
        transform = self._applied_ptransform.transform

        self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
        if isinstance(self._applied_ptransform.parent.transform,
                      core._MultiParDo):  # pylint: disable=protected-access
            do_outputs_tuple = self._applied_ptransform.parent.outputs[0]
            assert isinstance(do_outputs_tuple, pvalue.DoOutputsTuple)
            main_output_pcollection = do_outputs_tuple[
                do_outputs_tuple._main_tag]  # pylint: disable=protected-access

            for side_output_tag in transform.side_output_tags:
                output_pcollection = do_outputs_tuple[side_output_tag]
                self._tagged_receivers[side_output_tag] = (
                    self._evaluation_context.create_bundle(output_pcollection))
                self._tagged_receivers[side_output_tag].tag = side_output_tag
        else:
            assert len(self._outputs) == 1
            main_output_pcollection = list(self._outputs)[0]

        self._tagged_receivers[None] = self._evaluation_context.create_bundle(
            main_output_pcollection)
        self._tagged_receivers[None].tag = None  # main_tag is None.

        self._counter_factory = counters.CounterFactory()

        # TODO(aaltay): Consider storing the serialized form as an optimization.
        dofn = pickler.loads(pickler.dumps(transform.dofn))

        pipeline_options = self._evaluation_context.pipeline_options
        if (pipeline_options is not None
                and pipeline_options.view_as(TypeOptions).runtime_type_check):
            # TODO(sourabhbajaj): Remove this if-else
            if isinstance(dofn, core.NewDoFn):
                dofn = TypeCheckWrapperNewDoFn(dofn,
                                               transform.get_type_hints())
            else:
                dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

        # TODO(sourabhbajaj): Remove this if-else
        if isinstance(dofn, core.NewDoFn):
            dofn = OutputCheckWrapperNewDoFn(
                dofn, self._applied_ptransform.full_label)
        else:
            dofn = OutputCheckWrapperDoFn(dofn,
                                          self._applied_ptransform.full_label)
        self.runner = DoFnRunner(
            dofn,
            transform.args,
            transform.kwargs,
            self._side_inputs,
            self._applied_ptransform.inputs[0].windowing,
            tagged_receivers=self._tagged_receivers,
            step_name=self._applied_ptransform.full_label,
            state=DoFnState(self._counter_factory),
            scoped_metrics_container=self.scoped_metrics_container)
        self.runner.start()
Ejemplo n.º 25
0
 def __init__(self, name_context, spec, counter_factory, state_sampler):
     super(CombineOperation, self).__init__(name_context, spec,
                                            counter_factory, state_sampler)
     # Combiners do not accept deferred side-inputs (the ignored fourth argument)
     # and therefore the code to handle the extra args/kwargs is simpler than for
     # the DoFn's of ParDo.
     fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3]
     self.phased_combine_fn = (PhasedCombineFnExecutor(
         self.spec.phase, fn, args, kwargs))
Ejemplo n.º 26
0
def _create_pardo_operation(factory,
                            transform_id,
                            transform_proto,
                            consumers,
                            serialized_fn,
                            side_inputs_proto=None):

    if side_inputs_proto:
        tagged_side_inputs = [
            (tag,
             beam.pvalue.SideInputData.from_runner_api(si, factory.context))
            for tag, si in side_inputs_proto.items()
        ]
        tagged_side_inputs.sort(key=lambda tag_si: int(tag_si[0][4:]))
        side_input_maps = [
            StateBackedSideInputMap(factory.state_handler, transform_id, tag,
                                    si) for tag, si in tagged_side_inputs
        ]
    else:
        side_input_maps = []

    output_tags = list(transform_proto.outputs.keys())

    # Hack to match out prefix injected by dataflow runner.
    def mutate_tag(tag):
        if 'None' in output_tags:
            if tag == 'None':
                return 'out'
            else:
                return 'out_' + tag
        else:
            return tag

    dofn_data = pickler.loads(serialized_fn)
    if not dofn_data[-1]:
        # Windowing not set.
        side_input_tags = side_inputs_proto or ()
        pcoll_id, = [
            pcoll for tag, pcoll in transform_proto.inputs.items()
            if tag not in side_input_tags
        ]
        windowing = factory.context.windowing_strategies.get_by_id(
            factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
        serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing, ))

    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=serialized_fn,
        output_tags=[mutate_tag(tag) for tag in output_tags],
        input=None,
        side_inputs=[],  # Obsoleted by side_input_maps.
        output_coders=[output_coders[tag] for tag in output_tags])
    return factory.augment_oldstyle_op(
        operations.DoOperation(transform_proto.unique_name, spec,
                               factory.counter_factory, factory.state_sampler,
                               side_input_maps), transform_proto.unique_name,
        consumers, output_tags)
Ejemplo n.º 27
0
 def __init__(self, operation_name, spec, counter_factory, state_sampler):
   super(CombineOperation, self).__init__(
       operation_name, spec, counter_factory, state_sampler)
   # Combiners do not accept deferred side-inputs (the ignored fourth argument)
   # and therefore the code to handle the extra args/kwargs is simpler than for
   # the DoFn's of ParDo.
   fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3]
   self.phased_combine_fn = (
       PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs))
Ejemplo n.º 28
0
 def register_pickle_urn(cls, pickle_urn):
     """Registers and implements the given urn via pickling.
 """
     inspect.currentframe().f_back.f_locals['to_runner_api_parameter'] = (
         lambda self, context:
         (pickle_urn, wrappers_pb2.BytesValue(value=pickler.dumps(self))))
     cls.register_urn(
         pickle_urn, wrappers_pb2.BytesValue,
         lambda proto, unused_context: pickler.loads(proto.value))
Ejemplo n.º 29
0
def create(factory, transform_id, transform_proto, parameter, consumers):
    source = pickler.loads(parameter.value)
    spec = operation_specs.WorkerRead(
        iobase.SourceBundle(1.0, source, None, None),
        [WindowedValueCoder(source.default_output_coder())])
    return factory.augment_oldstyle_op(
        operations.ReadOperation(transform_proto.unique_name, spec,
                                 factory.counter_factory,
                                 factory.state_sampler),
        transform_proto.unique_name, consumers)
def create(factory, transform_id, transform_proto, parameter, consumers):
    dofn_data = pickler.loads(parameter)
    if len(dofn_data) == 2:
        # Has side input data.
        serialized_fn, side_input_data = dofn_data
    else:
        # No side input data.
        serialized_fn, side_input_data = parameter, []
    return _create_pardo_operation(factory, transform_id, transform_proto,
                                   consumers, serialized_fn, side_input_data)
Ejemplo n.º 31
0
Archivo: urns.py Proyecto: gamars/beam
 def register_pickle_urn(cls, pickle_urn):
   """Registers and implements the given urn via pickling.
   """
   inspect.currentframe().f_back.f_locals['to_runner_api_parameter'] = (
       lambda self, context: (
           pickle_urn, wrappers_pb2.BytesValue(value=pickler.dumps(self))))
   cls.register_urn(
       pickle_urn,
       wrappers_pb2.BytesValue,
       lambda proto, unused_context: pickler.loads(proto.value))
Ejemplo n.º 32
0
    def run_pipeline(self, pipeline):
        docker_image = (
            pipeline.options.view_as(PortableOptions).harness_docker_image
            or self.default_docker_image())
        job_endpoint = pipeline.options.view_as(PortableOptions).job_endpoint
        if not job_endpoint:
            raise ValueError(
                'job_endpoint should be provided while creating runner.')

        proto_context = pipeline_context.PipelineContext(
            default_environment_url=docker_image)
        proto_pipeline = pipeline.to_runner_api(context=proto_context)

        if not self.is_embedded_fnapi_runner:
            # Java has different expectations about coders
            # (windowed in Fn API, but *un*windowed in runner API), whereas the
            # embedded FnApiRunner treats them consistently, so we must guard this
            # for now, until FnApiRunner is fixed.
            # See also BEAM-2717.
            for pcoll in proto_pipeline.components.pcollections.values():
                if pcoll.coder_id not in proto_context.coders:
                    # This is not really a coder id, but a pickled coder.
                    coder = coders.registry.get_coder(
                        pickler.loads(pcoll.coder_id))
                    pcoll.coder_id = proto_context.coders.get_id(coder)
            proto_context.coders.populate_map(proto_pipeline.components.coders)

        # Some runners won't detect the GroupByKey transform unless it has no
        # subtransforms.  Remove all sub-transforms until BEAM-4605 is resolved.
        for _, transform_proto in list(
                proto_pipeline.components.transforms.items()):
            if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
                for sub_transform in transform_proto.subtransforms:
                    del proto_pipeline.components.transforms[sub_transform]
                del transform_proto.subtransforms[:]

        job_service = beam_job_api_pb2_grpc.JobServiceStub(
            grpc.insecure_channel(job_endpoint))
        prepare_response = job_service.Prepare(
            beam_job_api_pb2.PrepareJobRequest(job_name='job',
                                               pipeline=proto_pipeline))
        if prepare_response.artifact_staging_endpoint.url:
            stager = portable_stager.PortableStager(
                grpc.insecure_channel(
                    prepare_response.artifact_staging_endpoint.url),
                prepare_response.staging_session_token)
            retrieval_token, _ = stager.stage_job_resources(
                pipeline._options, staging_location='')
        else:
            retrieval_token = None
        run_response = job_service.Run(
            beam_job_api_pb2.RunJobRequest(
                preparation_id=prepare_response.preparation_id,
                retrieval_token=retrieval_token))
        return PipelineResult(job_service, run_response.job_id)
def create(factory, transform_id, transform_proto, parameter, consumers):
    # The Dataflow runner harness strips the base64 encoding.
    source = pickler.loads(base64.b64encode(parameter))
    spec = operation_specs.WorkerRead(
        iobase.SourceBundle(1.0, source, None, None),
        [WindowedValueCoder(source.default_output_coder())])
    return factory.augment_oldstyle_op(
        operations.ReadOperation(transform_proto.unique_name, spec,
                                 factory.counter_factory,
                                 factory.state_sampler),
        transform_proto.unique_name, consumers)
Ejemplo n.º 34
0
    def start(self):
        with self.scoped_start_state:
            super(DoOperation, self).start()

            # See fn_data in dataflow_runner.py
            fn, args, kwargs, tags_and_types, window_fn = (pickler.loads(
                self.spec.serialized_fn))

            state = common.DoFnState(self.counter_factory)
            state.step_name = self.step_name

            # TODO(silviuc): What is the proper label here? PCollection being
            # processed?
            context = common.DoFnContext('label', state=state)
            # Tag to output index map used to dispatch the side output values emitted
            # by the DoFn function to the appropriate receivers. The main output is
            # tagged with None and is associated with its corresponding index.
            self.tagged_receivers = _TaggedReceivers(self.counter_factory,
                                                     self.step_name)

            output_tag_prefix = PropertyNames.OUT + '_'
            for index, tag in enumerate(self.spec.output_tags):
                if tag == PropertyNames.OUT:
                    original_tag = None
                elif tag.startswith(output_tag_prefix):
                    original_tag = tag[len(output_tag_prefix):]
                else:
                    raise ValueError(
                        'Unexpected output name for operation: %s' % tag)
                self.tagged_receivers[original_tag] = self.receivers[index]

            if self.side_input_maps is None:
                if tags_and_types:
                    self.side_input_maps = list(
                        self._read_side_inputs(tags_and_types))
                else:
                    self.side_input_maps = []

            self.dofn_runner = common.DoFnRunner(
                fn,
                args,
                kwargs,
                self.side_input_maps,
                window_fn,
                context,
                self.tagged_receivers,
                logger,
                self.step_name,
                scoped_metrics_container=self.scoped_metrics_container)
            self.dofn_receiver = (self.dofn_runner if isinstance(
                self.dofn_runner, Receiver) else DoFnRunnerReceiver(
                    self.dofn_runner))

            self.dofn_runner.start()
Ejemplo n.º 35
0
  def run(self, pipeline):
    """Execute test pipeline and verify test matcher"""
    self.result = super(TestDataflowRunner, self).run(pipeline)
    self.result.wait_until_finish()

    options = pipeline.options.view_as(TestOptions)
    if options.on_success_matcher:
      from hamcrest import assert_that as hc_assert_that
      hc_assert_that(self.result, pickler.loads(options.on_success_matcher))

    return self.result
Ejemplo n.º 36
0
def _create_sdf_operation(proxy_dofn, factory, transform_id, transform_proto,
                          parameter, consumers):

    dofn_data = pickler.loads(parameter.do_fn.spec.payload)
    dofn = dofn_data[0]
    restriction_provider = common.DoFnSignature(
        dofn).get_restriction_provider()
    serialized_fn = pickler.dumps((proxy_dofn(dofn, restriction_provider), ) +
                                  dofn_data[1:])
    return _create_pardo_operation(factory, transform_id, transform_proto,
                                   consumers, serialized_fn, parameter)
Ejemplo n.º 37
0
    def setup(self):
        # type: () -> None
        with self.scoped_start_state:
            super(DoOperation, self).setup()

            # See fn_data in dataflow_runner.py
            fn, args, kwargs, tags_and_types, window_fn = (pickler.loads(
                self.spec.serialized_fn))

            state = common.DoFnState(self.counter_factory)
            state.step_name = self.name_context.logging_name()

            # Tag to output index map used to dispatch the output values emitted
            # by the DoFn function to the appropriate receivers. The main output is
            # either the only output or the output tagged with 'None' and is
            # associated with its corresponding index.
            self.tagged_receivers = _TaggedReceivers(
                self.counter_factory, self.name_context.logging_name())

            if len(self.spec.output_tags) == 1:
                self.tagged_receivers[None] = self.receivers[0]
                self.tagged_receivers[
                    self.spec.output_tags[0]] = self.receivers[0]
            else:
                for index, tag in enumerate(self.spec.output_tags):
                    self.tagged_receivers[tag] = self.receivers[index]
                    if tag == 'None':
                        self.tagged_receivers[None] = self.receivers[index]

            if self.user_state_context:
                self.timer_specs = {
                    spec.name: spec
                    for spec in userstate.get_dofn_specs(fn)[1]
                }

            if self.side_input_maps is None:
                if tags_and_types:
                    self.side_input_maps = list(
                        self._read_side_inputs(tags_and_types))
                else:
                    self.side_input_maps = []

            self.dofn_runner = common.DoFnRunner(
                fn,
                args,
                kwargs,
                self.side_input_maps,
                window_fn,
                tagged_receivers=self.tagged_receivers,
                step_name=self.name_context.logging_name(),
                state=state,
                user_state_context=self.user_state_context,
                operation_name=self.name_context.metrics_name())
            self.dofn_runner.setup()
Ejemplo n.º 38
0
def create(factory, transform_id, transform_proto, parameter, consumers):
  dofn_data = pickler.loads(parameter)
  if len(dofn_data) == 2:
    # Has side input data.
    serialized_fn, side_input_data = dofn_data
  else:
    # No side input data.
    serialized_fn, side_input_data = parameter.value, []
  return _create_pardo_operation(
      factory, transform_id, transform_proto, consumers,
      serialized_fn, side_input_data)
def create(factory, transform_id, transform_proto, parameter, consumers):
    assert parameter.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO
    serialized_fn = parameter.do_fn.spec.payload
    dofn_data = pickler.loads(serialized_fn)
    if len(dofn_data) == 2:
        # Has side input data.
        serialized_fn, side_input_data = dofn_data
    else:
        # No side input data.
        side_input_data = []
    return _create_pardo_operation(factory, transform_id, transform_proto,
                                   consumers, serialized_fn, side_input_data)
Ejemplo n.º 40
0
def _create_sdf_operation(
    proxy_dofn,
    factory, transform_id, transform_proto, parameter, consumers):

  dofn_data = pickler.loads(parameter.do_fn.spec.payload)
  dofn = dofn_data[0]
  restriction_provider = common.DoFnSignature(dofn).get_restriction_provider()
  serialized_fn = pickler.dumps(
      (proxy_dofn(dofn, restriction_provider),) + dofn_data[1:])
  return _create_pardo_operation(
      factory, transform_id, transform_proto, consumers,
      serialized_fn, parameter)
Ejemplo n.º 41
0
def create(factory, transform_id, transform_proto, mapping_fn_spec, consumers):
    assert mapping_fn_spec.spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN
    window_mapping_fn = pickler.loads(mapping_fn_spec.spec.payload)

    class MapWindows(beam.DoFn):
        def process(self, element):
            key, window = element
            return [(key, window_mapping_fn(window))]

    return _create_simple_pardo_operation(factory, transform_id,
                                          transform_proto, consumers,
                                          MapWindows())
Ejemplo n.º 42
0
  def start(self):
    with self.scoped_start_state:
      super(DoOperation, self).start()

      # See fn_data in dataflow_runner.py
      fn, args, kwargs, tags_and_types, window_fn = (
          pickler.loads(self.spec.serialized_fn))

      state = common.DoFnState(self.counter_factory)
      state.step_name = self.name_context.logging_name()

      # Tag to output index map used to dispatch the side output values emitted
      # by the DoFn function to the appropriate receivers. The main output is
      # tagged with None and is associated with its corresponding index.
      self.tagged_receivers = _TaggedReceivers(
          self.counter_factory, self.name_context.logging_name())

      output_tag_prefix = PropertyNames.OUT + '_'
      for index, tag in enumerate(self.spec.output_tags):
        if tag == PropertyNames.OUT:
          original_tag = None
        elif tag.startswith(output_tag_prefix):
          original_tag = tag[len(output_tag_prefix):]
        else:
          raise ValueError('Unexpected output name for operation: %s' % tag)
        self.tagged_receivers[original_tag] = self.receivers[index]

      if self.user_state_context:
        self.user_state_context.update_timer_receivers(self.tagged_receivers)
        self.timer_specs = {
            spec.name: spec
            for spec in userstate.get_dofn_specs(fn)[1]
        }

      if self.side_input_maps is None:
        if tags_and_types:
          self.side_input_maps = list(self._read_side_inputs(tags_and_types))
        else:
          self.side_input_maps = []

      self.dofn_runner = common.DoFnRunner(
          fn, args, kwargs, self.side_input_maps, window_fn,
          tagged_receivers=self.tagged_receivers,
          step_name=self.name_context.logging_name(),
          state=state,
          user_state_context=self.user_state_context,
          operation_name=self.name_context.metrics_name())

      self.dofn_receiver = (self.dofn_runner
                            if isinstance(self.dofn_runner, Receiver)
                            else DoFnRunnerReceiver(self.dofn_runner))

      self.dofn_runner.start()
def create(factory, transform_id, transform_proto, mapping_fn_spec, consumers):
  assert mapping_fn_spec.spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN
  window_mapping_fn = pickler.loads(mapping_fn_spec.spec.payload)

  class MapWindows(beam.DoFn):

    def process(self, element):
      key, window = element
      return [(key, window_mapping_fn(window))]

  return _create_simple_pardo_operation(
      factory, transform_id, transform_proto, consumers,
      MapWindows())
Ejemplo n.º 44
0
    def test_create_uses_coder_for_pickling(self):
        coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
        create = Create([_Unpicklable(1), _Unpicklable(2), _Unpicklable(3)])
        unpickled_create = pickler.loads(pickler.dumps(create))
        self.assertEqual(
            sorted(create.values, key=lambda v: v.value),
            sorted(unpickled_create.values, key=lambda v: v.value))

        with self.assertRaises(NotImplementedError):
            # As there is no special coder for Union types, this will fall back to
            # FastPrimitivesCoder, which in turn falls back to pickling.
            create_mixed_types = Create([_Unpicklable(1), 2])
            pickler.dumps(create_mixed_types)
Ejemplo n.º 45
0
def create(factory, transform_id, transform_proto, parameter, consumers):
  assert parameter.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO
  serialized_fn = parameter.do_fn.spec.payload
  dofn_data = pickler.loads(serialized_fn)
  if len(dofn_data) == 2:
    # Has side input data.
    serialized_fn, side_input_data = dofn_data
  else:
    # No side input data.
    side_input_data = []
  return _create_pardo_operation(
      factory, transform_id, transform_proto, consumers,
      serialized_fn, side_input_data)
Ejemplo n.º 46
0
def create(factory, transform_id, transform_proto, parameter, consumers):
  # The Dataflow runner harness strips the base64 encoding.
  source = pickler.loads(base64.b64encode(parameter))
  spec = operation_specs.WorkerRead(
      iobase.SourceBundle(1.0, source, None, None),
      [WindowedValueCoder(source.default_output_coder())])
  return factory.augment_oldstyle_op(
      operations.ReadOperation(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler),
      transform_proto.unique_name,
      consumers)
Ejemplo n.º 47
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (
        pickler.loads(pickler.dumps(transform.dofn))
        if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.user_state_context = None
    self.user_timer_map = {}
    if is_stateful_dofn(dofn):
      kv_type_hint = self._applied_ptransform.inputs[0].element_type
      if kv_type_hint and kv_type_hint != Any:
        coder = coders.registry.get_coder(kv_type_hint)
        self.key_coder = coder.key_coder()
      else:
        self.key_coder = coders.registry.get_coder(Any)

      self.user_state_context = DirectUserStateContext(
          self._step_context, dofn, self.key_coder)
      _, all_timer_specs = get_dofn_specs(dofn)
      for timer_spec in all_timer_specs:
        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

    self.runner = DoFnRunner(
        dofn,
        args,
        kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        user_state_context=self.user_state_context)
    self.runner.setup()
    self.runner.start()
Ejemplo n.º 48
0
    def run_pipeline(self, pipeline):
        # Java has different expectations about coders
        # (windowed in Fn API, but *un*windowed in runner API), whereas the
        # FnApiRunner treats them consistently, so we must guard this.
        # See also BEAM-2717.
        proto_context = pipeline_context.PipelineContext(
            default_environment_url=self._docker_image)
        proto_pipeline = pipeline.to_runner_api(context=proto_context)
        if self._runner_api_address:
            for pcoll in proto_pipeline.components.pcollections.values():
                if pcoll.coder_id not in proto_context.coders:
                    coder = coders.registry.get_coder(
                        pickler.loads(pcoll.coder_id))
                    pcoll.coder_id = proto_context.coders.get_id(coder)
            proto_context.coders.populate_map(proto_pipeline.components.coders)

        # Some runners won't detect the GroupByKey transform unless it has no
        # subtransforms.  Remove all sub-transforms until BEAM-4605 is resolved.
        for _, transform_proto in list(
                proto_pipeline.components.transforms.items()):
            if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
                for sub_transform in transform_proto.subtransforms:
                    del proto_pipeline.components.transforms[sub_transform]
                del transform_proto.subtransforms[:]

        job_service = self._create_job_service()
        prepare_response = job_service.Prepare(
            beam_job_api_pb2.PrepareJobRequest(job_name='job',
                                               pipeline=proto_pipeline))
        if prepare_response.artifact_staging_endpoint.url:
            # Must commit something to get a retrieval token,
            # committing empty manifest for now.
            # TODO(BEAM-3883): Actually stage required files.
            artifact_service = beam_artifact_api_pb2_grpc.ArtifactStagingServiceStub(
                grpc.insecure_channel(
                    prepare_response.artifact_staging_endpoint.url))
            commit_manifest = artifact_service.CommitManifest(
                beam_artifact_api_pb2.CommitManifestRequest(
                    manifest=beam_artifact_api_pb2.Manifest(),
                    staging_session_token=prepare_response.
                    staging_session_token))
            retrieval_token = commit_manifest.retrieval_token
        else:
            retrieval_token = None
        run_response = job_service.Run(
            beam_job_api_pb2.RunJobRequest(
                preparation_id=prepare_response.preparation_id,
                retrieval_token=retrieval_token))
        return PipelineResult(job_service, run_response.job_id)
Ejemplo n.º 49
0
def _create_pardo_operation(
    factory, transform_id, transform_proto, consumers,
    serialized_fn, side_input_data):
  def create_side_input(tag, coder):
    # TODO(robertwb): Extract windows (and keys) out of element data.
    # TODO(robertwb): Extract state key from ParDoPayload.
    return operation_specs.WorkerSideInputSource(
        tag=tag,
        source=SideInputSource(
            factory.state_handler,
            beam_fn_api_pb2.StateKey.MultimapSideInput(
                key=side_input_tag(transform_id, tag)),
            coder=coder))
  output_tags = list(transform_proto.outputs.keys())

  # Hack to match out prefix injected by dataflow runner.
  def mutate_tag(tag):
    if 'None' in output_tags:
      if tag == 'None':
        return 'out'
      else:
        return 'out_' + tag
    else:
      return tag
  dofn_data = pickler.loads(serialized_fn)
  if not dofn_data[-1]:
    # Windowing not set.
    pcoll_id, = transform_proto.inputs.values()
    windowing = factory.context.windowing_strategies.get_by_id(
        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
  output_coders = factory.get_output_coders(transform_proto)
  spec = operation_specs.WorkerDoFn(
      serialized_fn=serialized_fn,
      output_tags=[mutate_tag(tag) for tag in output_tags],
      input=None,
      side_inputs=[
          create_side_input(tag, coder) for tag, coder in side_input_data],
      output_coders=[output_coders[tag] for tag in output_tags])
  return factory.augment_oldstyle_op(
      operations.DoOperation(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler),
      transform_proto.unique_name,
      consumers,
      output_tags)
Ejemplo n.º 50
0
def _create_pardo_operation(
    factory, transform_id, transform_proto, consumers,
    serialized_fn, side_input_data):
  def create_side_input(tag, coder):
    # TODO(robertwb): Extract windows (and keys) out of element data.
    # TODO(robertwb): Extract state key from ParDoPayload.
    return operation_specs.WorkerSideInputSource(
        tag=tag,
        source=SideInputSource(
            factory.state_handler,
            beam_fn_api_pb2.StateKey.MultimapSideInput(
                key=side_input_tag(transform_id, tag)),
            coder=coder))
  output_tags = list(transform_proto.outputs.keys())

  # Hack to match out prefix injected by dataflow runner.
  def mutate_tag(tag):
    if 'None' in output_tags:
      if tag == 'None':
        return 'out'
      else:
        return 'out_' + tag
    else:
      return tag
  dofn_data = pickler.loads(serialized_fn)
  if not dofn_data[-1]:
    # Windowing not set.
    pcoll_id, = transform_proto.inputs.values()
    windowing = factory.context.windowing_strategies.get_by_id(
        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
  output_coders = factory.get_output_coders(transform_proto)
  spec = operation_specs.WorkerDoFn(
      serialized_fn=serialized_fn,
      output_tags=[mutate_tag(tag) for tag in output_tags],
      input=None,
      side_inputs=[
          create_side_input(tag, coder) for tag, coder in side_input_data],
      output_coders=[output_coders[tag] for tag in output_tags])
  return factory.augment_oldstyle_op(
      operations.DoOperation(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler),
      transform_proto.unique_name,
      consumers,
      output_tags)
Ejemplo n.º 51
0
  def start(self):
    with self.scoped_start_state:
      super(DoOperation, self).start()

      # See fn_data in dataflow_runner.py
      fn, args, kwargs, tags_and_types, window_fn = (
          pickler.loads(self.spec.serialized_fn))

      state = common.DoFnState(self.counter_factory)
      state.step_name = self.step_name

      # TODO(silviuc): What is the proper label here? PCollection being
      # processed?
      context = common.DoFnContext('label', state=state)
      # Tag to output index map used to dispatch the side output values emitted
      # by the DoFn function to the appropriate receivers. The main output is
      # tagged with None and is associated with its corresponding index.
      self.tagged_receivers = _TaggedReceivers(
          self.counter_factory, self.step_name)

      output_tag_prefix = PropertyNames.OUT + '_'
      for index, tag in enumerate(self.spec.output_tags):
        if tag == PropertyNames.OUT:
          original_tag = None
        elif tag.startswith(output_tag_prefix):
          original_tag = tag[len(output_tag_prefix):]
        else:
          raise ValueError('Unexpected output name for operation: %s' % tag)
        self.tagged_receivers[original_tag] = self.receivers[index]

      if self.side_input_maps is None:
        if tags_and_types:
          self.side_input_maps = list(self._read_side_inputs(tags_and_types))
        else:
          self.side_input_maps = []

      self.dofn_runner = common.DoFnRunner(
          fn, args, kwargs, self.side_input_maps,
          window_fn, context, self.tagged_receivers,
          logger, self.step_name,
          scoped_metrics_container=self.scoped_metrics_container)
      self.dofn_receiver = (self.dofn_runner
                            if isinstance(self.dofn_runner, Receiver)
                            else DoFnRunnerReceiver(self.dofn_runner))

      self.dofn_runner.start()
Ejemplo n.º 52
0
    def run(self, pipeline):
        """Execute test pipeline and verify test matcher"""
        options = pipeline.options.view_as(TestOptions)
        on_success_matcher = options.on_success_matcher

        # [BEAM-1889] Do not send this to remote workers also, there is no need to
        # send this option to remote executors.
        options.on_success_matcher = None

        self.result = super(TestDataflowRunner, self).run(pipeline)
        self.result.wait_until_finish()

        if on_success_matcher:
            from hamcrest import assert_that as hc_assert_that
            hc_assert_that(self.result, pickler.loads(on_success_matcher))

        return self.result
Ejemplo n.º 53
0
 def __init__(self,
              name_context,
              spec,
              counter_factory,
              state_sampler,
              windowing=None):
     super(PGBKCVOperation, self).__init__(name_context, spec,
                                           counter_factory, state_sampler)
     # Combiners do not accept deferred side-inputs (the ignored fourth
     # argument) and therefore the code to handle the extra args/kwargs is
     # simpler than for the DoFn's of ParDo.
     fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
     self.combine_fn = curry_combine_fn(fn, args, kwargs)
     self.combine_fn_add_input = self.combine_fn.add_input
     base_compact = (core.CombineFn.compact if sys.version_info >=
                     (3, ) else core.CombineFn.compact.__func__)
     if self.combine_fn.compact.__func__ is base_compact:
         self.combine_fn_compact = None
     else:
         self.combine_fn_compact = self.combine_fn.compact
     if windowing:
         self.is_default_windowing = windowing.is_default()
         tsc_type = windowing.timestamp_combiner
         self.timestamp_combiner = (
             None if tsc_type == window.TimestampCombiner.OUTPUT_AT_EOW else
             window.TimestampCombiner.get_impl(tsc_type,
                                               windowing.windowfn))
     else:
         self.is_default_windowing = False  # unknown
         self.timestamp_combiner = None
     # Optimization for the (known tiny accumulator, often wide keyspace)
     # combine functions.
     # TODO(b/36567833): Bound by in-memory size rather than key count.
     self.max_keys = (
         1000 * 1000 if
         isinstance(fn,
                    (combiners.CountCombineFn, combiners.MeanCombineFn)) or
         # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized
         # combiners to the short list above.
         (isinstance(fn, core.CallableWrapperCombineFn)
          and fn._fn in (min, max, sum)) else 100 * 1000)  # pylint: disable=protected-access
     self.key_count = 0
     self.table = {}
Ejemplo n.º 54
0
  def split(self, desired_bundle_size, start_offset=None, stop_offset=None):
    if start_offset is None:
      start_offset = self._start_offset
    if stop_offset is None:
      stop_offset = self._stop_offset

    if self._splittable:
      bundle_size = max(desired_bundle_size, self._min_bundle_size)

      bundle_start = start_offset
      while bundle_start < stop_offset:
        bundle_stop = min(bundle_start + bundle_size, stop_offset)
        yield iobase.SourceBundle(
            bundle_stop - bundle_start,
            _SingleFileSource(
                # Copying this so that each sub-source gets a fresh instance.
                pickler.loads(pickler.dumps(self._file_based_source)),
                self._file_name,
                bundle_start,
                bundle_stop,
                min_bundle_size=self._min_bundle_size,
                splittable=self._splittable),
            bundle_start,
            bundle_stop)
        bundle_start = bundle_stop
    else:
      # Returning a single sub-source with end offset set to OFFSET_INFINITY (so
      # that all data of the source gets read) since this source is
      # unsplittable. Choosing size of the file as end offset will be wrong for
      # certain unsplittable source, e.g., compressed sources.
      yield iobase.SourceBundle(
          stop_offset - start_offset,
          _SingleFileSource(
              self._file_based_source,
              self._file_name,
              start_offset,
              range_trackers.OffsetRangeTracker.OFFSET_INFINITY,
              min_bundle_size=self._min_bundle_size,
              splittable=self._splittable
          ),
          start_offset,
          range_trackers.OffsetRangeTracker.OFFSET_INFINITY
      )
Ejemplo n.º 55
0
    def _get_concat_source(self):
        if self._concat_source is None:
            pattern = self._pattern.get()

            single_file_sources = []
            if self._file_system is None:
                self._file_system = get_filesystem(pattern)
            match_result = self._file_system.match([pattern])[0]
            files_metadata = match_result.metadata_list

            # We create a reference for FileBasedSource that will be serialized along
            # with each _SingleFileSource. To prevent this FileBasedSource from having
            # a reference to ConcatSource (resulting in quadratic space complexity)
            # we clone it here.
            file_based_source_ref = pickler.loads(pickler.dumps(self))

            for file_metadata in files_metadata:
                file_name = file_metadata.path
                file_size = file_metadata.size_in_bytes
                if file_size == 0:
                    continue  # Ignoring empty file.

                # We determine splittability of this specific file.
                splittable = self.splittable
                if (splittable
                        and self._compression_type == CompressionTypes.AUTO):
                    compression_type = CompressionTypes.detect_compression_type(
                        file_name)
                    if compression_type != CompressionTypes.UNCOMPRESSED:
                        splittable = False

                single_file_source = _SingleFileSource(
                    file_based_source_ref,
                    file_name,
                    0,
                    file_size,
                    min_bundle_size=self._min_bundle_size,
                    splittable=splittable)
                single_file_sources.append(single_file_source)
            self._concat_source = concat_source.ConcatSource(
                single_file_sources)
        return self._concat_source
  def expand(self, pvalue):
    cache_dict = {}

    for dataset_key in self._dataset_keys:

      dataset_cache_path = os.path.join(self._cache_base_dir,
                                        _make_dataset_key(dataset_key))
      if not tf.io.gfile.isdir(dataset_cache_path):
        continue
      cache_dict[dataset_key] = {}
      with tf.io.gfile.GFile(
          os.path.join(dataset_cache_path, _MANIFEST_FILE_NAME), 'r') as f:
        manifest = pickler.loads(f.read())
      for key, value in six.iteritems(manifest):
        cache_dict[dataset_key][key] = (
            pvalue.pipeline
            | 'ReadCache[{}][{}]'.format(dataset_key, value) >>
            self._source('{}{}'.format(
                os.path.join(dataset_cache_path, str(value)), '-*-of-*')))
    return cache_dict
Ejemplo n.º 57
0
 def check_coder(self, coder, *values, **kwargs):
   context = kwargs.pop('context', pipeline_context.PipelineContext())
   test_size_estimation = kwargs.pop('test_size_estimation', True)
   assert not kwargs
   self._observe(coder)
   for v in values:
     self.assertEqual(v, coder.decode(coder.encode(v)))
     if test_size_estimation:
       self.assertEqual(coder.estimate_size(v),
                        len(coder.encode(v)))
       self.assertEqual(coder.estimate_size(v),
                        coder.get_impl().estimate_size(v))
       self.assertEqual(coder.get_impl().get_estimated_size_and_observables(v),
                        (coder.get_impl().estimate_size(v), []))
     copy1 = pickler.loads(pickler.dumps(coder))
   copy2 = coders.Coder.from_runner_api(coder.to_runner_api(context), context)
   for v in values:
     self.assertEqual(v, copy1.decode(copy2.encode(v)))
     if coder.is_deterministic():
       self.assertEqual(copy1.encode(v), copy2.encode(v))
Ejemplo n.º 58
0
    def validate_test_matcher(self, view, arg_name):
        """Validates that on_success_matcher argument if set.

    Validates that on_success_matcher is unpicklable and is instance
    of `hamcrest.core.base_matcher.BaseMatcher`.
    """
        # This is a test only method and requires hamcrest
        from hamcrest.core.base_matcher import BaseMatcher
        pickled_matcher = view.on_success_matcher
        errors = []
        try:
            matcher = pickler.loads(pickled_matcher)
            if not isinstance(matcher, BaseMatcher):
                errors.extend(
                    self._validate_error(self.ERR_INVALID_TEST_MATCHER_TYPE,
                                         matcher, arg_name))
        except:  # pylint: disable=bare-except
            errors.extend(
                self._validate_error(self.ERR_INVALID_TEST_MATCHER_UNPICKLABLE,
                                     pickled_matcher, arg_name))
        return errors
Ejemplo n.º 59
0
 def __init__(self, name_context, spec, counter_factory, state_sampler):
   super(PGBKCVOperation, self).__init__(
       name_context, spec, counter_factory, state_sampler)
   # Combiners do not accept deferred side-inputs (the ignored fourth
   # argument) and therefore the code to handle the extra args/kwargs is
   # simpler than for the DoFn's of ParDo.
   fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
   self.combine_fn = curry_combine_fn(fn, args, kwargs)
   self.combine_fn_add_input = self.combine_fn.add_input
   # Optimization for the (known tiny accumulator, often wide keyspace)
   # combine functions.
   # TODO(b/36567833): Bound by in-memory size rather than key count.
   self.max_keys = (
       1000 * 1000 if
       isinstance(fn, (combiners.CountCombineFn, combiners.MeanCombineFn)) or
       # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized
       # combiners to the short list above.
       (isinstance(fn, core.CallableWrapperCombineFn) and
        fn._fn in (min, max, sum)) else 100 * 1000)  # pylint: disable=protected-access
   self.key_count = 0
   self.table = {}