class WorkerFailureTest(parameterized.TestCase): @parameterized.named_parameters( ('native_remote_request_reply', remote_runtime_test_utils.create_localhost_remote_context( _WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts( _WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts( _WORKER_PORTS)), ('native_remote_streaming', remote_runtime_test_utils.create_localhost_remote_context( _WORKER_PORTS, rpc_mode='STREAMING'), remote_runtime_test_utils.create_localhost_worker_contexts( _WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts( _WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context( _AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS)), ) def test_computations_run_with_worker_restarts(self, context, first_contexts, second_contexts): @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(context): with contextlib.ExitStack() as stack: for server_context in first_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2]) # Closing and re-entering the server contexts serves to simulate failures # and restarts at the workers. Restarts leave the workers in a state that # needs initialization again; entering the second context ensures that the # servers need to be reinitialized by the controller. with contextlib.ExitStack() as stack: for server_context in second_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2])
def _get_all_contexts(): # pyformat: disable return [ ('native_local', tff.backends.native.create_local_execution_context()), ('native_remote', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)), ('native_sizing', tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', tff.backends.native.create_thread_debugging_execution_context()), ('reference', tff.backends.reference.create_reference_context()), ]
def test_computations_run_with_worker_restarts(self): context = remote_runtime_test_utils.create_localhost_remote_context( _WORKER_PORTS) first_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts( _WORKER_PORTS) second_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts( _WORKER_PORTS) @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(context): with contextlib.ExitStack() as stack: for server_context in first_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2]) # Closing and re-entering the server contexts serves to simulate failures # and restarts at the workers. Restarts leave the workers in a state that # needs initialization again; entering the second context ensures that the # servers need to be reinitialized by the controller. with contextlib.ExitStack() as stack: for server_context in second_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2])
def _get_all_contexts(): # pyformat: disable return [ ('native_local', tff.backends.native.create_local_execution_context()), ('native_local_caching', _create_native_local_caching_context()), ('native_remote_request_reply', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS, rpc_mode='REQUEST_REPLY'), remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)), ('native_remote_streaming', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS, rpc_mode='STREAMING'), remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)), ('native_sizing', tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', tff.backends.native.create_thread_debugging_execution_context()), ('reference', tff.backends.reference.create_reference_context()), ('test', tff.backends.test.create_test_execution_context()), ]
def test_worker_going_down_with_fixed_clients_per_round(self): tff_context = remote_runtime_test_utils.create_localhost_remote_context( _WORKER_PORTS, default_num_clients=10) worker_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts( _WORKER_PORTS) @tff.federated_computation(tff.type_at_server(tf.int32)) def sum_arg(x): return tff.federated_sum(tff.federated_broadcast(x)) context_stack = tff.framework.get_context_stack() with context_stack.install(tff_context): with worker_contexts[0]: with worker_contexts[1]: # With both workers live, we should get 10 back. self.assertEqual(sum_arg(1), 10) # Leaving the inner context kills the second worker, but should leave # the result untouched. self.assertEqual(sum_arg(1), 10)
def test_computations_run_with_worker_restarts_and_aggregation(self): context = remote_runtime_test_utils.create_localhost_remote_context( _AGGREGATOR_PORTS) # TODO(b/180524229): Swap for inprocess aggregator when mutex # corruption on shutdown is understood. aggregation_contexts = remote_runtime_test_utils.create_standalone_subprocess_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS) first_worker_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts( _WORKER_PORTS) second_worker_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts( _WORKER_PORTS) @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(context): with contextlib.ExitStack() as aggregation_stack: for server_context in aggregation_contexts: aggregation_stack.enter_context(server_context) with contextlib.ExitStack() as first_worker_stack: for server_context in first_worker_contexts: first_worker_stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2]) # Reinitializing the workers without leaving the aggregation context # simulates a worker failure, while the aggregator keeps running. with contextlib.ExitStack() as second_worker_stack: for server_context in second_worker_contexts: second_worker_stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2])
def test_runs_computation_streaming_with_intermediate_agg(self): @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one_and_sum(federated_arg): return tff.federated_sum(tff.federated_map(add_one, federated_arg)) execution_context = remote_runtime_test_utils.create_localhost_remote_context( _AGGREGATOR_PORTS, rpc_mode='STREAMING') worker_contexts = remote_runtime_test_utils.create_localhost_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS, rpc_mode='STREAMING') context_stack = tff.framework.get_context_stack() with context_stack.install(execution_context): with contextlib.ExitStack() as stack: for server_context in worker_contexts: stack.enter_context(server_context) result = map_add_one_and_sum([0, 1]) self.assertEqual(result, 3)
def test_computations_run_with_partially_available_workers(self): tff_context = remote_runtime_test_utils.create_localhost_remote_context( _WORKER_PORTS) server_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts( [_WORKER_PORTS[0]]) @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(tff_context): with contextlib.ExitStack() as stack: for server_context in server_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2])
def _get_all_contexts(): """Returns a list containing a (name, context_fn) tuple for each context.""" # pylint: disable=unnecessary-lambda # pyformat: disable return [ ('native_local_python', lambda: tff.backends.native.create_local_python_execution_context()), ('native_mergeable', lambda: _create_local_mergeable_comp_context()), ('native_remote', lambda: remote_runtime_test_utils. create_localhost_remote_context(WORKER_PORTS), lambda: remote_runtime_test_utils.create_inprocess_worker_contexts( WORKER_PORTS)), ('native_remote_intermediate_aggregator', lambda: remote_runtime_test_utils.create_localhost_remote_context( AGGREGATOR_PORTS), lambda: remote_runtime_test_utils. create_inprocess_aggregator_contexts(WORKER_PORTS, AGGREGATOR_PORTS)), ('native_sizing', lambda: tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', lambda: tff.backends.native. create_thread_debugging_execution_context()), ('test_python', lambda: tff.backends.test.create_test_python_execution_context()), ]
class TensorFlowComputationTest(parameterized.TestCase): @with_contexts def test_returns_constant(self): @tff.tf_computation def foo(): return 10 result = foo() self.assertEqual(result, 10) @with_contexts def test_returns_empty_tuple(self): @tff.tf_computation def foo(): return () result = foo() self.assertEqual(result, ()) @with_contexts def test_returns_variable(self): @tff.tf_computation def foo(): return tf.Variable(10, name='var') result = foo() self.assertEqual(result, 10) # pyformat: disable @with_contexts( ('native_local', tff.backends.native.create_local_execution_context()), ('native_local_caching', _create_native_local_caching_context()), ('native_remote', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)), ('native_sizing', tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', tff.backends.native.create_thread_debugging_execution_context()), ) # pyformat: enable def test_takes_infinite_dataset(self): @tff.tf_computation def foo(ds): return ds.take(10).reduce(np.int64(0), lambda x, y: x + y) ds = tf.data.Dataset.range(10).repeat() actual_result = foo(ds) expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(actual_result, expected_result) # pyformat: disable @with_contexts( ('native_local', tff.backends.native.create_local_execution_context()), ('native_local_caching', _create_native_local_caching_context()), ('native_remote', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)), ('native_sizing', tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', tff.backends.native.create_thread_debugging_execution_context()), ) # pyformat: enable def test_returns_infinite_dataset(self): @tff.tf_computation def foo(): return tf.data.Dataset.range(10).repeat() actual_result = foo() expected_result = tf.data.Dataset.range(10).repeat() self.assertEqual( actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y), expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y)) @with_contexts def test_returns_result_with_typed_fn(self): @tff.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y result = foo(1, 2) self.assertEqual(result, 3) @with_contexts def test_raises_type_error_with_typed_fn(self): @tff.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y with self.assertRaises(TypeError): foo(1.0, 2.0) @with_contexts def test_returns_result_with_polymorphic_fn(self): @tff.tf_computation def foo(x, y): return x + y result = foo(1, 2) self.assertEqual(result, 3) result = foo(1.0, 2.0) self.assertEqual(result, 3.0)
worker_contexts = remote_runtime_test_utils.create_localhost_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS, rpc_mode='STREAMING') context_stack = tff.framework.get_context_stack() with context_stack.install(execution_context): with contextlib.ExitStack() as stack: for server_context in worker_contexts: stack.enter_context(server_context) result = map_add_one_and_sum([0, 1]) self.assertEqual(result, 3) @parameterized.named_parameters(( 'native_remote_request_reply', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS), ), ( 'native_remote_streaming', remote_runtime_test_utils.create_localhost_remote_context( _WORKER_PORTS, rpc_mode='STREAMING'), remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS), ), ( 'native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context( _AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS), )) class RemoteRuntimeConfigurationChangeTest(absltest.TestCase): def test_computations_run_with_changing_clients(self, context,
context_stack = tff.framework.get_context_stack() with context_stack.install(tff_context): with worker_contexts[0]: with worker_contexts[1]: # With both workers live, we should get 10 back. self.assertEqual(sum_arg(1), 10) # Leaving the inner context kills the second worker, but should leave # the result untouched. self.assertEqual(sum_arg(1), 10) @parameterized.named_parameters(( 'native_remote', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS), remote_runtime_test_utils.create_inprocess_worker_contexts(_WORKER_PORTS), ), ( 'native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context( _AGGREGATOR_PORTS), remote_runtime_test_utils.create_inprocess_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS), )) class RemoteRuntimeConfigurationChangeTest(parameterized.TestCase): def test_computations_run_with_changing_clients(self, context, server_contexts): @tff.tf_computation(tf.int32) @tf.function
class RemoteRuntimeIntegrationTest(parameterized.TestCase): @parameterized.named_parameters( ('native_remote', remote_runtime_test_utils.create_localhost_remote_context( _WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts( _WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts( _WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context( _AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS)), ) def test_computations_run_with_worker_restarts(self, context, first_contexts, second_contexts): @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(context): with contextlib.ExitStack() as stack: for server_context in first_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2]) # Closing and re-entering the server contexts serves to simulate failures # and restarts at the workers. Restarts leave the workers in a state that # needs initialization again; entering the second context ensures that the # servers need to be reinitialized by the controller. with contextlib.ExitStack() as stack: for server_context in second_contexts: stack.enter_context(server_context) result = map_add_one([0, 1]) self.assertEqual(result, [1, 2]) @parameterized.named_parameters(( 'native_remote', remote_runtime_test_utils.create_localhost_remote_context( _WORKER_PORTS), remote_runtime_test_utils.create_localhost_worker_contexts( _WORKER_PORTS), list(range(len(_WORKER_PORTS) - 1)), ), ( 'native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context( _AGGREGATOR_PORTS), remote_runtime_test_utils.create_localhost_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS), list(range(len(_AGGREGATOR_PORTS) - 1)), )) def test_computations_run_with_fewer_clients_than_remote_connections( self, context, serving_contexts, arg): self.skipTest('b/170315887') @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg) context_stack = tff.framework.get_context_stack() with context_stack.install(context): with contextlib.ExitStack() as stack: for server_context in serving_contexts: stack.enter_context(server_context) result = map_add_one(arg) self.assertEqual(result, [x + 1 for x in arg])
class TensorFlowComputationTest(tf.test.TestCase, parameterized.TestCase): @test_contexts.with_contexts def test_create_call_take_two_from_stateful_dataset(self): vocab = ['a', 'b', 'c', 'd', 'e', 'f'] @tff.tf_computation(tff.SequenceType(tf.string)) def take_two(ds): table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer( vocab, tf.range(len(vocab), dtype=tf.int64)), num_oov_buckets=1) ds = ds.map(table.lookup) return ds.take(2) ds = tf.data.Dataset.from_tensor_slices(vocab) result = take_two(ds) self.assertCountEqual([x.numpy() for x in result], [0, 1]) @test_contexts.with_contexts def test_twice_used_variable_keeps_separate_state(self): def count_one_body(): variable = tf.Variable(initial_value=0, name='var_of_interest') with tf.control_dependencies([variable.assign_add(1)]): return variable.read_value() count_one_1 = tff.tf_computation(count_one_body) count_one_2 = tff.tf_computation(count_one_body) @tff.tf_computation def count_one_twice(): return count_one_1(), count_one_1(), count_one_2() self.assertEqual((1, 1, 1), count_one_twice()) @test_contexts.with_contexts def test_dynamic_lookup_table(self): @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string), tff.TensorType(shape=[None], dtype=tf.string)) def comp(table_args, to_lookup): values = tf.range(tf.shape(table_args)[0]) initializer = tf.lookup.KeyValueTensorInitializer( table_args, values) table = tf.lookup.StaticHashTable(initializer, default_value=101) return table.lookup(to_lookup) result = comp(tf.constant(['a', 'b', 'c']), tf.constant(['a', 'z', 'c'])) self.assertAllEqual(result, [0, 101, 2]) @test_contexts.with_contexts def test_reinitialize_dynamic_lookup_table(self): @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string), tff.TensorType(shape=[], dtype=tf.string)) def comp(table_args, to_lookup): values = tf.range(tf.shape(table_args)[0]) initializer = tf.lookup.KeyValueTensorInitializer( table_args, values) table = tf.lookup.StaticHashTable(initializer, default_value=101) return table.lookup(to_lookup) expected_zero = comp(tf.constant(['a', 'b', 'c']), tf.constant('a')) expected_three = comp(tf.constant(['a', 'b', 'c', 'd']), tf.constant('d')) self.assertEqual(expected_zero, 0) self.assertEqual(expected_three, 3) @test_contexts.with_contexts def test_returns_constant(self): @tff.tf_computation def foo(): return 10 result = foo() self.assertEqual(result, 10) @test_contexts.with_contexts def test_returns_empty_tuple(self): @tff.tf_computation def foo(): return () result = foo() self.assertEqual(result, ()) @test_contexts.with_contexts def test_returns_variable(self): @tff.tf_computation def foo(): return tf.Variable(10, name='var') result = foo() self.assertEqual(result, 10) # pyformat: disable @test_contexts.with_contexts( # pylint: disable=unnecessary-lambda ('native_local', lambda: tff.backends.native.create_local_python_execution_context()), ('native_remote', lambda: remote_runtime_test_utils. create_localhost_remote_context(test_contexts.WORKER_PORTS), lambda: remote_runtime_test_utils.create_inprocess_worker_contexts( test_contexts.WORKER_PORTS)), ('native_remote_intermediate_aggregator', lambda: remote_runtime_test_utils.create_localhost_remote_context( test_contexts.AGGREGATOR_PORTS), lambda: remote_runtime_test_utils.create_inprocess_aggregator_contexts( test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)), ('native_sizing', lambda: tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', lambda: tff.backends.native. create_thread_debugging_execution_context()), ) # pyformat: enable def test_takes_infinite_dataset(self): @tff.tf_computation def foo(ds): return ds.take(10).reduce(np.int64(0), lambda x, y: x + y) ds = tf.data.Dataset.range(10).repeat() actual_result = foo(ds) expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(actual_result, expected_result) # pyformat: disable @test_contexts.with_contexts( # pylint: disable=unnecessary-lambda ('native_local', lambda: tff.backends.native.create_local_python_execution_context()), ('native_remote', lambda: remote_runtime_test_utils. create_localhost_remote_context(test_contexts.WORKER_PORTS), lambda: remote_runtime_test_utils.create_inprocess_worker_contexts( test_contexts.WORKER_PORTS)), ('native_remote_intermediate_aggregator', lambda: remote_runtime_test_utils.create_localhost_remote_context( test_contexts.AGGREGATOR_PORTS), lambda: remote_runtime_test_utils.create_inprocess_aggregator_contexts( test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)), ('native_sizing', lambda: tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', lambda: tff.backends.native. create_thread_debugging_execution_context()), ) # pyformat: enable def test_returns_infinite_dataset(self): @tff.tf_computation def foo(): return tf.data.Dataset.range(10).repeat() actual_result = foo() expected_result = tf.data.Dataset.range(10).repeat() self.assertEqual( actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y), expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y)) @test_contexts.with_contexts def test_returns_result_with_typed_fn(self): @tff.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y result = foo(1, 2) self.assertEqual(result, 3) @test_contexts.with_contexts def test_raises_type_error_with_typed_fn(self): @tff.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y with self.assertRaises(TypeError): foo(1.0, 2.0) @test_contexts.with_contexts def test_returns_result_with_polymorphic_fn(self): @tff.tf_computation def foo(x, y): return x + y result = foo(1, 2) self.assertEqual(result, 3) result = foo(1.0, 2.0) self.assertEqual(result, 3.0)