def test_executor_service_create_one_arg_computation_value_and_call(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) @computations.tf_computation(tf.int32) def comp(x): return tf.add(x, 1) value_proto, _ = executor_serialization.serialize_value(comp) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) comp_ref = response.value_ref value_proto, _ = executor_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) arg_ref = response.value_ref response = env.stub.CreateCall( executor_pb2.CreateCallRequest( function_ref=comp_ref, argument_ref=arg_ref)) self.assertIsInstance(response, executor_pb2.CreateCallResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 11) del env
def test_executor_service_execute_failure_in_processing(self): class _RaisingExecutor(eager_tf_executor.EagerTFExecutor): async def create_value(self, *args, **kwargs): # Unknown exception on server raise Exception ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: _RaisingExecutor()) env = TestEnv(ex_factory) iter_obj = CallNoArgFnIterator() response_iterator = env.stub.Execute(iter_obj.iterator()) return_value_count = 0 with self.assertRaises(grpc.RpcError): # pylint: disable=g-error-prone-assert-raises # We disable the linter here because we should raise on the final # iteration. The return_value_count assertion below ensures that we have # as many return values as expected. for response in response_iterator: iter_obj.queue.put(response) self.assertIsInstance(response, executor_pb2.ExecuteResponse) return_value_count += 1 self.assertEqual(return_value_count, 3)
def test_executors_persisted_is_capped(self): ex = eager_tf_executor.EagerTFExecutor() factory = executor_stacks.ResourceManagingExecutorFactory(lambda _: ex) for num_clients in range(100): factory.create_executor({placements.CLIENTS: num_clients}) self.assertLess(len(factory._executors), 20)
def test_executor_service_execute_create_value(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) def _iterator(): @computations.tf_computation(tf.int32) def comp(x): return tf.add(x, 1) value_proto, _ = executor_serialization.serialize_value(comp) request = executor_pb2.ExecuteRequest( create_value=executor_pb2.CreateValueRequest( value=value_proto)) yield request response_iterator = env.stub.Execute(_iterator()) return_value_count = 0 for response in response_iterator: self.assertIsInstance(response, executor_pb2.ExecuteResponse) self.assertIsInstance(response.create_value, executor_pb2.CreateValueResponse) return_value_count += 1 self.assertEqual(return_value_count, 1)
def test_failure_on_construction_fails_as_expected(self, mock_service): ex = eager_tf_executor.EagerTFExecutor() ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: ex) with self.assertRaises(ValueError): with server_utils.server_context(ex_factory, 1, portpicker.pick_unused_port()): time.sleep(1)
def test_ensure_closed_closes_executor_passed_at_initialization( self, mock_ex): def _stack_fn(x): del x # Unused return ExecutorMock() resource_manager = executor_stacks.ResourceManagingExecutorFactory( _stack_fn, ensure_closed=[mock_ex]) resource_manager.clean_up_executors() mock_ex.close.assert_called_once()
def test_executor_service_raises_after_cleanup_without_configuration(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) env.stub.ClearExecutor(executor_pb2.ClearExecutorRequest()) value_proto, _ = executor_serialization.serialize_value( tf.constant(10.0).numpy(), tf.float32) with self.assertRaises(grpc.RpcError): env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto))
def test_executor_service_slowly_create_tensor_value(self): class SlowExecutorValue(executor_value_base.ExecutorValue): def __init__(self, v, t): self._v = v self._t = t @property def type_signature(self): return self._t async def compute(self): return self._v class SlowExecutor(executor_base.Executor): def __init__(self): self.status = 'idle' self.busy = threading.Event() self.done = threading.Event() async def create_value(self, value, type_spec=None): self.status = 'busy' self.busy.set() self.done.wait() self.status = 'done' return SlowExecutorValue(value, type_spec) async def create_call(self, comp, arg=None): raise NotImplementedError async def create_struct(self, elements): raise NotImplementedError async def create_selection(self, source, index=None, name=None): raise NotImplementedError def close(self): pass ex = SlowExecutor() ex_factory = executor_stacks.ResourceManagingExecutorFactory(lambda _: ex) env = TestEnv(ex_factory) self.assertEqual(ex.status, 'idle') value_proto, _ = executor_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) ex.busy.wait() self.assertEqual(ex.status, 'busy') ex.done.set() value = env.get_value(response.value_ref.id) self.assertEqual(ex.status, 'done') self.assertEqual(value, 10)
def test_as_default_context(self): ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV) factory = executor_stacks.ResourceManagingExecutorFactory( executor_stack_fn=lambda _: ex) context = execution_context.ExecutionContext(factory) set_default_context.set_default_context(context) @computations.tf_computation(tf.float32) def comp(x): return x + 1.0 self.assertEqual(comp(10.0), 11.0)
def test_executor_service_create_tensor_value(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) value_proto, _ = executor_serialization.serialize_value( tf.constant(10.0).numpy(), tf.float32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 10.0) del env
def test_executor_service_execute_failure_in_connection(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) iter_obj = CallNoArgFnIterator() response_iterator = env.stub.Execute(iter_obj.iterator()) with self.assertRaises(grpc.RpcError): for response in response_iterator: iter_obj.queue.put(response) env.close_channel()
def test_server_context_shuts_down_uncaught_exception( self, mock_logging_info): ex = eager_tf_executor.EagerTFExecutor() ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: ex) with self.assertRaises(TypeError): with server_utils.server_context( ex_factory, 1, portpicker.pick_unused_port()) as server: time.sleep(1) raise TypeError mock_logging_info.assert_called_once_with('Shutting down server.')
def test_server_context_shuts_down_under_keyboard_interrupt( self, mock_logging_info): ex = eager_tf_executor.EagerTFExecutor() ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: ex) with server_utils.server_context( ex_factory, 1, portpicker.pick_unused_port()) as server: time.sleep(1) raise KeyboardInterrupt mock_logging_info.assert_has_calls([ mock.call('Server stopped by KeyboardInterrupt.'), mock.call('Shutting down server.') ])
def test_executor_factory( num_clients: Optional[int] = None, clients_per_thread: int = 1, *, default_num_clients: int = 0) -> executor_factory.ExecutorFactory: """Constructs a test execution stack to execute local computations. This factory is similar to `tff.framework.thread_debugging_executor_factory` except that it is configured to delegate the implementation of federated intrinsics to a `federated_strategy.TestFederatedStrategy`. This execution stack can be useful when testing federated algorithms that require unique implementations for the intrinsics provided by TFF. Args: num_clients: (Deprecated) The number of clients. If specified, the executor factory returned by `local_executor_factory` will be configured to have exactly `num_clients` clients. If unspecified (`None`), then the function returned will attempt to infer cardinalities of all placements for which it is passed values. clients_per_thread: Integer number of clients for each of TFF's threads to run in sequence. Increasing `clients_per_thread` therefore reduces the concurrency of the TFF runtime, which can be useful if client work is very lightweight or models are very large and multiple copies cannot fit in memory. default_num_clients: The number of clients to run by default if cardinality cannot be inferred from arguments. Returns: An `executor_factory.ExecutorFactory`. """ unplaced_ex_factory = executor_stacks.UnplacedExecutorFactory( use_caching=False, can_resolve_references=True, ) num_clients = executor_stacks.normalize_num_clients_and_default_num_clients( num_clients, default_num_clients) federating_executor_factory = executor_stacks.FederatingExecutorFactory( clients_per_thread=clients_per_thread, unplaced_ex_factory=unplaced_ex_factory, default_num_clients=default_num_clients, use_sizing=False, federated_strategy_factory=federated_strategy.TestFederatedStrategy .factory) return executor_stacks.ResourceManagingExecutorFactory( federating_executor_factory.create_executor)
def test_executor_service_create_and_select_from_tuple(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) value_proto, _ = executor_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) ten_ref = response.value_ref self.assertEqual(env.get_value(ten_ref.id), 10) value_proto, _ = executor_serialization.serialize_value(20, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) twenty_ref = response.value_ref self.assertEqual(env.get_value(twenty_ref.id), 20) response = env.stub.CreateStruct( executor_pb2.CreateStructRequest(element=[ executor_pb2.CreateStructRequest.Element(name='a', value_ref=ten_ref), executor_pb2.CreateStructRequest.Element(name='b', value_ref=twenty_ref) ])) self.assertIsInstance(response, executor_pb2.CreateStructResponse) tuple_ref = response.value_ref self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>') for arg_name, arg_val, result_val in [('name', 'a', 10), ('name', 'b', 20), ('index', 0, 10), ('index', 1, 20)]: response = env.stub.CreateSelection( executor_pb2.CreateSelectionRequest(source_ref=tuple_ref, **{arg_name: arg_val})) self.assertIsInstance(response, executor_pb2.CreateSelectionResponse) selection_ref = response.value_ref self.assertEqual(env.get_value(selection_ref.id), result_val) del env
def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_factory = executor_stacks.local_executor_factory(num_clients=3) tracers = [] def _tracer_fn(cardinalities): tracer = executor_test_utils.TracingExecutor( target_factory.create_executor(cardinalities)) tracers.append(tracer) return tracer service = executor_service.ExecutorService( executor_stacks.ResourceManagingExecutorFactory(_tracer_fn)) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) stub = executor_pb2_grpc.ExecutorStub(channel) serialized_cards = executor_service_utils.serialize_cardinalities( {placement_literals.CLIENTS: 3}) stub.SetCardinalities( executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = reference_resolving_executor.ReferenceResolvingExecutor( remote_exec) try: yield collections.namedtuple('_', 'executor tracers')(executor, tracers) finally: executor.close() for tracer in tracers: tracer.close() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def test_executor_service_create_no_arg_computation_value_and_call(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) @computations.tf_computation def comp(): return tf.constant(10) value_proto, _ = executor_service_utils.serialize_value(comp) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) response = env.stub.CreateCall( executor_pb2.CreateCallRequest(function_ref=response.value_ref)) self.assertIsInstance(response, executor_pb2.CreateCallResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 10) del env
def test_executor_service_value_unavailable_after_dispose(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) value_proto, _ = executor_serialization.serialize_value( tf.constant(10.0).numpy(), tf.float32) # Create the value response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) value_id = str(response.value_ref.id) # Check that the value appears in the _values map env.get_value_future_directly(value_id) # Dispose of the value dispose_request = executor_pb2.DisposeRequest() dispose_request.value_ref.append(response.value_ref) response = env.stub.Dispose(dispose_request) self.assertIsInstance(response, executor_pb2.DisposeResponse) # Check that the value is gone from the _values map # get_value_future_directly is used here so that we can catch the # exception rather than having it occur on the GRPC thread. with self.assertRaises(KeyError): env.get_value_future_directly(value_id)
def test_clear_executor_calls_cleanup(self, mock_cleanup): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) env.stub.ClearExecutor(executor_pb2.ClearExecutorRequest()) mock_cleanup.assert_called_once()
def _create_test_executor_factory(): executor = eager_tf_executor.EagerTFExecutor() return executor_stacks.ResourceManagingExecutorFactory(lambda _: executor)
def create_test_executor_factory(): executor = eager_tf_executor.EagerTFExecutor() executor = caching_executor.CachingExecutor(executor) executor = reference_resolving_executor.ReferenceResolvingExecutor( executor) return executor_stacks.ResourceManagingExecutorFactory(lambda _: executor)