コード例 #1
0
    def test_executor_service_create_one_arg_computation_value_and_call(self):
        ex_factory = executor_test_utils.BasicTestExFactory(
            eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        @tensorflow_computation.tf_computation(tf.int32)
        def comp(x):
            return tf.add(x, 1)

        value_proto, _ = value_serialization.serialize_value(comp)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        comp_ref = response.value_ref

        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        arg_ref = response.value_ref

        response = env.stub.CreateCall(
            executor_pb2.CreateCallRequest(executor=env.executor_pb,
                                           function_ref=comp_ref,
                                           argument_ref=arg_ref))
        self.assertIsInstance(response, executor_pb2.CreateCallResponse)
        value_id = str(response.value_ref.id)
        value = env.get_value(value_id)
        self.assertEqual(value, 11)
        del env
コード例 #2
0
    def test_executor_service_create_one_arg_computation_value_and_call(self):
        env = TestEnv(eager_executor.EagerExecutor())

        @computations.tf_computation(tf.int32)
        def comp(x):
            return tf.add(x, 1)

        value_proto, _ = executor_service_utils.serialize_value(comp)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        comp_ref = response.value_ref

        value_proto, _ = executor_service_utils.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        arg_ref = response.value_ref

        response = env.stub.CreateCall(
            executor_pb2.CreateCallRequest(function_ref=comp_ref,
                                           argument_ref=arg_ref))
        self.assertIsInstance(response, executor_pb2.CreateCallResponse)
        value_id = str(response.value_ref.id)
        value = env.get_value(value_id)
        self.assertEqual(value, 11)
        del env
コード例 #3
0
    def test_create_call_reraises_type_error(self, mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateCall = mock.Mock(side_effect=TypeError)
        stub = create_stub()

        with self.assertRaises(TypeError):
            stub.create_call(request=executor_pb2.CreateCallRequest())
コード例 #4
0
    def test_create_call_raises_retryable_error_on_grpc_error_unavailable(
            self, mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateCall = mock.Mock(
            side_effect=_raise_grpc_error_unavailable)
        stub = create_stub()

        with self.assertRaises(executors_errors.RetryableError):
            stub.create_call(request=executor_pb2.CreateCallRequest())
コード例 #5
0
    def test_create_call_returns_remote_value(self, mock_executor_grpc_stub):
        response = executor_pb2.CreateCallResponse()
        instance = mock_executor_grpc_stub.return_value
        instance.CreateCall = mock.Mock(side_effect=[response])
        stub = create_stub()
        result = stub.create_call(request=executor_pb2.CreateCallRequest())

        instance.CreateCall.assert_called_once()
        self.assertEqual(result, response)
コード例 #6
0
    def test_create_call_reraises_grpc_error(self, mock_executor_grpc_stub):
        instance = mock_executor_grpc_stub.return_value
        instance.CreateCall = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        stub = create_stub()

        with self.assertRaises(grpc.RpcError) as context:
            stub.create_call(request=executor_pb2.CreateCallRequest())

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
コード例 #7
0
 async def create_call(self, comp, arg=None):
   py_typecheck.check_type(comp, RemoteValue)
   py_typecheck.check_type(comp.type_signature, computation_types.FunctionType)
   if arg is not None:
     py_typecheck.check_type(arg, RemoteValue)
   response = self._stub.CreateCall(
       executor_pb2.CreateCallRequest(
           function_ref=comp.value_ref,
           argument_ref=(arg.value_ref if arg is not None else None)))
   py_typecheck.check_type(response, executor_pb2.CreateCallResponse)
   return RemoteValue(response.value_ref, comp.type_signature.result, self)
コード例 #8
0
 async def create_call(self, comp, arg=None):
   py_typecheck.check_type(comp, RemoteValue)
   py_typecheck.check_type(comp.type_signature, computation_types.FunctionType)
   if arg is not None:
     py_typecheck.check_type(arg, RemoteValue)
   create_call_request = executor_pb2.CreateCallRequest(
       function_ref=comp.value_ref,
       argument_ref=(arg.value_ref if arg is not None else None))
   if not self._bidi_stream:
     response = self._stub.CreateCall(create_call_request)
   else:
     response = (await self._bidi_stream.send_request(
         executor_pb2.ExecuteRequest(create_call=create_call_request)
     )).create_call
   py_typecheck.check_type(response, executor_pb2.CreateCallResponse)
   return RemoteValue(response.value_ref, comp.type_signature.result, self)
コード例 #9
0
    def test_executor_service_create_no_arg_computation_value_and_call(self):
        env = TestEnv(eager_executor.EagerExecutor())

        @computations.tf_computation
        def comp():
            return tf.constant(10)

        value_proto, _ = executor_service_utils.serialize_value(comp)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        response = env.stub.CreateCall(
            executor_pb2.CreateCallRequest(function_ref=response.value_ref))
        self.assertIsInstance(response, executor_pb2.CreateCallResponse)
        value_id = str(response.value_ref.id)
        value = env.get_value(value_id)
        self.assertEqual(value, 10)
        del env
コード例 #10
0
    def iterator(self):
        @computations.tf_computation()
        def comp():
            return 1

        value_proto, _ = executor_serialization.serialize_value(comp)
        request = executor_pb2.ExecuteRequest(
            create_value=executor_pb2.CreateValueRequest(value=value_proto))
        yield request
        response = self.queue.get()
        create_call_proto = executor_pb2.CreateCallRequest(
            function_ref=response.create_value.value_ref, argument_ref=None)
        request = executor_pb2.ExecuteRequest(create_call=create_call_proto)
        yield request
        response = self.queue.get()
        compute_proto = executor_pb2.ComputeRequest(
            value_ref=response.create_call.value_ref)
        request = executor_pb2.ExecuteRequest(compute=compute_proto)
        yield request
コード例 #11
0
  def test_executor_service_create_no_arg_computation_value_and_call(self):
    ex_factory = executor_stacks.ResourceManagingExecutorFactory(
        lambda _: eager_tf_executor.EagerTFExecutor())
    env = TestEnv(ex_factory)

    @computations.tf_computation
    def comp():
      return tf.constant(10)

    value_proto, _ = executor_serialization.serialize_value(comp)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    response = env.stub.CreateCall(
        executor_pb2.CreateCallRequest(function_ref=response.value_ref))
    self.assertIsInstance(response, executor_pb2.CreateCallResponse)
    value_id = str(response.value_ref.id)
    value = env.get_value(value_id)
    self.assertEqual(value, 10)
    del env