def create_thread_debugging_execution_context(default_num_clients: int = 0, clients_per_thread=1): """Creates a simple execution context that executes computations locally.""" factory = executor_stacks.thread_debugging_executor_factory( default_num_clients=default_num_clients, clients_per_thread=clients_per_thread, ) def _debug_compiler(comp): return compiler.transform_to_native_form(comp, transform_math_to_tf=True) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=_debug_compiler)
def create_thread_debugging_execution_context(num_clients=None, clients_per_thread=1): """Creates a simple execution context that executes computations locally.""" factory = executor_stacks.thread_debugging_executor_factory( num_clients=num_clients, clients_per_thread=clients_per_thread, ) def _debug_compiler(comp): native_form = compiler.transform_to_native_form(comp) return compiler.transform_mathematical_functions_to_tensorflow(native_form) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=_debug_compiler)
def _create_concurrent_maxthread_tuples(): tuples = [] for concurrency in range(1, 5): local_ex_string = 'local_executor_{}_clients_per_thread'.format(concurrency) ex_factory = executor_stacks.local_executor_factory( clients_per_thread=concurrency) tuples.append((local_ex_string, ex_factory, concurrency)) sizing_ex_string = 'sizing_executor_{}_client_thread'.format(concurrency) ex_factory = executor_stacks.sizing_executor_factory( clients_per_thread=concurrency) tuples.append((sizing_ex_string, ex_factory, concurrency)) debug_ex_string = 'debug_executor_{}_client_thread'.format(concurrency) ex_factory = executor_stacks.thread_debugging_executor_factory( clients_per_thread=concurrency) tuples.append((debug_ex_string, ex_factory, concurrency)) return tuples
def test_thread_debugging_executor_constructs_exactly_one_reference_resolving_executor( self, executor_mock): executor_stacks.thread_debugging_executor_factory().create_executor( {placements.CLIENTS: 10}) executor_mock.assert_called_once()
class ExecutorStacksTest(parameterized.TestCase): @parameterized.named_parameters( ('local_executor', executor_stacks.local_executor_factory), ('sizing_executor', executor_stacks.sizing_executor_factory), ('debug_executor', executor_stacks.thread_debugging_executor_factory), ) def test_construction_with_no_args(self, executor_factory_fn): executor_factory_impl = executor_factory_fn() self.assertIsInstance(executor_factory_impl, executor_stacks.ResourceManagingExecutorFactory) @parameterized.named_parameters( ('local_executor', executor_stacks.local_executor_factory), ('sizing_executor', executor_stacks.sizing_executor_factory), ) def test_construction_raises_with_max_fanout_one(self, executor_factory_fn): with self.assertRaises(ValueError): executor_factory_fn(max_fanout=1) @parameterized.named_parameters( ('local_executor_none_clients', executor_stacks.local_executor_factory()), ('sizing_executor_none_clients', executor_stacks.sizing_executor_factory()), ('local_executor_three_clients', executor_stacks.local_executor_factory(num_clients=3)), ('sizing_executor_three_clients', executor_stacks.sizing_executor_factory(num_clients=3)), ) @test_utils.skip_test_for_multi_gpu def test_execution_of_temperature_sensor_example(self, executor): comp = _temperature_sensor_example_next_fn() to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(10).map(to_float), tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float), ] threshold = 15.0 with executor_test_utils.install_executor(executor): result = comp(temperatures, threshold) self.assertAlmostEqual(result, 8.333, places=3) @parameterized.named_parameters( ('local_executor', executor_stacks.local_executor_factory), ('sizing_executor', executor_stacks.sizing_executor_factory), ) def test_execution_with_inferred_clients_larger_than_fanout( self, executor_factory_fn): @computations.federated_computation( computation_types.at_clients(tf.int32)) def foo(x): return intrinsics.federated_sum(x) executor = executor_factory_fn(max_fanout=3) with executor_test_utils.install_executor(executor): result = foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) self.assertEqual(result, 55) @parameterized.named_parameters( ('local_executor_none_clients', executor_stacks.local_executor_factory()), ('sizing_executor_none_clients', executor_stacks.sizing_executor_factory()), ('debug_executor_none_clients', executor_stacks.thread_debugging_executor_factory()), ('local_executor_one_client', executor_stacks.local_executor_factory(num_clients=1)), ('sizing_executor_one_client', executor_stacks.sizing_executor_factory(num_clients=1)), ('debug_executor_one_client', executor_stacks.thread_debugging_executor_factory(num_clients=1)), ) def test_execution_of_tensorflow(self, executor): @computations.tf_computation def comp(): return tf.math.add(5, 5) with executor_test_utils.install_executor(executor): result = comp() self.assertEqual(result, 10) @parameterized.named_parameters(*_create_concurrent_maxthread_tuples()) def test_limiting_concurrency_constructs_one_eager_executor( self, ex_factory, clients_per_thread, tf_executor_mock): num_clients = 10 ex_factory.create_executor({placements.CLIENTS: num_clients}) concurrency_level = math.ceil(num_clients / clients_per_thread) args_list = tf_executor_mock.call_args_list # One for server executor, one for unplaced executor, concurrency_level for # clients. self.assertLen(args_list, concurrency_level + 2) @mock.patch( 'tensorflow_federated.python.core.impl.executors.reference_resolving_executor.ReferenceResolvingExecutor', return_value=ExecutorMock()) def test_thread_debugging_executor_constructs_exactly_one_reference_resolving_executor( self, executor_mock): executor_stacks.thread_debugging_executor_factory().create_executor( {placements.CLIENTS: 10}) executor_mock.assert_called_once() @parameterized.named_parameters( ('local_executor', executor_stacks.local_executor_factory), ('sizing_executor', executor_stacks.sizing_executor_factory), ('debug_executor', executor_stacks.thread_debugging_executor_factory), ) def test_create_executor_raises_with_wrong_cardinalities( self, executor_factory_fn): executor_factory_impl = executor_factory_fn(num_clients=5) cardinalities = { placements.SERVER: 1, None: 1, placements.CLIENTS: 1, } with self.assertRaises(ValueError, ): executor_factory_impl.create_executor(cardinalities)