def test_create_value_reraises_type_error(self, mock_stub): mock_stub.create_value = mock.Mock(side_effect=TypeError) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) with self.assertRaises(TypeError): asyncio.run(executor.create_value(1, tf.int32))
def _configure_remote_workers(default_num_clients, stubs, thread_pool_executor, dispose_batch_size): """"Configures `default_num_clients` across `remote_executors`.""" available_stubs = [stub for stub in stubs if stub.is_ready] logging.info('%s TFF workers available out of a total of %s.', len(available_stubs), len(stubs)) if not available_stubs: raise executors_errors.RetryableError( 'No workers are ready; try again to reconnect.') remaining_clients = default_num_clients live_workers = [] for stub_idx, stub in enumerate(available_stubs): remaining_stubs = len(available_stubs) - stub_idx default_num_clients_to_host = remaining_clients // remaining_stubs remaining_clients -= default_num_clients_to_host if default_num_clients_to_host > 0: ex = remote_executor.RemoteExecutor(stub, thread_pool_executor, dispose_batch_size) ex.set_cardinalities( {placements.CLIENTS: default_num_clients_to_host}) live_workers.append(ex) return [ _wrap_executor_in_threading_stack(e, can_resolve_references=False) for e in live_workers ]
def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = executor_stacks.local_executor_factory( num_clients=3).create_executor({}) tracer = executor_test_utils.TracingExecutor(target_executor) service = executor_service.ExecutorService(tracer) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = reference_resolving_executor.ReferenceResolvingExecutor( remote_exec) try: yield collections.namedtuple('_', 'executor tracer')(executor, tracer) finally: executor.close() tracer.close() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def test_set_cardinalities_returns_none(self, mock_stub): mock_stub.get_executor.return_value = executor_pb2.GetExecutorResponse( executor=executor_pb2.ExecutorId(id='test_id')) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) result = executor.set_cardinalities({placements.CLIENTS: 3}) self.assertIsNone(result)
def test_create_call_reraises_type_error(self, mock_stub): mock_stub.create_call = mock.Mock(side_effect=TypeError) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(TypeError): asyncio.run(executor.create_call(comp))
def test_create_value_reraises_grpc_error(self, mock_stub): mock_stub.create_value = mock.Mock( side_effect=_raise_non_retryable_grpc_error) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) with self.assertRaises(grpc.RpcError) as context: asyncio.run(executor.create_value(1, tf.int32)) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_create_value_returns_remote_value(self, mock_stub): mock_stub.create_value.return_value = executor_pb2.CreateValueResponse( ) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) result = asyncio.run(executor.create_value(1, tf.int32)) mock_stub.create_value.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_selection_reraises_type_error(self, mock_stub): mock_stub.create_selection = mock.Mock(side_effect=TypeError) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.StructType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(TypeError): asyncio.run(executor.create_selection(source, 0))
def test_create_struct_reraises_type_error(self, mock_stub): mock_stub.create_struct = mock.Mock(side_effect=TypeError) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.TensorType(tf.int32) value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(TypeError): asyncio.run(executor.create_struct([value_1, value_2]))
def test_create_call_returns_remote_value(self, mock_stub): mock_stub.create_call.return_value = executor_pb2.CreateCallResponse() executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.FunctionType(None, tf.int32) fn = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = asyncio.run(executor.create_call(fn, None)) mock_stub.create_call.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_call_reraises_grpc_error(self, mock_stub): mock_stub.create_call = mock.Mock( side_effect=_raise_non_retryable_grpc_error) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(grpc.RpcError) as context: asyncio.run(executor.create_call(comp, None)) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_create_selection_returns_remote_value(self, mock_stub): mock_stub.create_selection.return_value = executor_pb2.CreateSelectionResponse( ) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.StructType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = asyncio.run(executor.create_selection(source, 0)) mock_stub.create_selection.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_struct_returns_remote_value(self, mock_stub): mock_stub.create_struct.return_value = executor_pb2.CreateStructResponse( ) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) type_signature = computation_types.TensorType(tf.int32) value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = asyncio.run(executor.create_struct([value_1, value_2])) mock_stub.create_struct.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_compute_returns_result(self, mock_stub): tensor_proto = tf.make_tensor_proto(1) any_pb = any_pb2.Any() any_pb.Pack(tensor_proto) value = executor_pb2.Value(tensor=any_pb) mock_stub.compute.return_value = executor_pb2.ComputeResponse( value=value) executor = remote_executor.RemoteExecutor(mock_stub) _set_cardinalities_with_mock(executor, mock_stub) executor.set_cardinalities({placements.CLIENTS: 3}) type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = asyncio.run(comp.compute()) mock_stub.compute.assert_called_once() self.assertEqual(result, 1)
def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_factory = executor_stacks.local_executor_factory(num_clients=3) tracers = [] def _tracer_fn(cardinalities): tracer = executor_test_utils.TracingExecutor( target_factory.create_executor(cardinalities)) tracers.append(tracer) return tracer service = executor_service.ExecutorService( executor_stacks.ResourceManagingExecutorFactory(_tracer_fn)) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) stub = executor_pb2_grpc.ExecutorStub(channel) serialized_cards = executor_service_utils.serialize_cardinalities( {placement_literals.CLIENTS: 3}) stub.SetCardinalities( executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = reference_resolving_executor.ReferenceResolvingExecutor( remote_exec) try: yield collections.namedtuple('_', 'executor tracers')(executor, tracers) finally: executor.close() for tracer in tracers: tracer.close() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def create_remote_execution_context(channels, rpc_mode='REQUEST_REPLY', thread_pool_executor=None, dispose_batch_size=20, max_fanout: int = 100): """Creates context to execute computations using remote workers on `channels`.""" # TODO(b/166634524): Reparameterize worker_pool_executor_factory to # construct remote executors, rename to remote_executor_factory or something # similar. executors = [ remote_executor.RemoteExecutor(channel, rpc_mode, thread_pool_executor, dispose_batch_size) for channel in channels ] factory = executor_stacks.worker_pool_executor_factory( executors=executors, max_fanout=max_fanout, ) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
def test_context(): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_factory = executor_test_utils.LocalTestExecutorFactory( default_num_clients=3) tracers = [] def _tracer_fn(cardinalities): tracer = executor_test_utils.TracingExecutor( target_factory.create_executor(cardinalities)) tracers.append(tracer) return tracer service = executor_service.ExecutorService( executor_test_utils.BasicTestExFactory(_tracer_fn)) executor_pb2_grpc.add_ExecutorGroupServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) stub = remote_executor_grpc_stub.RemoteExecutorGrpcStub(channel) remote_exec = remote_executor.RemoteExecutor(stub) remote_exec.set_cardinalities({placements.CLIENTS: 3}) executor = reference_resolving_executor.ReferenceResolvingExecutor( remote_exec) try: yield collections.namedtuple('_', 'executor tracers')(executor, tracers) finally: executor.close() for tracer in tracers: tracer.close() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def remote_executor_factory(channels, rpc_mode='REQUEST_REPLY', thread_pool_executor=None, dispose_batch_size=20, max_fanout=100) -> executor_factory.ExecutorFactory: """Create an executor backed by remote workers. Args: channels: A list of `grpc.Channels` hosting services which can execute TFF work. rpc_mode: A string specifying the connection mode between the local host and `channels`. thread_pool_executor: Optional concurrent.futures.Executor used to wait for the reply to a streaming RPC message. Uses the default Executor if not specified. dispose_batch_size: The batch size for requests to dispose of remote worker values. Lower values will result in more requests to the remote worker, but will result in values being cleaned up sooner and therefore may result in lower memory usage on the remote worker. max_fanout: The maximum fanout at any point in the aggregation hierarchy. If `num_clients > max_fanout`, the constructed executor stack will consist of multiple levels of aggregators. The height of the stack will be on the order of `log(num_clients) / log(max_fanout)`. Returns: An instance of `executor_factory.ExecutorFactory` encapsulating the executor construction logic specified above. """ py_typecheck.check_type(channels, list) if not channels: raise ValueError('The list of channels cannot be empty.') remote_executors = [ remote_executor.RemoteExecutor(channel, rpc_mode, thread_pool_executor, dispose_batch_size) for channel in channels ] def _configure_remote_executor(ex, cardinalities, loop): """Configures `ex` to run the appropriate number of clients.""" loop.run_until_complete(ex.set_cardinalities(cardinalities)) return def _configure_remote_workers(cardinalities): loop = asyncio.new_event_loop() try: if not cardinalities.get(placement_literals.CLIENTS): for ex in remote_executors: _configure_remote_executor(ex, cardinalities, loop) return [_wrap_executor_in_threading_stack(e) for e in remote_executors] remaining_clients = cardinalities[placement_literals.CLIENTS] clients_per_most_executors = remaining_clients // len(remote_executors) for ex in remote_executors[:-1]: _configure_remote_executor( ex, {placement_literals.CLIENTS: clients_per_most_executors}, loop) remaining_clients -= clients_per_most_executors _configure_remote_executor( remote_executors[-1], {placement_literals.CLIENTS: remaining_clients}, loop) finally: loop.close() return [_wrap_executor_in_threading_stack(e) for e in remote_executors] flat_stack_fn = _configure_remote_workers unplaced_ex_factory = UnplacedExecutorFactory(use_caching=False) composing_executor_factory = ComposingExecutorFactory( max_fanout=max_fanout, unplaced_ex_factory=unplaced_ex_factory, flat_stack_fn=flat_stack_fn, ) return ResourceManagingExecutorFactory( executor_stack_fn=composing_executor_factory.create_executor)
def create_remote_executor(): port = portpicker.pick_unused_port() channel = grpc.insecure_channel('localhost:{}'.format(port)) return remote_executor.RemoteExecutor(channel, 'REQUEST_REPLY')
def remote_executor_factory( channels: List[grpc.Channel], rpc_mode: str = 'REQUEST_REPLY', thread_pool_executor: Optional[futures.Executor] = None, dispose_batch_size: int = 20, max_fanout: int = 100) -> executor_factory.ExecutorFactory: """Create an executor backed by remote workers. Args: channels: A list of `grpc.Channels` hosting services which can execute TFF work. rpc_mode: A string specifying the connection mode between the local host and `channels`. thread_pool_executor: Optional concurrent.futures.Executor used to wait for the reply to a streaming RPC message. Uses the default Executor if not specified. dispose_batch_size: The batch size for requests to dispose of remote worker values. Lower values will result in more requests to the remote worker, but will result in values being cleaned up sooner and therefore may result in lower memory usage on the remote worker. max_fanout: The maximum fanout at any point in the aggregation hierarchy. If `num_clients > max_fanout`, the constructed executor stack will consist of multiple levels of aggregators. The height of the stack will be on the order of `log(num_clients) / log(max_fanout)`. Returns: An instance of `executor_factory.ExecutorFactory` encapsulating the executor construction logic specified above. """ py_typecheck.check_type(channels, list) if not channels: raise ValueError('The list of channels cannot be empty.') py_typecheck.check_type(rpc_mode, str) if thread_pool_executor is not None: py_typecheck.check_type(thread_pool_executor, futures.Executor) py_typecheck.check_type(dispose_batch_size, int) py_typecheck.check_type(max_fanout, int) remote_executors = [] for channel in channels: remote_executors.append( remote_executor.RemoteExecutor( channel=channel, rpc_mode=rpc_mode, thread_pool_executor=thread_pool_executor, dispose_batch_size=dispose_batch_size)) def _get_event_loop(): should_close_loop = False try: loop = asyncio.get_event_loop() if loop.is_closed(): loop = asyncio.new_event_loop() should_close_loop = True except RuntimeError: loop = asyncio.new_event_loop() should_close_loop = True return loop, should_close_loop def _configure_remote_executor(ex, cardinalities, loop): """Configures `ex` to run the appropriate number of clients.""" if loop.is_running(): asyncio.run_coroutine_threadsafe( ex.set_cardinalities(cardinalities), loop) else: loop.run_until_complete(ex.set_cardinalities(cardinalities)) return def _configure_remote_workers(cardinalities): loop, must_close_loop = _get_event_loop() try: if not cardinalities.get(placement_literals.CLIENTS): for ex in remote_executors: _configure_remote_executor(ex, cardinalities, loop) return [ _wrap_executor_in_threading_stack(e) for e in remote_executors ] remaining_clients = cardinalities[placement_literals.CLIENTS] live_workers = [] for ex_idx, ex in enumerate(remote_executors): remaining_executors = len(remote_executors) - ex_idx num_clients_to_host = remaining_clients // remaining_executors remaining_clients -= num_clients_to_host if num_clients_to_host > 0: _configure_remote_executor( ex, {placement_literals.CLIENTS: num_clients_to_host}, loop) live_workers.append(ex) finally: if must_close_loop: loop.stop() loop.close() return [_wrap_executor_in_threading_stack(e) for e in live_workers] flat_stack_fn = _configure_remote_workers unplaced_ex_factory = UnplacedExecutorFactory(use_caching=False) composing_executor_factory = ComposingExecutorFactory( max_fanout=max_fanout, unplaced_ex_factory=unplaced_ex_factory, flat_stack_fn=flat_stack_fn, ) return ResourceManagingExecutorFactory( executor_stack_fn=composing_executor_factory.create_executor, ensure_closed=remote_executors)
def remote_executor_factory( channels: List[grpc.Channel], thread_pool_executor: Optional[futures.Executor] = None, dispose_batch_size: int = 20, max_fanout: int = 100, default_num_clients: int = 0, ) -> executor_factory.ExecutorFactory: """Create an executor backed by remote workers. Args: channels: A list of `grpc.Channels` hosting services which can execute TFF work. thread_pool_executor: Optional concurrent.futures.Executor used to wait for the reply to a streaming RPC message. Uses the default Executor if not specified. dispose_batch_size: The batch size for requests to dispose of remote worker values. Lower values will result in more requests to the remote worker, but will result in values being cleaned up sooner and therefore may result in lower memory usage on the remote worker. max_fanout: The maximum fanout at any point in the aggregation hierarchy. If `num_clients > max_fanout`, the constructed executor stack will consist of multiple levels of aggregators. The height of the stack will be on the order of `log(num_clients) / log(max_fanout)`. default_num_clients: The number of clients to use for simulations where the number of clients cannot be inferred. Usually the number of clients will be inferred from the number of values passed to computations which accept client-placed values. However, when this inference isn't possible (such as in the case of a no-argument or non-federated computation) this default will be used instead. Returns: An instance of `executor_factory.ExecutorFactory` encapsulating the executor construction logic specified above. """ py_typecheck.check_type(channels, list) if not channels: raise ValueError('The list of channels cannot be empty.') if thread_pool_executor is not None: py_typecheck.check_type(thread_pool_executor, futures.Executor) py_typecheck.check_type(dispose_batch_size, int) py_typecheck.check_type(max_fanout, int) py_typecheck.check_type(default_num_clients, int) remote_executors = [] for channel in channels: remote_executors.append( remote_executor.RemoteExecutor( channel=channel, thread_pool_executor=thread_pool_executor, dispose_batch_size=dispose_batch_size)) def _flat_stack_fn(cardinalities): num_clients = cardinalities.get(placements.CLIENTS, default_num_clients) return _configure_remote_workers(num_clients, remote_executors) unplaced_ex_factory = UnplacedExecutorFactory(use_caching=False) composing_executor_factory = ComposingExecutorFactory( max_fanout=max_fanout, unplaced_ex_factory=unplaced_ex_factory, flat_stack_fn=_flat_stack_fn, ) return ReconstructOnChangeExecutorFactory( underlying_stack=composing_executor_factory, ensure_closed=remote_executors, change_query=_CardinalitiesOrReadyListChanged( maybe_ready_list=remote_executors))