示例#1
0
    def test_raises_cardinality_mismatch(self):
        factory = python_executor_stacks.local_executor_factory()

        def _cardinality_fn(x, y):
            del x, y  # Unused
            return {placements.CLIENTS: 1}

        context = async_execution_context.AsyncExecutionContext(
            factory, cardinality_inference_fn=_cardinality_fn)

        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)

        @federated_computation.federated_computation(arg_type)
        def identity(x):
            return x

        with get_context_stack.get_context_stack().install(context):
            # This argument conflicts with the value returned by the
            # cardinality-inference function; we should get an error surfaced.
            data = [0, 1]
            val_coro = identity(data)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            with self.assertRaises(executors_errors.CardinalityError):
                asyncio.run(val_coro)
    def __init__(
        self,
        executor_fn: executor_factory.ExecutorFactory,
        compiler_fn: Optional[Callable[[computation_base.Computation],
                                       Any]] = None,
        *,
        cardinality_inference_fn: cardinalities_utils.
        CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities):
        """Initializes a synchronous execution context which retries invocations.

    Args:
      executor_fn: Instance of `executor_factory.ExecutorFactory`.
      compiler_fn: A Python function that will be used to compile a computation.
      cardinality_inference_fn: A Python function specifying how to infer
        cardinalities from arguments (and their associated types). The value
        returned by this function will be passed to the `create_executor` method
        of `executor_fn` to construct a `tff.framework.Executor` instance.
    """
        py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory)
        self._executor_factory = executor_fn
        self._async_context = async_execution_context.AsyncExecutionContext(
            executor_fn=executor_fn,
            compiler_fn=compiler_fn,
            cardinality_inference_fn=cardinality_inference_fn)
        self._async_runner = async_utils.AsyncThreadRunner()
示例#3
0
    def test_install_and_execute_in_context(self):
        factory = python_executor_stacks.local_executor_factory()
        context = async_execution_context.AsyncExecutionContext(factory)

        @tensorflow_computation.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        with get_context_stack.get_context_stack().install(context):
            val_coro = add_one(1)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            self.assertEqual(asyncio.run(val_coro), 2)
def _make_basic_python_execution_context(*, executor_fn, compiler_fn,
                                         asynchronous):
    """Wires executor function and compiler into sync or async context."""

    if not asynchronous:
        context = sync_execution_context.ExecutionContext(
            executor_fn=executor_fn, compiler_fn=compiler_fn)
    else:
        context = async_execution_context.AsyncExecutionContext(
            executor_fn=executor_fn, compiler_fn=compiler_fn)

    return context
示例#5
0
    def test_runs_cardinality_free(self):
        factory = python_executor_stacks.local_executor_factory()
        context = async_execution_context.AsyncExecutionContext(
            factory, cardinality_inference_fn=(lambda x, y: {}))

        @federated_computation.federated_computation(tf.int32)
        def identity(x):
            return x

        with get_context_stack.get_context_stack().install(context):
            data = 0
            # This computation is independent of cardinalities
            val_coro = identity(data)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            self.assertEqual(asyncio.run(val_coro), 0)
示例#6
0
    def __init__(
        self,
        executor_factories: Sequence[executor_factory.ExecutorFactory],
        compiler_fn: Optional[Callable[[computation_base.Computation],
                                       MergeableCompForm]] = None):
        self._async_runner = async_utils.AsyncThreadRunner()
        self._async_execution_contexts = [
            async_execution_context.AsyncExecutionContext(ex_factory)
            for ex_factory in executor_factories
        ]

        if compiler_fn is not None:
            self._compiler_pipeline = compiler_pipeline.CompilerPipeline(
                compiler_fn)
        else:
            self._compiler_pipeline = None
示例#7
0
    def test_install_and_execute_computations_with_different_cardinalities(
            self):
        factory = python_executor_stacks.local_executor_factory()
        context = async_execution_context.AsyncExecutionContext(factory)

        @federated_computation.federated_computation(
            computation_types.FederatedType(tf.int32, placements.CLIENTS))
        def repackage_arg(x):
            return [x, x]

        with get_context_stack.get_context_stack().install(context):
            single_val_coro = repackage_arg([1])
            second_val_coro = repackage_arg([1, 2])
            self.assertTrue(asyncio.iscoroutine(single_val_coro))
            self.assertTrue(asyncio.iscoroutine(second_val_coro))
            self.assertEqual(
                [asyncio.run(single_val_coro),
                 asyncio.run(second_val_coro)], [[[1], [1]], [[1, 2], [1, 2]]])
示例#8
0
    def __init__(self,
                 executor_fn: executor_factory.ExecutorFactory,
                 compiler_fn: Optional[Callable[[computation_base.Computation],
                                                Any]] = None):
        """Initializes a synchronous execution context which retries invocations.

    Args:
      executor_fn: Instance of `executor_factory.ExecutorFactory`.
      compiler_fn: A Python function that will be used to compile a computation.
    """
        py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory)
        self._executor_factory = executor_fn
        self._async_context = async_execution_context.AsyncExecutionContext(
            executor_fn=executor_fn, compiler_fn=compiler_fn)

        self._event_loop = asyncio.new_event_loop()
        self._event_loop.set_task_factory(
            tracing.propagate_trace_context_task_factory)