Ejemplo n.º 1
0
    async def invoke(self, comp, arg):
        if asyncio.iscoroutine(arg):
            # Awaiting if we are passed a coro allows us to install and use the async
            # context in conjunction with ConcreteComputations' implementation of
            # __call__.
            arg = await arg
        comp.type_signature.check_function()
        # Save the type signature before compiling. Compilation currently loses
        # container types, so we must remember them here so that they can be
        # restored in the output.
        result_type = comp.type_signature.result
        if self._compiler_pipeline is not None:
            with tracing.span('ExecutionContext', 'Compile', span=True):
                comp = self._compiler_pipeline.compile(comp)

        with tracing.span('ExecutionContext', 'Invoke', span=True):

            if arg is not None:
                cardinalities = self._cardinality_inference_fn(
                    arg, comp.type_signature.parameter)
            else:
                cardinalities = {}

            with self._reset_factory_on_error(self._executor_factory,
                                              cardinalities) as executor:
                py_typecheck.check_type(executor, executor_base.Executor)

                if arg is not None:
                    arg = await tracing.wrap_coroutine_in_current_trace_context(
                        _ingest(executor, arg, comp.type_signature.parameter))

                return await tracing.wrap_coroutine_in_current_trace_context(
                    _invoke(executor, comp, arg, result_type))
Ejemplo n.º 2
0
 def test_nested_non_async_span(self):
     mock = set_mock_trace()
     with tracing.span('outer', 'osub'):
         with tracing.span('middle', 'msub'):
             with tracing.span('inner', 'isub'):
                 pass
     self.assertEqual(mock.scopes, ['outer', 'middle', 'inner'])
     self.assertEqual(mock.sub_scopes, ['osub', 'msub', 'isub'])
     self.assertEqual(mock.parent_span_yields, [None, 0, 1])
Ejemplo n.º 3
0
  def test_sibling_spans(self):
    mock = set_mock_trace()
    with tracing.span('parent', ''):
      with tracing.span('child1', ''):
        pass
      with tracing.span('child2', ''):
        pass
    with tracing.span('parentless', ''):
      pass

    self.assertEqual(mock.scopes, ['parent', 'child1', 'child2', 'parentless'])
    self.assertEqual(mock.parent_span_yields, [None, 0, 0, None])
Ejemplo n.º 4
0
  def invoke(self, comp, arg):
    comp.type_signature.check_function()
    # Save the type signature before compiling. Compilation currently loses
    # container types, so we must remember them here so that they can be
    # restored in the output.
    result_type = comp.type_signature.result
    if self._compiler_pipeline is not None:
      with tracing.span('ExecutionContext', 'Compile', span=True):
        comp = self._compiler_pipeline.compile(comp)

    with tracing.span('ExecutionContext', 'Invoke', span=True):

      @contextlib.contextmanager
      def executor_closer(ex_factory, cardinalities):
        """Wraps an Executor into a closeable resource."""
        # TODO(b/168744510): The lifecycles embedded here are confusing; unify
        # or clarify the need for them.
        ex = ex_factory.create_executor(cardinalities)
        try:
          yield ex
        except Exception as e:
          ex_factory.clean_up_executors()
          raise e
        finally:
          ex.close()

      if arg is not None:
        py_typecheck.check_type(arg, ExecutionContextValue)
        unwrapped_arg = _unwrap_execution_context_value(arg)
        cardinalities = cardinalities_utils.infer_cardinalities(
            unwrapped_arg, arg.type_signature)
      else:
        cardinalities = {}

      with executor_closer(self._executor_factory, cardinalities) as executor:
        py_typecheck.check_type(executor, executor_base.Executor)

        def get_event_loop():
          new_loop = asyncio.new_event_loop()
          new_loop.set_task_factory(
              tracing.propagate_trace_context_task_factory)
          return new_loop

        event_loop = get_event_loop()

        if arg is not None:
          arg = event_loop.run_until_complete(
              tracing.wrap_coroutine_in_current_trace_context(
                  _ingest(executor, unwrapped_arg, arg.type_signature)))

        return event_loop.run_until_complete(
            tracing.wrap_coroutine_in_current_trace_context(
                _invoke(executor, comp, arg, result_type)))
Ejemplo n.º 5
0
    def test_parenting_non_async_to_async_to_nested_async(self):
        mock = set_mock_trace()
        loop = asyncio.new_event_loop()
        loop.set_task_factory(tracing.propagate_trace_context_task_factory)

        def run_loop():
            loop.run_forever()
            loop.close()

        thread = threading.Thread(target=functools.partial(run_loop),
                                  daemon=True)
        thread.start()

        @tracing.trace
        async def middle():
            with tracing.span('inner', ''):
                pass

        with tracing.span('outer', ''):
            # This sends the coroutine over to another thread,
            # keeping the current trace context.
            coro_with_trace_ctx = tracing.wrap_coroutine_in_current_trace_context(
                middle())
            asyncio.run_coroutine_threadsafe(coro_with_trace_ctx,
                                             loop).result()

        loop.call_soon_threadsafe(loop.stop)
        thread.join()

        self.assertEqual(mock.parent_span_yields, [None, 0, 1])
        self.assertEqual(mock.scopes, ['outer', '<locals>', 'inner'])
        self.assertEqual(mock.sub_scopes, ['', 'middle', ''])
Ejemplo n.º 6
0
    def invoke(self, comp, arg):
        comp.type_signature.check_function()
        # Save the type signature before compiling. Compilation currently loses
        # container types, so we must remember them here so that they can be
        # restored in the output.
        result_type = comp.type_signature.result
        if self._compiler_pipeline is not None:
            with tracing.span('ExecutionContext', 'Compile', span=True):
                comp = self._compiler_pipeline.compile(comp)

        with tracing.span('ExecutionContext', 'Invoke', span=True):

            @contextlib.contextmanager
            def executor_closer(ex_factory, cardinalities):
                """Wraps an Executor into a closeable resource."""
                ex = ex_factory.create_executor(cardinalities)
                try:
                    yield ex
                except Exception as e:
                    ex_factory.clean_up_executors()
                    raise e

            if arg is not None:
                py_typecheck.check_type(arg, ExecutionContextValue)
                unwrapped_arg = _unwrap_execution_context_value(arg)
                cardinalities = cardinalities_utils.infer_cardinalities(
                    unwrapped_arg, arg.type_signature)
            else:
                cardinalities = {}

            with executor_closer(self._executor_factory,
                                 cardinalities) as executor:
                py_typecheck.check_type(executor, executor_base.Executor)

                if arg is not None:
                    arg = self._event_loop.run_until_complete(
                        tracing.wrap_coroutine_in_current_trace_context(
                            _ingest(executor, unwrapped_arg,
                                    arg.type_signature)))

                return self._event_loop.run_until_complete(
                    tracing.wrap_coroutine_in_current_trace_context(
                        _invoke(executor, comp, arg, result_type)))
Ejemplo n.º 7
0
 def test_basic_span(self):
     mock = set_mock_trace()
     with tracing.span('scope', 'sub_scope', options='some_option'):
         pass
     self.assertEqual(mock.scopes[0], 'scope')
     self.assertEqual(mock.sub_scopes[0], 'sub_scope')
     self.assertEqual(mock.parent_span_yields[0], None)
     self.assertEqual(mock.fn_argss[0], None)
     self.assertEqual(mock.fn_kwargss[0], None)
     self.assertEqual(mock.trace_optss[0], {'options': 'some_option'})
     self.assertIsInstance(mock.trace_results[0], tracing.TracedSpan)
Ejemplo n.º 8
0
    def invoke(self, comp, arg):

        with tracing.span('ExecutionContext', 'Invoke'):

            @contextlib.contextmanager
            def executor_closer(wrapped_executor):
                """Wraps an Executor into a closeable resource."""
                try:
                    yield wrapped_executor
                finally:
                    wrapped_executor.close()

            if arg is not None:
                py_typecheck.check_type(arg, ExecutionContextValue)
                unwrapped_arg = _unwrap_execution_context_value(arg)
                cardinalities = cardinalities_utils.infer_cardinalities(
                    unwrapped_arg, arg.type_signature)
            else:
                cardinalities = {}

            with executor_closer(
                    self._executor_factory.create_executor(
                        cardinalities)) as executor:
                py_typecheck.check_type(executor, executor_base.Executor)

                def get_event_loop():
                    new_loop = asyncio.new_event_loop()
                    new_loop.set_task_factory(
                        tracing.propagate_trace_context_task_factory)
                    return new_loop

                event_loop = get_event_loop()

                if arg is not None:
                    arg = event_loop.run_until_complete(
                        tracing.run_coroutine_in_ambient_trace_context(
                            _ingest(executor, unwrapped_arg,
                                    arg.type_signature)))

                return event_loop.run_until_complete(
                    tracing.run_coroutine_in_ambient_trace_context(
                        _invoke(executor, comp, arg)))
Ejemplo n.º 9
0
 def CreateValue(
     self,
     request: executor_pb2.CreateValueRequest,
     context: grpc.ServicerContext,
 ) -> executor_pb2.CreateValueResponse:
     """Creates a value embedded in the executor."""
     py_typecheck.check_type(request, executor_pb2.CreateValueRequest)
     with self._try_handle_request_context(
             request, context, executor_pb2.CreateValueResponse):
         with tracing.span('ExecutorService.CreateValue',
                           'deserialize_value'):
             value, value_type = (value_serialization.deserialize_value(
                 request.value))
         value_id = str(uuid.uuid4())
         coro = self.executor(request,
                              context).create_value(value, value_type)
         future_val = self._run_coro_threadsafe_with_tracing(coro)
         with self._lock:
             self._values[value_id] = future_val
         return executor_pb2.CreateValueResponse(
             value_ref=executor_pb2.ValueRef(id=value_id))
Ejemplo n.º 10
0
 def CreateValue(
     self,
     request: executor_pb2.CreateValueRequest,
     context: grpc.ServicerContext,
 ) -> executor_pb2.CreateValueResponse:
   """Creates a value embedded in the executor."""
   py_typecheck.check_type(request, executor_pb2.CreateValueRequest)
   try:
     with tracing.span('ExecutorService.CreateValue', 'deserialize_value'):
       value, value_type = (
           executor_service_utils.deserialize_value(request.value))
     value_id = str(uuid.uuid4())
     coro = self._executor.create_value(value, value_type)
     future_val = self._run_coro_threadsafe_with_tracing(coro)
     with self._lock:
       self._values[value_id] = future_val
     return executor_pb2.CreateValueResponse(
         value_ref=executor_pb2.ValueRef(id=value_id))
   except (ValueError, TypeError) as err:
     _set_invalid_arg_err(context, err)
     return executor_pb2.CreateValueResponse()
Ejemplo n.º 11
0
def _call_embedded_tf(*, arg, param_fns, result_fns, result_type, wrapped_fn,
                      destroy_before_invocation, destroy_after_invocation):
    """Function to be run upon EagerTFExecutor.create_call invocation.

  As this function is run completely synchronously, and
  `EagerTFExecutor.create_call` invocations represent the main work of the
  program, this function should be kept as-thin a wrapper around delegation
  to the eager TensorFlow runtime as possible.

  Args:
    arg: Argument on which to invoke embedded function.
    param_fns: Functions to be applied to elements of `arg` before passing to
      `wrapped_fn`, to prepare these argument for ingestion by the eager TF
      runtime.
    result_fns: Functions to be applied to results of calling `wrapped_fn` on
      arg before re-embedding as EagerTFExecutor values.
    result_type: TFF Type signature of the result of `wrapped_fn`.
    wrapped_fn: Result of `tf.compat.v1.wrap_function` to run in the eager TF
      runtime.
    destroy_before_invocation: eager TF runtime resources which should be
      destroyed before invoking `wrapped_fn`. Examples might include hashtable
      resources.
    destroy_after_invocation: eager TF runtime resources which should be
      destroyed after invoking `wrapped_fn`. Examples might include resource
      variables.

  Returns:
    A `structure.Struct` representing the result of invoking `wrapped_fn` on
    `arg`.

  Raises:
    RuntimeError: If `arg` and `param_fns` have different numbers of elements.
  """

    # TODO(b/166479382): This cleanup-before-invocation pattern is a workaround
    # to square the circle of TF data expecting to lazily reference this
    # resource on iteration, as well as usages that expect to reinitialize a
    # table with new data. Revisit the semantics implied by this cleanup
    # pattern.
    with tracing.span('EagerTFExecutor.create_call',
                      'resource_cleanup_before_invocation',
                      span=True):
        for resource in destroy_before_invocation:
            tf.raw_ops.DestroyResourceOp(resource=resource)

    param_elements = []
    if arg is not None:
        with tracing.span('EagerTFExecutor.create_call',
                          'arg_ingestion',
                          span=True):
            arg_parts = structure.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
    result_parts = wrapped_fn(*param_elements)

    # There is a tf.wrap_function(...) issue b/144127474 that variables created
    # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
    # destroyed.  So get all the variables from `wrapped_fn` and destroy
    # manually.
    # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
    # is fixed.
    with tracing.span('EagerTFExecutor.create_call',
                      'resource_cleanup_after_invocation',
                      span=True):
        for resource in destroy_after_invocation:
            tf.raw_ops.DestroyResourceOp(resource=resource)

    with tracing.span('EagerTFExecutor.create_call',
                      'result_packing',
                      span=True):
        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return structure.pack_sequence_as(result_type, result_elements)
Ejemplo n.º 12
0
 async def middle():
     with tracing.span('inner', ''):
         pass
Ejemplo n.º 13
0
def transform_to_native_form(
    comp: computation_base.Computation,
    transform_math_to_tf: bool = False,
    grappler_config: Optional[tf.compat.v1.ConfigProto] = None
) -> computation_base.Computation:
  """Compiles a computation for execution in the TFF native runtime.

  This function transforms the proto underlying `comp` by transforming it
  to call-dominant form (see `tff.framework.transform_to_call_dominant` for
  definition).

  Args:
    comp: Instance of `computation_base.Computation` to compile.
    transform_math_to_tf: Whether to additional transform math to TensorFlow
      graphs. Necessary if running on a execution state without
      ReferenceResolvingExecutors underneath FederatingExecutors.
    grappler_config: Configuration for Grappler optimizations to perform on the
      TensorFlow computations. If `None`, Grappler will not be run and no
      optimizations wil be applied.

  Returns:
    A new `computation_base.Computation` representing the compiled version of
    `comp`.
  """
  proto = computation_impl.ComputationImpl.get_proto(comp)
  computation_building_block = building_blocks.ComputationBuildingBlock.from_proto(
      proto)
  try:
    logging.debug('Compiling TFF computation to CDF.')
    with tracing.span(
        'transform_to_native_form', 'transform_to_call_dominant', span=True):
      call_dominant_form, _ = transformations.transform_to_call_dominant(
          computation_building_block)
    logging.debug('Computation compiled to:')
    logging.debug(call_dominant_form.formatted_representation())
    if transform_math_to_tf:
      logging.debug('Compiling local computations to TensorFlow.')
      with tracing.span(
          'transform_to_native_form',
          'compile_local_computations_to_tensorflow',
          span=True):
        call_dominant_form, _ = transformations.compile_local_computations_to_tensorflow(
            call_dominant_form)
      logging.debug('Computation compiled to:')
      logging.debug(call_dominant_form.formatted_representation())
    if grappler_config is not None:
      with tracing.span(
          'transform_to_native_form', 'optimize_tf_graphs', span=True):
        call_dominant_form, _ = transformations.optimize_tensorflow_graphs(
            call_dominant_form, grappler_config)
    with tracing.span(
        'transform_to_native_form',
        'transform_tf_call_ops_disable_grappler',
        span=True):
      disabled_grapler_form, _ = tree_transformations.transform_tf_call_ops_to_disable_grappler(
          call_dominant_form)
    with tracing.span(
        'transform_to_native_form', 'transform_tf_add_ids', span=True):
      form_with_ids, _ = tree_transformations.transform_tf_add_ids(
          disabled_grapler_form)
    return computation_wrapper_instances.building_block_to_computation(
        form_with_ids)
  except ValueError as e:
    logging.debug('Compilation for native runtime failed with error %s', e)
    logging.debug('computation: %s',
                  computation_building_block.compact_representation())
    return comp