Пример #1
0
    def test_raises_cardinality_mismatch(self):
        factory = python_executor_stacks.local_executor_factory()

        def _cardinality_fn(x, y):
            del x, y  # Unused
            return {placements.CLIENTS: 1}

        context = async_execution_context.AsyncExecutionContext(
            factory, cardinality_inference_fn=_cardinality_fn)

        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)

        @federated_computation.federated_computation(arg_type)
        def identity(x):
            return x

        with get_context_stack.get_context_stack().install(context):
            # This argument conflicts with the value returned by the
            # cardinality-inference function; we should get an error surfaced.
            data = [0, 1]
            val_coro = identity(data)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            with self.assertRaises(executors_errors.CardinalityError):
                asyncio.run(val_coro)
Пример #2
0
def check_in_federated_context() -> None:
  """Checks if the current context is a `tff.program.FederatedContext`."""
  context_stack = get_context_stack.get_context_stack()
  if not isinstance(context_stack.current, FederatedContext):
    raise ValueError(
        'Expected the current context to be a `tff.program.FederatedContext`, '
        f'found \'{type(context_stack.current)}\'.')
Пример #3
0
    def test_basic_functionality(self):
        ctx_stack = get_context_stack.get_context_stack()
        self.assertIsInstance(ctx_stack, context_stack_impl.ContextStackImpl)
        self.assertIsInstance(ctx_stack.current,
                              execution_context.ExecutionContext)

        with ctx_stack.install(context_stack_test_utils.TestContext('foo')):
            self.assertIsInstance(
                get_context_stack.get_context_stack().current,
                context_stack_test_utils.TestContext)
            self.assertEqual(
                get_context_stack.get_context_stack().current.name, 'foo')

            with ctx_stack.install(
                    context_stack_test_utils.TestContext('bar')):
                self.assertIsInstance(
                    get_context_stack.get_context_stack().current,
                    context_stack_test_utils.TestContext)
                self.assertEqual(
                    get_context_stack.get_context_stack().current.name, 'bar')

            self.assertEqual(
                get_context_stack.get_context_stack().current.name, 'foo')

        self.assertIsInstance(get_context_stack.get_context_stack().current,
                              execution_context.ExecutionContext)
 def test_set_dafault_context(self):
     ctx_stack = get_context_stack.get_context_stack()
     self.assertIsInstance(ctx_stack.current,
                           execution_context.ExecutionContext)
     foo = context_stack_test_utils.TestContext('foo')
     set_default_context.set_default_context(foo)
     self.assertIs(ctx_stack.current, foo)
     set_default_context.set_default_context()
     self.assertIsInstance(ctx_stack.current,
                           execution_context.ExecutionContext)
Пример #5
0
  def test_install_and_execute_in_context(self):
    context = cpp_execution_contexts.create_local_async_cpp_execution_context()

    @tensorflow_computation.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    with get_context_stack.get_context_stack().install(context):
      val_coro = add_one(1)
      self.assertTrue(asyncio.iscoroutine(val_coro))
      self.assertEqual(asyncio.run(val_coro), 2)
Пример #6
0
  def test_returns_same_python_structure(self):

    @federated_computation.federated_computation(
        collections.OrderedDict(a=tf.int32, b=tf.float32))
    def identity(x):
      return x

    context = cpp_execution_contexts.create_local_cpp_execution_context()
    with get_context_stack.get_context_stack().install(context):
      odict = identity(collections.OrderedDict(a=0, b=1.))

    self.assertIsInstance(odict, collections.OrderedDict)
Пример #7
0
    def test_install_and_execute_in_context(self):
        factory = python_executor_stacks.local_executor_factory()
        context = async_execution_context.AsyncExecutionContext(factory)

        @tensorflow_computation.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        with get_context_stack.get_context_stack().install(context):
            val_coro = add_one(1)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            self.assertEqual(asyncio.run(val_coro), 2)
    def test_stack_resets_on_none_returned(self):
        stack = get_context_stack.get_context_stack()
        self.assertIsInstance(stack.current,
                              runtime_error_context.RuntimeErrorContext)

        try:

            @computation_wrapper_instances.federated_computation_wrapper()
            def _():
                pass

        except computation_wrapper.ComputationReturnedNoneError:
            self.assertIsInstance(  # pylint: disable=g-assert-in-except
                stack.current, runtime_error_context.RuntimeErrorContext)
Пример #9
0
  def test_runs_tensorflow(self):

    @tensorflow_computation.tf_computation(
        collections.OrderedDict(x=tf.int32, y=tf.int32))
    def multiply(ordered_dict):
      return ordered_dict['x'] * ordered_dict['y']

    context = cpp_execution_contexts.create_local_cpp_execution_context()
    with get_context_stack.get_context_stack().install(context):
      zero = multiply(collections.OrderedDict(x=0, y=1))
      one = multiply(collections.OrderedDict(x=1, y=1))

    self.assertEqual(zero, 0)
    self.assertEqual(one, 1)
Пример #10
0
    def test_runs_cardinality_free(self):
        factory = python_executor_stacks.local_executor_factory()
        context = async_execution_context.AsyncExecutionContext(
            factory, cardinality_inference_fn=(lambda x, y: {}))

        @federated_computation.federated_computation(tf.int32)
        def identity(x):
            return x

        with get_context_stack.get_context_stack().install(context):
            data = 0
            # This computation is independent of cardinalities
            val_coro = identity(data)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            self.assertEqual(asyncio.run(val_coro), 0)
Пример #11
0
  def test_install_and_execute_computations_with_different_cardinalities(self):
    context = cpp_execution_contexts.create_local_async_cpp_execution_context()

    @federated_computation.federated_computation(
        computation_types.FederatedType(tf.int32, placements.CLIENTS))
    def repackage_arg(x):
      return [x, x]

    with get_context_stack.get_context_stack().install(context):
      single_val_coro = repackage_arg([1])
      second_val_coro = repackage_arg([1, 2])
      self.assertTrue(asyncio.iscoroutine(single_val_coro))
      self.assertTrue(asyncio.iscoroutine(second_val_coro))
      self.assertEqual(
          [asyncio.run(single_val_coro),
           asyncio.run(second_val_coro)], [[[1], [1]], [[1, 2], [1, 2]]])
Пример #12
0
    def test_get_size_info(self, num_clients):
        @computations.federated_computation(
            type_factory.at_clients(computation_types.SequenceType(
                tf.float32)), type_factory.at_server(tf.float32))
        def comp(temperatures, threshold):
            client_data = [
                temperatures,
                intrinsics.federated_broadcast(threshold)
            ]
            result_map = intrinsics.federated_map(
                count_over, intrinsics.federated_zip(client_data))
            count_map = intrinsics.federated_map(count_total, temperatures)
            return intrinsics.federated_mean(result_map, count_map)

        sizing_factory = executor_stacks.sizing_executor_factory(
            num_clients=num_clients)
        sizing_context = execution_context.ExecutionContext(sizing_factory)
        with get_context_stack.get_context_stack().install(sizing_context):
            to_float = lambda x: tf.cast(x, tf.float32)
            temperatures = [tf.data.Dataset.range(10).map(to_float)
                            ] * num_clients
            threshold = 15.0
            comp(temperatures, threshold)

            # Each client receives a tf.float32 and uploads two tf.float32 values.
            expected_broadcast_bits = [num_clients * 32]
            expected_aggregate_bits = [num_clients * 32 * 2]
            expected_broadcast_history = {
                (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients
            }
            expected_aggregate_history = {
                (('CLIENTS', num_clients), ):
                [[1, tf.float32]] * num_clients * 2
            }

            size_info = sizing_factory.get_size_info()

            self.assertEqual(expected_broadcast_history,
                             size_info.broadcast_history)
            self.assertEqual(expected_aggregate_history,
                             size_info.aggregate_history)
            self.assertEqual(expected_broadcast_bits, size_info.broadcast_bits)
            self.assertEqual(expected_aggregate_bits, size_info.aggregate_bits)
Пример #13
0
  def test_returns_datasets(self):

    @tensorflow_computation.tf_computation
    def create_dataset():
      return tf.data.Dataset.range(5)

    context = cpp_execution_contexts.create_local_cpp_execution_context()
    with get_context_stack.get_context_stack().install(context):
      with self.subTest('unplaced'):
        dataset = create_dataset()
        self.assertEqual(dataset.element_spec,
                         tf.TensorSpec(shape=[], dtype=tf.int64))
        self.assertEqual(tf.data.experimental.cardinality(dataset), 5)
      with self.subTest('federated'):

        @federated_computation.federated_computation
        def create_federated_dataset():
          return intrinsics.federated_eval(create_dataset, placements.SERVER)

        dataset = create_federated_dataset()
        self.assertEqual(dataset.element_spec,
                         tf.TensorSpec(shape=[], dtype=tf.int64))
        self.assertEqual(tf.data.experimental.cardinality(dataset), 5)
      with self.subTest('struct'):

        @tensorflow_computation.tf_computation()
        def create_struct_of_datasets():
          return (create_dataset(), create_dataset())

        datasets = create_struct_of_datasets()
        self.assertLen(datasets, 2)
        self.assertEqual([d.element_spec for d in datasets], [
            tf.TensorSpec(shape=[], dtype=tf.int64),
            tf.TensorSpec(shape=[], dtype=tf.int64),
        ])
        self.assertEqual(
            [tf.data.experimental.cardinality(d) for d in datasets], [5, 5])
Пример #14
0
 def test_returns_context(self):
     context_stack = get_context_stack.get_context_stack()
     self.assertIsInstance(context_stack,
                           context_stack_impl.ContextStackImpl)