def test_executor_service_create_one_arg_computation_value_and_call(self):
    ex_factory = executor_stacks.ResourceManagingExecutorFactory(
        lambda _: eager_tf_executor.EagerTFExecutor())
    env = TestEnv(ex_factory)

    @computations.tf_computation(tf.int32)
    def comp(x):
      return tf.add(x, 1)

    value_proto, _ = executor_serialization.serialize_value(comp)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    comp_ref = response.value_ref

    value_proto, _ = executor_serialization.serialize_value(10, tf.int32)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    arg_ref = response.value_ref

    response = env.stub.CreateCall(
        executor_pb2.CreateCallRequest(
            function_ref=comp_ref, argument_ref=arg_ref))
    self.assertIsInstance(response, executor_pb2.CreateCallResponse)
    value_id = str(response.value_ref.id)
    value = env.get_value(value_id)
    self.assertEqual(value, 11)
    del env
  def test_executor_service_execute_failure_in_processing(self):

    class _RaisingExecutor(eager_tf_executor.EagerTFExecutor):

      async def create_value(self, *args, **kwargs):
        # Unknown exception on server
        raise Exception

    ex_factory = executor_stacks.ResourceManagingExecutorFactory(
        lambda _: _RaisingExecutor())
    env = TestEnv(ex_factory)

    iter_obj = CallNoArgFnIterator()
    response_iterator = env.stub.Execute(iter_obj.iterator())

    return_value_count = 0
    with self.assertRaises(grpc.RpcError):  # pylint: disable=g-error-prone-assert-raises
      # We disable the linter here because we should raise on the final
      # iteration. The return_value_count assertion below ensures that we have
      # as many return values as expected.
      for response in response_iterator:
        iter_obj.queue.put(response)
        self.assertIsInstance(response, executor_pb2.ExecuteResponse)
        return_value_count += 1

    self.assertEqual(return_value_count, 3)
Exemple #3
0
    def test_executors_persisted_is_capped(self):
        ex = eager_tf_executor.EagerTFExecutor()

        factory = executor_stacks.ResourceManagingExecutorFactory(lambda _: ex)
        for num_clients in range(100):
            factory.create_executor({placements.CLIENTS: num_clients})
        self.assertLess(len(factory._executors), 20)
Exemple #4
0
    def test_executor_service_execute_create_value(self):
        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        def _iterator():
            @computations.tf_computation(tf.int32)
            def comp(x):
                return tf.add(x, 1)

            value_proto, _ = executor_serialization.serialize_value(comp)
            request = executor_pb2.ExecuteRequest(
                create_value=executor_pb2.CreateValueRequest(
                    value=value_proto))
            yield request

        response_iterator = env.stub.Execute(_iterator())
        return_value_count = 0
        for response in response_iterator:
            self.assertIsInstance(response, executor_pb2.ExecuteResponse)
            self.assertIsInstance(response.create_value,
                                  executor_pb2.CreateValueResponse)
            return_value_count += 1

        self.assertEqual(return_value_count, 1)
Exemple #5
0
    def test_failure_on_construction_fails_as_expected(self, mock_service):

        ex = eager_tf_executor.EagerTFExecutor()
        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: ex)

        with self.assertRaises(ValueError):
            with server_utils.server_context(ex_factory, 1,
                                             portpicker.pick_unused_port()):
                time.sleep(1)
Exemple #6
0
    def test_ensure_closed_closes_executor_passed_at_initialization(
            self, mock_ex):
        def _stack_fn(x):
            del x  # Unused
            return ExecutorMock()

        resource_manager = executor_stacks.ResourceManagingExecutorFactory(
            _stack_fn, ensure_closed=[mock_ex])
        resource_manager.clean_up_executors()
        mock_ex.close.assert_called_once()
 def test_executor_service_raises_after_cleanup_without_configuration(self):
     ex_factory = executor_stacks.ResourceManagingExecutorFactory(
         lambda _: eager_tf_executor.EagerTFExecutor())
     env = TestEnv(ex_factory)
     env.stub.ClearExecutor(executor_pb2.ClearExecutorRequest())
     value_proto, _ = executor_serialization.serialize_value(
         tf.constant(10.0).numpy(), tf.float32)
     with self.assertRaises(grpc.RpcError):
         env.stub.CreateValue(
             executor_pb2.CreateValueRequest(value=value_proto))
  def test_executor_service_slowly_create_tensor_value(self):

    class SlowExecutorValue(executor_value_base.ExecutorValue):

      def __init__(self, v, t):
        self._v = v
        self._t = t

      @property
      def type_signature(self):
        return self._t

      async def compute(self):
        return self._v

    class SlowExecutor(executor_base.Executor):

      def __init__(self):
        self.status = 'idle'
        self.busy = threading.Event()
        self.done = threading.Event()

      async def create_value(self, value, type_spec=None):
        self.status = 'busy'
        self.busy.set()
        self.done.wait()
        self.status = 'done'
        return SlowExecutorValue(value, type_spec)

      async def create_call(self, comp, arg=None):
        raise NotImplementedError

      async def create_struct(self, elements):
        raise NotImplementedError

      async def create_selection(self, source, index=None, name=None):
        raise NotImplementedError

      def close(self):
        pass

    ex = SlowExecutor()
    ex_factory = executor_stacks.ResourceManagingExecutorFactory(lambda _: ex)
    env = TestEnv(ex_factory)
    self.assertEqual(ex.status, 'idle')
    value_proto, _ = executor_serialization.serialize_value(10, tf.int32)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    ex.busy.wait()
    self.assertEqual(ex.status, 'busy')
    ex.done.set()
    value = env.get_value(response.value_ref.id)
    self.assertEqual(ex.status, 'done')
    self.assertEqual(value, 10)
    def test_as_default_context(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
        factory = executor_stacks.ResourceManagingExecutorFactory(
            executor_stack_fn=lambda _: ex)
        context = execution_context.ExecutionContext(factory)
        set_default_context.set_default_context(context)

        @computations.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        self.assertEqual(comp(10.0), 11.0)
 def test_executor_service_create_tensor_value(self):
   ex_factory = executor_stacks.ResourceManagingExecutorFactory(
       lambda _: eager_tf_executor.EagerTFExecutor())
   env = TestEnv(ex_factory)
   value_proto, _ = executor_serialization.serialize_value(
       tf.constant(10.0).numpy(), tf.float32)
   response = env.stub.CreateValue(
       executor_pb2.CreateValueRequest(value=value_proto))
   self.assertIsInstance(response, executor_pb2.CreateValueResponse)
   value_id = str(response.value_ref.id)
   value = env.get_value(value_id)
   self.assertEqual(value, 10.0)
   del env
Exemple #11
0
    def test_executor_service_execute_failure_in_connection(self):

        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        iter_obj = CallNoArgFnIterator()
        response_iterator = env.stub.Execute(iter_obj.iterator())

        with self.assertRaises(grpc.RpcError):
            for response in response_iterator:
                iter_obj.queue.put(response)
                env.close_channel()
Exemple #12
0
    def test_server_context_shuts_down_uncaught_exception(
            self, mock_logging_info):

        ex = eager_tf_executor.EagerTFExecutor()
        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: ex)

        with self.assertRaises(TypeError):
            with server_utils.server_context(
                    ex_factory, 1, portpicker.pick_unused_port()) as server:
                time.sleep(1)
                raise TypeError

        mock_logging_info.assert_called_once_with('Shutting down server.')
Exemple #13
0
    def test_server_context_shuts_down_under_keyboard_interrupt(
            self, mock_logging_info):

        ex = eager_tf_executor.EagerTFExecutor()
        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: ex)

        with server_utils.server_context(
                ex_factory, 1, portpicker.pick_unused_port()) as server:
            time.sleep(1)
            raise KeyboardInterrupt

        mock_logging_info.assert_has_calls([
            mock.call('Server stopped by KeyboardInterrupt.'),
            mock.call('Shutting down server.')
        ])
Exemple #14
0
def test_executor_factory(
    num_clients: Optional[int] = None,
    clients_per_thread: int = 1,
    *,
    default_num_clients: int = 0) -> executor_factory.ExecutorFactory:
  """Constructs a test execution stack to execute local computations.

  This factory is similar to `tff.framework.thread_debugging_executor_factory`
  except that it is configured to delegate the implementation of federated
  intrinsics to a `federated_strategy.TestFederatedStrategy`.

  This execution stack can be useful when testing federated algorithms that
  require unique implementations for the intrinsics provided by TFF.

  Args:
    num_clients: (Deprecated) The number of clients. If specified, the executor
      factory returned by `local_executor_factory` will be configured to have
      exactly `num_clients` clients. If unspecified (`None`), then the function
      returned will attempt to infer cardinalities of all placements for which
      it is passed values.
    clients_per_thread: Integer number of clients for each of TFF's threads to
      run in sequence. Increasing `clients_per_thread` therefore reduces the
      concurrency of the TFF runtime, which can be useful if client work is very
      lightweight or models are very large and multiple copies cannot fit in
      memory.
    default_num_clients: The number of clients to run by default if cardinality
      cannot be inferred from arguments.

  Returns:
    An `executor_factory.ExecutorFactory`.
  """
  unplaced_ex_factory = executor_stacks.UnplacedExecutorFactory(
      use_caching=False,
      can_resolve_references=True,
  )
  num_clients = executor_stacks.normalize_num_clients_and_default_num_clients(
      num_clients, default_num_clients)
  federating_executor_factory = executor_stacks.FederatingExecutorFactory(
      clients_per_thread=clients_per_thread,
      unplaced_ex_factory=unplaced_ex_factory,
      default_num_clients=default_num_clients,
      use_sizing=False,
      federated_strategy_factory=federated_strategy.TestFederatedStrategy
      .factory)

  return executor_stacks.ResourceManagingExecutorFactory(
      federating_executor_factory.create_executor)
Exemple #15
0
    def test_executor_service_create_and_select_from_tuple(self):
        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        value_proto, _ = executor_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        ten_ref = response.value_ref
        self.assertEqual(env.get_value(ten_ref.id), 10)

        value_proto, _ = executor_serialization.serialize_value(20, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        twenty_ref = response.value_ref
        self.assertEqual(env.get_value(twenty_ref.id), 20)

        response = env.stub.CreateStruct(
            executor_pb2.CreateStructRequest(element=[
                executor_pb2.CreateStructRequest.Element(name='a',
                                                         value_ref=ten_ref),
                executor_pb2.CreateStructRequest.Element(name='b',
                                                         value_ref=twenty_ref)
            ]))
        self.assertIsInstance(response, executor_pb2.CreateStructResponse)
        tuple_ref = response.value_ref
        self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>')

        for arg_name, arg_val, result_val in [('name', 'a', 10),
                                              ('name', 'b', 20),
                                              ('index', 0, 10),
                                              ('index', 1, 20)]:
            response = env.stub.CreateSelection(
                executor_pb2.CreateSelectionRequest(source_ref=tuple_ref,
                                                    **{arg_name: arg_val}))
            self.assertIsInstance(response,
                                  executor_pb2.CreateSelectionResponse)
            selection_ref = response.value_ref
            self.assertEqual(env.get_value(selection_ref.id), result_val)

        del env
Exemple #16
0
def test_context(rpc_mode='REQUEST_REPLY'):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_factory = executor_stacks.local_executor_factory(num_clients=3)
    tracers = []

    def _tracer_fn(cardinalities):
        tracer = executor_test_utils.TracingExecutor(
            target_factory.create_executor(cardinalities))
        tracers.append(tracer)
        return tracer

    service = executor_service.ExecutorService(
        executor_stacks.ResourceManagingExecutorFactory(_tracer_fn))
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()

    channel = grpc.insecure_channel('localhost:{}'.format(port))
    stub = executor_pb2_grpc.ExecutorStub(channel)
    serialized_cards = executor_service_utils.serialize_cardinalities(
        {placement_literals.CLIENTS: 3})
    stub.SetCardinalities(
        executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards))

    remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracers')(executor,
                                                              tracers)
    finally:
        executor.close()
        for tracer in tracers:
            tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)
Exemple #17
0
    def test_executor_service_create_no_arg_computation_value_and_call(self):
        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        @computations.tf_computation
        def comp():
            return tf.constant(10)

        value_proto, _ = executor_service_utils.serialize_value(comp)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        response = env.stub.CreateCall(
            executor_pb2.CreateCallRequest(function_ref=response.value_ref))
        self.assertIsInstance(response, executor_pb2.CreateCallResponse)
        value_id = str(response.value_ref.id)
        value = env.get_value(value_id)
        self.assertEqual(value, 10)
        del env
 def test_executor_service_value_unavailable_after_dispose(self):
   ex_factory = executor_stacks.ResourceManagingExecutorFactory(
       lambda _: eager_tf_executor.EagerTFExecutor())
   env = TestEnv(ex_factory)
   value_proto, _ = executor_serialization.serialize_value(
       tf.constant(10.0).numpy(), tf.float32)
   # Create the value
   response = env.stub.CreateValue(
       executor_pb2.CreateValueRequest(value=value_proto))
   self.assertIsInstance(response, executor_pb2.CreateValueResponse)
   value_id = str(response.value_ref.id)
   # Check that the value appears in the _values map
   env.get_value_future_directly(value_id)
   # Dispose of the value
   dispose_request = executor_pb2.DisposeRequest()
   dispose_request.value_ref.append(response.value_ref)
   response = env.stub.Dispose(dispose_request)
   self.assertIsInstance(response, executor_pb2.DisposeResponse)
   # Check that the value is gone from the _values map
   # get_value_future_directly is used here so that we can catch the
   # exception rather than having it occur on the GRPC thread.
   with self.assertRaises(KeyError):
     env.get_value_future_directly(value_id)
 def test_clear_executor_calls_cleanup(self, mock_cleanup):
   ex_factory = executor_stacks.ResourceManagingExecutorFactory(
       lambda _: eager_tf_executor.EagerTFExecutor())
   env = TestEnv(ex_factory)
   env.stub.ClearExecutor(executor_pb2.ClearExecutorRequest())
   mock_cleanup.assert_called_once()
def _create_test_executor_factory():
    executor = eager_tf_executor.EagerTFExecutor()
    return executor_stacks.ResourceManagingExecutorFactory(lambda _: executor)
Exemple #21
0
def create_test_executor_factory():
    executor = eager_tf_executor.EagerTFExecutor()
    executor = caching_executor.CachingExecutor(executor)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        executor)
    return executor_stacks.ResourceManagingExecutorFactory(lambda _: executor)