def test_raises_cardinality_mismatch(self): factory = python_executor_stacks.local_executor_factory() def _cardinality_fn(x, y): del x, y # Unused return {placements.CLIENTS: 1} context = async_execution_context.AsyncExecutionContext( factory, cardinality_inference_fn=_cardinality_fn) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) @federated_computation.federated_computation(arg_type) def identity(x): return x with get_context_stack.get_context_stack().install(context): # This argument conflicts with the value returned by the # cardinality-inference function; we should get an error surfaced. data = [0, 1] val_coro = identity(data) self.assertTrue(asyncio.iscoroutine(val_coro)) with self.assertRaises(executors_errors.CardinalityError): asyncio.run(val_coro)
def __init__( self, executor_fn: executor_factory.ExecutorFactory, compiler_fn: Optional[Callable[[computation_base.Computation], Any]] = None, *, cardinality_inference_fn: cardinalities_utils. CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities): """Initializes a synchronous execution context which retries invocations. Args: executor_fn: Instance of `executor_factory.ExecutorFactory`. compiler_fn: A Python function that will be used to compile a computation. cardinality_inference_fn: A Python function specifying how to infer cardinalities from arguments (and their associated types). The value returned by this function will be passed to the `create_executor` method of `executor_fn` to construct a `tff.framework.Executor` instance. """ py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory) self._executor_factory = executor_fn self._async_context = async_execution_context.AsyncExecutionContext( executor_fn=executor_fn, compiler_fn=compiler_fn, cardinality_inference_fn=cardinality_inference_fn) self._async_runner = async_utils.AsyncThreadRunner()
def test_install_and_execute_in_context(self): factory = python_executor_stacks.local_executor_factory() context = async_execution_context.AsyncExecutionContext(factory) @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 with get_context_stack.get_context_stack().install(context): val_coro = add_one(1) self.assertTrue(asyncio.iscoroutine(val_coro)) self.assertEqual(asyncio.run(val_coro), 2)
def _make_basic_python_execution_context(*, executor_fn, compiler_fn, asynchronous): """Wires executor function and compiler into sync or async context.""" if not asynchronous: context = sync_execution_context.ExecutionContext( executor_fn=executor_fn, compiler_fn=compiler_fn) else: context = async_execution_context.AsyncExecutionContext( executor_fn=executor_fn, compiler_fn=compiler_fn) return context
def test_runs_cardinality_free(self): factory = python_executor_stacks.local_executor_factory() context = async_execution_context.AsyncExecutionContext( factory, cardinality_inference_fn=(lambda x, y: {})) @federated_computation.federated_computation(tf.int32) def identity(x): return x with get_context_stack.get_context_stack().install(context): data = 0 # This computation is independent of cardinalities val_coro = identity(data) self.assertTrue(asyncio.iscoroutine(val_coro)) self.assertEqual(asyncio.run(val_coro), 0)
def __init__( self, executor_factories: Sequence[executor_factory.ExecutorFactory], compiler_fn: Optional[Callable[[computation_base.Computation], MergeableCompForm]] = None): self._async_runner = async_utils.AsyncThreadRunner() self._async_execution_contexts = [ async_execution_context.AsyncExecutionContext(ex_factory) for ex_factory in executor_factories ] if compiler_fn is not None: self._compiler_pipeline = compiler_pipeline.CompilerPipeline( compiler_fn) else: self._compiler_pipeline = None
def test_install_and_execute_computations_with_different_cardinalities( self): factory = python_executor_stacks.local_executor_factory() context = async_execution_context.AsyncExecutionContext(factory) @federated_computation.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def repackage_arg(x): return [x, x] with get_context_stack.get_context_stack().install(context): single_val_coro = repackage_arg([1]) second_val_coro = repackage_arg([1, 2]) self.assertTrue(asyncio.iscoroutine(single_val_coro)) self.assertTrue(asyncio.iscoroutine(second_val_coro)) self.assertEqual( [asyncio.run(single_val_coro), asyncio.run(second_val_coro)], [[[1], [1]], [[1, 2], [1, 2]]])
def __init__(self, executor_fn: executor_factory.ExecutorFactory, compiler_fn: Optional[Callable[[computation_base.Computation], Any]] = None): """Initializes a synchronous execution context which retries invocations. Args: executor_fn: Instance of `executor_factory.ExecutorFactory`. compiler_fn: A Python function that will be used to compile a computation. """ py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory) self._executor_factory = executor_fn self._async_context = async_execution_context.AsyncExecutionContext( executor_fn=executor_fn, compiler_fn=compiler_fn) self._event_loop = asyncio.new_event_loop() self._event_loop.set_task_factory( tracing.propagate_trace_context_task_factory)