def test_test_matcher(self):
    def get_validator(matcher):
      options = ['--project=example:example',
                 '--job_name=job',
                 '--staging_location=gs://foo/bar',
                 '--temp_location=gs://foo/bar',]
      if matcher:
        options.append('--on_success_matcher=' + matcher)

      pipeline_options = PipelineOptions(options)
      runner = MockRunners.TestDataflowRunner()
      return PipelineOptionsValidator(pipeline_options, runner)

    test_case = [
        {'on_success_matcher': None,
         'errors': []},
        {'on_success_matcher': pickler.dumps(AlwaysPassMatcher()),
         'errors': []},
        {'on_success_matcher': 'abc',
         'errors': ['on_success_matcher']},
        {'on_success_matcher': pickler.dumps(object),
         'errors': ['on_success_matcher']},
    ]

    for case in test_case:
      errors = get_validator(case['on_success_matcher']).validate()
      self.assertEqual(
          self.check_errors_for_arguments(errors, case['errors']), [])
Ejemplo n.º 2
0
  def __init__(self, fn, *args, **kwargs):
    if isinstance(fn, type) and issubclass(fn, WithTypeHints):
      # Don't treat Fn class objects as callables.
      raise ValueError('Use %s() not %s.' % (fn.__name__, fn.__name__))
    self.fn = self.make_fn(fn)
    # Now that we figure out the label, initialize the super-class.
    super(PTransformWithSideInputs, self).__init__()

    if (any([isinstance(v, pvalue.PCollection) for v in args]) or
        any([isinstance(v, pvalue.PCollection) for v in kwargs.itervalues()])):
      raise error.SideInputError(
          'PCollection used directly as side input argument. Specify '
          'AsIter(pcollection) or AsSingleton(pcollection) to indicate how the '
          'PCollection is to be used.')
    self.args, self.kwargs, self.side_inputs = util.remove_objects_from_args(
        args, kwargs, pvalue.AsSideInput)
    self.raw_side_inputs = args, kwargs

    # Prevent name collisions with fns of the form '<function <lambda> at ...>'
    self._cached_fn = self.fn

    # Ensure fn and side inputs are picklable for remote execution.
    self.fn = pickler.loads(pickler.dumps(self.fn))
    self.args = pickler.loads(pickler.dumps(self.args))
    self.kwargs = pickler.loads(pickler.dumps(self.kwargs))

    # For type hints, because loads(dumps(class)) != class.
    self.fn = self._cached_fn
Ejemplo n.º 3
0
 def test_nested_class(self):
   """Tests that a nested class object is pickled correctly."""
   self.assertEquals(
       'X:abc',
       loads(dumps(module_test.TopClass.NestedClass('abc'))).datum)
   self.assertEquals(
       'Y:abc',
       loads(dumps(module_test.TopClass.MiddleClass.NestedClass('abc'))).datum)
Ejemplo n.º 4
0
 def to_runner_api(self, unused_context):
   return beam_runner_api_pb2.SideInput(
       access_pattern=beam_runner_api_pb2.FunctionSpec(
           urn=self.access_pattern),
       view_fn=beam_runner_api_pb2.SdkFunctionSpec(
           spec=beam_runner_api_pb2.FunctionSpec(
               urn=python_urns.PICKLED_VIEWFN,
               payload=pickler.dumps((self.view_fn, self.coder)))),
       window_mapping_fn=beam_runner_api_pb2.SdkFunctionSpec(
           spec=beam_runner_api_pb2.FunctionSpec(
               urn=python_urns.PICKLED_WINDOW_MAPPING_FN,
               payload=pickler.dumps(self.window_mapping_fn))))
Ejemplo n.º 5
0
 def to_runner_api(self, context):
   return beam_runner_api_pb2.SideInput(
       access_pattern=beam_runner_api_pb2.FunctionSpec(
           urn=self.access_pattern),
       view_fn=beam_runner_api_pb2.SdkFunctionSpec(
           environment_id=context.default_environment_id(),
           spec=beam_runner_api_pb2.FunctionSpec(
               urn=python_urns.PICKLED_VIEWFN,
               payload=pickler.dumps(self.view_fn))),
       window_mapping_fn=beam_runner_api_pb2.SdkFunctionSpec(
           environment_id=context.default_environment_id(),
           spec=beam_runner_api_pb2.FunctionSpec(
               urn=python_urns.PICKLED_WINDOW_MAPPING_FN,
               payload=pickler.dumps(self.window_mapping_fn))))
Ejemplo n.º 6
0
  def run_CombineValues(self, transform_node):
    transform = transform_node.transform
    input_tag = transform_node.inputs[0].tag
    input_step = self._cache.get_pvalue(transform_node.inputs[0])
    step = self._add_step(
        TransformNames.COMBINE, transform_node.full_label, transform_node)
    # Combiner functions do not take deferred side-inputs (i.e. PValues) and
    # therefore the code to handle extra args/kwargs is simpler than for the
    # DoFn's of the ParDo transform. In the last, empty argument is where
    # side inputs information would go.
    fn_data = (transform.fn, transform.args, transform.kwargs, ())
    step.add_property(PropertyNames.SERIALIZED_FN,
                      pickler.dumps(fn_data))
    step.add_property(
        PropertyNames.PARALLEL_INPUT,
        {'@type': 'OutputReference',
         PropertyNames.STEP_NAME: input_step.proto.name,
         PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})
    # Note that the accumulator must not have a WindowedValue encoding, while
    # the output of this step does in fact have a WindowedValue encoding.
    accumulator_encoding = self._get_cloud_encoding(
        transform_node.transform.fn.get_accumulator_coder())
    output_encoding = self._get_encoded_output_coder(transform_node)

    step.encoding = output_encoding
    step.add_property(PropertyNames.ENCODING, accumulator_encoding)
    # Generate description for main output 'out.'
    outputs = []
    # Add the main output to the description.
    outputs.append(
        {PropertyNames.USER_NAME: (
            '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
         PropertyNames.ENCODING: step.encoding,
         PropertyNames.OUTPUT_NAME: PropertyNames.OUT})
    step.add_property(PropertyNames.OUTPUT_INFO, outputs)
Ejemplo n.º 7
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = pickler.loads(pickler.dumps(transform.dofn))

    pipeline_options = self._evaluation_context.pipeline_options
    if (pipeline_options is not None
        and pipeline_options.view_as(TypeOptions).runtime_type_check):
      dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

    dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
    self.runner = DoFnRunner(
        dofn, transform.args, transform.kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()
Ejemplo n.º 8
0
  def _get_concat_source(self):
    if self._concat_source is None:
      pattern = self._pattern.get()

      single_file_sources = []
      match_result = FileSystems.match([pattern])[0]
      files_metadata = match_result.metadata_list

      # We create a reference for FileBasedSource that will be serialized along
      # with each _SingleFileSource. To prevent this FileBasedSource from having
      # a reference to ConcatSource (resulting in quadratic space complexity)
      # we clone it here.
      file_based_source_ref = pickler.loads(pickler.dumps(self))

      for file_metadata in files_metadata:
        file_name = file_metadata.path
        file_size = file_metadata.size_in_bytes
        if file_size == 0:
          continue  # Ignoring empty file.

        # We determine splittability of this specific file.
        splittable = (
            self.splittable and
            _determine_splittability_from_compression_type(
                file_name, self._compression_type))

        single_file_source = _SingleFileSource(
            file_based_source_ref, file_name,
            0,
            file_size,
            min_bundle_size=self._min_bundle_size,
            splittable=splittable)
        single_file_sources.append(single_file_source)
      self._concat_source = concat_source.ConcatSource(single_file_sources)
    return self._concat_source
Ejemplo n.º 9
0
 def test_get_coder_with_composite_custom_coder(self):
   typecoders.registry.register_coder(CustomClass, CustomCoder)
   coder = typecoders.registry.get_coder(typehints.KV[CustomClass, str])
   revived_coder = pickler.loads(pickler.dumps(coder))
   self.assertEqual(
       (CustomClass(123), 'abc'),
       revived_coder.decode(revived_coder.encode((CustomClass(123), 'abc'))))
Ejemplo n.º 10
0
def _create_pardo_operation(
    factory, transform_id, transform_proto, consumers,
    serialized_fn, side_inputs_proto=None):

  if side_inputs_proto:
    input_tags_to_coders = factory.get_input_coders(transform_proto)
    tagged_side_inputs = [
        (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context))
        for tag, si in side_inputs_proto.items()]
    tagged_side_inputs.sort(key=lambda tag_si: int(tag_si[0][4:]))
    side_input_maps = [
        StateBackedSideInputMap(
            factory.state_handler,
            transform_id,
            tag,
            si,
            input_tags_to_coders[tag])
        for tag, si in tagged_side_inputs]
  else:
    side_input_maps = []

  output_tags = list(transform_proto.outputs.keys())

  # Hack to match out prefix injected by dataflow runner.
  def mutate_tag(tag):
    if 'None' in output_tags:
      if tag == 'None':
        return 'out'
      else:
        return 'out_' + tag
    else:
      return tag

  dofn_data = pickler.loads(serialized_fn)
  if not dofn_data[-1]:
    # Windowing not set.
    side_input_tags = side_inputs_proto or ()
    pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items()
                 if tag not in side_input_tags]
    windowing = factory.context.windowing_strategies.get_by_id(
        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))

  output_coders = factory.get_output_coders(transform_proto)
  spec = operation_specs.WorkerDoFn(
      serialized_fn=serialized_fn,
      output_tags=[mutate_tag(tag) for tag in output_tags],
      input=None,
      side_inputs=None,  # Fn API uses proto definitions and the Fn State API
      output_coders=[output_coders[tag] for tag in output_tags])
  return factory.augment_oldstyle_op(
      operations.DoOperation(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler,
          side_input_maps),
      transform_proto.unique_name,
      consumers,
      output_tags)
Ejemplo n.º 11
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (pickler.loads(pickler.dumps(transform.dofn))
            if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.runner = DoFnRunner(
        dofn, args, kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()
Ejemplo n.º 12
0
  def _get_concat_source(self):
    if self._concat_source is None:
      single_file_sources = []
      file_names = [f for f in fileio.ChannelFactory.glob(self._pattern)]
      sizes = FileBasedSource._estimate_sizes_of_files(file_names,
                                                       self._pattern)

      # We create a reference for FileBasedSource that will be serialized along
      # with each _SingleFileSource. To prevent this FileBasedSource from having
      # a reference to ConcatSource (resulting in quadratic space complexity)
      # we clone it here.
      file_based_source_ref = pickler.loads(pickler.dumps(self))

      for index, file_name in enumerate(file_names):
        if sizes[index] == 0:
          continue  # Ignoring empty file.

        # We determine splittability of this specific file.
        splittable = self.splittable
        if (splittable and
            self._compression_type == fileio.CompressionTypes.AUTO):
          compression_type = fileio.CompressionTypes.detect_compression_type(
              file_name)
          if compression_type != fileio.CompressionTypes.UNCOMPRESSED:
            splittable = False

        single_file_source = _SingleFileSource(
            file_based_source_ref, file_name,
            0,
            sizes[index],
            min_bundle_size=self._min_bundle_size,
            splittable=splittable)
        single_file_sources.append(single_file_source)
      self._concat_source = concat_source.ConcatSource(single_file_sources)
    return self._concat_source
Ejemplo n.º 13
0
def _create_simple_pardo_operation(
    factory, transform_id, transform_proto, consumers, dofn):
  serialized_fn = pickler.dumps((dofn, (), {}, [], None))
  side_input_data = []
  return _create_pardo_operation(
      factory, transform_id, transform_proto, consumers,
      serialized_fn, side_input_data)
Ejemplo n.º 14
0
  def run_ParDo(self, transform_node):
    transform = transform_node.transform
    input_tag = transform_node.inputs[0].tag
    input_step = self._cache.get_pvalue(transform_node.inputs[0])

    # Attach side inputs.
    si_dict = {}
    # We must call self._cache.get_pvalue exactly once due to refcounting.
    si_labels = {}
    for side_pval in transform_node.side_inputs:
      si_labels[side_pval] = self._cache.get_pvalue(side_pval).step_name
    lookup_label = lambda side_pval: si_labels[side_pval]
    for side_pval in transform_node.side_inputs:
      assert isinstance(side_pval, PCollectionView)
      si_label = lookup_label(side_pval)
      si_dict[si_label] = {
          '@type': 'OutputReference',
          PropertyNames.STEP_NAME: si_label,
          PropertyNames.OUTPUT_NAME: PropertyNames.OUT}

    # Now create the step for the ParDo transform being handled.
    step = self._add_step(
        TransformNames.DO, transform_node.full_label, transform_node,
        transform_node.transform.side_output_tags)
    fn_data = self._pardo_fn_data(transform_node, lookup_label)
    step.add_property(PropertyNames.SERIALIZED_FN, pickler.dumps(fn_data))
    step.add_property(
        PropertyNames.PARALLEL_INPUT,
        {'@type': 'OutputReference',
         PropertyNames.STEP_NAME: input_step.proto.name,
         PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})
    # Add side inputs if any.
    step.add_property(PropertyNames.NON_PARALLEL_INPUTS, si_dict)

    # Generate description for main output and side outputs. The output names
    # will be 'out' for main output and 'out_<tag>' for a tagged output.
    # Using 'out' as a tag will not clash with the name for main since it will
    # be transformed into 'out_out' internally.
    outputs = []
    step.encoding = self._get_encoded_output_coder(transform_node)

    # Add the main output to the description.
    outputs.append(
        {PropertyNames.USER_NAME: (
            '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
         PropertyNames.ENCODING: step.encoding,
         PropertyNames.OUTPUT_NAME: PropertyNames.OUT})
    for side_tag in transform.side_output_tags:
      # The assumption here is that side outputs will have the same typehint
      # and coder as the main output. This is certainly the case right now
      # but conceivably it could change in the future.
      outputs.append(
          {PropertyNames.USER_NAME: (
              '%s.%s' % (transform_node.full_label, side_tag)),
           PropertyNames.ENCODING: step.encoding,
           PropertyNames.OUTPUT_NAME: (
               '%s_%s' % (PropertyNames.OUT, side_tag))})

    step.add_property(PropertyNames.OUTPUT_INFO, outputs)
Ejemplo n.º 15
0
 def to_runner_api(self, context):
   return beam_runner_api_pb2.PCollection(
       unique_name='%d%s.%s' % (
           len(self.producer.full_label), self.producer.full_label, self.tag),
       coder_id=pickler.dumps(self.element_type),
       is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED,
       windowing_strategy_id=context.windowing_strategies.get_id(
           self.windowing))
Ejemplo n.º 16
0
 def visit_transform(self, transform_node):
   try:
     # Transforms must be picklable.
     pickler.loads(pickler.dumps(transform_node.transform,
                                 enable_trace=False),
                   enable_trace=False)
   except Exception:
     Visitor.ok = False
Ejemplo n.º 17
0
  def test_lambda_with_globals(self):
    """Tests that the globals of a function are preserved."""

    # The point of the test is that the lambda being called after unpickling
    # relies on having the re module being loaded.
    self.assertEquals(
        ['abc', 'def'],
        loads(dumps(module_test.get_lambda_with_globals()))('abc def'))
 def run_PartialGroupByKeyCombineValues(self, transform_node):
   element_coder = self._get_coder(transform_node.outputs[None])
   _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
   combine_op = operation_specs.WorkerPartialGroupByKey(
       combine_fn=pickler.dumps(
           (transform_node.transform.combine_fn, (), {}, ())),
       output_coders=[element_coder],
       input=(producer_index, output_index))
   self._run_as_op(transform_node, combine_op)
Ejemplo n.º 19
0
 def visit_transform(self, transform_node):
   if transform_node.side_inputs:
     # No side inputs (yet).
     Visitor.ok = False
   try:
     # Transforms must be picklable.
     pickler.loads(pickler.dumps(transform_node.transform))
   except Exception:
     Visitor.ok = False
Ejemplo n.º 20
0
Archivo: urns.py Proyecto: gamars/beam
 def register_pickle_urn(cls, pickle_urn):
   """Registers and implements the given urn via pickling.
   """
   inspect.currentframe().f_back.f_locals['to_runner_api_parameter'] = (
       lambda self, context: (
           pickle_urn, wrappers_pb2.BytesValue(value=pickler.dumps(self))))
   cls.register_urn(
       pickle_urn,
       wrappers_pb2.BytesValue,
       lambda proto, unused_context: pickler.loads(proto.value))
Ejemplo n.º 21
0
 def to_runner_api(self, context):
   from apache_beam.runners.api import beam_runner_api_pb2
   from apache_beam.internal import pickler
   return beam_runner_api_pb2.PCollection(
       unique_name='%d%s.%s' % (
           len(self.producer.full_label), self.producer.full_label, self.tag),
       coder_id=pickler.dumps(self.element_type),
       is_bounded=beam_runner_api_pb2.BOUNDED,
       windowing_strategy_id=context.windowing_strategies.get_id(
           self.windowing))
 def _run_combine_transform(self, transform_node, phase):
   transform = transform_node.transform
   element_coder = self._get_coder(transform_node.outputs[None])
   _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
   combine_op = operation_specs.WorkerCombineFn(
       serialized_fn=pickler.dumps(
           (transform.combine_fn, (), {}, ())),
       phase=phase,
       output_coders=[element_coder],
       input=(producer_index, output_index))
   self._run_as_op(transform_node, combine_op)
Ejemplo n.º 23
0
def _create_sdf_operation(
    proxy_dofn,
    factory, transform_id, transform_proto, parameter, consumers):

  dofn_data = pickler.loads(parameter.do_fn.spec.payload)
  dofn = dofn_data[0]
  restriction_provider = common.DoFnSignature(dofn).get_restriction_provider()
  serialized_fn = pickler.dumps(
      (proxy_dofn(dofn, restriction_provider),) + dofn_data[1:])
  return _create_pardo_operation(
      factory, transform_id, transform_proto, consumers,
      serialized_fn, parameter)
def create(factory, transform_id, transform_proto, payload, consumers):
  serialized_combine_fn = pickler.dumps(
      (beam.CombineFn.from_runner_api(payload.combine_fn, factory.context),
       [], {}))
  return factory.augment_oldstyle_op(
      operations.PGBKCVOperation(
          transform_proto.unique_name,
          operation_specs.WorkerPartialGroupByKey(
              serialized_combine_fn,
              None,
              [factory.get_only_output_coder(transform_proto)]),
          factory.counter_factory,
          factory.state_sampler),
      transform_proto.unique_name,
      consumers)
Ejemplo n.º 25
0
def _create_pardo_operation(
    factory, transform_id, transform_proto, consumers,
    serialized_fn, side_input_data):
  def create_side_input(tag, coder):
    # TODO(robertwb): Extract windows (and keys) out of element data.
    # TODO(robertwb): Extract state key from ParDoPayload.
    return operation_specs.WorkerSideInputSource(
        tag=tag,
        source=SideInputSource(
            factory.state_handler,
            beam_fn_api_pb2.StateKey.MultimapSideInput(
                key=side_input_tag(transform_id, tag)),
            coder=coder))
  output_tags = list(transform_proto.outputs.keys())

  # Hack to match out prefix injected by dataflow runner.
  def mutate_tag(tag):
    if 'None' in output_tags:
      if tag == 'None':
        return 'out'
      else:
        return 'out_' + tag
    else:
      return tag
  dofn_data = pickler.loads(serialized_fn)
  if not dofn_data[-1]:
    # Windowing not set.
    pcoll_id, = transform_proto.inputs.values()
    windowing = factory.context.windowing_strategies.get_by_id(
        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
  output_coders = factory.get_output_coders(transform_proto)
  spec = operation_specs.WorkerDoFn(
      serialized_fn=serialized_fn,
      output_tags=[mutate_tag(tag) for tag in output_tags],
      input=None,
      side_inputs=[
          create_side_input(tag, coder) for tag, coder in side_input_data],
      output_coders=[output_coders[tag] for tag in output_tags])
  return factory.augment_oldstyle_op(
      operations.DoOperation(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler),
      transform_proto.unique_name,
      consumers,
      output_tags)
Ejemplo n.º 26
0
 def to_runner_api(self, context):
   from apache_beam.runners.api import beam_runner_api_pb2
   return beam_runner_api_pb2.PTransform(
       unique_name=self.full_label,
       spec=beam_runner_api_pb2.FunctionSpec(
           urn=urns.PICKLED_TRANSFORM,
           parameter=proto_utils.pack_Any(
               wrappers_pb2.BytesValue(value=pickler.dumps(self.transform)))),
       subtransforms=[context.transforms.get_id(part) for part in self.parts],
       # TODO(BEAM-115): Side inputs.
       inputs={tag: context.pcollections.get_id(pc)
               for tag, pc in self.named_inputs().items()},
       outputs={str(tag): context.pcollections.get_id(out)
                for tag, out in self.named_outputs().items()},
       # TODO(BEAM-115): display_data
       display_data=None)
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (pickler.loads(pickler.dumps(transform.dofn))
            if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.user_state_context = None
    self.user_timer_map = {}
    if is_stateful_dofn(dofn):
      kv_type_hint = self._applied_ptransform.inputs[0].element_type
      if kv_type_hint and kv_type_hint != typehints.Any:
        coder = coders.registry.get_coder(kv_type_hint)
        self.key_coder = coder.key_coder()
      else:
        self.key_coder = coders.registry.get_coder(typehints.Any)

      self.user_state_context = DirectUserStateContext(
          self._step_context, dofn, self.key_coder)
      _, all_timer_specs = get_dofn_specs(dofn)
      for timer_spec in all_timer_specs:
        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

    self.runner = DoFnRunner(
        dofn, args, kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        user_state_context=self.user_state_context)
    self.runner.start()
Ejemplo n.º 28
0
  def get_full_options_as_args(self, **extra_opts):
    """Get full pipeline options as an argument list.

    Append extra pipeline options to existing option list if provided.
    Test verifier (if contains in extra options) should be pickled before
    appending, and will be unpickled later in the TestRunner.
    """
    options = list(self.options_list)
    for k, v in extra_opts.items():
      if not v:
        continue
      elif isinstance(v, bool) and v:
        options.append('--%s' % k)
      elif 'matcher' in k:
        options.append('--%s=%s' % (k, pickler.dumps(v)))
      else:
        options.append('--%s=%s' % (k, v))
    return options
Ejemplo n.º 29
0
def _create_combine_phase_operation(
    factory, transform_proto, payload, consumers, phase):
  # This is where support for combine fn side inputs would go.
  serialized_combine_fn = pickler.dumps(
      (beam.CombineFn.from_runner_api(payload.combine_fn, factory.context),
       [], {}))
  return factory.augment_oldstyle_op(
      operations.CombineOperation(
          transform_proto.unique_name,
          operation_specs.WorkerCombineFn(
              serialized_combine_fn,
              phase,
              None,
              [factory.get_only_output_coder(transform_proto)]),
          factory.counter_factory,
          factory.state_sampler),
      transform_proto.unique_name,
      consumers)
Ejemplo n.º 30
0
  def split(self, desired_bundle_size, start_offset=None, stop_offset=None):
    if start_offset is None:
      start_offset = self._start_offset
    if stop_offset is None:
      stop_offset = self._stop_offset

    if self._splittable:
      bundle_size = max(desired_bundle_size, self._min_bundle_size)

      bundle_start = start_offset
      while bundle_start < stop_offset:
        bundle_stop = min(bundle_start + bundle_size, stop_offset)
        yield iobase.SourceBundle(
            bundle_stop - bundle_start,
            _SingleFileSource(
                # Copying this so that each sub-source gets a fresh instance.
                pickler.loads(pickler.dumps(self._file_based_source)),
                self._file_name,
                bundle_start,
                bundle_stop,
                min_bundle_size=self._min_bundle_size,
                splittable=self._splittable),
            bundle_start,
            bundle_stop)
        bundle_start = bundle_stop
    else:
      # Returning a single sub-source with end offset set to OFFSET_INFINITY (so
      # that all data of the source gets read) since this source is
      # unsplittable. Choosing size of the file as end offset will be wrong for
      # certain unsplittable source, e.g., compressed sources.
      yield iobase.SourceBundle(
          stop_offset - start_offset,
          _SingleFileSource(
              self._file_based_source,
              self._file_name,
              start_offset,
              range_trackers.OffsetRangeTracker.OFFSET_INFINITY,
              min_bundle_size=self._min_bundle_size,
              splittable=self._splittable
          ),
          start_offset,
          range_trackers.OffsetRangeTracker.OFFSET_INFINITY
      )
Ejemplo n.º 31
0
 def coder_id_from_element_type(self, element_type):
     # type: (Any) -> str
     if self.use_fake_coders:
         return pickler.dumps(element_type).decode('ascii')
     else:
         return self.coders.get_id(coders.registry.get_coder(element_type))
Ejemplo n.º 32
0
    def _map_task_to_runner_protos(self, map_task, data_operation_spec):
        input_data = {}
        side_input_data = {}
        runner_sinks = {}

        context = pipeline_context.PipelineContext()
        transform_protos = {}
        used_pcollections = {}

        def uniquify(*names):
            # An injective mapping from string* to string.
            return ':'.join("%s:%d" % (name, len(name)) for name in names)

        def pcollection_id(op_ix, out_ix):
            if (op_ix, out_ix) not in used_pcollections:
                used_pcollections[op_ix,
                                  out_ix] = uniquify(map_task[op_ix][0], 'out',
                                                     str(out_ix))
            return used_pcollections[op_ix, out_ix]

        def get_inputs(op):
            if hasattr(op, 'inputs'):
                inputs = op.inputs
            elif hasattr(op, 'input'):
                inputs = [op.input]
            else:
                inputs = []
            return {
                'in%s' % ix: pcollection_id(*input)
                for ix, input in enumerate(inputs)
            }

        def get_outputs(op_ix):
            op = map_task[op_ix][1]
            return {
                tag: pcollection_id(op_ix, out_ix)
                for out_ix, tag in enumerate(
                    getattr(op, 'output_tags', ['out']))
            }

        for op_ix, (stage_name, operation) in enumerate(map_task):
            transform_id = uniquify(stage_name)

            if isinstance(operation, operation_specs.WorkerInMemoryWrite):
                # Write this data back to the runner.
                runner_sinks[(transform_id, 'out')] = operation
                transform_spec = beam_runner_api_pb2.FunctionSpec(
                    urn=sdk_worker.DATA_OUTPUT_URN,
                    parameter=proto_utils.pack_Any(data_operation_spec))

            elif isinstance(operation, operation_specs.WorkerRead):
                # A Read from an in-memory source is done over the data plane.
                if (isinstance(operation.source.source,
                               maptask_executor_runner.InMemorySource)
                        and isinstance(
                            operation.source.source.default_output_coder(),
                            WindowedValueCoder)):
                    input_data[(
                        transform_id, 'input')] = self._reencode_elements(
                            operation.source.source.read(None),
                            operation.source.source.default_output_coder())
                    transform_spec = beam_runner_api_pb2.FunctionSpec(
                        urn=sdk_worker.DATA_INPUT_URN,
                        parameter=proto_utils.pack_Any(data_operation_spec))

                else:
                    # Otherwise serialize the source and execute it there.
                    # TODO: Use SDFs with an initial impulse.
                    # The Dataflow runner harness strips the base64 encoding. do the same
                    # here until we get the same thing back that we sent in.
                    transform_spec = beam_runner_api_pb2.FunctionSpec(
                        urn=sdk_worker.PYTHON_SOURCE_URN,
                        parameter=proto_utils.pack_Any(
                            wrappers_pb2.BytesValue(value=base64.b64decode(
                                pickler.dumps(operation.source.source)))))

            elif isinstance(operation, operation_specs.WorkerDoFn):
                # Record the contents of each side input for access via the state api.
                side_input_extras = []
                for si in operation.side_inputs:
                    assert isinstance(si.source, iobase.BoundedSource)
                    element_coder = si.source.default_output_coder()
                    # TODO(robertwb): Actually flesh out the ViewFn API.
                    side_input_extras.append((si.tag, element_coder))
                    side_input_data[sdk_worker.side_input_tag(
                        transform_id, si.tag)] = (self._reencode_elements(
                            si.source.read(
                                si.source.get_range_tracker(None, None)),
                            element_coder))
                augmented_serialized_fn = pickler.dumps(
                    (operation.serialized_fn, side_input_extras))
                transform_spec = beam_runner_api_pb2.FunctionSpec(
                    urn=sdk_worker.PYTHON_DOFN_URN,
                    parameter=proto_utils.pack_Any(
                        wrappers_pb2.BytesValue(
                            value=augmented_serialized_fn)))

            elif isinstance(operation, operation_specs.WorkerFlatten):
                # Flatten is nice and simple.
                transform_spec = beam_runner_api_pb2.FunctionSpec(
                    urn=sdk_worker.IDENTITY_DOFN_URN)

            else:
                raise NotImplementedError(operation)

            transform_protos[transform_id] = beam_runner_api_pb2.PTransform(
                unique_name=stage_name,
                spec=transform_spec,
                inputs=get_inputs(operation),
                outputs=get_outputs(op_ix))

        pcollection_protos = {
            name: beam_runner_api_pb2.PCollection(
                unique_name=name,
                coder_id=context.coders.get_id(
                    map_task[op_id][1].output_coders[out_id]))
            for (op_id, out_id), name in used_pcollections.items()
        }
        # Must follow creation of pcollection_protos to capture used coders.
        context_proto = context.to_runner_api()
        process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
            id=self._next_uid(),
            transforms=transform_protos,
            pcollections=pcollection_protos,
            codersyyy=dict(context_proto.coders.items()),
            windowing_strategies=dict(
                context_proto.windowing_strategies.items()),
            environments=dict(context_proto.environments.items()))
        return input_data, side_input_data, runner_sinks, process_bundle_descriptor
Ejemplo n.º 33
0
 def coder_id_from_element_type(self, element_type):
     if self.use_fake_coders:
         return pickler.dumps(element_type)
     else:
         return self.coders.get_id(coders.registry.get_coder(element_type))
Ejemplo n.º 34
0
def serialize_coder(coder):
    from apache_beam.internal import pickler
    return b'%s$%s' % (coder.__class__.__name__.encode('utf-8'),
                       pickler.dumps(coder))
Ejemplo n.º 35
0
    def test_row_coder_picklable(self):
        # occasionally coders can get pickled, RowCoder should be able to handle it
        coder = coders_registry.get_coder(Person)
        roundtripped = pickler.loads(pickler.dumps(coder))

        self.assertEqual(roundtripped, coder)
Ejemplo n.º 36
0
 def to_runner_api_pickled(self, unused_context):
     # type: (PipelineContext) -> Tuple[str, bytes]
     return (python_urns.PICKLED_TRANSFORM, pickler.dumps(self))
Ejemplo n.º 37
0
 def to_runner_api_pickled(self, unused_context):
     return (python_urns.PICKLED_TRANSFORM, pickler.dumps(self))
Ejemplo n.º 38
0
def _create_pardo_operation(factory,
                            transform_id,
                            transform_proto,
                            consumers,
                            serialized_fn,
                            pardo_proto=None):

    if pardo_proto and pardo_proto.side_inputs:
        input_tags_to_coders = factory.get_input_coders(transform_proto)
        tagged_side_inputs = [
            (tag,
             beam.pvalue.SideInputData.from_runner_api(si, factory.context))
            for tag, si in pardo_proto.side_inputs.items()
        ]
        tagged_side_inputs.sort(key=lambda tag_si: int(
            re.match('side([0-9]+)(-.*)?$', tag_si[0]).group(1)))
        side_input_maps = [
            StateBackedSideInputMap(factory.state_handler, transform_id, tag,
                                    si, input_tags_to_coders[tag])
            for tag, si in tagged_side_inputs
        ]
    else:
        side_input_maps = []

    output_tags = list(transform_proto.outputs.keys())

    # Hack to match out prefix injected by dataflow runner.
    def mutate_tag(tag):
        if 'None' in output_tags:
            if tag == 'None':
                return 'out'
            else:
                return 'out_' + tag
        else:
            return tag

    dofn_data = pickler.loads(serialized_fn)
    if not dofn_data[-1]:
        # Windowing not set.
        if pardo_proto:
            other_input_tags = set.union(set(pardo_proto.side_inputs),
                                         set(pardo_proto.timer_specs))
        else:
            other_input_tags = ()
        pcoll_id, = [
            pcoll for tag, pcoll in transform_proto.inputs.items()
            if tag not in other_input_tags
        ]
        windowing = factory.context.windowing_strategies.get_by_id(
            factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
        serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing, ))

    if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs):
        main_input_coder = None
        timer_inputs = {}
        for tag, pcoll_id in transform_proto.inputs.items():
            if tag in pardo_proto.timer_specs:
                timer_inputs[tag] = pcoll_id
            elif tag in pardo_proto.side_inputs:
                pass
            else:
                # Must be the main input
                assert main_input_coder is None
                main_input_coder = factory.get_windowed_coder(pcoll_id)
        assert main_input_coder is not None

        user_state_context = FnApiUserStateContext(
            factory.state_handler,
            transform_id,
            main_input_coder.key_coder(),
            main_input_coder.window_coder,
            timer_specs=pardo_proto.timer_specs)
    else:
        user_state_context = None
        timer_inputs = None

    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=serialized_fn,
        output_tags=[mutate_tag(tag) for tag in output_tags],
        input=None,
        side_inputs=None,  # Fn API uses proto definitions and the Fn State API
        output_coders=[output_coders[tag] for tag in output_tags])

    return factory.augment_oldstyle_op(
        operations.DoOperation(transform_proto.unique_name,
                               spec,
                               factory.counter_factory,
                               factory.state_sampler,
                               side_input_maps,
                               user_state_context,
                               timer_inputs=timer_inputs),
        transform_proto.unique_name, consumers, output_tags)
Ejemplo n.º 39
0
def _create_simple_pardo_operation(factory, transform_id, transform_proto,
                                   consumers, dofn):
    serialized_fn = pickler.dumps((dofn, (), {}, [], None))
    return _create_pardo_operation(factory, transform_id, transform_proto,
                                   consumers, serialized_fn)
Ejemplo n.º 40
0
def serialize_and_pack_py_fn(fn, urn, id=None):
    """Returns serialized and packed function in a function spec proto."""
    return pack_function_spec_data(pickler.dumps(fn), urn, id)
Ejemplo n.º 41
0
    def process(self, source):
        if isinstance(source, iobase.SourceBundle):
            for value in source.source.read(
                    source.source.get_range_tracker(source.start_position,
                                                    source.stop_position)):
                yield value
        else:
            # Dataflow native source
            with source.reader() as reader:
                for value in reader:
                    yield value


# See DataflowRunner._pardo_fn_data
OLDE_SOURCE_SPLITTABLE_DOFN_DATA = pickler.dumps(
    (OldeSourceSplittableDoFn(), (), {}, [],
     beam.transforms.core.Windowing(GlobalWindows())))


class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
    def __init__(self, use_runner_protos=False):
        super(FnApiRunner, self).__init__()
        self._last_uid = -1
        if use_runner_protos:
            self._map_task_to_protos = self._map_task_to_runner_protos
        else:
            self._map_task_to_protos = self._map_task_to_fn_protos

    def has_metrics_support(self):
        return False