async def invoke(self, comp, arg): if asyncio.iscoroutine(arg): # Awaiting if we are passed a coro allows us to install and use the async # context in conjunction with ConcreteComputations' implementation of # __call__. arg = await arg comp.type_signature.check_function() # Save the type signature before compiling. Compilation currently loses # container types, so we must remember them here so that they can be # restored in the output. result_type = comp.type_signature.result if self._compiler_pipeline is not None: with tracing.span('ExecutionContext', 'Compile', span=True): comp = self._compiler_pipeline.compile(comp) with tracing.span('ExecutionContext', 'Invoke', span=True): if arg is not None: cardinalities = self._cardinality_inference_fn( arg, comp.type_signature.parameter) else: cardinalities = {} with self._reset_factory_on_error(self._executor_factory, cardinalities) as executor: py_typecheck.check_type(executor, executor_base.Executor) if arg is not None: arg = await tracing.wrap_coroutine_in_current_trace_context( _ingest(executor, arg, comp.type_signature.parameter)) return await tracing.wrap_coroutine_in_current_trace_context( _invoke(executor, comp, arg, result_type))
def test_nested_non_async_span(self): mock = set_mock_trace() with tracing.span('outer', 'osub'): with tracing.span('middle', 'msub'): with tracing.span('inner', 'isub'): pass self.assertEqual(mock.scopes, ['outer', 'middle', 'inner']) self.assertEqual(mock.sub_scopes, ['osub', 'msub', 'isub']) self.assertEqual(mock.parent_span_yields, [None, 0, 1])
def test_sibling_spans(self): mock = set_mock_trace() with tracing.span('parent', ''): with tracing.span('child1', ''): pass with tracing.span('child2', ''): pass with tracing.span('parentless', ''): pass self.assertEqual(mock.scopes, ['parent', 'child1', 'child2', 'parentless']) self.assertEqual(mock.parent_span_yields, [None, 0, 0, None])
def invoke(self, comp, arg): comp.type_signature.check_function() # Save the type signature before compiling. Compilation currently loses # container types, so we must remember them here so that they can be # restored in the output. result_type = comp.type_signature.result if self._compiler_pipeline is not None: with tracing.span('ExecutionContext', 'Compile', span=True): comp = self._compiler_pipeline.compile(comp) with tracing.span('ExecutionContext', 'Invoke', span=True): @contextlib.contextmanager def executor_closer(ex_factory, cardinalities): """Wraps an Executor into a closeable resource.""" # TODO(b/168744510): The lifecycles embedded here are confusing; unify # or clarify the need for them. ex = ex_factory.create_executor(cardinalities) try: yield ex except Exception as e: ex_factory.clean_up_executors() raise e finally: ex.close() if arg is not None: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = cardinalities_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) else: cardinalities = {} with executor_closer(self._executor_factory, cardinalities) as executor: py_typecheck.check_type(executor, executor_base.Executor) def get_event_loop(): new_loop = asyncio.new_event_loop() new_loop.set_task_factory( tracing.propagate_trace_context_task_factory) return new_loop event_loop = get_event_loop() if arg is not None: arg = event_loop.run_until_complete( tracing.wrap_coroutine_in_current_trace_context( _ingest(executor, unwrapped_arg, arg.type_signature))) return event_loop.run_until_complete( tracing.wrap_coroutine_in_current_trace_context( _invoke(executor, comp, arg, result_type)))
def test_parenting_non_async_to_async_to_nested_async(self): mock = set_mock_trace() loop = asyncio.new_event_loop() loop.set_task_factory(tracing.propagate_trace_context_task_factory) def run_loop(): loop.run_forever() loop.close() thread = threading.Thread(target=functools.partial(run_loop), daemon=True) thread.start() @tracing.trace async def middle(): with tracing.span('inner', ''): pass with tracing.span('outer', ''): # This sends the coroutine over to another thread, # keeping the current trace context. coro_with_trace_ctx = tracing.wrap_coroutine_in_current_trace_context( middle()) asyncio.run_coroutine_threadsafe(coro_with_trace_ctx, loop).result() loop.call_soon_threadsafe(loop.stop) thread.join() self.assertEqual(mock.parent_span_yields, [None, 0, 1]) self.assertEqual(mock.scopes, ['outer', '<locals>', 'inner']) self.assertEqual(mock.sub_scopes, ['', 'middle', ''])
def invoke(self, comp, arg): comp.type_signature.check_function() # Save the type signature before compiling. Compilation currently loses # container types, so we must remember them here so that they can be # restored in the output. result_type = comp.type_signature.result if self._compiler_pipeline is not None: with tracing.span('ExecutionContext', 'Compile', span=True): comp = self._compiler_pipeline.compile(comp) with tracing.span('ExecutionContext', 'Invoke', span=True): @contextlib.contextmanager def executor_closer(ex_factory, cardinalities): """Wraps an Executor into a closeable resource.""" ex = ex_factory.create_executor(cardinalities) try: yield ex except Exception as e: ex_factory.clean_up_executors() raise e if arg is not None: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = cardinalities_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) else: cardinalities = {} with executor_closer(self._executor_factory, cardinalities) as executor: py_typecheck.check_type(executor, executor_base.Executor) if arg is not None: arg = self._event_loop.run_until_complete( tracing.wrap_coroutine_in_current_trace_context( _ingest(executor, unwrapped_arg, arg.type_signature))) return self._event_loop.run_until_complete( tracing.wrap_coroutine_in_current_trace_context( _invoke(executor, comp, arg, result_type)))
def test_basic_span(self): mock = set_mock_trace() with tracing.span('scope', 'sub_scope', options='some_option'): pass self.assertEqual(mock.scopes[0], 'scope') self.assertEqual(mock.sub_scopes[0], 'sub_scope') self.assertEqual(mock.parent_span_yields[0], None) self.assertEqual(mock.fn_argss[0], None) self.assertEqual(mock.fn_kwargss[0], None) self.assertEqual(mock.trace_optss[0], {'options': 'some_option'}) self.assertIsInstance(mock.trace_results[0], tracing.TracedSpan)
def invoke(self, comp, arg): with tracing.span('ExecutionContext', 'Invoke'): @contextlib.contextmanager def executor_closer(wrapped_executor): """Wraps an Executor into a closeable resource.""" try: yield wrapped_executor finally: wrapped_executor.close() if arg is not None: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = cardinalities_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) else: cardinalities = {} with executor_closer( self._executor_factory.create_executor( cardinalities)) as executor: py_typecheck.check_type(executor, executor_base.Executor) def get_event_loop(): new_loop = asyncio.new_event_loop() new_loop.set_task_factory( tracing.propagate_trace_context_task_factory) return new_loop event_loop = get_event_loop() if arg is not None: arg = event_loop.run_until_complete( tracing.run_coroutine_in_ambient_trace_context( _ingest(executor, unwrapped_arg, arg.type_signature))) return event_loop.run_until_complete( tracing.run_coroutine_in_ambient_trace_context( _invoke(executor, comp, arg)))
def CreateValue( self, request: executor_pb2.CreateValueRequest, context: grpc.ServicerContext, ) -> executor_pb2.CreateValueResponse: """Creates a value embedded in the executor.""" py_typecheck.check_type(request, executor_pb2.CreateValueRequest) with self._try_handle_request_context( request, context, executor_pb2.CreateValueResponse): with tracing.span('ExecutorService.CreateValue', 'deserialize_value'): value, value_type = (value_serialization.deserialize_value( request.value)) value_id = str(uuid.uuid4()) coro = self.executor(request, context).create_value(value, value_type) future_val = self._run_coro_threadsafe_with_tracing(coro) with self._lock: self._values[value_id] = future_val return executor_pb2.CreateValueResponse( value_ref=executor_pb2.ValueRef(id=value_id))
def CreateValue( self, request: executor_pb2.CreateValueRequest, context: grpc.ServicerContext, ) -> executor_pb2.CreateValueResponse: """Creates a value embedded in the executor.""" py_typecheck.check_type(request, executor_pb2.CreateValueRequest) try: with tracing.span('ExecutorService.CreateValue', 'deserialize_value'): value, value_type = ( executor_service_utils.deserialize_value(request.value)) value_id = str(uuid.uuid4()) coro = self._executor.create_value(value, value_type) future_val = self._run_coro_threadsafe_with_tracing(coro) with self._lock: self._values[value_id] = future_val return executor_pb2.CreateValueResponse( value_ref=executor_pb2.ValueRef(id=value_id)) except (ValueError, TypeError) as err: _set_invalid_arg_err(context, err) return executor_pb2.CreateValueResponse()
def _call_embedded_tf(*, arg, param_fns, result_fns, result_type, wrapped_fn, destroy_before_invocation, destroy_after_invocation): """Function to be run upon EagerTFExecutor.create_call invocation. As this function is run completely synchronously, and `EagerTFExecutor.create_call` invocations represent the main work of the program, this function should be kept as-thin a wrapper around delegation to the eager TensorFlow runtime as possible. Args: arg: Argument on which to invoke embedded function. param_fns: Functions to be applied to elements of `arg` before passing to `wrapped_fn`, to prepare these argument for ingestion by the eager TF runtime. result_fns: Functions to be applied to results of calling `wrapped_fn` on arg before re-embedding as EagerTFExecutor values. result_type: TFF Type signature of the result of `wrapped_fn`. wrapped_fn: Result of `tf.compat.v1.wrap_function` to run in the eager TF runtime. destroy_before_invocation: eager TF runtime resources which should be destroyed before invoking `wrapped_fn`. Examples might include hashtable resources. destroy_after_invocation: eager TF runtime resources which should be destroyed after invoking `wrapped_fn`. Examples might include resource variables. Returns: A `structure.Struct` representing the result of invoking `wrapped_fn` on `arg`. Raises: RuntimeError: If `arg` and `param_fns` have different numbers of elements. """ # TODO(b/166479382): This cleanup-before-invocation pattern is a workaround # to square the circle of TF data expecting to lazily reference this # resource on iteration, as well as usages that expect to reinitialize a # table with new data. Revisit the semantics implied by this cleanup # pattern. with tracing.span('EagerTFExecutor.create_call', 'resource_cleanup_before_invocation', span=True): for resource in destroy_before_invocation: tf.raw_ops.DestroyResourceOp(resource=resource) param_elements = [] if arg is not None: with tracing.span('EagerTFExecutor.create_call', 'arg_ingestion', span=True): arg_parts = structure.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. with tracing.span('EagerTFExecutor.create_call', 'resource_cleanup_after_invocation', span=True): for resource in destroy_after_invocation: tf.raw_ops.DestroyResourceOp(resource=resource) with tracing.span('EagerTFExecutor.create_call', 'result_packing', span=True): result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return structure.pack_sequence_as(result_type, result_elements)
async def middle(): with tracing.span('inner', ''): pass
def transform_to_native_form( comp: computation_base.Computation, transform_math_to_tf: bool = False, grappler_config: Optional[tf.compat.v1.ConfigProto] = None ) -> computation_base.Computation: """Compiles a computation for execution in the TFF native runtime. This function transforms the proto underlying `comp` by transforming it to call-dominant form (see `tff.framework.transform_to_call_dominant` for definition). Args: comp: Instance of `computation_base.Computation` to compile. transform_math_to_tf: Whether to additional transform math to TensorFlow graphs. Necessary if running on a execution state without ReferenceResolvingExecutors underneath FederatingExecutors. grappler_config: Configuration for Grappler optimizations to perform on the TensorFlow computations. If `None`, Grappler will not be run and no optimizations wil be applied. Returns: A new `computation_base.Computation` representing the compiled version of `comp`. """ proto = computation_impl.ComputationImpl.get_proto(comp) computation_building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) try: logging.debug('Compiling TFF computation to CDF.') with tracing.span( 'transform_to_native_form', 'transform_to_call_dominant', span=True): call_dominant_form, _ = transformations.transform_to_call_dominant( computation_building_block) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) if transform_math_to_tf: logging.debug('Compiling local computations to TensorFlow.') with tracing.span( 'transform_to_native_form', 'compile_local_computations_to_tensorflow', span=True): call_dominant_form, _ = transformations.compile_local_computations_to_tensorflow( call_dominant_form) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) if grappler_config is not None: with tracing.span( 'transform_to_native_form', 'optimize_tf_graphs', span=True): call_dominant_form, _ = transformations.optimize_tensorflow_graphs( call_dominant_form, grappler_config) with tracing.span( 'transform_to_native_form', 'transform_tf_call_ops_disable_grappler', span=True): disabled_grapler_form, _ = tree_transformations.transform_tf_call_ops_to_disable_grappler( call_dominant_form) with tracing.span( 'transform_to_native_form', 'transform_tf_add_ids', span=True): form_with_ids, _ = tree_transformations.transform_tf_add_ids( disabled_grapler_form) return computation_wrapper_instances.building_block_to_computation( form_with_ids) except ValueError as e: logging.debug('Compilation for native runtime failed with error %s', e) logging.debug('computation: %s', computation_building_block.compact_representation()) return comp