def start_bundle(self):
        transform = self._applied_ptransform.transform

        self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
        for output_tag in self._applied_ptransform.outputs:
            output_pcollection = pvalue.PCollection(None, tag=output_tag)
            output_pcollection.producer = self._applied_ptransform
            self._tagged_receivers[output_tag] = (
                self._evaluation_context.create_bundle(output_pcollection))
            self._tagged_receivers[output_tag].tag = output_tag

        self._counter_factory = counters.CounterFactory()

        # TODO(aaltay): Consider storing the serialized form as an optimization.
        dofn = (pickler.loads(pickler.dumps(transform.dofn))
                if self._perform_dofn_pickle_test else transform.dofn)

        args = transform.args if hasattr(transform, 'args') else []
        kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

        self.runner = DoFnRunner(
            dofn,
            args,
            kwargs,
            self._side_inputs,
            self._applied_ptransform.inputs[0].windowing,
            tagged_receivers=self._tagged_receivers,
            step_name=self._applied_ptransform.full_label,
            state=DoFnState(self._counter_factory),
            scoped_metrics_container=self.scoped_metrics_container)
        self.runner.start()
Exemple #2
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = pickler.loads(pickler.dumps(transform.dofn))

    pipeline_options = self._evaluation_context.pipeline_options
    if (pipeline_options is not None
        and pipeline_options.view_as(TypeOptions).runtime_type_check):
      dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

    dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
    self.runner = DoFnRunner(
        dofn, transform.args, transform.kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        scoped_metrics_container=self.scoped_metrics_container)
    self.runner.start()
    def __init__(
            self,
            pipeline_options,
            bundle_factory,  # type: BundleFactory
            root_transforms,
            value_to_consumers,
            step_names,
            views,  # type: Iterable[pvalue.AsSideInput]
            clock):
        self.pipeline_options = pipeline_options
        self._bundle_factory = bundle_factory
        self._root_transforms = root_transforms
        self._value_to_consumers = value_to_consumers
        self._step_names = step_names
        self.views = views
        self._pcollection_to_views = collections.defaultdict(
            list
        )  # type: DefaultDict[pvalue.PCollection, List[pvalue.AsSideInput]]
        for view in views:
            self._pcollection_to_views[view.pvalue].append(view)
        self._transform_keyed_states = self._initialize_keyed_states(
            root_transforms, value_to_consumers)
        self._side_inputs_container = _SideInputsContainer(views)
        self._watermark_manager = WatermarkManager(
            clock, root_transforms, value_to_consumers,
            self._transform_keyed_states)
        self._pending_unblocked_tasks = [
        ]  # type: List[Tuple[TransformExecutor, Timestamp]]
        self._counter_factory = counters.CounterFactory()
        self._metrics = DirectMetrics()

        self._lock = threading.Lock()
    def test_create_process_wide(self):
        sampler = statesampler.StateSampler('', counters.CounterFactory())
        statesampler.set_current_tracker(sampler)
        state1 = sampler.scoped_state(
            'mystep', 'myState', metrics_container=MetricsContainer('mystep'))

        try:
            sampler.start()
            with state1:
                urn = "my:custom:urn"
                labels = {'key': 'value'}
                counter = InternalMetrics.counter(urn=urn,
                                                  labels=labels,
                                                  process_wide=True)
                # Test that if process_wide is set, that it will be set
                # on the process_wide container.
                counter.inc(10)
                self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))

                del counter

                metric_name = MetricName(None, None, urn=urn, labels=labels)
                # Expect a value set on the current container.
                self.assertEqual(
                    MetricsEnvironment.process_wide_container().get_counter(
                        metric_name).get_cumulative(), 10)
                # Expect no value set on the current container.
                self.assertEqual(
                    MetricsEnvironment.current_container().get_counter(
                        metric_name).get_cumulative(), 0)
        finally:
            sampler.stop()
Exemple #5
0
    def start_bundle(self):
        transform = self._applied_ptransform.transform

        self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
        if isinstance(self._applied_ptransform.parent.transform,
                      core._MultiParDo):  # pylint: disable=protected-access
            do_outputs_tuple = self._applied_ptransform.parent.outputs[0]
            assert isinstance(do_outputs_tuple, pvalue.DoOutputsTuple)
            main_output_pcollection = do_outputs_tuple[
                do_outputs_tuple._main_tag]  # pylint: disable=protected-access

            for side_output_tag in transform.side_output_tags:
                output_pcollection = do_outputs_tuple[side_output_tag]
                self._tagged_receivers[side_output_tag] = (
                    self._evaluation_context.create_bundle(output_pcollection))
                self._tagged_receivers[side_output_tag].tag = side_output_tag
        else:
            assert len(self._outputs) == 1
            main_output_pcollection = list(self._outputs)[0]

        self._tagged_receivers[None] = self._evaluation_context.create_bundle(
            main_output_pcollection)
        self._tagged_receivers[None].tag = None  # main_tag is None.

        self._counter_factory = counters.CounterFactory()

        # TODO(aaltay): Consider storing the serialized form as an optimization.
        dofn = pickler.loads(pickler.dumps(transform.dofn))

        pipeline_options = self._evaluation_context.pipeline_options
        if (pipeline_options is not None
                and pipeline_options.view_as(TypeOptions).runtime_type_check):
            # TODO(sourabhbajaj): Remove this if-else
            if isinstance(dofn, core.NewDoFn):
                dofn = TypeCheckWrapperNewDoFn(dofn,
                                               transform.get_type_hints())
            else:
                dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())

        # TODO(sourabhbajaj): Remove this if-else
        if isinstance(dofn, core.NewDoFn):
            dofn = OutputCheckWrapperNewDoFn(
                dofn, self._applied_ptransform.full_label)
        else:
            dofn = OutputCheckWrapperDoFn(dofn,
                                          self._applied_ptransform.full_label)
        self.runner = DoFnRunner(
            dofn,
            transform.args,
            transform.kwargs,
            self._side_inputs,
            self._applied_ptransform.inputs[0].windowing,
            tagged_receivers=self._tagged_receivers,
            step_name=self._applied_ptransform.full_label,
            state=DoFnState(self._counter_factory),
            scoped_metrics_container=self.scoped_metrics_container)
        self.runner.start()
 def __init__(
     self, process_bundle_descriptor, state_handler, data_channel_factory):
   self.process_bundle_descriptor = process_bundle_descriptor
   self.state_handler = state_handler
   self.data_channel_factory = data_channel_factory
   # TODO(robertwb): Figure out the correct prefix to use for output counters
   # from StateSampler.
   self.counter_factory = counters.CounterFactory()
   self.state_sampler = statesampler.StateSampler(
       'fnapi-step-%s' % self.process_bundle_descriptor.id,
       self.counter_factory)
   self.ops = self.create_execution_tree(self.process_bundle_descriptor)
 def run(self):
   state_sampler = statesampler.StateSampler('', counters.CounterFactory())
   statesampler.set_current_tracker(state_sampler)
   while not self.shutdown_requested:
     task = self._get_task_or_none()
     if task:
       try:
         if not self.shutdown_requested:
           self._update_name(task)
           task.call(state_sampler)
           self._update_name()
       finally:
         self.queue.task_done()
Exemple #8
0
  def start_bundle(self):
    transform = self._applied_ptransform.transform

    self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
    for output_tag in self._applied_ptransform.outputs:
      output_pcollection = pvalue.PCollection(None, tag=output_tag)
      output_pcollection.producer = self._applied_ptransform
      self._tagged_receivers[output_tag] = (
          self._evaluation_context.create_bundle(output_pcollection))
      self._tagged_receivers[output_tag].tag = output_tag

    self._counter_factory = counters.CounterFactory()

    # TODO(aaltay): Consider storing the serialized form as an optimization.
    dofn = (
        pickler.loads(pickler.dumps(transform.dofn))
        if self._perform_dofn_pickle_test else transform.dofn)

    args = transform.args if hasattr(transform, 'args') else []
    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}

    self.user_state_context = None
    self.user_timer_map = {}
    if is_stateful_dofn(dofn):
      kv_type_hint = self._applied_ptransform.inputs[0].element_type
      if kv_type_hint and kv_type_hint != Any:
        coder = coders.registry.get_coder(kv_type_hint)
        self.key_coder = coder.key_coder()
      else:
        self.key_coder = coders.registry.get_coder(Any)

      self.user_state_context = DirectUserStateContext(
          self._step_context, dofn, self.key_coder)
      _, all_timer_specs = get_dofn_specs(dofn)
      for timer_spec in all_timer_specs:
        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec

    self.runner = DoFnRunner(
        dofn,
        args,
        kwargs,
        self._side_inputs,
        self._applied_ptransform.inputs[0].windowing,
        tagged_receivers=self._tagged_receivers,
        step_name=self._applied_ptransform.full_label,
        state=DoFnState(self._counter_factory),
        user_state_context=self.user_state_context)
    self.runner.setup()
    self.runner.start()
Exemple #9
0
    def __init__(self, pipeline_options, bundle_factory, root_transforms,
                 value_to_consumers, step_names, views):
        self.pipeline_options = pipeline_options
        self._bundle_factory = bundle_factory
        self._root_transforms = root_transforms
        self._value_to_consumers = value_to_consumers
        self._step_names = step_names
        self.views = views

        # AppliedPTransform -> Evaluator specific state objects
        self._application_state_interals = {}
        self._watermark_manager = WatermarkManager(Clock(), root_transforms,
                                                   value_to_consumers)
        self._side_inputs_container = _SideInputsContainer(views)
        self._pending_unblocked_tasks = []
        self._counter_factory = counters.CounterFactory()
        self._cache = None
        self._metrics = DirectMetrics()

        self._lock = threading.Lock()
    def create_execution_tree(self, descriptor):
        # TODO(robertwb): Figure out the correct prefix to use for output counters
        # from StateSampler.
        counter_factory = counters.CounterFactory()
        state_sampler = statesampler.StateSampler(
            'fnapi-step%s' % descriptor.id, counter_factory)

        transform_factory = BeamTransformFactory(descriptor,
                                                 self.data_channel_factory,
                                                 counter_factory,
                                                 state_sampler,
                                                 self.state_handler)

        pcoll_consumers = collections.defaultdict(list)
        for transform_id, transform_proto in descriptor.transforms.items():
            for pcoll_id in transform_proto.inputs.values():
                pcoll_consumers[pcoll_id].append(transform_id)

        @memoize
        def get_operation(transform_id):
            transform_consumers = {
                tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
                for tag, pcoll_id in
                descriptor.transforms[transform_id].outputs.items()
            }
            return transform_factory.create_operation(transform_id,
                                                      transform_consumers)

        # Operations must be started (hence returned) in order.
        @memoize
        def topological_height(transform_id):
            return 1 + max([0] + [
                topological_height(consumer) for pcoll in
                descriptor.transforms[transform_id].outputs.values()
                for consumer in pcoll_consumers[pcoll]
            ])

        return [
            get_operation(transform_id) for transform_id in sorted(
                descriptor.transforms, key=topological_height, reverse=True)
        ]
    def __init__(self, pipeline_options, bundle_factory, root_transforms,
                 value_to_consumers, step_names, views):
        self.pipeline_options = pipeline_options
        self._bundle_factory = bundle_factory
        self._root_transforms = root_transforms
        self._value_to_consumers = value_to_consumers
        self._step_names = step_names
        self.views = views
        self._pcollection_to_views = collections.defaultdict(list)
        for view in views:
            self._pcollection_to_views[view.pvalue].append(view)
        self._transform_keyed_states = self._initialize_keyed_states(
            root_transforms, value_to_consumers)
        self._watermark_manager = WatermarkManager(Clock(), root_transforms,
                                                   value_to_consumers)
        self._side_inputs_container = _SideInputsContainer(views)
        self._pending_unblocked_tasks = []
        self._counter_factory = counters.CounterFactory()
        self._cache = None
        self._metrics = DirectMetrics()

        self._lock = threading.Lock()
Exemple #12
0
  def __init__(
      self, process_bundle_descriptor, state_handler, data_channel_factory):
    """Initialize a bundle processor.

    Args:
      process_bundle_descriptor (``beam_fn_api_pb2.ProcessBundleDescriptor``):
        a description of the stage that this ``BundleProcessor``is to execute.
      state_handler (beam_fn_api_pb2_grpc.BeamFnStateServicer).
      data_channel_factory (``data_plane.DataChannelFactory``).
    """
    self.process_bundle_descriptor = process_bundle_descriptor
    self.state_handler = state_handler
    self.data_channel_factory = data_channel_factory
    # TODO(robertwb): Figure out the correct prefix to use for output counters
    # from StateSampler.
    self.counter_factory = counters.CounterFactory()
    self.state_sampler = statesampler.StateSampler(
        'fnapi-step-%s' % self.process_bundle_descriptor.id,
        self.counter_factory)
    self.ops = self.create_execution_tree(self.process_bundle_descriptor)
    for op in self.ops.values():
      op.setup()
    self.splitting_lock = threading.Lock()
Exemple #13
0
    def test_create_counter_distribution(self):
        sampler = statesampler.StateSampler('', counters.CounterFactory())
        statesampler.set_current_tracker(sampler)
        state1 = sampler.scoped_state(
            'mystep', 'myState', metrics_container=MetricsContainer('mystep'))

        try:
            sampler.start()
            with state1:
                counter_ns = 'aCounterNamespace'
                distro_ns = 'aDistributionNamespace'
                name = 'a_name'
                counter = Metrics.counter(counter_ns, name)
                distro = Metrics.distribution(distro_ns, name)
                counter.inc(10)
                counter.dec(3)
                distro.update(10)
                distro.update(2)
                self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
                self.assertTrue(
                    isinstance(distro, Metrics.DelegatingDistribution))

                del distro
                del counter

                container = MetricsEnvironment.current_container()
                self.assertEqual(
                    container.get_counter(MetricName(counter_ns,
                                                     name)).get_cumulative(),
                    7)
                self.assertEqual(
                    container.get_distribution(MetricName(
                        distro_ns, name)).get_cumulative(),
                    DistributionData(12, 2, 2, 10))
        finally:
            sampler.stop()
Exemple #14
0
 def setUp(self):
   self.counter_factory = counters.CounterFactory()
Exemple #15
0
  def create_execution_tree(self, descriptor):
    # TODO(vikasrk): Add an id field to Coder proto and use that instead.
    coders = {coder.function_spec.id: operation_specs.get_coder_from_spec(
        json.loads(unpack_function_spec_data(coder.function_spec)))
              for coder in descriptor.coders}

    counter_factory = counters.CounterFactory()
    # TODO(robertwb): Figure out the correct prefix to use for output counters
    # from StateSampler.
    state_sampler = statesampler.StateSampler(
        'fnapi-step%s-' % descriptor.id, counter_factory)
    consumers = collections.defaultdict(lambda: collections.defaultdict(list))
    ops_by_id = {}
    reversed_ops = []

    for transform in reversed(descriptor.primitive_transform):
      # TODO(robertwb): Figure out how to plumb through the operation name (e.g.
      # "s3") from the service through the FnAPI so that msec counters can be
      # reported and correctly plumbed through the service and the UI.
      operation_name = 'fnapis%s' % transform.id

      def only_element(iterable):
        element, = iterable
        return element

      if transform.function_spec.urn == DATA_OUTPUT_URN:
        target = beam_fn_api_pb2.Target(
            primitive_transform_reference=transform.id,
            name=only_element(transform.outputs.keys()))

        op = DataOutputOperation(
            operation_name,
            transform.step_name,
            consumers[transform.id],
            counter_factory,
            state_sampler,
            coders[only_element(transform.outputs.values()).coder_reference],
            target,
            self.data_channel_factory.create_data_channel(
                transform.function_spec))

      elif transform.function_spec.urn == DATA_INPUT_URN:
        target = beam_fn_api_pb2.Target(
            primitive_transform_reference=transform.id,
            name=only_element(transform.inputs.keys()))
        op = DataInputOperation(
            operation_name,
            transform.step_name,
            consumers[transform.id],
            counter_factory,
            state_sampler,
            coders[only_element(transform.outputs.values()).coder_reference],
            target,
            self.data_channel_factory.create_data_channel(
                transform.function_spec))

      elif transform.function_spec.urn == PYTHON_DOFN_URN:
        def create_side_input(tag, si):
          # TODO(robertwb): Extract windows (and keys) out of element data.
          return operation_specs.WorkerSideInputSource(
              tag=tag,
              source=SideInputSource(
                  self.state_handler,
                  beam_fn_api_pb2.StateKey(
                      function_spec_reference=si.view_fn.id),
                  coder=unpack_and_deserialize_py_fn(si.view_fn)))
        output_tags = list(transform.outputs.keys())
        spec = operation_specs.WorkerDoFn(
            serialized_fn=unpack_function_spec_data(transform.function_spec),
            output_tags=output_tags,
            input=None,
            side_inputs=[create_side_input(tag, si)
                         for tag, si in transform.side_inputs.items()],
            output_coders=[coders[transform.outputs[out].coder_reference]
                           for out in output_tags])

        op = operations.DoOperation(operation_name, spec, counter_factory,
                                    state_sampler)
        # TODO(robertwb): Move these to the constructor.
        op.step_name = transform.step_name
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(
                consumer, output_tags.index(tag))

      elif transform.function_spec.urn == IDENTITY_DOFN_URN:
        op = operations.FlattenOperation(operation_name, None, counter_factory,
                                         state_sampler)
        # TODO(robertwb): Move these to the constructor.
        op.step_name = transform.step_name
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(consumer, 0)

      elif transform.function_spec.urn == PYTHON_SOURCE_URN:
        source = load_compressed(unpack_function_spec_data(
            transform.function_spec))
        # TODO(vikasrk): Remove this once custom source is implemented with
        # splittable dofn via the data plane.
        spec = operation_specs.WorkerRead(
            iobase.SourceBundle(1.0, source, None, None),
            [WindowedValueCoder(source.default_output_coder())])
        op = operations.ReadOperation(operation_name, spec, counter_factory,
                                      state_sampler)
        op.step_name = transform.step_name
        output_tags = list(transform.outputs.keys())
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(
                consumer, output_tags.index(tag))

      else:
        raise NotImplementedError

      # Record consumers.
      for _, inputs in transform.inputs.items():
        for target in inputs.target:
          consumers[target.primitive_transform_reference][target.name].append(
              op)

      reversed_ops.append(op)
      ops_by_id[transform.id] = op

    return list(reversed(reversed_ops)), ops_by_id