コード例 #1
0
    def test_executor_service_create_one_arg_computation_value_and_call(self):
        ex_factory = executor_test_utils.BasicTestExFactory(
            eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        @tensorflow_computation.tf_computation(tf.int32)
        def comp(x):
            return tf.add(x, 1)

        value_proto, _ = value_serialization.serialize_value(comp)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        comp_ref = response.value_ref

        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        arg_ref = response.value_ref

        response = env.stub.CreateCall(
            executor_pb2.CreateCallRequest(executor=env.executor_pb,
                                           function_ref=comp_ref,
                                           argument_ref=arg_ref))
        self.assertIsInstance(response, executor_pb2.CreateCallResponse)
        value_id = str(response.value_ref.id)
        value = env.get_value(value_id)
        self.assertEqual(value, 11)
        del env
コード例 #2
0
    def test_raising_failed_precondition_destroys_executor(self):

        # A simple clas to mock out the exceptions raised by GRPC.
        class GrpcFailedPrecondition(grpc.RpcError):
            pass

            def code(self):
                return grpc.StatusCode.FAILED_PRECONDITION

        def _return_error():
            return GrpcFailedPrecondition('Raising failed precondition')

        raising_ex = executor_test_utils.RaisingExecutor(_return_error)
        ex_factory = executor_test_utils.BasicTestExFactory(raising_ex)
        env = TestEnv(ex_factory)
        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        value_id = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        # When CreateValue is forced to run, the executor's error will be raised. We
        # raise lazily because the service operates future-to-future.
        with self.assertRaises(grpc.RpcError):
            env.get_value(value_id.value_ref.id)

        # The next CreateValue call, or indeed any call, should eagerly raise,
        # since the executor has been destroyed.
        with self.assertRaises(grpc.RpcError) as rpc_err:
            env.stub.CreateValue(
                executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                                value=value_proto))
        self.assertEqual(rpc_err.exception.code(),
                         grpc.StatusCode.FAILED_PRECONDITION)
        self.assertIn('No executor found',
                      rpc_err.exception.exception().details())
コード例 #3
0
    def test_executor_service_create_one_arg_computation_value_and_call(self):
        env = TestEnv(eager_executor.EagerExecutor())

        @computations.tf_computation(tf.int32)
        def comp(x):
            return tf.add(x, 1)

        value_proto, _ = executor_service_utils.serialize_value(comp)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        comp_ref = response.value_ref

        value_proto, _ = executor_service_utils.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        arg_ref = response.value_ref

        response = env.stub.CreateCall(
            executor_pb2.CreateCallRequest(function_ref=comp_ref,
                                           argument_ref=arg_ref))
        self.assertIsInstance(response, executor_pb2.CreateCallResponse)
        value_id = str(response.value_ref.id)
        value = env.get_value(value_id)
        self.assertEqual(value, 11)
        del env
コード例 #4
0
ファイル: remote_executor.py プロジェクト: mysqlsc/federated
 async def create_value(self, value, type_spec=None):
     value_proto, type_spec = (executor_service_utils.serialize_value(
         value, type_spec))
     response = self._stub.CreateValue(
         executor_pb2.CreateValueRequest(value=value_proto))
     py_typecheck.check_type(response, executor_pb2.CreateValueResponse)
     return RemoteValue(response.value_ref, type_spec, self)
コード例 #5
0
    def test_create_value_reraises_type_error(self, mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateValue = mock.Mock(side_effect=TypeError)
        stub = create_stub()

        with self.assertRaises(TypeError):
            stub.create_value(request=executor_pb2.CreateValueRequest())
コード例 #6
0
  def test_dispose_does_not_trigger_cleanup(self):

    class MockFactory(executor_factory.ExecutorFactory, mock.MagicMock):

      def create_executor(self, *args, **kwargs):
        return mock.MagicMock()

      def clean_up_executors(self):
        return

    ex_factory = MockFactory()
    ex_factory.clean_up_executors = mock.MagicMock()

    env = TestEnv(ex_factory)
    value_proto, _ = executor_serialization.serialize_value(
        tf.constant(10.0).numpy(), tf.float32)
    # Create the value
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    value_id = str(response.value_ref.id)
    # Check that the value appears in the _values map
    env.get_value_future_directly(value_id)
    # Dispose of the value
    dispose_request = executor_pb2.DisposeRequest()
    dispose_request.value_ref.append(response.value_ref)
    response = env.stub.Dispose(dispose_request)
    # We shouldn't be propagating close down the executor stack on Dispose--this
    # would close the bidi stream and cause a hang in the streaming case with
    # intermediate aggregation. Python GC takes care of pushing Dispose requests
    # from the aggregators to the workers.
    ex_factory.clean_up_executors.assert_not_called()
コード例 #7
0
 def test_create_value_returns_value(self, mock_executor_grpc_stub):
     response = executor_pb2.CreateValueResponse()
     instance = mock_executor_grpc_stub.return_value
     instance.CreateValue = mock.Mock(side_effect=[response])
     stub = create_stub()
     result = stub.create_value(request=executor_pb2.CreateValueRequest())
     self.assertEqual(result, response)
コード例 #8
0
    def test_executor_service_create_and_select_from_tuple(self):
        ex_factory = executor_test_utils.BasicTestExFactory(
            eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        ten_ref = response.value_ref
        self.assertEqual(env.get_value(ten_ref.id), 10)

        value_proto, _ = value_serialization.serialize_value(20, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        twenty_ref = response.value_ref
        self.assertEqual(env.get_value(twenty_ref.id), 20)

        response = env.stub.CreateStruct(
            executor_pb2.CreateStructRequest(
                executor=env.executor_pb,
                element=[
                    executor_pb2.CreateStructRequest.Element(
                        name='a', value_ref=ten_ref),
                    executor_pb2.CreateStructRequest.Element(
                        name='b', value_ref=twenty_ref)
                ]))
        self.assertIsInstance(response, executor_pb2.CreateStructResponse)
        tuple_ref = response.value_ref
        self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>')

        for index, result_val in [(0, 10), (1, 20)]:
            response = env.stub.CreateSelection(
                executor_pb2.CreateSelectionRequest(executor=env.executor_pb,
                                                    source_ref=tuple_ref,
                                                    index=index))
            self.assertIsInstance(response,
                                  executor_pb2.CreateSelectionResponse)
            selection_ref = response.value_ref
            self.assertEqual(env.get_value(selection_ref.id), result_val)

        del env
コード例 #9
0
 def test_executor_service_create_tensor_value(self):
     value_proto = executor_service_utils.serialize_value(
         tf.constant(10.0).numpy(), tf.float32)
     request = executor_pb2.CreateValueRequest(value=value_proto)
     response = self._stub.CreateValue(request)
     self.assertIsInstance(response, executor_pb2.CreateValueResponse)
     value_id = str(response.value_ref.id)
     value = self._extract_value_from_service(value_id)
     self.assertEqual(value.internal_representation.numpy(), 10.0)
コード例 #10
0
    def test_create_value_raises_retryable_error_on_grpc_error_unavailable(
            self, mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateValue = mock.Mock(
            side_effect=_raise_grpc_error_unavailable)
        stub = create_stub()

        with self.assertRaises(executors_errors.RetryableError):
            stub.create_value(request=executor_pb2.CreateValueRequest())
コード例 #11
0
    def test_create_value_reraises_grpc_error(self, mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateValue = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        stub = create_stub()

        with self.assertRaises(grpc.RpcError) as context:
            stub.create_value(request=executor_pb2.CreateValueRequest())

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
コード例 #12
0
 def test_executor_service_raises_after_cleanup_without_configuration(self):
     ex_factory = executor_stacks.ResourceManagingExecutorFactory(
         lambda _: eager_tf_executor.EagerTFExecutor())
     env = TestEnv(ex_factory)
     env.stub.ClearExecutor(executor_pb2.ClearExecutorRequest())
     value_proto, _ = executor_serialization.serialize_value(
         tf.constant(10.0).numpy(), tf.float32)
     with self.assertRaises(grpc.RpcError):
         env.stub.CreateValue(
             executor_pb2.CreateValueRequest(value=value_proto))
コード例 #13
0
        def _iterator():
            @computations.tf_computation(tf.int32)
            def comp(x):
                return tf.add(x, 1)

            value_proto, _ = executor_serialization.serialize_value(comp)
            request = executor_pb2.ExecuteRequest(
                create_value=executor_pb2.CreateValueRequest(
                    value=value_proto))
            yield request
コード例 #14
0
    def test_executor_service_create_and_select_from_tuple(self):
        ex_factory = executor_stacks.ResourceManagingExecutorFactory(
            lambda _: eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        value_proto, _ = executor_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        ten_ref = response.value_ref
        self.assertEqual(env.get_value(ten_ref.id), 10)

        value_proto, _ = executor_serialization.serialize_value(20, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        twenty_ref = response.value_ref
        self.assertEqual(env.get_value(twenty_ref.id), 20)

        response = env.stub.CreateStruct(
            executor_pb2.CreateStructRequest(element=[
                executor_pb2.CreateStructRequest.Element(name='a',
                                                         value_ref=ten_ref),
                executor_pb2.CreateStructRequest.Element(name='b',
                                                         value_ref=twenty_ref)
            ]))
        self.assertIsInstance(response, executor_pb2.CreateStructResponse)
        tuple_ref = response.value_ref
        self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>')

        for arg_name, arg_val, result_val in [('name', 'a', 10),
                                              ('name', 'b', 20),
                                              ('index', 0, 10),
                                              ('index', 1, 20)]:
            response = env.stub.CreateSelection(
                executor_pb2.CreateSelectionRequest(source_ref=tuple_ref,
                                                    **{arg_name: arg_val}))
            self.assertIsInstance(response,
                                  executor_pb2.CreateSelectionResponse)
            selection_ref = response.value_ref
            self.assertEqual(env.get_value(selection_ref.id), result_val)

        del env
コード例 #15
0
 def test_executor_service_create_tensor_value(self):
     env = TestEnv(eager_executor.EagerExecutor())
     value_proto, _ = executor_service_utils.serialize_value(
         tf.constant(10.0).numpy(), tf.float32)
     response = env.stub.CreateValue(
         executor_pb2.CreateValueRequest(value=value_proto))
     self.assertIsInstance(response, executor_pb2.CreateValueResponse)
     value_id = str(response.value_ref.id)
     value = env.get_value(value_id)
     self.assertEqual(value, 10.0)
     del env
コード例 #16
0
ファイル: remote_executor.py プロジェクト: ali-yaz/federated
    async def create_value(self, value, type_spec=None):
        @tracing.trace
        def serialize_value():
            return executor_serialization.serialize_value(value, type_spec)

        value_proto, type_spec = serialize_value()
        create_value_request = executor_pb2.CreateValueRequest(
            value=value_proto)
        response = _request(self._stub.CreateValue, create_value_request)
        py_typecheck.check_type(response, executor_pb2.CreateValueResponse)
        return RemoteValue(response.value_ref, type_spec, self)
コード例 #17
0
  def test_executor_service_slowly_create_tensor_value(self):

    class SlowExecutorValue(executor_value_base.ExecutorValue):

      def __init__(self, v, t):
        self._v = v
        self._t = t

      @property
      def type_signature(self):
        return self._t

      async def compute(self):
        return self._v

    class SlowExecutor(executor_base.Executor):

      def __init__(self):
        self.status = 'idle'
        self.busy = threading.Event()
        self.done = threading.Event()

      async def create_value(self, value, type_spec=None):
        self.status = 'busy'
        self.busy.set()
        self.done.wait()
        self.status = 'done'
        return SlowExecutorValue(value, type_spec)

      async def create_call(self, comp, arg=None):
        raise NotImplementedError

      async def create_struct(self, elements):
        raise NotImplementedError

      async def create_selection(self, source, index=None, name=None):
        raise NotImplementedError

      def close(self):
        pass

    ex = SlowExecutor()
    ex_factory = executor_stacks.ResourceManagingExecutorFactory(lambda _: ex)
    env = TestEnv(ex_factory)
    self.assertEqual(ex.status, 'idle')
    value_proto, _ = executor_serialization.serialize_value(10, tf.int32)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    ex.busy.wait()
    self.assertEqual(ex.status, 'busy')
    ex.done.set()
    value = env.get_value(response.value_ref.id)
    self.assertEqual(ex.status, 'done')
    self.assertEqual(value, 10)
コード例 #18
0
 async def create_value(self, value, type_spec=None):
   value_proto, type_spec = (
       executor_service_utils.serialize_value(value, type_spec))
   create_value_request = executor_pb2.CreateValueRequest(value=value_proto)
   if not self._bidi_stream:
     response = self._stub.CreateValue(create_value_request)
   else:
     response = (await self._bidi_stream.send_request(
         executor_pb2.ExecuteRequest(create_value=create_value_request)
     )).create_value
   py_typecheck.check_type(response, executor_pb2.CreateValueResponse)
   return RemoteValue(response.value_ref, type_spec, self)
コード例 #19
0
 def test_executor_service_create_tensor_value(self):
   ex_factory = executor_stacks.ResourceManagingExecutorFactory(
       lambda _: eager_tf_executor.EagerTFExecutor())
   env = TestEnv(ex_factory)
   value_proto, _ = executor_serialization.serialize_value(
       tf.constant(10.0).numpy(), tf.float32)
   response = env.stub.CreateValue(
       executor_pb2.CreateValueRequest(value=value_proto))
   self.assertIsInstance(response, executor_pb2.CreateValueResponse)
   value_id = str(response.value_ref.id)
   value = env.get_value(value_id)
   self.assertEqual(value, 10.0)
   del env
コード例 #20
0
    def test_executor_service_create_computation_value(self):
        @computations.tf_computation
        def comp():
            return tf.constant(10)

        value_proto = executor_service_utils.serialize_value(comp)
        request = executor_pb2.CreateValueRequest(value=value_proto)
        response = self._stub.CreateValue(request)
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        value_id = str(response.value_ref.id)
        value = self._extract_value_from_service(value_id)
        self.assertTrue(callable(value.internal_representation))
        self.assertEqual(value.internal_representation().numpy(), 10.0)
コード例 #21
0
    async def create_value(self, value, type_spec=None):
        @tracing.trace
        def serialize_value():
            return executor_serialization.serialize_value(value, type_spec)

        value_proto, type_spec = serialize_value()
        create_value_request = executor_pb2.CreateValueRequest(
            value=value_proto)
        if self._bidi_stream is None:
            response = _request(self._stub.CreateValue, create_value_request)
        else:
            response = (await self._bidi_stream.send_request(
                executor_pb2.ExecuteRequest(create_value=create_value_request)
            )).create_value
        py_typecheck.check_type(response, executor_pb2.CreateValueResponse)
        return RemoteValue(response.value_ref, type_spec, self)
コード例 #22
0
    def test_executor_service_create_no_arg_computation_value_and_call(self):
        env = TestEnv(eager_executor.EagerExecutor())

        @computations.tf_computation
        def comp():
            return tf.constant(10)

        value_proto, _ = executor_service_utils.serialize_value(comp)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        response = env.stub.CreateCall(
            executor_pb2.CreateCallRequest(function_ref=response.value_ref))
        self.assertIsInstance(response, executor_pb2.CreateCallResponse)
        value_id = str(response.value_ref.id)
        value = env.get_value(value_id)
        self.assertEqual(value, 10)
        del env
コード例 #23
0
    def iterator(self):
        @computations.tf_computation()
        def comp():
            return 1

        value_proto, _ = executor_serialization.serialize_value(comp)
        request = executor_pb2.ExecuteRequest(
            create_value=executor_pb2.CreateValueRequest(value=value_proto))
        yield request
        response = self.queue.get()
        create_call_proto = executor_pb2.CreateCallRequest(
            function_ref=response.create_value.value_ref, argument_ref=None)
        request = executor_pb2.ExecuteRequest(create_call=create_call_proto)
        yield request
        response = self.queue.get()
        compute_proto = executor_pb2.ComputeRequest(
            value_ref=response.create_call.value_ref)
        request = executor_pb2.ExecuteRequest(compute=compute_proto)
        yield request
コード例 #24
0
  def test_executor_service_create_no_arg_computation_value_and_call(self):
    ex_factory = executor_stacks.ResourceManagingExecutorFactory(
        lambda _: eager_tf_executor.EagerTFExecutor())
    env = TestEnv(ex_factory)

    @computations.tf_computation
    def comp():
      return tf.constant(10)

    value_proto, _ = executor_serialization.serialize_value(comp)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    response = env.stub.CreateCall(
        executor_pb2.CreateCallRequest(function_ref=response.value_ref))
    self.assertIsInstance(response, executor_pb2.CreateCallResponse)
    value_id = str(response.value_ref.id)
    value = env.get_value(value_id)
    self.assertEqual(value, 10)
    del env
コード例 #25
0
 def test_executor_service_value_unavailable_after_dispose(self):
     env = TestEnv(eager_tf_executor.EagerTFExecutor())
     value_proto, _ = executor_service_utils.serialize_value(
         tf.constant(10.0).numpy(), tf.float32)
     # Create the value
     response = env.stub.CreateValue(
         executor_pb2.CreateValueRequest(value=value_proto))
     self.assertIsInstance(response, executor_pb2.CreateValueResponse)
     value_id = str(response.value_ref.id)
     # Check that the value appears in the _values map
     env.get_value_future_directly(value_id)
     # Dispose of the value
     dispose_request = executor_pb2.DisposeRequest()
     dispose_request.value_ref.append(response.value_ref)
     response = env.stub.Dispose(dispose_request)
     self.assertIsInstance(response, executor_pb2.DisposeResponse)
     # Check that the value is gone from the _values map
     # get_value_future_directly is used here so that we can catch the
     # exception rather than having it occur on the GRPC thread.
     with self.assertRaises(KeyError):
         env.get_value_future_directly(value_id)