Esempio n. 1
0
class WorkerFailureTest(parameterized.TestCase):
    @parameterized.named_parameters(
        ('native_remote_request_reply',
         remote_runtime_test_utils.create_localhost_remote_context(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS)),
        ('native_remote_streaming',
         remote_runtime_test_utils.create_localhost_remote_context(
             _WORKER_PORTS, rpc_mode='STREAMING'),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         remote_runtime_test_utils.create_localhost_remote_context(
             _AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_localhost_aggregator_contexts(
             _WORKER_PORTS, _AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_localhost_aggregator_contexts(
             _WORKER_PORTS, _AGGREGATOR_PORTS)),
    )
    def test_computations_run_with_worker_restarts(self, context,
                                                   first_contexts,
                                                   second_contexts):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(context):

            with contextlib.ExitStack() as stack:
                for server_context in first_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])

            # Closing and re-entering the server contexts serves to simulate failures
            # and restarts at the workers. Restarts leave the workers in a state that
            # needs initialization again; entering the second context ensures that the
            # servers need to be reinitialized by the controller.
            with contextlib.ExitStack() as stack:
                for server_context in second_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])
Esempio n. 2
0
def _get_all_contexts():
  # pyformat: disable
  return [
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_remote',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
      ('reference', tff.backends.reference.create_reference_context()),
  ]
Esempio n. 3
0
def _get_all_contexts():
  # pyformat: disable
  return [
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_local_caching', _create_native_local_caching_context()),
      ('native_remote_request_reply',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS, rpc_mode='REQUEST_REPLY'),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_streaming',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS, rpc_mode='STREAMING'),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
      ('reference', tff.backends.reference.create_reference_context()),
      ('test', tff.backends.test.create_test_execution_context()),
  ]
Esempio n. 4
0
    def test_runs_computation_streaming_with_intermediate_agg(self):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one_and_sum(federated_arg):
            return tff.federated_sum(tff.federated_map(add_one, federated_arg))

        execution_context = remote_runtime_test_utils.create_localhost_remote_context(
            _AGGREGATOR_PORTS, rpc_mode='STREAMING')
        worker_contexts = remote_runtime_test_utils.create_localhost_aggregator_contexts(
            _WORKER_PORTS, _AGGREGATOR_PORTS, rpc_mode='STREAMING')

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(execution_context):

            with contextlib.ExitStack() as stack:
                for server_context in worker_contexts:
                    stack.enter_context(server_context)
                result = map_add_one_and_sum([0, 1])
                self.assertEqual(result, 3)
Esempio n. 5
0
class TensorFlowComputationTest(parameterized.TestCase):

  @with_contexts
  def test_returns_constant(self):

    @tff.tf_computation
    def foo():
      return 10

    result = foo()
    self.assertEqual(result, 10)

  @with_contexts
  def test_returns_empty_tuple(self):

    @tff.tf_computation
    def foo():
      return ()

    result = foo()
    self.assertEqual(result, ())

  @with_contexts
  def test_returns_variable(self):

    @tff.tf_computation
    def foo():
      return tf.Variable(10, name='var')

    result = foo()
    self.assertEqual(result, 10)

  # pyformat: disable
  @with_contexts(
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_local_caching', _create_native_local_caching_context()),
      ('native_remote',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
  )
  # pyformat: enable
  def test_takes_infinite_dataset(self):

    @tff.tf_computation
    def foo(ds):
      return ds.take(10).reduce(np.int64(0), lambda x, y: x + y)

    ds = tf.data.Dataset.range(10).repeat()
    actual_result = foo(ds)

    expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y)
    self.assertEqual(actual_result, expected_result)

  # pyformat: disable
  @with_contexts(
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_local_caching', _create_native_local_caching_context()),
      ('native_remote',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
  )
  # pyformat: enable
  def test_returns_infinite_dataset(self):

    @tff.tf_computation
    def foo():
      return tf.data.Dataset.range(10).repeat()

    actual_result = foo()

    expected_result = tf.data.Dataset.range(10).repeat()
    self.assertEqual(
        actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y),
        expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y))

  @with_contexts
  def test_returns_result_with_typed_fn(self):

    @tff.tf_computation(tf.int32, tf.int32)
    def foo(x, y):
      return x + y

    result = foo(1, 2)
    self.assertEqual(result, 3)

  @with_contexts
  def test_raises_type_error_with_typed_fn(self):

    @tff.tf_computation(tf.int32, tf.int32)
    def foo(x, y):
      return x + y

    with self.assertRaises(TypeError):
      foo(1.0, 2.0)

  @with_contexts
  def test_returns_result_with_polymorphic_fn(self):

    @tff.tf_computation
    def foo(x, y):
      return x + y

    result = foo(1, 2)
    self.assertEqual(result, 3)
    result = foo(1.0, 2.0)
    self.assertEqual(result, 3.0)
Esempio n. 6
0

@parameterized.named_parameters((
    'native_remote_request_reply',
    remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
    remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS),
), (
    'native_remote_streaming',
    remote_runtime_test_utils.create_localhost_remote_context(
        _WORKER_PORTS, rpc_mode='STREAMING'),
    remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS),
), (
    'native_remote_intermediate_aggregator',
    remote_runtime_test_utils.create_localhost_remote_context(
        _AGGREGATOR_PORTS),
    remote_runtime_test_utils.create_localhost_aggregator_contexts(
        _WORKER_PORTS, _AGGREGATOR_PORTS),
))
class RemoteRuntimeConfigurationChangeTest(absltest.TestCase):
    def test_computations_run_with_changing_clients(self, context,
                                                    server_contexts):
        self.skipTest('b/175155128')

        @tff.tf_computation(tf.int32)
        @tf.function
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)
Esempio n. 7
0
class RemoteRuntimeIntegrationTest(parameterized.TestCase):
    @parameterized.named_parameters(
        ('native_remote',
         remote_runtime_test_utils.create_localhost_remote_context(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         remote_runtime_test_utils.create_localhost_remote_context(
             _AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_localhost_aggregator_contexts(
             _WORKER_PORTS, _AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_localhost_aggregator_contexts(
             _WORKER_PORTS, _AGGREGATOR_PORTS)),
    )
    def test_computations_run_with_worker_restarts(self, context,
                                                   first_contexts,
                                                   second_contexts):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(context):

            with contextlib.ExitStack() as stack:
                for server_context in first_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])

            # Closing and re-entering the server contexts serves to simulate failures
            # and restarts at the workers. Restarts leave the workers in a state that
            # needs initialization again; entering the second context ensures that the
            # servers need to be reinitialized by the controller.
            with contextlib.ExitStack() as stack:
                for server_context in second_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])

    @parameterized.named_parameters((
        'native_remote',
        remote_runtime_test_utils.create_localhost_remote_context(
            _WORKER_PORTS),
        remote_runtime_test_utils.create_localhost_worker_contexts(
            _WORKER_PORTS),
        list(range(len(_WORKER_PORTS) - 1)),
    ), (
        'native_remote_intermediate_aggregator',
        remote_runtime_test_utils.create_localhost_remote_context(
            _AGGREGATOR_PORTS),
        remote_runtime_test_utils.create_localhost_aggregator_contexts(
            _WORKER_PORTS, _AGGREGATOR_PORTS),
        list(range(len(_AGGREGATOR_PORTS) - 1)),
    ))
    def test_computations_run_with_fewer_clients_than_remote_connections(
            self, context, serving_contexts, arg):
        self.skipTest('b/170315887')

        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(context):

            with contextlib.ExitStack() as stack:
                for server_context in serving_contexts:
                    stack.enter_context(server_context)
                result = map_add_one(arg)
                self.assertEqual(result, [x + 1 for x in arg])