def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = executor_stacks.local_executor_factory( num_clients=3).create_executor({}) tracer = executor_test_utils.TracingExecutor(target_executor) service = executor_service.ExecutorService(tracer) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = lambda_executor.LambdaExecutor(remote_exec) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl(lambda _: executor)) try: yield collections.namedtuple('_', 'executor tracer')(executor, tracer) finally: set_default_executor.set_default_executor() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def worker_pool_executor_factory(executors, max_fanout=100 ) -> executor_factory.ExecutorFactory: """Create an executor backed by a worker pool. Args: executors: A list of `tff.framework.Executor` instances that forward work to workers in the worker pool. These can be any type of executors, but in most scenarios, they will be instances of `tff.framework.RemoteExecutor`. max_fanout: The maximum fanout at any point in the aggregation hierarchy. If `num_clients > max_fanout`, the constructed executor stack will consist of multiple levels of aggregators. The height of the stack will be on the order of `log(num_clients) / log(max_fanout)`. Returns: An instance of `executor_factory.ExecutorFactory` encapsulating the executor construction logic specified above. """ py_typecheck.check_type(executors, list) py_typecheck.check_type(max_fanout, int) if not executors: raise ValueError('The list executors cannot be empty.') if max_fanout < 2: raise ValueError('Max fanout must be greater than 1.') executors = [_complete_stack(e) for e in executors] def _stack_fn(cardinalities): del cardinalities # Unused return _aggregate_stacks(executors, max_fanout) return executor_factory.ExecutorFactoryImpl(executor_stack_fn=_stack_fn)
def test_cleanup_succeeds_without_init(self): def _stack_fn(x): del x # Unused return eager_executor.EagerExecutor() factory = executor_factory.ExecutorFactoryImpl(_stack_fn) factory.clean_up_executors()
def test_concrete_class_instantiates_stack_fn(self): def _stack_fn(x): del x # Unused return eager_executor.EagerExecutor() factory = executor_factory.ExecutorFactoryImpl(_stack_fn) self.assertIsInstance(factory, executor_factory.ExecutorFactoryImpl)
def _create_inferred_cardinality_factory( max_fanout, stack_func, clients_per_thread) -> executor_factory.ExecutorFactory: """Creates executor function with variable cardinality.""" def _create_variable_clients_executors(cardinalities): """Constructs executor stacks from `dict` argument.""" py_typecheck.check_type(cardinalities, dict) for k, v in cardinalities.items(): py_typecheck.check_type(k, placement_literals.PlacementLiteral) if k not in [ placement_literals.CLIENTS, placement_literals.SERVER ]: raise ValueError('Unsupported placement: {}.'.format(k)) if v <= 0: raise ValueError( 'Cardinality must be at ' 'least one; you have passed {} for placement {}.'.format( v, k)) return _create_full_stack( cardinalities.get(placement_literals.CLIENTS, 0), max_fanout, stack_func, clients_per_thread) return executor_factory.ExecutorFactoryImpl( executor_stack_fn=_create_variable_clients_executors)
def test_call_constructs_executor(self): def _stack_fn(x): del x # Unused return eager_executor.EagerExecutor() factory = executor_factory.ExecutorFactoryImpl(_stack_fn) ex = factory.create_executor({}) self.assertIsInstance(ex, executor_base.Executor)
def test_cleanup_calls_close(self): ex = eager_executor.EagerExecutor() ex.close = mock.MagicMock() def _stack_fn(x): del x # Unused return ex factory = executor_factory.ExecutorFactoryImpl(_stack_fn) factory.create_executor({}) factory.clean_up_executors() ex.close.assert_called_once()
def _create_explicit_cardinality_factory( num_clients, max_fanout, stack_func, clients_per_thread) -> executor_factory.ExecutorFactory: """Creates executor function with fixed cardinality.""" def _return_executor(cardinalities): n_requested_clients = cardinalities.get(placement_literals.CLIENTS) if n_requested_clients is not None and n_requested_clients != num_clients: raise ValueError('Expected to construct an executor with {} clients, ' 'but executor is hardcoded for {}'.format( n_requested_clients, num_clients)) return _create_full_stack(num_clients, max_fanout, stack_func, clients_per_thread) return executor_factory.ExecutorFactoryImpl( executor_stack_fn=_return_executor)
def setUp(self): super().setUp() # 2 clients per worker stack * 3 worker stacks * 2 middle stacks self._num_clients = 12 def _stack_fn(x): del x # Unused return _create_middle_stack([ _create_middle_stack( [_create_worker_stack() for _ in range(3)]), _create_middle_stack( [_create_worker_stack() for _ in range(3)]) ]) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl(_stack_fn))
def test_construction_with_multiple_cardinalities_reuses_existing_stacks( self): ex = eager_executor.EagerExecutor() ex.close = mock.MagicMock() num_times_invoked = 0 def _stack_fn(x): del x # Unused nonlocal num_times_invoked num_times_invoked += 1 return ex factory = executor_factory.ExecutorFactoryImpl(_stack_fn) for _ in range(2): factory.create_executor({}) factory.create_executor({placement_literals.SERVER: 1}) self.assertEqual(num_times_invoked, 2)
def test_basic_functionality(self): @computations.tf_computation(computation_types.SequenceType(tf.int32)) def comp(ds): return ds.take(5).reduce(np.int32(0), lambda x, y: x + y) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl( lambda _: eager_executor.EagerExecutor())) ds = tf.data.Dataset.range(1).map(lambda x: tf.constant(5)).repeat() v = comp(ds) self.assertEqual(v, 25) set_default_executor.set_default_executor() self.assertIn( 'ExecutionContext', str(type(context_stack_impl.context_stack.current).__name__))
def test_end_to_end(self): @computations.tf_computation(tf.int32) def add_one(x): return tf.add(x, 1) ex = concurrent_executor.ConcurrentExecutor(eager_executor.EagerExecutor()) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl(lambda _: ex)) self.assertEqual(add_one(7), 8) # After this invocation, the ConcurrentExecutor has been closed, and needs # to be re-initialized. self.assertEqual(add_one(8), 9) set_default_executor.set_default_executor()
def test_runs_tf(test_obj, executor): """Tests `executor` can run a minimal TF computation.""" py_typecheck.check_type(executor, executor_base.Executor) set_default_executor.set_default_executor( executor_factory.ExecutorFactoryImpl(lambda _: executor)) test_obj.assertEqual(_dummy_tf_computation(), 10)